Skip to content

Instantly share code, notes, and snippets.

@wangz10
Created September 3, 2019 01:22
Show Gist options
  • Select an option

  • Save wangz10/18eb6ce15a566adae33343f92f9e113b to your computer and use it in GitHub Desktop.

Select an option

Save wangz10/18eb6ce15a566adae33343f92f9e113b to your computer and use it in GitHub Desktop.
In [1]: import numpy as np
...: import jax.numpy as jnp
...: from jax import random, jit
In [2]: def slow_f(x):
...: # Element-wise ops see a large benefit from fusion
...: return x * x + x * 2.0
...:
...: # use XLA to compile the function
...: fast_f = jit(slow_f)
In [3]: x = np.ones((5000, 5000))
...: type(x)
Out[3]: numpy.ndarray
In [4]: %timeit fast_f(x)
/Users/zichen/venv37/lib/python3.7/site-packages/jax/lib/xla_bridge.py:114: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
63.6 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [5]: %timeit slow_f(x)
249 ms ± 3.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [6]: x_j = jnp.ones((5000, 5000))
...: type(x_j)
Out[6]: jax.lax.lax._FilledConstant
In [7]: %timeit fast_f(x_j)
29.8 ms ± 739 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [8]: %timeit slow_f(x_j)
101 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment