当前位置: 首页 > news >正文

人工智能大作业:植物病害检测系统

本项目使用卷积神经网络算法实现了植物病害检测系统,下面我将以代码来详细说明实现思路

首先,本项目核心算法就是Resnet 50+迁移学习+数据增强

我使用了公共数据集PlantVillage / New Plant Diseases,在该数据集上进行训练,实现植物叶子病害的自动诊断。

数据预处理:dataset.py

from torchvision import datasets, transforms
from torch.utils.data import DataLoaderdef get_transforms():train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])return train_transform, val_transformdef get_dataloaders(data_dir='data', batch_size=32, num_workers=4):train_transform, val_transform = get_transforms()# ==================== 情况B 的正确路径 ====================train_path = f'{data_dir}/train'val_path   = f'{data_dir}/valid'# ======================================================
    train_dataset = datasets.ImageFolder(root=train_path, transform=train_transform)val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)print(f" 数据集加载完成")print(f"类别数量: {len(train_dataset.classes)}")print(f"训练集图片数量: {len(train_dataset)}")print(f"验证集图片数量: {len(val_dataset)}")print(f"第一个类别示例: {train_dataset.classes[0]}")return train_loader, val_loader, train_dataset.classes

  • RandomResizedCrop(224):随机裁剪并缩放到 224×224(ResNet 输入标准尺寸)
  • RandomHorizontalFlip():随机水平翻转,模拟叶子不同方向
  • RandomRotation(15):随便挑一个小角度旋转
  • ColorJitter:调整亮度、对比度、饱和度、色调,增加对光照变化的鲁棒性(鲁棒性是指是指一个计算机系统在执行过程中处理错误,以及算法在遭遇输入、运算等异常时维持正常运行的能力。简单来说就是稳定性,参数来自于维基百科)
  • Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):使用 ImageNet 预训练均值和标准差(迁移学习必须一致)
  • 训练时随机变换,保证训练质量,同时防止过拟合
  • 但是验证的时候要使用固定裁剪模式,保证验证成果的稳定性

 

定义获取神经网络模型:model.py

 

import torch.nn as nn
from torchvision import modelsdef get_model(num_classes=38,model_name='resnet50'):if model_name == 'resnet50':model = models.resnet50(weights='IMAGENET1K_V1')model.fc = nn.Linear(model.fc.in_features,num_classes)elif model_name == 'mobilenet_v3':model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')model.classifier[3] = nn.Linear(model.classifier[3].in_features,num_classes)return model    

 

  • 残差网络(Residual Network):核心创新是残差连接(Shortcut Connection),解决深度网络的“退化问题”(Degradation Problem)和梯度消失。

    残差连接的核心思想是引入一个“快捷连接”(shortcut connection)或“跳跃连接”(skip connection),允许数据绕过一些层直接传播。这样,网络中的一部分可以直接学习到输入与输出之间的残差(即差异),而不是直接学习到映射本身。具体来说,如果我们希望学习的目标映射是 H(x),我们让网络学习残差映射 F(x)=H(x)−x。因此,原始的目标映射可以表示为 F(x)+x。

  • 公式:y = F(x) + x(残差块),让网络更容易学习恒等映射。
  • ResNet50 有 50 层,包含多个 Bottleneck 残差块。
  • 迁移学习:利用在 ImageNet(1400 万张图片)上预训练的权重,只替换最后的全连接层(fc),大幅减少训练时间和数据需求,同时获得优秀的特征提取能力。

为什么选 ResNet50?

  • 精度高、结构成熟、在医学/农业图像任务中表现优秀
  • 参数量适中(约 25M),适合大多数显卡

 

具体训练过程:train.py

import torch
from torch import nn, optim
from tqdm import tqdm
import torchmetrics
from src.dataset import get_dataloaders
from src.model import get_model
import osdef train():device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用设备: {device}")train_loader, val_loader, class_names = get_dataloaders(batch_size=64)model = get_model(num_classes=len(class_names), model_name='resnet50')model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)best_acc = 0.0os.makedirs('models', exist_ok=True)for epoch in range(20):  # 可根据需要增加 epochs
        model.train()running_loss = 0.0for inputs, labels in tqdm(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 验证
        model.eval()accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=len(class_names)).to(device)with torch.no_grad():for inputs, labels in val_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)preds = outputs.argmax(dim=1)accuracy.update(preds, labels)acc = accuracy.compute().item()scheduler.step()print(f"Epoch {epoch+1}/20 | Loss: {running_loss/len(train_loader):.4f} | Val Acc: {acc:.4f}")if acc > best_acc:best_acc = acctorch.save(model.state_dict(), 'models/best_plant_disease.pth')print(" 保存最佳模型")print(f"训练完成!最佳验证准确率: {best_acc:.4f}")return model, class_namesimport jsonclass_names_path = 'models/class_names.json'with open(class_names_path, 'w', encoding='utf-8') as f:json.dump(class_names, f, ensure_ascii=False, indent=2)print(f"class_names 已保存到 {class_names_path}")return model, class_names
# 在 src/train.py 的 train() 函数末尾(训练完成后)添加:# 保存 class_names

 

作用:完整训练流程、模型保存、日志记录。

 

使用的算法和技术

 

  1. 损失函数:CrossEntropyLoss(交叉熵损失)
    • 多分类任务的标准损失
  2. 优化器:AdamW
    • Adam + Weight Decay(权重衰减)
    • 比传统 Adam 更好地处理正则化,防止过拟合
  3. 学习率调度器:CosineAnnealingLR
    • 余弦退火:学习率按余弦曲线下降,能帮助模型在后期更精细收敛(模拟退火:模拟物理降温过程参数的变化,基于能量最低原理,可以让参数以随机且靠近最小的方法改变,适合解决单峰问题,不适合多峰问题)
  4. 评估指标:torchmetrics.Accuracy(多分类准确率)
  5. 训练流程
    • model.train() / model.eval() —— 切换模式(影响 BatchNorm、Dropout)
    • torch.no_grad() —— 验证时关闭梯度计算,节省显存
    • optimizer.zero_grad() → loss.backward() → optimizer.step()(标准反向传播)
  6. Early Saving:只保存验证准确率最好的模型(防止过拟合)

外部模型接口:train.py

from src.train import trainif __name__ == "__main__":model, classes = train()

外部模型接口,方便调用

 

部署模块:app.py

import gradio as gr
import torch
from PIL import Image
import jsonfrom src.model import get_model
from src.dataset import get_transforms# ==================== 全局加载模型和类别 ====================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载 class_names
with open('models/class_names.json', 'r', encoding='utf-8') as f:class_names = json.load(f)# 加载模型
def load_model():num_classes = len(class_names)model = get_model(num_classes=num_classes, model_name='resnet50')model.load_state_dict(torch.load('models/best_plant_disease.pth', map_location=device))model.to(device)model.eval()return modelmodel = load_model()
_, val_transform = get_transforms()# ==================== 预测函数 ====================
def predict_image(image):if image is None:return "请上传图片"# 预处理input_tensor = val_transform(image).unsqueeze(0).to(device)with torch.no_grad():output = model(input_tensor)probabilities = torch.softmax(output[0], dim=0)confidence, predicted_idx = torch.max(probabilities, 0)predicted_class = class_names[predicted_idx.item()]confidence_pct = confidence.item() * 100if "healthy" in predicted_class.lower():result = f" **健康**\n\n类别:{predicted_class}\n置信度:{confidence_pct:.2f}%"else:result = f"**疑似病害**\n\n类别:{predicted_class}\n置信度:{confidence_pct:.2f}%"return result# ==================== Gradio 界面 ====================
interface = gr.Interface(fn=predict_image,inputs=gr.Image(type="pil", label="上传植物叶子照片"),outputs=gr.Textbox(label="诊断结果"),title=" 植物病害智能检测系统",description="上传一张叶子照片,AI 将帮助你判断是否生病及病害类型",examples=[["examples/healthy.jpg"], ["examples/diseased.jpg"]],  # 可选allow_flagging="never"
)if __name__ == "__main__":interface.launch(share=False)   # share=True 可生成公网链接

提供使用模型入口

 

http://www.gsyq.cn/news/1297565.html

相关文章:

  • CodeCursor配置全攻略:自定义API密钥与模型选择的最佳实践
  • TestableMock多场景应用:从基础Mock到复杂业务逻辑测试
  • Linux驱动开发:自旋锁实现GPIO LED互斥访问的实战解析
  • 终极指南:如何使用public-apis开源项目快速找到免费API资源
  • 3mux常见问题解决:10个用户最常遇到的错误及其修复方法
  • OMS-ERP库存WMS管理:实现库存共享与仓位优化的完整指南 [特殊字符]
  • 跟我一起学“仓颉”算法-二叉查找树练习题
  • 基于Adafruit Gemma M0与NeoPixel的可编程交互发光头饰制作全攻略
  • 参数失控?画风平庸?Midjourney抽象表现主义进阶必修课,含5套已验证Prompt模板+权重调试日志
  • AI写教材必备:低查重工具实测,30分钟生成10万字专业教材!
  • 5分钟掌握英雄联盟国服换肤:R3nzSkin完整解决方案
  • Opengrep性能优化终极指南:如何实现秒级代码扫描
  • 机器人基础模型 π0.7:一个模型做咖啡、叠衣服、洗盘子——通用机器人从「实验室」走进「厨房」
  • Microsoft-OpenAI 分手进行时:独家云合作终结,Sam Altman 抛「超级智能新政」——AI 行业进入多极时代
  • Apple Music JS核心组件深度解析:从播放器到界面交互
  • Bootstrap Application Wizard最佳实践总结:避免常见陷阱的15个要点
  • Spectre:支持编译时契约评估,可转换 C 代码的安全底层编程语言!
  • Promises/A+完全指南:深入理解JavaScript异步编程标准规范
  • 终极指南:如何让苹果触控板在Windows上获得专业级体验
  • ISG系统三大电机结构深度解析:永磁同步、感应与开关磁阻电机对比
  • 手机的智能体AI,正在因为天玑全面跃升
  • TestableMock与Kotlin完美结合:解决协程和扩展函数Mock难题终极指南
  • 海底生物检测-目标检测数据集(包括VOC格式、YOLO格式)
  • 今起,老年旅客12306购票有打折优惠服务!
  • 超越点灯:用JTAG调试XCZU3EG MPSOC时,你可能会忽略的3个硬件细节与1个Vivado设置
  • 基于RK3568核心板的智能家居控制器:从芯片选型到量产实战
  • RT-Thread Smart在QEMU RISC-V虚拟机上的开发环境搭建与调试实践
  • Raiden Network API开发教程:构建去中心化应用的完整指南
  • React Native Picker Select 自定义扩展教程:创建专属选择器组件的3种方法
  • TIDoS-Framework核心架构解析:理解5个阶段的设计原理