Created
July 16, 2025 09:36
-
-
Save scheidan/cafaa89250aaef712d7be43bfd2ba6cc to your computer and use it in GitHub Desktop.
Toy implementation of the Particle Gibbs with Ancestral Sampling for state space models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## ------------------------------------------------------- | |
| ## Simpel implmentation of the | |
| ## Particle Gibbs with Ancestral Sampling | |
| ## proposed on Lindsten et. al 2014 | |
| ## | |
| ## THIS IS CODE WAS NOT TESTED OR OPTIMIZED!!! Do not use it for anything | |
| ## importnat without checking. | |
| ## | |
| ## July 16, 2025 -- Andreas Scheidegger | |
| ## andreas.scheidegger@eawag.ch | |
| ## ------------------------------------------------------- | |
| import Pkg | |
| Pkg.activate(".") | |
| using Distributions | |
| using Plots | |
| ## ----------- | |
| ## 1. Particle Gibbs | |
| """ | |
| particle_gibbs(y, N, M, prior_sample, f_sample, g_logpdf) | |
| Run a Particle Gibbs with Ancestral Sampling sampler for a state space model | |
| ``` | |
| X_t ~ f(x_t | x_{t−1}) | |
| Y_t ~ p_obs(x_t | x_t) | |
| ``` | |
| to sample from the smoothed distribution | |
| ``` | |
| p(x_{1:T} | y_{1:T}). | |
| ``` | |
| ### Arguments: | |
| * `y` – vector of observations `y_1, ..., y_T` | |
| * `N` – number of trajectories sampled from PGAS Markov kernel | |
| * `M` – number of Particle Gibbs iterations (size of final samples) | |
| * `sample_x1` – `() -> x` : draw from p(x₁) | |
| * `sample_move` – `(x_prev) -> x_next` : draw from f(x_t | x_{t−1}) | |
| * `log_p_obs` – `(y_t, x_t) -> loglik` : log p(y_t | x_t) | |
| * `x_ref` - initial reference trajectory | |
| Returns a vector of `M` trajectories, each itself a `Vector`. | |
| ### Reference: | |
| Lindsten, F., Jordan, M.I., Schön, T.B., 2014. Particle gibbs with | |
| ancestor sampling. J. Mach. Learn. Res. 15, 2145–2184. | |
| """ | |
| function particle_gibbs(y, N, M, sample_x1, sample_move, log_p_obs; | |
| x_ref = [sample_x1() for _ in 1:length(y)]) | |
| T = length(y) | |
| samples = Vector{Vector{eltype(x_ref)}}(undef, M) | |
| for m in 1:M | |
| ## -- conditional SMC step, create N trajectories | |
| ## (alg 3, line 1:11) | |
| xs, anc, logw_last = csmc(y, N, x_ref, | |
| sample_x1, sample_move, log_p_obs) | |
| ## -- backward trace to sample construct one trajectory | |
| ## (alg 3, line 12:13) | |
| x_star = traceback(xs, anc, logw_last) | |
| samples[m] = x_star | |
| x_ref = x_star # refresh reference trajectory | |
| end | |
| return samples | |
| end | |
| ## ----------- | |
| ## 2. PGAS Markov kernel (alg 3, line 1:11) | |
| ## Note, we use r(x_t | x{t_1}, y_t) := f(x_t | x{t_1}) | |
| function csmc(y, N, x_ref, sample_x1, sample_move, log_p_obs) | |
| T = length(y) | |
| xs = [Vector{eltype(x_ref)}(undef, N) for _ in 1:T] | |
| anc = [Vector{Int}(undef, N) for _ in 1:T-1] # ancestor index | |
| logw = Vector{Float64}(undef, N) | |
| weights = Vector{Float64}(undef, N) | |
| ## -- time 1, initialise particles | |
| for i in 1:N | |
| if i == 1 | |
| xs[1][i] = x_ref[1] | |
| else | |
| xs[1][i] = sample_x1() # p(x_1) | |
| end | |
| logw[i] = log_p_obs(y[1], xs[1][i]) | |
| end | |
| normalise!(weights, logw) | |
| ## -- times 2...T | |
| for t in 2:T | |
| ## sample ancestors, with anc[t-1][i] = 1 | |
| anc[t-1] .= resample(weights; keep_index = 1) | |
| for i in 1:N | |
| if i == 1 | |
| xs[t][i] = x_ref[t] | |
| else | |
| parent_idx = anc[t-1][i] | |
| xs[t][i] = sample_move(xs[t-1][parent_idx]) # p(x_t | x_{t-1}) | |
| end | |
| logw[i] = log_p_obs(y[t], xs[t][i]) # p(y_t | x_t) | |
| end | |
| normalise!(weights, logw) | |
| end | |
| return xs, anc, logw | |
| end | |
| ## ----------- | |
| ## 3. Back-trace a single trajectory | |
| function traceback(xs, anc, logw_last) | |
| T = length(xs) | |
| idx = categorical_exp(logw_last) # final-time index | |
| path = Vector{eltype(xs[1])}(undef, T) | |
| for t in T:-1:1 | |
| path[t] = xs[t][idx] | |
| if t > 1 | |
| idx = anc[t-1][idx] | |
| end | |
| end | |
| return path | |
| end | |
| ## ----------- | |
| ## 4. Small helpers | |
| ## -- convert (unnormalised) log-weights to probabilities in-place | |
| function normalise!(w, logw) | |
| m = maximum(logw) | |
| w .= exp.(logw .- m) | |
| w ./= sum(w) | |
| return nothing | |
| end | |
| ## -- simple multinomial resampling, first index kept unchanged | |
| function resample(w; keep_index = 1) | |
| N = length(w) | |
| idx = Vector{Int}(undef, N) | |
| idx[1] = keep_index | |
| cdf = cumsum(w) | |
| for i in 2:N | |
| u = rand() | |
| idx[i] = searchsortedfirst(cdf, u) | |
| end | |
| return idx | |
| end | |
| ## -- draw Prob(index) ∝ exp(logw) | |
| function categorical_exp(logw) | |
| m = maximum(logw) | |
| w_norm = exp.(logw .- m) | |
| cdf = cumsum(w_norm) | |
| u = rand() * cdf[end] | |
| return searchsortedfirst(cdf, u) | |
| end | |
| # ------------------------------------------------- | |
| # Test with 1d SSM | |
| # simulate data | |
| T = 100 | |
| x = [sin(t/20)*12 for t in 1:T] | |
| y = x .+ rand(Normal(0, 0.5), T) | |
| # define SSM | |
| sample_x1() = rand(Normal(0, 2)) | |
| sample_move(xprev) = xprev + rand(Normal(0, 0.7)) | |
| log_p_obs(y_t, x_t) = logpdf(Normal(x_t, 0.5), y_t) # p(y_t | x_t) | |
| # run sampler | |
| samples = particle_gibbs(y, 500, 200, sample_x1, sample_move, log_p_obs); | |
| samples = stack(samples); | |
| # plot | |
| plot(stack(samples),label="", alpha = 0.1, color=:grey, xlabel="time", ylabel="x"); | |
| scatter!(y, label="y"); | |
| plot!(x, linewidth=2, color=:red, label="true trajectory") | |
| # ------------------------------------------------- | |
| # Test with 2d SSM | |
| # simulate data | |
| T = 100 | |
| x = [(sin(t/20)*12, t/4) for t in 1:T] | |
| y = [(xi[1] + rand(Normal(0, 0.5)), xi[2] + rand(Normal(0, 0.5))) for xi in x] | |
| # gap in the data | |
| for t in 50:75 | |
| y[t] = (NaN, NaN) | |
| end | |
| # define SSM | |
| sample_x1() = (rand(Normal(0, 2)), rand(Normal(0, 2))) | |
| sample_move(xprev) = xprev .+ (rand(Normal(0, 0.7)), rand(Normal(0, 0.7))) | |
| function log_p_obs(y_t, x_t) # p(y_t | x_t) | |
| if isnan(y_t[1]) | |
| zero(y_t[1]) | |
| else | |
| sum(logpdf.(Normal.(x_t, 0.5), y_t)) | |
| end | |
| end | |
| # run sampler | |
| # very simple reference trajectory. Works ususally. | |
| x_ref = [(-0.2*t+10, 0.1*randn()) for t in 1:T] | |
| # A "better" reference trajectory. However, the sampler mostly fails with this one... | |
| # x_ref = [(10 - 0.2*t + 0.1*randn(), 5 + 0.2*t) for t in 1:T] | |
| samples = particle_gibbs(y, 200, 500, sample_x1, sample_move, log_p_obs; x_ref = x_ref); | |
| samples = stack(samples); | |
| # plot | |
| plot(stack(samples)[1,:,:], stack(samples)[2,:,:], | |
| label="", alpha = 0.1, color=:grey, | |
| xlabel="x₁", ylabel="x₂"); | |
| scatter!(stack(y)[1,:], stack(y)[2,:], color=:green, label="y obs"); | |
| plot!(stack(x)[1,:], stack(x)[2,:], linewidth=2, color=:red, label="true trajectory"); | |
| plot!(stack(x_ref)[1,:], stack(x_ref)[2,:], color = :lightblue, label="x_ref") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment