工具
用于AES的crypto,二进制字符转换十六进制字符用的binascii,以及torch。
python 在Windows下使用AES时要安装的是pycryptodome 模块
pip install pycryptodome
python 在Linux下使用AES时要安装的是pycrypto模块
pip install pycrypto
关于AES
crypto里面的AES加密需要一个初始向量v,在下面的代码里变量叫做iv,需要是byte类型;还需要加密/解密用的key,在代码里的变量叫key。
代码
Net是我随便定义的一个网络。
- torch的save方法是用二进制数据保存的torch模型,我们也只能用二进制方法加密和解密。
- torch的load方法只能是load一个文件,并不能load一个对象,所以即便是你decrypte出来模型原文件,也只能先把它到文件里,然后在用torch load进来。
- 最后为了比较两个网络是否相同,我用了一个test_tensor去测试两个网络,得到相同的输出。注,比较前要先eval。具体见:这个博客
import torch
from Crypto.Cipher import AES
from binascii import b2a_hex, a2b_hex
from Util import Net
# 如果byte_string不足16位的倍数就用空格补足为16位
def add_to_16(byte_string):
if len(byte_string) % 16:
add = 16 - (len(byte_string) % 16)
else:
add = 0
byte_string = byte_string + (b'\0' * add)
return byte_string
key = '9999999999999999'.encode('utf-8')
mode = AES.MODE_CBC
iv = b'qqqqqqqqqqqqqqqq'
# 加密函数
def Encrypt(byte_string):
byte_string = add_to_16(byte_string)
cryptos = AES.new(key, mode, iv)
cipher_text = cryptos.encrypt(byte_string)
# 因为AES加密后的字符串不一定是ascii字符集的,输出保存可能存在问题,所以这里转为16进制字符串
return b2a_hex(cipher_text)
# 解密后,去掉补足的空格用strip() 去掉
def Decrypt(byte_string):
cryptos = AES.new(key, mode, iv)
plain_text = cryptos.decrypt(a2b_hex(byte_string))
return plain_text.rstrip(b'\0')
if __name__ == '__main__':
#测试向量
test_tensor=torch.randn(1,14)
net=Net()
#全部保存
torch.save(net,'model')
#加密&写加密文件
with open('model','rb') as f1:
encrypted=Encrypt(f1.read())
with open('model-encrypted','wb') as f2:
f2.write(encrypted)
#解密 加密过的文件
with open('model-decrypted','wb') as f:
content=open('model-encrypted','rb').read()
f.write(Decrypt(content))
#load原本model
net1=torch.load('./model')
net1=net1.eval()
#load解密后的model
net2=torch.load('model-decrypted')
net2=net2.eval()
#测试结果
print(net2(test_tensor), net1(test_tensor))
exit(0)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。