之前一直对recurrent neural network不了解,所以最近研究pytorch的文档,试着自己用rnn来做mnist的分类,代码如下

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
import datetime


# %%
trainset = torchvision.datasets.MNIST('./mnist_data',download=True,train=True)
testset = torchvision.datasets.MNIST('./mnist_data',download=True,train=False)


# %%
device = 'cuda:0'
def trans(data):
    ims = [np.array(i[0]) for i in data]
    ims = np.stack(ims)
    ims = ims/255
    labels = np.array([i[1] for i in data])
    ims = torch.from_numpy(ims).float()
    labels = torch.from_numpy(labels).long()
    return ims.to(device),labels.to(device)


# %%
trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,
                                        shuffle=True,collate_fn=trans)

testloader = torch.utils.data.DataLoader(testset, batch_size=500,
                                        shuffle=False,collate_fn=trans)


# %%
class RNNModel(nn.Module):
    def __init__(self):
        super(RNNModel,self).__init__()
        self.rnn = nn.LSTM(input_size=28,hidden_size=400,num_layers=1,batch_first=True,bidirectional=False)
        self.linear = nn.Linear(in_features=400,out_features=10)
    def forward(self,input):
        #inputshape : N*28*28
        input = input.permute(0,2,1)
        o,(h,c) = self.rnn(input)
        #print('o:',o.shape)
        #print('h:',h.shape)
        #print('c:',c.shape)
        output = torch.flatten(h.permute(1,0,2),1,2)
        output = self.linear(output)
        output = F.relu(output)
        return output


# %%
model = RNNModel().to(device)
for x,y in trainloader:
    plt.imshow(x[0].cpu().numpy())
    o = model(x)
    print(o.shape)
    break


# %%
model = RNNModel().to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()
epoch = 30
writer = SummaryWriter('./runs/rnn_mnist '+str(datetime.datetime.now()).replace(':','_'))
for epoch_idx in range(epoch):
    for index,(x,y) in enumerate(trainloader):
        output = model(x)
        loss = criterion(output,y)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        writer.add_scalar('train/loss',loss.item(),epoch_idx*len(trainloader)+index)
    with torch.no_grad():
        correct_cnt = 0
        test_mean_loss = 0
        for index,(x,y) in enumerate(testloader):
            output = model(x)
            loss = criterion(output,y)
            correct_cnt += torch.sum(torch.argmax(output,dim=1)==y).item()
            writer.add_scalar('test/loss',loss.item(),epoch_idx*len(testloader)+index)
            test_mean_loss += loss.item()
        writer.add_scalar('test/loss_per_epoch',test_mean_loss/len(testloader),epoch_idx)
        writer.add_scalar('test/acc',correct_cnt/len(testloader)/testloader.batch_size,epoch_idx)


# %%

跑的时候用的就是如上的参数,相同的参数跑了四回,结果如下

train loss:

test loss:

test loss per epoch:

test acc:

同样的参数,运行四回,其中三回都可以达到98.5%以上的准确率,但是有一回准确率停留在了80%,可能继续训练更多个epoch能让它也同样到98.5%以上吧。同样的参数和epoch数,训练结果有差异,说明还是有一定的随机性的。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注