pytorch模型的保存和加载!🍧

pytorch模型的保存和加载

torch提供了两种方式进行保存:

  1. 保存整个模型:保存整个模型的结构(代码)、参数。
  2. 保存模型参数:仅保存模型的参数,而不保存模型的结构(代码)。

先看一下第一种保存方式,保存整个模型的结构(代码)、参数:

1
2
# 保存模型
torch.save(model, 'model.pth')

那如何使用呢?特别简单:

1
2
3
4
# 加载整个模型
loaded_model = torch.load('model.pth')
# 直接进行推理
output = loaded_model(input_tensor)

第二种方式是只保存了保存模型的参数,不保存模型的结构(代码):

1
2
# 保存模型参数
torch.save(model.state_dict(), 'model_params.pth'

那使用呢?和第一种方式有很大的差别!要先实例化模型,也是说要有模型结构的代码,才能加载参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 模型结构
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 20) # 输入大小为10,输出大小为20
self.fc2 = nn.Linear(20, 1) # 输入大小为20,输出大小为1

def forward(self, x):
x = torch.relu(self.fc1(x)) # 使用ReLU作为激活函数
x = self.fc2(x)
return x

# 加载模型参数
model = SimpleNN() # 创建模型实例
model.load_state_dict(torch.load('model_params.pth'))
# 直接进行推理
output = model(input_tensor)

两种保存方式的差别,但是还是要注意:第一种方式其实是在保存模型的时候,序列化的数据被绑定到了特定的类(代码中的模型类)和确切的目录,本质上是不保存模型结构(代码)本身,而是保存这个模型结构(代码)的路径,并且在加载的时候会使用,因此当在其他项目里使用或者重构的时候,这种方式加载模型的时候会出错。所以一般建议使用第二种方式!

tips:

huggingface保存的bin文件和pth文件有什么区别?

答案是bin文件保存的是模型的参数,是上述的torch的第二种方式,想要完整加载模型是需要模型结构(代码)的。


pytorch模型的保存和加载!🍧
https://yangchuanzhi20.github.io/2024/04/28/人工智能/Pytorch/基础知识/pytorch模型的保存和加载/
作者
白色很哇塞
发布于
2024年4月28日
许可协议