import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch.autograd import Variable
In this blog let’s build a MNIST Classifier (one vs REST) where we will classify whether a given image is of the digit 2 or not.
Let’s first import the necessary libraries
# Download MNIST dataset
= MNIST(root='./data', download=True) mnist_data
# 60000 images of size 28x28
= mnist_data.data.numpy()
data data.shape
(60000, 28, 28)
= train_test_split(data, mnist_data.targets.numpy(), test_size=0.2, random_state=42)
train_image,test_image,train_label,test_label train_image.shape, train_label.shape, test_image.shape, test_label.shape
((48000, 28, 28), (48000,), (12000, 28, 28), (12000,))
= train_image[0]
sample_image ='gray') plt.imshow(sample_image, cmap
<matplotlib.image.AxesImage at 0x294a76e90>
print(train_label[12:17])
# 1 as label for digit 2 and 0 otherwise
!=2] = 0
train_label[train_label==2] = 1
train_label[train_label!=2] = 0
test_label[test_label==2] = 1
test_label[test_label
# Convert to torch tensors
= torch.from_numpy(train_image).long()
train_image = torch.from_numpy(test_image).long()
test_image = torch.from_numpy(train_label).float()
train_label = torch.from_numpy(test_label).float()
test_label 12:17] train_label[
[8 1 1 2 6]
tensor([0., 0., 0., 1., 0.])
= train_image.view(-1, 1,28,28).float()
train_image = test_image.view(-1,1,28,28).float()
test_image
= train_image/255.
train_image = test_image/255.
test_image
= torch.utils.data.TensorDataset(train_image,train_label)
train = torch.utils.data.TensorDataset(test_image,test_label)
test
= 64
BATCH_SIZE = torch.utils.data.DataLoader(train, batch_size = BATCH_SIZE, shuffle = False)
train_loader = torch.utils.data.DataLoader(test, batch_size = BATCH_SIZE, shuffle = False) test_loader
# Define the model
class cnn_Model(nn.Module):
def __init__(self):
super(cnn_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
self.pool = nn.MaxPool2d(2,stride=2)
self.conv2 = nn.Conv2d(6, 16, 5, padding=0)
self.R = nn.ReLU()
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 1)
self.S = nn.Sigmoid()
def forward(self, x):
= self.conv1(x)
x = self.R(x)
x = self.pool(x)
x
= self.conv2(x)
x = self.R(x)
x = self.pool(x)
x = x.view(-1, 400)
x
= self.fc1(x)
x = self.R(x)
x = self.fc2(x)
x = self.R(x)
x = self.fc3(x)
x # x = self.S(x)
return x.squeeze()
def accuracy(y_pred,y_true):
= y_pred.detach().numpy()
y_pred = y_true.detach().numpy()
y_true >=0.5] = 1
y_pred[y_pred<0.5] = 0
y_pred[y_predreturn y_pred,y_true
def fit(model, train_loader,learning_rate=0.001,random_seed=42):
torch.manual_seed(random_seed)= Adam(model.parameters(),lr=learning_rate)
optimizer = nn.BCEWithLogitsLoss()
error = 5
EPOCHS = []
predictions # model.train()
for epoch in range(EPOCHS):
= 0
correct for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
= Variable(X_batch).float()
var_X_batch = Variable(y_batch)
var_y_batch
optimizer.zero_grad()= model(var_X_batch)
output = error(output, var_y_batch)
loss
loss.backward()
optimizer.step()
= accuracy(output.data,var_y_batch)
y_pred,y_true
predictions.extend(y_pred)+= (output.data == var_y_batch).sum()
correct if batch_idx % 50 == 0:
print('Epoch : {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t Accuracy:{:.3f}%'.format(
*len(X_batch), len(train_loader.dataset), 100.*batch_idx / len(train_loader), loss.data.item(), float(correct*100) / float(BATCH_SIZE*(batch_idx+1))))
epoch, batch_idx
def instantiate_and_train(train_loader,Model,random_seed=42):
torch.manual_seed(random_seed)= Model()
model =random_seed)
fit(model,train_loader,random_seed
instantiate_and_train(train_loader,cnn_Model)
Epoch : 0 [0/48000 (0%)] Loss: 0.711395 Accuracy:92.188%
Epoch : 0 [3200/48000 (7%)] Loss: 0.320103 Accuracy:89.951%
Epoch : 0 [6400/48000 (13%)] Loss: 0.072405 Accuracy:90.470%
Epoch : 0 [9600/48000 (20%)] Loss: 0.037177 Accuracy:92.643%
Epoch : 0 [12800/48000 (27%)] Loss: 0.114377 Accuracy:93.905%
Epoch : 0 [16000/48000 (33%)] Loss: 0.013866 Accuracy:94.746%
Epoch : 0 [19200/48000 (40%)] Loss: 0.019222 Accuracy:95.359%
Epoch : 0 [22400/48000 (47%)] Loss: 0.003240 Accuracy:95.873%
Epoch : 0 [25600/48000 (53%)] Loss: 0.037296 Accuracy:96.228%
Epoch : 0 [28800/48000 (60%)] Loss: 0.054712 Accuracy:96.504%
Epoch : 0 [32000/48000 (67%)] Loss: 0.059329 Accuracy:96.775%
Epoch : 0 [35200/48000 (73%)] Loss: 0.007022 Accuracy:97.008%
Epoch : 0 [38400/48000 (80%)] Loss: 0.015875 Accuracy:97.166%
Epoch : 0 [41600/48000 (87%)] Loss: 0.054735 Accuracy:97.312%
Epoch : 0 [44800/48000 (93%)] Loss: 0.003340 Accuracy:97.437%
Epoch : 1 [0/48000 (0%)] Loss: 0.066569 Accuracy:98.438%
Epoch : 1 [3200/48000 (7%)] Loss: 0.006697 Accuracy:98.836%
Epoch : 1 [6400/48000 (13%)] Loss: 0.005530 Accuracy:99.087%
Epoch : 1 [9600/48000 (20%)] Loss: 0.082764 Accuracy:99.214%
Epoch : 1 [12800/48000 (27%)] Loss: 0.007963 Accuracy:99.285%
Epoch : 1 [16000/48000 (33%)] Loss: 0.001485 Accuracy:99.247%
Epoch : 1 [19200/48000 (40%)] Loss: 0.007370 Accuracy:99.232%
Epoch : 1 [22400/48000 (47%)] Loss: 0.000315 Accuracy:99.270%
Epoch : 1 [25600/48000 (53%)] Loss: 0.009400 Accuracy:99.299%
Epoch : 1 [28800/48000 (60%)] Loss: 0.012171 Accuracy:99.321%
Epoch : 1 [32000/48000 (67%)] Loss: 0.005416 Accuracy:99.345%
Epoch : 1 [35200/48000 (73%)] Loss: 0.000906 Accuracy:99.373%
Epoch : 1 [38400/48000 (80%)] Loss: 0.002171 Accuracy:99.373%
Epoch : 1 [41600/48000 (87%)] Loss: 0.081926 Accuracy:99.357%
Epoch : 1 [44800/48000 (93%)] Loss: 0.000962 Accuracy:99.376%
Epoch : 2 [0/48000 (0%)] Loss: 0.019608 Accuracy:98.438%
Epoch : 2 [3200/48000 (7%)] Loss: 0.003214 Accuracy:99.112%
Epoch : 2 [6400/48000 (13%)] Loss: 0.001873 Accuracy:99.319%
Epoch : 2 [9600/48000 (20%)] Loss: 0.093223 Accuracy:99.452%
Epoch : 2 [12800/48000 (27%)] Loss: 0.001541 Accuracy:99.502%
Epoch : 2 [16000/48000 (33%)] Loss: 0.000616 Accuracy:99.508%
Epoch : 2 [19200/48000 (40%)] Loss: 0.009107 Accuracy:99.481%
Epoch : 2 [22400/48000 (47%)] Loss: 0.000131 Accuracy:99.533%
Epoch : 2 [25600/48000 (53%)] Loss: 0.004520 Accuracy:99.556%
Epoch : 2 [28800/48000 (60%)] Loss: 0.003475 Accuracy:99.553%
Epoch : 2 [32000/48000 (67%)] Loss: 0.001568 Accuracy:99.557%
Epoch : 2 [35200/48000 (73%)] Loss: 0.000462 Accuracy:99.577%
Epoch : 2 [38400/48000 (80%)] Loss: 0.001937 Accuracy:99.568%
Epoch : 2 [41600/48000 (87%)] Loss: 0.062798 Accuracy:99.568%
Epoch : 2 [44800/48000 (93%)] Loss: 0.001069 Accuracy:99.572%
Epoch : 3 [0/48000 (0%)] Loss: 0.008862 Accuracy:98.438%
Epoch : 3 [3200/48000 (7%)] Loss: 0.003903 Accuracy:99.449%
Epoch : 3 [6400/48000 (13%)] Loss: 0.001003 Accuracy:99.489%
Epoch : 3 [9600/48000 (20%)] Loss: 0.070207 Accuracy:99.607%
Epoch : 3 [12800/48000 (27%)] Loss: 0.000268 Accuracy:99.658%
Epoch : 3 [16000/48000 (33%)] Loss: 0.000322 Accuracy:99.664%
Epoch : 3 [19200/48000 (40%)] Loss: 0.013567 Accuracy:99.631%
Epoch : 3 [22400/48000 (47%)] Loss: 0.000157 Accuracy:99.631%
Epoch : 3 [25600/48000 (53%)] Loss: 0.001247 Accuracy:99.642%
Epoch : 3 [28800/48000 (60%)] Loss: 0.001019 Accuracy:99.636%
Epoch : 3 [32000/48000 (67%)] Loss: 0.000877 Accuracy:99.638%
Epoch : 3 [35200/48000 (73%)] Loss: 0.000389 Accuracy:99.654%
Epoch : 3 [38400/48000 (80%)] Loss: 0.000376 Accuracy:99.649%
Epoch : 3 [41600/48000 (87%)] Loss: 0.052261 Accuracy:99.647%
Epoch : 3 [44800/48000 (93%)] Loss: 0.001428 Accuracy:99.655%
Epoch : 4 [0/48000 (0%)] Loss: 0.001804 Accuracy:100.000%
Epoch : 4 [3200/48000 (7%)] Loss: 0.004144 Accuracy:99.632%
Epoch : 4 [6400/48000 (13%)] Loss: 0.000248 Accuracy:99.675%
Epoch : 4 [9600/48000 (20%)] Loss: 0.056268 Accuracy:99.741%
Epoch : 4 [12800/48000 (27%)] Loss: 0.000263 Accuracy:99.775%
Epoch : 4 [16000/48000 (33%)] Loss: 0.000082 Accuracy:99.795%
Epoch : 4 [19200/48000 (40%)] Loss: 0.019946 Accuracy:99.740%
Epoch : 4 [22400/48000 (47%)] Loss: 0.000092 Accuracy:99.742%
Epoch : 4 [25600/48000 (53%)] Loss: 0.002394 Accuracy:99.751%
Epoch : 4 [28800/48000 (60%)] Loss: 0.000842 Accuracy:99.747%
Epoch : 4 [32000/48000 (67%)] Loss: 0.000388 Accuracy:99.747%
Epoch : 4 [35200/48000 (73%)] Loss: 0.000847 Accuracy:99.762%
Epoch : 4 [38400/48000 (80%)] Loss: 0.000160 Accuracy:99.761%
Epoch : 4 [41600/48000 (87%)] Loss: 0.032126 Accuracy:99.762%
Epoch : 4 [44800/48000 (93%)] Loss: 0.000399 Accuracy:99.766%
= []
predictions for idx, (images,labels) in enumerate(test_loader):
= cnn_Model()(images)
outputs = accuracy(outputs,labels)
y_pred,y_true
predictions.extend(y_pred)= np.array(predictions)
predictions predictions.shape
(12000,)
= plt.subplots(4,4,figsize=(10,10))
fig,ax = ax.flatten()
axes
for i in range(len(axes)):
28,28), cmap='gray')
axes[i].imshow(test_image[i].reshape(
axes[i].set_xticks([])
axes[i].set_yticks([])"Prediction: {}".format(predictions[i])) axes[i].set_title(