郑州网站关键词推广,网站制作公司大型,网站建设与数据库维护 pdf,黄骅港旅游景点在做深度学习项目时#xff0c;从头训练一个模型是需要大量时间和算力的#xff0c;我们通常采用加载预训练权重的方法#xff0c;而我们往往面临以下几种情况#xff1a;
未修改网络#xff0c;A与B一致
很简单#xff0c;直接.load_state_dict()
net ANet(num_cla…在做深度学习项目时从头训练一个模型是需要大量时间和算力的我们通常采用加载预训练权重的方法而我们往往面临以下几种情况
未修改网络A与B一致
很简单直接.load_state_dict()
net ANet(num_classses 5,init_weightsTrue)
net.to(device)
net.load_state_dict(torch.load(weight/B_weight.pth))修改了网络A与B不一致
[pytorch官方文档](Search — PyTorch master documentation):
load_state_dict(state_dict, strictTrue)
将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。
state_dict是包含参数和持久缓冲区的字典可以看出 strict默认为True所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的
load_state_dict()函数有两个返回值
missing_keys 是包含缺失键的 str 列表 unexpected_keys 是包含意外键的 str 列表
方法一
将strict改为false加载键值相同的部分。
model NET2()
state_dict model.state_dict()
weights torch.load(weights_path)[model_state_dict] #读取预训练模型权重
model.load_state_dict(weights, strictFalse) #strict但是此时还存在一种情况键值相同但shape不同故应进行if…in…的判断
ANet torch.load(ANet.pt) # 加载预训练权重模型(.pt文件)参数
#现成的模型的话如resnet50 models.resnet50(pretrainedTrue)
#采用:pretrained_dict resnet50().state_dict()
model Model() # 创建模型
model_dict model.state_dict() # 得到模型的参数字典# 判断预训练模型中网络的模块是否修改后的网络中也存在并且shape相同如果相同则取出
pretrained_dict {k: v for k, v in ANet.items() if k in model_dict and (v.shape model_dict[k].shape)}# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strictFalse)方法二:
1.将权重导入原模型之后在加载后的原模型基础上进行修改。 2.修改权重文件参数再进行导入 适用于改动不大的模型