深度学习中,模型权重保存与加载是很普遍的需求, 最简单的情况是用单个GPU训练模型,完成后保存权重,重新加载时一般不会出问题。 以PyTorch为例,如下操作即可完成
import torch
import Model
# Model should be defined and trained earlier
# save model weights
model = Model(args)
torch.save(model.state_dict(), 'best_model.pth')
# load previous `model` weights
model = Model(args)
pretrained_model = torch.load('best_model.pth')
model.load_state_dict(pretrained_model)
然而很多时候为了缩短实验周期,往往会用多个GPU加速训练过程,此时模型权重保存与加载会有些$\color{red}{注意事项}$, 否则很可能得不到正确的预测/调优结果。以PyTorch为例,多卡训练模型后,保存步骤如下
方式一
import torch
from torch.utils.data import DistributedDataParallel as DDP
import Model
# Model should be defined earlier
model = Model(args)
# wrap with DDP for training with multiple GPUs,
# here `rank` is a parameter specifying which GPU to used, and
# it is passed into by PyTorch multiprocessing function
model_ddp = DDP(model, device_ids=[rank])
# after training, save `model_ddp` weights
torch.save(model_ddp.state_dict(), 'best_model.pth')
此时,如果想要加载多卡训练得到的 best_model.pth
,不同情况加载方法不同,否则某些权重不能正常加载
(此时也无异常提示,而当事人以为一切正常,最终模型效果不好而又找不到问题)或报错
情况一
多卡加载(比如进行调优/下游任务)。因为方式一中模型是在DDP下模式保存的,所以加载时也要包裹在DDP下
import torch
from torch.utils.data import DistributedDataParallel as DDP
import Model
# Model should be defined earlier
model = Model(args)
# wrap with DDP for training with multiple GPUs,
# here `rank` is a parameter specifying which GPU to used, and
# it is passed into by PyTorch multiprocessing function
model_ddp = DDP(model, device_ids=[rank])
# load `model_ddp` weights trained on multiple GPUs
pretrained_model = torch.load('best_model.pth')
model_ddp.load_state_dict(pretrained_model)
情况二
单卡加载(比如进行测试/推理)。因为方式一中模型是在DDP模式下模式保存的,而单卡测试无法使用DDP模式,此时就要对原来保存的
best_model.pth
的key作些更改,以适应单卡需求
import torch
import Model
# Model should be defined earlier
model = Model(args)
# load previous `model_ddp` weights
pretrained_model = torch.load('best_model.pth')
# Actually, best_model.pth is a torch.nn.Module, and
# each Module is a dict consisting of `key: value` pairs.
# Here we need to change `key` in pretrained_model
# e.g. features.module.0.weight -> features.0.weight
# key in DDP mode -> key in single GPU mode
# `module` in features.module.0.weight is added by DDP utility
pretrained_model = {key.replace("module.", ""): value for key, value in pretrained_model.items()}
# load model weights on single GPU
model.load_state_dict(pretrained_model)
可以看到单卡加载多卡训练保存的模型,需要对DDP模型的key作出修改,以匹配单卡加载出来模型的key;若不进行此修改
且 model.load_state_dict(pretrained_model, strict=False)
时,代码运行并$\color{red}{不会提示}$未匹配的key及对应的权重未被加载,
在推理/调优时效果会大打折扣而找不到问题所在。
当然,多卡训练完模型,保存时还有另一种方式,如下
方式二
import torch
from torch.utils.data import DistributedDataParallel as DDP
import Model
# Model should be defined earlier
model = Model(args)
# wrap with DDP for training with multiple GPUs,
# here `rank` is a parameter specifying which GPU to used, and
# it is passed into by PyTorch multiprocessing function
model_ddp = DDP(model, device_ids=[rank])
# after training, save model weights
# here `model_ddp.module.state_dict()` is equivalent to saving model
# trained on single GPU since the synchronization mechanism in PyTorch DDP
# makes sure all models trained on different GPUs are same after every epoch
torch.save(model_ddp.module.state_dict(), 'best_model.pth')
因为保存机制作出了改变,加载机制也应作出改变,分不同情况
情况一
多卡加载(比如进行调优/下游任务)。因为方式二中模型相当于在单卡模式下保存,而多卡训练要用到DDP,所以就是 把模型权重加载出来再用DDP封装
import torch
from torch.utils.data import DistributedDataParallel as DDP
import Model
# Model should be defined earlier
model = Model(args)
# load previous `model_ddp` weights
pretrained_model = torch.load('best_model.pth')
model.load_state_dict(pretrained_model)
# wrap with DDP for training with multiple GPUs,
# here `rank` is a parameter specifying which GPU to used, and
# it is passed into by PyTorch multiprocessing function
model_ddp = DDP(model, device_ids=[rank])
情况二
单卡加载(比如进行测试/推理)。因为方式二中模型相当于在单卡模式下保存,单卡直接加载即可
import torch
import Model
# Model should be defined earlier
model = Model(args)
# load `model` weights
pretrained_model = torch.load('best_model.pth')
model.load_state_dict(pretrained_model)
总结
很多人有这样的经历,复现模型时经常达不到论文报告的效果,有时差了好几个点又找不到问题,靠调参不可能补上这么大gap,
更离奇这是论文作者公开的代码。
这样的问题我也遇到过,比如复现 CrossPoint,加载预训练模型key不匹配,
导致最终效果里论文结果相差甚远,
主要原因就是上面讲的 保存-加载
不匹配(我把代码改成了分布式训练,原代码是单卡训练,但忽略了加载模型也要随之改变)
经过一番研究终于搞明白了这个问题,模型效果显著提升,某些数据集上的指标达到了与论文可比的结果。
总而言之,单卡保存-单卡加载,多卡保存-多卡加载一般不会有意外,而 单卡保存-多卡加载
,多卡保存-单卡加载
就要
灵活应对了。