Projects & Blogs
PROBABLISTIC APPR...
MINI-UNET
MINI-ALEXNET
SEQUENTIAL MONTE ...
TRUNCATED SVD
CUSTOM DATALOADER...
PROBABILITY
from typing import Callable, List, NamedTuple
import torch
import matplotlib.pyplot as plt
class Constants(NamedTuple):
phi: torch.Tensor = torch.tensor([0.9])
q: torch.Tensor = torch.tensor([10.0])
beta: torch.Tensor = torch.tensor([0.5])
r: torch.Tensor = torch.tensor([1.0])
init_constants = Constants()
class GaussianSequenceModel:
def __init__(self, init_constants: Constants, N_s):
self.init_constants = init_constants
self.N_s = N_s
def sample_z_step(self, prev_zt):
phi = self.init_constants.phi
q = self.init_constants.q
zt = torch.distributions.Normal(loc=phi*prev_zt, scale=torch.sqrt(q)).sample()
return zt
def y_step(self, prev_y_mu_t, zt):
beta = self.init_constants.beta
r = self.init_constants.r
mu_t = torch.vmap(lambda mu_i,z_i_t : beta*mu_i+z_i_t)(prev_y_mu_t,zt)
y_t_dist = torch.distributions.Normal(loc=mu_t.ravel(), scale=torch.sqrt(r))
return mu_t, y_t_dist
def sample_y_step(self, prev_y_mu_t, zt):
mu_t, y_t_dist = self.y_step(prev_y_mu_t, zt)
y_t = y_t_dist.sample()
return mu_t.ravel(), y_t.ravel(), y_t_dist
def sample_step(self, prev_zt, prev_y_mu_t):
zt = self.sample_z_step(prev_zt)
mu_t, y_t, _ = self.sample_y_step(prev_y_mu_t,zt)
return (zt, mu_t), torch.cat((zt.unsqueeze(0), y_t.unsqueeze(0)))
def sample(self, nsteps):
mu_0, z_i_0, _ = self.inits(self.N_s)
init_values = (mu_0,z_i_0)
def scan(func, init_values, length):
carry = init_values
logs_chain = []
for i in range(length):
carry, logs = func(carry[0], carry[1])
logs_chain.append(logs)
logs_output = torch.stack(logs_chain,dim=1)
return carry, logs_output
_, logs_chain = scan(self.sample_step, init_values, nsteps)
return logs_chain
def inits(self, N_s):
mu_0 = torch.zeros(N_s)
z_i_0 = torch.zeros(N_s)
w_i_0 = torch.zeros(N_s)
return mu_0, z_i_0, w_i_0
def sis(self, y_ts, T):
z_1 = torch.tensor([0])
z_i_t = self.sample_z_step(z_1.repeat((self.N_s)))
w_i_t = y_ts[0]/z_i_t
W_i_t = w_i_t/w_i_t.sum()
mc_post_t = (W_i_t*torch.abs(z_1-z_i_t)).sum()
normalized_w, posteriors = [],[]
posteriors.append(mc_post_t)
normalized_w.append(W_i_t)
for t in range(1,T):
z_i_t = self.sample_z_step(z_i_t)
alpha = y_ts[t]/(y_ts[t-1]*z_i_t)
w_i_t = w_i_t*alpha
W_i_t = torch.nn.functional.normalize(w_i_t, dim=-1)
normalized_w.append(W_i_t)
mc_post_t = (W_i_t*torch.abs(z_i_t.mean()-z_i_t)).sum()
posteriors.append(mc_post_t)
return torch.stack(posteriors), torch.stack(normalized_w)
def resample_permute(self, w_i_t, z_i_t):
indexs = torch.distributions.Categorical(logits=w_i_t).sample((self.N_s,))
z_i_t = torch.stack([z_i_t[i] for i in indexs])
w_i_t = 1/self.N_s * torch.ones((self.N_s))
return indexs, z_i_t, w_i_t
def sisr(self, y_ts, T):
z_1 = torch.tensor([0])
z_i_t = self.sample_z_step(z_1.repeat((self.N_s)))
w_i_t = y_ts[0]/z_i_t
W_i_t = w_i_t/w_i_t.sum()
mc_post_t = (W_i_t*torch.abs(z_1-z_i_t)).sum()
normalized_w, selected_indexs, posteriors = [], [], []
normalized_w.append(W_i_t)
posteriors.append(mc_post_t)
for t in range(1,T):
indexes, z_i_t, w_i_t = self.resample_permute(w_i_t, z_i_t)
selected_indexs.append(indexes)
z_i_t = self.sample_z_step(z_i_t)
w_i_t = y_ts[t]/(y_ts[t-1] * z_i_t)
W_i_t = torch.nn.functional.normalize(w_i_t, dim=-1)
mc_post_t = (W_i_t*torch.abs(z_i_t.mean()-z_i_t)).sum()
normalized_w.append(W_i_t)
posteriors.append(mc_post_t)
return torch.stack(posteriors), torch.stack(normalized_w), torch.stack(selected_indexs)
def sisr_with_adaptive_resampling(self, y_ts, T, min_weight):
z_1 = torch.tensor([0])
z_i_t = self.sample_z_step(z_1.repeat((self.N_s)))
w_i_t = y_ts[0]/z_i_t
W_i_t = w_i_t/w_i_t.sum()
mc_post_t = (W_i_t*torch.abs(z_1-z_i_t)).sum()
normalized_w, Z_ts, posteriors, selected_indexs = [], [], [], []
prev_w_i_t = torch.ones(self.N_s)
Z_t = 1
normalized_w.append(W_i_t)
posteriors.append(mc_post_t)
for t in range(1, T):
z_i_t = self.sample_z_step(z_i_t)
alpha = y_ts[t]/(y_ts[t-1]*z_i_t)
w_i_t = w_i_t*alpha
# # normalization const
Z_ts.append(self.calc_Z(w_i_t, prev_w_i_t, Z_t))
if self.ESS(w_i_t, ess_min=min_weight):
indexes, z_i_t, w_i_t = self.resample_permute(w_i_t, z_i_t)
selected_indexs.append((t, indexes))
W_i_t = torch.nn.functional.normalize(w_i_t, dim=-1)
mc_post_t = (W_i_t*torch.abs(z_i_t.mean()-z_i_t)).sum()
normalized_w.append(W_i_t)
posteriors.append(mc_post_t)
return torch.stack(posteriors), torch.stack(normalized_w), torch.stack(Z_ts), selected_indexs
def ESS(self, w_i_t, ess_min):
ess = w_i_t.sum().square() / w_i_t.square().sum()
if ess < ess_min:
return True
else:
return False
def calc_Z(self, w_i_t, prev_w_i_t, prev_Z_t):
# Z_t/Z_{t-1}
Z_t_Z_t_1 = w_i_t.sum() / prev_w_i_t.sum()
# Z_t
return prev_Z_t * Z_t_Z_t_1
particles = 5
T = 10
model = GaussianSequenceModel(init_constants, N_s=particles)
samples = model.sample(T)
print('sample (zts, yts) shape', samples.shape)
y_ts = torch.squeeze(samples.split(1)[-1])
print('only yts shape : ',y_ts.shape)
# SIS
p_sis, w_sis = model.sis(y_ts, T)
# SISR
p_smc, w_smc, s_idxs = model.sisr(y_ts, T)
# SISR_Adaptive
p_g_sisr_ad, w_g_sisr_ad, nc_g_sisr_ad , s_idxs_sisr_ad = model.sisr_with_adaptive_resampling(y_ts, T, min_weight=particles/5)
# print(f'Normalized weights for sequential_importance_sampling : \n {weights_sis}')
print(f'Normalized weights for sequential_importance_resampling : \n {w_smc}')
print(f'Normalized weights for sequential_importance_resampling_adaptive : \n {w_g_sisr_ad}')
sample (zts, yts) shape torch.Size([2, 10, 5])
only yts shape : torch.Size([10, 5])
Normalized weights for sequential_importance_resampling :
tensor([[-6.6174e-01, 6.6187e-02, -9.7740e-02, -4.4313e-01, 2.1364e+00],
[-1.8866e-01, 8.9358e-01, 3.5519e-01, 7.0895e-02, 1.8638e-01],
[-8.6953e-02, -3.1492e-04, -2.4521e-01, 9.1163e-01, -3.1818e-01],
[-4.1938e-02, 9.9866e-01, -1.2471e-02, -2.3220e-02, -1.5050e-02],
[-1.6795e-02, -5.5831e-02, -1.2844e-03, -3.8192e-02, 9.9757e-01],
[ 4.9757e-01, -4.2751e-01, -7.5116e-01, 2.7587e-02, -6.8301e-02],
[ 4.9504e-02, -9.9616e-01, -5.9280e-02, 3.9084e-02, 1.3173e-02],
[-3.2842e-02, -1.6739e-01, 1.7384e-01, 9.2731e-01, 2.8420e-01],
[-9.8761e-01, 1.2042e-01, 7.7607e-02, 2.1090e-02, -6.0432e-02],
[ 5.8333e-01, 5.7038e-01, -2.4771e-01, -1.4537e-01, -5.0191e-01]])
Normalized weights for sequential_importance_resampling_adaptive :
tensor([[ 1.4396e+00, -1.6778e-02, -5.5203e-02, 1.5816e-01, -5.2582e-01],
[ 2.9889e-01, -3.5871e-03, -7.2293e-03, 2.7421e-02, 9.5386e-01],
[ 8.6965e-01, -4.0461e-06, -1.5039e-03, 5.4041e-03, 4.9363e-01],
[ 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01],
[ 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01],
[ 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01],
[ 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01],
[-2.4316e-02, -9.8074e-02, -5.3355e-01, -2.8268e-01, -7.9070e-01],
[ 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01],
[ 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01, 4.4721e-01]])
def plot_sis_weights(weights, n_steps, spacing=1.5, max_size=0.3):
"""
Plot the evolution of weights in the sequential importance sampling (SIS) algorithm.
Parameters
----------
weights: array(n_particles, n_steps)
Weights at each time step.
n_steps: int
Number of steps to plot.
spacing: float
Spacing between particles.
max_size: float
Maximum size of the particles.
"""
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_aspect(1)
weights_subset = weights[:n_steps]
for col, weights_row in enumerate(weights_subset):
norm_cst = weights_row.sum()
radii = weights_row / norm_cst * max_size
for row, rad in enumerate(radii):
if col != n_steps - 1:
plt.arrow(spacing * (col + 0.25), row, 0.6, 0, width=0.05,
edgecolor="white", facecolor="tab:gray")
circle = plt.Circle((spacing * col, row), rad, color="tab:red")
ax.add_artist(circle)
plt.xlim(-1, n_steps * spacing)
plt.xlabel("Iteration (t)")
plt.ylabel("Particle index (i)")
xticks_pos = torch.arange(0, n_steps * spacing - 1, 2)
xticks_lab = torch.arange(1, n_steps + 1)
plt.xticks(xticks_pos, xticks_lab)
return fig, ax
spacing = 2
fig, ax = plot_sis_weights(w_sis, n_steps=10, spacing=spacing)
plt.tight_layout()
def plot_smc_weights(weights, indexes, n_steps, spacing=1.5, max_size=0.3):
"""
Plot the evolution of weights in the sequential Monte Carlo (SMC) algorithm.
Parameters
----------
weights: array(n_particles, n_steps)
Weights at each time step.
n_steps: int
Number of steps to plot.
spacing: float
Spacing between particles.
max_size: float
Maximum size of the particles.
Returns
-------
fig: matplotlib.figure.Figure
Figure containing the plot.
"""
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_aspect(1)
weights_subset = weights[:n_steps]
# sampled indices represent the "position" of weights at the next time step
ix_subset = indexes[:n_steps]
for it, (weights_row, p_target) in enumerate(zip(weights_subset, ix_subset)):
norm_cst = weights_row.sum()
radii = weights_row / norm_cst * max_size
for particle_ix, (rad, target_ix) in enumerate(zip(radii, p_target)):
# print(rad, target_ix)
if it != n_steps - 2:
diff = particle_ix - target_ix
plt.arrow(spacing * (it + 0.15), target_ix, 1.3, diff, width=0.05,
edgecolor="white", facecolor="tab:gray", length_includes_head=True)
circle = plt.Circle((spacing * it, particle_ix), rad, color="tab:blue")
ax.add_artist(circle)
plt.xlim(-1, n_steps * spacing - 2)
plt.xlabel("Iteration (t)")
plt.ylabel("Particle index (i)")
xticks_pos = torch.arange(0, n_steps * spacing - 2, 2)
xticks_lab = torch.arange(1, n_steps)
plt.xticks(xticks_pos, xticks_lab)
# ylims = ax.axes.get_ylim() # to-do: grab this value for SCM-particle descendents' plot
return fig, ax
fig, ax = plot_smc_weights(w_smc, s_idxs, n_steps=10, spacing=spacing)
ylims = ax.axes.get_ylim()
plt.tight_layout()
class SampleGeneratorSSM:
def __init__(self, z_function: Callable, q=0.001, r=0.05):
self.z_function = z_function
self.q = q
self.r = r
def z_dist(self, zt: torch.Tensor):
return torch.distributions.MultivariateNormal(loc=z_function(zt),
covariance_matrix=self.q * torch.eye(zt.shape[-1]))
def y_dist(self, zt: torch.Tensor):
return torch.distributions.MultivariateNormal(loc=zt,
covariance_matrix=self.r * torch.eye(zt.shape[-1]))
def sample_y(self, zt_next):
p_y = self.y_dist(zt_next)
yt_next = p_y.sample()
return yt_next
def sample_z(self, zt):
p_z = self.z_dist(zt)
zt_next = p_z.sample()
return zt_next
def _step(self, zt = torch.tensor([1.5, 0.0])):
zt_next = self.sample_z(zt)
yt_next = self.sample_y(zt_next)
return zt_next, torch.stack((zt_next, yt_next))
def samples(self, z0, N_s = 100):
def scan(func, init_values, length):
carry = init_values
logs_chain = []
for i in range(length):
carry, logs = func(carry)
logs_chain.append(logs)
logs_output = torch.stack(logs_chain,dim=1)
return carry, logs_output
_, logs_chain = scan(self._step, z0, N_s)
# (zt,yt)
return logs_chain
def z_function(z, delta=0.4):
z_x = z[0] + delta * torch.sin(z[1])
z_y = z[1] + delta * torch.cos(z[0])
return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))
def z_function_sampler(z, delta=0.4):
z_x = z[0] + delta * torch.sin(z[1])
z_y = z[1] + delta * torch.cos(z[0])
return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))
def z_function_predictor(z, delta=0.8):
z_x = z[0] + delta * z[1]
z_y = z[1] + delta * z[0]
return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))
T = 100
N_s = 5
# Samples generation for N_s particles
samples_list = []
samples_actual_z_list = []
generator_model = SampleGeneratorSSM(z_function_sampler)
for _ in range(N_s):
z0 = torch.tensor([1.5, 0.0])
samples = generator_model.samples(z0, T)
samples_list.append(samples[1])
samples_actual_z_list.append(samples[0])
ns_samples = torch.stack(samples_list)
ns_zt_actual_samples = torch.stack(samples_actual_z_list)
# ns_samples = ns_samples.transpose(0,1)
print(ns_samples.shape)
# print(ns_zt_actual_samples.shape)
torch.Size([5, 100, 2])
class RandomSSMSampler:
def __init__(self, z_function: Callable, N_s=10, q=0.001, r=0.05):
self.z_function = z_function
self.q = q
self.r = r
self.N_s = N_s
def z_dist(self, zt: torch.Tensor):
return torch.distributions.MultivariateNormal(loc=z_function(zt),
covariance_matrix=self.q * torch.eye(zt.shape[-1]), validate_args=False)
def y_dist(self, zt: torch.Tensor):
return torch.distributions.MultivariateNormal(loc=zt,
covariance_matrix=self.r * torch.eye(zt.shape[-1]), validate_args=False)
def z_step(self, zt: torch.Tensor):
return torch.stack([self.z_dist(z).sample() for z in zt])
def alpha_step(self, zt: torch.Tensor, yt: torch.Tensor):
def alpha_per_sample(zt_per_sample: torch.Tensor, yt_per_sample: torch.Tensor):
return self.y_dist(zt_per_sample).log_prob(yt_per_sample)
return torch.vmap(alpha_per_sample)(zt, yt)
def inits(self, sample_shape):
mu_0 = torch.zeros(sample_shape)
z_i_0 = torch.tensor([1.5, 0.0]).repeat((5,1))
w_i_0 = torch.zeros(sample_shape[0])
return mu_0, z_i_0, w_i_0
def resample_permute(self, w_i_t, z_i_t):
def resample():
indexs = [torch.distributions.Categorical(logits=w_i_t).sample((self.N_s,))]
return indexs, torch.cat([z_i_t[index] for index in indexs])
indexs, z_i_t = resample()
w_i_t = torch.log(torch.tensor(1/self.N_s)) * torch.ones((self.N_s))
return indexs, z_i_t, w_i_t
def sisr_with_adaptive_resampling(self, y_ts, T, min_weight):
z_1 = torch.tensor([1.5, 0.0])
z_i_t = self.z_step(z_1.repeat((self.N_s,1)))
w_i_t = torch.stack([self.y_dist(zt).log_prob(yt) for zt,yt in zip(z_i_t, y_ts[0])])
W_i_t = torch.exp(w_i_t - torch.logsumexp(w_i_t, dim=0))
mc_post_t = torch.matmul(W_i_t,torch.abs(z_1-z_i_t))
normalized_w, Z_ts, posteriors, selected_indexs, particle_zs = [], [], [], [], []
prev_w_i_t = torch.ones(self.N_s)
Z_t = torch.tensor([1,1])
normalized_w.append(W_i_t)
posteriors.append(mc_post_t)
particle_zs.append(z_i_t)
for t in range(1, T):
z_i_t = self.z_step(z_i_t)
alpha = torch.stack([self.y_dist(zt).log_prob(yt) for zt,yt in zip(z_i_t, y_ts[0])])
w_i_t = w_i_t + alpha
if self.ESS(w_i_t, ess_min=min_weight):
indexes, z_i_t, w_i_t = self.resample_permute(w_i_t, z_i_t)
selected_indexs.append((t, indexes))
# print('resample')
W_i_t = torch.exp(w_i_t - torch.logsumexp(w_i_t, dim=0))
mc_post_t = torch.matmul(W_i_t,torch.abs(z_1-z_i_t))
normalized_w.append(W_i_t)
posteriors.append(mc_post_t)
# normalization const
Z_t = self.calc_Z(w_i_t, prev_w_i_t, Z_t)
Z_ts.append(Z_t)
prev_w_i_t = w_i_t
particle_zs.append(z_i_t)
# print(W_i_t, w_i_t)
states = self.extract_states(torch.stack(normalized_w), torch.stack(particle_zs))
return torch.stack(posteriors), states, torch.stack(Z_ts), selected_indexs
def extract_states(self, normalized_w, particle_zs):
states = []
for t in range(normalized_w.shape[0]):
index = torch.distributions.Categorical(probs=normalized_w[t]).sample()
states.append(particle_zs[t][index])
return torch.stack(states)
def ESS(self, w_i_t, ess_min):
ess = w_i_t.sum().square() / w_i_t.square().sum()
if ess < ess_min:
return True
else:
return False
def calc_Z(self, w_i_t, prev_w_i_t, prev_Z_t):
# Z_t/Z_{t-1}
Z_t_Z_t_1 = torch.sum(w_i_t,dim=0) / torch.sum(prev_w_i_t,dim=0)
# Z_t
return prev_Z_t * Z_t_Z_t_1
def _step(self, zt = torch.tensor([1.5, 0.0])):
zt_next = self.z_dist(zt).sample()
yt_next = self.y_dist(zt_next).sample()
return zt_next, torch.stack((zt_next, yt_next))
def samples(self, z0, T = 100):
def scan(func, init_values, length):
carry = init_values
logs_chain = []
for i in range(length):
carry, logs = func(carry)
logs_chain.append(logs)
logs_output = torch.stack(logs_chain,dim=1)
return carry, logs_output
_, logs_chain = scan(self._step, z0, T)
# (zt,yt)
return logs_chain
def z_function_sampler(z, delta=0.4):
z_x = z[0] + delta * torch.sin(z[1])
z_y = z[1] + delta * torch.cos(z[0])
return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))
def z_function_predictor(z, delta=0.8):
z_x = z[0] + delta * z[1]
z_y = z[1] + delta * z[0]
return torch.cat((torch.unsqueeze(z_x,0), torch.unsqueeze(z_y,0)))
particles = 5
model = RandomSSMSampler(z_function_predictor, N_s=particles)
# model.sisr_with_adaptive_resampling(ns_samples.transpose(0,1), T, min_weight=particles/5)
predicted_posteriors, predicted_zs, predicted_Zts, resample_indexes = model.sisr_with_adaptive_resampling(ns_samples.transpose(0,1), T, min_weight=particles*0.78)
# print(predicted_zs.shape)
after_learning = model.samples(predicted_zs[-1], 10)
print(after_learning[1].shape)
torch.Size([10, 2])
def plot_inference(N_s, states, emissions, estimates=None, after_learning_eastimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
lines = []
# print(estimates.shape, after_learning_eastimates.shape)
for i in range(N_s):
s, e = states[i], emissions[i]
if ax is None:
fig, ax = plt.subplots()
if show_states:
line, = ax.plot(*s.T, label=f'{i} particle')
lines.append(line)
ax.plot(*e.T, "ok", fillstyle="none", ms=4, label="Observations")
if estimates is not None:
line, = ax.plot(*estimates.T, color="r", linewidth=2, label=f"{est_type} Estimate")
lines.append(line)
if after_learning_eastimates is not None:
line, = ax.plot(*after_learning_eastimates.T, color="b", linewidth=2, label=f"{est_type} after learning")
lines.append(line)
ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k", handles=lines)
#ax.set_aspect(aspect)
ax.set_title(title)
ax.axis('equal')
return ax
plot_inference(N_s, ns_samples, ns_zt_actual_samples, predicted_zs, after_learning[1], title='custom linear gaussian state space model', show_states=True)
<Axes: title={'center': 'custom linear gaussian state space model'}>