当前位置: 首页 > news >正文

P66实训题

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import numpy as np

1. 数据加载与预处理

transform = Compose([
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

类别标签

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

查看数据集样本(可选)

def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

dataiter = iter(train_loader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images[:4]))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

2. 构建网络

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.3)

def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv2(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv3(x)))x = self.dropout(x)x = x.view(-1, 128 * 4 * 4)x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

net = Net()

3. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

4. 训练网络

epochs = 30
train_losses = []
train_accs = []
test_losses = []
test_accs = []

for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
net.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = running_loss / len(train_loader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
train_accs.append(train_acc)

# 测试
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:outputs = net(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()
test_loss = test_loss / len(test_loader)
test_acc = 100. * correct / total
test_losses.append(test_loss)
test_accs.append(test_acc)print(f'Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.2f}%')

print('Finished Training')

绘制训练曲线

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy')
plt.show()

5. 测试模型精度

net.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

print(f'测试集精度: {100. * correct / total:.2f}%')

查看各类别预测精度(可选)

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1

for i in range(10):
print(f'类别 {classes[i]} 的精度: {100. * class_correct[i] / class_total[i]:.2f}%')

http://www.gsyq.cn/news/21807.html

相关文章:

  • 非主流网站程序IndexNow添加方法
  • 卷积神经网络视频读书报告
  • 以*this返回局部对象的两种情况
  • 2025.10.15
  • Kali 自定义ISO镜像
  • pytorch实训题
  • 【Azure App Service】App Service是否支持PHP的版本选择呢?
  • Markdown转换为Word:Pandoc模板使用指南 - 实践
  • 复习CSharp
  • C语言学习——运算符的学习
  • 实用指南:NXP - 用MCUXpresso IDE v25.6.136的工具链编译Smoothieware固件工程
  • cifar10
  • 感知节点@4@ ESP32+arduino+ 第二个程序 LED灯显示
  • WebGL学习及项目实战(第02期:绘制一个点)
  • display ip routing-table protocol ospf 概念及题目 - 详解
  • C语言学习——小数数据类型
  • 高敏感人应对焦虑
  • 2025 年执业兽医资格证备考服务机构推荐榜,执业兽医资格证培训机构/执兽考试机构/考试辅导机构获得行业推荐
  • [LangChain] 基本介绍
  • Palantir 的“本体工程”的核心思路、技术架构与实践示例
  • P14164 [ICPC 2022 Nanjing R] 命题作文
  • display ospf peer brief 概念及题目 - 实践
  • 记录一次客户现场环境,银河麒麟V10操作系统重启后,进入登录页面后卡死,鼠标键盘无响应的解决过程
  • ManySpeech.AliParaformerAsr 使用指南
  • 易路:以“薪酬科技+AI”重塑中国企业薪酬管理新范式
  • Web 编写 22
  • 下雪了 - L
  • 【html】canvas实现一个时钟 - 实践
  • 特殊函数
  • 一行代码也能行?极简实现GPIO按键关机(支持短按/长按)