Introduction to Jax
An introduction to Jax and why to use it even outside neural networks and deep learning.
Jax is a relatively new library developed at Google that originated from AutoGrad to automatically get gradients of Python functions. I've been using Jax for roughly two years and I'm still excited about it and the direction it is heading in.
While it is really nice to use Jax for neural networks it is also absolutely perfect to use for other types of calculations. Especially if you want automatic gradients.
Jax by itself can be used to speed up Python code, similarly to Numba, but the distinguishing features are:
- The same code can run efficiently on CPU, GPU or TPU
- Automatic differentiation
- Parallelization is easy
- Use NumPy and SciPy as you normally would
Jax has a limitation which I've come to seen as also an advantage. It traces the computation instead of your code. Intially this makes it awkward as if
statements don't work as expected, but this can be leveraged to make super-fast specific implementations while still maintaining a single piece of code. We'll go a little bit into how Jax works and why, but this specific topic will be covered later.
Parallelize simulations
Recently I gave a talk at Insurance Data Science about how many tasks in insurance are embarrassingly parallel and that the GPU is a great fit. While you can write your own CUDA code and make a Python wrapper...it is a pain. No particular part is hard, but getting CMake to work, get correct versions of libraries and then make a Python wrapper is just tedious and error prone. To get absolute control over memory you got to do this, as we do at my current company, but this takes a large time investment and is likely to not be worth it.
Instead I'll make a simple example with Jax and you can see how trivial it looks.
As an example we use a AR(1) mean-reverting process. We generate values $x_0,...,x_T$ where each value $x_i$ depends on it's previous value plus some noise i.e. $x_i := \rho x_{i-1} + \epsilon_i$ and $\epsilon \sim N(0, 1)$.
The goal is to generate $N$ paths of length $T$.
N = 500
T = 120
rho = 0.9
Generate normally distributed random values. Randomness is slightly different in Jax to make it reproducable and fit for parallel computation. It is stateless and always requires a key/state of the generator to be passed to it. For more background please see https://jax.readthedocs.io/en/latest/jax.random.html
key = jax.random.PRNGKey(1)
es = jax.random.normal(key, shape=(N, T))
Then we define a function to get the next value of $x$:
def next_x(rho, x, e):
x = rho * x + e
return x
next_x(0.9, 1.0, 0.033)
Then we define a function that generates a series of $x_1,...,x_T$. First I'll do it in vanilla Python:
def path_python(rho, es):
values = np.zeros(es.shape)
for i in range(1, len(es)):
values[i] = next_x(rho, values[i-1], es[i-1])
return values
path_python(rho, es[0]).shape
For this we use scan
from Jax to generate a list. For good reasons we do not use for loops. Practically, Jax limits the usage :) The underlying reason and motivation is that compilers have to work hard to guess what your loop means as it can use global state. With scan
we make it explicit and it's less work for a compiler.
def path(rho, es):
# A small function to transform from our own definition to use the style for `scan`
def _next_x(x_prev, e):
x = next_x(rho, x_prev, e)
return x, x_prev
# Apply scan over the errors and start with x=0
_, values = jax.lax.scan(
_next_x,
init=0.,
xs=es
)
return values
path(rho, es[0]).shape
We now use vmap
to parallelize over the matrix $N$ by $T$ over the first axis and pass $T$-sized vectors to the path
function. Because rho
is a scalar we do not map over any of its axis as it has none.
def simulate_python(rho, es):
return np.vstack([path(rho, esi) for esi in es])
simulate_python(rho, es).shape
def simulate(rho, es):
# We indicate how we map over the parameters:
# - Rho is a scalar, so do not map over it, simply pass the same value to all
# - es' first dimension is the number of paths so map over the first axis
return jax.vmap(path, in_axes=(None, 0))(rho, es)
simulate(rho, es).shape
To get a NumPy array use np.asarray
which will result in a zero-copy array:
plt.plot(np.asarray(simulate(rho, es)).T)
plt.show()
%timeit simulate(rho, es)
fast = jax.jit(simulate) # Define a compiled version
fast(rho, es) # Let it compile to exclude it from the timeit
%timeit fast(rho, es)
The plain Python version takes quite a while longer. It is really not a fair comparison, but just as a reference.
%timeit simulate_python(rho, es)
For comparison we would have a different implementation in NumPy:
factor = np.triu(np.fromfunction(lambda i, j: j-i, (T, T)))
def numpy(rho, es):
rho_factor = np.triu(rho**factor)
xs = es @ rho_factor
return xs
%timeit numpy(rho, es)