长生栈 长生栈
首页
  • 编程语言

    • C语言
    • C++
    • Java
    • Python
  • 数据结构和算法

    • 全排列算法实现
    • 动态规划算法
  • CMake
  • gitlab 安装和配置
  • docker快速搭建wordpress
  • electron+react开发和部署
  • Electron-创建你的应用程序
  • ImgUI编译环境
  • 搭建图集网站
  • 使用PlantUml画时序图
  • 友情链接
关于
收藏
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

Living Team

编程技术分享
首页
  • 编程语言

    • C语言
    • C++
    • Java
    • Python
  • 数据结构和算法

    • 全排列算法实现
    • 动态规划算法
  • CMake
  • gitlab 安装和配置
  • docker快速搭建wordpress
  • electron+react开发和部署
  • Electron-创建你的应用程序
  • ImgUI编译环境
  • 搭建图集网站
  • 使用PlantUml画时序图
  • 友情链接
关于
收藏
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 初识Python
  • 变量和运算符
  • python之正则表达式
  • 机器学习pytorch虚拟环境搭建
  • AI

    • pytorch-quickstart
    • pytorch-DATASETS & DATALOADERS
    • pytorch-TENSORS
    • pytorch-BUILD THE NEURAL NETWORK
    • pytorch-OPTIMIZING MODEL PARAMETERS
    • pytorch-SAVE AND LOAD THE MODEL
      • Saving and Loading Model Weights
      • Saving and Loading Models with Shapes
    • YOLO - You only look once
    • 知识蒸馏
  • Python
  • AI
DC Wang
2022-03-25
目录

pytorch-SAVE AND LOAD THE MODEL

# SAVE AND LOAD THE MODEL

In this section we will look at how to persist model state with saving, loading and running model predictions.

import torch
import torchvision.models as models
1
2

# Saving and Loading Model Weights

PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
1
2

To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
1
2
3

NOTE

be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

# Saving and Loading Models with Shapes

When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass model (and not model.state_dict()) to the saving function:

torch.save(model, 'model.pth')
1

We can then load the model like this:

model = torch.load('model.pth')
1

NOTE

This approach uses Python pickle (opens new window) module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

编辑 (opens new window)
#Python#AI#机器学习
上次更新: 2022/10/03, 09:24:26
pytorch-OPTIMIZING MODEL PARAMETERS
YOLO - You only look once

← pytorch-OPTIMIZING MODEL PARAMETERS YOLO - You only look once→

最近更新
01
ESP32-网络摄像头方案
06-14
02
ESP32-PWM驱动SG90舵机
06-14
03
ESP32-实时操作系统freertos
06-14
更多文章>
Theme by Vdoing | Copyright © 2019-2025 DC Wang All right reserved | 辽公网安备 21021102001125号 | 吉ICP备20001966号-2
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式