MobileNet部署和微调
# MobileNet部署和微调
MobileNet 是一系列专为移动端和嵌入式设备设计的轻量级、高效的卷积神经网络架构。主要用于图像分类等任务。
# 快速开始
import torch
from torchvision import transforms, models
from PIL import Image
# 加载预训练的 MobileNetV3 模型
model = models.mobilenet_v3_large(pretrained=True) # 自动加载ImageNet权重
model.eval() # 设置为评估模式
# 定义图像预处理变换 (必须匹配模型训练时的预处理)
preprocess = transforms.Compose([
transforms.Resize(256), # 先缩放到256x256
transforms.CenterCrop(224), # 中心裁剪到224x224
transforms.ToTensor(), # 转为张量 (0-1)
transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet均值和标准差
std=[0.229, 0.224, 0.225]),
])
# 加载并预处理图像
img_path = '02.jpg'
img = Image.open(img_path)
input_tensor = preprocess(img) # 应用预处理变换
input_batch = input_tensor.unsqueeze(0) # 添加批次维度 (NCHW)
# 进行预测 (确保输入在GPU上运行, 如果模型也在GPU上)
with torch.no_grad():
output = model(input_batch)
# 获取预测类别索引 (top1)
_, predicted_idx = torch.max(output, 1)
# 输出预测类别索引
print(f"Predicted class: {predicted_idx.item()}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 训练(微调)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
# 设置中文字体(微软雅黑)
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 或者使用 ['SimHei'] 黑体
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)
# 检测GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据集路径
data_dir = './dataset' # 替换为你的数据集路径
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')
# 图像预处理
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并调整大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
transforms.ToTensor(), # 转换为张量
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet标准化
]),
'val': transforms.Compose([
transforms.Resize(256), # 调整大小
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 加载数据集
image_datasets = {
'train': datasets.ImageFolder(train_dir, data_transforms['train']),
'val': datasets.ImageFolder(val_dir, data_transforms['val'])
}
dataloaders = {
'train': torch.utils.data.DataLoader(
image_datasets['train'],
batch_size=32,
shuffle=True,
num_workers=0
),
'val': torch.utils.data.DataLoader(
image_datasets['val'],
batch_size=32,
shuffle=False,
num_workers=0
)
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(f"训练集大小: {dataset_sizes['train']}")
print(f"验证集大小: {dataset_sizes['val']}")
print(f"类别: {class_names}")
# 加载预训练模型
model = models.mobilenet_v3_large(pretrained=True) # 使用Large版本,也可以选择Small版本
print("原始模型结构:")
print(model)
# 冻结所有参数
for param in model.parameters():
param.requires_grad = False
# 修改分类器
num_ftrs = model.classifier[-1].in_features # 获取原始分类器的输入特征数
model.classifier[-1] = nn.Linear(num_ftrs, len(class_names)) # 替换最后一层
# 将模型移到GPU
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
# 只优化分类器参数
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
# 学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# 训练函数
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
# 记录训练历史
history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': []
}
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# 每个epoch都有训练和验证阶段
for phase in ['train', 'val']:
if phase == 'train':
model.train() # 设置训练模式
else:
model.eval() # 设置评估模式
running_loss = 0.0
running_corrects = 0
# 迭代数据
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# 只在训练阶段反向传播和优化
if phase == 'train':
loss.backward()
optimizer.step()
# 统计
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
# 记录历史
if phase == 'train':
history['train_loss'].append(epoch_loss)
history['train_acc'].append(epoch_acc.item())
else:
history['val_loss'].append(epoch_loss)
history['val_acc'].append(epoch_acc.item())
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# 深度复制模型
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print(f'训练完成于 {time_elapsed // 60:.0f}分 {time_elapsed % 60:.0f}秒')
print(f'最佳验证准确率: {best_acc:.4f}')
# 加载最佳模型权重
model.load_state_dict(best_model_wts)
return model, history
# 训练模型
num_epochs = 15
model, history = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)
# 保存模型
torch.save(model.state_dict(), 'mobilenet_v3_finetuned.pth')
# 可视化训练过程
def plot_training_history(history):
plt.figure(figsize=(12, 4))
# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='训练损失')
plt.plot(history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
# 准确率曲线
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='训练准确率')
plt.plot(history['val_acc'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.tight_layout()
plt.savefig('training_history.png')
plot_training_history(history)
# 可视化模型预测
def visualize_model(model, num_images=12):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure(figsize=(10, 8))
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title(f'predicted: {class_names[preds[j]]}\groundtruth: {class_names[labels[j]]}')
# 反标准化图像
inp = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if images_so_far == num_images:
model.train(mode=was_training)
plt.savefig('model_predictions.png')
return
model.train(mode=was_training)
visualize_model(model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# 附录-环境配置
# 虚拟环境
python -m venv venv
.\venv\Scripts\activate
# 清华源: -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install matplotlib
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
1
2
3
4
5
6
2
3
4
5
6
编辑 (opens new window)