Skip to content

Instantly share code, notes, and snippets.

@scheidan
Created July 16, 2025 09:36
Show Gist options
  • Select an option

  • Save scheidan/cafaa89250aaef712d7be43bfd2ba6cc to your computer and use it in GitHub Desktop.

Select an option

Save scheidan/cafaa89250aaef712d7be43bfd2ba6cc to your computer and use it in GitHub Desktop.
Toy implementation of the Particle Gibbs with Ancestral Sampling for state space models
## -------------------------------------------------------
## Simpel implmentation of the
## Particle Gibbs with Ancestral Sampling
## proposed on Lindsten et. al 2014
##
## THIS IS CODE WAS NOT TESTED OR OPTIMIZED!!! Do not use it for anything
## importnat without checking.
##
## July 16, 2025 -- Andreas Scheidegger
## andreas.scheidegger@eawag.ch
## -------------------------------------------------------
import Pkg
Pkg.activate(".")
using Distributions
using Plots
## -----------
## 1. Particle Gibbs
"""
particle_gibbs(y, N, M, prior_sample, f_sample, g_logpdf)
Run a Particle Gibbs with Ancestral Sampling sampler for a state space model
```
X_t ~ f(x_t | x_{t−1})
Y_t ~ p_obs(x_t | x_t)
```
to sample from the smoothed distribution
```
p(x_{1:T} | y_{1:T}).
```
### Arguments:
* `y` – vector of observations `y_1, ..., y_T`
* `N` – number of trajectories sampled from PGAS Markov kernel
* `M` – number of Particle Gibbs iterations (size of final samples)
* `sample_x1` – `() -> x` : draw from p(x₁)
* `sample_move` – `(x_prev) -> x_next` : draw from f(x_t | x_{t−1})
* `log_p_obs` – `(y_t, x_t) -> loglik` : log p(y_t | x_t)
* `x_ref` - initial reference trajectory
Returns a vector of `M` trajectories, each itself a `Vector`.
### Reference:
Lindsten, F., Jordan, M.I., Schön, T.B., 2014. Particle gibbs with
ancestor sampling. J. Mach. Learn. Res. 15, 2145–2184.
"""
function particle_gibbs(y, N, M, sample_x1, sample_move, log_p_obs;
x_ref = [sample_x1() for _ in 1:length(y)])
T = length(y)
samples = Vector{Vector{eltype(x_ref)}}(undef, M)
for m in 1:M
## -- conditional SMC step, create N trajectories
## (alg 3, line 1:11)
xs, anc, logw_last = csmc(y, N, x_ref,
sample_x1, sample_move, log_p_obs)
## -- backward trace to sample construct one trajectory
## (alg 3, line 12:13)
x_star = traceback(xs, anc, logw_last)
samples[m] = x_star
x_ref = x_star # refresh reference trajectory
end
return samples
end
## -----------
## 2. PGAS Markov kernel (alg 3, line 1:11)
## Note, we use r(x_t | x{t_1}, y_t) := f(x_t | x{t_1})
function csmc(y, N, x_ref, sample_x1, sample_move, log_p_obs)
T = length(y)
xs = [Vector{eltype(x_ref)}(undef, N) for _ in 1:T]
anc = [Vector{Int}(undef, N) for _ in 1:T-1] # ancestor index
logw = Vector{Float64}(undef, N)
weights = Vector{Float64}(undef, N)
## -- time 1, initialise particles
for i in 1:N
if i == 1
xs[1][i] = x_ref[1]
else
xs[1][i] = sample_x1() # p(x_1)
end
logw[i] = log_p_obs(y[1], xs[1][i])
end
normalise!(weights, logw)
## -- times 2...T
for t in 2:T
## sample ancestors, with anc[t-1][i] = 1
anc[t-1] .= resample(weights; keep_index = 1)
for i in 1:N
if i == 1
xs[t][i] = x_ref[t]
else
parent_idx = anc[t-1][i]
xs[t][i] = sample_move(xs[t-1][parent_idx]) # p(x_t | x_{t-1})
end
logw[i] = log_p_obs(y[t], xs[t][i]) # p(y_t | x_t)
end
normalise!(weights, logw)
end
return xs, anc, logw
end
## -----------
## 3. Back-trace a single trajectory
function traceback(xs, anc, logw_last)
T = length(xs)
idx = categorical_exp(logw_last) # final-time index
path = Vector{eltype(xs[1])}(undef, T)
for t in T:-1:1
path[t] = xs[t][idx]
if t > 1
idx = anc[t-1][idx]
end
end
return path
end
## -----------
## 4. Small helpers
## -- convert (unnormalised) log-weights to probabilities in-place
function normalise!(w, logw)
m = maximum(logw)
w .= exp.(logw .- m)
w ./= sum(w)
return nothing
end
## -- simple multinomial resampling, first index kept unchanged
function resample(w; keep_index = 1)
N = length(w)
idx = Vector{Int}(undef, N)
idx[1] = keep_index
cdf = cumsum(w)
for i in 2:N
u = rand()
idx[i] = searchsortedfirst(cdf, u)
end
return idx
end
## -- draw Prob(index) ∝ exp(logw)
function categorical_exp(logw)
m = maximum(logw)
w_norm = exp.(logw .- m)
cdf = cumsum(w_norm)
u = rand() * cdf[end]
return searchsortedfirst(cdf, u)
end
# -------------------------------------------------
# Test with 1d SSM
# simulate data
T = 100
x = [sin(t/20)*12 for t in 1:T]
y = x .+ rand(Normal(0, 0.5), T)
# define SSM
sample_x1() = rand(Normal(0, 2))
sample_move(xprev) = xprev + rand(Normal(0, 0.7))
log_p_obs(y_t, x_t) = logpdf(Normal(x_t, 0.5), y_t) # p(y_t | x_t)
# run sampler
samples = particle_gibbs(y, 500, 200, sample_x1, sample_move, log_p_obs);
samples = stack(samples);
# plot
plot(stack(samples),label="", alpha = 0.1, color=:grey, xlabel="time", ylabel="x");
scatter!(y, label="y");
plot!(x, linewidth=2, color=:red, label="true trajectory")
# -------------------------------------------------
# Test with 2d SSM
# simulate data
T = 100
x = [(sin(t/20)*12, t/4) for t in 1:T]
y = [(xi[1] + rand(Normal(0, 0.5)), xi[2] + rand(Normal(0, 0.5))) for xi in x]
# gap in the data
for t in 50:75
y[t] = (NaN, NaN)
end
# define SSM
sample_x1() = (rand(Normal(0, 2)), rand(Normal(0, 2)))
sample_move(xprev) = xprev .+ (rand(Normal(0, 0.7)), rand(Normal(0, 0.7)))
function log_p_obs(y_t, x_t) # p(y_t | x_t)
if isnan(y_t[1])
zero(y_t[1])
else
sum(logpdf.(Normal.(x_t, 0.5), y_t))
end
end
# run sampler
# very simple reference trajectory. Works ususally.
x_ref = [(-0.2*t+10, 0.1*randn()) for t in 1:T]
# A "better" reference trajectory. However, the sampler mostly fails with this one...
# x_ref = [(10 - 0.2*t + 0.1*randn(), 5 + 0.2*t) for t in 1:T]
samples = particle_gibbs(y, 200, 500, sample_x1, sample_move, log_p_obs; x_ref = x_ref);
samples = stack(samples);
# plot
plot(stack(samples)[1,:,:], stack(samples)[2,:,:],
label="", alpha = 0.1, color=:grey,
xlabel="x₁", ylabel="x₂");
scatter!(stack(y)[1,:], stack(y)[2,:], color=:green, label="y obs");
plot!(stack(x)[1,:], stack(x)[2,:], linewidth=2, color=:red, label="true trajectory");
plot!(stack(x_ref)[1,:], stack(x_ref)[2,:], color = :lightblue, label="x_ref")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment