Skip to content

Instantly share code, notes, and snippets.

@qiaoxu123
Last active June 15, 2025 12:07
Show Gist options
  • Select an option

  • Save qiaoxu123/d35d35414df45158ac06699e9fea13cf to your computer and use it in GitHub Desktop.

Select an option

Save qiaoxu123/d35d35414df45158ac06699e9fea13cf to your computer and use it in GitHub Desktop.
手写数字识别
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
import tkinter as tk
from PIL import Image, ImageDraw, ImageOps
import numpy as np
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "lenet5_mnist.pth"
# 定义LeNet-5模型
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(256,120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84,10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 256)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 实例化模型
model = LeNet5().to(device)
# 如果模型文件存在,则加载
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("模型已加载,无需重新训练。")
else:
# 数据准备
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 优化器与损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 训练模型
print("开始训练模型...")
epochs = 5
for epoch in range(epochs):
running_loss = 0.0
model.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')
# 保存模型
torch.save(model.state_dict(), model_path)
print("模型已训练并保存。")
# 测试准确率
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'测试准确率: {100 * correct / total:.2f}%')
# 可视化手写识别(Tkinter)
def predict_digit(img):
img = img.resize((28, 28)).convert('L')
img = ImageOps.invert(img)
img = TF.to_tensor(img).unsqueeze(0)
img = TF.normalize(img, [0.1307], [0.3081])
img = img.to(device)
with torch.no_grad():
output = model(img)
pred = torch.argmax(output, dim=1)
return pred.item()
# 手写板类
class App(tk.Tk):
def __init__(self):
super().__init__()
self.title("手写数字识别")
self.canvas = tk.Canvas(self, width=200, height=200, bg="white")
self.canvas.pack()
self.image = Image.new("RGB", (200, 200), "white")
self.draw = ImageDraw.Draw(self.image)
self.canvas.bind("<B1-Motion>", self.paint)
tk.Button(self, text="识别", command=self.recognize).pack()
tk.Button(self, text="清除", command=self.clear).pack()
self.result = tk.Label(self, text="", font=("Helvetica", 20))
self.result.pack()
def paint(self, event):
x, y = event.x, event.y
r = 8
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill='black')
self.draw.ellipse([x - r, y - r, x + r, y + r], fill='black')
def recognize(self):
digit = predict_digit(self.image)
self.result.config(text=f"识别结果:{digit}")
def clear(self):
self.canvas.delete("all")
self.draw.rectangle([0, 0, 200, 200], fill="white")
self.result.config(text="")
# 启动GUI
App().mainloop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment