Last active
April 3, 2026 13:48
-
-
Save nagadomi/efad14292b98cfe83ec2778cfad54cd3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ctypes | |
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import torchvision | |
| # NVOF Constants | |
| NV_OF_API_MAJOR_VERSION = 2 | |
| NV_OF_API_MINOR_VERSION = 0 | |
| NV_OF_API_VERSION = (NV_OF_API_MAJOR_VERSION << 4) | NV_OF_API_MINOR_VERSION | |
| MAX_NUM_PLANES = 3 | |
| class NV_OF_STATUS: | |
| SUCCESS = 0 | |
| ERR_OF_NOT_AVAILABLE = 1 | |
| ERR_UNSUPPORTED_DEVICE = 2 | |
| ERR_DEVICE_DOES_NOT_EXIST = 3 | |
| ERR_INVALID_PTR = 4 | |
| ERR_INVALID_PARAM = 5 | |
| ERR_INVALID_CALL = 6 | |
| ERR_INVALID_VERSION = 7 | |
| ERR_OUT_OF_MEMORY = 8 | |
| ERR_NOT_INITIALIZED = 9 | |
| ERR_UNSUPPORTED_FEATURE = 10 | |
| ERR_GENERIC = 11 | |
| @staticmethod | |
| def to_string(status): | |
| for name, value in NV_OF_STATUS.__dict__.items(): | |
| if value == status: | |
| return name | |
| return f"UNKNOWN({status})" | |
| class NV_OF_MODE: | |
| UNDEFINED = 0 | |
| OPTICALFLOW = 1 | |
| STEREODISPARITY = 2 | |
| class NV_OF_PERF_LEVEL: | |
| UNDEFINED = 0 | |
| SLOW = 5 | |
| MEDIUM = 10 | |
| FAST = 20 | |
| class NV_OF_OUTPUT_VECTOR_GRID_SIZE: | |
| UNDEFINED = 0 | |
| GRID_1 = 1 | |
| GRID_2 = 2 | |
| GRID_4 = 4 | |
| class NV_OF_HINT_VECTOR_GRID_SIZE: | |
| UNDEFINED = 0 | |
| GRID_1 = 1 | |
| GRID_2 = 2 | |
| GRID_4 = 4 | |
| GRID_8 = 8 | |
| class NV_OF_BUFFER_USAGE: | |
| UNDEFINED = 0 | |
| INPUT = 1 | |
| OUTPUT = 2 | |
| HINT = 3 | |
| COST = 4 | |
| class NV_OF_BUFFER_FORMAT: | |
| UNDEFINED = 0 | |
| GRAYSCALE8 = 1 | |
| NV12 = 2 | |
| ABGR8 = 3 | |
| SHORT = 4 | |
| SHORT2 = 5 | |
| UINT = 6 | |
| UINT8 = 7 | |
| class NV_OF_CUDA_BUFFER_TYPE: | |
| UNDEFINED = 0 | |
| CUARRAY = 1 | |
| CUDEVICEPTR = 2 | |
| class NV_OF_CAPS: | |
| SUPPORTED_OUTPUT_GRID_SIZES = 0 | |
| SUPPORTED_HINT_GRID_SIZES = 1 | |
| SUPPORT_HINT_WITH_OF_MODE = 2 | |
| SUPPORT_HINT_WITH_ST_MODE = 3 | |
| WIDTH_MIN = 4 | |
| HEIGHT_MIN = 5 | |
| WIDTH_MAX = 6 | |
| HEIGHT_MAX = 7 | |
| SUPPORT_ROI = 8 | |
| SUPPORT_ROI_MAX_NUM = 9 | |
| SUPPORT_MAX = 10 | |
| # Structs | |
| class NV_OF_INIT_PARAMS(ctypes.Structure): | |
| _pack_ = 1 | |
| _fields_ = [ | |
| ("width", ctypes.c_uint32), | |
| ("height", ctypes.c_uint32), | |
| ("outGridSize", ctypes.c_int), | |
| ("hintGridSize", ctypes.c_int), | |
| ("mode", ctypes.c_int), | |
| ("perfLevel", ctypes.c_int), | |
| ("enableExternalHints", ctypes.c_int), | |
| ("enableOutputCost", ctypes.c_int), | |
| ("hPrivData", ctypes.c_void_p), | |
| ("disparityRange", ctypes.c_int), | |
| ("enableRoi", ctypes.c_int), | |
| ] | |
| class NV_OF_BUFFER_DESCRIPTOR(ctypes.Structure): | |
| _pack_ = 1 | |
| _fields_ = [ | |
| ("width", ctypes.c_uint32), | |
| ("height", ctypes.c_uint32), | |
| ("bufferUsage", ctypes.c_int), | |
| ("bufferFormat", ctypes.c_int), | |
| ] | |
| class NV_OF_ROI_RECT(ctypes.Structure): | |
| _fields_ = [ | |
| ("start_x", ctypes.c_uint32), | |
| ("start_y", ctypes.c_uint32), | |
| ("width", ctypes.c_uint32), | |
| ("height", ctypes.c_uint32), | |
| ] | |
| class NV_OF_EXECUTE_INPUT_PARAMS(ctypes.Structure): | |
| _pack_ = 1 | |
| _fields_ = [ | |
| ("inputFrame", ctypes.c_void_p), | |
| ("referenceFrame", ctypes.c_void_p), | |
| ("externalHints", ctypes.c_void_p), | |
| ("disableTemporalHints", ctypes.c_int), | |
| ("padding", ctypes.c_uint32), | |
| ("hPrivData", ctypes.c_void_p), | |
| ("padding2", ctypes.c_uint32), | |
| ("numRois", ctypes.c_uint32), | |
| ("roiData", ctypes.POINTER(NV_OF_ROI_RECT)), | |
| ] | |
| class NV_OF_EXECUTE_OUTPUT_PARAMS(ctypes.Structure): | |
| _pack_ = 1 | |
| _fields_ = [ | |
| ("outputBuffer", ctypes.c_void_p), | |
| ("outputCostBuffer", ctypes.c_void_p), | |
| ("hPrivData", ctypes.c_void_p), | |
| ] | |
| class NV_OF_BUFFER_STRIDE(ctypes.Structure): | |
| _fields_ = [ | |
| ("strideXInBytes", ctypes.c_uint32), | |
| ("strideYInBytes", ctypes.c_uint32), | |
| ] | |
| class NV_OF_CUDA_BUFFER_STRIDE_INFO(ctypes.Structure): | |
| _fields_ = [ | |
| ("strideInfo", NV_OF_BUFFER_STRIDE * MAX_NUM_PLANES), | |
| ("numPlanes", ctypes.c_uint32), | |
| ] | |
| # Function Pointers | |
| PFNNVCREATEOPTICALFLOWCUDA = ctypes.CFUNCTYPE( | |
| ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) | |
| ) | |
| PFNNVOFINIT = ctypes.CFUNCTYPE( | |
| ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(NV_OF_INIT_PARAMS) | |
| ) | |
| PFNNVOFCREATEGPUBUFFERCUDA = ctypes.CFUNCTYPE( | |
| ctypes.c_int, | |
| ctypes.c_void_p, | |
| ctypes.POINTER(NV_OF_BUFFER_DESCRIPTOR), | |
| ctypes.c_int, | |
| ctypes.POINTER(ctypes.c_void_p), | |
| ) | |
| PFNNVOFGPUBUFFERGETCUDEVICEPTR = ctypes.CFUNCTYPE(ctypes.c_uint64, ctypes.c_void_p) | |
| PFNVOFGPUBUFFERGETSTRIDEINFO = ctypes.CFUNCTYPE( | |
| ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(NV_OF_CUDA_BUFFER_STRIDE_INFO) | |
| ) | |
| PFNNVOFEXECUTE = ctypes.CFUNCTYPE( | |
| ctypes.c_int, | |
| ctypes.c_void_p, | |
| ctypes.POINTER(NV_OF_EXECUTE_INPUT_PARAMS), | |
| ctypes.POINTER(NV_OF_EXECUTE_OUTPUT_PARAMS), | |
| ) | |
| PFNNVOFDESTROY = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p) | |
| PFNNVOFDESTROYGPUBUFFERCUDA = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p) | |
| PFNNVOFGETCAPS = ctypes.CFUNCTYPE( | |
| ctypes.c_int, | |
| ctypes.c_void_p, | |
| ctypes.c_int, | |
| ctypes.POINTER(ctypes.c_uint32), | |
| ctypes.POINTER(ctypes.c_uint32), | |
| ) | |
| class NV_OF_CUDA_API_FUNCTION_LIST(ctypes.Structure): | |
| _fields_ = [ | |
| ("nvCreateOpticalFlowCuda", PFNNVCREATEOPTICALFLOWCUDA), | |
| ("nvOFInit", PFNNVOFINIT), | |
| ("nvOFCreateGPUBufferCuda", PFNNVOFCREATEGPUBUFFERCUDA), | |
| ("nvOFGPUBufferGetCUarray", ctypes.c_void_p), | |
| ("nvOFGPUBufferGetCUdeviceptr", PFNNVOFGPUBUFFERGETCUDEVICEPTR), | |
| ("nvOFGPUBufferGetStrideInfo", PFNVOFGPUBUFFERGETSTRIDEINFO), | |
| ("nvOFSetIOCudaStreams", ctypes.c_void_p), | |
| ("nvOFExecute", PFNNVOFEXECUTE), | |
| ("nvOFDestroyGPUBufferCuda", PFNNVOFDESTROYGPUBUFFERCUDA), | |
| ("nvOFDestroy", PFNNVOFDESTROY), | |
| ("nvOFGetLastError", ctypes.c_void_p), | |
| ("nvOFGetCaps", PFNNVOFGETCAPS), | |
| ] | |
| def find_cudart(): | |
| torch_dir = os.path.dirname(torch.__file__) | |
| site_packages = os.path.dirname(torch_dir) | |
| candidates = [ | |
| os.path.join(torch_dir, "lib"), | |
| os.path.join(site_packages, "nvidia", "cuda_runtime", "lib"), | |
| os.path.join(site_packages, "nvidia", "cuda_runtime", "bin"), # Windows | |
| ] | |
| for lib_dir in candidates: | |
| if not os.path.exists(lib_dir): | |
| continue | |
| for f in os.listdir(lib_dir): | |
| if sys.platform == "win32": | |
| if f.startswith("cudart64") and f.endswith(".dll"): | |
| return os.path.join(lib_dir, f) | |
| else: | |
| if f.startswith("libcudart") and ".so" in f: | |
| return os.path.join(lib_dir, f) | |
| return "libcudart.so" # Fallback to system path | |
| class NVOF: | |
| def __init__(self, width, height, grid_size=4, gpu_id=0): | |
| self.width = width | |
| self.height = height | |
| # Load libraries | |
| nvof_lib_path = "/usr/lib/x86_64-linux-gnu/libnvidia-opticalflow.so" | |
| if not os.path.exists(nvof_lib_path): | |
| raise RuntimeError(f"NVOF library not found at {nvof_lib_path}") | |
| self.nvof_lib = ctypes.CDLL(nvof_lib_path) | |
| cudart_path = find_cudart() | |
| self.libcudart = ctypes.CDLL(cudart_path) | |
| self.libcudart.cudaMemcpy.argtypes = [ | |
| ctypes.c_void_p, | |
| ctypes.c_void_p, | |
| ctypes.c_size_t, | |
| ctypes.c_int, | |
| ] | |
| self.cuda_lib = ctypes.CDLL("libcuda.so.1") | |
| # Get API instance | |
| self.api = NV_OF_CUDA_API_FUNCTION_LIST() | |
| res = self.nvof_lib.NvOFAPICreateInstanceCuda( | |
| NV_OF_API_VERSION, ctypes.byref(self.api) | |
| ) | |
| if res != 0: | |
| raise RuntimeError( | |
| f"NvOFAPICreateInstanceCuda failed: {NV_OF_STATUS.to_string(res)}" | |
| ) | |
| # Ensure PyTorch has initialized CUDA | |
| torch.cuda.set_device(gpu_id) | |
| _ = torch.zeros(1, device=f"cuda:{gpu_id}") | |
| self.cuContext = ctypes.c_void_p() | |
| res = self.cuda_lib.cuCtxGetCurrent(ctypes.byref(self.cuContext)) | |
| if res != 0 or not self.cuContext: | |
| raise RuntimeError(f"cuCtxGetCurrent failed: {res}") | |
| # Create NVOF Handle | |
| self.hOf = ctypes.c_void_p() | |
| res = self.api.nvCreateOpticalFlowCuda(self.cuContext, ctypes.byref(self.hOf)) | |
| if res != 0: | |
| raise RuntimeError( | |
| f"nvCreateOpticalFlowCuda failed: {NV_OF_STATUS.to_string(res)}" | |
| ) | |
| # Check supported grid sizes | |
| supported_grids = self.get_supported_grid_sizes() | |
| if grid_size not in supported_grids: | |
| raise ValueError( | |
| f"Requested grid_size {grid_size} is not supported. Supported: {supported_grids}" | |
| ) | |
| # Initialize OF | |
| init_params = NV_OF_INIT_PARAMS() | |
| init_params.width = width | |
| init_params.height = height | |
| init_params.outGridSize = { | |
| 1: NV_OF_OUTPUT_VECTOR_GRID_SIZE.GRID_1, | |
| 2: NV_OF_OUTPUT_VECTOR_GRID_SIZE.GRID_2, | |
| 4: NV_OF_OUTPUT_VECTOR_GRID_SIZE.GRID_4, | |
| }[grid_size] | |
| init_params.mode = NV_OF_MODE.OPTICALFLOW | |
| init_params.perfLevel = NV_OF_PERF_LEVEL.SLOW | |
| res = self.api.nvOFInit(self.hOf, ctypes.byref(init_params)) | |
| if res != 0: | |
| raise RuntimeError(f"nvOFInit failed: {NV_OF_STATUS.to_string(res)}") | |
| self.grid_size = grid_size | |
| self.out_width = (width + self.grid_size - 1) // self.grid_size | |
| self.out_height = (height + self.grid_size - 1) // self.grid_size | |
| def create_buffer(self, usage, format): | |
| desc = NV_OF_BUFFER_DESCRIPTOR() | |
| if usage == NV_OF_BUFFER_USAGE.INPUT: | |
| desc.width = self.width | |
| desc.height = self.height | |
| else: | |
| desc.width = self.out_width | |
| desc.height = self.out_height | |
| desc.bufferUsage = usage | |
| desc.bufferFormat = format | |
| buffer_handle = ctypes.c_void_p() | |
| res = self.api.nvOFCreateGPUBufferCuda( | |
| self.hOf, | |
| ctypes.byref(desc), | |
| NV_OF_CUDA_BUFFER_TYPE.CUDEVICEPTR, | |
| ctypes.byref(buffer_handle), | |
| ) | |
| if res != 0: | |
| raise RuntimeError( | |
| f"nvOFCreateGPUBufferCuda failed: {NV_OF_STATUS.to_string(res)}" | |
| ) | |
| return buffer_handle | |
| def get_device_ptr(self, buffer_handle): | |
| return self.api.nvOFGPUBufferGetCUdeviceptr(buffer_handle) | |
| def get_stride_info(self, buffer_handle): | |
| stride_info = NV_OF_CUDA_BUFFER_STRIDE_INFO() | |
| res = self.api.nvOFGPUBufferGetStrideInfo( | |
| buffer_handle, ctypes.byref(stride_info) | |
| ) | |
| if res != 0: | |
| raise RuntimeError( | |
| f"nvOFGPUBufferGetStrideInfo failed: {NV_OF_STATUS.to_string(res)}" | |
| ) | |
| return stride_info | |
| def get_supported_grid_sizes(self): | |
| size = ctypes.c_uint32(0) | |
| # Call with NULL to get size | |
| res = self.api.nvOFGetCaps( | |
| self.hOf, NV_OF_CAPS.SUPPORTED_OUTPUT_GRID_SIZES, None, ctypes.byref(size) | |
| ) | |
| if res != 0: | |
| raise RuntimeError(f"nvOFGetCaps size failed: {NV_OF_STATUS.to_string(res)}") | |
| # Call again to get values | |
| caps_val = (ctypes.c_uint32 * size.value)() | |
| res = self.api.nvOFGetCaps( | |
| self.hOf, | |
| NV_OF_CAPS.SUPPORTED_OUTPUT_GRID_SIZES, | |
| caps_val, | |
| ctypes.byref(size), | |
| ) | |
| if res != 0: | |
| raise RuntimeError( | |
| f"nvOFGetCaps values failed: {NV_OF_STATUS.to_string(res)}" | |
| ) | |
| return [caps_val[i] for i in range(size.value)] | |
| def execute(self, input_h, ref_h, output_h): | |
| in_params = NV_OF_EXECUTE_INPUT_PARAMS() | |
| in_params.inputFrame = input_h | |
| in_params.referenceFrame = ref_h | |
| in_params.disableTemporalHints = 1 | |
| out_params = NV_OF_EXECUTE_OUTPUT_PARAMS() | |
| out_params.outputBuffer = output_h | |
| res = self.api.nvOFExecute( | |
| self.hOf, ctypes.byref(in_params), ctypes.byref(out_params) | |
| ) | |
| if res != 0: | |
| raise RuntimeError(f"nvOFExecute failed: {NV_OF_STATUS.to_string(res)}") | |
| def copy_to_buffer(self, buffer_handle, tensor): | |
| ptr = self.get_device_ptr(buffer_handle) | |
| stride_info = self.get_stride_info(buffer_handle) | |
| stride_x = stride_info.strideInfo[0].strideXInBytes | |
| # If stride matches exactly, we can do a single copy. | |
| # Otherwise, we might need to handle it row by row or allocate a larger tensor. | |
| # For simplicity, we'll use a temporary larger tensor if stride > row_size. | |
| row_size = ( | |
| tensor.shape[1] | |
| * tensor.element_size() | |
| * (tensor.shape[2] if tensor.ndim > 2 else 1) | |
| ) | |
| if stride_x == row_size: | |
| byte_size = tensor.element_size() * tensor.nelement() | |
| res = self.libcudart.cudaMemcpy(ptr, tensor.data_ptr(), byte_size, 3) | |
| if res != 0: | |
| raise RuntimeError(f"cudaMemcpy failed: {res}") | |
| else: | |
| # Need to handle pitch copy. libcudart has cudaMemcpy2D but let's do it row by row for simplicity. | |
| # Or we can just use a large flat copy if the destination is large enough. | |
| # Actually, the NVOF buffer is allocated with stride_x. | |
| for y in range(tensor.shape[0]): | |
| src = tensor.data_ptr() + y * row_size | |
| dst = ptr + y * stride_x | |
| res = self.libcudart.cudaMemcpy(dst, src, row_size, 3) | |
| if res != 0: | |
| raise RuntimeError(f"cudaMemcpy row {y} failed") | |
| def copy_from_buffer(self, buffer_handle): | |
| ptr = self.get_device_ptr(buffer_handle) | |
| stride_info = self.get_stride_info(buffer_handle) | |
| stride_x = stride_info.strideInfo[0].strideXInBytes | |
| # SHORT2 is 4 bytes per pixel | |
| pixels_per_row = stride_x // 4 | |
| # Copy the entire pitched buffer | |
| temp_t = torch.zeros( | |
| (self.out_height, pixels_per_row, 2), dtype=torch.int16, device="cuda" | |
| ) | |
| res = self.libcudart.cudaMemcpy( | |
| temp_t.data_ptr(), ptr, temp_t.element_size() * temp_t.nelement(), 3 | |
| ) | |
| if res != 0: | |
| raise RuntimeError(f"cudaMemcpy failed: {res}") | |
| # Crop to actual width | |
| return temp_t[:, : self.out_width, :].clone() | |
| def destroy_buffer(self, buffer_handle): | |
| self.api.nvOFDestroyGPUBufferCuda(buffer_handle) | |
| def __del__(self): | |
| if hasattr(self, "hOf") and self.hOf: | |
| self.api.nvOFDestroy(self.hOf) | |
| def flow_to_rgb(flow): | |
| """ | |
| Visualize optical flow in RGB color space using PyTorch. | |
| Zero flow is white, colors represent direction, intensity represents magnitude. | |
| flow: (H, W, 2) float32 | |
| Returns: (3, H, W) float32 in [0, 1] | |
| """ | |
| flow = flow.permute(2, 0, 1) # (2, H, W) | |
| mag = torch.norm(flow, dim=0) | |
| angle = torch.atan2(flow[1], flow[0]) | |
| # Map angle to Hue [0, 1] | |
| hue = (angle + np.pi) / (2 * np.pi) | |
| # Standard optical flow visualization: | |
| # Zero flow is white -> Saturation = magnitude, Value = 1.0 | |
| # Higher magnitude -> More saturated color | |
| max_mag = mag.max() | |
| sat = torch.clamp(mag / (max_mag + 1e-6), 0, 1) | |
| val = torch.ones_like(hue) | |
| # HSV to RGB | |
| h = hue * 6 | |
| c = val * sat | |
| x = c * (1 - torch.abs(h % 2 - 1)) | |
| m = val - c | |
| cond = h.long() % 6 | |
| r = torch.where( | |
| cond == 0, | |
| c, | |
| torch.where( | |
| cond == 1, | |
| x, | |
| torch.where( | |
| cond == 2, | |
| torch.zeros_like(h), | |
| torch.where( | |
| cond == 3, torch.zeros_like(h), torch.where(cond == 4, x, c) | |
| ), | |
| ), | |
| ), | |
| ) | |
| g = torch.where( | |
| cond == 0, | |
| x, | |
| torch.where( | |
| cond == 1, | |
| c, | |
| torch.where( | |
| cond == 2, | |
| c, | |
| torch.where( | |
| cond == 3, | |
| x, | |
| torch.where(cond == 4, torch.zeros_like(h), torch.zeros_like(h)), | |
| ), | |
| ), | |
| ), | |
| ) | |
| b = torch.where( | |
| cond == 0, | |
| torch.zeros_like(h), | |
| torch.where( | |
| cond == 1, | |
| torch.zeros_like(h), | |
| torch.where( | |
| cond == 2, x, torch.where(cond == 3, c, torch.where(cond == 4, c, x)) | |
| ), | |
| ), | |
| ) | |
| rgb = torch.stack([r + m, g + m, b + m], dim=0) | |
| return rgb | |
| def test(): | |
| from PIL import Image | |
| img1_path = "images/001.png" | |
| img2_path = "images/002.png" | |
| img1 = Image.open(img1_path).convert("RGB") | |
| img2 = Image.open(img2_path).convert("RGB") | |
| width, height = img1.size | |
| def to_abgr(img): | |
| rgba = np.array(img.convert("RGBA")) | |
| # RGBA -> ABGR: [R, G, B, A] -> [A, B, G, R] | |
| abgr = np.empty_like(rgba) | |
| abgr[:, :, 0] = rgba[:, :, 3] # A | |
| abgr[:, :, 1] = rgba[:, :, 2] # B | |
| abgr[:, :, 2] = rgba[:, :, 1] # G | |
| abgr[:, :, 3] = rgba[:, :, 0] # R | |
| return torch.from_numpy(abgr).cuda() | |
| t1 = to_abgr(img1) | |
| t2 = to_abgr(img2) | |
| nvof = NVOF(width, height, grid_size=1) | |
| b1 = nvof.create_buffer(NV_OF_BUFFER_USAGE.INPUT, NV_OF_BUFFER_FORMAT.ABGR8) | |
| b2 = nvof.create_buffer(NV_OF_BUFFER_USAGE.INPUT, NV_OF_BUFFER_FORMAT.ABGR8) | |
| out_b = nvof.create_buffer(NV_OF_BUFFER_USAGE.OUTPUT, NV_OF_BUFFER_FORMAT.SHORT2) | |
| print("Copying data to NVOF buffers...") | |
| nvof.copy_to_buffer(b1, t1) | |
| nvof.copy_to_buffer(b2, t2) | |
| print("Executing NVOF...") | |
| nvof.execute(b1, b2, out_b) | |
| print("NVOF executed successfully.") | |
| out_t = nvof.copy_from_buffer(out_b) | |
| # Flow is in S10.5 format. | |
| flow = out_t.float() / 32.0 | |
| print(f"Flow shape: {flow.shape}") | |
| print(f"Max flow: {torch.max(torch.abs(flow)).item()}") | |
| # Visualize | |
| vis_rgb = flow_to_rgb(flow) | |
| torchvision.utils.save_image(vis_rgb, "flow_vis.png") | |
| print("Flow visualization saved to flow_vis.png using torchvision") | |
| # Clean up | |
| nvof.destroy_buffer(b1) | |
| nvof.destroy_buffer(b2) | |
| nvof.destroy_buffer(out_b) | |
| if __name__ == "__main__": | |
| test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment