Miscellaneous

PyTrees

The word PyTree (borrowed from JAX) refers to a tree-like data structure made of Python container types including

  • list,

  • tuple,

  • dict,

  • data classes.

  • custom Python classes or C++ bindings with a DRJIT_STRUCT annotation.

Various Dr.Jit operations will automatically traverse such PyTrees to process any Dr.Jit arrays or tensors found within. For example, it might be convenient to store differentiable parameters of an optimization within a dictionary and then batch-enable gradients:

from drjit.cuda.ad import Array3f, Float

params = {
    'foo': Array3f(...),
    'bar': Float(...)
}

dr.enable_grad(params)

PyTrees can similarly be used as state variables in symbolic loops and conditionals, as arguments and return values of symbolic calls, as arguments of scatter/gather operations, and many others (the reference explicitly lists the word PyTree in all supported operations).

Limitations

You may not use Dr.Jit types as keys of a dictionary occurring within a PyTree. Furthermore, PyTrees may not contain cycles. For example, the following data structure will cause PyTree-compatible operations to fail with a RecursionError.

x = []
x.append(x)

Finally, Dr.Jit automatically traverses tuples, lists, and dictionaries, but it does not traverse subclasses of basic containers and other generalized sequences or mappings. This is intentional.

Custom types

There are two ways of extending PyTrees with custom data types. The first is to register a Python data class.

from drjit.cuda.ad import Float
from dataclasses import dataclass

@dataclass
class MyPoint2f:
    x: Float
    y: Float

# Create a vector representing 100 2D points. Dr.Jit will
# automatically populate the 'x' and 'y' members
value = dr.zeros(MyPoint2f, 100)

The second option is to annotate an existing non-dataclass type (e.g. a standard Python class or a C++ binding) with a static DRJIT_STRUCT member. This is simply a dictionary describing the names and types of all fields. Such custom types must be default-constructible (i.e., the constructor should work if called without arguments).

from drjit.cuda.ad import Float

class MyPoint2f:
    DRJIT_STRUCT = { 'x' : Float, 'y': Float }

# Create a vector representing 100 2D points. Dr.Jit will
# automatically populate the 'x' and 'y' members
value = dr.zeros(MyPoint2f, 100)

Fields don’t exclusively have to be containers or Dr.Jit types. For example, we could have added an extra datetime entry to record when a set of points was captured. Such fields will be ignored by traversal operations.

Local memory

Local memory is a relatively advanced feature of Dr.Jit. You may need it it if you encounter the following circumstances:

  1. A symbolic loop in your program must both read and write the same memory buffer using computed indices.

  2. The buffer is entirely local to a thread of the computation (i.e., local to an element of an array program).

  3. The buffer is small (e.g., a few 100-1000s of entries).

The drjit.alloc_local() function allocates a local memory buffer of size n and type T:

buf: dr.Local[T] = dr.alloc_local(T, n)

You may further specify an optional value=... argument to default-initialize the buffer entries. The returned instance of type drjit.Local represents the allocation. Its elements can be accessed using the regular [] indexing syntax.

Example uses of local memory might include a local stack to traverse a tree data structure, insertion sort to maintain a small sorted list, or a LU factorization of a small (e.g. 32×32) matrix with column pivoting. In contrast to what the name might suggest, local memory is neither particularly fast nor local to the processor. In fact, it is based on standard global device memory. Local memory is also not to be confused with shared memory present on CUDA architectures.

The purpose of local memory is that it exposes global memory in a different way to provide a local scratch space within a larger parallel computation. Normally, one would use drjit.gather() and drjit.scatter() to dynamically read and write memory. However, they cannot be used in this situation because read-after-write (RAW) dependencies would trigger variable evaluations that aren’t permitted in a symbolic context. Local memory legalizes such programs because RAW dependencies among multiple threads are simply not possible.

Local memory may also appear similar to dynamic array types like drjit.cuda.ArrayXf, which group multiple variables/registers into an array for convenience. The key difference is that ArrayXf does not support element access with computed indices, while local memory buffers do.

Allocating, reading, and writing local memory are all symbolic operations that don’t consume any memory by themselves. However, when local memory appears in a kernel being launched, the system must conceptually allocate extra memory for the duration of the kernel (the details of this are backend-dependent). While this does not contribute to the long-term memory requirements of a program, the short term memory requirements can be significant because local memory is separately allocated for each thread. On a CUDA device, there could be as many as 1 million simultaneously resident threads across thread blocks. A seemingly small local 1024-element single precision array then expands into a whopping 4 GiB of memory.

See the snippet below for an example that calls a function f() n times to compute a histogram (stored in local memory) of its outputs to then find the largest histogram bucket.

from drjit.auto import UInt32

# A function returning results in the range 0..9
def f(i: UInt32) -> UInt32: ....

@dr.syntax
def g(n: UInt32):
    # Create zero-initialized buffer of type 'drjit.Local[UInt32]'
    hist = dr.alloc_local(UInt32, 10, value=dr.zeros(UInt32))

    # Fill histogram
    i = UInt32(0)
    while i < n:        # <-- symbolic loop
        hist[f(i)] += 1 # <-- read+write with computed index
        i += 1

    # Get the largest histogram entry
    i, maxval = UInt32(0), UInt32(0)
    while i < 10:
        maxval = dr.maximum(maxval, hist[i])
        i += 1
    return maxval

When this function is evaluated with an array of inputs (e.g. n=UInt32(n1, n2, ...)) it will create several histograms with different numbers of functions evaluations in parallel. Each evaluation conceptually gets its own hist variable in this case.

Dr.Jit can also create local memory over PyTrees (for example, instead of dtype=Float, we could have called drjit.alloc_local() with a complex number, 3x3 matrix, tuple, or dataclass). Indexing into the drjit.Local instance then fetches or stores one instance of the PyTree.

Note

Local memory reads/writes are not tracked by Dr.Jit’s automatic differentiation layer. However, you may use local memory in implementations of custom differentiable operations based on the drjit.CustomOp interface.

The implication of the above two points it that when you want to differentiate a local memory-based computation, you have to realize the forward/backward derivative yourself. This is intentional because the default AD-provided derivative would be extremely bad (it will increase the size of the scratch space many-fold).

Accuracy of transcendental operations

Single precision

Note

The trigonometric functions sin, cos, and tan are optimized for low error on the domain \(|x| < 8192\) and don’t perform as well beyond this range.

Function

Tested domain

Abs. error (mean)

Abs. error (max)

Rel. error (mean)

Rel. error (max)

\(\text{sin}()\)

\(-8192 < x < 8192\)

\(1.2 \cdot 10^{-8}\)

\(1.2 \cdot 10^{-7}\)

\(1.9 \cdot 10^{-8}\,(0.25\,\text{ulp})\)

\(1.8 \cdot 10^{-6}\,(19\,\text{ulp})\)

\(\text{cos}()\)

\(-8192 < x < 8192\)

\(1.2 \cdot 10^{-8}\)

\(1.2 \cdot 10^{-7}\)

\(1.9 \cdot 10^{-8}\,(0.25\,\text{ulp})\)

\(3.1 \cdot 10^{-6}\,(47\,\text{ulp})\)

\(\text{tan}()\)

\(-8192 < x < 8192\)

\(4.7 \cdot 10^{-6}\)

\(8.1 \cdot 10^{-1}\)

\(3.4 \cdot 10^{-8}\,(0.42\,\text{ulp})\)

\(3.1 \cdot 10^{-6}\,(30\,\text{ulp})\)

\(\text{asin}()\)

\(-1 < x < 1\)

\(2.3 \cdot 10^{-8}\)

\(1.2 \cdot 10^{-7}\)

\(2.9 \cdot 10^{-8}\,(0.33\,\text{ulp})\)

\(2.3 \cdot 10^{-7}\,(2\,\text{ulp})\)

\(\text{acos}()\)

\(-1 < x < 1\)

\(4.7 \cdot 10^{-8}\)

\(2.4 \cdot 10^{-7}\)

\(2.9 \cdot 10^{-8}\,(0.33\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{atan}()\)

\(-1 < x < 1\)

\(1.8 \cdot 10^{-7}\)

\(6 \cdot 10^{-7}\)

\(4.2 \cdot 10^{-7}\,(4.9\,\text{ulp})\)

\(8.2 \cdot 10^{-7}\,(12\,\text{ulp})\)

\(\text{sinh}()\)

\(-10 < x < 10\)

\(2.6 \cdot 10^{-5}\)

\(2 \cdot 10^{-3}\)

\(2.8 \cdot 10^{-8}\,(0.34\,\text{ulp})\)

\(2.7 \cdot 10^{-7}\,(3\,\text{ulp})\)

\(\text{cosh}()\)

\(-10 < x < 10\)

\(2.9 \cdot 10^{-5}\)

\(2 \cdot 10^{-3}\)

\(2.9 \cdot 10^{-8}\,(0.35\,\text{ulp})\)

\(2.5 \cdot 10^{-7}\,(4\,\text{ulp})\)

\(\text{tanh}()\)

\(-10 < x < 10\)

\(4.8 \cdot 10^{-8}\)

\(4.2 \cdot 10^{-7}\)

\(5 \cdot 10^{-8}\,(0.76\,\text{ulp})\)

\(5 \cdot 10^{-7}\,(7\,\text{ulp})\)

\(\text{asinh}()\)

\(-30 < x < 30\)

\(2.8 \cdot 10^{-8}\)

\(4.8 \cdot 10^{-7}\)

\(1 \cdot 10^{-8}\,(0.13\,\text{ulp})\)

\(1.7 \cdot 10^{-7}\,(2\,\text{ulp})\)

\(\text{acosh}()\)

\(1 < x < 10\)

\(2.9 \cdot 10^{-8}\)

\(2.4 \cdot 10^{-7}\)

\(1.5 \cdot 10^{-8}\,(0.18\,\text{ulp})\)

\(2.4 \cdot 10^{-7}\,(3\,\text{ulp})\)

\(\text{atanh}()\)

\(-1 < x < 1\)

\(9.9 \cdot 10^{-9}\)

\(2.4 \cdot 10^{-7}\)

\(1.5 \cdot 10^{-8}\,(0.18\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{exp}()\)

\(-20 < x < 30\)

\(0.72 \cdot 10^{4}\)

\(0.1 \cdot 10^{7}\)

\(2.4 \cdot 10^{-8}\,(0.27\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{log}()\)

\(10^{-20} < x < 2\cdot 10^{30}\)

\(9.6 \cdot 10^{-9}\)

\(7.6 \cdot 10^{-6}\)

\(1.4 \cdot 10^{-10}\,(0.0013\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{erf}()\)

\(-1 < x < 1\)

\(3.2 \cdot 10^{-8}\)

\(1.8 \cdot 10^{-7}\)

\(6.4 \cdot 10^{-8}\,(0.78\,\text{ulp})\)

\(3.3 \cdot 10^{-7}\,(4\,\text{ulp})\)

\(\text{erfc}()\)

\(-1 < x < 1\)

\(3.4 \cdot 10^{-8}\)

\(2.4 \cdot 10^{-7}\)

\(6.4 \cdot 10^{-8}\,(0.79\,\text{ulp})\)

\(1 \cdot 10^{-6}\,(11\,\text{ulp})\)

Double precision

Function

Tested domain

Abs. error (mean)

Abs. error (max)

Rel. error (mean)

Rel. error (max)

\(\text{sin}()\)

\(-8192 < x < 8192\)

\(2.2 \cdot 10^{-17}\)

\(2.2 \cdot 10^{-16}\)

\(3.6 \cdot 10^{-17}\,(0.25\,\text{ulp})\)

\(3.1 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{cos}()\)

\(-8192 < x < 8192\)

\(2.2 \cdot 10^{-17}\)

\(2.2 \cdot 10^{-16}\)

\(3.6 \cdot 10^{-17}\,(0.25\,\text{ulp})\)

\(3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{tan}()\)

\(-8192 < x < 8192\)

\(6.8 \cdot 10^{-16}\)

\(1.2 \cdot 10^{-10}\)

\(5.4 \cdot 10^{-17}\,(0.35\,\text{ulp})\)

\(4.1 \cdot 10^{-16}\,(3\,\text{ulp})\)

\(\text{cot}()\)

\(-8192 < x < 8192\)

\(4.9 \cdot 10^{-16}\)

\(1.2 \cdot 10^{-10}\)

\(5.5 \cdot 10^{-17}\,(0.36\,\text{ulp})\)

\(4.4 \cdot 10^{-16}\,(3\,\text{ulp})\)

\(\text{asin}()\)

\(-1 < x < 1\)

\(1.3 \cdot 10^{-17}\)

\(2.2 \cdot 10^{-16}\)

\(1.5 \cdot 10^{-17}\,(0.098\,\text{ulp})\)

\(2.2 \cdot 10^{-16}\,(1\,\text{ulp})\)

\(\text{acos}()\)

\(-1 < x < 1\)

\(5.4 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(3.5 \cdot 10^{-17}\,(0.23\,\text{ulp})\)

\(2.2 \cdot 10^{-16}\,(1\,\text{ulp})\)

\(\text{atan}()\)

\(-1 < x < 1\)

\(4.3 \cdot 10^{-17}\)

\(3.3 \cdot 10^{-16}\)

\(1 \cdot 10^{-16}\,(0.65\,\text{ulp})\)

\(7.1 \cdot 10^{-16}\,(5\,\text{ulp})\)

\(\text{sinh}()\)

\(-10 < x < 10\)

\(3.1 \cdot 10^{-14}\)

\(1.8 \cdot 10^{-12}\)

\(3.3 \cdot 10^{-17}\,(0.22\,\text{ulp})\)

\(4.3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{cosh}()\)

\(-10 < x < 10\)

\(2.2 \cdot 10^{-14}\)

\(1.8 \cdot 10^{-12}\)

\(2 \cdot 10^{-17}\,(0.13\,\text{ulp})\)

\(2.9 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{tanh}()\)

\(-10 < x < 10\)

\(5.6 \cdot 10^{-17}\)

\(3.3 \cdot 10^{-16}\)

\(6.1 \cdot 10^{-17}\,(0.52\,\text{ulp})\)

\(5.5 \cdot 10^{-16}\,(3\,\text{ulp})\)

\(\text{asinh}()\)

\(-30 < x < 30\)

\(5.1 \cdot 10^{-17}\)

\(8.9 \cdot 10^{-16}\)

\(1.9 \cdot 10^{-17}\,(0.13\,\text{ulp})\)

\(4.4 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{acosh}()\)

\(1 < x < 10\)

\(4.9 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(2.6 \cdot 10^{-17}\,(0.17\,\text{ulp})\)

\(6.6 \cdot 10^{-16}\,(5\,\text{ulp})\)

\(\text{atanh}()\)

\(-1 < x < 1\)

\(1.8 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(3.2 \cdot 10^{-17}\,(0.21\,\text{ulp})\)

\(3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{exp}()\)

\(-20 < x < 30\)

\(4.7 \cdot 10^{-6}\)

\(2 \cdot 10^{-3}\)

\(2.5 \cdot 10^{-17}\,(0.16\,\text{ulp})\)

\(3.3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{log}()\)

\(10^{-20} < x < 2\cdot 10^{30}\)

\(1.9 \cdot 10^{-17}\)

\(1.4 \cdot 10^{-14}\)

\(2.7 \cdot 10^{-19}\,(0.0013\,\text{ulp})\)

\(2.2 \cdot 10^{-16}\,(1\,\text{ulp})\)

\(\text{erf}()\)

\(-1 < x < 1\)

\(4.7 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(9.6 \cdot 10^{-17}\,(0.63\,\text{ulp})\)

\(5.9 \cdot 10^{-16}\,(5\,\text{ulp})\)

\(\text{erfc}()\)

\(-1 < x < 1\)

\(4.8 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(9.6 \cdot 10^{-17}\,(0.64\,\text{ulp})\)

\(2.5 \cdot 10^{-15}\,(16\,\text{ulp})\)

Type signatures

The drjit.ArrayBase class and various core functions have relatively complicated-looking type signatures involving Python generics and type variables. This enables type-checking of arithmetic expressions and improves visual autocomplete in editors such as VS Code. This section explains how these type annotations work.

The drjit.ArrayBase class is both an abstract and a generic Python type parameterized by several auxiliary type parameters. They help static type checkers like MyPy and PyRight make sense how subclasses of this type transform when passed to various builtin operations. These auxiliary parameters are:

  • SelfT: the type of the array subclass (i.e., a forward reference of the type to itself).

  • SelfCpT: a union of compatible types, for which self + other or self | other produce a result of type SelfT.

  • ValT: the value type (i.e., the type of self[0])

  • ValCpT: a union of compatible types, for which self[0] + other or self[0] | other produce a result of type ValT.

  • RedT: type following reduction by drjit.sum() or drjit.all().

  • PlainT: the plain type underlying a special array (e.g. dr.scalar.Complex2f -> dr.scalar.Array2f, dr.llvm.TensorXi -> dr.llvm.Int).

  • MaskT: type produced by comparisons such as __eq__.

For example, here is the declaration of llvm.ad.Array2f shipped as part of Dr.Jit’s stub file drjit/llvm/ad.pyi:

class Array2f(drjit.ArrayBase['Array2f', '_Array2fCp', Float, '_FloatCp', Float, Array2b]):
    pass

String arguments provide forward references that the type checker will resolve at a later point. So here, we have

The mysterious-looking underscored forward references can be found at the bottom of the same stub, for example:

_Array2fCp: TypeAlias = Union['Array2f', '_FloatCp', 'drjit.llvm._Array2fCp',
                              'drjit.scalar._Array2fCp', 'Array2f', '_Array2f16Cp']

This alias creates a union of types that are compatible (as implied by the "Cp" suffix) with the type Array2f, for example when encountered in an arithmetic operations like an addition. This includes:

  • Whatever is compatible with the value type of the array (drjit.llvm.ad._FloatCp)

  • Types compatible with the non-AD version of the array (drjit.llvm._Array2fCp)

  • Types compatible with the scalar version of the array (drjit.scalar._Array2fCp)

  • Types compatible with a representative lower-precision version of that same array type (drjit.llvm.ad._Array2f16Cp)

These are all themselves type aliases representing unions continuing in the same vein, and so this in principle expands up a quite huge combined union. This enables static type inference based on Dr.Jit’s promotion rules.

With this background, we can now try to understand a type signature such as that of drjit.maximum():

@overload
def maximum(a: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], b: SelfCpT, /) -> SelfT: ...
@overload
def maximum(a: SelfCpT, b: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ...
@overload
def maximum(a: T, b: T, /) -> T: ...

Suppose we are computing the maximum of two 3D arrays:

a: Array3u = ...
b: Array3f = ...
c: WhatIsThis = dr.maximum(a, b)

In this case, WhatIsThis is Array3f due to the type promotion rules, but how does the type checker know this? When it tries the first overload, it realizes that b: Array3f is not part of the SelfCpT (compatible with self) type parameter of Array3u. In second overload, the test is reversed and succeeds, and the result is the SelfT of Array3f, which is also Array3f. The third overload exists to handle cases where neither input is a Dr.Jit array type. (e.g. dr.maximum(1, 2))