import torch
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn import preprocessing
import os
from PIL import Image
from skimage.filters import sobel,laplace
import matplotlib.gridspec as gridspec
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import MNIST
try:
from einops import rearrange
except ImportError:
%pip install einops
from einops import rearrange
Image renconstruction using various activation functions
Let’s import the necessary libraries
Loading image
= MNIST(root='/root',train=True,download=True)
data data
Dataset MNIST
Number of datapoints: 60000
Root location: /root
Split: Train
data.data.shape
torch.Size([60000, 28, 28])
= data.data[0]
image image.shape
torch.Size([28, 28])
='gray')
plt.imshow(image,cmap plt.tight_layout()
= "cuda" if torch.cuda.is_available() else "cpu"
device print(device)
= preprocessing.MinMaxScaler().fit(image.reshape(-1, 1))
scaler_img = scaler_img.transform(image.reshape(-1, 1)).reshape(image.shape)
img_scaled = torch.tensor(img_scaled)
img_scaled = img_scaled.to(device)
img_scaled img_scaled.shape
cpu
torch.Size([28, 28])
= torch.unsqueeze(img_scaled,dim=0)
img_scaled min(),image.max(),img_scaled.min(),img_scaled.max() image.
(tensor(0, dtype=torch.uint8),
tensor(255, dtype=torch.uint8),
tensor(0., dtype=torch.float64),
tensor(1., dtype=torch.float64))
def create_coordinate_map(img):
"""
img: torch.Tensor of shape (num_channels, height, width)
return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
"""
= img.shape
num_channels, height, width
# Create a 2D grid of (x,y) coordinates (h, w)
# width values change faster than height values
= torch.arange(width).repeat(height, 1)
w_coords = torch.arange(height).repeat(width, 1).t()
h_coords = w_coords.reshape(-1)
w_coords = h_coords.reshape(-1)
h_coords
# Combine the x and y coordinates into a single tensor
= torch.stack([h_coords, w_coords], dim=1).float()
X
# Move X to GPU if available
= X.to(device)
X
# Reshape the image to (h * w, num_channels)
= rearrange(img, 'c h w -> (h w) c').float()
Y return X, Y
= create_coordinate_map(img_scaled)
Coords, pixels Coords.shape, pixels.shape
(torch.Size([784, 2]), torch.Size([784, 1]))
5],pixels[:5] Coords[:
(tensor([[0., 0.],
[0., 1.],
[0., 2.],
[0., 3.],
[0., 4.]]),
tensor([[0.],
[0.],
[0.],
[0.],
[0.]]))
# MinMaxScaler from -1 to 1
= preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(Coords.cpu())
scaler_X
# Scale the X coordinates
= scaler_X.transform(Coords.cpu())
Coords_scaled
# Move the scaled X coordinates to the GPU
= torch.tensor(Coords_scaled).to(device)
Coords_scaled
# Set to dtype float32
= Coords_scaled.float()
Coords_scaled min(),Coords.max() Coords.shape,Coords.
(torch.Size([784, 2]), tensor(0.), tensor(27.))
min(),Coords_scaled.max() Coords_scaled.shape,Coords_scaled.
(torch.Size([784, 2]), tensor(-1.), tensor(1.))
def gradient(y, x, grad_outputs=None):
if grad_outputs is None:
= torch.ones_like(y)
grad_outputs = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
grad return grad
def Laplace(y, x):
= gradient(y, x)
grad return divergence(grad, x)
def divergence(y, x):
= 0.
div for i in range(y.shape[-1]):
+= torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
div return div
# Model
class NN(nn.Module):
def _init_siren(self, omega):
self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
for layers in [self.fc2, self.fc3, self.fc4, self.fc5]:
-np.sqrt(6/self.fc2.in_features)/omega,
layers.weight.data.uniform_(6/self.fc2.in_features)/omega)
np.sqrt(
def __init__(self, activation=torch.sin, n_out=1, omega=1.0,s=128):
super().__init__()
self.activation = activation
self.omega = omega
self.fc1 = nn.Linear(2, s)
self.fc2 = nn.Linear(s, s)
self.fc3 = nn.Linear(s, s)
self.fc4 = nn.Linear(s, s)
self.fc5 = nn.Linear(s, n_out) #gray scale image (1) or RGB (3)
if self.activation == torch.sin:
# init weights and biases for sine activation
self._init_siren(omega=self.omega)
def forward(self, x):
= self.activation(self.omega*self.fc1(x))
x = self.activation(self.omega*self.fc2(x))
x = self.activation(self.omega*self.fc3(x))
x = self.activation(self.omega*self.fc4(x))
x return self.fc5(x)
# Shuffle data
# shuffled index
= torch.randperm(Coords_scaled.shape[0])
sh_index
# Shuffle the dataset
= Coords_scaled[sh_index]
Coords_sh = pixels[sh_index] pixels_sh
= {}
nns "mnist"] = {}
nns["mnist"]["relu"] = NN(activation=torch.relu).to(device)
nns["mnist"]["sin"] = NN(activation=torch.sin, omega=5).to(device)
nns["mnist"]["tanh"] = NN(activation=torch.tanh).to(device) nns[
"mnist"]["relu"](Coords_sh).shape, nns["mnist"]["sin"](Coords_sh).shape nns[
(torch.Size([784, 1]), torch.Size([784, 1]))
Coords_sh.get_device()
-1
= 2000
n_iter
def train(net, lr, X, Y, epochs, verbose=True):
"""
net: torch.nn.Module
lr: float
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
= nn.MSELoss()
criterion = torch.optim.Adam(net.parameters(), lr=lr)
optimizer for epoch in range(epochs):
optimizer.zero_grad()= net(X)
outputs
= criterion(outputs, Y)
loss
loss.backward()
optimizer.step()if verbose and epoch % 100 == 0:
print(f"Epoch {epoch} loss: {loss.item():.6f}")
return loss.item()
"mnist"]["relu"], lr=3e-4, X=Coords_sh, Y=pixels_sh, epochs=n_iter) train(nns[
Epoch 0 loss: 0.124708
Epoch 100 loss: 0.057384
Epoch 200 loss: 0.015448
Epoch 300 loss: 0.007028
Epoch 400 loss: 0.004833
Epoch 500 loss: 0.003784
Epoch 600 loss: 0.003161
Epoch 700 loss: 0.002692
Epoch 800 loss: 0.002465
Epoch 900 loss: 0.002187
Epoch 1000 loss: 0.001870
Epoch 1100 loss: 0.001652
Epoch 1200 loss: 0.001581
Epoch 1300 loss: 0.001353
Epoch 1400 loss: 0.001245
Epoch 1500 loss: 0.001054
Epoch 1600 loss: 0.000952
Epoch 1700 loss: 0.000897
Epoch 1800 loss: 0.000800
Epoch 1900 loss: 0.000816
0.0008279115427285433
= np.stack([sobel(img_scaled.to("cpu").detach().numpy(), axis=0), sobel(img_scaled.to("cpu").detach().numpy(), axis=1)], axis=-1)
grad = np.linalg.norm(grad, axis=-1)
grad_norm grad.shape,grad_norm.shape
((1, 28, 28, 2), (1, 28, 28))
"chw -> hwc",torch.tensor(grad_norm)),cmap='gray')
plt.imshow(torch.einsum( plt.show()
= img_scaled.cpu().permute(1, 2, 0).numpy()
original_img_np = laplace(original_img_np)
original_laplacian =2).reshape(28,28,1),cmap='gray')
plt.imshow(np.linalg.norm(original_laplacian,axis"Original Laplacian") plt.title(
Text(0.5, 1.0, 'Original Laplacian')
def plot_reconstructed_and_original_image(original_img, net, X, title=""):
"""
net: torch.nn.Module
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
= original_img.shape
num_channels, height, width
= original_img.detach()
original_img True)
X.requires_grad_(eval()
net.= net(X)
outputs = outputs.reshape(height, width, num_channels)
outputs = gradient(outputs,X)
grad = Laplace(outputs,X)
lp
# Compute gradient and Laplacian of original and reconstructed images using skimage
= original_img.cpu().permute(1, 2, 0).numpy()
original_img_np = outputs.detach().cpu().numpy()
reconstructed_img_np
= sobel(original_img_np, axis=0)
original_gradient_x = sobel(original_img_np, axis=1)
original_gradient_y = np.stack([sobel(original_img.to("cpu").detach().numpy(), axis=0), sobel(original_img.to("cpu").detach().numpy(), axis=1)], axis=-1)
original_gradient = np.linalg.norm(original_gradient, axis=-1)
grad_norm = laplace(original_img_np)
original_laplacian
= plt.figure(figsize=(12, 8))
fig = gridspec.GridSpec(2, 4, width_ratios=[1, 1, 1,1])
gs
= plt.subplot(gs[0])
ax0 = plt.subplot(gs[1])
ax1 = plt.subplot(gs[2])
ax2 = plt.subplot(gs[3])
ax3 = plt.subplot(gs[4])
ax4 = plt.subplot(gs[5])
ax5 = plt.subplot(gs[6])
ax6 = plt.subplot(gs[7])
ax7
ax0.imshow(outputs.detach().cpu().numpy())"Reconstructed Image")
ax0.set_title(
1, 2, 0))
ax1.imshow(original_img.cpu().permute("Original Image")
ax1.set_title(# print(np.linalg.norm(grad_norm,axis=0).reshape(1,500,500).shape)
"chw -> hwc",torch.tensor(np.linalg.norm(grad_norm,axis=0).reshape(1,28,28))),cmap='gray')
ax2.imshow(torch.einsum("Original Gradient")
ax2.set_title(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax3.imshow(grad.norm(dim"Reconstructed Gradient")
ax3.set_title(
='gray')
ax4.imshow(original_gradient_x, cmap"Original Gradient X")
ax4.set_title(
='gray')
ax5.imshow(original_gradient_y, cmap"Original Gradient Y")
ax5.set_title(
=2).reshape(28,28,1),cmap='gray')
ax6.imshow(np.linalg.norm(original_laplacian,axis"Original Laplacian")
ax6.set_title(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax7.imshow(lp.norm(dim"Reconstructed Laplacian")
ax7.set_title(
for a in [ax0, ax1, ax2, ax3, ax4, ax5,ax6,ax7]:
"off")
a.axis(
=0.9)
fig.suptitle(title, y
plt.tight_layout()
plt.show()
return outputs,grad,lp
= plot_reconstructed_and_original_image(img_scaled, nns["mnist"]["relu"], Coords_scaled, title="ReLU") outputs_relu,grad_relu,lp_relu
= train(nns["mnist"]["sin"], lr=3e-4, X=Coords_sh, Y=pixels_sh, epochs=n_iter) imgs_sin
Epoch 0 loss: 0.135896
Epoch 100 loss: 0.006196
Epoch 200 loss: 0.002919
Epoch 300 loss: 0.001405
Epoch 400 loss: 0.000737
Epoch 500 loss: 0.000398
Epoch 600 loss: 0.000234
Epoch 700 loss: 0.000134
Epoch 800 loss: 0.000075
Epoch 900 loss: 0.000082
Epoch 1000 loss: 0.000028
Epoch 1100 loss: 0.000019
Epoch 1200 loss: 0.000064
Epoch 1300 loss: 0.000120
Epoch 1400 loss: 0.000006
Epoch 1500 loss: 0.000005
Epoch 1600 loss: 0.000005
Epoch 1700 loss: 0.000003
Epoch 1800 loss: 0.000189
Epoch 1900 loss: 0.000002
= plot_reconstructed_and_original_image(img_scaled, nns["mnist"]["sin"], Coords_scaled, title="Sine") outputs_sin,grad_sin,lp_sin
= train(nns["mnist"]["tanh"], lr=3e-4, X=Coords_sh, Y=pixels_sh, epochs=n_iter) imgs_tanh
Epoch 0 loss: 0.127715
Epoch 100 loss: 0.076765
Epoch 200 loss: 0.072533
Epoch 300 loss: 0.065501
Epoch 400 loss: 0.045376
Epoch 500 loss: 0.032519
Epoch 600 loss: 0.023787
Epoch 700 loss: 0.019351
Epoch 800 loss: 0.016675
Epoch 900 loss: 0.014709
Epoch 1000 loss: 0.012673
Epoch 1100 loss: 0.011636
Epoch 1200 loss: 0.010895
Epoch 1300 loss: 0.010314
Epoch 1400 loss: 0.009828
Epoch 1500 loss: 0.009476
Epoch 1600 loss: 0.009905
Epoch 1700 loss: 0.008863
Epoch 1800 loss: 0.008430
Epoch 1900 loss: 0.008168
= plot_reconstructed_and_original_image(img_scaled, nns["mnist"]["tanh"], Coords_scaled, title="tanh") outputs_tanh,grad_tanh,lp_tanh
# plot figure comparisons
def plot_comparisons():
= plt.figure(figsize=(12, 8))
fig = gridspec.GridSpec(3, 4, width_ratios=[1, 1, 1,1])
gs
= plt.subplot(gs[0])
ax0 = plt.subplot(gs[1])
ax1 = plt.subplot(gs[2])
ax2 = plt.subplot(gs[3])
ax3 = plt.subplot(gs[4])
ax4 = plt.subplot(gs[5])
ax5 = plt.subplot(gs[6])
ax6 = plt.subplot(gs[7])
ax7 = plt.subplot(gs[8])
ax8 = plt.subplot(gs[9])
ax9 = plt.subplot(gs[10])
ax10 = plt.subplot(gs[11])
ax11
"chw -> hwc",img_scaled).detach().cpu().numpy(),cmap='gray')
ax0.imshow(torch.einsum("Ground Truth Image",size=9)
ax0.set_title("off")
ax0.axis(
='gray')
ax1.imshow(outputs_relu.detach().cpu().numpy(),cmap"ReLU Image",size=9)
ax1.set_title("off")
ax1.axis(
='gray')
ax2.imshow(outputs_sin.detach().cpu().numpy(),cmap"Sine Image",size=9)
ax2.set_title("off")
ax2.axis(
='gray')
ax3.imshow(outputs_tanh.detach().cpu().numpy(),cmap"tanh Image",size=9)
ax3.set_title("off")
ax3.axis(
"chw -> hwc",torch.tensor(grad_norm)),cmap='gray')
ax4.imshow(torch.einsum("Ground Truth Gradient",size=9)
ax4.set_title("off")
ax4.axis(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax5.imshow(grad_relu.norm(dim"ReLU Gradient",size=9)
ax5.set_title("off")
ax5.axis(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax6.imshow(grad_sin.norm(dim"Sine Gradient",size=9)
ax6.set_title("off")
ax6.axis(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax7.imshow(grad_tanh.norm(dim"tanh Gradient",size=9)
ax7.set_title("off")
ax7.axis(
=2).reshape(28,28,1),cmap='gray')
ax8.imshow(np.linalg.norm(original_laplacian,axis"Original Laplacian")
ax8.set_title("off")
ax8.axis(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax9.imshow(lp_relu.norm(dim"ReLU Laplacian",size=9)
ax9.set_title("off")
ax9.axis(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax10.imshow(lp_sin.norm(dim"Sine Laplacian",size=9)
ax10.set_title("off")
ax10.axis(
=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax11.imshow(lp_tanh.norm(dim"tanh Laplacian",size=9)
ax11.set_title("off")
ax11.axis(
plt.tight_layout() plt.show()
plot_comparisons()
grad_relu.shape
torch.Size([784, 2])
# Analyzing siren for different values of omega
= [w for w in range(1,34,5)]
omega_values omega_values
[1, 6, 11, 16, 21, 26, 31]
def plot_omega_comparisons():
= plt.subplots(3,7,figsize=(15,5))
fig,ax True)
Coords_scaled.requires_grad_(for w in range(len(omega_values)):
"mnist"][f"{omega_values[w]}"] = NN(activation=torch.sin, omega=omega_values[w]).to(device)
nns[= nns["mnist"][f"{omega_values[w]}"]
net = train(net, lr=3e-4, X=Coords_sh, Y=pixels_sh, epochs=n_iter,verbose=False)
imgs_sin_w eval()
net.= net(Coords_scaled)
outputs = outputs.reshape(28, 28, 1)
outputs = gradient(outputs,Coords_scaled)
grad = Laplace(outputs,Coords_scaled)
lp
0,w].imshow(outputs.detach().cpu().numpy(),cmap='gray')
ax[0,w].set_title(f"Reconstructed Image, omega={omega_values[w]}",size=7)
ax[0,w].axis("off")
ax[
1,w].imshow(grad.norm(dim=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax[1,w].set_title("Reconstructed Gradient",size=7)
ax[1,w].axis("off")
ax[
2,w].imshow(lp.norm(dim=-1).cpu().view(28,28).detach().numpy(),cmap='gray')
ax[2,w].set_title("Reconstructed Laplacian",size=7)
ax[2,w].axis("off")
ax[
plot_omega_comparisons()