Projects & Blogs

  • PROBABLISTIC APPR...

  • MINI-UNET

  • MINI-ALEXNET

  • SEQUENTIAL MONTE ...

  • TRUNCATED SVD

  • CUSTOM DATALOADER...

  • PROBABILITY

from typing import Callable, List, NamedTuple
import torch
import matplotlib.pyplot as plt

class Constants(NamedTuple):
    phi: torch.Tensor = torch.tensor([0.9])
    q: torch.Tensor = torch.tensor([10.0])
    beta: torch.Tensor = torch.tensor([0.5])
    r: torch.Tensor = torch.tensor([1.0])

init_constants = Constants()

class GaussianSequenceModel:

    def __init__(self, init_constants: Constants, N_s):
        self.init_constants = init_constants
        self.N_s = N_s

    def sample_z_step(self, prev_zt):
        phi = self.init_constants.phi
        q = self.init_constants.q
        zt = torch.distributions.Normal(loc=phi*prev_zt, scale=torch.sqrt(q)).sample()
        return zt

    def y_step(self, prev_y_mu_t, zt):
        beta = self.init_constants.beta
        r = self.init_constants.r
        mu_t = torch.vmap(lambda mu_i,z_i_t : beta*mu_i+z_i_t)(prev_y_mu_t,zt)
        y_t_dist = torch.distributions.Normal(loc=mu_t.ravel(), scale=torch.sqrt(r))
        return mu_t, y_t_dist

    def sample_y_step(self, prev_y_mu_t, zt):
        mu_t, y_t_dist = self.y_step(prev_y_mu_t, zt)
        y_t = y_t_dist.sample()
        return mu_t.ravel(), y_t.ravel(), y_t_dist
    
    def sample_step(self, prev_zt, prev_y_mu_t):
        zt = self.sample_z_step(prev_zt)
        mu_t, y_t, _ = self.sample_y_step(prev_y_mu_t,zt)

        return (zt, mu_t), torch.cat((zt.unsqueeze(0), y_t.unsqueeze(0)))
    
    def sample(self, nsteps):
        mu_0, z_i_0, _ = self.inits(self.N_s)
        init_values = (mu_0,z_i_0)
        def scan(func, init_values, length):
            carry = init_values
            logs_chain = []
            for i in range(length):
                carry, logs = func(carry[0], carry[1])
                logs_chain.append(logs)
            logs_output = torch.stack(logs_chain,dim=1)
            return carry, logs_output
        _, logs_chain = scan(self.sample_step, init_values, nsteps)
        return logs_chain

    def inits(self, N_s):
        mu_0 = torch.zeros(N_s)
        z_i_0 = torch.zeros(N_s)
        w_i_0 = torch.zeros(N_s)
        return mu_0, z_i_0, w_i_0
    
    def sis(self, y_ts, T):
        z_1 = torch.tensor([0])
        z_i_t = self.sample_z_step(z_1.repeat((self.N_s)))
        w_i_t = y_ts[0]/z_i_t
        W_i_t = w_i_t/w_i_t.sum()
        mc_post_t = (W_i_t*torch.abs(z_1-z_i_t)).sum()
        normalized_w, posteriors = [],[]
        posteriors.append(mc_post_t)
        normalized_w.append(W_i_t)
        for t in range(1,T):
            z_i_t = self.sample_z_step(z_i_t) 
            alpha = y_ts[t]/(y_ts[t-1]*z_i_t)
            w_i_t = w_i_t*alpha
            W_i_t = torch.nn.functional.normalize(w_i_t, dim=-1)
            normalized_w.append(W_i_t)
            mc_post_t = (W_i_t*torch.abs(z_i_t.mean()-z_i_t)).sum()
            posteriors.append(mc_post_t)
        return torch.stack(posteriors), torch.stack(normalized_w)

    def resample_permute(self, w_i_t, z_i_t):
        indexs = torch.distributions.Categorical(logits=w_i_t).sample((self.N_s,))
        z_i_t = torch.stack([z_i_t[i] for i in indexs])
        w_i_t = 1/self.N_s * torch.ones((self.N_s))
        return indexs, z_i_t, w_i_t
    
    def sisr(self, y_ts, T):
        z_1 = torch.tensor([0])
        z_i_t = self.sample_z_step(z_1.repeat((self.N_s)))
        w_i_t = y_ts[0]/z_i_t
        W_i_t = w_i_t/w_i_t.sum()
        mc_post_t = (W_i_t*torch.abs(z_1-z_i_t)).sum()
        normalized_w, selected_indexs, posteriors = [], [], []
        normalized_w.append(W_i_t)
        posteriors.append(mc_post_t)
        for t in range(1,T):
            indexes, z_i_t, w_i_t = self.resample_permute(w_i_t, z_i_t)
            selected_indexs.append(indexes)
            z_i_t = self.sample_z_step(z_i_t)
            w_i_t = y_ts[t]/(y_ts[t-1] * z_i_t)
            W_i_t = torch.nn.functional.normalize(w_i_t, dim=-1)
            mc_post_t = (W_i_t*torch.abs(z_i_t.mean()-z_i_t)).sum()
            normalized_w.append(W_i_t)
            posteriors.append(mc_post_t)
        return torch.stack(posteriors), torch.stack(normalized_w), torch.stack(selected_indexs)
        
    def sisr_with_adaptive_resampling(self, y_ts, T, min_weight):
        z_1 = torch.tensor([0])
        z_i_t = self.sample_z_step(z_1.repeat((self.N_s)))
        w_i_t = y_ts[0]/z_i_t
        W_i_t = w_i_t/w_i_t.sum()
        mc_post_t = (W_i_t*torch.abs(z_1-z_i_t)).sum()
        normalized_w, Z_ts, posteriors, selected_indexs = [], [], [], []
        prev_w_i_t = torch.ones(self.N_s)
        Z_t = 1
        normalized_w.append(W_i_t)
        posteriors.append(mc_post_t)

        for t in range(1, T):
            z_i_t = self.sample_z_step(z_i_t)
            alpha = y_ts[t]/(y_ts[t-1]*z_i_t)
            w_i_t = w_i_t*alpha
            # # normalization const
            Z_ts.append(self.calc_Z(w_i_t, prev_w_i_t, Z_t))
            if self.ESS(w_i_t, ess_min=min_weight):
                indexes, z_i_t, w_i_t = self.resample_permute(w_i_t, z_i_t)
                selected_indexs.append((t, indexes))
            W_i_t = torch.nn.functional.normalize(w_i_t, dim=-1)
            mc_post_t = (W_i_t*torch.abs(z_i_t.mean()-z_i_t)).sum()
            normalized_w.append(W_i_t)
            posteriors.append(mc_post_t)
        return torch.stack(posteriors), torch.stack(normalized_w), torch.stack(Z_ts), selected_indexs

    def ESS(self, w_i_t, ess_min):
        ess = w_i_t.sum().square() / w_i_t.square().sum()
        if ess < ess_min:
            return True
        else:
            return False
        
    def calc_Z(self, w_i_t, prev_w_i_t, prev_Z_t):
        # Z_t/Z_{t-1}
        Z_t_Z_t_1 = w_i_t.sum() / prev_w_i_t.sum()
        # Z_t
        return prev_Z_t * Z_t_Z_t_1
    
particles = 5
T = 10
model = GaussianSequenceModel(init_constants, N_s=particles)
samples = model.sample(T)
print('sample (zts, yts) shape', samples.shape)
y_ts = torch.squeeze(samples.split(1)[-1])
print('only yts shape : ',y_ts.shape)
# SIS
p_sis, w_sis = model.sis(y_ts, T)
# SISR
p_smc, w_smc, s_idxs = model.sisr(y_ts, T)
# SISR_Adaptive
p_g_sisr_ad, w_g_sisr_ad, nc_g_sisr_ad , s_idxs_sisr_ad = model.sisr_with_adaptive_resampling(y_ts, T, min_weight=particles/5)
# print(f'Normalized weights for sequential_importance_sampling : \n {weights_sis}')
print(f'Normalized weights for sequential_importance_resampling : \n {w_smc}')
print(f'Normalized weights for sequential_importance_resampling_adaptive : \n {w_g_sisr_ad}')
sample (zts, yts) shape torch.Size([2, 10, 5])
only yts shape :  torch.Size([10, 5])
Normalized weights for sequential_importance_resampling : 
 tensor([[-6.6174e-01,  6.6187e-02, -9.7740e-02, -4.4313e-01,  2.1364e+00],
        [-1.8866e-01,  8.9358e-01,  3.5519e-01,  7.0895e-02,  1.8638e-01],
        [-8.6953e-02, -3.1492e-04, -2.4521e-01,  9.1163e-01, -3.1818e-01],
        [-4.1938e-02,  9.9866e-01, -1.2471e-02, -2.3220e-02, -1.5050e-02],
        [-1.6795e-02, -5.5831e-02, -1.2844e-03, -3.8192e-02,  9.9757e-01],
        [ 4.9757e-01, -4.2751e-01, -7.5116e-01,  2.7587e-02, -6.8301e-02],
        [ 4.9504e-02, -9.9616e-01, -5.9280e-02,  3.9084e-02,  1.3173e-02],
        [-3.2842e-02, -1.6739e-01,  1.7384e-01,  9.2731e-01,  2.8420e-01],
        [-9.8761e-01,  1.2042e-01,  7.7607e-02,  2.1090e-02, -6.0432e-02],
        [ 5.8333e-01,  5.7038e-01, -2.4771e-01, -1.4537e-01, -5.0191e-01]])
Normalized weights for sequential_importance_resampling_adaptive : 
 tensor([[ 1.4396e+00, -1.6778e-02, -5.5203e-02,  1.5816e-01, -5.2582e-01],
        [ 2.9889e-01, -3.5871e-03, -7.2293e-03,  2.7421e-02,  9.5386e-01],
        [ 8.6965e-01, -4.0461e-06, -1.5039e-03,  5.4041e-03,  4.9363e-01],
        [ 4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01],
        [ 4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01],
        [ 4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01],
        [ 4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01],
        [-2.4316e-02, -9.8074e-02, -5.3355e-01, -2.8268e-01, -7.9070e-01],
        [ 4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01],
        [ 4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01,  4.4721e-01]])
def plot_sis_weights(weights, n_steps, spacing=1.5, max_size=0.3):
    """
    Plot the evolution of weights in the sequential importance sampling (SIS) algorithm.

    Parameters
    ----------
    weights: array(n_particles, n_steps)
        Weights at each time step.
    n_steps: int
        Number of steps to plot.
    spacing: float
        Spacing between particles.
    max_size: float
        Maximum size of the particles.
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.set_aspect(1)
    weights_subset = weights[:n_steps]
    for col, weights_row in enumerate(weights_subset):
        norm_cst = weights_row.sum()
        radii = weights_row / norm_cst * max_size
        for row, rad in enumerate(radii):
            if col != n_steps - 1:
                plt.arrow(spacing * (col + 0.25), row, 0.6, 0, width=0.05,
                          edgecolor="white", facecolor="tab:gray")
            circle = plt.Circle((spacing * col, row), rad, color="tab:red")
            ax.add_artist(circle)

    plt.xlim(-1, n_steps * spacing)
    plt.xlabel("Iteration (t)")
    plt.ylabel("Particle index (i)")

    xticks_pos = torch.arange(0, n_steps * spacing - 1, 2)
    xticks_lab = torch.arange(1, n_steps + 1)
    plt.xticks(xticks_pos, xticks_lab)

    return fig, ax

spacing = 2
fig, ax = plot_sis_weights(w_sis, n_steps=10, spacing=spacing)
plt.tight_layout()
<Figure size 800x600 with 1 Axes>
def plot_smc_weights(weights, indexes, n_steps, spacing=1.5, max_size=0.3):
    """
    Plot the evolution of weights in the sequential Monte Carlo (SMC) algorithm.

    Parameters
    ----------
    weights: array(n_particles, n_steps)
        Weights at each time step.
    n_steps: int
        Number of steps to plot.
    spacing: float
        Spacing between particles.
    max_size: float
        Maximum size of the particles.
    
    Returns
    -------
    fig: matplotlib.figure.Figure
        Figure containing the plot.
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.set_aspect(1)

    weights_subset = weights[:n_steps]
    # sampled indices represent the "position" of weights at the next time step
    ix_subset = indexes[:n_steps]

    for it, (weights_row, p_target) in enumerate(zip(weights_subset, ix_subset)):
        norm_cst = weights_row.sum()
        radii = weights_row / norm_cst * max_size
        for particle_ix, (rad, target_ix) in enumerate(zip(radii, p_target)):
            # print(rad, target_ix)
            if it != n_steps - 2:
                diff = particle_ix - target_ix
                plt.arrow(spacing * (it + 0.15), target_ix, 1.3, diff, width=0.05,
                        edgecolor="white", facecolor="tab:gray", length_includes_head=True)
            circle = plt.Circle((spacing * it, particle_ix), rad, color="tab:blue")
            ax.add_artist(circle)

    plt.xlim(-1, n_steps * spacing - 2)
    plt.xlabel("Iteration (t)")
    plt.ylabel("Particle index (i)")

    xticks_pos = torch.arange(0, n_steps * spacing - 2, 2)
    xticks_lab = torch.arange(1, n_steps)
    plt.xticks(xticks_pos, xticks_lab)

    # ylims = ax.axes.get_ylim() # to-do: grab this value for SCM-particle descendents' plot

    return fig, ax

fig, ax = plot_smc_weights(w_smc, s_idxs, n_steps=10, spacing=spacing)
ylims = ax.axes.get_ylim()
plt.tight_layout()
<Figure size 800x600 with 1 Axes>
class SampleGeneratorSSM:

    def __init__(self, z_function: Callable, q=0.001, r=0.05):
        self.z_function = z_function
        self.q = q
        self.r = r

    def z_dist(self, zt: torch.Tensor):
        return torch.distributions.MultivariateNormal(loc=z_function(zt),
                                                      covariance_matrix=self.q * torch.eye(zt.shape[-1]))
    
    def y_dist(self, zt: torch.Tensor):
        return torch.distributions.MultivariateNormal(loc=zt,
                                                      covariance_matrix=self.r * torch.eye(zt.shape[-1]))

    def sample_y(self, zt_next):
        p_y = self.y_dist(zt_next)
        yt_next = p_y.sample()
        return yt_next

    def sample_z(self, zt):
        p_z = self.z_dist(zt)
        zt_next = p_z.sample()
        return zt_next
    
    def _step(self, zt = torch.tensor([1.5, 0.0])):
        zt_next = self.sample_z(zt)
        yt_next = self.sample_y(zt_next)
        return zt_next, torch.stack((zt_next, yt_next))
    
    def samples(self, z0, N_s = 100):
        def scan(func, init_values, length):
            carry = init_values
            logs_chain = []
            for i in range(length):
                carry, logs = func(carry)
                logs_chain.append(logs)
            logs_output = torch.stack(logs_chain,dim=1)
            return carry, logs_output
       
        _, logs_chain = scan(self._step, z0, N_s)
        # (zt,yt)
        return logs_chain
    
def z_function(z, delta=0.4):
    z_x = z[0] + delta * torch.sin(z[1])
    z_y = z[1] + delta * torch.cos(z[0])
    return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))
def z_function_sampler(z, delta=0.4):
    z_x = z[0] + delta * torch.sin(z[1])
    z_y = z[1] + delta * torch.cos(z[0])
    return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))

def z_function_predictor(z, delta=0.8):
    z_x = z[0] + delta * z[1]
    z_y = z[1] + delta * z[0]
    return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))

T = 100
N_s = 5
# Samples generation for N_s particles
samples_list = []
samples_actual_z_list = []
generator_model = SampleGeneratorSSM(z_function_sampler)
for _ in range(N_s):
    z0 = torch.tensor([1.5, 0.0])
    samples = generator_model.samples(z0, T)
    samples_list.append(samples[1])
    samples_actual_z_list.append(samples[0])

ns_samples = torch.stack(samples_list)
ns_zt_actual_samples = torch.stack(samples_actual_z_list)
# ns_samples = ns_samples.transpose(0,1)
print(ns_samples.shape)
# print(ns_zt_actual_samples.shape)
torch.Size([5, 100, 2])
class RandomSSMSampler:

    def __init__(self, z_function: Callable, N_s=10, q=0.001, r=0.05):
        self.z_function = z_function
        self.q = q
        self.r = r
        self.N_s = N_s

    def z_dist(self, zt: torch.Tensor):
        return torch.distributions.MultivariateNormal(loc=z_function(zt),
                                                      covariance_matrix=self.q * torch.eye(zt.shape[-1]), validate_args=False)
    
    def y_dist(self, zt: torch.Tensor):
        return torch.distributions.MultivariateNormal(loc=zt,
                                                      covariance_matrix=self.r * torch.eye(zt.shape[-1]), validate_args=False)
    
    def z_step(self, zt: torch.Tensor):
        return torch.stack([self.z_dist(z).sample() for z in zt])
    
    def alpha_step(self, zt: torch.Tensor, yt: torch.Tensor):
        def alpha_per_sample(zt_per_sample: torch.Tensor, yt_per_sample: torch.Tensor):
            return self.y_dist(zt_per_sample).log_prob(yt_per_sample)
        return torch.vmap(alpha_per_sample)(zt, yt)

    def inits(self, sample_shape):
        mu_0 = torch.zeros(sample_shape)
        z_i_0 = torch.tensor([1.5, 0.0]).repeat((5,1))
        w_i_0 = torch.zeros(sample_shape[0])
        return mu_0, z_i_0, w_i_0

    def resample_permute(self, w_i_t, z_i_t):
        def resample():
            indexs = [torch.distributions.Categorical(logits=w_i_t).sample((self.N_s,))]
            return indexs, torch.cat([z_i_t[index] for index in indexs])
        indexs, z_i_t = resample()
        w_i_t = torch.log(torch.tensor(1/self.N_s)) * torch.ones((self.N_s))
        return indexs, z_i_t, w_i_t
    
    def sisr_with_adaptive_resampling(self, y_ts, T, min_weight):
        z_1 = torch.tensor([1.5, 0.0])
        z_i_t = self.z_step(z_1.repeat((self.N_s,1)))
        w_i_t = torch.stack([self.y_dist(zt).log_prob(yt) for zt,yt in zip(z_i_t, y_ts[0])])
        W_i_t = torch.exp(w_i_t - torch.logsumexp(w_i_t, dim=0))
        mc_post_t = torch.matmul(W_i_t,torch.abs(z_1-z_i_t))
        normalized_w, Z_ts, posteriors, selected_indexs, particle_zs = [], [], [], [], []
        prev_w_i_t = torch.ones(self.N_s)
        Z_t = torch.tensor([1,1])
        normalized_w.append(W_i_t)
        posteriors.append(mc_post_t)
        particle_zs.append(z_i_t)

        for t in range(1, T):
            z_i_t = self.z_step(z_i_t)
            alpha = torch.stack([self.y_dist(zt).log_prob(yt) for zt,yt in zip(z_i_t, y_ts[0])])
            w_i_t = w_i_t + alpha
            if self.ESS(w_i_t, ess_min=min_weight):
                indexes, z_i_t, w_i_t = self.resample_permute(w_i_t, z_i_t)
                selected_indexs.append((t, indexes))
                # print('resample')
            W_i_t = torch.exp(w_i_t - torch.logsumexp(w_i_t, dim=0))
            mc_post_t = torch.matmul(W_i_t,torch.abs(z_1-z_i_t))
            normalized_w.append(W_i_t)
            posteriors.append(mc_post_t)
            # normalization const
            Z_t = self.calc_Z(w_i_t, prev_w_i_t, Z_t)
            Z_ts.append(Z_t)
            prev_w_i_t = w_i_t
            particle_zs.append(z_i_t)
            # print(W_i_t, w_i_t)

        states = self.extract_states(torch.stack(normalized_w), torch.stack(particle_zs))
        return torch.stack(posteriors), states, torch.stack(Z_ts), selected_indexs
    
    def extract_states(self, normalized_w, particle_zs):
        states = []
        for t in range(normalized_w.shape[0]):
            index = torch.distributions.Categorical(probs=normalized_w[t]).sample()
            states.append(particle_zs[t][index])
        return torch.stack(states)

    def ESS(self, w_i_t, ess_min):
        ess = w_i_t.sum().square() / w_i_t.square().sum()
        if ess < ess_min:
            return True
        else:
            return False
        
    def calc_Z(self, w_i_t, prev_w_i_t, prev_Z_t):
        # Z_t/Z_{t-1}
        Z_t_Z_t_1 = torch.sum(w_i_t,dim=0) / torch.sum(prev_w_i_t,dim=0)
        # Z_t
        return prev_Z_t * Z_t_Z_t_1
    
    def _step(self, zt = torch.tensor([1.5, 0.0])):
        zt_next = self.z_dist(zt).sample()
        yt_next = self.y_dist(zt_next).sample()
        return zt_next, torch.stack((zt_next, yt_next))
    
    def samples(self, z0, T = 100):
        def scan(func, init_values, length):
            carry = init_values
            logs_chain = []
            for i in range(length):
                carry, logs = func(carry)
                logs_chain.append(logs)
            logs_output = torch.stack(logs_chain,dim=1)
            return carry, logs_output
       
        _, logs_chain = scan(self._step, z0, T)
        # (zt,yt)
        return logs_chain
    
def z_function_sampler(z, delta=0.4):
    z_x = z[0] + delta * torch.sin(z[1])
    z_y = z[1] + delta * torch.cos(z[0])
    return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))

def z_function_predictor(z, delta=0.8):
    z_x = z[0] + delta * z[1]
    z_y = z[1] + delta * z[0]
    return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))

particles = 5
model = RandomSSMSampler(z_function_predictor, N_s=particles)
# model.sisr_with_adaptive_resampling(ns_samples.transpose(0,1), T, min_weight=particles/5)
predicted_posteriors, predicted_zs, predicted_Zts, resample_indexes = model.sisr_with_adaptive_resampling(ns_samples.transpose(0,1), T, min_weight=particles*0.78)
# print(predicted_zs.shape)
after_learning = model.samples(predicted_zs[-1], 10)
print(after_learning[1].shape)
torch.Size([10, 2])
def plot_inference(N_s, states, emissions, estimates=None, after_learning_eastimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
    lines = []
    # print(estimates.shape, after_learning_eastimates.shape)
    for i in range(N_s):
        s, e = states[i], emissions[i]
        if ax is None:
            fig, ax = plt.subplots()
        if show_states:
            line, = ax.plot(*s.T, label=f'{i} particle')
            lines.append(line)
        ax.plot(*e.T, "ok", fillstyle="none", ms=4, label="Observations")

    if estimates is not None:
        line, = ax.plot(*estimates.T, color="r", linewidth=2, label=f"{est_type} Estimate")
        lines.append(line)
    if after_learning_eastimates is not None:
        line, = ax.plot(*after_learning_eastimates.T, color="b", linewidth=2, label=f"{est_type} after learning")
        lines.append(line)
    ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k", handles=lines)
    #ax.set_aspect(aspect)
    ax.set_title(title)
    ax.axis('equal')
    return ax

plot_inference(N_s, ns_samples, ns_zt_actual_samples, predicted_zs, after_learning[1], title='custom linear gaussian state space model', show_states=True)
<Axes: title={'center': 'custom linear gaussian state space model'}>
<Figure size 640x480 with 1 Axes>
2024 Debashis Blogs...
Contact
LinkedIn
Privacy