Interoperability

Dr.Jit can exchange data with various other array programming frameworks. Currently, the following ones are officially supported:

There isn’t much to it: given an input array from another framework, simply pass it to the constructor of the Dr.Jit array or tensor type you wish to construct.

import numpy as np
from drjit.llvm import Array3f, TensorXf

a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)

# Load into a dynamic 3D array with shape (3, N)
b = Array3f(a)

# Load into a generic tensor, can represent any input shape.
c = TensorXf(a)

Dr.Jit uses a zero-copy strategy whenever possible, by simply exposing the existing data using a different type. This is possible thanks to the DLPack data exchange protocol.

The reverse direction is in principle analogous, though not all frameworks correctly detect that Dr.Jit arrays implements the DLPack specification. To avoid unnecessary copies, use the .numpy(), .torch(), .jax(), or .tf() members that always do the right thing for each target.

b_np    = b.numpy()
c_torch = c.torch()

Note that these operations evaluate the input Dr.Jit array if this has not already been done before.

Differentiability

The former operations only convert data but do not track derivatives.

>>> import torch, drjit as dr
>>> from drjit.cuda.ad import Float
>>> a = torch.tensor([1.0], requires_grad=True)
>>> b = drjit.llvm.Float(a)
>>> dr.grad_enabled(b)
False
>>> :-(

Multi-framework differentiation requires a clear interface within the AD system of each participant. The @drjit.wrap decorator provides such an interface.

This decorator can either expose a differentiable Dr.Jit function in another framework or the reverse, and it supports both forward and reverse-mode differentiation.

You can combine it with further decorators such as @drjit.syntax and use the full set of symbolic or evaluated operations available in normal Dr.Jit programs.

Below is an example computing the derivative of a Dr.Jit subroutine within a larger PyTorch program:

>>> from drjit.cuda import Int
>>> @dr.wrap(source='torch', target='drjit')
... @dr.syntax
... def pow2(n, x):
...    i, n = Int(0), Int(n)
...    while dr.hint(i < n, max_iterations=10):
...        x *= x
...        i += 1
...    return x
...
>>> n = torch.tensor([0, 1, 2, 3], dtype=torch.int32)
>>> x = torch.tensor([4, 4, 4, 4], dtype=torch.float32, requires_grad=True)
>>> y = pow2(n, x)
>>> print(y)
tensor([4.0000e+00, 1.6000e+01, 2.5600e+02, 6.5536e+04],
       grad_fn=<TorchWrapperBackward>)
>>> y.sum().backward()
>>> print(x.grad)
tensor([1.0000e+00, 8.0000e+00, 2.5600e+02, 1.3107e+05])

See the documentation of @drjit.wrap for further details.

Caveats

Some frameworks are extremely greedy in their use of resources especially when working with CUDA. They must be reined in to build software that effectively combines multiple frameworks. This is where things stand as of early 2024:

  • Dr.Jit behaves nicely and only allocates memory on demand.

  • PyTorch behaves nicely and only allocates memory on demand.

  • JAX preallocates 75% of the total GPU memory when the first JAX operation is run, which only leaves a small remainder for Dr.Jit and the operating system.

    To disable this behavior, you must set the environment variable XLA_PYTHON_CLIENT_PREALLOCATE=false before launching Python or the Jupyter notebook.

    Alternatively, you can run

    import os
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    

    before importing JAX.

  • TensorFlow preallocates “nearly all” of the GPU memory visible to the process, which will likely prevent Dr.Jit from functioning correctly.

    To disable this behavior, you must call the set_memory_growth function before using any other TensorFlow API, which will cause it to use a less aggressive on-demand allocation policy.

Once they allocate memory, these frameworks also keep it to themselves: for example, if your program temporarily creates a huge PyTorch tensor that uses nearly all GPU memory, then that memory is blocked from further use in Dr.Jit.

This behavior is technically justified: allocating and releasing memory is a rather slow operation especially on CUDA, so every framework (including Dr.Jit) implements some type of internal memory cache. These caches can be manually freed if necessary. Here is how this can be accomplished:

  • Dr.Jit: call drjit.flush_malloc_cache().

  • PyTorch: call torch.cuda.empty_cache().

  • JAX: there is no way to do it besides setting XLA_PYTHON_CLIENT_ALLOCATOR=platform before launching Python or the Jupyter notebook or setting the variable via os.environ at the beginning of the program/Jupyter notebook. This disables the JAX memory cache, which may have a negative impact on performance.

  • TensorFlow: there is no way to do it besides setting TF_GPU_ALLOCATOR=cuda_malloc_async before launching Python or the Jupyter notebook or setting the variable via os.environ at the beginning of the program/Jupyter notebook. This disables the TensorFlow memory cache, which may have a negative impact on performance.

A side remark is that clearing such allocations caches is an expensive operation in any of these frameworks. You likely don’t want to do so within a performance-sensitive program region (e.g., an optimization loop).