PyTorch 基础3

在 PyTorch 中,保存和加载模型主要有两种方式。一种是只保存参数(推荐),另一种是保存整个模型(不推荐)

下面我为你详细对比这两种方式的区别、写法以及优缺点。


方式一:只保存参数 (官方推荐)

这是最标准、最通用的做法。 原理:只把模型的权重字典(比如卷积层的卷积核数值、全连接层的权重矩阵)存下来。不保存网络结构的代码。

1. 保存代码

1
2
3
# 保存
# 只保存 state_dict (它是 python 的字典格式)
torch.save(tudui.state_dict(), "tudui_parameter.pth")

2. 加载代码

Python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from model_file import Net # 假设你的网络定义在 model_file.py 里

# 第一步:必须先创建网络结构 (空壳子)
model = Net()

# 第二步:加载参数
# map_location 确保加载到正确的设备
state_dict = torch.load("tudui_parameter.pth", map_location=torch.device('cpu'))

# 第三步:将参数填入壳子
model.load_state_dict(state_dict)

model.eval() # 准备预测
  • 优点
    • 文件小:只存数据,不存多余的元数据。
    • 灵活:如果以后想修改网络结构的一小部分,或者做迁移学习,这种方式更容易操作。
    • 通用:完全不依赖项目目录结构。
  • 缺点
    • 加载时,你的代码里必须有 class Net 的定义,否则无法创建第一步的空壳子。

方式二:保存整个模型

这种方式利用了 Python 的 pickle 序列化机制,把整个对象(包括网络结构定义和参数)一起打包存下来。

1. 保存代码

Python

1
2
3
# 保存
# 直接保存整个模型对象
torch.save(tudui, "tudui_whole_model.pth")

2. 加载代码

Python

1
2
3
4
5
6
7
import torch

# 加载
# 不需要手动写 model = Net(),它自动帮你连结构带参数都加载出来
model = torch.load("tudui_whole_model.pth", map_location=torch.device('cpu'))

model.eval()
  • 优点
    • 写法简单:少写两行代码,加载时不需要手动实例化对象。
  • 致命缺点 (大坑)
    • 高度依赖路径:它保存时会记录 Net 类所在的具体文件路径
    • 举例:如果你在 project/train.py 里定义了 Net 并保存了模型。如果你把在这个 .pth 文件发给朋友,或者你把代码移动到了 project/src/train.py,加载时就会直接报错 AttributeError: Can't get attribute 'Net'...
    • 很难迁移:如果你想修改网络结构,这种全量保存的文件很难处理。

1. 核心作用详解

A. tudui.train() —— 开启“练习模式”

  • 默认状态:模型刚实例化出来时,默认就是这个模式。
  • Dropout 层会随机“偷懒”
    • 它会按照设定的概率(比如 50%)随机把一部分神经元扔掉(置为 0)。
    • 目的:防止模型死记硬背(过拟合),增加训练难度。
  • BatchNorm 层现学现卖
    • 它会根据当前这一个 Batch 数据的均值和方差来进行归一化。
    • 同时,它会悄悄记录一个“全局平均值”,留着给考试时候用。

B. tudui.eval() —— 开启“考试模式”

  • Dropout 层全员上阵
    • 所有神经元必须 100% 工作,一个都不能少。
    • 目的:为了得到最强、最稳定的预测输出。
  • BatchNorm 层使用经验
    • 不再看当前数据的均值方差了(因为考试时可能只有 1 张图,算均值没意义)。
    • 它会拿出在训练时记录下来的那个“全局平均值”来做归一化。

2. 图解对比:以 Dropout 为例

  • Training (左):神经元随机断开,网络结构每次都不一样。
  • Evaluation (右):所有连接全部接通,数值会乘以一个系数来保持平衡。

只有网络模型和数据以及loss能放在GPU上运行