Remove "module." from Pytorch pth dict
作者:XD / 发表: 2022年2月11日 04:25 / 更新: 2022年2月11日 04:28 / 编程笔记 / 阅读量:1758
Sometimes, the model is trained with multiply graphic cards. The saved model has the prefix "module." in each key from the model dict. Here is the method to remove "module." from pth dict.
weight_path = "best_model.pth"
class_number = 5
model = mobilenetv2(width_mult=0.25)
model.classifier = nn.Linear(model.last_channel, class_number)
update_weights = {}
temp = torch.load(weight_path)
for key in temp.keys():
update_weights[key[7:]] = temp[key]
model.load_state_dict(update_weights)