Created
April 25, 2025 19:48
-
-
Save phuang1024/9a9c99b7a75d708c04d1c0012d77aa55 to your computer and use it in GitHub Desktop.
Motion blur in 3D rendering with PyTorch autograd
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| CAM_POS = torch.tensor([-5, 0, 2], dtype=torch.float) | |
| FOCAL_LEN = 1 | |
| # Facing +X | |
| # Each object is (x1, x2) coordinates for a line segment. | |
| OBJS: torch.Tensor = [] | |
| for x in torch.linspace(0, 10, 5): | |
| for y in torch.linspace(-10, 0, 10): | |
| OBJS.append(([x, y, 0], [x, y, 1])) | |
| OBJS = torch.tensor(OBJS, dtype=torch.float, requires_grad=True) | |
| RES = 800 | |
| FPS = 60 | |
| MB_FAC = 0.005 | |
| def project(objs): | |
| delta = objs - CAM_POS | |
| return delta[..., 1:] / delta[..., :1] * FOCAL_LEN | |
| def iter_projs_as_pixels(projection): | |
| objs = projection.detach().numpy() | |
| objs[..., 1] *= -1 | |
| objs = objs * RES + RES / 2 | |
| objs = objs.round().astype(int) | |
| return objs | |
| def render(objs): | |
| img = np.zeros((RES, RES, 3), dtype=np.uint8) | |
| for line in iter_projs_as_pixels(project(objs)): | |
| cv2.line(img, tuple(line[0]), tuple(line[1]), (255, 255, 255), 1) | |
| return img | |
| def render_mb(objs, velocities): | |
| objs.grad = None | |
| projection = project(objs) | |
| projection[..., 0].sum().backward() | |
| mb_x = (objs.grad.detach() * velocities).sum(dim=-1) | |
| objs.grad = None | |
| projection = project(objs) | |
| projection[..., 1].sum().backward() | |
| mb_y = (objs.grad.detach() * velocities).sum(dim=-1) | |
| with torch.no_grad(): | |
| img = np.zeros((RES, RES, 3), dtype=np.uint8) | |
| lines = iter_projs_as_pixels(project(objs)) | |
| for i in range(lines.shape[0]): | |
| points = np.array(( | |
| lines[i][0], | |
| lines[i][1], | |
| lines[i][1] - RES * MB_FAC * np.array((mb_x[i][1], mb_y[i][1])), | |
| lines[i][0] - RES * MB_FAC * np.array((mb_x[i][0], mb_y[i][0])), | |
| ), dtype=int) | |
| cv2.fillConvexPoly(img, points, (30, 30, 30)) | |
| return img | |
| def main(): | |
| global OBJS | |
| x_vel = 1 | |
| y_vel = 0 | |
| do_mb = True | |
| run = True | |
| running_delta = 0 | |
| while True: | |
| velocities = torch.zeros_like(OBJS) | |
| velocities[..., 1] = x_vel | |
| velocities[..., 2] = y_vel | |
| img = np.zeros((RES, RES, 3), dtype=np.uint8) | |
| if do_mb: | |
| img += render_mb(OBJS, velocities) | |
| with torch.no_grad(): | |
| img = np.maximum(img, render(OBJS)) | |
| with torch.no_grad(): | |
| if run: | |
| img = 255 - img | |
| cv2.imshow("frame", img) | |
| key = cv2.waitKey(1000 // FPS) | |
| if key == ord("q"): | |
| break | |
| elif key == ord("w"): | |
| x_vel *= 1.1 | |
| elif key == ord("s"): | |
| x_vel /= 1.1 | |
| elif key == ord("m"): | |
| do_mb = not do_mb | |
| elif key == ord(" "): | |
| run = not run | |
| if run: | |
| OBJS += velocities / FPS | |
| running_delta += x_vel / FPS | |
| if running_delta > 10: | |
| OBJS[..., 1] -= running_delta | |
| running_delta = 0 | |
| if random.random() < 0.01: | |
| y_vel = random.uniform(-1, 1) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment