当前位置: 首页 > news >正文

cifar10

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from multiprocessing import freeze_support
import sys

1. 加载和预处理数据

def load_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True,num_workers=2
)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2
)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')return trainloader, testloader, classes

2. 构建网络

class Net(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.flatten(x, 1)  # 展平x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x

3. 编译网络(定义损失函数和优化器)

def compile_model(net):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
return criterion, optimizer

4. 训练网络(已同步设备)

def train(net, trainloader, criterion, optimizer, device, epochs=2):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 核心:数据与模型设备同步
inputs, labels = inputs.to(device), labels.to(device)

        # 梯度清零optimizer.zero_grad()# 前向计算 + 反向传播 + 优化参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印训练日志running_loss += loss.item()if i % 2000 == 1999:  # 每2000个batch打印一次print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('训练完成')

5. 测试网络(已同步设备)

def test(net, testloader, classes, device):
correct = 0
total = 0
# 测试时不计算梯度,加快速度
with torch.no_grad():
for data in testloader:
images, labels = data
# 数据与模型设备同步
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'测试集整体准确率: {100 * correct // total} %')# 按类别统计准确率
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predictions = torch.max(outputs, 1)# 统计每个类别的预测结果for label, prediction in zip(labels, predictions):if label == prediction:correct_pred[classes[label]] += 1total_pred[classes[label]] += 1# 打印各类别准确率
for classname, correct_count in correct_pred.items():accuracy = 100 * float(correct_count) / total_pred[classname]print(f'类别: {classname:5s} 准确率: {accuracy:.1f} %')

if name == 'main':
freeze_support() # 解决Windows多进程问题
# 自动选择设备(有GPU用GPU,无则用CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {device}")

# 加载数据、初始化模型和优化器
trainloader, testloader, classes = load_data()
net = Net().to(device)  # 模型放到指定设备
criterion, optimizer = compile_model(net)# 重定向输出到文件,同时保留控制台打印
original_stdout = sys.stdout
with open('cifar10_result.txt', 'w') as f:sys.stdout = fprint(f"当前使用设备: {device}")train(net, trainloader, criterion, optimizer, device)test(net, testloader, classes, device)sys.stdout = original_stdout  # 恢复控制台输出print("训练完成!结果已保存到 cifar10_result.txt ")

image

http://www.gsyq.cn/news/21774.html

相关文章:

  • 感知节点@4@ ESP32+arduino+ 第二个程序 LED灯显示
  • WebGL学习及项目实战(第02期:绘制一个点)
  • display ip routing-table protocol ospf 概念及题目 - 详解
  • C语言学习——小数数据类型
  • 高敏感人应对焦虑
  • 2025 年执业兽医资格证备考服务机构推荐榜,执业兽医资格证培训机构/执兽考试机构/考试辅导机构获得行业推荐
  • [LangChain] 基本介绍
  • Palantir 的“本体工程”的核心思路、技术架构与实践示例
  • P14164 [ICPC 2022 Nanjing R] 命题作文
  • display ospf peer brief 概念及题目 - 实践
  • 记录一次客户现场环境,银河麒麟V10操作系统重启后,进入登录页面后卡死,鼠标键盘无响应的解决过程
  • ManySpeech.AliParaformerAsr 使用指南
  • 易路:以“薪酬科技+AI”重塑中国企业薪酬管理新范式
  • Web 编写 22
  • 下雪了 - L
  • 【html】canvas实现一个时钟 - 实践
  • 特殊函数
  • 一行代码也能行?极简实现GPIO按键关机(支持短按/长按)
  • 抖音麒麟福袋软件操作指南
  • 平面图最小割与对偶图最短路 - 干
  • 2025 苏州注册公司服务机构实用推荐:选择深度解析
  • LeetCode | 45. 跳跃游戏 II(转载)
  • 实用指南:mysql_query函数:数据库世界的信使
  • 基于MATLAB的车道线检测
  • 断言
  • 2025 年国内小程序开发优质机构最新推荐排行榜:覆盖多领域需求,助力政企精准选型
  • Python 受保护成员和私有成员
  • 2025 单招综评培训机构推荐榜:济南易升教育 5 星领跑,适配基础/冲刺/面试全流程备考
  • 深入解析:Scikit-learn Python机器学习 - 聚类分析算法 - Agglomerative Clustering(凝聚层次聚类)
  • “一切皆文件”:揭秘LINUX I/O与虚拟内存的底层设计哲学