PyTorch 基础3
在 PyTorch 中,保存和加载模型主要有两种方式。一种是只保存参数(推荐),另一种是保存整个模型(不推荐)。
下面我为你详细对比这两种方式的区别、写法以及优缺点。
方式一:只保存参数 (官方推荐)
这是最标准、最通用的做法。 原理:只把模型的权重字典(比如卷积层的卷积核数值、全连接层的权重矩阵)存下来。不保存网络结构的代码。
1. 保存代码
1 | # 保存 |
2. 加载代码
Python
1 | import torch |
- 优点:
- 文件小:只存数据,不存多余的元数据。
- 灵活:如果以后想修改网络结构的一小部分,或者做迁移学习,这种方式更容易操作。
- 通用:完全不依赖项目目录结构。
- 缺点:
- 加载时,你的代码里必须有
class Net的定义,否则无法创建第一步的空壳子。
- 加载时,你的代码里必须有
方式二:保存整个模型
这种方式利用了 Python 的 pickle 序列化机制,把整个对象(包括网络结构定义和参数)一起打包存下来。
1. 保存代码
Python
1 | # 保存 |
2. 加载代码
Python
1 | import torch |
- 优点:
- 写法简单:少写两行代码,加载时不需要手动实例化对象。
- 致命缺点 (大坑):
- 高度依赖路径:它保存时会记录
Net类所在的具体文件路径。 - 举例:如果你在
project/train.py里定义了Net并保存了模型。如果你把在这个.pth文件发给朋友,或者你把代码移动到了project/src/train.py,加载时就会直接报错AttributeError: Can't get attribute 'Net'...。 - 很难迁移:如果你想修改网络结构,这种全量保存的文件很难处理。
- 高度依赖路径:它保存时会记录
1. 核心作用详解
A. tudui.train() —— 开启“练习模式”
- 默认状态:模型刚实例化出来时,默认就是这个模式。
- Dropout 层:会随机“偷懒”。
- 它会按照设定的概率(比如 50%)随机把一部分神经元扔掉(置为 0)。
- 目的:防止模型死记硬背(过拟合),增加训练难度。
- BatchNorm 层:现学现卖。
- 它会根据当前这一个 Batch 数据的均值和方差来进行归一化。
- 同时,它会悄悄记录一个“全局平均值”,留着给考试时候用。
B. tudui.eval() —— 开启“考试模式”
- Dropout 层:全员上阵。
- 所有神经元必须 100% 工作,一个都不能少。
- 目的:为了得到最强、最稳定的预测输出。
- BatchNorm 层:使用经验。
- 它不再看当前数据的均值方差了(因为考试时可能只有 1 张图,算均值没意义)。
- 它会拿出在训练时记录下来的那个“全局平均值”来做归一化。
2. 图解对比:以 Dropout 为例
- Training (左):神经元随机断开,网络结构每次都不一样。
- Evaluation (右):所有连接全部接通,数值会乘以一个系数来保持平衡。
只有网络模型和数据以及loss能放在GPU上运行