Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created May 7, 2025 11:27
Show Gist options
  • Select an option

  • Save llandsmeer/e4dab6b39ba43711bc7a544436e1d698 to your computer and use it in GitHub Desktop.

Select an option

Save llandsmeer/e4dab6b39ba43711bc7a544436e1d698 to your computer and use it in GitHub Desktop.

Revisions

  1. llandsmeer created this gist May 7, 2025.
    48 changes: 48 additions & 0 deletions sdf_jax.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,48 @@
    import jax
    import jax.numpy as jnp

    import matplotlib.pyplot as plt

    def normalize(x):
    return x / jnp.linalg.norm(x)

    def vec(x, y, z):
    return jnp.array([x, y, z], dtype='float32')

    screen_x, screen_y = jnp.meshgrid(jnp.linspace(-1, 1), jnp.linspace(-1, 1))
    screen_z = 1 * jnp.ones_like(screen_x)

    light = normalize(vec(0, -1, 0.))
    light_color = vec(1, 1, 1)
    base_color = vec(0.5, 0.5, 0.5)

    def sdf(x):
    return jnp.sqrt(((x - vec(0.2, .1, 1.0))**2).sum()) - 0.5

    def sdf_normal(x):
    return normalize(jax.grad(sdf)(x))

    rays = jnp.vstack([
    screen_x.flatten(),
    screen_y.flatten(),
    screen_z.flatten()
    ]).T

    origins = jnp.zeros_like(rays)
    rays = jax.vmap(normalize)(rays)

    def march(origin, ray):
    def loop(at, _):
    d = sdf(at)
    at = at + d * ray * 0.1
    return at, at
    at, _ = jax.lax.scan(loop, origin, length=100)
    return (sdf(at) < 0.01) * (
    base_color +
    light_color * (sdf_normal(at) @ light)
    )

    img = jax.vmap(march)(origins, rays).reshape(*screen_z.shape, 3)
    plt.imshow(img)
    plt.show()