Skip to content

Instantly share code, notes, and snippets.

@phuang1024
Created April 25, 2025 19:48
Show Gist options
  • Select an option

  • Save phuang1024/9a9c99b7a75d708c04d1c0012d77aa55 to your computer and use it in GitHub Desktop.

Select an option

Save phuang1024/9a9c99b7a75d708c04d1c0012d77aa55 to your computer and use it in GitHub Desktop.
Motion blur in 3D rendering with PyTorch autograd
import random
import cv2
import numpy as np
import torch
CAM_POS = torch.tensor([-5, 0, 2], dtype=torch.float)
FOCAL_LEN = 1
# Facing +X
# Each object is (x1, x2) coordinates for a line segment.
OBJS: torch.Tensor = []
for x in torch.linspace(0, 10, 5):
for y in torch.linspace(-10, 0, 10):
OBJS.append(([x, y, 0], [x, y, 1]))
OBJS = torch.tensor(OBJS, dtype=torch.float, requires_grad=True)
RES = 800
FPS = 60
MB_FAC = 0.005
def project(objs):
delta = objs - CAM_POS
return delta[..., 1:] / delta[..., :1] * FOCAL_LEN
def iter_projs_as_pixels(projection):
objs = projection.detach().numpy()
objs[..., 1] *= -1
objs = objs * RES + RES / 2
objs = objs.round().astype(int)
return objs
def render(objs):
img = np.zeros((RES, RES, 3), dtype=np.uint8)
for line in iter_projs_as_pixels(project(objs)):
cv2.line(img, tuple(line[0]), tuple(line[1]), (255, 255, 255), 1)
return img
def render_mb(objs, velocities):
objs.grad = None
projection = project(objs)
projection[..., 0].sum().backward()
mb_x = (objs.grad.detach() * velocities).sum(dim=-1)
objs.grad = None
projection = project(objs)
projection[..., 1].sum().backward()
mb_y = (objs.grad.detach() * velocities).sum(dim=-1)
with torch.no_grad():
img = np.zeros((RES, RES, 3), dtype=np.uint8)
lines = iter_projs_as_pixels(project(objs))
for i in range(lines.shape[0]):
points = np.array((
lines[i][0],
lines[i][1],
lines[i][1] - RES * MB_FAC * np.array((mb_x[i][1], mb_y[i][1])),
lines[i][0] - RES * MB_FAC * np.array((mb_x[i][0], mb_y[i][0])),
), dtype=int)
cv2.fillConvexPoly(img, points, (30, 30, 30))
return img
def main():
global OBJS
x_vel = 1
y_vel = 0
do_mb = True
run = True
running_delta = 0
while True:
velocities = torch.zeros_like(OBJS)
velocities[..., 1] = x_vel
velocities[..., 2] = y_vel
img = np.zeros((RES, RES, 3), dtype=np.uint8)
if do_mb:
img += render_mb(OBJS, velocities)
with torch.no_grad():
img = np.maximum(img, render(OBJS))
with torch.no_grad():
if run:
img = 255 - img
cv2.imshow("frame", img)
key = cv2.waitKey(1000 // FPS)
if key == ord("q"):
break
elif key == ord("w"):
x_vel *= 1.1
elif key == ord("s"):
x_vel /= 1.1
elif key == ord("m"):
do_mb = not do_mb
elif key == ord(" "):
run = not run
if run:
OBJS += velocities / FPS
running_delta += x_vel / FPS
if running_delta > 10:
OBJS[..., 1] -= running_delta
running_delta = 0
if random.random() < 0.01:
y_vel = random.uniform(-1, 1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment