博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
9.PyTorch实现MNIST(手写数字识别)(2卷积1全连接)
阅读量:4203 次
发布时间:2019-05-26

本文共 5179 字,大约阅读时间需要 17 分钟。

0 写在前面

0.1 流程图

在这里插入图片描述

在这里插入图片描述

0.2 文中代码解释

  • 1、torch.nn.Conv2d(1,10,kernel_size=3,stride=2,bias=False)

    • 1是指输入的Channel,灰色图像是1维的;
    • 10是指输出的Channel,也可以说第一个卷积层需要10个卷积核;
    • kernel_size=3,卷积核大小是3x3;stride=2进行卷积运算时的步长,默认为1;
    • bias=False卷积运算是否需要偏置bias,默认为False。
    • padding = 0,卷积操作是否补0。
  • 2、self.fc = torch.nn.Linear(320, 10)

    • 这个320获取的方式,可以通过x = x.view(batch_size, -1)
    • print(x.shape)可得到(64,320),64指的是batch,320就是指要进行全连接操作时,输入的特征维度。

0.3 注意事项

卷积(线性变换),激活函数(非线性变换)

1 prepare dataset

import torchfrom torchvision import transformsfrom torchvision import datasetsfrom torch.utils.data import DataLoaderimport torch.nn.functional as Fimport torch.optim as optim # prepare dataset batch_size = 64transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = datasets.MNIST(root='./资料/data/mnist/', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='./资料/data/mnist/', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
D:\common_software\Anaconda\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:180.)  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

2 design model using class

class Net(torch.nn.Module):    def __init__(self):        super(Net, self).__init__()                # 卷积1操作,1个通道,10个卷积核,卷积核大小5*5        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)                # 卷积2操作        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)                # 最大池化        self.pooling = torch.nn.MaxPool2d(2)                # 全连接        self.fc = torch.nn.Linear(320, 10)      def forward(self, x):        # flatten data from (n,1,28,28) to (n, 784)        batch_size = x.size(0)        x = F.relu(self.pooling(self.conv1(x)))        x = F.relu(self.pooling(self.conv2(x)))        x = x.view(batch_size, -1) # -1 此处自动算出的是320        x = self.fc(x)         return x model = Net()
model
Net(  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))  (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  (fc): Linear(in_features=320, out_features=10, bias=True))

3 construct loss and optimizer

# construct loss and optimizercriterion = torch.nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

4 training cycle forward, backward, update

def train(epoch):    running_loss = 0.0    for batch_idx, data in enumerate(train_loader, 0):        inputs, target = data        optimizer.zero_grad()         outputs = model(inputs)        loss = criterion(outputs, target)        loss.backward()        optimizer.step()         running_loss += loss.item()        if batch_idx % 300 == 299:            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))            running_loss = 0.0  def test():    correct = 0    total = 0    with torch.no_grad():        for data in test_loader:            images, labels = data            outputs = model(images)            _, predicted = torch.max(outputs.data, dim=1)            total += labels.size(0)            correct += (predicted == labels).sum().item()    print('accuracy on test set: %d %% ' % (100*correct/total))
if __name__ == '__main__':    for epoch in range(10):        train(epoch)        test()
D:\common_software\Anaconda\lib\site-packages\torch\nn\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ..\c10/core/TensorImpl.h:1156.)  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)[1,   300] loss: 0.691[1,   600] loss: 0.211[1,   900] loss: 0.155accuracy on test set: 96 % [2,   300] loss: 0.124[2,   600] loss: 0.108[2,   900] loss: 0.099accuracy on test set: 97 % [3,   300] loss: 0.087[3,   600] loss: 0.081[3,   900] loss: 0.079accuracy on test set: 97 % [4,   300] loss: 0.068[4,   600] loss: 0.070[4,   900] loss: 0.069accuracy on test set: 97 % [5,   300] loss: 0.065[5,   600] loss: 0.058[5,   900] loss: 0.060accuracy on test set: 98 % [6,   300] loss: 0.055[6,   600] loss: 0.055[6,   900] loss: 0.056accuracy on test set: 98 % [7,   300] loss: 0.052[7,   600] loss: 0.048[7,   900] loss: 0.051accuracy on test set: 98 % [8,   300] loss: 0.048[8,   600] loss: 0.044[8,   900] loss: 0.050accuracy on test set: 98 % [9,   300] loss: 0.040[9,   600] loss: 0.048[9,   900] loss: 0.044accuracy on test set: 98 % [10,   300] loss: 0.036[10,   600] loss: 0.041[10,   900] loss: 0.046accuracy on test set: 98 %

转载地址:http://ufali.baihongyu.com/

你可能感兴趣的文章
【Error】chsh: PAM: Authentication failure
查看>>
【Error】zsh历史记录丢失
查看>>
解析漏洞总结
查看>>
有趣的二进制 读书笔记
查看>>
记一次vmware磁盘扩容part2:真正扩展根目录
查看>>
【Error】zsh: corrupt history file /home/myusername/.zsh_history
查看>>
记一次编译linux 2.6 和4.10内核源码
查看>>
【Error】couldn't be accessed by user '_apt'. - pkgAcquire::Run (13: Permission denied) [duplicate]
查看>>
qemu 文件系统制作:自己制作根目录和应用程序 + busybox
查看>>
关闭CSDN广告必备插件:adblock plus
查看>>
【pwnable.kr】fd
查看>>
【pwnable.kr】 collision
查看>>
【pwnable.kr】bof
查看>>
【pwnable.kr】flag
查看>>
【pwnable.kr】 passcode
查看>>
【pwnable.kr】input
查看>>
【Windows C++】调用powershell上传指定目录下所有文件
查看>>
【Error】ropgadget依赖选项capstone报错ImportError: ERROR: fail to load the dynamic library.
查看>>
【Error】西部数据磁盘插上不显示盘符
查看>>
【Windows C++】powershell 获取chrome密码并上传
查看>>