当前位置: 首页 > news >正文

pytorch实训题

代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

1. 数据预处理与加载

transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 数据增强:随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10标准归一化参数
])

加载数据集

trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2
)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

2. 定义卷积神经网络模型

class CIFAR10CNN(nn.Module):
def init(self):
super(CIFAR10CNN, self).init()
# 卷积层部分
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), # 3输入通道,64输出通道,3x3卷积核
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # 2x2池化,步长2

        nn.Conv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.Conv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),nn.Conv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))# 全连接层部分self.fc_layers = nn.Sequential(nn.Dropout(0.5),nn.Linear(256 * 4 * 4, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 10)  # 10个分类)def forward(self, x):x = self.conv_layers(x)x = x.view(-1, 256 * 4 * 4)  # 展平特征图x = self.fc_layers(x)return x

3. 初始化模型、损失函数和优化器

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

net = CIFAR10CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

4. 训练模型

def train(epochs=20):
train_losses = []
test_losses = []
best_acc = 0.0

print("开始训练...")
start_time = time.time()for epoch in range(epochs):net.train()  # 训练模式running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 清零梯度optimizer.zero_grad()# 前向传播、计算损失、反向传播、参数更新outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计损失running_loss += loss.item()if i % 100 == 99:  # 每100个batch打印一次print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')running_loss = 0.0# 每个epoch结束后测试test_loss, acc = test()train_losses.append(running_loss / len(trainloader))test_losses.append(test_loss)# 学习率调整scheduler.step(test_loss)# 保存最佳模型if acc > best_acc:best_acc = acctorch.save(net.state_dict(), 'best_model.pth')print(f'Epoch {epoch+1} 测试准确率: {acc:.2f}%')print(f'训练完成,耗时: {time.time() - start_time:.2f}秒')
print(f'最佳测试准确率: {best_acc:.2f}%')# 绘制损失曲线
plt.plot(train_losses, label='训练损失')
plt.plot(test_losses, label='测试损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss_curve.png')
plt.close()return train_losses, test_losses

5. 测试模型

def test():
net.eval() # 评估模式
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():  # 不计算梯度for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / total
avg_loss = test_loss / len(testloader)
return avg_loss, acc

6. 测试每个类别的准确率

def test_class_accuracy():
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'类别 {classes[i]} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

7. 显示一些测试图像和预测结果

def show_predictions(num_images=5):
dataiter = iter(testloader)
images, labels = next(dataiter)

# 打印原始图像
imshow(torchvision.utils.make_grid(images))
print('真实标签: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))# 预测
outputs = net(images.to(device))
_, predicted = torch.max(outputs, 1)
print('预测标签: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.figure(figsize=(10, 4))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.savefig('predictions.png')
plt.close()

主程序

if name == 'main':
# 训练模型(20个epochs)
train_losses, test_losses = train(epochs=20)

# 加载最佳模型
net.load_state_dict(torch.load('best_model.pth'))# 评估模型
print("\n每个类别的准确率:")
test_class_accuracy()# 显示预测结果
show_predictions()

运行结果
使用设备: cuda:0
开始训练...
[1, 100] loss: 1.762
[1, 200] loss: 1.421
[1, 300] loss: 1.285
Epoch 1 测试准确率: 57.32%
[2, 100] loss: 1.135
[2, 200] loss: 1.052
...
训练完成,耗时: 456.23秒
最佳测试准确率: 85.67%

每个类别的准确率:
类别 plane 的准确率: 89.20%
类别 car 的准确率: 92.50%
类别 bird 的准确率: 78.30%
类别 cat 的准确率: 72.10%
类别 deer 的准确率: 84.50%
类别 dog 的准确率: 79.80%
类别 frog 的准确率: 88.70%
类别 horse 的准确率: 87.60%
类别 ship 的准确率: 91.20%
类别 truck 的准确率: 89.40%

http://www.gsyq.cn/news/21795.html

相关文章:

  • 【Azure App Service】App Service是否支持PHP的版本选择呢?
  • Markdown转换为Word:Pandoc模板使用指南 - 实践
  • 复习CSharp
  • C语言学习——运算符的学习
  • 实用指南:NXP - 用MCUXpresso IDE v25.6.136的工具链编译Smoothieware固件工程
  • cifar10
  • 感知节点@4@ ESP32+arduino+ 第二个程序 LED灯显示
  • WebGL学习及项目实战(第02期:绘制一个点)
  • display ip routing-table protocol ospf 概念及题目 - 详解
  • C语言学习——小数数据类型
  • 高敏感人应对焦虑
  • 2025 年执业兽医资格证备考服务机构推荐榜,执业兽医资格证培训机构/执兽考试机构/考试辅导机构获得行业推荐
  • [LangChain] 基本介绍
  • Palantir 的“本体工程”的核心思路、技术架构与实践示例
  • P14164 [ICPC 2022 Nanjing R] 命题作文
  • display ospf peer brief 概念及题目 - 实践
  • 记录一次客户现场环境,银河麒麟V10操作系统重启后,进入登录页面后卡死,鼠标键盘无响应的解决过程
  • ManySpeech.AliParaformerAsr 使用指南
  • 易路:以“薪酬科技+AI”重塑中国企业薪酬管理新范式
  • Web 编写 22
  • 下雪了 - L
  • 【html】canvas实现一个时钟 - 实践
  • 特殊函数
  • 一行代码也能行?极简实现GPIO按键关机(支持短按/长按)
  • 抖音麒麟福袋软件操作指南
  • 平面图最小割与对偶图最短路 - 干
  • 2025 苏州注册公司服务机构实用推荐:选择深度解析
  • LeetCode | 45. 跳跃游戏 II(转载)
  • 实用指南:mysql_query函数:数据库世界的信使
  • 基于MATLAB的车道线检测