当前位置: 首页 > news >正文

简单的CNN实现

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
transform = transforms.Compose([transforms.ToTensor(),  # 转为张量transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

transforms.ToTensor():
将PIL图像或NumPy数组转换为PyTorch张量(Tensor)
自动将像素值从 [0, 255] 归一化到 [0, 1] 范围
将图像格式从 (H, W, C) 转换为 (C, H, W),即通道优先
transforms.Normalize((0.5,), (0.5,)) :
对图像进行标准化处理
输出范围:[(0-0.5)/0.5, (1-0.5)/0.5] = [-1, 1]

train_dataset = datasets.MNIST(root= './data',train= True,transform = transform,download=True)
test_dataset = datasets.MNIST(root='./data',train= False, transform= transform,download=True)

下载MNIST数据集:
transform = transform表示用上述定义的操作处理数据

train_loader = torch.utils.data.DataLoader(dataset= train_dataset,batch_size = 64,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset= test_dataset,batch_size = 64,shuffle = False)class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN,self).__init__()self.conv1 = nn.Conv2d(1,32,kernel_size=3,stride=1,padding=1)self.conv2 = nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1)self.fc1 = nn.Linear(64*7*7,128)self.fc2 = nn.Linear(128,10)def forward(self,x):x = F.relu(self.conv1(x))x = F.max_pool2d(x,2)x = F.relu(self.conv2(x))x = F.max_pool2d(x,2)x = x.view(-1,64*7*7)x = F.relu(self.fc1(x))x = self.fc2(x)return x

Conv2d(输入通道数,输出通道数,核大小,步长,padding)
输入通道由输入图像通道数决定
输出通道数由核数决定
linear(x,y) x->y


model = SimpleCNN()criterion = nn.CrossEntropyLoss()#交叉熵损失函数
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.9)num_epochs= 5
model.train()for epoch in range(num_epochs):total_loss = 0for images,labels in train_loader:outputs = model(images)loss = criterion(outputs,labels)#criterion函数是用于计算模型预测值与真实值之间差距的损失函数。#先将梯度归零(optimizer.zero_grad()),然后反向传播计算得到每个参数的梯度值(loss.backward()),最后通过梯度下降执行一步参数              #更新(optimizer.step())optimizer.zero_grad()#使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关loss.backward()#如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去optimizer.step()#optimizer只负责通过梯度下降进行优化total_loss += loss.item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")
model.eval()
correct = 0
total = 0with torch.no_grad():for images,labels in test_loader:outputs = model(images)_,predicted = torch.max(outputs,1)total += labels.size(0)correct +=(predicted == labels).sum().item()accuracy = 100*correct/total
print(f"Test Accuracy: {accuracy:.2f}%")dataiter = iter(test_loader)
images, labels = next(dataiter)
outputs = model(images)
_, predictions = torch.max(outputs,1)fig,axes = plt.subplots(1,6,figsize=(12,4))
for i in range(6):axes[i].imshow(images[i][0], cmap='gray')axes[i].set_title(f"Label: {labels[i]}\nPred: {predictions[i]}")axes[i].axis('off')
plt.show()
http://www.gsyq.cn/news/30937.html

相关文章:

  • 解包魔改pyinstaller
  • 反编译解包微信小程序
  • 浅谈C++中的作用域
  • 2025年锡条厂家推荐排行榜,高温抗氧化锡条,焊接专用锡条,电子行业锡条,工业级锡条公司精选
  • 2025年冠晶石厂家推荐排行榜,外墙冠晶石,内墙冠晶石,防霉冠晶石,水包水冠晶石,水包砂冠晶石,耐污冠晶石,自洁冠晶石公司推荐
  • 学弟欢乐赛 - T3 T4 题解
  • 2025年空调维保厂家推荐排行榜,空调维保/末端保养/空调保养/空调清洗/水处理公司专业服务与高效维护首选
  • 2025 ICPC Xian Regional Contest
  • 2025 年 10 月系统门窗厂商榜单揭晓,专业智造实力与品牌保障口碑优选
  • 2025年环境试验设备厂家推荐排行榜,冷热冲击/高低温/快速温变试验箱,氙灯/紫外耐候气候环境试验箱,步入式/恒温恒湿试验箱,高压加速老化/HAST/PCT试验箱,机械环境/淋雨/砂尘试验箱公司推荐
  • python3: ubuntu上安装时报错: No module named zlib
  • [java 锁 - 03 重入写法 ]
  • 2025年包装机厂家权威推荐榜:自动包装机,半自动包装机,高效包装设备源头厂家精选与选购指南
  • 完整教程:iOS 抓包工具有哪些?实战对比、场景分工与开发者排查流程
  • 使用pyautogui完成简单的游戏功能--皇室战争降杯
  • 彻底清除浏览器缓存
  • 2025 年 10 月系统门窗厂商榜单揭晓,专业工艺制造与品牌保障口碑优选
  • 实用指南:MySQL进阶知识点(八)---- SQL优化
  • 2025年饮料包装设备厂家权威推荐榜:缠膜机/吹瓶机/膜包机/杀菌机/水处理/套标机/贴标机/洗瓶机/卸垛机/旋盖机/液氮机/装箱机/灌装生产线/一条龙生产线/配件/灌装机
  • AI浏览器comet拉新,一单20美元(附详细教程)
  • 若依前后端分离版学习笔记(十八)——页面权限,页签缓存以及图标,字典,参数的利用
  • 【c++】红黑树的部分构建
  • ssh原理
  • 我的学习方式破局思考 ——读《认真听讲》、《做中学》与《做教练》有感
  • Unity协程除了实现功能还可以增加可读性
  • Nginx程序结构及核心配置
  • Nginx部署星益小游戏平台(静态页面)
  • 序列密码基本模型
  • 企业级Web应用及Nginx介绍
  • 11种排序算法的Python代码实现