长生栈 长生栈
首页
  • 编程语言

    • C语言
    • C++
    • Java
    • Python
  • 数据结构和算法

    • 全排列算法实现
    • 动态规划算法
  • CMake
  • gitlab 安装和配置
  • docker快速搭建wordpress
  • electron+react开发和部署
  • Electron-创建你的应用程序
  • ImgUI编译环境
  • 搭建图集网站
  • 使用PlantUml画时序图
  • 友情链接
关于
收藏
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

Living Team

编程技术分享
首页
  • 编程语言

    • C语言
    • C++
    • Java
    • Python
  • 数据结构和算法

    • 全排列算法实现
    • 动态规划算法
  • CMake
  • gitlab 安装和配置
  • docker快速搭建wordpress
  • electron+react开发和部署
  • Electron-创建你的应用程序
  • ImgUI编译环境
  • 搭建图集网站
  • 使用PlantUml画时序图
  • 友情链接
关于
收藏
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 计算机视觉

    • 基本环境搭建
    • 数据集制作-FFMPEG从视频提取关键帧
    • 数据集制作-图像分类
    • 数据集制作-目标检测
    • 数据集制作-LabelImg的使用
    • ResNet50部署和微调
      • 自定义数据集
      • 训练(微调)
      • 测试
      • 附录-环境配置
    • MobileNet部署和微调
    • YOLO部署和微调
    • Janus-Pro部署和使用
  • ESP32开发

  • Linux系统移植

  • 快速开始

  • 编程小知识

  • 技术
  • 计算机视觉
DC Wang
2025-06-07
目录

ResNet50部署和微调

# ResNet50部署和微调

ResNet(残差网络)是由微软研究院的何凯明(Kaiming He)等人于2015年提出的深度卷积神经网络架构。它解决了深度神经网络训练中的核心难题——​​梯度消失​​和​​退化问题​​(即网络层数增加时性能反而下降的现象)。 arXiv链接​​:arXiv:1512.03385 (opens new window)

# 自定义数据集

path-to-dataset-base/
├── train/               # 训练集
│   ├── class1/          # 类别1的图片
│   │   ├── img1.jpg
│   │   ├── img2.jpg
│   │   └── ...
│   └── class2/          # 类别2的图片
│       ├── img1.jpg
│       └── ...
│
├── val/                 # 验证集(可选)
│   ├── class_1/
│   │   ├── img1.jpg
│   │   └── ...
│   └── class2/
│       ├── img1.jpg
│       └── ...
│
└── test/                # 测试集(可选)
    ├── class_1/
    │   ├── test_img1.jpg
    │   └── ...
    └── class2/
        ├── img1.jpg
        └── ...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

# 训练(微调)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torchvision.models as models
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from torchinfo import summary

# 配置参数
BATCH_SIZE = 64
base_dir = "E:/Resources/Projects/Python/ResNet50"
train_dir = base_dir + "/dataset/train"
test_dir = base_dir + "/dataset/val"
models_saved_path = base_dir + "/models_saved"
output_path = base_dir + "/out"

# 工具函数
def mkdirp(path):
    if not os.path.exists(path):
        os.makedirs(path)

def save_last_model(model, epoch):
    mkdirp(models_saved_path)
    model_name = f'{models_saved_path}/model-last-{epoch}-' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + '.pth'
    torch.save(model.state_dict(), model_name)
    return model_name

def save_best_model(model, epoch):
    mkdirp(models_saved_path)
    model_name = f'{models_saved_path}/model-best-{epoch}.pth'
    torch.save(model.state_dict(), model_name)
    return model_name

def draw(epoch_list, train_loss_list, test_acc_list):
    plt.figure(figsize=(10, 8))
    plt.subplot(211)
    plt.plot(epoch_list, train_loss_list, color='darkorange')
    plt.title('train loss')
    
    plt.subplot(212)
    plt.plot(epoch_list, test_acc_list, color='deepskyblue')
    plt.title('test accuracy')
    
    plt.tight_layout()
    plt.subplots_adjust(wspace=None, hspace=0.3)
    mkdirp(output_path)
    plt.savefig(f"{output_path}/train_loss_acc.png")
    plt.close()

def save_csv(train_data):
    mkdirp(output_path)
    col_name = ["epoch", "train_loss", "test_accuracy"]
    df=pd.DataFrame(columns=col_name, data=train_data)
    df.to_csv(f"{output_path}/result.csv", encoding='utf-8')

# 数据加载
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 模型定义
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
num_classes = 2
pretrained_resnet.fc = nn.Linear(pretrained_resnet.fc.in_features, num_classes)
pretrained_resnet = pretrained_resnet.to(device)

# 训练过程
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_resnet.parameters(), lr=0.001)

def train_epoch(epoch):
    pretrained_resnet.train()
    running_loss = 0.0
    progress_train_bar = tqdm(total=len(train_loader), desc=f"Processing train {epoch}", leave=False, position=1)
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = pretrained_resnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_train_bar.update(1)
    progress_train_bar.close()

    return running_loss / len(train_loader)

def validate(epoch):
    pretrained_resnet.eval()
    correct = 0
    total = 0

    progress_test_bar = tqdm(total=len(test_loader), desc=f"Processing test  {epoch}", leave=False, position=1)
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = pretrained_resnet(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            progress_test_bar.update(1)
    progress_test_bar.close()

    return correct / total

def main():
    max_accuracy = 0.75
    epochs = 10
    epoch_list = []
    train_loss_list = []
    test_accuracy_list = []
    train_data = []

    pretrained_resnet.load_state_dict(torch.load('models_saved/model-last-10-2023-12-23-23-11-19.pth'))
    summary(pretrained_resnet, input_size=(BATCH_SIZE, 3, 224, 224))
    
    progress_bar = tqdm(total=epochs, desc="Epochs", leave=False, position=0)
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(epoch)
        save_last_model(pretrained_resnet, epoch)
        test_accuracy = validate(epoch)
        
        if test_accuracy > max_accuracy:
            max_accuracy = test_accuracy
            save_best_model(pretrained_resnet, epoch)

        epoch_list.append(epoch)
        train_loss_list.append(train_loss)
        test_accuracy_list.append(test_accuracy)
        
        if epoch > 2:
            draw(epoch_list, train_loss_list, test_accuracy_list)
        
        train_data.append([epoch, train_loss, test_accuracy])
        save_csv(train_data)
        progress_bar.update(1)
    
    progress_bar.close()

if __name__ == '__main__':
    main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

# 测试

import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from tqdm import tqdm
from torchinfo import summary

# 配置参数
BATCH_SIZE = 64
base_dir = "E:/Resources/Projects/Python/ResNet50"
test_dir = base_dir + "/dataset/test"
output_path = base_dir + "/out"

# 工具函数
def mkdirp(path):
    if not os.path.exists(path):
        os.makedirs(path)

def imshow(img, labels, predicted, index):
    plt.clf()
    np_img = img.cpu().numpy()
    np_img = np.transpose(np_img, (1, 2, 0))
    np_img = (np_img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
    
    plt.imshow(np_img.astype('uint8'))
    plt.axis('off')

    for i in range(len(predicted)):
        plt.text(
            (i % 8) * 225 + 4, (i // 8) * 225 + 60, 
            f'{predicted[i]}', 
            color='black', 
            backgroundcolor='green' if labels[i] == predicted[i] else 'red', 
            fontsize=8
        )

    mkdirp(output_path)
    plt.savefig(f"{output_path}/output_{index}.png", dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close()

# 数据加载
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = datasets.ImageFolder(test_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 模型定义
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
num_classes = 2
pretrained_resnet.fc = nn.Linear(pretrained_resnet.fc.in_features, num_classes)
pretrained_resnet = pretrained_resnet.to(device)

# 测试函数
def test_model():
    pretrained_resnet.eval()
    correct = 0
    total = 0
    index = 1

    progress_test_bar = tqdm(total=len(test_loader), desc=f"Processing test", leave=False, position=0)
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = pretrained_resnet(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            imshow(torchvision.utils.make_grid(inputs), labels, predicted, index)
            index += 1
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            progress_test_bar.update(1)
    progress_test_bar.close()

    return correct / total

def main():
    pretrained_resnet.load_state_dict(torch.load('models_saved/model-last-10-2023-12-24-11-30-46.pth'))
    summary(pretrained_resnet, input_size=(BATCH_SIZE, 3, 224, 224))
    
    accuracy = test_model()
    print(f"Test Accuracy: {accuracy}")

if __name__ == '__main__':
    main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

# 附录-环境配置

# 虚拟环境
python -m venv venv
.\venv\Scripts\activate
# 清华源: -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install torchinfo
pip install tqdm
pip install numpy
pip install matplotlib
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
1
2
3
4
5
6
7
8
9
编辑 (opens new window)
#AI#CV#Python
上次更新: 2025/06/07, 21:53:36
数据集制作-LabelImg的使用
MobileNet部署和微调

← 数据集制作-LabelImg的使用 MobileNet部署和微调→

最近更新
01
ESP32-网络摄像头方案
06-14
02
ESP32-PWM驱动SG90舵机
06-14
03
ESP32-实时操作系统freertos
06-14
更多文章>
Theme by Vdoing | Copyright © 2019-2025 DC Wang All right reserved | 辽公网安备 21021102001125号 | 吉ICP备20001966号-2
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式