jaxbind.jaxbind module#

get_linear_call(f, f_T, /, abstract, abstract_T, *, first_n_args_fixed=0, func_can_batch=False) partial[source]#

Create a JAX primitive for the provided linear function

Parameters:
  • f (linear function respectively its transpose) – The function signature must be (out, args, kwargs_dump), where out and args are tuples. The results of the functions should be written as numpy.ndarrays of float[32/64] or complex[64/128] type into the out tuple. The args tuple contains the input for the function. In kwargs_dump, potential keyword arguments are contained in serialized form. The keyword arguments can be deserialized via jaxbind.load_kwargs(kwargs_dump).

  • f_T (linear function respectively its transpose) – The function signature must be (out, args, kwargs_dump), where out and args are tuples. The results of the functions should be written as numpy.ndarrays of float[32/64] or complex[64/128] type into the out tuple. The args tuple contains the input for the function. In kwargs_dump, potential keyword arguments are contained in serialized form. The keyword arguments can be deserialized via jaxbind.load_kwargs(kwargs_dump).

  • abstract (function respectively its transpose) – Computing the shape and dtype of the operator’s output from shape and dtype of its input. Its signature must be (*args, **kwargs). args will be a tuple containing abstract tracer arrays with shape and dtype for each input argument of f respectively f_T. Via **kwargs, potential keyword arguments are passed to the function. The function must return a tuple containing tuples of (shape_out, dtype_out) for each output argument of f respectively f_T.

  • abstract_T (function respectively its transpose) – Computing the shape and dtype of the operator’s output from shape and dtype of its input. Its signature must be (*args, **kwargs). args will be a tuple containing abstract tracer arrays with shape and dtype for each input argument of f respectively f_T. Via **kwargs, potential keyword arguments are passed to the function. The function must return a tuple containing tuples of (shape_out, dtype_out) for each output argument of f respectively f_T.

  • first_n_args_fixed (int) – If the function cannot be differentiated with respect to some of the arguments, these can be passed as the first arguments to the function. fist_n_args_fixed indicates the number of non-differential arguments. Note: The function does not need to be linear with respect to these arguments. Default 0 (all arguments are differentiable).

  • func_can_batch (bool) – Indicator whether the function natively supports batching. If true, the function will receive one additional argument called batch_axes. The parameter will be a tuple of tuples, or None if no batching is currently performed. The tuple will be of length of the input and for each input will contain a tuple of integer indices along which the computation shall be batched.

Returns:

op

Return type:

Jax primitive corresponding to the function f.

Notes

  • f and f_T must not return anything; the result of the computation must be written into the member arrays of out.

  • the contents of args must not be modified.

  • no reference to the contents of args or out may be stored beyond the execution time of f or f_T.

get_nonlinear_call(f, f_derivative, /, abstract, abstract_reverse, *, first_n_args_fixed=0, func_can_batch=False) partial[source]#

Create a JAX primitive for the provided (nonlinear) function

Parameters:
  • f (function) – The function signature must be (out, args, kwargs_dump), where out and args are tuples. The results of the functions should be written as numpy.ndarrays of float[32/64] or complex[64/128] type into the out tuple. The args tuple contains the input for the function. In kwargs_dump, potential keyword arguments are contained in serialized form. The keyword arguments can be deserialized via jaxbind.load_kwargs(kwargs_dump).

  • f_derivative (tuple of functions) – Tuple containing functions for evaluating jvp and vjp of f. The fist entry in the function should evaluate jvp, the second vjp. The signature of the jvp and vjp functions should be (out, args, kwargs_dump) analogous to f.

  • abstract (functions) – Computing the shape and dtype of the operator’s output from shape and dtype of its input. Its signature must be (*args, **kwargs). *args will be a tuple containing abstract tracer arrays with shape and dtype for each input argument of f. Via **kwargs, potential keyword arguments are passed to the function. The function must return a tuple containing tuples of (shape_out, dtype_out) for each output argument. abstract should compute the output shapes of f and jvp. abstract_reverse should compute the output shape of vjp.

  • abstract_reverse (functions) – Computing the shape and dtype of the operator’s output from shape and dtype of its input. Its signature must be (*args, **kwargs). *args will be a tuple containing abstract tracer arrays with shape and dtype for each input argument of f. Via **kwargs, potential keyword arguments are passed to the function. The function must return a tuple containing tuples of (shape_out, dtype_out) for each output argument. abstract should compute the output shapes of f and jvp. abstract_reverse should compute the output shape of vjp.

  • first_n_args_fixed (int) – If the function cannot be differentiated with respect to some of the arguments, these can be passed as the first arguments to the function. fist_n_args_fixed indicates the number of non-differential arguments. Default 0 (all arguments are differentiable).

  • func_can_batch (bool) – Indicator whether the function natively supports batching. If true, the function will receive one additional argument called batch_axes. The parameter will be a tuple of tuples, or None if no batching is currently performed. The tuple will be of length of the input and for each input will contain a tuple of integer indices along which the computation shall be batched.

Returns:

op

Return type:

Jax primitive corresponding to the function f.

Notes

  • f and members of f_derivative must not return anything; the result of the computation must be written into the member arrays out.

  • the contents of args must not be modified.

  • no references to the contents of args or out may be stored beyond the execution time of f or f_derivative.