Encrypt ONNX Model
作者:XD / 发表: 2022年4月2日 00:27 / 更新: 2022年4月2日 00:27 / 编程笔记 / 阅读量:2877
ONNX model can be encrypted with the following code, which can be compiled with a .so file to improve code safety.
import hashlib
from Crypto import Random
from Crypto.Cipher import AES
def load_graph(path):
with open(path, 'rb') as f:
protobuf_byte_str = f.read()
return protobuf_byte_str
def encrypt_file(raw, _key):
bs = 32
key = hashlib.sha256(_key.encode()).digest()
s = raw
raw = s + str.encode((bs - len(s) % bs) * chr(bs - len(s) % bs))
iv = Random.new().read(AES.block_size)
cipher = AES.new(key, AES.MODE_CBC, iv)
return (iv + cipher.encrypt(raw))
def decrypt_file(enc, _key):
key = hashlib.sha256(_key.encode()).digest()
iv = enc[:AES.block_size]
cipher = AES.new(key, AES.MODE_CBC, iv)
s = cipher.decrypt(enc[AES.block_size:])
return s[:-ord(s[len(s) - 1:])]
def main():
input_path = 'test.onnx'
output_path = 'test_encode.onnx'
_key = 'Password123!'
# encode
nodes_binary_str = load_graph(input_path)
nodes_binary_str = encrypt_file(nodes_binary_str, _key)
with open(output_path, 'wb') as f:
f.write(nodes_binary_str)
# decode
nodes_binary_str = load_graph(output_path)
nodes_str_decrypt = decrypt_file(nodes_binary_str, _key)
session = onnxruntime.InferenceSession(nodes_str_decrypt)
# ort_inputs = {session.get_inputs()[0].name: onnx_input}
# outputs = session.run(None, ort_inputs)
if __name__ == "__main__":
main()