本文共 5179 字,大约阅读时间需要 17 分钟。
1、torch.nn.Conv2d(1,10,kernel_size=3,stride=2,bias=False)
2、self.fc = torch.nn.Linear(320, 10)
卷积(线性变换),激活函数(非线性变换)
import torchfrom torchvision import transformsfrom torchvision import datasetsfrom torch.utils.data import DataLoaderimport torch.nn.functional as Fimport torch.optim as optim # prepare dataset batch_size = 64transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = datasets.MNIST(root='./资料/data/mnist/', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='./资料/data/mnist/', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
D:\common_software\Anaconda\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:180.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() # 卷积1操作,1个通道,10个卷积核,卷积核大小5*5 self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5) # 卷积2操作 self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5) # 最大池化 self.pooling = torch.nn.MaxPool2d(2) # 全连接 self.fc = torch.nn.Linear(320, 10) def forward(self, x): # flatten data from (n,1,28,28) to (n, 784) batch_size = x.size(0) x = F.relu(self.pooling(self.conv1(x))) x = F.relu(self.pooling(self.conv2(x))) x = x.view(batch_size, -1) # -1 此处自动算出的是320 x = self.fc(x) return x model = Net()
model
Net( (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1)) (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (fc): Linear(in_features=320, out_features=10, bias=True))
# construct loss and optimizercriterion = torch.nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch): running_loss = 0.0 for batch_idx, data in enumerate(train_loader, 0): inputs, target = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, target) loss.backward() optimizer.step() running_loss += loss.item() if batch_idx % 300 == 299: print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300)) running_loss = 0.0 def test(): correct = 0 total = 0 with torch.no_grad(): for data in test_loader: images, labels = data outputs = model(images) _, predicted = torch.max(outputs.data, dim=1) total += labels.size(0) correct += (predicted == labels).sum().item() print('accuracy on test set: %d %% ' % (100*correct/total))
if __name__ == '__main__': for epoch in range(10): train(epoch) test()
D:\common_software\Anaconda\lib\site-packages\torch\nn\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)[1, 300] loss: 0.691[1, 600] loss: 0.211[1, 900] loss: 0.155accuracy on test set: 96 % [2, 300] loss: 0.124[2, 600] loss: 0.108[2, 900] loss: 0.099accuracy on test set: 97 % [3, 300] loss: 0.087[3, 600] loss: 0.081[3, 900] loss: 0.079accuracy on test set: 97 % [4, 300] loss: 0.068[4, 600] loss: 0.070[4, 900] loss: 0.069accuracy on test set: 97 % [5, 300] loss: 0.065[5, 600] loss: 0.058[5, 900] loss: 0.060accuracy on test set: 98 % [6, 300] loss: 0.055[6, 600] loss: 0.055[6, 900] loss: 0.056accuracy on test set: 98 % [7, 300] loss: 0.052[7, 600] loss: 0.048[7, 900] loss: 0.051accuracy on test set: 98 % [8, 300] loss: 0.048[8, 600] loss: 0.044[8, 900] loss: 0.050accuracy on test set: 98 % [9, 300] loss: 0.040[9, 600] loss: 0.048[9, 900] loss: 0.044accuracy on test set: 98 % [10, 300] loss: 0.036[10, 600] loss: 0.041[10, 900] loss: 0.046accuracy on test set: 98 %
转载地址:http://ufali.baihongyu.com/