当前位置: 首页 > 编程笔记 >

pytorch 利用lstm做mnist手写数字识别分类的实例

白嘉石
2023-03-14
本文向大家介绍pytorch 利用lstm做mnist手写数字识别分类的实例,包括了pytorch 利用lstm做mnist手写数字识别分类的实例的使用技巧和注意事项,需要的朋友参考一下

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持小牛知识库。

 类似资料:
  • 什么是tensorflow tensor意思是张量,flow是流。 张量原本是力学里的术语,表示弹性介质中各点应力状态。在数学中,张量表示的是一种广义的“数量”,0阶张量就是标量(比如:0、1、2……),1阶张量就是向量(比如:(1,3,4)),2阶张量就是矩阵,本来这几种形式是不相关的,但是都归为张量,是因为他们同时满足一些特性:1)可以用坐标系表示;2)在坐标变换中遵守同样的变换法则;3)有着

  • 本文向大家介绍pytorch cnn 识别手写的字实现自建图片数据,包括了pytorch cnn 识别手写的字实现自建图片数据的使用技巧和注意事项,需要的朋友参考一下 本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下: 图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》 结果如下: 以上就是本文的全部内容,希望对大家的学习

  • 本文向大家介绍MNIST数据如何与TensorFlow一起用于识别手写数字?,包括了MNIST数据如何与TensorFlow一起用于识别手写数字?的使用技巧和注意事项,需要的朋友参考一下 Tensorflow是Google提供的一种机器学习框架。它是一个开放源代码框架,与Python结合使用以实现算法,深度学习应用程序等等。它用于研究和生产目的。 可以使用下面的代码行在Windows上安装'ten

  • 本文向大家介绍Pytorch入门之mnist分类实例,包括了Pytorch入门之mnist分类实例的使用技巧和注意事项,需要的朋友参考一下 本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持呐喊教程。

  • 本文向大家介绍Python实现识别手写数字大纲,包括了Python实现识别手写数字大纲的使用技巧和注意事项,需要的朋友参考一下 写在前面 其实我之前写过一个简单的识别手写数字的程序,但是因为逻辑比较简单,而且要求比较严苛,是在50x50大小像素的白底图上手写黑色数字,并且给的训练材料也不够多,导致准确率只能五五开。所以这一次准备写一个加强升级版的,借此来提升我对Python处理文件与图片的能力。

  • 本文向大家介绍tensorflow实现softma识别MNIST,包括了tensorflow实现softma识别MNIST的使用技巧和注意事项,需要的朋友参考一下 识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。 这次我们用tensorflow搭建一个softmax多分类器,和之前搭