Skip to content

Instantly share code, notes, and snippets.

@maneesh29s
Created December 21, 2024 11:30
Show Gist options
  • Select an option

  • Save maneesh29s/a59a9f32db25d4d3c764e80c590026b8 to your computer and use it in GitHub Desktop.

Select an option

Save maneesh29s/a59a9f32db25d4d3c764e80c590026b8 to your computer and use it in GitHub Desktop.
Gist contains notebooks , which demonstrate how we can use JAX to use vectorized functions with dask arrays. Also contains Numba example for CPU.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"LD_LIBRARY_PATH\"]=\"/usr/local/cuda-12/lib64\"\n",
"os.environ[\"PATH\"]=\"/usr/local/cuda-12/bin:\" + os.environ[\"PATH\"]\n",
"# os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
"# os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"0.2\"\n",
"\n",
"from jax import config\n",
"config.update(\"jax_traceback_filtering\", \"off\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from numba import guvectorize, int64\n",
"import numpy as np\n",
"import xarray as xr\n",
"import dask\n",
"import dask.array\n",
"import jax\n",
"import cupy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a local cluster"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import LocalCluster, WorkerPlugin"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class EnvPlugin(WorkerPlugin):\n",
" def __init__(self):\n",
" pass\n",
"\n",
" def setup(self, worker):\n",
" import os\n",
" os.environ[\"LD_LIBRARY_PATH\"]=\"/usr/local/cuda-12/lib64\"\n",
" os.environ[\"PATH\"]=\"/usr/local/cuda-12/bin:\" + os.environ[\"PATH\"]\n",
" # os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
" # os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"0.2\"\n",
"\n",
" from jax import config\n",
" config.update(\"jax_traceback_filtering\", \"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cluster = LocalCluster(n_workers=1, threads_per_worker=1, scheduler_port=8786, dashboard_address=\":8787\", resources={\"GPU\": 1})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = cluster.get_client()\n",
"client.register_plugin(EnvPlugin())\n",
"client.forward_logging()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating temp data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# shape=(4096,1024,1024)\n",
"# chunksizes=(256,1024,1024)\n",
"# # # vis = xr.DataArray(np.random.random(shape).astype(dtype=np.float32), dims=[\"z\",\"y\",\"x\"])\n",
"# # # flag = xr.DataArray(np.ones(shape, dtype=bool), dims=[\"z\",\"y\",\"x\"])\n",
"# # # weights = xr.DataArray(np.random.random(shape).astype(dtype=np.float32), dims=[\"z\",\"y\",\"x\"])\n",
"# # # kernel = xr.DataArray(np.random.random((16,16)).astype(dtype=np.float32), dims=[\"ky\",\"kx\"])\n",
"\n",
"# vis = xr.DataArray(dask.array.random.random(shape, chunks=chunksizes,).astype(dtype=np.float32), dims=[\"z\",\"y\",\"x\"])\n",
"# flag = xr.DataArray(dask.array.ones(shape, chunks=chunksizes, dtype=bool), dims=[\"z\",\"y\",\"x\"])\n",
"# weights = xr.DataArray(dask.array.random.random(shape, chunks=chunksizes,).astype(dtype=np.float32), dims=[\"z\",\"y\",\"x\"])\n",
"\n",
"# xds = xr.Dataset(dict(vis=vis, flag=flag, weights=weights))\n",
"# xds.to_zarr(\"xds_big.zarr\", mode='w')\n",
"# # # xds.to_zarr(\"xds_small.zarr\", mode='w')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read zarr dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xds = xr.open_zarr(\"xds_big_rechunked.zarr\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xds.vis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Jax operation on GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import Semaphore\n",
"\n",
"gpusem = Semaphore(max_leases=1, name=\"gpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"@jax.jit\n",
"@partial(jax.vmap, in_axes=(0,0,0), out_axes=0)\n",
"def jax_vmap_gpu_operation(vis, weights, flag):\n",
" output = jnp.einsum(\"ij,jk->ik\", vis, vis)\n",
" output = output * weights\n",
" output = jnp.where(flag, output, 0.0)\n",
"\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def jax_gpu_operation(vis_batch,weights_batch,flag_batch, semaphore):\n",
" with semaphore:\n",
" vis = jax.device_put(vis_batch, donate=True)\n",
" weights = jax.device_put(weights_batch, donate=True)\n",
" flag = jax.device_put(flag_batch, donate=True)\n",
"\n",
" output_gpu = jax_vmap_gpu_operation(vis, weights, flag)\n",
"\n",
" return jax.device_get(output_gpu)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result_jax_gpu_vec = xr.apply_ufunc(\n",
" jax_gpu_operation,\n",
" xds.vis,\n",
" xds.weights,\n",
" xds.flag,\n",
" input_core_dims=[['y','x'],['y','x'],['y','x'],],\n",
" output_core_dims=[['y','x']],\n",
" dask=\"parallelized\",\n",
" vectorize=False,\n",
" output_dtypes=[np.float32],\n",
" kwargs=dict(semaphore=gpusem)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit -n 1 -r 1 client.compute(result_jax_gpu_vec, sync=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cupy operations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dask.config.set({\"array.backend\": \"cupy\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xds = xr.open_zarr(\"xds.zarr\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for var in xds.data_vars:\n",
" xds[var].data = xds[var].data.map_blocks(cupy.asarray)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vis2d = xds.vis[0:128, :, :].data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.einsum(\"mij,mjk->mik\", vis2d, vis2d).compute()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def cupy_operation(vis, weights, flag):\n",
" output = cupy.einsum(\"ij,jk->ik\", vis, vis)\n",
" output = output * weights\n",
" output = cupy.where(flag, output, 0.0)\n",
"\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result_cupy_gpu_vec = xr.apply_ufunc(\n",
" cupy_operation,\n",
" xds.vis,\n",
" xds.weights,\n",
" xds.flag,\n",
" input_core_dims=[['y','x'],['y','x'],['y','x'],],\n",
" output_core_dims=[['y','x']],\n",
" dask=\"parallelized\",\n",
" vectorize=True,\n",
" output_dtypes=[cupy.float32],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client.compute(result_cupy_gpu_vec, sync=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task = np.einsum(\"ij,jk->ik\", vis, vis)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Numba operation on GPU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TODO"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "spec_line",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment