Projects & Blogs

  • PROBABLISTIC APPR...

  • MINI-UNET

  • MINI-ALEXNET

  • SEQUENTIAL MONTE ...

  • TRUNCATED SVD

  • CUSTOM DATALOADER...

  • PROBABILITY

from datetime import datetime

import torch
import torchvision
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import torch.distributions as D
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data as data_util
import matplotlib.pyplot as plt
from IPython import display
from collections import Counter
transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

training_data = datasets.MNIST(root="../data/MINST", download=True, train=True, transform=transform)
validation_data = datasets.MNIST(root="../data/MINST", download=True, train=False, transform=transform)

train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=100, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=100, shuffle=False)
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    """Set the axes for matplotlib."""
    axes.set_xlabel(xlabel)
    axes.set_ylabel(ylabel)
    axes.set_xscale(xscale)
    axes.set_yscale(yscale)
    axes.set_xlim(xlim)
    axes.set_ylim(ylim)
    if legend:
        axes.legend(legend)
    axes.grid()
    
class Animator:
    """For plotting data in animation."""

    def __init__(
        self,
        xlabel=None,
        ylabel=None,
        legend=None,
        xlim=None,
        ylim=None,
        xscale="linear",
        yscale="linear",
        fmts=("-", "m--", "g-.", "r:"),
        nrows=1,
        ncols=1,
        figsize=(3.5, 2.5),
    ):
        # Incrementally plot multiple lines
        if legend is None:
            legend = []
        display.set_matplotlib_formats("svg")
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [
                self.axes,
            ]
        # Use a lambda function to capture arguments
        self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # Add multiple data points into the figure
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for i,(x, y) in enumerate(zip(self.X, self.Y)):
            if y:
                if i==0:
                    self.axes[0].plot(x, y, self.fmts[0])
                else:
                    self.axes[0].plot(x, y, self.fmts[1])
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)
softmax = nn.Softmax(dim=0)
sigmoid = nn.Sigmoid()

class MNISTModel(nn.Module):

    def __init__(self, total_class = 12, input_shape = 28*28):
        super(MNISTModel, self).__init__()
        random_uniform_val = torch.rand(total_class)
        self.mean = nn.Parameter(random_uniform_val/random_uniform_val.sum())
        self.covariance = nn.Parameter(torch.special.logit(torch.full((total_class,input_shape),1/total_class)))
        
    def forward(self,data):
        current_m = self.update(softmax(self.mean), sigmoid(self.covariance))
        model_prob = current_m.log_prob(data)
        return -sum(model_prob)/data.shape[0]
    
    def update(self,mean,covariance):
        return D.MixtureSameFamily(D.Categorical(mean),D.Independent(D.Bernoulli(covariance), reinterpreted_batch_ndims=1))

model = MNISTModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
past_steps = 50
def train_one_epoch(animator,epoch):
    last_loss = 0.
    for i, data in enumerate(train_dataloader):
        inputs, _ = data
        batch_size = inputs.shape[0]
        inputs = torch.reshape(inputs,(batch_size,28*28))
        inputs = torch.where(inputs > 0, 1.0, 0.0)
        optimizer.zero_grad()
        output = model(inputs)
        output.backward()
        optimizer.step()
        animator.add(epoch + (i + 1), (output.item(),None)) 
    return last_loss

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/MNIST_{}'.format(timestamp))
EPOCHS = 15
animator = Animator(
        xlabel=f'DataPerEpoch {EPOCHS}epochs x {len(train_dataloader)} datapoints', xlim=[1, EPOCHS*len(train_dataloader)], ylim=[0, 300], legend=["training loss"]
    )
for epoch in range(EPOCHS):
    model.train(True)
    avg_loss = train_one_epoch(animator,epoch*len(train_dataloader))
<Figure size 350x250 with 1 Axes>
images = sigmoid(list(model.parameters())[1].detach())
fig, axes = plt.subplots(3, 4)
i = 0
for image,ax in zip(images,axes.flatten()):
    image = image.reshape(28,28)
    ax.imshow(image, cmap='gray')
    ax.set_title(f'Plot No. {i}')
    ax.axis("off")
    i = i+1
fig.tight_layout(pad=1.0)
plt.show()
<Figure size 640x480 with 12 Axes>
2024 Debashis Blogs...
Contact
LinkedIn
Privacy