# SPDX-License-Identifier: BSD-2-Clause
# Authors: Martin Reinecke, Jakob Roth, Gordian Edenhofer
# Copyright(C) 2024 Max-Planck-Society
import jax
import jax.numpy as jnp
from jax import random
import jaxbind
jax.config.update("jax_enable_x64", True)
Binding a multi-linear function to JAX#
This demo showcases the use of JAXbind for binding multi-linear functions to JAX. As an example we bind the Python function mlin computing (x,y) -> (xy, xy) to a JAX primitive. Note: multi-linear functions can also be regarded als general non-linear functions. For the JAXbind interface for non-linear functions see the ‘demo_nonlin.py’ and the docstring of the jaxbind.get_nonlinear_call’. Additional information for linear functions can also be found in ‘demo_scipy_fft.py’.
def mlin(out, args, kwargs_dump):
# extract the input from the input tuple
x, y = args[0], args[1]
# do the computation and write result in the out tuple
out[0][()] = x * y
out[1][()] = x * y
Besides the application of the function (‘mlin’) itself, JAXbind requires the linear transpose of the partial derivatives of ‘mlin’.
# linear transpose of the partial derivative of 'mlin' with respect to the fist
# variable x.
def mlin_T1(out, args, kwargs_dump):
y, da, db = args[0], args[1], args[2]
out[0][()] = y * da + y * db
# linear transpose of the partial derivative of 'mlin' with respect to the second
# variable y.
def mlin_T2(out, args, kwargs_dump):
x, da, db = args[0], args[1], args[2]
out[0][()] = x * da + x * db
JAX needs to abstractly evaluate the code, thus needs to be able to evaluate the shape and dtype of the output of a function given the shape and dtype of the input. For this we have to provide the abstract eval functions for mlin, mlin_T1, and mlin_T2. The abstract evaluations functions return for each output argument a tuple containing the shape and dtype of this output. More details are in the ‘demo_scipy_fft.py’
def mlin_abstract(*args, **kwargs):
assert args[0].shape == args[1].shape
return (
(args[0].shape, args[0].dtype),
(args[0].shape, args[0].dtype),
)
def mlin_abstract_T1(*args, **kwargs):
assert args[0].shape == args[1].shape
return ((args[0].shape, args[0].dtype),)
def mlin_abstract_T2(*args, **kwargs):
assert args[0].shape == args[1].shape
return ((args[0].shape, args[0].dtype),)
Now we can register the JAX primitive corresponding to the Python function mlin.
func_T = (mlin_T1, mlin_T2)
func_abstract_T = (mlin_abstract_T1, mlin_abstract_T2)
mlin_jax = jaxbind.get_linear_call(
mlin,
func_T,
mlin_abstract,
func_abstract_T,
)
# generate some random input to showcase the use of the newly register JAX primitive
key = random.PRNGKey(42)
key, subkey = random.split(key)
inp0 = jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64)
key, subkey = random.split(key)
inp1 = jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64)
inp = (inp0, inp1)
# apply the new primitive
res = mlin_jax(*inp)
# jit compile the new primitive
mlin_jit = jax.jit(mlin_jax)
res_jit = mlin_jit(*inp)
# compute the jvp
res_jvp = jax.jvp(mlin_jit, inp, inp)