General information

Optimizations

This section lists optimizations performed by Dr.Jit while tracing code. The examples all use the following import:

>>> from drjit.llvm import Int

Vectorization and parallelization

Dr.Jit automatically vectorizes and parallelizes traced code. The implications of these transformations are backend-specific.

Consider the following simple calculation, which squares an integer sequence with 10000 elements.

>>> dr.arange(dr.llvm.Int, 10000)**2
[0, 1, 4, .. 9994 skipped .., 99940009, 99960004, 99980001]

On the LLVM backend, vectorization means that generated code uses instruction set extensions such as Intel AVX/AVX2/AVX512, or ARM NEON when they are available. For example, when the machine supports the AVX512 extensions, each machine instruction processes a packet of 16 values, which means that a total of 625 packets need to be evaluated.

The system uses the built-in nanothread thread pool to distribute packets to be processed among the available processor cores. In this simple example, there is not enough work to truly benefit from multi-core parallelism, but this approach pays off in more complex examples.

You can use the functions drjit.thread_count(), drjit.set_thread_count() to specify the number of threads used for parallel processing.

On the CUDA backend, the system automatically determines a number of threads that maximize occupancy along with a suitable number of blocks and then launches a parallel program that spreads out over the entire GPU (assuming that there is enough work to do so).

Copy-on-Write

Arrays are reference-counted and use a Copy-on-Write (CoW) strategy. This means that copying an array is cheap since the copy can reference the original array without requiring a device memory copy. The matching variable indices in the example below demonstrate the lack of an actual copy.

>>> a = Int(1, 2, 3)
>>> b = Int(a)        # <- create a copy of 'a'
>>> a.index, b.index
(1, 1)

However, subsequent modification causes this copy to be made.

>>> b[0] = 0
>>> (a.index, b.index)
(1, 2)

This optimization is always active and cannot be disabled.

Constant propagation

Dr.Jit immediately performs arithmetic involving literal constant arrays:

>>> a = Int(4) + Int(5)
>>> a.state
dr.VarState.Literal

In other words, the addition does not become part of the generated device code. This optimization reduces the size of the generated LLVM/PTX IR and can be controlled via drjit.JitFlag.ConstantPropagation.

Dead code elimination

When generating code, Dr.Jit excludes unnecessary operations that do not influence arrays evaluated by the kernel. It also removes dead branches in loops and conditional statements.

This optimization is always active and cannot be disabled.

Value numbering

Dr.Jit collapses identical expressions into the same variable (this is safe given the CoW strategy explained above).

>>> a, b = Int(1, 2, 3), Int(4, 5, 6)
>>> c = a + b
>>> d = a + b
>>> c.index == d.index
True

This optimization reduces the size of the generated LLVM/PTX IR and can be controlled via drjit.JitFlag.ValueNumbering.

Local atomic reduction

Atomic memory operations can be a bottleneck when they encounter write contention, which refers to a situation where many threads attempt to write to the same array element at once.

For example, the following operation causes 1’000’000 threads to write to a[0].

>>> a = dr.zeros(Int, 10)
>>> dr.scatter_add(target=a, index=dr.zeros(Int, 1000000), value=...)

Since Dr.Jit vectorizes the program during execution, the computation is grouped into packets that typically contain 16 to 32 elements. By locally pre-accumulating the values within each packet and then only performing 31-62K atomic memory operations (instead of 1’000’000), performance can be considerably improved.

This issue is particularly important when automatically differentiating computation in reverse mode (e.g. drjit.backward()), since this transformation turns differentiable global memory reads into atomic scatter-additions. A differentiable scalar read is all it takes to create such an atomic memory bottleneck.

The following plots illustrate the expected level performance in a microbenchmark that scatters-adds \(10^8\) random integers into a buffer at uniformly distributed positions. The size of the target buffer varies along the horizontal axis. Generally, we expect to see significant contention on the left, since this involves a large number of writes to only a few elements. The behavior of GPU and CPU atomics are somewhat different, hence we look at them in turn starting with the CUDA backend.

The drjit.ReduceMode.Direct strategy generates a plain atomic operation without additional handling. This generally performs badly except for two special cases: when writing to a scalar array, the NVIDIA compiler detects this and performs a specialized optimization (that is, however, quite specific to this microbenchmark and unlikely to work in general). Towards the right, there is essentially no contention and multiple writes to the same destination are unlikely to appear within the same warp, hence drjit.ReduceMode.Direct outperforms the other methods.

https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/01/scatter_add_cuda.svghttps://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/01/scatter_add_cuda_dark.svg

The drjit.ReduceMode.Local strategy in the above plot performs a butterfly reduction to locally pre-reduce writes targeting the same region of memory, which significantly reduces the dangers of atomic memory contention.

On the CPU (LLVM) backend, Direct mode can become so slow that this essentially breaks the program. The Local strategy is analogous to the CUDA backend and improves performance by an order of magnitude when many writes target the same element. In this benchmark, that becomes less likely as the target array grows, and the optimization becomes ineffective.

https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/01/scatter_add_llvm.svghttps://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/01/scatter_add_llvm_dark.svg

The drjit.ReduceMode.Expand strategy produces a near-flat profile. It replicates the target array to avoid write conflicts altogether, which enables the use of non-atomic memory operations. This is significantly faster but also very memory-intensive, as the storage cost of an 1 MiB array targeted by a drjit.scatter_reduce() operation now grows to N MiB, where N is the number of cores. The functions expand_threshold() and set_expand_threshold() can be used to set thresholds that determine when Dr.Jit is willing to automatically use this strategy.

Other

Some other optimizations are specific to symbolic operations, such as

Please refer the documentation of these flags for details.

PyTrees

The word PyTree (borrowed from JAX) refers to a tree-like data structure made of Python container types including

  • list,

  • tuple,

  • dict,

  • data classes.

  • custom Python classes or C++ bindings with a DRJIT_STRUCT annotation.

Various Dr.Jit operations will automatically traverse such PyTrees to process any Dr.Jit arrays or tensors found within. For example, it might be convenient to store differentiable parameters of an optimization within a dictionary and then batch-enable gradients:

from drjit.cuda.ad import Array3f, Float

params = {
    'foo': Array3f(...),
    'bar': Float(...)
}

dr.enable_grad(params)

PyTrees can similarly be used as state variables in symbolic loops and conditionals, as arguments and return values of symbolic calls, as arguments of scatter/gather operations, and many others (the reference explicitly lists the word PyTree in all supported operations).

Limitations

You may not use Dr.Jit types as keys of a dictionary occurring within a PyTree. Furthermore, PyTrees may not contain cycles. For example, the following data structure will cause PyTree-compatible operations to fail with a RecursionError.

x = []
x.append(x)

Finally, Dr.Jit automatically traverses tuples, lists, and dictionaries, but it does not traverse subclasses of basic containers and other generalized sequences or mappings. This is intentional.

Custom types

There are two ways of extending PyTrees with custom data types. The first is to register a Python data class. Note that this type must be default-constructible, which means that its members should have default initializers.

from drjit.cuda.ad import Float
from dataclasses import dataclass

@dataclass
class MyPoint2f:
    x: Float = Float(0)
    y: Float = Float(0)

# Create a vector representing 100 2D points. Dr.Jit will
# automatically populate the 'x' and 'y' members
value = dr.zeros(MyPoint2f, 100)

The second option is to annotate an existing non-dataclass type (e.g. a standard Python class or a C++ binding) with a static DRJIT_STRUCT member. This is simply a dictionary describing the names and types of all fields. Such custom types must also be default-constructible (i.e., the constructor should work if called without arguments). The following is analogous to the above dataclass version:

from drjit.cuda.ad import Float

class MyPoint2f:
    DRJIT_STRUCT = { 'x' : Float, 'y': Float }

    def __init__(self, x: Float | None = None, y: Float | None = None):
        self.x = x or Float()
        self.y = y or Float()

# Create a vector representing 100 2D points. Dr.Jit will
# automatically populate the 'x' and 'y' members
value = dr.zeros(MyPoint2f, 100)

Fields don’t exclusively have to be containers or Dr.Jit types. For example, we could have added an extra datetime entry to record when a set of points was captured. Such fields will be ignored by traversal operations.

Accuracy of transcendental operations

Single precision

Note

The trigonometric functions sin, cos, and tan are optimized for low error on the domain \(|x| < 8192\) and don’t perform as well beyond this range.

Function

Tested domain

Abs. error (mean)

Abs. error (max)

Rel. error (mean)

Rel. error (max)

\(\text{sin}()\)

\(-8192 < x < 8192\)

\(1.2 \cdot 10^{-8}\)

\(1.2 \cdot 10^{-7}\)

\(1.9 \cdot 10^{-8}\,(0.25\,\text{ulp})\)

\(1.8 \cdot 10^{-6}\,(19\,\text{ulp})\)

\(\text{cos}()\)

\(-8192 < x < 8192\)

\(1.2 \cdot 10^{-8}\)

\(1.2 \cdot 10^{-7}\)

\(1.9 \cdot 10^{-8}\,(0.25\,\text{ulp})\)

\(3.1 \cdot 10^{-6}\,(47\,\text{ulp})\)

\(\text{tan}()\)

\(-8192 < x < 8192\)

\(4.7 \cdot 10^{-6}\)

\(8.1 \cdot 10^{-1}\)

\(3.4 \cdot 10^{-8}\,(0.42\,\text{ulp})\)

\(3.1 \cdot 10^{-6}\,(30\,\text{ulp})\)

\(\text{asin}()\)

\(-1 < x < 1\)

\(2.3 \cdot 10^{-8}\)

\(1.2 \cdot 10^{-7}\)

\(2.9 \cdot 10^{-8}\,(0.33\,\text{ulp})\)

\(2.3 \cdot 10^{-7}\,(2\,\text{ulp})\)

\(\text{acos}()\)

\(-1 < x < 1\)

\(4.7 \cdot 10^{-8}\)

\(2.4 \cdot 10^{-7}\)

\(2.9 \cdot 10^{-8}\,(0.33\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{atan}()\)

\(-1 < x < 1\)

\(1.8 \cdot 10^{-7}\)

\(6 \cdot 10^{-7}\)

\(4.2 \cdot 10^{-7}\,(4.9\,\text{ulp})\)

\(8.2 \cdot 10^{-7}\,(12\,\text{ulp})\)

\(\text{sinh}()\)

\(-10 < x < 10\)

\(2.6 \cdot 10^{-5}\)

\(2 \cdot 10^{-3}\)

\(2.8 \cdot 10^{-8}\,(0.34\,\text{ulp})\)

\(2.7 \cdot 10^{-7}\,(3\,\text{ulp})\)

\(\text{cosh}()\)

\(-10 < x < 10\)

\(2.9 \cdot 10^{-5}\)

\(2 \cdot 10^{-3}\)

\(2.9 \cdot 10^{-8}\,(0.35\,\text{ulp})\)

\(2.5 \cdot 10^{-7}\,(4\,\text{ulp})\)

\(\text{tanh}()\)

\(-10 < x < 10\)

\(4.8 \cdot 10^{-8}\)

\(4.2 \cdot 10^{-7}\)

\(5 \cdot 10^{-8}\,(0.76\,\text{ulp})\)

\(5 \cdot 10^{-7}\,(7\,\text{ulp})\)

\(\text{asinh}()\)

\(-30 < x < 30\)

\(2.8 \cdot 10^{-8}\)

\(4.8 \cdot 10^{-7}\)

\(1 \cdot 10^{-8}\,(0.13\,\text{ulp})\)

\(1.7 \cdot 10^{-7}\,(2\,\text{ulp})\)

\(\text{acosh}()\)

\(1 < x < 10\)

\(2.9 \cdot 10^{-8}\)

\(2.4 \cdot 10^{-7}\)

\(1.5 \cdot 10^{-8}\,(0.18\,\text{ulp})\)

\(2.4 \cdot 10^{-7}\,(3\,\text{ulp})\)

\(\text{atanh}()\)

\(-1 < x < 1\)

\(9.9 \cdot 10^{-9}\)

\(2.4 \cdot 10^{-7}\)

\(1.5 \cdot 10^{-8}\,(0.18\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{exp}()\)

\(-20 < x < 30\)

\(0.72 \cdot 10^{4}\)

\(0.1 \cdot 10^{7}\)

\(2.4 \cdot 10^{-8}\,(0.27\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{log}()\)

\(10^{-20} < x < 2\cdot 10^{30}\)

\(9.6 \cdot 10^{-9}\)

\(7.6 \cdot 10^{-6}\)

\(1.4 \cdot 10^{-10}\,(0.0013\,\text{ulp})\)

\(1.2 \cdot 10^{-7}\,(1\,\text{ulp})\)

\(\text{erf}()\)

\(-1 < x < 1\)

\(3.2 \cdot 10^{-8}\)

\(1.8 \cdot 10^{-7}\)

\(6.4 \cdot 10^{-8}\,(0.78\,\text{ulp})\)

\(3.3 \cdot 10^{-7}\,(4\,\text{ulp})\)

\(\text{erfc}()\)

\(-1 < x < 1\)

\(3.4 \cdot 10^{-8}\)

\(2.4 \cdot 10^{-7}\)

\(6.4 \cdot 10^{-8}\,(0.79\,\text{ulp})\)

\(1 \cdot 10^{-6}\,(11\,\text{ulp})\)

Double precision

Function

Tested domain

Abs. error (mean)

Abs. error (max)

Rel. error (mean)

Rel. error (max)

\(\text{sin}()\)

\(-8192 < x < 8192\)

\(2.2 \cdot 10^{-17}\)

\(2.2 \cdot 10^{-16}\)

\(3.6 \cdot 10^{-17}\,(0.25\,\text{ulp})\)

\(3.1 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{cos}()\)

\(-8192 < x < 8192\)

\(2.2 \cdot 10^{-17}\)

\(2.2 \cdot 10^{-16}\)

\(3.6 \cdot 10^{-17}\,(0.25\,\text{ulp})\)

\(3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{tan}()\)

\(-8192 < x < 8192\)

\(6.8 \cdot 10^{-16}\)

\(1.2 \cdot 10^{-10}\)

\(5.4 \cdot 10^{-17}\,(0.35\,\text{ulp})\)

\(4.1 \cdot 10^{-16}\,(3\,\text{ulp})\)

\(\text{cot}()\)

\(-8192 < x < 8192\)

\(4.9 \cdot 10^{-16}\)

\(1.2 \cdot 10^{-10}\)

\(5.5 \cdot 10^{-17}\,(0.36\,\text{ulp})\)

\(4.4 \cdot 10^{-16}\,(3\,\text{ulp})\)

\(\text{asin}()\)

\(-1 < x < 1\)

\(1.3 \cdot 10^{-17}\)

\(2.2 \cdot 10^{-16}\)

\(1.5 \cdot 10^{-17}\,(0.098\,\text{ulp})\)

\(2.2 \cdot 10^{-16}\,(1\,\text{ulp})\)

\(\text{acos}()\)

\(-1 < x < 1\)

\(5.4 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(3.5 \cdot 10^{-17}\,(0.23\,\text{ulp})\)

\(2.2 \cdot 10^{-16}\,(1\,\text{ulp})\)

\(\text{atan}()\)

\(-1 < x < 1\)

\(4.3 \cdot 10^{-17}\)

\(3.3 \cdot 10^{-16}\)

\(1 \cdot 10^{-16}\,(0.65\,\text{ulp})\)

\(7.1 \cdot 10^{-16}\,(5\,\text{ulp})\)

\(\text{sinh}()\)

\(-10 < x < 10\)

\(3.1 \cdot 10^{-14}\)

\(1.8 \cdot 10^{-12}\)

\(3.3 \cdot 10^{-17}\,(0.22\,\text{ulp})\)

\(4.3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{cosh}()\)

\(-10 < x < 10\)

\(2.2 \cdot 10^{-14}\)

\(1.8 \cdot 10^{-12}\)

\(2 \cdot 10^{-17}\,(0.13\,\text{ulp})\)

\(2.9 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{tanh}()\)

\(-10 < x < 10\)

\(5.6 \cdot 10^{-17}\)

\(3.3 \cdot 10^{-16}\)

\(6.1 \cdot 10^{-17}\,(0.52\,\text{ulp})\)

\(5.5 \cdot 10^{-16}\,(3\,\text{ulp})\)

\(\text{asinh}()\)

\(-30 < x < 30\)

\(5.1 \cdot 10^{-17}\)

\(8.9 \cdot 10^{-16}\)

\(1.9 \cdot 10^{-17}\,(0.13\,\text{ulp})\)

\(4.4 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{acosh}()\)

\(1 < x < 10\)

\(4.9 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(2.6 \cdot 10^{-17}\,(0.17\,\text{ulp})\)

\(6.6 \cdot 10^{-16}\,(5\,\text{ulp})\)

\(\text{atanh}()\)

\(-1 < x < 1\)

\(1.8 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(3.2 \cdot 10^{-17}\,(0.21\,\text{ulp})\)

\(3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{exp}()\)

\(-20 < x < 30\)

\(4.7 \cdot 10^{-6}\)

\(2 \cdot 10^{-3}\)

\(2.5 \cdot 10^{-17}\,(0.16\,\text{ulp})\)

\(3.3 \cdot 10^{-16}\,(2\,\text{ulp})\)

\(\text{log}()\)

\(10^{-20} < x < 2\cdot 10^{30}\)

\(1.9 \cdot 10^{-17}\)

\(1.4 \cdot 10^{-14}\)

\(2.7 \cdot 10^{-19}\,(0.0013\,\text{ulp})\)

\(2.2 \cdot 10^{-16}\,(1\,\text{ulp})\)

\(\text{erf}()\)

\(-1 < x < 1\)

\(4.7 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(9.6 \cdot 10^{-17}\,(0.63\,\text{ulp})\)

\(5.9 \cdot 10^{-16}\,(5\,\text{ulp})\)

\(\text{erfc}()\)

\(-1 < x < 1\)

\(4.8 \cdot 10^{-17}\)

\(4.4 \cdot 10^{-16}\)

\(9.6 \cdot 10^{-17}\,(0.64\,\text{ulp})\)

\(2.5 \cdot 10^{-15}\,(16\,\text{ulp})\)

Type signatures

The drjit.ArrayBase class and various core functions have relatively complicated-looking type signatures involving Python generics and type variables. This enables type-checking of arithmetic expressions and improves visual autocomplete in editors such as VS Code. This section explains how these type annotations work.

The drjit.ArrayBase class is both an abstract and a generic Python type parameterized by several auxiliary type parameters. They help static type checkers like MyPy and PyRight make sense how subclasses of this type transform when passed to various builtin operations. These auxiliary parameters are:

  • SelfT: the type of the array subclass (i.e., a forward reference of the type to itself).

  • SelfCpT: a union of compatible types, for which self + other or self | other produce a result of type SelfT.

  • ValT: the value type (i.e., the type of self[0])

  • ValCpT: a union of compatible types, for which self[0] + other or self[0] | other produce a result of type ValT.

  • RedT: type following reduction by drjit.sum() or drjit.all().

  • PlainT: the plain type underlying a special array (e.g. dr.scalar.Complex2f -> dr.scalar.Array2f, dr.llvm.TensorXi -> dr.llvm.Int).

  • MaskT: type produced by comparisons such as __eq__.

For example, here is the declaration of llvm.ad.Array2f shipped as part of Dr.Jit’s stub file drjit/llvm/ad.pyi:

class Array2f(drjit.ArrayBase['Array2f', '_Array2fCp', Float, '_FloatCp', Float, Array2b]):
    pass

String arguments provide forward references that the type checker will resolve at a later point. So here, we have

The mysterious-looking underscored forward references can be found at the bottom of the same stub, for example:

_Array2fCp: TypeAlias = Union['Array2f', '_FloatCp', 'drjit.llvm._Array2fCp',
                              'drjit.scalar._Array2fCp', 'Array2f', '_Array2f16Cp']

This alias creates a union of types that are compatible (as implied by the "Cp" suffix) with the type Array2f, for example when encountered in an arithmetic operations like an addition. This includes:

  • Whatever is compatible with the value type of the array (drjit.llvm.ad._FloatCp)

  • Types compatible with the non-AD version of the array (drjit.llvm._Array2fCp)

  • Types compatible with the scalar version of the array (drjit.scalar._Array2fCp)

  • Types compatible with a representative lower-precision version of that same array type (drjit.llvm.ad._Array2f16Cp)

These are all themselves type aliases representing unions continuing in the same vein, and so this in principle expands up a quite huge combined union. This enables static type inference based on Dr.Jit’s promotion rules.

With this background, we can now try to understand a type signature such as that of drjit.maximum():

@overload
def maximum(a: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], b: SelfCpT, /) -> SelfT: ...
@overload
def maximum(a: SelfCpT, b: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ...
@overload
def maximum(a: T, b: T, /) -> T: ...

Suppose we are computing the maximum of two 3D arrays:

a: Array3u = ...
b: Array3f = ...
c: WhatIsThis = dr.maximum(a, b)

In this case, WhatIsThis is Array3f due to the type promotion rules, but how does the type checker know this? When it tries the first overload, it realizes that b: Array3f is not part of the SelfCpT (compatible with self) type parameter of Array3u. In second overload, the test is reversed and succeeds, and the result is the SelfT of Array3f, which is also Array3f. The third overload exists to handle cases where neither input is a Dr.Jit array type. (e.g. dr.maximum(1, 2))