1. .pth文件详解
在pytorch进行模型保存的时候,一般有两种保存方式,一种是保存整个模型,另一种是只保存模型的参数。
torch.save(model.state_dict(), "my_model.pth") # 只保存模型的参数
torch.save(model, "my_model.pth") # 保存整个模型
保存的模型参数实际上一个字典类型,通过key-value的形式来存储模型的所有参数。
2. .pth文件基本信息查看
import torch
# resnet50的参数文件
path = r'C:\Users\Administrator\Desktop\resnet50-19c8e357.pth'
parameter = torch.load(path)
print(type(parameter)) # 类型是 dict
print(len(parameter)) # 长度为 267,即存在267个 key-value 键值对
3. 查看所有的键和值
for key in parameter.keys():
print(key) # 查看所有的键
for value in parameter.values():
print(value) # 查看所有的值
4. 输出单个键值对的值
print(parameter["conv1.weight"])
版权声明:本文为kingonlyuserjava原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。