```python # SPDX-License-Identifier: BSD-2-Clause # Authors: Martin Reinecke, Jakob Roth, Gordian Edenhofer # Copyright(C) 2024 Max-Planck-Society ``` ```python import jax import jax.numpy as jnp from jax import random import jaxbind jax.config.update("jax_enable_x64", True) ``` # Binding a multi-linear function to JAX This demo showcases the use of JAXbind for binding multi-linear functions to JAX. As an example we bind the Python function mlin computing (x,y) -> (x*y, x*y) to a JAX primitive. Note: multi-linear functions can also be regarded als general non-linear functions. For the JAXbind interface for non-linear functions see the 'demo_nonlin.py' and the docstring of the jaxbind.get_nonlinear_call'. Additional information for linear functions can also be found in 'demo_scipy_fft.py'. ```python def mlin(out, args, kwargs_dump): # extract the input from the input tuple x, y = args[0], args[1] # do the computation and write result in the out tuple out[0][()] = x * y out[1][()] = x * y ``` Besides the application of the function ('mlin') itself, JAXbind requires the linear transpose of the partial derivatives of 'mlin'. ```python # linear transpose of the partial derivative of 'mlin' with respect to the fist # variable x. def mlin_T1(out, args, kwargs_dump): y, da, db = args[0], args[1], args[2] out[0][()] = y * da + y * db # linear transpose of the partial derivative of 'mlin' with respect to the second # variable y. def mlin_T2(out, args, kwargs_dump): x, da, db = args[0], args[1], args[2] out[0][()] = x * da + x * db ``` JAX needs to abstractly evaluate the code, thus needs to be able to evaluate the shape and dtype of the output of a function given the shape and dtype of the input. For this we have to provide the abstract eval functions for mlin, mlin_T1, and mlin_T2. The abstract evaluations functions return for each output argument a tuple containing the shape and dtype of this output. More details are in the 'demo_scipy_fft.py' ```python def mlin_abstract(*args, **kwargs): assert args[0].shape == args[1].shape return ( (args[0].shape, args[0].dtype), (args[0].shape, args[0].dtype), ) def mlin_abstract_T1(*args, **kwargs): assert args[0].shape == args[1].shape return ((args[0].shape, args[0].dtype),) def mlin_abstract_T2(*args, **kwargs): assert args[0].shape == args[1].shape return ((args[0].shape, args[0].dtype),) ``` Now we can register the JAX primitive corresponding to the Python function mlin. ```python func_T = (mlin_T1, mlin_T2) func_abstract_T = (mlin_abstract_T1, mlin_abstract_T2) mlin_jax = jaxbind.get_linear_call( mlin, func_T, mlin_abstract, func_abstract_T, ) # generate some random input to showcase the use of the newly register JAX primitive key = random.PRNGKey(42) key, subkey = random.split(key) inp0 = jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64) key, subkey = random.split(key) inp1 = jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64) inp = (inp0, inp1) # apply the new primitive res = mlin_jax(*inp) # jit compile the new primitive mlin_jit = jax.jit(mlin_jax) res_jit = mlin_jit(*inp) # compute the jvp res_jvp = jax.jvp(mlin_jit, inp, inp) ```