当前位置: 首页 > news >正文

基于pytorch卷积神经网络的汉字识别系统

基于pytorch卷积神经网络的汉字识别系统

源代码如下(pycharm//附运行结果):

import os
import shutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score
import warnings
from tqdm import tqdm # 进度条显示

warnings.filterwarnings('ignore')


# ======================== 1. 配置参数 ========================
class Config:
# 数据路径配置
TXT_PATH = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/train.txt"
RAW_PNG_DIR = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/test_data"
OUTPUT_DATASET_ROOT = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/dataset"

# 训练参数配置
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 64 # GPU可用时用64,CPU用32
EPOCHS = 100
LR = 1e-4
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "saved_models"
SAVE_INTERVAL = 10
ROTATION_DEGREES = 5
TRANSLATE = (0.05, 0.05)


# 创建必要目录
os.makedirs(Config.SAVE_DIR, exist_ok=True)


# ======================== 2. 数据集处理 ========================
def process_train_txt_and_generate_dataset():
print("===== 开始处理数据集 =====")
for split in ['train', 'val', 'test']:
os.makedirs(os.path.join(Config.OUTPUT_DATASET_ROOT, split), exist_ok=True)

data = []
with open(Config.TXT_PATH, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
png_rel_path, text = line.split('\t', 1)
png_filename = os.path.basename(png_rel_path)
if text:
first_char = text[0]
data.append((png_filename, first_char))
else:
print(f"⚠️ 跳过空文本:{png_rel_path}")

char_groups = {}
for png_filename, first_char in data:
if first_char not in char_groups:
char_groups[first_char] = []
char_groups[first_char].append(png_filename)

total_images = 0
for char, png_list in char_groups.items():
random.shuffle(png_list)
total = len(png_list)
total_images += total
train_num = int(total * 0.7)
val_num = int(total * 0.2)

for i, png_filename in enumerate(png_list):
src_path = os.path.join(Config.RAW_PNG_DIR, png_filename)
if not os.path.exists(src_path):
print(f"⚠️ 跳过不存在的文件:{src_path}")
continue

if i < train_num:
split = 'train'
elif i < train_num + val_num:
split = 'val'
else:
split = 'test'

dst_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, split, char)
os.makedirs(dst_dir, exist_ok=True)
shutil.copy(src_path, os.path.join(dst_dir, png_filename))

print(f"✅ 数据集处理完成!共处理 {total_images} 张图像,{len(char_groups)} 个汉字类别")
print(f" 数据集目录:{Config.OUTPUT_DATASET_ROOT}")
return char_groups


# 仅首次运行时处理数据集,后续可注释
char_groups = process_train_txt_and_generate_dataset()

# ======================== 3. 数据加载 ========================
CHINESE_CHARS = sorted(char_groups.keys())
CHAR_TO_IDX = {char: idx for idx, char in enumerate(CHINESE_CHARS)}
IDX_TO_CHAR = {idx: char for idx, char in enumerate(CHINESE_CHARS)}
NUM_CLASSES = len(CHINESE_CHARS)
print(f"\n===== 模型配置 =====")
print(f" 识别类别数:{NUM_CLASSES},示例汉字:{CHINESE_CHARS[:10]}...")


class ChineseCharDataset(Dataset):
def __init__(self, data_dir, char_to_idx, transform=None):
self.data_dir = data_dir
self.char_to_idx = char_to_idx
self.transform = transform
self.image_paths = []
self.labels = []

for char in os.listdir(data_dir):
char_dir = os.path.join(data_dir, char)
if not os.path.isdir(char_dir) or char not in char_to_idx:
continue
for img_name in os.listdir(char_dir):
if img_name.endswith(".png"):
self.image_paths.append(os.path.join(char_dir, img_name))
self.labels.append(char_to_idx[char])

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
img = Image.open(self.image_paths[idx]).convert("L")
label = self.labels[idx]
if self.transform:
img = self.transform(img)
return img, torch.tensor(label, dtype=torch.long)


def get_transforms():
train_transform = transforms.Compose([
transforms.Resize(Config.IMAGE_SIZE),
transforms.RandomRotation(Config.ROTATION_DEGREES),
transforms.RandomAffine(0, translate=Config.TRANSLATE),
transforms.RandomResizedCrop(Config.IMAGE_SIZE, scale=(0.9, 1.0)),
transforms.ToTensor(),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.05)),
transforms.Normalize(mean=[0.5], std=[0.5])
])
val_test_transform = transforms.Compose([
transforms.Resize(Config.IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
return train_transform, val_test_transform


train_transform, val_test_transform = get_transforms()
train_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "train")
val_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "val")
test_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "test")

train_dataset = ChineseCharDataset(train_dir, CHAR_TO_IDX, train_transform)
val_dataset = ChineseCharDataset(val_dir, CHAR_TO_IDX, val_test_transform)
test_dataset = ChineseCharDataset(test_dir, CHAR_TO_IDX, val_test_transform)

# Windows系统禁用多进程(解决路径问题)
train_loader = DataLoader(
train_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)
val_loader = DataLoader(
val_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)
test_loader = DataLoader(
test_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)

print(f"\n===== 数据集加载 =====")
print(f" 训练集:{len(train_dataset)} 张图像")
print(f" 验证集:{len(val_dataset)} 张图像")
print(f" 测试集:{len(test_dataset)} 张图像")


# ======================== 4. 模型定义 ========================
class ImprovedChineseCharCNN(nn.Module):
def __init__(self, num_classes):
super(ImprovedChineseCharCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05),

nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05),

nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05)
)

dummy = torch.randn(1, 1, Config.IMAGE_SIZE[0], Config.IMAGE_SIZE[1])
self.fc_input_dim = self.conv_layers(dummy).view(1, -1).size(1)

self.fc_layers = nn.Sequential(
nn.Linear(self.fc_input_dim, 1024),
nn.ReLU(inplace=True),
nn.BatchNorm1d(1024),
nn.Dropout(0.2),

nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.2),

nn.Linear(512, num_classes)
)

def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x


model = ImprovedChineseCharCNN(NUM_CLASSES).to(Config.DEVICE)
print(f"\n===== 模型信息 =====")
print(f" 设备:{Config.DEVICE}")
print(f" 模型结构:{model}")


# ======================== 5. 训练与评估函数 ========================
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train()
total_loss, all_preds, all_labels = 0.0, [], []
for images, labels in tqdm(train_loader, desc="训练中", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item() * images.size(0)
all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(train_loader.dataset)
acc = accuracy_score(all_labels, all_preds)
return avg_loss, acc


def evaluate(model, dataloader, criterion, device, split="验证"):
model.eval()
total_loss, all_preds, all_labels = 0.0, [], []
with torch.no_grad():
for images, labels in tqdm(dataloader, desc=f"{split}中", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

total_loss += loss.item() * images.size(0)
all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(dataloader.dataset)
acc = accuracy_score(all_labels, all_preds)
return avg_loss, acc


# ======================== 新增:输出识别文字结果 ========================
def print_recognition_results(model, dataloader, device, idx_to_char, num_samples=5):
"""随机打印指定数量样本的识别结果(预测文字 vs 真实文字)"""
model.eval()
samples_shown = 0
# 随机打乱数据顺序,避免每次打印相同样本
random_indices = random.sample(range(len(dataloader.dataset)), min(num_samples, len(dataloader.dataset)))

with torch.no_grad():
for idx in random_indices:
# 获取单个样本
image, label = dataloader.dataset[idx]
image = image.unsqueeze(0).to(device) # 增加批次维度
output = model(image)
pred_idx = torch.argmax(output, 1).cpu().item() # 预测索引
true_idx = label.item() # 真实索引

# 转换为文字
pred_char = idx_to_char[pred_idx]
true_char = idx_to_char[true_idx]

# 打印结果
print(f"样本 {samples_shown + 1}:预测='{pred_char}',真实='{true_char}',"
f"{'✅' if pred_char == true_char else '❌'}")
samples_shown += 1
if samples_shown >= num_samples:
break


# ======================== 6. 主训练函数(支持断点续训) ========================
def main_train(load_from_checkpoint=True, checkpoint_path="saved_models/best_model.pth"):
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
model.parameters(),
lr=Config.LR,
weight_decay=1e-4
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
patience=3,
factor=0.5
)

best_val_acc = 0.0
start_epoch = 1

if load_from_checkpoint and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
if "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "val_acc" in checkpoint:
best_val_acc = checkpoint["val_acc"]
if "epoch" in checkpoint:
start_epoch = checkpoint["epoch"] + 1
print(f"📌 已加载历史模型,从第{start_epoch}轮继续训练(历史最佳准确率:{best_val_acc:.4f})")

print(f"\n===== 开始训练 =====")
for epoch in range(start_epoch, Config.EPOCHS + 1):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, Config.DEVICE)
val_loss, val_acc = evaluate(model, val_loader, criterion, Config.DEVICE, split="验证")

scheduler.step(val_acc)

print(f"Epoch [{epoch:3d}/{Config.EPOCHS}] | "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | "
f"LR: {optimizer.param_groups[0]['lr']:.6f}")

if epoch % Config.SAVE_INTERVAL == 0:
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": val_acc
}, os.path.join(Config.SAVE_DIR, f"model_epoch_{epoch}.pth"))
print(f"💾 已保存第{epoch}轮模型")

if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": best_val_acc,
"char_to_idx": CHAR_TO_IDX,
"idx_to_char": IDX_TO_CHAR
}, os.path.join(Config.SAVE_DIR, "best_model.pth"))
print(f"🌟 最佳模型更新(Val Acc: {best_val_acc:.4f})")

# 训练完成后测试并输出识别结果
best_model_path = os.path.join(Config.SAVE_DIR, "best_model.pth")
if os.path.exists(best_model_path):
best_model = torch.load(best_model_path)
model.load_state_dict(best_model["model_state_dict"])
test_loss, test_acc = evaluate(model, test_loader, criterion, Config.DEVICE, split="测试")
print(f"\n===== 训练完成 =====")
print(f" 测试集准确率:{test_acc:.4f}")

# 调用新增函数,输出5个样本的识别文字
print(f"\n===== 随机抽取5个测试样本的识别结果 =====")
print_recognition_results(model, test_loader, Config.DEVICE, IDX_TO_CHAR, num_samples=5)
else:
print("\n⚠️ 未找到最佳模型文件")


# ======================== 启动训练 == ======================
if __name__ == "__main__":
main_train(load_from_checkpoint=True)

////准确率达90%以上////

http://www.gsyq.cn/news/43082.html

相关文章:

  • 2025年热门成人自考机构推荐
  • CANopen转Profinet是一种构建于控制局域网设备之上的协议网关
  • 2025年国内成人自考机构口碑推荐榜单:如何选择靠谱的学历提升平台
  • 2025年11月星光喷头厂家推荐排行榜:专业选购与维护指南
  • Spring Cloud Alibaba + Sentinel
  • 德鲁克管理哲学:管理是知行统一的实践创新 - 详解
  • 2025 年 11 月食堂送菜平台推荐排行榜,送菜上门,食堂送菜公司,饭堂送菜平台,专业高效与新鲜直达服务口碑之选
  • 2025 年 11 月电能质量分析仪厂家推荐排行榜,A类/B类电能质量分析仪,动态电能质量监测仪,三相电能质量分析仪,在线检测装置系统公司推荐
  • 2025 年 11 月开窗器厂家推荐排行榜,链条开窗器,机芯开窗器,配件开窗器,优质开窗器公司推荐
  • 2025 年 11 月包装机厂家推荐排行榜,全自动/定量/FFS/25公斤/粉料/颗粒料/肥料/树脂/抽真空/底充式/锂电/零排放/吨袋包装机公司推荐
  • 2025 年 11 月码垛机厂家推荐排行榜,全自动码垛机,高位码垛机,低位码垛机,立柱码垛机,编织袋码垛机,纸箱码垛机,桶码垛机,粉料码垛机,肥料码垛机公司推荐
  • 2025 年 11 月包装称厂家推荐排行榜,全自动/定量/FFS重膜/高速/锂电/零排放/螺旋/吨袋包装称,铜精粉/肥料吨包包装称公司精选
  • gxyz圣经
  • 涡街流量计温度数据的协议桥梁:ModbusRTU转Profinet网关的自动化应用
  • git 添加大文件
  • 第一周--3:使用远程终端登录系统(ubuntu和rocky),并且总结linux系统基础命令
  • 2025年聚硅氧烷漆批发厂家权威推荐榜单:聚硅氮烷漆/防腐油漆厂家/工业防腐漆源头厂家精选
  • 2025 年 11 月民航机票购买,儿童机票购买,国内机票预定平台最新推荐,聚焦资质、服务与口碑的深度解析!
  • 权威认证!EasyCVR平台检测全达标,GB/T28181合规实力再升级
  • mongo内存
  • OIFC 2025.11.7 模拟赛总结
  • Linux - 9 定时任务篇(crontab)
  • Elasticsearch、OpenSearch 与 Easysearch:三代搜索引擎的演化与抉择 - 指南
  • 分布式专题——35 Netty的使用和常用组件辨析 - 详解
  • 2025年11月油脂提取设备知名品牌与破碎仪厂家介绍
  • 开发笔记|PHP+AJAX前后端交互调试的关键注意事项
  • 2025年耐用的高精度内圆磨床订制厂家权威推荐榜单:比较好的高精度内圆磨床/好的高精度内圆磨床/靠谱的高精度内圆磨床源头厂家精选
  • 工业主板VS商用主板:五大核心差异,选对才能高效运行
  • 2025 年最新推荐!国内胶粘剂源头厂家优质品牌排行榜:聚焦实力厂商,助力企业精准选品水性胶粘剂 / 电子胶粘剂 / 注塑胶粘剂公司推荐
  • 【IEEE出版|往届均已完成EI检索】第四届地理信息与遥感技术国际学术会议(GIRST 2025)