API Reference

Array creation

drjit.zeros(dtype: type, shape: int = 1) object
drjit.zeros(dtype: type, shape: collections.abc.Sequence[int]) object

Return a zero-initialized instance of the desired type and shape.

This function can create zero-initialized instances of various types. In particular, dtype can be:

  • A Dr.Jit array type like drjit.cuda.Array2f. When shape specifies a sequence, it must be compatible with static dimensions of the dtype. For example, dr.zeros(dr.cuda.Array2f, shape=(3, 100)) fails, since the leading dimension is incompatible with drjit.cuda.Array2f. When shape is an integer, it specifies the size of the last (dynamic) dimension, if available.

  • A tensorial type like drjit.scalar.TensorXf. When shape specifies a sequence (list/tuple/..), it determines the tensor rank and shape. When shape is an integer, the function creates a rank-1 tensor of the specified size.

  • A PyTree. In this case, drjit.zeros() will invoke itself recursively to zero-initialize each field of the data structure.

  • A scalar Python type like int, float, or bool. The shape parameter is ignored in this case.

Note that when dtype refers to a scalar mask or a mask array, it will be initialized to False as opposed to zero.

The function returns a literal constant array that consumes no device memory.

  • dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.

  • shape (Sequence[int] | int) – Shape of the desired array


A zero-initialized instance of type dtype.

Return type:


drjit.empty(dtype: type, shape: int = 1) object
drjit.empty(dtype: type, shape: collections.abc.Sequence[int]) object

Return an uninitialized Dr.Jit array of the desired type and shape.

This function can create uninitialized buffers of various types. It should only be used in combination with a subsequent call to an operation like drjit.scatter() that fills the array contents with valid data.

The dtype parameter can be used to request:

  • A Dr.Jit array type like drjit.cuda.Array2f. When shape specifies a sequence, it must be compatible with static dimensions of the dtype. For example, dr.empty(dr.cuda.Array2f, shape=(3, 100)) fails, since the leading dimension is incompatible with drjit.cuda.Array2f. When shape is an integer, it specifies the size of the last (dynamic) dimension, if available.

  • A tensorial type like drjit.scalar.TensorXf. When shape specifies a sequence (list/tuple/..), it determines the tensor rank and shape. When shape is an integer, the function creates a rank-1 tensor of the specified size.

  • A PyTree. In this case, drjit.empty() will invoke itself recursively to allocate memory for each field of the data structure.

  • A scalar Python type like int, float, or bool. The shape parameter is ignored in this case, and the function returns a zero-initialized result (there is little point in instantiating uninitialized versions of scalar Python types).

drjit.empty() delays allocation of the underlying buffer until an operation tries to read/write the actual array contents.

  • dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.

  • shape (Sequence[int] | int) – Shape of the desired array


An instance of type dtype with arbitrary/undefined contents.

Return type:


drjit.ones(dtype: type, shape: int = 1) object
drjit.ones(dtype: type, shape: collections.abc.Sequence[int]) object

Return an instance of the desired type and shape filled with ones.

This function can create one-initialized instances of various types. In particular, dtype can be:

  • A Dr.Jit array type like drjit.cuda.Array2f. When shape specifies a sequence, it must be compatible with static dimensions of the dtype. For example, dr.ones(dr.cuda.Array2f, shape=(3, 100)) fails, since the leading dimension is incompatible with drjit.cuda.Array2f. When shape is an integer, it specifies the size of the last (dynamic) dimension, if available.

  • A tensorial type like drjit.scalar.TensorXf. When shape specifies a sequence (list/tuple/..), it determines the tensor rank and shape. When shape is an integer, the function creates a rank-1 tensor of the specified size.

  • A PyTree. In this case, drjit.ones() will invoke itself recursively to initialize each field of the data structure.

  • A scalar Python type like int, float, or bool. The shape parameter is ignored in this case.

Note that when dtype refers to a scalar mask or a mask array, it will be initialized to True as opposed to one.

The function returns a literal constant array that consumes no device memory.

  • dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.

  • shape (Sequence[int] | int) – Shape of the desired array


A instance of type dtype filled with ones.

Return type:


drjit.full(dtype: type, value: object, shape: int = 1) object
drjit.full(dtype: type, value: object, shape: collections.abc.Sequence[int]) object

Return an constant-valued instance of the desired type and shape.

This function can create constant-valued instances of various types. In particular, dtype can be:

  • A Dr.Jit array type like drjit.cuda.Array2f. When shape specifies a sequence, it must be compatible with static dimensions of the dtype. For example, dr.full(dr.cuda.Array2f, value=1.0, shape=(3, 100)) fails, since the leading dimension is incompatible with drjit.cuda.Array2f. When shape is an integer, it specifies the size of the last (dynamic) dimension, if available.

  • A tensorial type like drjit.scalar.TensorXf. When shape specifies a sequence (list/tuple/..), it determines the tensor rank and shape. When shape is an integer, the function creates a rank-1 tensor of the specified size.

  • A PyTree. In this case, drjit.full() will invoke itself recursively to initialize each field of the data structure.

  • A scalar Python type like int, float, or bool. The shape parameter is ignored in this case.

The function returns a literal constant array that consumes no device memory.

  • dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.

  • value (object) – An instance of the underlying scalar type (float/int/bool, etc.) that will be used to initialize the array contents.

  • shape (Sequence[int] | int) – Shape of the desired array


A instance of type dtype filled with value

Return type:


drjit.opaque(dtype: type, value: object, shape: int = 1) object
drjit.opaque(dtype: type, value: object, shape: collections.abc.Sequence[int]) object

Return an opaque constant-valued instance of the desired type and shape.

This function is very similar to drjit.full() in that it creates constant-valued instances of various types including (potentially nested) Dr.Jit arrays, tensors, and PyTrees. Please refer to the documentation of drjit.full() for details on the function signature. However, drjit.full() creates literal constant arrays, which means that Dr.Jit is fully aware of the array contents.

In contrast, drjit.opaque() produces an opaque array backed by a representation in device memory.

Why is this useful?

Consider the following snippet, where a complex calculation is parameterized by the constant 1.

from drjit.llvm import Float

result = complex_function(Float(1), ...) # Float(1) is equivalent to dr.full(Float, 1)

The print() statement will cause Dr.Jit to evaluate the queued computation, which likely also requires compilation of a new kernel (if that exact pattern of steps hasn’t been observed before). Kernel compilation is costly and may be much slower than the actual computation that needs to be done.

Suppose we later wish to evaluate the function with a different parameter:

result = complex_function(Float(2), ...)

The constant 2 is essentially copy-pasted into the generated program, causing a mismatch with the previously compiled kernel that therefore cannot be reused. This unfortunately means that we must once more wait a few tens or even hundreds of milliseconds until a new kernel has been compiled and uploaded to the device.

This motivates the existence of drjit.opaque(). By making a variable opaque to Dr.Jit’s tracing mechanism, we can keep constants out of the generated program and improve the effectiveness of the kernel cache:

# The following lines reuse the compiled kernel regardless of the constant
value = dr.opqaque(Float, 2)
result = complex_function(value, ...)

This function is related to drjit.make_opaque(), which can turn an already existing Dr.Jit array, tensor, or PyTree into an opaque representation.

  • dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.

  • value (object) – An instance of the underlying scalar type (float/int/bool, etc.) that will be used to initialize the array contents.

  • shape (Sequence[int] | int) – Shape of the desired array


A instance of type dtype filled with value

Return type:


drjit.arange(dtype: type[T], size: int) T
drjit.arange(dtype: type[T], start: int, stop: int, step: int = 1) T

This function generates an integer sequence on the interval [start, stop) with step size step, where start = 0 and step = 1 if not specified.

  • dtype (type) – Desired Dr.Jit array type. The dtype must refer to a dynamically sized 1D Dr.Jit array such as drjit.scalar.ArrayXu or drjit.cuda.Float.

  • start (int) – Start of the interval. The default value is 0.

  • stop/size (int) – End of the interval (not included). The name of this parameter differs between the two provided overloads.

  • step (int) – Spacing between values. The default value is 1.


The computed sequence of type dtype.

Return type:


drjit.linspace(dtype: type[T], start: float, stop: float, num: int, endpoint: bool = True) T

This function generates an evenly spaced floating point sequence of size num covering the interval [start, stop].

  • dtype (type) – Desired Dr.Jit array type. The dtype must refer to a dynamically sized 1D Dr.Jit floating point array, such as drjit.scalar.ArrayXf or drjit.cuda.Float.

  • start (float) – Start of the interval.

  • stop (float) – End of the interval.

  • num (int) – Number of samples to generate.

  • endpoint (bool) – Should the interval endpoint be included? The default is True.


The computed sequence of type dtype.

Return type:


Control flow

drjit.syntax(f: None = None, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) Callable[[T], T]
drjit.syntax(f: T, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) T
drjit.hint(arg: T, /, *, mode: Literal['scalar', 'evaluated', 'symbolic', None] | None = None, max_iterations: int | None = None, label: str | None = None, include: List[object] | None = None, exclude: List[object] | None = None, strict: bool = True) T

Within ordinary Python code, this function is unremarkable: it returns the positional-only argument arg while ignoring any specified keyword arguments.

The main purpose of drjit.hint() is to provide hints that influence the transformation performed by the @drjit.syntax decorator. The following kinds of hints are supported:

  1. mode overrides the compilation mode of a while loop or if statement. The following choices are available:

  2. The optional strict=False reduces the strictness of variable consistency checks.

    Consider the following snippet:

    from drjit.llvm import UInt32
    def f(x: UInt32):
        if x < 4:
            y = 3
            y = 5
        return y

    This code will raise an exception.

    >> f(UInt32(1))
    RuntimeError: drjit.if_stmt(): the non-array state variable 'y' of type 'int' changed from '5' to '10'.
    Please review the interface and assumptions of 'drjit.while_loop()' as explained in the documentation

    This is because the computed variable y of type int has an inconsistent value depending on the taken branch. Furthermore, y is a scalar Python type that isn’t tracked by Dr.Jit. The fix here is to initialize y with UInt32(<integer value>).

    However, there may also be legitimate situations where such an inconsistency is needed by the implementation. This can be fine as y is not used below the if statement. In this case, you can annotate the conditional or loop with dr.hint(..., strict=False), which disables the check.

  3. max_iterations specifies a maximum number of loop iterations for reverse-mode automatic differentiation.

    Naive reverse-mode differentiation of loops (unless replaced by a smarter problem-specific strategy via drjit.custom and drjit.CustomOp) requires allocation of large buffers that hold loop state for all iterations.

    Dr.Jit requires an upper bound on the maximum number of loop iterations so that it can allocate such buffers, which can be provided via this hint. Otherwise, reverse-mode differentiation of loops will fail with an error message.

  4. label provovides a descriptive label.

    Dr.Jit will include this label as a comment in the generated intermediate representation, which can be helpful when debugging the compilation of large programs.

  5. include and exclude indicates to the @drjit.syntax decorator that a local variable should or should not be considered to be part of the set of state variables passed to drjit.while_loop() or drjit.if_stmt().

    While transforming a function, the @drjit.syntax decorator sequentially steps through a program to identify the set of read and written variables. It then forwards referenced variables to recursive drjit.while_loop() and drjit.if_stmt() calls. In rare cases, it may be useful to manually include or exclude a local variable from this process— specify a list of such variables to the drjit.hint() annotation to do so.

drjit.while_loop(state: tuple[*Ts], cond: typing.Callable[[*Ts], AnyArray | bool], body: typing.Callable[[*Ts], tuple[*Ts]], labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True, compress: bool | None = None, max_iterations: int | None = None) tuple[*Ts]

Repeatedly execute a function while a loop condition holds.


This function provides a vectorized generalization of a standard Python while loop. For example, consider the following Python snippet

i: int = 1
while i < 10:
    x *= x
    i += 1

This code would fail when i is replaced by an array with multiple entries (e.g., of type drjit.llvm.Int). In that case, the loop condition evaluates to a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, each entry of the array may need to run the loop for a different number of iterations. A standard Python while loop is not able to do so.

The drjit.while_loop() function realizes such a fine-grained looping mechanism. It takes three main input arguments:

  1. state, a tuple of state variables that are modified by the loop iteration,

    Dr.Jit optimizes away superfluous state variables, so there isn’t any harm in specifying variables that aren’t actually modified by the loop.

  2. cond, a function that takes the state variables as input and uses them to evaluate and return the loop condition in the form of a boolean array,

  3. body, a function that also takes the state variables as input and runs one loop iteration. It must return an updated set of state variables.

The function calls cond and body to execute the loop. It then returns a tuple containing the final version of the state variables. With this functionality, a vectorized version of the above loop can be written as follows:

i, x = dr.while_loop(
    state=(i, x),
    cond=lambda i, x: i < 10,
    body=lambda i, x: (i+1, x*x)

Lambda functions are convenient when the condition and body are simple enough to fit onto a single line. In general you may prefer to define local functions (def loop_cond(i, x): ...) and pass them to the cond and body arguments.

Dr.Jit also provides the @drjit.syntax decorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) using drjit.while_loop(). With this decorator, the above example would be written as follows:

def f(i, x):
    while i < 10:
        x *= x
        i += 1
     return i, x

Evaluation modes

Dr.Jit uses one of three different modes to compile this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).

  1. Scalar mode: Scalar loops that don’t need any vectorization can fall back to a simple Python loop construct.

    while cond(state):
        state = body(*state)

    This is the default strategy when cond(state) returns a scalar Python bool.

    The loop body may still use Dr.Jit types, but note that this effectively unrolls the loop, generating a potentially long sequence of instructions that may take a long time to compile. Symbolic mode (discussed next) may be advantageous in such cases.

  2. Symbolic mode: Here, Dr.Jit runs a single loop iteration to capture its effect on the state variables. It embeds this captured computation into the generated machine code. The loop will eventually run on the device (e.g., the GPU) but unlike a Python while statement, the loop does not run on the host CPU (besides the mentioned tentative evaluation for symbolic tracing).

    When loop optimizations are enabled (drjit.JitFlag.OptimizeLoops), Dr.Jit may re-trace the loop body so that it runs twice in total. This happens transparently and has no influence on the semantics of this operation.

  3. Evaluated mode: in this mode, Dr.Jit will repeatedly evaluate the loop state variables and update active elements using the loop body function until all of them are done. Conceptually, this is equivalent to the following Python code:

    active = True
    while True:
       active &= cond(state)
       if not dr.any(active):
       state = dr.select(active, body(state), state)

    (In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)

    Dr.Jit will typically compile a kernel when it runs the first loop iteration. Subsequent iterations can then reuse this cached kernel since they perform the same exact sequence of operations. Kernel caching tends to be crucial to achieve good performance, and it is good to be aware of pitfalls that can effectively disable it.

    For example, when you update a scalar (e.g. a Python int) in each loop iteration, this changing counter might be merged into the generated program, forcing the system to re-generate and re-compile code at every iteration, and this can ultimately dominate the execution time. If in doubt, increase the log level of Dr.Jit (drjit.set_log_level() to drjit.LogLevel.Info) and check if the kernels being launched contain the term cache miss. You can also inspect the Kernels launched line in the output of drjit.whos(). If you observe soft or hard misses at every loop iteration, then kernel caching isn’t working and you should carefully inspect your code to ensure that the computation stays consistent across iterations.

    When the loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.

    Alternatively, you may specify the parameter compress=True or set the flag drjit.JitFlag.CompressLoops, which causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases, which is why it isn’t enabled by default.

A separate section about symbolic and evaluated modes discusses these two options in further detail.

The drjit.while_loop() function chooses the evaluation mode as follows:

  1. When the mode argument is set to None (the default), the function examines the loop condition. It uses scalar mode when this produces a Python bool, otherwise it inspects the drjit.JitFlag.SymbolicLoops flag to switch between symbolic (the default) and evaluated mode.

    To change this automatic choice for a region of code, you may specify the mode= keyword argument, nest code into a drjit.scoped_set_flag() block, or change the behavior globally via drjit.set_flag():

    with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False):
        # .. nested code will use evaluated loops ..
  2. When mode is set to "scalar" "symbolic", or "evaluated", it directly uses that method without inspecting the compilation flags or loop condition type.

When using the @drjit.syntax decorator to automatically convert Python while loops into drjit.while_loop() calls, you can also use the drjit.hint() function to pass keyword arguments including mode, label, or max_iterations to the generated looping construct:

while dr.hint(i < 10, name='My loop', mode='evaluated'):
   # ...


The loop condition function must be pure (i.e., it should never modify the state variables or any other kind of program state). The loop body should not write to variables besides the officially declared state variables:

y = ..
def loop_body(x):
    y[0] += x     # <-- don't do this. 'y' is not a loop state variable

dr.while_loop(body=loop_body, ...)

Dr.Jit automatically tracks dependencies of indirect reads (done via drjit.gather()) and indirect writes (done via drjit.scatter(), drjit.scatter_reduce(), drjit.scatter_add(), drjit.scatter_inc(), etc.). Such operations create implicit inputs and outputs of a loop, and these do not need to be specified as loop state variables (however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within loops), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).

y = ..
def loop_body(x):
    # Scattering to 'y' is okay even if it is not declared as loop state
    dr.scatter(target=y, value=x, index=0)

Another important assumption is that the loop state remains consistent across iterations, which means:

  1. The type of state variables is not allowed to change. You may not declare a Python float before a loop and then overwrite it with a drjit.cuda.Float (or vice versa).

  2. Their structure/size must be consistent. The loop body may not turn a variable with 3 entries into one that has 5.

  3. Analogously, state variables must always be initialized prior to the loop. This is the case even if you know that the loop body is guaranteed to overwrite the variable with a well-defined result. An initial value of None would violate condition 1 (type invariance), while an empty array would violate condition 2 (shape compatibility).

The implementation will check for violations and, if applicable, raise an exception identifying problematic state variables.

Potential pitfalls

  1. Long compilation times.

    In the example below, i < 100000 is scalar, causing drjit.while_loop() to use the scalar evaluation strategy that effectively copy-pastes the loop body 100000 times to produce a giant program. Code written in this way will be bottlenecked by the CUDA/LLVM compilation stage.

    def f():
        i = 0
        while i < 100000:
            # .. costly computation
            i += 1
  2. Incorrect behavior in symbolic mode.

    Let’s fix the above program by casting the loop condition into a Dr.Jit type to ensure that a symbolic loop is used. Problem solved, right?

    from drjit.cuda import Bool
    def f():
        i = 0
        while Bool(i < 100000):
            # .. costly computation
            i += 1

    Unfortunately, no: this loop never terminates when run in symbolic mode. Symbolic mode does not track modifications of scalar/non-Dr.Jit types across loop iterations such as the int-valued loop counter i. It’s as if we had written while Bool(0 < 100000), which of course never finishes.

    Evaluated mode does not have this problem—if your loop behaves differently in symbolic and evaluated modes, then some variation of this mistake is likely to blame. To fix this, we must declare the loop counter as a vector type before the loop and then modify it as follows:

    from drjit.cuda import Int
    def f():
        i = Int(0)
        while i < 100000:
            # .. costly computation
            i += 1


This new implementation of the drjit.while_loop() abstraction still lacks the functionality to break or return from the loop, or to continue to the next loop iteration. We plan to add these capabilities in the near future.


  • state (tuple) – A tuple containing the initial values of the loop’s state variables. This tuple normally consists of Dr.Jit arrays or PyTrees. Other values are permissible as well and will be forwarded to the loop body. However, such variables will not be captured by the symbolic tracing process.

  • cond (Callable) – a function/callable that will be invoked with *state (i.e., the state variables will be unpacked and turned into function arguments). It should return a scalar Python bool or a boolean-typed Dr.Jit array representing the loop condition.

  • body (Callable) – a function/callable that will be invoked with *state (i.e., the state variables will be unpacked and turned into function arguments). It should update the loop state and then return a new tuple of state variables that are compatible with the previous state (see the earlier description regarding what such compatibility entails).

  • mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides None are: "scalar", "symbolic", "evaluated". If not specified, the function first checks if the loop is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flag drjit.JitFlag.SymbolicLoops and then either performs a symbolic or an evaluated loop.

  • compress (Optional[bool]) – Set this this parameter to True or False to enable or disable loop state compression in evaluated loops (see the text above for a description of this feature). The function queries the value of drjit.JitFlag.CompressLoops when the parameter is not specified. Symbolic loops ignore this parameter.

  • labels (list[str]) – An optional list of labels associated with each state entry. Dr.Jit uses this to provide better error messages in case of a detected inconsistency. The @drjit.syntax decorator automatically provides these labels based on the transformed code.

  • label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.

  • max_iterations (int) – The maximum number of loop iterations (default: -1). You must specify a correct upper bound here if you wish to differentiate the loop in reverse mode. In that case, the maximum iteration count is used to reserve memory to store intermediate loop state.

  • strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of drjit.hint() for an example. The default is strict=True.


The function returns the final state of the loop variables following termination of the loop.

Return type:


drjit.if_stmt(args: tuple[*Ts], cond: AnyArray | bool, true_fn: typing.Callable[[*Ts], T], false_fn: typing.Callable[[*Ts], T], arg_labels: typing.Sequence[str] = (), rv_labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True) T

Conditionally execute code.


This function provides a vectorized generalization of a standard Python if statement. For example, consider the following Python snippet

i: int = .. some expression ..
if i > 0:
    x = f(i) # <-- some costly function 'f' that depends on 'i'
    y += 1

This code would fail if i is replaced by an array containing multiple entries (e.g., of type drjit.llvm.Int). In that case, the conditional expression produces a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, some of the entries may want to run the body of the if statement, while others must skip to the else block. This is not compatible with the semantics of a standard Python if statement.

The drjit.if_stmt() function realizes a more fine-grained conditional operation that accommodates these requirements, while avoiding execution of the costly branch unless this is truly needed. It takes the following input arguments:

  1. cond, a boolean array that specifies whether the body of the if statement should execute.

  2. A tuple of input arguments (args) that will be forwarded to true_fn and false_fn. It is important to specify all inputs to ensure correct derivative tracking of the operation.

  3. true_fn, a callable that implements the body of the if block.

  4. false_fn, a callable that implements the body of the else block.

The implementation will invoke true_fn(*args) and false_fn(*args) to trace their contents. The return values of these functions must be compatible with each other (a precise definition of compatibility is described below). A vectorized version of the earlier example can then be written as follows:

x, y = dr.if_stmt(
    args=(i, x, y),
    cond=i > 0,
    true_fn=lambda i, x, y: (f(i), y),
    false_fn=lambda i, x, y: (x, y + 1)

Lambda functions are convenient when true_fn and false_fn are simple enough to fit onto a single line. In general you may prefer to define local functions (def true_fn(i, x, y): ...) and pass them to the true_fn and false_fn arguments.

Dr.Jit later optimizes away superfluous inputs/outputs of drjit.if_stmt(), so there isn’t any harm in, e.g., specifying an identical element of a return value in both true_fn and false_fn.

Dr.Jit also provides the @drjit.syntax decorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) using drjit.if_stmt(). With this decorator, the above example would be written as follows:

def f(i, x, y):
    if i > 0:
        x = f(i)
        y += 1
    return x, y

Evaluation modes

Dr.Jit uses one of three different modes to realize this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).

  1. Scalar mode: Scalar if statements that don’t need any vectorization can be reduced to normal Python branching constructs:

    if cond:
        state = true_fn(*args)
        state = false_fn(*args)

    This strategy is the default when cond is a scalar Python bool.

  2. Symbolic mode: Dr.Jit runs true_fn and false_fn to capture the computation performed by each function, which allows it to generate an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.

  3. Evaluated mode: in this mode, Dr.Jit runs both branches of the if statement and then combines the results via drjit.select(). This is nearly equivalent to the following Python code:

    true_state = true_fn(*state)
    false_state = false_fn(*state) if false_fn else state
    state = dr.select(cond, true_fn, false_fn)

    (In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)

    Evaluated mode is conceptually simpler but also slower, since the device executes both sides of a branch when only one of them is actually needed.

The mode is chosen as follows:

  1. When the mode argument is set to None (the default), the function examines the type of the cond input and uses scalar mode if the type is a builtin Python bool.

    Otherwise, it chooses between symbolic and evaluated mode based on the drjit.JitFlag.SymbolicConditionals flag, which is set by default. To change this choice for a region of code, you may specify the mode= keyword argument, nest it into a drjit.scoped_set_flag() block, or change the behavior globally via drjit.set_flag():

    with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False):
        # .. nested code will use evaluated mode ..
  2. When mode is set to "scalar" "symbolic", or "evaluated", it directly uses that mode without inspecting the compilation flags or condition type.

When using the @drjit.syntax decorator to automatically convert Python if statements into drjit.if_stmt() calls, you can also use the drjit.hint() function to pass keyword arguments including the mode and label parameters.

if dr.hint(i < 10, mode='evaluated'):
   # ...


The return values of true_fn and false_fn must be of the same type. This requirement applies recursively if the return value is a PyTree.

Dr.Jit will refuse to compile vectorized conditionals, in which true_fn and false_fn return a scalar that is inconsistent between the branches.

>>> @dr.syntax
... def (x):
...    if x > 0:
...        y = 1
...    else:
...        y = 0
...    return y
>>> print(f(dr.llvm.Float(-1,2)))
RuntimeError: dr.if_stmt(): detected an inconsistency when comparing the return
values of 'true_fn' and 'false_fn': drjit.detail.check_compatibility(): inconsistent
scalar Python object of type 'int' for field 'y'.

Please review the interface and assumptions of dr.if_stmt() as explained in the
Dr.Jit documentation.

The problem can be solved by assigning an instance of a capitalized Dr.Jit type (e.g., y=Int(1)) so that the operation can be tracked.

The functions true_fn and false_fn should not write to variables besides the explicitly declared return value(s):

vec = drjit.cuda.Array3f(1, 2, 3)
def true_fn(x):
    vec.x += x     # <-- don't do this. 'y' is not a declared output

dr.if_stmt(args=(x,), true_fun=true_fn, ...)

This example can be fixed as follows:

def true_fn(x, vec):
    vec.x += x
    return vec

vec = dr.if_stmt(args=(x, vec), true_fun=true_fn, ...)

drjit.if_stmt() is differentiable in both forward and reverse modes. Correct derivative tracking requires that regular differentiable inputs are specified via the args parameter. The @drjit.syntax decorator ensures that these assumptions are satisfied.

Dr.Jit also tracks dependencies of indirect reads (done via drjit.gather()) and indirect writes (done via drjit.scatter(), drjit.scatter_reduce(), drjit.scatter_add(), drjit.scatter_inc(), etc.). Such operations create implicit inputs and outputs, and these do not need to be specified as part of args or the return value of true_fn and false_fn (however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within conditional statements), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).

y = ..
def true_fn(x):
    # 'y' is neither declared as input nor output of 'f', which is fine
    dr.scatter(target=y, value=x, index=0)

dr.if_stmt(args=(x,), true_fn=true_fn, ...)


  • cond (bool|drjit.ArrayBase) – a scalar Python bool or a boolean-valued Dr.Jit array.

  • args (tuple) – A list of positional arguments that will be forwarded to true_fn and false_fn.

  • true_fn (Callable) – a callable that implements the body of the if block.

  • false_fn (Callable) – a callable that implements the body of the else block.

  • mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides None are: "scalar", "symbolic", "evaluated".

  • arg_labels (list[str]) – An optional list of labels associated with each input argument. Dr.Jit uses this feature in combination with the @drjit.syntax decorator to provide better error messages in case of detected inconsistencies.

  • rv_labels (list[str]) – An optional list of labels associated with each element of the return value. This parameter should only be specified when the return value is a tuple. Dr.Jit uses this feature in combination with the @drjit.syntax decorator to provide better error messages in case of detected inconsistencies.

  • label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.

  • strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of drjit.hint() for an example. The default is strict=True.


Combined return value mixing the results of true_fn and false_fn.

Return type:


drjit.switch(index: object, targets: collections.abc.Sequence, *args, **kwargs) object

Selectively invoke functions based on a provided index array.

When called with a scalar index (of type int), this function is equivalent to the following Python expression:

targets[index](*args, **kwargs)

When called with a Dr.Jit index array (specifically, 32-bit unsigned integers), it performs the vectorized equivalent of the above and assembles an array of return values containing the result of all referenced functions. It does so efficiently using at most a single invocation of each function in targets.

from drjit.llvm import UInt32

res = dr.switch(
    index=UInt32(0, 0, 1, 1), # <-- selects the function
    targets=[                 # <-- arbitrary list of callables
        lambda x: x,
        lambda x: x*10
    UInt32(1, 2, 3, 4)        # <-- argument passed to function

# res now contains [0, 10, 20, 30]

The function traverses the set of positional (*args) and keyword arguments (**kwargs) to find all Dr.Jit arrays including arrays contained within PyTrees. It routes a subset of array entries to each function as specified by the index argument.

Dr.Jit will use one of two possible strategies to compile this operation depending on the active compilation flags (see drjit.set_flag(), drjit.scoped_set_flag()):

  1. Symbolic mode: Dr.Jit transcribes every function into a counterpart in the generated low-level intermediate representation (LLVM IR or PTX) and targets them via an indirect jump instruction.

    This mode is used when drjit.JitFlag.SymbolicCalls is set, which is the default.

  2. Evaluated mode: Dr.Jit evaluates the inputs index, args, kwargs via drjit.eval(), groups them by index, and invokes each function with with the subset of inputs that reference it. Callables that are not referenced by any element of index are ignored.

    In this mode, a drjit.switch() statement will cause Dr.Jit to launch a series of kernels processing subsets of the input data (one per function).

A separate section about symbolic and evaluated modes discusses these two options in detail.

To switch the compilation mode locally, use drjit.scoped_set_flag() as shown below:

with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False):
    result = dr.switch(..)

When a boolean Dr.Jit array (e.g., drjit.llvm.Bool, drjit.cuda.ad.Bool, etc.) is specified as last positional argument or as a keyword argument named active, that argument is treated specially: entries of the input arrays associated with a False mask entry are ignored and never passed to the functions. Associated entries of the return value will be zero-initialized. The function will still receive the mask argument as input, but it will always be set to True.


The indices provided to this operation are unchecked by default. Attempting to call functions beyond the end of the targets array is undefined behavior and may crash the application, unless such calls are explicitly disabled via the active parameter. Negative indices are not permitted.

If debug mode is enabled via the drjit.JitFlag.Debug flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound calls and furthermore report warnings to identify problematic source locations:

>>> print(dr.switch(UInt32(0, 100), [lambda x:x], UInt32(1)))
Attempted to invoke callable with index 100, but this↵
value must be smaller than 1. (<stdin>:2)
  • index (int|drjit.ArrayBase) – a list of indices to choose the functions

  • targets (Sequence[Callable]) – a list of callables to which calls will be dispatched based on the index argument.

  • mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides None are: "symbolic", "evaluated". If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flag drjit.JitFlag.SymbolicCalls and then either performs a symbolic or an evaluated call.

  • label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.

  • *args (tuple) – a variable-length list of positional arguments passed to the functions. PyTrees are supported.

  • **kwargs (dict) – a variable-length list of keyword arguments passed to the functions. PyTrees are supported.


When index is a scalar Python integer, the return value simply forwards the return value of the selected function. Otherwise, the function returns a Dr.Jit array or PyTree combining the results from each referenced callable.

Return type:


drjit.dispatch(inst: drjit.ArrayBase, target: collections.abc.Callable, *args, **kwargs) object

Invoke a provided Python function for each instance in an instance array.

This function invokes the provided target for each instance in the instance array inst and assembles the return values into a result array. Conceptually, it does the following:

def dispatch(inst, target, *args, **kwargs):
    result = []
    for in in inst:
        result.append(target(inst, *args, **kwargs))

However, the implementation accomplishes this more efficiently using only a single call per unique instance. Instead of a Python list, it returns a Dr.Jit array or PyTree.

In practice, this function is mainly good for two things:

  • Dr.Jit instance arrays contain C++ instance, and these will typically expose a set of methods. Adding further methods requires re-compiling C++ code and adding bindings, which may impede quick prototyping. With drjit.dispatch(), a developer can quickly implement additional vectorized method calls within Python (with the caveat that these can only access public members of the underlying type).

  • Dynamic dispatch is a relatively costly operation. When multiple calls are performed on the same set of instances, it may be preferable to merge them into a single and potentially significantly faster use of drjit.dispatch(). An example is shown below:

    inst = # .. Array of C++ instances ..
    result_1 = inst.func_1(arg1)
    result_2 = inst.func_2(arg2)

    The following alternative implementation instead uses drjit.dispatch():

    def my_func(self, arg1, arg2):
        return (self.func_1(arg1),
    result_1, result_2 = dr.dispatch(inst, my_func, arg1, arg2)

This function is otherwise very similar to drjit.switch() and similarly provides two different compilation modes, differentiability, and special handling of mask arguments. Please review the documentation of drjit.switch() for details.

  • inst (drjit.ArrayBase) – a Dr.Jit instance array.

  • target (Callable) – function to dispatch on all instances

  • mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides None are: "symbolic", "evaluated". If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flag drjit.JitFlag.SymbolicCalls and then either performs a symbolic or an evaluated call.

  • label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.

  • *args (tuple) – a variable-length list of positional arguments passed to the function. PyTrees are supported.

  • **kwargs (dict) – a variable-length list of keyword arguments passed to the function. PyTrees are supported.


A Dr.Jit array or PyTree containing the result of each performed function call.

Return type:


Horizontal operations

These operations are horizontal in the sense that [..]

drjit.gather(dtype: type[T], source: object, index: AnyArray | Sequence[int] | int, active: AnyArray | Sequence[bool] | bool = True, mode: drjit.ReduceMode = drjit.ReduceMode.Auto) T

Gather values from a flat array or nested data structure.

This function performs a gather (i.e., indirect memory read) from source at position index. It expects a dtype argument and will return an instance of this type. The optional active argument can be used to disable some of the components, which is useful when not all indices are valid; the corresponding output will be zero in this case.

This operation can be used in the following different ways:

  1. When dtype is a 1D Dr.Jit array like drjit.llvm.ad.Float, this operation implements a parallelized version of the Python array indexing expression source[index] with optional masking. Example:

    source = dr.cuda.Float([...])
    index = dr.cuda.UInt([...]) # Note: negative indices are not permitted
    result = dr.gather(dtype=type(source), source=source, index=index)
  2. When dtype is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:

    • When type(source) matches dtype, the gather operation threads through entries and invokes itself recursively. For example, the gather operation in

      result = dr.cuda.Array3f(...)
      index = dr.cuda.UInt([...])
      result = dr.gather(dr.cuda.Array3f, source, index)

      is equivalent to

      result = dr.cuda.Array3f(
          dr.gather(dr.cuda.Float, source.x, index),
          dr.gather(dr.cuda.Float, source.y, index),
          dr.gather(dr.cuda.Float, source.z, index)

      A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.

    • Otherwise, the operation reconstructs the requested dtype from a flat source array, using C-style ordering with a suitably modified index. For example, the gather below reads 3D vectors from a 1D array.

      source = dr.cuda.Float([...])
      index = dr.cuda.UInt([...])
      result = dr.gather(dr.cuda.Array3f, source, index)

      and is equivalent to

      result = dr.cuda.Vector3f(
          dr.gather(dr.cuda.Float, source, index*3 + 0),
          dr.gather(dr.cuda.Float, source, index*3 + 1),
          dr.gather(dr.cuda.Float, source, index*3 + 2))


The indices provided to this operation are unchecked by default. Attempting to read beyond the end of the source array is undefined behavior and may crash the application, unless such reads are explicitly disabled via the active parameter. Negative indices are not permitted.

If debug mode is enabled via the drjit.JitFlag.Debug flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound reads and furthermore report warnings to identify problematic source locations:

>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100))
drjit.gather(): out-of-bounds read from position 100 in an array↵
of size 3. (<stdin>:2)
  • dtype (type) – The desired output type (typically equal to type(source), but other variations are possible as well, see the description above.)

  • source (object) – The object from which data should be read (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)

  • index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g., drjit.scalar.ArrayXu or drjit.cuda.UInt) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.

  • active (object) – an optional 1D dynamic Dr.Jit mask array (e.g., drjit.scalar.ArrayXb or drjit.cuda.Bool) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default is True.

  • mode (drjit.ReduceMode) – The reverse-mode derivative of a gather is an atomic scatter-reduction. The execution of such atomics can be rather performance-sensitive (see the discussion of drjit.ReduceMode for details), hence Dr.Jit offers a few different compilation strategies to realize them. Specifying this parameter selects a strategy for the derivative of a particular gather operation. The default is drjit.ReduceMode.Auto.

drjit.scatter(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None

Scatter values into a flat array or nested data structure.

This operation performs a scatter (i.e., indirect memory write) of the value parameter to the target array at position index. The optional active argument can be used to disable some of the individual write operations, which is useful when not all provided values or indices are valid.

This operation can be used in the following different ways:

  1. When target is a 1D Dr.Jit array like drjit.llvm.ad.Float, this operation implements a parallelized version of the Python array indexing expression target[index] = value with optional masking. Example:

    target = dr.empty(dr.cuda.Float, 1024*1024)
    value = dr.cuda.Float([...])
    index = dr.cuda.UInt([...]) # Note: negative indices are not permitted
    dr.scatter(target, value=value, index=index)
  2. When target is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:

    • When target and value are of the same type, the scatter operation threads through entries and invokes itself recursively. For example, the scatter operation in

      target = dr.cuda.Array3f(...)
      value = dr.cuda.Array3f(...)
      index = dr.cuda.UInt([...])
      dr.scatter(target, value, index)

      is equivalent to

      dr.scatter(target.x, value.x, index)
      dr.scatter(target.y, value.y, index)
      dr.scatter(target.z, value.z, index)

      A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.

    • Otherwise, the operation flattens the value array and writes it using C-style ordering with a suitably modified index. For example, the scatter below writes 3D vectors into a 1D array.

      target = dr.cuda.Float(...)
      value = dr.cuda.Array3f(...)
      index = dr.cuda.UInt([...])
      dr.scatter(target, value, index)

      and is equivalent to

      dr.scatter(target, value.x, index*3 + 0)
      dr.scatter(target, value.y, index*3 + 1)
      dr.scatter(target, value.z, index*3 + 2)


The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the active parameter). Negative indices are not permitted.

If debug mode is enabled via the drjit.JitFlag.Debug flag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.

Dr.Jit makes no guarantees about the expected behavior when a scatter operation has conflicts, i.e., when a specific position is written multiple times by a single drjit.scatter() operation.

  • target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)

  • value (object) – The values to be written (typically of type type(target), but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the input is not an array type.

  • index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g., drjit.scalar.ArrayXu or drjit.cuda.UInt) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.

  • active (object) – an optional 1D dynamic Dr.Jit mask array (e.g., drjit.scalar.ArrayXb or drjit.cuda.Bool) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default is True.

enum drjit.ReduceOp(value)

List of different atomic read-modify-write (RMW) operations supported by drjit.scatter_reduce().

Valid values are as follows:

Identity = ReduceOp.Identity

Perform an ordinary scatter operation that ignores the current entry.

Add = ReduceOp.Add


Mul = ReduceOp.Mul


Min = ReduceOp.Min


Max = ReduceOp.Max


And = ReduceOp.And

Binary AND operation.

Or = ReduceOp.Or

Binary OR operation.

enum drjit.ReduceMode(value)

Compilation strategy for atomic scatter-reductions.

Elements of of this enumeration determine how Dr.Jit executes atomic scatter-reductions, which refers to indirect writes that update an existing element in an array, while avoiding problems arising due to concurrency.

Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude.

This parameter also plays an important role for drjit.gather(), which is nominally a read-only operation. This is because the reverse-mode derivative of a gather turns it into an atomic scatter-addition, where further context on how to compile the operation is needed.

Dr.Jit implements several strategies to address contention, which can be selected by passing the optional mode parameter to drjit.scatter_reduce(), drjit.scatter_add(), and drjit.gather().

If you find that a part of your program is bottlenecked by atomic writes, then it may be worth explicitly specifying some of the strategies below to see which one performs best.

Valid values are as follows:

Auto = ReduceMode.Auto

Select the first valid option from the following list:

Direct = ReduceMode.Direct

Insert an ordinary atomic reduction operation into the program.

This mode is ideal when no or little contention is expected, for example because the target indices of scatters are well spread throughout the target array. This mode generates a minimal amount of code, which can help improve performance especially on GPU backends.

Local = ReduceMode.Local

Locally pre-reduce operands.

In this mode, Dr.Jit adds extra code to the compiled program to examine the target indices of atomic updates. For example, CUDA programs run with an instruction granularity referred to as a warp, which is a group of 32 threads. When some of these threads want to write to the same location, then those operands can be pre-processed to reduce the total number of necessary atomic memory transactions (potentially to just a single one!)

On the CPU/LLVM backend, the same process works at the granularity of packets. The details depends on the underlying instruction set—for example, there are 16 threads per packet on a machine with AVX512, so there is a potential for reducing atomic write traffic by that factor.

NoConflicts = ReduceMode.NoConflicts

Perform a non-atomic read-modify-write operation.

This mode is only safe in specific situations. The caller must guarantee that there are no conflicts (i.e., scatters targeting the same elements). If specified, Dr.Jit will generate a non-atomic read-modify-update operation that potentially runs significantly faster, especially on the LLVM backend.

Permute = ReduceMode.Permute

In contrast to prior enumeration entries, this one modifies plain (non-reductive) scatters and gathers. It exists to enable internal optimizations that Dr.Jit uses when differentiating vectorized function calls and compressed loops. You likely should not use it in your own code.

When setting this mode, the caller guarantees that there will be no conflicts, and that every entry is written exactly single time using an index vector representing a permutation (it’s fine this permutation is accomplished by multiple separate write operations, but there should be no more than 1 write to each element).

Giving ‘Permute’ as an argument to a (nominally read-only) gather operation is helpful because we then know that the reverse-mode derivative of this operation can be a plain scatter instead of a more costly atomic scatter-add.

Giving ‘Permute’ as an argument to a scatter operation is helpful because we then know that the forward-mode derivative does not depend on any prior derivative values associated with that array, as all current entries will be overwritten.

Expand = ReduceMode.Expand

Expand the target array to avoid write conflicts, then scatter non-atomically.

This feature is only supported on the LLVM backend. Other backends interpret this flag as if drjit.ReduceMode.Auto had been specified.

This mode internally expands the storage underlying the target array to a much larger size that is proportional to the number of CPU cores. Scalar (length-1) target arrays are expanded even further to ensure that each CPU gets an entirely separate cache line.

Following this one-time expansion step, the array can then accommodate an arbitrary sequence of scatter-reduction operations that the system will internally perform using non-atomic read-modify-write operations (i.e., analogous to the NoConflicts mode). Dr.Jit automatically re-compress the array into the ordinary representation.

On bigger arrays and on machines with many cores, the storage costs resulting from this mode can be prohibitive.

drjit.scatter_reduce(op: drjit.ReduceOp, target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None

Atomically update values in a flat array or nested data structure.

This function performs an atomic scatter-reduction, which is a read-modify-write operation that applies one of several possible mathematical functions to selected entries of an array. The following are supported:

Here, a refers to an entry of target selected by index, and b denotes the associated element of value. The operation resolves potential conflicts arising due to the parallel execution of this operation.

The optional active argument can be used to disable some of the updates, e.g., when not all provided values or indices are valid.

Atomic additions are subject to non-deterministic rounding errors. The reason for this is that IEEE-754 addition are non-commutative. The execution order is scheduling-dependent, which can lead to small variations across program runs.

Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude. Dr.Jit provides several different compilation strategies to reduce these costs, which can be selected via the mode parameter. The documentation of drjit.ReduceMode provides more detail and performance plots.

Backend support

Many combinations of reductions and variable types are not supported. Some combinations depend on the compute capability (CC) of the underlying CUDA device or on the LLVM version (LV) and the host architecutre (AMD64, x86_64). The following matrices display the level of support.










⚠️ CC≥60

⚠️ CC≥60



⚠️ CC≥90


⚠️ CC≥90












⚠️ LV≥16



⚠️ LV≥15

⚠️ LV≥16, ARM64

⚠️ LV≥15

⚠️ LV≥15


⚠️ LV≥15

⚠️ LV≥16, ARM64

⚠️ LV≥15

⚠️ LV≥15



The function raises an exception when the operation is not supported by the backend.

Scatter-reducing nested types

This operation can be used in the following different ways:

  1. When target is a 1D Dr.Jit array like drjit.llvm.ad.Float, this operation implements a parallelized version of the Python array indexing expression target[index] = op(target[index], value) with optional masking. Example:

    target = dr.zeros(dr.cuda.Float, 1024*1024)
    value = dr.cuda.Float([...])
    index = dr.cuda.UInt([...]) # Note: negative indices are not permitted
    dr.scatter_reduce(dr.ReduceOp.Add, target, value=value, index=index)
  2. When target is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:

    • When target and value are of the same type, the scatter-reduction threads through entries and invokes itself recursively. For example, the scatter operation in

      op = dr.ReduceOp.Add
      target = dr.cuda.Array3f(...)
      value = dr.cuda.Array3f(...)
      index = dr.cuda.UInt([...])
      dr.scatter_reduce(op, target, value, index)

      is equivalent to

      dr.scatter_reduce(op, target.x, value.x, index)
      dr.scatter_reduce(op, target.y, value.y, index)
      dr.scatter_reduce(op, target.z, value.z, index)

      A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.

    • Otherwise, the operation flattens the value array and writes it using C-style ordering with a suitably modified index. For example, the scatter-reduction below writes 3D vectors into a 1D array.

      op = dr.ReduceOp.Add
      target = dr.cuda.Float(...)
      value = dr.cuda.Array3f(...)
      index = dr.cuda.UInt([...])
      dr.scatter_reduce(op, target, value, index)

      and is equivalent to

      dr.scatter_reduce(op, target, value.x, index*3 + 0)
      dr.scatter_reduce(op, target, value.y, index*3 + 1)
      dr.scatter_reduce(op, target, value.z, index*3 + 2)


The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the active parameter). Negative indices are not permitted.

If debug mode is enabled via the drjit.JitFlag.Debug flag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.

Dr.Jit makes no guarantees about the relative ordering of atomic operations when a drjit.scatter_reduce() writes to the same element multiple times. Combined with the non-associate nature of floating point operations, concurrent writes will generally introduce non-deterministic rounding error.

  • op (drjit.ReduceOp) – Specifies the type of update that should be performed.

  • target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)

  • value (object) – The values to be used in the RMW operation (typically of type type(target), but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the the input is not an array type.

  • index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g., drjit.scalar.ArrayXu or drjit.cuda.UInt) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.

  • active (object) – an optional 1D dynamic Dr.Jit mask array (e.g., drjit.scalar.ArrayXb or drjit.cuda.Bool) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default is True.

  • mode (drjit.ReduceMode) – Dr.Jit offers several different strategies to implement atomic scatter-reductions that can be selected via this parameter. They achieve different best/worst case performance and, in the case of drjit.ReduceMode.Expand, involve additional memory storage overheads. The default is drjit.ReduceMode.Auto.

drjit.scatter_add(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None

Atomically add values to a flat array or nested data structure.

This function is equivalent to drjit.scatter_reduce(drjit.ReduceOp.Add, ...) and exists for reasons of convenience. Please refer to drjit.scatter_reduce() for details on atomic scatter-reductions.

drjit.scatter_add_kahan(target_1: drjit.ArrayBase, target_2: drjit.ArrayBase, value: object, index: object, active: object = True) None

Perform a Kahan-compensated atomic scatter-addition.

Atomic floating point accumulation can incur significant rounding error when many values contribute to a single element. This function implements an error-compensating version of drjit.scatter_add() based on the Kahan-Babuška-Neumeier algorithm that simultaneously accumulates into two target buffers.

In particular, the operation first accumulates a values into entries of a dynamic 1D array target_1. It tracks the round-off error caused by this operation and then accumulates this error into a second 1D array named target_2. At the end, the buffers can simply be added together to obtain the error-compensated result.

This function has a few limitations: in contrast to drjit.scatter_reduce() and drjit.scatter_add(), it does not perform a local reduction (see flag JitFlag.ScatterReduceLocal), which can be an important optimization when atomic accumulation is a performance bottleneck.

Furthermore, the function currently works with flat 1D arrays. This is an implementation limitation that could in principle be removed in the future.

Finally, the function is differentiable, but derivatives currently only propagate into target_1. This means that forward derivatives don’t enjoy the error compensation of the primal computation. This limitation is of no relevance for backward derivatives.

drjit.scatter_inc(target: drjit.ArrayBase, index: object, active: object = True) object

Atomically increment a value within an unsigned 32-bit integer array and return the value prior to the update.

This operation works just like the drjit.scatter_reduce() operation for 32-bit unsigned integer operands, but with a fixed value=1 parameter and op=ReduceOp::Add.

The main difference is that this variant additionally returns the old value of the target array prior to the atomic update in contrast to the more general scatter-reduction, which just returns None. The operation also supports masking—the return value in the unmasked case is undefined. Both target and index parameters must be 1D unsigned 32-bit arrays.

This operation is a building block for stream compaction: threads can scatter-increment a global counter to request a spot in an array and then write their result there. The recipe for this is look as follows:

data_1 = ...
data_2 = ...
active = drjit.ones(Bool, len(data_1)) # .. or a more complex condition

# This will hold the counter
ctr = UInt32(0)

# Allocate output buffers
max_size = 1024
data_compact_1 = dr.empty(Float, max_size)
data_compact_2 = dr.empty(Float, max_size)

idx = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active)

# Disable dr.scatter() operations below in case of a buffer overflow
active &= idx < max_size



When following this approach, be sure to provide the same mask value to the drjit.scatter_inc() and subsequent drjit.scatter() operations.

The function drjit.reshape() can be used to resize the resulting arrays to their compacted size. Please refer to the documentation of this function, specifically the code example illustrating the use of the shrink=True argument.

The function drjit.scatter_inc() exhibits the following unusual behavior compared to regular Dr.Jit operations: the return value references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the computed index is reused. In essence, the variable is “consumed” by the process of evaluation.

my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active)

dr.eval(data_compact_1) # Run Kernel #1

    index=my_index, # <-- oops, reusing my_index in another kernel.
    mask=active     #     This raises an exception.

To get the above code to work, you will need to evaluate my_index at the same time to materialize it into a stored (and therefore trivially reproducible) representation. For this, ensure that the size of the active mask matches len(data_*) and that it is not the trivial True default mask (otherwise, the evaluated my_index will be scalar).

dr.eval(data_compact_1, my_index)

Such multi-stage evaluation is potentially inefficient and may defeat the purpose of performing stream compaction in the first place. In general, prefer keeping all scatter operations involving the computed index in the same kernel, and then this issue does not arise.

The implementation of drjit.scatter_inc() performs a local reduction first, followed by a single atomic write per SIMD packet/warp. This is done to reduce contention from a potentially very large number of atomic operations targeting the same memory address. Fully masked updates do not cause memory traffic.

There is some conceptual overlap between this function and drjit.compress(), which can likewise be used to reduce a stream to a smaller subset of active items. The downside of drjit.compress() is that it requires evaluating the variables to be reduced, which can be very costly in terms of of memory traffic and storage footprint. Reducing through drjit.scatter_inc() does not have this limitation: it can operate on symbolic arrays that greatly exceed the available device memory. One advantage of drjit.compress() is that it essentially boils down to a relatively simple prefix sum, which does not require atomic memory operations (these can be slow in some cases).

drjit.block_reduce(op: ReduceOp, value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T

Reduce elements within blocks.

This function reduces all elements within contiguous blocks of size block_size along the trailing dimension of the input array value, returning a correspondingly smaller output array. Various types of reductions are supported (see drjit.ReduceOp for details).

For example, a sum reduction of a hypothetical array [a, b, c, d, e, f] with block_size=2 produces the output [a+b, c+d, e+f].

The function raises an exception when the length of the trailing dimension is not a multiple of the block size. It recursively threads through nested arrays and PyTrees.

Dr.Jit uses one of two strategies to realize this operation, which can be optionally forced by specifying the mode parameter.

  • mode="evaluated" first evaluates the input array via drjit.eval() and then launches a specialized reduction kernel.

    On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions with the limitation that it requires block_size to be a power of two. The LLVM backend parallelizes the operation via the built-in thread pool and has no block_size limitations.

  • mode="symbolic" uses drjit.scatter_reduce() to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input array is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).

    Disadvantages of this mode are that

    • Atomic scatters can suffer from memory contention (though drjit.scatter_reduce() takes steps to reduce contention, see its documentation for details).

    • Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer and floating point min/max reductions are unaffected by this.

  • mode=None (default) automatically picks a reasonable strategy according to the following logic:

    • Symbolic mode is admissible when the necessary atomic reduction is supported by the backend.

    • Evaluated mode is admissible when the input does not involve symbolic variables. On the CUDA backend block_size must furthermore be a power of two.

    • If only one strategy remains, then pick that one. Raise an exception when no strategy works out.

    • Otherwise, use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.

    • Use symbolic mode in all other cases.

For some inputs, no strategy works out (e.g., multiplicative reduction of an array with a non-power-of-two block size on the CUDA backend). The function will raise an exception in such cases.

Since evaluated mode can be quite a bit faster and is guaranteed to be deterministic, it is recommended that you design your program so that it invokes drjit.block_reduce() with a power-of-two block_size.


Tensor inputs are not supported. To reduce blocks within tensors, apply the regular axis-wide reductions (drjit.sum(), drjit.prod(), drjit.min(), drjit.max()) to reshaped tensors. For example, to sum-reduce a (16, 16) tensor by a factor of (4, 2) (i.e., to a (4, 8)-sized tensor), write dr.sum(dr.reshape(value, shape=(4, 4, 8, 2)), axis=(1, 3)).

  • value (object) – A Dr.Jit array or PyTree

  • block_size (int) – size of the block

  • mode (str | None) – optional parameter to force an evaluation strategy.


The block-reduced array or PyTree as specified above.

drjit.block_sum(value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T

Sum elements within blocks.

This is a convenience alias for drjit.block_reduce() with op set to drjit.ReduceOp.Add.

drjit.reduce(op: ReduceOp, value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object

Reduce the input array, tensor, or iterable along the specified axis/axes.

This function reduces arrays, tensors and other iterable Python types along one or multiple axes, where op selects the operation to be performed:

The functions drjit.sum(), drjit.prod(), drjit.min(), and drjit.max() are convenience aliases that call drjit.reduce() with specific values of op.

By default, the reduction is along axis 0 (i.e., the outermost one), returning an instance of the array’s element type. For instance, sum-reducing an array a of type drjit.cuda.Array3f is equivalent to writing a[0] + a[1] + a[2] and produces a result of type drjit.cuda.Float. Dr.Jit can trace this operation and include it in the generated kernel.

Negative indices (e.g. axis=-1) count backward from the innermost axis. Multiple axes can be specified as a tuple. The value axis=None requests a simultaneous reduction over all axes.

When reducing axes of a tensor, or when reducing the trailing dimension of a Jit-compiled array, some special precautions apply: these axes correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. This can be done using the following strategies:

  • mode="evaluated" first evaluates the input array via drjit.eval() and then launches a specialized reduction kernel.

    On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions. The LLVM backend parallelizes the reduction via the built-in thread pool.

  • mode="symbolic" uses drjit.scatter_reduce() to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).

    Disadvantages of this mode are that

    • Atomic scatters can suffer from memory contention (though the drjit.scatter_reduce() function takes steps to reduce contention, see its documentation for details).

    • Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer reductions and floating point min/max reductions are unaffected by this.

  • mode=None (default) automatically picks a reasonable strategy according to the following logic:

    • Use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.

    • Use evaluated mode when the necessary atomic reduction operation is not supported by the backend.

    • Otherwise, use symbolic mode.

This function generally strips away reduced axes, but there is one notable exception: it will never remove a trailing dynamic dimension, if present in the input array.

For example, reducing an instance of type drjit.cuda.Float along axis 0 does not produce a scalar Python float. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.

  • op (ReduceOp) – The operation that should be applied along the reduced axis/axes.

  • value (ArrayBase | Iterable | float | int) – An input Dr.Jit array or tensor.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.

  • mode (str | None) – optional parameter to force an evaluation strategy. Must equal "evaluated", "symbolic", or None.


The reduced array or tensor as specified above.

drjit.sum(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object

Sum-reduce the input array, tensor, or iterable along the specified axis/axes.

This function sum-reduces arrays, tensors and other iterable Python types along one or multiple axes. It is equivalent to dr.reduce(dr.ReduceOp.Add, ...). See the documentation of this function for further information.

  • value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.

  • mode (str | None) – optional parameter to force an evaluation strategy. Must equal "evaluated", "symbolic", or None.


The reduced array or tensor as specified above.

Return type:


drjit.prod(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object

Multiplicatively reduce the input array, tensor, or iterable along the specified axis/axes.

This function performs a multiplicative reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to dr.reduce(dr.ReduceOp.Mul, ...). See the documentation of this function for further information.

  • value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.

  • mode (str | None) – optional parameter to force an evaluation strategy. Must equal "evaluated", "symbolic", or None.


The reduced array or tensor as specified above.

Return type:


drjit.min(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object

Perform a minimum reduction of the input array, tensor, or iterable along the specified axis/axes.

(Not to be confused with drjit.minimum(), which computes the smaller of two values).

This function performs a minimum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to dr.reduce(dr.ReduceOp.Min, ...). See the documentation of this function for further information.

  • value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.

  • mode (str | None) – optional parameter to force an evaluation strategy. Must equal "evaluated", "symbolic", or None.


The reduced array or tensor as specified above.

Return type:


drjit.max(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object

Perform a maximum reduction of the input array, tensor, or iterable along the specified axis/axes.

(Not to be confused with drjit.maximum(), which computes the larger of two values).

This function performs a maximum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to dr.reduce(dr.ReduceOp.Max, ...). See the documentation of this function for further information.

  • value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.

  • mode (str | None) – optional parameter to force an evaluation strategy. Must equal "evaluated", "symbolic", or None.


The reduced array or tensor as specified above.

drjit.mean(value: object, axis: int | Tuple[int, ...] | None = 0, mode: Literal['symbolic', 'evaluated', None] | None = None) object

Compute the mean of the input array or tensor along one or multiple axes.

This function performs a horizontal sum reduction by adding values of the input array, tensor, or Python sequence along one or multiple axes and then dividing by the number of entries. By default, it sums along the outermost axis; specify axis=None to sum over all of them at once. The mean of an empty array is considered to be zero.

See the section on horizontal reductions for important general information about their properties.

  • value (float | int | Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type

  • axis (int | None) – The axis along which to reduce (Default: 0). A value of None causes a simultaneous reduction along all axes. Currently, only values of 0 and None are supported.


Result of the reduction operation)”;

Return type:

float | int | drjit.ArrayBase

drjit.all(value: object, axis: int | tuple[int, ...] | None = 0) object

Check if all elements along the specified axis are active.

Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the & (AND) operator.

By default, it reduces along index 0, which refers to the outermost axis. Negative indices (e.g. -1) count backwards from the innermost axis. The special argument axis=None causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to be True.

The function is internally based on dr.reduce(). See the documentation of this function for further information.

Like dr.reduce(), this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducing drjit.cuda.Bool does not produce a scalar Python bool. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.

Boolean 1D arrays automatically convert to bool if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:

from drjit.cuda import Float

x = Float(...)
if dr.all(s < 0):
   # ...

A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves drjit.eval() to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.

To avoid Boolean reductions, one can often use symbolic operations such as if_stmt(), while_loop(), etc. The @dr.syntax decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...) based on the condition.

def f(x: Float):
    if a < 0:
       # ...
  • value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.


The reduced array or tensor as specified above.

Return type:


drjit.any(value: object, axis: int | tuple[int, ...] | None = 0) object

Check if any elements along the specified axis are active.

Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the | (OR) operator.

By default, it reduces along index 0, which refers to the outermost axis. Negative indices (e.g. -1) count backwards from the innermost axis. The special argument axis=None causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to be False.

The function is internally based on dr.reduce(). See the documentation of this function for further information.

Like dr.reduce(), this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducing drjit.cuda.Bool does not produce a scalar Python bool. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.

Boolean 1D arrays automatically convert to bool if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:

from drjit.cuda import Float

x = Float(...)
if dr.any(s < 0):
   # ...

A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves drjit.eval() to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.

To avoid Boolean reductions, one can often use symbolic operations such as if_stmt(), while_loop(), etc. The @dr.syntax decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...) based on the condition.

def f(x: Float):
    if a < 0:
       # ...
  • value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.


Result of the reduction operation

Return type:

bool | drjit.ArrayBase

drjit.none(value: object, axis: int | tuple[int, ...] | None = 0) object

Check if none elements along the specified axis are active.

Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the | (OR) operator and finally returns the bit-wise inverse of the result.

The function is internally based on dr.reduce(). See the documentation of this function for further information.

Like dr.reduce(), this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducing drjit.cuda.Bool does not produce a scalar Python bool. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.

Boolean 1D arrays automatically convert to bool if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:

from drjit.cuda import Float

x = Float(...)
if dr.none(s < 0):
   # ...

A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves drjit.eval() to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.

To avoid Boolean reductions, one can often use symbolic operations such as if_stmt(), while_loop(), etc. The @dr.syntax decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...) based on the condition.

def f(x: Float):
    if a < 0:
       # ...
  • value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.

  • axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is 0.


Result of the reduction operation

Return type:

bool | drjit.ArrayBase

drjit.count(value: object, axis: int | tuple[int, ...] | None = 0) object

Compute the number of active entries along the given axis.

Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the + operator (interpreting True elements as 1 and False elements as 0). It returns an unsigned 32-bit version of the input array.

By default, it reduces along index 0, which refers to the outermost axis. Negative indices (e.g. -1) count backwards from the innermost axis. The special argument axis=None causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to be zero.

See the section on horizontal reductions for important general information about their properties.

  • value (bool | Sequence | drjit.ArrayBase) – A Python or Dr.Jit mask type

  • axis (int | None) – The axis along which to reduce. The default value of 0 refers to the outermost axis. Negative values count backwards from the innermost axis. A value of None causes a simultaneous reduction along all axes.


Result of the reduction operation

Return type:

int | drjit.ArrayBase

drjit.dot(arg0: object, arg1: object, /) object

Compute the dot product of two arrays.

Whenever possible, the implementation uses a sequence of fma() (fused multiply-add) operations to evaluate the dot product. When the input is a 1D JIT array like drjit.cuda.Float, the function evaluates the product of the input arrays via drjit.eval() and then performs a sum reduction via drjit.sum().

See the section on horizontal reductions for details on the properties of such horizontal reductions.


Dot product of inputs

Return type:

float | int | drjit.ArrayBase

drjit.abs_dot(arg0: object, arg1: object, /) object

Compute the absolute value of the dot product of two arrays.

This function implements a convenience short-hand for abs(dot(arg0, arg1)).

See the section on horizontal reductions for details on the properties of such horizontal reductions.


Absolute value of the dot product of inputs

Return type:

float | int | drjit.ArrayBase

drjit.squared_norm(arg: object, /) object

Computes the squared 2-norm of a Dr.Jit array, tensor, or Python sequence.

The operation is equivalent to

dr.dot(arg, arg)

The squared_norm() operation performs a horizontal reduction. Please see the section on horizontal reductions for details on their properties.


arg (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type


squared 2-norm of the input

Return type:

float | int | drjit.ArrayBase

drjit.norm(arg: object, /) object

Computes the 2-norm of a Dr.Jit array, tensor, or Python sequence.

The operation is equivalent to

dr.sqrt(dr.dot(arg, arg))

The norm() operation performs a horizontal reduction. Please see the section on horizontal reductions for details on their properties.


arg (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type


2-norm of the input

Return type:

float | int | drjit.ArrayBase

drjit.prefix_sum(value: ArrayT, exclusive: bool = True, axis: int | None = 0) ArrayT

Compute an exclusive or inclusive prefix sum of the input array.

By default, the function returns an output array \(\mathbf{y}\) of the same size as the input \(\mathbf{x}\), where

\[y_i = \sum_{j=0}^{i-1} x_j.\]

which is known as an exclusive prefix sum, as each element of the output array excludes the corresponding input in its sum. When the exclusive argument is set to False, the function instead returns an inclusive prefix sum defined as

\[y_i = \sum_{j=0}^i x_j.\]

There is also a convenience alias drjit.cumsum() that computes an inclusive sum analogous to various other nd-array frameworks.

Not all numeric data types are supported by prefix_sum(): presently, the function accepts Int32, UInt32, UInt64, Float32, and Float64-typed arrays.

The CUDA backend implementation for “large” numeric types (Float64, UInt64) has the following technical limitation: when reducing 64-bit integers, their values must be smaller than \(2^{62}\). When reducing double precision arrays, the two least significant mantissa bits are clamped to zero when forwarding the prefix from one 512-wide block to the next (at a very minor, probably negligible loss in accuracy). See the implementation for details on the rationale of this limitation.

  • value (drjit.ArrayBase) – A Python or Dr.Jit arithmetic type

  • exclusive (bool) – Specifies whether or not the prefix sum should be exclusive (the default) or inclusive.


An array of the same type containing the computed prefix sum.

Return type:


drjit.cumsum(arg, /)

Compute an cumulative sum (aka. inclusive prefix sum) of the input array.

This function wraps drjit.prefix_sum() and is implemented as

def cumsum(arg, /):
    return prefix_sum(arg, exclusive=False)
drjit.reverse(value, axis: int = 0)

Reverses the given Dr.Jit array or Python sequence along the specified axis.

  • value (ArrayBase|Sequence) – Dr.Jit array or Python sequence type

  • axis (int) – Axis along which the reversal should be performed. Only axis==0 is supported for now.


An output of the same type as value containing a copy of the reversed array.

Return type:


drjit.compress(arg: drjit.ArrayBase, /) object

Compress a mask into an array of nonzero indices.

This function takes an boolean array as input and then returns an unsigned 32-bit integer array containing the indices of nonzero entries.

It can be used to reduce a stream to a subset of active entries via the following recipe:

# Input: an active mask and several arrays data_1, data_2, ...
dr.schedule(active, data_1, data_2, ...)
indices = dr.compress(active)
data_1 = dr.gather(type(data_1), data_1, indices)
data_2 = dr.gather(type(data_2), data_2, indices)
# ...

There is some conceptual overlap between this function and drjit.scatter_inc(), which can likewise be used to reduce a stream to a smaller subset of active items. Please see the documentation of drjit.scatter_inc() for details.


This function internally performs a synchronization step.


arg (bool | drjit.ArrayBase) – A Python or Dr.Jit boolean type


Array of nonzero indices

drjit.ravel(array: object, order: str = 'A') object

Convert the input into a contiguous flat array.

This operation takes a Dr.Jit array, typically with some static and some dynamic dimensions (e.g., drjit.cuda.Array3f with shape 3xN), and converts it into a flattened 1D dynamically sized array (e.g., drjit.cuda.Float) using either a C or Fortran-style ordering convention.

It can also convert Dr.Jit tensors into a flat representation, though only C-style ordering is supported in this case.

Internally, drjit.ravel() performs a series of calls to drjit.scatter() to suitably reorganize the array contents.

For example,

x = dr.cuda.Array3f([1, 2], [3, 4], [5, 6])
y = dr.ravel(x, order=...)

will produce

  • [1, 3, 5, 2, 4, 6] with order='F' (the default for Dr.Jit arrays), which means that X/Y/Z components alternate.

  • [1, 2, 3, 4, 5, 6] with order='C', in which case all X coordinates are written as a contiguous block followed by the Y- and then Z-coordinates.

  • array (drjit.ArrayBase) – An arbitrary Dr.Jit array or tensor

  • order (str) – A single character indicating the index order. 'F' indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative 'C' specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value 'A' (automatic) will use F-style ordering for arrays and C-style ordering for tensors.


A dynamic 1D array containing the flattened representation of array with the desired ordering. The type of the return value depends on the type of the input. When array is already contiguous/flattened, this function returns it without making a copy.

Return type:


drjit.unravel(dtype: type[ArrayT], array: AnyArray, order: Literal['A', 'C', 'F'] = 'A') ArrayT

Load a sequence of Dr.Jit vectors/matrices/etc. from a contiguous flat array.

This operation implements the inverse of drjit.ravel(). In contrast to drjit.ravel(), it requires one additional parameter (dtype) specifying type of the return value. For example,

x = dr.cuda.Float([1, 2, 3, 4, 5, 6])
y = dr.unravel(dr.cuda.Array3f, x, order=...)

will produce an array of two 3D vectors with different contents depending on the indexing convention:

  • [1, 2, 3] and [4, 5, 6] when unraveled with order='F' (the default for Dr.Jit arrays), and

  • [1, 3, 5] and [2, 4, 6] when unraveled with order='C'

Internally, drjit.unravel() performs a series of calls to drjit.gather() to suitably reorganize the array contents.

  • dtype (type) – An arbitrary Dr.Jit array type

  • array (drjit.ArrayBase) – A dynamically sized 1D Dr.Jit array instance that is compatible with dtype. In other words, both must have the same underlying scalar type and be located imported in the same package (e.g., drjit.llvm.ad).

  • order (str) – A single character indicating the index order. 'F' (the default) indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative 'C' specifies row-major/C-style ordering, in which case the last index changes at the highest frequency.


An instance of type dtype containing the result of the unravel operation.

Return type:


drjit.reshape(dtype: type, value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) object
drjit.reshape(dtype: type, value: object, shape: int, order: str = 'A', shrink: bool = False) object

Converts value into an array of type dtype by rearranging the contents according to the specified shape.

The parameter shape may contain a single -1-valued target dimension, in which case its value is inferred from the remaining shape entries and the size of the input. When shape is of type int, it is interpreted as a 1-tuple (shape,).

This function supports the following behaviors:

  1. Reshaping tensors: Dr.Jit tensors admit arbitrary shapes. The drjit.reshape() can convert between them as long as the total number of elements remains unchanged.

    >>> from drjit.llvm.ad import TensorXf
    >>> value = dr.arange(TensorXf, 6)
    >>> dr.reshape(dtype=TensorXf, value=value, shape=(3, -1))
    [[0, 1]
     [2, 3]
     [4, 5]]
  2. Reshaping nested arrays: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions drjit.ravel() and drjit.unravel().

    >>> from drjit.llvm.ad import Array2f, Array3f
    >>> value = Array2f([1, 2, 3], [4, 5, 6])
    >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='C')
    [[1, 4, 2],
     [5, 3, 6]]
    >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='F')
    [[1, 3, 5],
     [2, 4, 6]]

    (By convention, Dr.Jit nested arrays are always printed in transposed form, which explains the difference in output compared to the identically shaped Tensor example just above.)

    The order argument can be used to specify C ("C") or Fortran ("F")-style ordering when rearranging the array. The default value "A" corresponds to Fortran-style ordering.

  3. PyTrees: When value is a PyTree, the operation recursively threads through the tree’s elements.

  1. Stream compression and loops that fork recursive work. When called with shrink=True, the function creates a view of the original data that potentially has a smaller number of elements.

    The main use of this feature is to implement loops that process large numbers of elements in parallel, and which need to occasionally “fork” some recursive work. On modern compute accelerators, an efficient way to handle this requirement is to append this work into a queue that is processed in a subsequent pass until no work is left. The reshape operation with shrink=True then resizes the preallocated queue to the actual number of collected items, which are the input of the next iteration.

    Please refer to the following example that illustrates how drjit.scatter_inc(), drjit.scatter(), and drjit.reshape(..., shrink=True) can be combined to realize a parallel loop with a fork condition

    def f():
        # Loop state variables (an arbitrary array or PyTree)
        state = ...
        # Determine how many elements should be processed
        size = dr.width(loop_state)
        # Run the following loop until no work is left
        while size > 0:
            # 1-element array used as an atomic counter
            queue_index = UInt(0)
            # Preallocate memory for the queue. The necessary
            # amount of memory is task-dependent
            queue_size = size
            queue = dr.empty(dtype=type(state), shape=queue_size)
            # Create an opaque variable representing the number 'loop_state'.
            # This keeps this changing value from being baked into the program,
            # which is needed for proper kernel caching
            queue_size_o = dr.opaque(UInt32, queue_size)
            while not stopping_criterion(state):
                # This line represents the loop body that processes work
                state = loop_body(state)
                # if the condition 'fork' is True, spawn a new work item that
                # will be handled in a future iteration of the parent loop.
                if fork(state):
                    # Atomically reserve a slot in 'queue'
                    slot = dr.scatter_inc(target=queue_index, index=0)
                    # Work item for the next iteration, task dependent
                    todo = state
                    # Be careful not to write beyond the end of the queue
                    valid = slot < queue_size_o
                    # Write 'todo' into the reserved slot
                    dr.scatter(target=queue, index=slot, value=todo, active=valid)
           # Determine how many fork operations took place
           size = queue_index[0]
           if size > queue_size:
               raise RuntimeError('Preallocated queue was too small: tried to store '
                                  f'{size} elements in a queue of size {queue_size}')
           # Reshape the queue and re-run the loop
           state = dr.reshape(dtype=type(state), value=queue, shape=size, shrink=True)
  • dtype (type) – Desired output type of the reshaped array. This could equal type(value) or refer to an entirely different array type.

  • value (object) – An arbitrary Dr.Jit array, tensor, or PyTree. The function returns unknown objects of other types unchanged.

  • shape (int|tuple[int, ...]) – The target shape.

  • order (str) – A single character indicating the index order used to reinterpret the input. 'F' indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative 'C' specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value 'A' (automatic) will use F-style ordering for arrays and C-style ordering for tensors.

  • shrink (bool) – Cheaply construct a view of the input that potentially has a smaller number of elements. The main use case of this method is explained above.


The reshaped array or PyTree.

Return type:


drjit.slice(value: object, index: object = 0) object

Select a subset of the input array or PyTree along the trailing dynamic dimension.

Given a Dr.Jit array value with shape (..., N) (where N represents a dynamically sized dimension), this operation effectively evaluates the expression value[..., index]. It recursively traverses PyTrees and transforms each compatible array element. Other values are returned unchanged.

The following properties of index determine the return type:

  • When index is a 1D integer array, the operation reduces to one or more calls to drjit.gather(), and slice() returns a reduced output object of the same type and structure.

  • When index is a scalar Python int, the trailing dimension is entirely removed, and the operation returns an array from the drjit.scalar namespace containing the extracted values.

drjit.tile(value: T, count: int) T

Tile the input array count times along the trailing dimension.

This function replicates the input count times along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. When count==1, the function returns the input without changes.

An example is shown below:


The tiled input as described above. The return type matches that of value.

Return type:


drjit.repeat(value: T, count: int) T

Repeat each successive entry of the input count times along the trailing dimension.

This function replicates the input count times along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. When count==1, the function returns the input without changes.

An example is shown below:


The repeated input as described above. The return type matches that of value.

Return type:


Mask operations

Also relevant here are any(), all(), none(), and count().

drjit.select(arg0: object, arg1: object, arg2: object, /) object
drjit.select(arg0: bool, arg1: object, arg2: object, /) object

Select elements from inputs based on a condition

This function uses a first mask argument to select between the subsequent two arguments. It implements the following component-wise operation:

\[\mathrm{result}_i = \begin{cases} \texttt{arg1}_i,\quad&\text{if }\texttt{arg0}_i,\\ \texttt{arg2}_i,\quad&\text{otherwise.} \end{cases}\]
  • arg0 (bool | drjit.ArrayBase) – A Python or Dr.Jit mask type

  • arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for True-valued mask entries.

  • arg2 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for False-valued mask entries.


Component-wise result of the selection operation

Return type:

float | int | drjit.ArrayBase

drjit.isinf(arg, /)

Performs an elementwise test for positive or negative infinity


arg (object) – A Dr.Jit array or other kind of numeric sequence type.


A mask value describing the result of the test

Return type:


drjit.isnan(arg, /)

Performs an elementwise test for NaN (Not a Number) values


arg (object) – A Dr.Jit array or other kind of numeric sequence type.


A mask value describing the result of the test.

Return type:


drjit.isfinite(arg, /)

Performs an elementwise test that checks whether values are finite and not equal to NaN (Not a Number)


arg (object) – A Dr.Jit array or other kind of numeric sequence type.


A mask value describing the result of the test

Return type:


drjit.allclose(a: object, b: object, rtol: float | None = None, atol: float | None = None, equal_nan: bool = False) bool

Returns True if two arrays are element-wise equal within a given error tolerance.

The function considers both absolute and relative error thresholds. In particular, a and b are considered equal if all elements satisfy

\[|a - b| \le |b| \cdot \texttt{rtol} + \texttt{atol}. \]

If not specified, the constants atol and rtol are chosen depending on the precision of the input arrays:













Note that these constants used are fairly loose and far larger than the roundoff error of the underlying floating point representation. The double precision parameters were chosen to match the behavior of numpy.allclose().

  • a (object) – A Dr.Jit array or other kind of numeric sequence type.

  • b (object) – A Dr.Jit array or other kind of numeric sequence type.

  • rtol (float) – A relative error threshold chosen according to the above table.

  • atol (float) – An absolute error threshold according to the above table.

  • equal_nan (bool) – If a and b both contain a NaN (Not a Number) entry, should they be considered equal? The default is False.


The result of the comparison.

Return type:


Miscellaneous operations

drjit.shape(arg: object) tuple[int, ...]

Return a tuple describing dimension and shape of the provided Dr.Jit array, tensor, or standard sequence type.

When the input array is ragged the function raises a RuntimeError. The term ragged refers to an array, whose components have mismatched sizes, such as [[1, 2], [3, 4, 5]]. Note that scalar entries (e.g. [[1, 2], [3]]) are acceptable, since broadcasting can effectively convert them to any size.

The expressions drjit.shape(arg) and arg.shape are equivalent.


arg (drjit.ArrayBase) – an arbitrary Dr.Jit array or tensor


A tuple describing the dimension and shape of the provided Dr.Jit input array or tensor.

Return type:

tuple[int, …]

drjit.width(arg: object, /) int
drjit.width(*args) int

Returns the vectorization width of the provided input(s), which is defined as the length of the last dynamic dimension.

When working with Jit-compiled CUDA or LLVM-based arrays, this corresponds to the number of items being processed in parallel.

The function raises an exception when the input(s) is ragged, i.e., when it contains arrays with incompatible sizes. It returns 1 if if the input is scalar and/or does not contain any Dr.Jit arrays.


arg (object) – An arbitrary Dr.Jit array or PyTree.


The width of the provided input(s).

Return type:


drjit.slice_index(dtype: type[drjit.ArrayBase], shape: tuple, indices: tuple) tuple[tuple, object]

Computes an index array that can be used to slice a tensor. It is used internally by Dr.Jit to implement complex cases of the __getitem__ operation.

It must be called with the desired output dtype, which must be a dynamically sized 1D array of 32-bit integers. The shape parameter specifies the dimensions of a hypothetical input tensor, and indices contains the entries that would appear in a complex slicing operation, but as a tuple. For example, [5:10:2, ..., None] would be specified as (slice(5, 10, 2), Ellipsis, None).

An example is shown below:

>>> dr.slice_index(dtype=dr.scalar.ArrayXu,
                   shape=(10, 1),
                   indices=(slice(0, 10, 2), 0))
[0, 2, 4, 6, 8]
  • dtype (type) – A dynamic 32-bit unsigned integer Dr.Jit array type, such as drjit.scalar.ArrayXu or drjit.cuda.UInt.

  • shape (tuple[int, ...]) – The shape of the tensor to be sliced.

  • indices (tuple[int|slice|ellipsis|None|dr.ArrayBase, ...]) – A set of indices used to slice the tensor. Its entries can be slice instances, integers, integer arrays, ... (ellipsis) or None.


Tuple consisting of the output array shape and a flattened unsigned integer array of type dtype containing element indices.

Return type:

tuple[tuple[int, …], drjit.ArrayBase]

drjit.meshgrid(*args, indexing='xy') tuple

Return flattened N-D coordinate arrays from a sequence of 1D coordinate vectors.

This function constructs flattened coordinate arrays that are convenient for evaluating and plotting functions on a regular grid. An example is shown below:

import drjit as dr

x, y = dr.meshgrid(
    dr.arange(dr.llvm.UInt, 4),
    dr.arange(dr.llvm.UInt, 4)

# x = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
# y = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]

This function carefully reproduces the behavior of numpy.meshgrid except for one major difference: the output coordinates are returned in flattened/raveled form. Like the NumPy version, the indexing=='xy' case internally reorders the first two elements of *args.

  • *args – A sequence of 1D coordinate arrays

  • indexing (str) – Specifies the indexing convention. Must be either set

  • 'xy' (to)


A tuple of flattened coordinate arrays (one per input)

Return type:


drjit.make_opaque(arg: object, /) None
drjit.make_opaque(*args) None

Forcefully evaluate arrays (including literal constants).

This function implements a more drastic version of drjit.eval() that additionally converts literal constant arrays into evaluated (device memory-based) representations.

It is related to the function drjit.opaque() that can be used to directly construct such opaque arrays. Please see the documentation of this function regarding the rationale of making array contents opaque to Dr.Jit’s symbolic tracing mechanism.


*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)

drjit.copy(arg: T, /) T

Create a deep copy of a PyTree

This function recursively traverses PyTrees and replaces Dr.Jit arrays with copies created via the ordinary copy constructor. It also rebuilds tuples, lists, dictionaries, and other custom data strutures.

Just-in-time compilation

enum drjit.JitBackend(value)

List of just-in-time compilation backends supported by Dr.Jit. See also drjit.backend_v().

Valid values are as follows:

Invalid = JitBackend.Invalid

Indicates that a type is not handled by a Dr.Jit backend (e.g., a scalar type)

CUDA = JitBackend.CUDA

Dr.Jit backend targeting NVIDIA GPUs using PTX (“Parallel Thread Execution”) IR.

LLVM = JitBackend.LLVM

Dr.Jit backend targeting various processors via the LLVM compiler infrastructure.

enum drjit.VarType(value)

List of possible scalar array types (not all of them are supported).

Valid values are as follows:

Void = VarType.Void

Unknown/unspecified type.

Bool = VarType.Bool

Boolean/mask type.

Int8 = VarType.Int8

Signed 8-bit integer.

UInt8 = VarType.UInt8

Unsigned 8-bit integer.

Int16 = VarType.Int16

Signed 16-bit integer.

UInt16 = VarType.UInt16

Unsigned 16-bit integer.

Int32 = VarType.Int32

Signed 32-bit integer.

UInt32 = VarType.UInt32

Unsigned 32-bit integer.

Int64 = VarType.Int64

Signed 64-bit integer.

UInt64 = VarType.UInt64

Unsigned 64-bit integer.

Pointer = VarType.Pointer

Pointer to a memory address.

Float16 = VarType.Float16

16-bit floating point format (IEEE 754).

Float32 = VarType.Float32

32-bit floating point format (IEEE 754).

Float64 = VarType.Float64

64-bit floating point format (IEEE 754).

enum drjit.VarState(value)

The drjit.ArrayBase.state property returns one of the following enumeration values describing possible evaluation states of a Dr.Jit variable.

Valid values are as follows:

Invalid = VarState.Invalid

The variable has length 0 and effectively does not exist.

Literal = VarState.Literal

A literal constant. Does not consume device memory.

Undefined = VarState.Undefined

An undefined memory region. Does not (yet) consume device memory.

Unevaluated = VarState.Unevaluated

An ordinary unevaluated variable that is neither a literal constant nor symbolic.

Evaluated = VarState.Evaluated

Evaluated variable backed by an device memory region.

Dirty = VarState.Dirty

An evaluated variable backed by a device memory region. The variable furthermore has pending side effects (i.e. the user has performed a :py:func`:drjit.scatter`, drjit.scatter_reduce() :py:func`:drjit.scatter_inc`, :py:func`:drjit.scatter_add`, or :py:func`:drjit.scatter_add_kahan` operation, and the effect of this operation has not been realized yet). The array’s status will automatically change to Evaluated the next time that Dr.Jit evaluates computation, e.g. via drjit.eval().

Symbolic = VarState.Symbolic

A symbolic variable that could take on various inputs. Cannot be evaluated.

Mixed = VarState.Mixed

This is a nested array, and the components have mixed states.

enum drjit.JitFlag(value)

Flags that control how Dr.Jit compiles and optimizes programs.

This enumeration lists various flag that control how Dr.Jit compiles and optimizes programs, most of which are enabled by default. The status of each flag can be queried via drjit.flag() and enabled/disabled via the drjit.set_flag() or the recommended drjit.scoped_set_flag() functions, e.g.:

with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False):
    # code that has this flag disabled goes here

The most common reason to update the flags is to switch between symbolic and evaluated execution of loops and functions. The former eagerly executes programs by breaking them into many smaller kernels, while the latter records computation symbolically to assemble large megakernels. See explanations below along with the documentation of drjit.switch() and drjit.while_loop for more details on these two modes.

Dr.Jit flags are a thread-local property. This means that multiple independent threads using Dr.Jit can set them independently without interfering with each other.

Member Type:


Valid values are as follows:

Debug = JitFlag.Debug

Debug mode: Enable functionality to uncover errors in application code.

When debug mode is enabled, Dr.Jit inserts a number of additional runtime checks to locate sources of undefined behavior.

Debug mode comes at a significant cost: it interferes with kernel caching, reduces tracing performance, and produce kernels that run slower. We recommend that you only use it periodically before a release, or when encountering a serious problem like a crashing kernel.

First, debug mode enables assertion checks in user code such as those performed by drjit.assert_true(), drjit.assert_false(), and drjit.assert_equal().

Second, Dr.Jit inserts additional checks to intercept out-of-bound reads and writes performed by operations such as drjit.scatter(), drjit.gather(), drjit.scatter_reduce(), drjit.scatter_inc(), etc. It also detects calls to invalid callables performed via drjit.switch(), drjit.dispatch(). Such invalid operations are masked, and they generate a warning message on the console, e.g.:

>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100))
RuntimeWarning: drjit.gather(): out-of-bounds read from position 100 in an array↵
of size 3. (<stdin>:2)

Finally, Dr.Jit also installs a python tracing hook that associates all Jit variables with their Python source code location, and this information is propagated all the way to the final intermediate representation (PTX, LLVM IR). This is useful for low-level debugging and development of Dr.Jit itself. You can query the source location information of a variable x by writing x.label.

Due to limitations of the Python tracing interface, this handler becomes active within the next called function (or Jupyter notebook cell) following activation of the drjit.JitFlag.Debug flag. It does not apply to code within the same scope/function.

C++ code using Dr.Jit also benefits from debug mode but will lack accurate source code location information. In mixed-language projects, the reported file and line number information will reflect that of the last operation on the Python side of the interface.

ReuseIndices = JitFlag.ReuseIndices

Index reuse: Dr.Jit consists of two main parts: the just-in-time compiler, and the automatic differentiation layer. Both maintain an internal data structure representing captured computation, in which each variable is associated with an index (e.g., r1234 in the JIT compiler, and a1234 in the AD graph).

The index of a Dr.Jit array in these graphs can be queried via the drjit.index and drjit.index_ad variables, and they are also visible in debug messages (if drjit.set_log_level() is set to a more verbose debug level).

Dr.Jit aggressively reuses the indices of expired variables by default, but this can make debug output difficult to interpret. When when debugging Dr.Jit itself, it is often helpful to investigate the history of a particular variable. In such cases, set this flag to False to disable variable reuse both at the JIT and AD levels. This comes at a cost: the internal data structures keep on growing, so it is not suitable for long-running computations.

Index reuse is enabled by default.

ConstantPropagation = JitFlag.ConstantPropagation

Constant propagation: immediately evaluate arithmetic involving literal constants on the host and don’t generate any device-specific code for them.

For example, the following assertion holds when value numbering is enabled in Dr.Jit.

from drjit.llvm import Int

# Create two literal constant arrays
a, b = Int(4), Int(5)

# This addition operation can be immediately performed and does not need to be recorded
c1 = a + b

# Double-check that c1 and c2 refer to the same Dr.Jit variable
c2 = Int(9)
assert c1.index == c2.index

Constant propagation is enabled by default.

ValueNumbering = JitFlag.ValueNumbering

Local value numbering: a simple variant of common subexpression elimination that collapses identical expressions within basic blocks. For example, the following assertion holds when value numbering is enabled in Dr.Jit.

from drjit.llvm import Int

# Create two non-literal arrays stored in device memory
a, b = Int(1, 2, 3), Int(4, 5, 6)

# Perform the same arithmetic operation twice
c1 = a + b
c2 = a + b

# Verify that c1 and c2 reference the same Dr.Jit variable
assert c1.index == c2.index

Local value numbering is enabled by default.

FastMath = JitFlag.FastMath

Fast Math: this flag is analogous to the -ffast-math flag in C compilers. When set, the system may use approximations and simplifications that sacrifice strict IEEE-754 compatibility.

Currently, it changes two behaviors:

  • expressions of the form a * 0 will be simplified to 0 (which is technically not correct when a is infinite or NaN-valued).

  • Dr.Jit will use slightly approximate division and square root operations in CUDA mode. Note that disabling fast math mode is costly on CUDA devices, as the strict IEEE-754 compliant version of these operations uses software-based emulation.

Fast math mode is enabled by default.

SymbolicLoops = JitFlag.SymbolicLoops

Dr.Jit provides two main ways of compiling loops involving Dr.Jit arrays.

  1. Symbolic mode (the default): Dr.Jit executes the loop a single time regardless of how many iterations it requires in practice. It does so with symbolic (abstract) arguments to capture the loop condition and body and then turns it into an equivalent loop in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.

  2. Evaluated mode: Dr.Jit evaluates the loop’s state variables and reduces the loop condition to a single element (bool) that expresses whether any elements are still alive. If so, it runs the loop body and the process repeats.

A separate section about symbolic and evaluated modes discusses these two options in detail.

Symbolic loops are enabled by default.

OptimizeLoops = JitFlag.OptimizeLoops

Perform basic optimizations for loops involving Dr.Jit arrays.

This flag enables two optimizations:

  • Constant arrays: loop state variables that aren’t modified by the loop are automatically removed. This shortens the generated code, which can be helpful especially in combination with the automatic transformations performed by @drjit.syntax that can be somewhat conservative in classifying too many local variables as potential loop state.

  • Literal constant arrays: In addition to the above point, constant loop state variables that are literal constants are propagated into the loop body, where this may unlock further optimization opportunities.

    This is useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.

A practical implication of this optimization flag is that it may cause drjit.while_loop() to run the loop body twice instead of just once.

This flag is enabled by default. Note that it is only meaningful in combination with SymbolicLoops.

CompressLoops = JitFlag.CompressLoops

Compress the loop state of evaluated loops after every iteration.

When an evaluated loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.

Setting this flag causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases.

This flag is disabled by default. Note that it only applies to evaluated loops (i.e., when SymbolicLoops is disabled, or the mode='evaluted' parameter as passed to the loop in question).

SymbolicCalls = JitFlag.SymbolicCalls

Dr.Jit provides two main ways of compiling function calls targeting instance arrays.

  1. Symbolic mode (the default): Dr.Jit invokes each callable with symbolic (abstract) arguments. It does this to capture a transcript of the computation that it can turn into a function in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.

  2. Evaluated mode: Dr.Jit evaluates all inputs and groups them by instance ID. Following this, it launches a kernel per instance to process the rearranged inputs and assemble the function return value.

A separate section about symbolic and evaluated modes discusses these two options in detail.

Besides calls to instance arrays, this flag also controls the behavior of the functions drjit.switch() and drjit.dispatch().

Symbolic calls are enabled by default.

OptimizeCalls = JitFlag.OptimizeCalls

Perform basic optimizations for function calls on instance arrays.

This flag enables two optimizations:

  • Constant propagation: Dr.Jit will propagate literal constants across function boundaries while tracing, which can unlock simplifications within. This is especially useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.

  • Devirtualization: When an element of the return value has the same computation graph in all instances, it is removed from the function call interface and moved to the caller.

The flag is enabled by default. Note that it is only meaningful in combination with SymbolicCalls. Besides calls to instance arrays, this flag also controls the behavior of the functions drjit.switch() and drjit.dispatch().

MergeFunctions = JitFlag.MergeFunctions

Deduplicate code generated by function calls on instance arrays.

When arr is an instance array (potentially with thousands of instances), a function call like


can potentially generate vast numbers of different functions in the generated code. At the same time, many of these functions may contain identical code (or code that is identical except for data references).

Dr.Jit can exploit such redundancy and merge such functions during code generation. Besides generating shorter programs, this also helps to reduce thread divergence.

This flag is enabled by default. Note that it is only meaningful in combination with SymbolicCalls. Besides calls to instance arrays, this flag also controls the behavior of the functions drjit.switch() and drjit.dispatch().

ForceOptiX = JitFlag.ForceOptiX

Force execution through OptiX even if a kernel doesn’t use ray tracing. This only applies to the CUDA backend is mainly helpful for automated tests done by the Dr.Jit team.

This flag is disabled by default.

PrintIR = JitFlag.PrintIR

Print the low-level IR representation when launching a kernel.

If enabled, this flag causes Dr.Jit to print the low-level IR (LLVM IR, NVIDIA PTX) representation of the generated code onto the console (or Jupyter notebook).

This flag is disabled by default.

KernelHistory = JitFlag.KernelHistory

Maintain a history of kernel launches to profile/debug programs.

Programs written on top of Dr.Jit execute in an extremely asynchronous manner. By default, the system postpones the computation to build large fused kernels. Even when this computation eventually runs, it does so asynchronously with respect to the host, which can make benchmarking difficult.

In general, beware of the following benchmarking anti-pattern:

import time
a = time.time()
# Some Dr.Jit computation
b = time.time()
print("took %.2f ms" % ((b-a) * 1000))

In the worst case, the measured time interval may only capture the tracing time, without any actual computation having taken place. Another common mistake with this pattern is that Dr.Jit or the target device may still be busy with computation that started prior to the a = time.time() line, which is now incorrectly added to the measured period.

Dr.Jit provides a kernel history feature, where it creates an entry in a list whenever it launches a kernel or related operation (memory copies, etc.). This not only gives accurate and isolated timings (measured with counters on the CPU/GPU) but also reveals if a kernel was launched at all. To capture the kernel history, set this flag just before the region to be benchmarked and call drjit.kernel_history() at the end.

Capturing the history has a (very) small cost and is therefore disabled by default.

LaunchBlocking = JitFlag.LaunchBlocking

Force synchronization after every kernel launch. This is useful to isolate severe problems (e.g. crashes) to a specific kernel.

This flag has a severe performance impact and is disabled by default.

ScatterReduceLocal = JitFlag.ScatterReduceLocal

Reduce locally before performing atomic scatter-reductions.

Atomic memory operations are expensive when many writes target the same region of memory. This leads to a phenomenon called contention that is normally associated with significant slowdowns (10-100x aren’t unusual).

This issue is particularly common when automatically differentiating computation in reverse mode (e.g. drjit.backward()), since this transformation turns differentiable global memory reads into atomic scatter-additions. A differentiable scalar read is all it takes to create such an atomic memory bottleneck.

To reduce this cost, Dr.Jit can perform a local reduction that uses cooperation between SIMD/warp lanes to resolve all requests targeting the same address and then only issuing a single atomic memory transaction per unique target. This can reduce atomic memory traffic 32-fold on the GPU (CUDA) and 16-fold on the CPU (AVX512). On the CUDA backend, local reduction is currently only supported for 32-bit operands (signed/unsigned integers and single precision variables).

The section on optimizations presents plots that demonstrate the impact of this optimization.

The JIT flag drjit.JitFlag.ScatterReduceLocal affects the behavior of scatter_add(), scatter_reduce() along with the reverse-mode derivative of gather(). Setting the flag to True will usually cause a mode= argument value of drjit.ReduceOp.Auto to be interpreted as drjit.ReduceOp.Local. Another LLVM-specific optimization takes precedence in certain situations, refer to the discussion of this flag for details.

This flag is enabled by default.

SymbolicConditionals = JitFlag.SymbolicConditionals

Dr.Jit provides two main ways of compiling conditionals involving Dr.Jit arrays.

  1. Symbolic mode (the default): Dr.Jit captures the computation performed by the True and False branches and generates an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.

  2. Evaluated mode: Dr.Jit always executes both branches and blends their outputs.

A separate section about symbolic and evaluated modes discusses these two options in detail.

Symbolic conditionals are enabled by default.

SymbolicScope = JitFlag.SymbolicScope

This flag is set to True when Dr.Jit is currently capturing symbolic computation. The flag is automatically managed and should not be updated by application code.

User code may query this flag to check if it is legal to perform certain operations (e.g., evaluating array contents).

Note that this information can also be queried in a more fine-grained manner (per variable) using the drjit.ArrayBase.state field.

Default = JitFlag.Default

The default set of optimization flags consisting of

drjit.has_backend(arg: drjit.JitBackend, /) int

Check if the specified Dr.Jit backend was successfully initialized.

drjit.schedule(arg: object, /) bool
drjit.schedule(*args) bool

Schedule the provided JIT variable(s) for later evaluation

This function causes args to be evaluated by the next kernel launch. In other words, the effect of this operation is deferred: the next time that Dr.Jit’s LLVM or CUDA backends compile and execute code, they will include the trace of the specified variables in the generated kernel and turn them into an explicit memory-based representation.

Scheduling and evaluation of traced computation happens automatically, hence it is rare that a user would need to call this function explicitly. Explicit scheduling can improve performance in certain cases—for example, consider the following code:

# Computation that produces Dr.Jit arrays
a, b = ...

# The following line launches a kernel that computes 'a'

# The following line launches a kernel that computes 'b'

If the traces of a and b overlap (perhaps they reference computation from an earlier step not shown here), then this is inefficient as these steps will be executed twice. It is preferable to launch bigger kernels that leverage common subexpressions, which is what drjit.schedule() enables:

a, b = ... # Computation that produces Dr.Jit arrays

# Schedule both arrays for deferred evaluation, but don't evaluate yet
dr.schedule(a, b)

# The following line launches a kernel that computes both 'a' and 'b'

# References the stored array, no kernel launch

Note that drjit.eval() would also have been a suitable alternative in the above example; the main difference to drjit.schedule() is that it does the evaluation immediately without deferring the kernel launch.

This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).

During recursion, the function gathers all unevaluated Dr.Jit arrays. Evaluated arrays and incompatible types are ignored. Multiple variables can be equivalently scheduled with a single drjit.schedule() call or a sequence of calls to drjit.schedule(). Variables that are garbage collected between the original drjit.schedule() call and the next kernel launch are ignored and will not be stored in memory.


*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to all differentiable variables.)


True if a variable was scheduled, False if the operation did not do anything.

Return type:


drjit.eval(arg: object, /) bool
drjit.eval(*args) bool

Evaluate the provided JIT variable(s)

Dr.Jit automatically evaluates variables as needed, hence it is usually not necessary to call this function explicitly. That said, explicit evaluation may sometimes improve performance—refer to the documentation of drjit.schedule() for an example of such a use case.

drjit.eval() invokes Dr.Jit’s LLVM or CUDA backends to compile and then execute a kernel containing the all steps that are needed to evaluate the specified variables, which will turn them into a memory-based representation. The generated kernel(s) will also include computation that was previously scheduled via drjit.schedule(). In fact, drjit.eval() internally calls drjit.schedule(), as

dr.eval(arg_1, arg_2, ...)

is equivalent to

dr.schedule(arg_1, arg_2, ...)

This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).

During this recursive traversal, the function collects all unevaluated Dr.Jit arrays, while ignoring previously evaluated arrays along and non-array types. The function also does not evaluate literal constant arrays (this refers to potentially large arrays that are entirely uniform), as this is generally not wanted. Use the function drjit.make_opaque() if you wish to evaluate literal constant arrays as well.


*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)


True if a variable was evaluated, False if the operation did not do anything.

Return type:


drjit.set_flag(arg0: drjit.JitFlag, arg1: bool, /) None

Set the value of the given Dr.Jit compilation flag.

drjit.flag(arg: drjit.JitFlag, /) bool

Query whether the given Dr.Jit compilation flag is active.

class drjit.scoped_set_flag

Context manager, which sets or unsets a Dr.Jit compilation flag in a local execution scope.

For example, the following snippet shows how to temporarily disable a flag:

with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False):
    # Code affected by the change should be placed here

# Flag is returned to its original status
__init__(self, flag: drjit.JitFlag, value: bool = True) None
__enter__(self) None
__exit__(self, arg0: object | None, arg1: object | None, arg2: object | None) None

Type traits

The functions in this section can be used to infer properties or types of Dr.Jit arrays.

The naming convention with a trailing _v or _t indicates whether a function returns a value or a type. Evaluation takes place at runtime within Python. In C++, these expressions are all constexpr (i.e., evaluated at compile time.).

Array type tests

drjit.is_array_v(arg: object | None) bool

Check if the input is a Dr.Jit array instance or type


arg (object) – An arbitrary Python object


True if arg or type(arg) is a Dr.Jit array type, and

False otherwise

Return type:


drjit.is_mask_v(arg: object, /) bool

Check whether the input array instance or type is a Dr.Jit mask array or a Python bool value/type.


arg (object) – An arbitrary Python object


True if arg represents a Dr.Jit mask array or Python bool instance or type.

Return type:


drjit.is_half_v(arg: object, /) bool

Check whether the input array instance or type is a Dr.Jit half-precision floating point array or a Python half value/type.


arg (object) – An arbitrary Python object


True if arg represents a Dr.Jit half-precision floating point array or Python half instance or type.

Return type:


drjit.is_float_v(arg: object, /) bool

Check whether the input array instance or type is a Dr.Jit floating point array or a Python float value/type.


arg (object) – An arbitrary Python object


True if arg represents a Dr.Jit floating point array or Python float instance or type.

Return type:


drjit.is_integral_v(arg: object, /) bool

Check whether the input array instance or type is an integral Dr.Jit array or a Python int value/type.

Note that a mask array is not considered to be integral.


arg (object) – An arbitrary Python object


True if arg represents an integral Dr.Jit array or Python int instance or type.

Return type:


drjit.is_arithmetic_v(arg: object, /) bool

Check whether the input array instance or type is an arithmetic Dr.Jit array or a Python int or float value/type.

Note that a mask type (e.g. bool, drjit.scalar.Array2b, etc.) is not considered to be arithmetic.


arg (object) – An arbitrary Python object


True if arg represents an arithmetic Dr.Jit array or Python int or float instance or type.

Return type:


drjit.is_signed_v(arg: object, /) bool

Check whether the input array instance or type is an signed Dr.Jit array or a Python int or float value/type.


arg (object) – An arbitrary Python object


True if arg represents an signed Dr.Jit array or Python int or float instance or type.

Return type:


drjit.is_unsigned_v(arg: object, /) bool

Check whether the input array instance or type is an unsigned integer Dr.Jit array or a Python bool value/type (masks and boolean values are also considered to be unsigned).


arg (object) – An arbitrary Python object


True if arg represents an unsigned Dr.Jit array or Python bool instance or type.

Return type:


drjit.is_dynamic_v(arg: object, /) bool

Check whether the input instance or type represents a dynamically sized Dr.Jit array type.


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_jit_v(arg: object, /) bool

Check whether the input array instance or type represents a type that undergoes just-in-time compilation.


arg (object) – An arbitrary Python object


True if arg represents an array type from the drjit.cuda.* or drjit.llvm.* namespaces, and False otherwise.

Return type:


drjit.is_diff_v(arg: object, /) bool

Check whether the input is a differentiable Dr.Jit array instance or type.

Note that this is a type-based statement that is unrelated to mathematical differentiability. For example, the integral type drjit.cuda.ad.Int from the CUDA AD namespace satisfies is_diff_v(..) = 1.


arg (object) – An arbitrary Python object


True if arg represents an array type from the drjit.[cuda/llvm].ad.* namespace, and False otherwise.

Return type:


drjit.is_vector_v(arg: object, /) bool

Check whether the input is a Dr.Jit array instance or type representing a vectorial array type.


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_complex_v(arg: object, /) bool

Check whether the input is a Dr.Jit array instance or type representing a complex number.


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_matrix_v(arg: object, /) bool

Check whether the input is a Dr.Jit array instance or type representing a matrix.


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_quaternion_v(arg: object, /) bool

Check whether the input is a Dr.Jit array instance or type representing a quaternion.


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_tensor_v(arg: object, /) bool

Check whether the input is a Dr.Jit array instance or type representing a tensor.


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_special_v(arg: object, /) bool

Check whether the input is a special Dr.Jit array instance or type.

A special array type requires precautions when performing arithmetic operations like multiplications (complex numbers, quaternions, matrices).


arg (object) – An arbitrary Python object


True if the test was successful, and False otherwise.

Return type:


drjit.is_struct_v(arg: object, /) bool

Check if the input is a Dr.Jit-compatible data structure

Custom data structures can be made compatible with various Dr.Jit operations by specifying a DRJIT_STRUCT member. See the section on PyTrees for details. This type trait can be used to check for the existence of such a field.


arg (object) – An arbitrary Python object


True if arg has a DRJIT_STRUCT member

Return type:


Array properties (shape, type, etc.)

drjit.type_v(arg: object, /) drjit.VarType

Returns the scalar type associated with the given Dr.Jit array instance or type.


arg (object) – An arbitrary Python object


The associated type drjit.VarType.Void.

Return type:


drjit.backend_v(arg: object, /) drjit.JitBackend

Returns the backend responsible for the given Dr.Jit array instance or type.


arg (object) – An arbitrary Python object


The associated Jit backend or drjit.JitBackend.None.

Return type:


drjit.size_v(arg: object, /) int

Return the (static) size of the outermost dimension of the provided Dr.Jit array instance or type

Note that this function mainly exists to query type-level information. Use the Python len() function to query the size in a way that does not distinguish between static and dynamic arrays.


arg (object) – An arbitrary Python object


Returns either the static size or drjit.Dynamic when arg is a dynamic Dr.Jit array. Returns 1 for all other types.

Return type:


drjit.depth_v(arg: object, /) int

Return the depth of the provided Dr.Jit array instance or type

For example, an array consisting of floating point values (for example, drjit.scalar.Array3f) has depth 1, while an array consisting of sub-arrays (e.g., drjit.cuda.Array3f) has depth 2.


arg (object) – An arbitrary Python object


Returns the depth of the input, if it is a Dr.Jit array instance or type. Returns 0 for all other inputs.

Return type:


drjit.itemsize_v(arg: object, /) int

Return the per-item size (in bytes) of the scalar type underlying a Dr.Jit array


arg (object) – A Dr.Jit array instance or array type.


Returns the item size array elements in bytes.

Return type:


drjit.Dynamic: int = -1

Special size value used to identify dynamic arrays in size_v().

drjit.newaxis: NoneType = None

This variable stores an alias of None. It is used to create new axes in tensor slicing operations (analogous to np.newaxis in NumPy). See the discussion of tensors for an example.

Bit-level operations

drjit.reinterpret_array(arg0: type[drjit.ArrayBase], arg1: drjit.ArrayBase, /) object

Reinterpret the provided Dr.Jit array or tensor as a different type.

This operation reinterprets the input type as another type provided that it has a compatible in-memory layout (this operation is also known as a bit-cast).

  • dtype (type) – Target type.

  • value (object) – A compatible Dr.Jit input array or tensor.


Result of the conversion as described above.

Return type:


drjit.popcnt(arg: ArrayT, /) ArrayT
drjit.popcnt(arg: int, /) int

Return the number of nonzero zero bits.

This function evaluates the component-wise population count of the input scalar, array, or tensor. This function assumes that arg is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.


arg (int | drjit.ArrayBase) – A Python or Dr.Jit array


number of nonzero zero bits in arg

Return type:

int | drjit.ArrayBase

drjit.lzcnt(arg: ArrayT, /) ArrayT
drjit.lzcnt(arg: int, /) int

Return the number of leading zero bits.

This function evaluates the component-wise leading zero count of the input scalar, array, or tensor. This function assumes that arg is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.

The operation is well-defined when arg is zero.


arg (int | drjit.ArrayBase) – A Python or Dr.Jit array


number of leading zero bits in arg

Return type:

int | drjit.ArrayBase

drjit.tzcnt(arg: ArrayT, /) ArrayT
drjit.tzcnt(arg: int, /) int

Return the number of trailing zero bits.

This function evaluates the component-wise trailing zero count of the input scalar, array, or tensor. This function assumes that arg is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.

The operation is well-defined when arg is zero.


arg (int | drjit.ArrayBase) – A Python or Dr.Jit array


number of trailing zero bits in arg

Return type:

int | drjit.ArrayBase

drjit.brev(arg: ArrayT, /) ArrayT
drjit.brev(arg: int, /) int

Reverse the bit representation of an integer value or array.

This function assumes that arg is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.


arg (int | drjit.ArrayBase) – A Python int or Dr.Jit integer array.


the bit-reversed version of arg.

Return type:

int | drjit.ArrayBase

drjit.log2i(arg: T, /) T

Return the floor of the base-2 logarithm.

This function evaluates the component-wise floor of the base-2 logarithm of the input scalar, array, or tensor. This function assumes that arg is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.

The operation overflows when arg is zero.


arg (int | drjit.ArrayBase) – A Python or Dr.Jit array


number of leading zero bits in the input array

Return type:

int | drjit.ArrayBase

Standard mathematical functions

drjit.fma(arg0: object, arg1: object, arg2: object, /) object
drjit.fma(arg0: int, arg1: int, arg2: int, /) int
drjit.fma(arg0: float, arg1: float, arg2: float, /) float

Perform a fused multiply-addition (FMA) of the inputs.

Given arguments arg0, arg1, and arg2, this operation computes arg0 * arg1 + arg2 using only one final rounding step. The operation is not only more accurate, but also more efficient, since FMA maps to a native machine instruction on all platforms targeted by Dr.Jit.

When the input is complex- or quaternion-valued, the function internally uses a complex or quaternion product. In this case, it reduces the number of internal rounding steps instead of avoiding them altogether.

While FMA is traditionally a floating point operation, Dr.Jit also implements FMA for integer arrays and maps it onto dedicated instructions provided by the backend if possible (e.g. mad.lo.* for CUDA/PTX).


Result of the FMA operation

Return type:

float | drjit.ArrayBase

drjit.abs(arg: ArrayT, /) ArrayT
drjit.abs(arg: int, /) int
drjit.abs(arg: float, /) float

Compute the absolute value of the provided input.

This function evaluates the component-wise absolute value of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.


arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type


Absolute value of the input

Return type:

int | float | drjit.ArrayBase

drjit.minimum(arg0: int, arg1: int, /) int
drjit.minimum(arg0: object, arg1: object, /) object
drjit.minimum(arg0: float, arg1: float, /) float

Compute the element-wise minimum value of the provided inputs.

(Not to be confused with drjit.min(), which reduces the input along the specified axes to determine the minimum)

  • arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type

  • arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type


Minimum of the input(s)

Return type:

int | float | drjit.ArrayBase

drjit.maximum(arg0: int, arg1: int, /) int
drjit.maximum(arg0: object, arg1: object, /) object
drjit.maximum(arg0: float, arg1: float, /) float

Compute the element-wise maximum value of the provided inputs.

(Not to be confused with drjit.max(), which reduces the input along the specified axes to determine the maximum)

  • arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type

  • arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type


Maximum of the input(s)

Return type:

int | float | drjit.ArrayBase

drjit.sqrt(arg: ArrayT, /) ArrayT
drjit.sqrt(arg: float, /) float

Evaluate the square root of the provided input.

This function evaluates the component-wise square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.

Negative inputs produce a NaN output value. Consider using the safe_sqrt() function to work around issues where the input might occasionally be negative due to prior round-off errors.

Another noteworthy behavior of the square root function is that it has an infinite derivative at arg=0, which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The safe_sqrt() function contains a workaround to ensure a finite derivative in this case.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Square root of the input

Return type:

float | drjit.ArrayBase

drjit.cbrt(arg: ArrayT, /) ArrayT
drjit.cbrt(arg: float, /) float

Evaluate the cube root of the provided input.

This function is currently only implemented for real-valued inputs.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Cube root of the input

Return type:

float | drjit.ArrayBase

drjit.rcp(arg: ArrayT, /) ArrayT
drjit.rcp(arg: float, /) float

Evaluate the reciprocal (1 / arg) of the provided input.

When arg is a CUDA single precision array, the operation is implemented slightly approximately—see the documentation of the instruction rcp.approx.ftz.f32 in the NVIDIA PTX manual for details. For full IEEE-754 compatibility, unset drjit.JitFlag.FastMath.

When called with a matrix-, complex- or quaternion-valued array, this function uses the matrix, complex, or quaternion multiplicative inverse to evaluate the reciprocal.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Reciprocal of the input

Return type:

float | drjit.ArrayBase

drjit.rsqrt(arg: ArrayT, /) ArrayT
drjit.rsqrt(arg: float, /) float

Evaluate the reciprocal square root (1 / sqrt(arg)) of the provided input.

This function evaluates the component-wise reciprocal square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.

When arg is a CUDA single precision array, the operation is implemented slightly approximately—see the documentation of the instruction rsqrt.approx.ftz.f32 in the NVIDIA PTX manual for details. For full IEEE-754 compatibility, unset drjit.JitFlag.FastMath.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Reciprocal square root of the input

Return type:

float | drjit.ArrayBase

drjit.clip(value, min, max)

Clip the provided input to the given interval.

This function is equivalent to

dr.maximum(dr.minimum(value, max), min)

Clipped input

Return type:

float | drjit.ArrayBase

drjit.ceil(arg: ArrayT, /) ArrayT
drjit.ceil(arg: float, /) float

Evaluate the ceiling, i.e. the smallest integer >= arg.

The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Ceiling of the input

Return type:

float | drjit.ArrayBase

drjit.floor(arg: ArrayT, /) ArrayT
drjit.floor(arg: float, /) float

Evaluate the floor, i.e. the largest integer <= arg.

The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Floor of the input

Return type:

float | drjit.ArrayBase

drjit.trunc(arg: ArrayT, /) ArrayT
drjit.trunc(arg: float, /) float

Truncates arg to the nearest integer by towards zero.

The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Truncated result

Return type:

float | drjit.ArrayBase

drjit.round(arg: ArrayT, /) ArrayT
drjit.round(arg: float, /) float

Rounds the input to the nearest integer using Banker’s rounding for half-way values.

This function is equivalent to std::rint in C++. It does not convert the type of the input array. A separate cast is necessary when integer output is desired.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Rounded result

Return type:

float | drjit.ArrayBase

drjit.sign(arg, /)

Return the element-wise sign of the provided array.

The function returns

\[\mathrm{sign}(\texttt{arg}) = \begin{cases} 1&\texttt{arg}>=0,\\ -1&\mathrm{otherwise}. \end{cases}\]

arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit array


Sign of the input array

Return type:

float | int | drjit.ArrayBase

drjit.copysign(arg0, arg1, /)

Copy the sign of arg1 to arg0 element-wise.

  • arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to change the sign of

  • arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to copy the sign from


The values of arg0 with the sign of arg1

Return type:

float | int | drjit.ArrayBase

drjit.mulsign(arg0, arg1, /)

Multiply arg0 by the sign of arg1 element-wise.

This function is equivalent to

a * dr.sign(b)
  • arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to multiply the sign of

  • arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to take the sign from


The values of arg0 multiplied with the sign of arg1

Return type:

float | int | drjit.ArrayBase

Operations for vectors and matrices

drjit.cross(arg0: ArrayT, arg1: ArrayT, /) ArrayT

Returns the cross-product of the two input 3D arrays


Cross-product of the two input 3D arrays

Return type:


drjit.det(arg, /)

Compute the determinant of the provided Dr.Jit matrix.


arg (drjit.ArrayBase) – A Dr.Jit matrix type


The determinant value of the input matrix

Return type:


drjit.diag(arg, /)

This function either returns the diagonal entries of the provided Dr.Jit matrix, or it constructs a new matrix from the diagonal entries.


arg (drjit.ArrayBase) – A Dr.Jit matrix type


The diagonal matrix of the input matrix

Return type:


drjit.trace(arg, /)

Returns the trace of the provided Dr.Jit matrix.


arg (drjit.ArrayBase) – A Dr.Jit matrix type


The trace of the input matrix

Return type:


drjit.matmul(arg0: object, arg1: object, /) object

Compute a matrix-matrix, matrix-vector, vector-matrix, or inner product.

This function implements the semantics of the @ operator introduced in Python’s PEP 465. There is no practical difference between using drjit.matul() or @ in Dr.Jit-based code. Multiplication of matrix types (e.g., drjit.scalar.Matrix2f) using the standard multiplication operator (*) is also based on on matrix multiplication.

This function takes two Dr.Jit arrays and picks one of the following 5 cases based on their leading fixed-size dimensions.

  • Matrix-matrix product: If both arrays have leading static dimensions (n, n), they are multiplied like conventional matrices.

  • Matrix-vector product: If arg0 has leading static dimensions (n, n) and arg1 has leading static dimension (n,), the operation conceptually appends a trailing 1-sized dimension to arg1, multiplies, and then removes the extra dimension from the result.

  • Vector-matrix product: If arg0 has leading static dimensions (n,) and arg1 has leading static dimension (n, n), the operation conceptually prepends a leading 1-sized dimension to arg0, multiplies, and then removes the extra dimension from the result.

  • Inner product: If arg0 and arg1 have leading static dimensions (n,), the operation returns the sum of the elements of arg0*arg1.

  • Scalar product: If arg0 or arg1 is a scalar, the operation scales the elements of the other argument.

It is legal to combine vectorized and non-vectorized types, e.g.

dr.matmul(dr.scalar.Matrix4f(...), dr.cuda.Matrix4f(...))

Also, note that doesn’t matter whether an input is an instance of a matrix type or a similarly-shaped nested array—for example, drjit.scalar.Matrix3f() and drjit.scalar.Array33f() have the same shape and are treated identically.


This operation only handles fixed-sizes arrays. A different approach is needed for multiplications involving potentially large dynamic arrays/tensors. Other other tools like PyTorch, JAX, or Tensorflow will be preferable in such situations (e.g., to train neural networks).

  • arg0 (dr.ArrayBase) – Dr.Jit array type

  • arg1 (dr.ArrayBase) – Dr.Jit array type


The result of the operation as defined above

Return type:


drjit.hypot(a, b)

Computes \(\sqrt{x^2+y^2}\) while avoiding overflow and underflow.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type


The computed hypotenuse.

Return type:

float | drjit.ArrayBase

drjit.normalize(arg: T, /) T

Normalize the input vector so that it has unit length and return the result.

This operation is equivalent to

arg * dr.rsqrt(dr.squared_norm(arg))

arg (drjit.ArrayBase) – A Dr.Jit array type


Unit-norm version of the input

Return type:


drjit.lerp(a, b, t)

Linearly blend between two values.

This function computes

\[\mathrm{lerp}(t) = (1-t) a + t b\]

In other words, it linearly blends between \(a\) and \(b\) based on the value \(t\) that is typically on the interval \([0, 1]\).

It does so using two fused multiply-additions (drjit.fma()) to improve performance and avoid numerical errors.


Interpolated result

Return type:

float | drjit.ArrayBase

drjit.sh_eval(d: ArrayBase, order: int) list

Evalute real spherical harmonics basis function up to a specified order.

The input d must be a normalized 3D Cartesian coordinate vector. The function returns a list containing all spherical haromnic basis functions evaluated with respect to d up to the desired order, for a total of (order+1)**2 output values.

The implementation relies on efficient pre-generated branch-free code with aggressive constant folding and common subexpression elimination. It admits scalar and Jit-compiled input arrays. Evaluation routines are included for orders 0 to 10. Requesting higher orders triggers a runtime exception.

This automatically generated code is based on the paper Efficient Spherical Harmonic Evaluation, Journal of Computer Graphics Techniques (JCGT), vol. 2, no. 2, 84-90, 2013 by Peter-Pike Sloan.

The SciPy equivalent of this function is given by

def sh_eval(d, order: int):
    from scipy.special import sph_harm
    theta, phi = np.arccos(d.z), np.arctan2(d.y, d.x)
    r = []
    for l in range(order + 1):
        for m in range(-l, l + 1):
            Y = sph_harm(abs(m), l, phi, theta)
            if m > 0:
                Y = np.sqrt(2) * Y.real
            elif m < 0:
                Y = np.sqrt(2) * Y.imag
    return d

The Mathematica equivalent of a specific entry is given by:

SphericalHarmonicQ[l_, m_, d_] := Block[{θ, ϕ},
  θ = ArcCos[d[[3]]];
  ϕ = ArcTan[d[[1]], d[[2]]];
    {SphericalHarmonicY[l, m, θ, ϕ], m == 0},
    {Sqrt[2] * Re[SphericalHarmonicY[l,  m, θ, ϕ]], m > 0},
    {Sqrt[2] * Im[SphericalHarmonicY[l, -m, θ, ϕ]], m < 0}

Operations for complex values and quaternions

drjit.conj(arg, /)

Returns the conjugate of the provided complex or quaternion-valued array. For all other types, it returns the input unchanged.


arg (drjit.ArrayBase) – A Dr.Jit 3D array


Conjugate form of the input

Return type:


drjit.arg(z, /)

Return the argument of a complex Dr.Jit array.

The argument refers to the angle (in radians) between the positive real axis and a vector towards z in the complex plane. When the input isn’t complex-valued, the function returns \(0\) or \(\pi\) depending on the sign of z.


z (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array


Argument of the complex input array

Return type:

float | drjit.ArrayBase

drjit.real(arg, /)

Return the real part of a complex or quaternion-valued input.

When the input isn’t complex- or quaternion-valued, the function returns the input unchanged.


arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array


Real part of the input array

Return type:

float | drjit.ArrayBase

drjit.imag(arg, /)

Return the imaginary part of a complex or quaternion-valued input.

When the input isn’t complex- or quaternion-valued, the function returns zero.


arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array


Imaginary part of the input array

Return type:

float | drjit.ArrayBase

Transcendental functions

Dr.Jit implements the most common transcendental functions using methods that are based on the CEPHES math library. The accuracy of these approximations is documented in a set of tables below.

Trigonometric functions

drjit.sin(arg: ArrayT, /) ArrayT
drjit.sin(arg: float, /) float

Evaluate the sine function.

This function evaluates the component-wise sine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The default implementation of this function is based on the CEPHES library and is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Sine of the input

Return type:

float | drjit.ArrayBase

drjit.cos(arg: ArrayT, /) ArrayT
drjit.cos(arg: float, /) float

Evaluate the cosine function.

This function evaluates the component-wise cosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Cosine of the input

Return type:

float | drjit.ArrayBase

drjit.sincos(arg: ArrayT, /) tuple[ArrayT, ArrayT]
drjit.sincos(arg: float, /) tuple[float, float]

Evaluate both sine and cosine functions at the same time.

This function simultaneously evaluates the component-wise sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to drjit.sin() and drjit.cos() when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.

The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation instead uses the hardware’s built-in multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Sine and cosine of the input

Return type:

(float, float) | (drjit.ArrayBase, drjit.ArrayBase)

drjit.tan(arg: ArrayT, /) ArrayT
drjit.tan(arg: float, /) float

Evaluate the tangent function.

This function evaluates the component-wise tangent function associated with each entry of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.

The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Tangent of the input

Return type:

float | drjit.ArrayBase

drjit.asin(arg: ArrayT, /) ArrayT
drjit.asin(arg: float, /) float

Evaluate the arcsine function.

This function evaluates the component-wise arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when called with a complex-valued input.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.

Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the safe_asin() function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.

Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The safe_asin() function contains a workaround to ensure a finite derivative in this case.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Arcsine of the input

Return type:

float | drjit.ArrayBase

drjit.acos(arg: ArrayT, /) ArrayT
drjit.acos(arg: float, /) float

Evaluate the arccosine function.

This function evaluates the component-wise arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.

Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the safe_acos() function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.

Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The safe_acos() function contains a workaround to ensure a finite derivative in this case.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Arccosine of the input

Return type:

float | drjit.ArrayBase

drjit.atan(arg: ArrayT, /) ArrayT
drjit.atan(arg: float, /) float

Evaluate the arctangent function.

This function evaluates the component-wise arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Arctangent of the input

Return type:

float | drjit.ArrayBase

drjit.atan2(arg0: object, arg1: object, /) object
drjit.atan2(arg0: float, arg1: float, /) float

Evaluate the four-quadrant arctangent function.

This function is currently only implemented for real-valued inputs.

See the section on transcendental function approximations for details regarding accuracy.


Arctangent of y/x, using the argument signs to determine the quadrant of the return value

Return type:

float | drjit.ArrayBase

Hyperbolic functions

drjit.sinh(arg: ArrayT, /) ArrayT
drjit.sinh(arg: float, /) float

Evaluate the hyperbolic sine function.

This function evaluates the component-wise hyperbolic sine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic sine of the input

Return type:

float | drjit.ArrayBase

drjit.cosh(arg: ArrayT, /) ArrayT
drjit.cosh(arg: float, /) float

Evaluate the hyperbolic cosine function.

This function evaluates the component-wise hyperbolic cosine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic cosine of the input

Return type:

float | drjit.ArrayBase

drjit.sincosh(arg: ArrayT, /) tuple[ArrayT, ArrayT]
drjit.sincosh(arg: float, /) tuple[float, float]

Evaluate both hyperbolic sine and cosine functions at the same time.

This function simultaneously evaluates the component-wise hyperbolic sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to drjit.sinh() and drjit.cosh() when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic sine and cosine of the input

Return type:

(float, float) | (drjit.ArrayBase, drjit.ArrayBase)

drjit.tanh(arg: ArrayT, /) ArrayT
drjit.tanh(arg: float, /) float

Evaluate the hyperbolic tangent function.

This function evaluates the component-wise hyperbolic tangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic tangent of the input

Return type:

float | drjit.ArrayBase

drjit.asinh(arg: ArrayT, /) ArrayT
drjit.asinh(arg: float, /) float

Evaluate the hyperbolic arcsine function.

This function evaluates the component-wise hyperbolic arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic arcsine of the input

Return type:

float | drjit.ArrayBase

drjit.acosh(arg: ArrayT, /) ArrayT
drjit.acosh(arg: float, /) float

Hyperbolic arccosine approximation.

This function evaluates the component-wise hyperbolic arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic arccosine of the input

Return type:

float | drjit.ArrayBase

drjit.atanh(arg: ArrayT, /) ArrayT
drjit.atanh(arg: float, /) float

Evaluate the hyperbolic arctangent function.

This function evaluates the component-wise hyperbolic arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Hyperbolic arctangent of the input

Return type:

float | drjit.ArrayBase

Exponentials, logarithms, power function

drjit.log2(arg: ArrayT, /) ArrayT
drjit.log2(arg: float, /) float

Evaluate the base-2 logarithm.

This function evaluates the component-wise base-2 logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.

See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Base-2 logarithm of the input

Return type:

float | drjit.ArrayBase

drjit.log(arg: ArrayT, /) ArrayT
drjit.log(arg: float, /) float

Evaluate the natural logarithm.

This function evaluates the component-wise natural logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.

See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Natural logarithm of the input

Return type:

float | drjit.ArrayBase

drjit.exp2(arg: ArrayT, /) ArrayT
drjit.exp2(arg: float, /) float

Evaluate 2 raised to a given power.

This function evaluates the component-wise base-2 exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.

See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Base-2 exponential of the input

Return type:

float | drjit.ArrayBase

drjit.exp(arg: ArrayT, /) ArrayT
drjit.exp(arg: float, /) float

Evaluate the natural exponential function.

This function evaluates the component-wise natural exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.

See the section on transcendental function approximations for details regarding accuracy.

When arg is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Natural exponential of the input

Return type:

float | drjit.ArrayBase

drjit.power(arg0: int, arg1: int, /) float
drjit.power(arg0: float, arg1: float, /) float
drjit.power(arg0: object, arg1: object, /) object

Raise the first argument to a power specified via the second argument.

This function evaluates the component-wise power of the input scalar, array, or tensor arguments. When called with a complex or quaternion-valued inputs, it uses a suitable generalization of the operation.

When arg1 is a Python int or integral float value, the function reduces operation to a sequence of multiplies and adds (potentially followed by a reciprocation operation when arg1 is negative).

The general case involves recursive use of the identity pow(arg0, arg1) = exp2(log2(arg0) * arg1).

There is no difference between using drjit.power() and the builtin Python ** operator.


arg (object) – A Python or Dr.Jit arithmetic type


The result of the operation arg0**arg1

Return type:



drjit.erf(arg: ArrayT, /) ArrayT
drjit.erf(arg: float, /) float

Evaluate the error function.

The error function <https://en.wikipedia.org/wiki/Error_function> is defined as

\[\operatorname{erf}(z) = \frac{2}{\sqrt\pi}\int_0^z e^{-t^2}\,\mathrm{d}t.\]

See the section on transcendental function approximations for details regarding accuracy.

This function is currently only implemented for real-valued inputs.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type



Return type:

float | drjit.ArrayBase

drjit.erfinv(arg: ArrayT, /) ArrayT
drjit.erfinv(arg: float, /) float

Evaluate the inverse error function.

This function evaluates the inverse of drjit.erf(). Its implementation is based on the paper Approximating the erfinv function by Mike Giles.

This function is currently only implemented for real-valued inputs.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type



Return type:

float | drjit.ArrayBase

drjit.lgamma(arg: ArrayT, /) ArrayT
drjit.lgamma(arg: float, /) float

Evaluate the natural logarithm of the absolute value the gamma function.

The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.

This function is currently only implemented for real-valued inputs.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type



Return type:

float | drjit.ArrayBase

drjit.rad2deg(arg: T, /) T

Convert angles from radians to degrees.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


The equivalent angle in degrees.

Return type:

float | drjit.ArrayBase

drjit.deg2rad(arg: T, /) T

Convert angles from degrees to radians.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


The equivalent angle in radians.

Return type:

float | drjit.ArrayBase

Safe mathematical functions

Dr.Jit provides “safe” variants of a few standard mathematical operations that are prone to out-of-domain errors in calculations with floating point rounding errors. Such errors could, e.g., cause the argument of a square root to become negative, which would ordinarily require complex arithmetic. At zero, the derivative of the square root function is infinite. The following operations clamp the input to a safe range to avoid these extremes.

drjit.safe_sqrt(arg: T, /) T

Safely evaluate the square root of the provided input avoiding domain errors.

Negative inputs produce zero-valued output. When differentiated via AD, this function also avoids generating infinite derivatives at x=0.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Square root of the input

Return type:

float | drjit.ArrayBase

drjit.safe_asin(arg: T, /) T

Safe wrapper around drjit.asin() that avoids domain errors.

Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Arcsine approximation

Return type:

float | drjit.ArrayBase

drjit.safe_acos(arg: T, /) T

Safe wrapper around drjit.acos() that avoids domain errors.

Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.


arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type


Arccosine approximation

Return type:

float | drjit.ArrayBase

Automatic differentiation

enum drjit.ADMode(value)

Enumeration to distinguish different types of primal/derivative computation.

See also drjit.enqueue(), drjit.traverse().

Valid values are as follows:

Primal = ADMode.Primal

Primal/original computation without derivative tracking. Note that this is not a valid input to Dr.Jit AD routines, but it is sometimes useful to have this entry when to indicate to a computation that derivative propagation should not be performed.

Forward = ADMode.Forward

Propagate derivatives in forward mode (from inputs to outputs)

Backward = ADMode.Backward

Propagate derivatives in backward/reverse mode (from outputs to inputs

enum drjit.ADFlag(value)

By default, Dr.Jit’s AD system destructs the enqueued input graph during forward/backward mode traversal. This frees up resources, which is useful when working with large wavefronts or very complex computation graphs. However, this also prevents repeated propagation of gradients through a shared subgraph that is being differentiated multiple times.

To support more fine-grained use cases that require this, the following flags can be used to control what should and should not be destructed.

Member Type:


Valid values are as follows:

ClearNone = ADFlag.ClearNone

Clear nothing.

ClearEdges = ADFlag.ClearEdges

Delete all traversed edges from the computation graph

ClearInput = ADFlag.ClearInput

Clear the gradients of processed input vertices (in-degree == 0)

ClearInterior = ADFlag.ClearInterior

Clear the gradients of processed interior vertices (out-degree != 0)

ClearVertices = ADFlag.ClearVertices

Clear gradients of processed vertices only, but leave edges intact. Equal to ClearInput | ClearInterior.

AllowNoGrad = ADFlag.AllowNoGrad

Don’t fail when the input to a drjit.forward or backward operation is not a differentiable array.

Default = ADFlag.Default

Default: clear everything (edges, gradients of processed vertices). Equal to ClearEdges | ClearVertices.

drjit.detach(arg: T, preserve_type: bool = True) T

Transforms the input variable into its non-differentiable version (detaches it from the AD computational graph).

This function supports arbitrary Dr.Jit arrays/tensors and PyTrees as input. In the latter case, it applies the transformation recursively. When the input variable is not a PyTree or Dr.Jit array, it is returned as it is.

While the type of the returned array is preserved by default, it is possible to set the preserve_type argument to false to force the returned type to be non-differentiable. For example, this will convert an array of type drjit.llvm.ad.Float into one of type drjit.llvm.Float.

  • arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.

  • preserve_type (bool) – Defines whether the returned variable should preserve the type of the input variable.


The detached variable.

Return type:


drjit.enable_grad(arg: object, /) None
drjit.enable_grad(*args) None

Enable gradient tracking for the provided variables.

This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).

During this recursive traversal, the function enables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.


*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.

drjit.disable_grad(arg: object, /) None
drjit.disable_grad(*args) None

Disable gradient tracking for the provided variables.

This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).

During this recursive traversal, the function disables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.


*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.

drjit.set_grad_enabled(arg0: object, arg1: bool, /) None

Enable or disable gradient tracking on the provided variables.

  • arg (object) – An arbitrary Dr.Jit array, tensor, PyTree, sequence, or mapping.

  • value (bool) – Defines whether gradient tracking should be enabled or disabled.

drjit.grad_enabled(arg: object, /) bool
drjit.grad_enabled(*args) bool

Return whether gradient tracking is enabled on any of the given variables.


*args (tuple) – A variable-length list of Dr.Jit arrays/tensors instances or PyTrees. The function recursively traverses them to all differentiable variables.


True if any of the input variables has gradient tracking enabled, False otherwise.

Return type:


drjit.grad(arg: T, preserve_type: bool = True) T

Return the gradient value associated to a given variable.

When the variable doesn’t have gradient tracking enabled, this function returns 0.

  • arg (object) – An arbitrary Dr.Jit array, tensor or PyTree.

  • preserve_type (bool) – Should the operation preserve the input type in the return value? (This is the default). Otherwise, Dr.Jit will, e.g., return a type of drjit.cuda.Float for an input of type drjit.cuda.ad.Float.


the gradient value associated to the input variable.

Return type:


drjit.set_grad(target: T, source: T) None

Set the gradient associated with the provided variable.

This operation internally decomposes into two sub-steps:

dr.accum_grad(target, source)

When source is not of the same type as target, Dr.Jit will try to broadcast its contents into the right shape.

  • target (object) – An arbitrary Dr.Jit array, tensor, or PyTree.

  • source (object) – An arbitrary Dr.Jit array, tensor, or PyTree.

drjit.accum_grad(target: T, source: T) None

Accumulate the contents of one variable into the gradient of another variable.

When source is not of the same type as target, Dr.Jit will try to broadcast its contents into the right shape.

  • target (object) – An arbitrary Dr.Jit array, tensor, or PyTree.

  • source (object) – An arbitrary Dr.Jit array, tensor, or PyTree.

drjit.replace_grad(arg0: T, arg1: T, /) None

Replace the gradient value of arg0 with the one of arg1.

This is a relatively specialized operation to be used with care when implementing advanced automatic differentiation-related features.

One example use would be to inform Dr.Jit that there is a better way to compute the gradient of a particular expression than what the normal AD traversal of the primal computation graph would yield.

The function promotes and broadcasts arg0 and arg1 if they are not of the same type.

  • arg0 (object) – An arbitrary Dr.Jit array, tensor, Python arithmetic type, or PyTree.

  • arg1 (object) – An arbitrary Dr.Jit array, tensor, or PyTree.


a new Dr.Jit array combining the primal value of arg0 and the derivative of arg1.

Return type:


drjit.clear_grad(arg: object, /) None

Clear the gradient of the given variable.


arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.

drjit.traverse(mode: drjit.ADMode, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None

Propagate gradients along the enqueued set of AD graph edges.

Given prior use of :py:func`drjit.enqueue()` to enqueue AD nodes for gradient propagation, this functions now performs the actual gradient propagation into either the forward or reverse direction (as specified by the mode parameter)

By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. The term leaf node is defined as follows: refers to

  • In forward AD, leaf nodes have no forward edges. They are outputs of a computation, and no other differentiable variable depends on them.

  • In backward AD, leaf nodes have no backward edges. They are inputs to a computation.

By default, the traversal also removes the edges of visited nodes to isolate them. These defaults are usually good ones: cleaning up the graph his frees up resources, which is useful when working with large wavefronts or very complex computation graphs. It also avoids potentially undesired derivative contributions that can arise when the AD graphs of two unrelated computations are connected by an edge and subsequently separately differentiated.

In advanced applications that require multiple AD traversals of the same graph, specify specify different combinations of the enumeration drjit.ADFlag via the flags parameter.

  • mode (drjit.ADMode) – Specifies the direction in which gradients should be propgated. drjit.ADMode.Forward and:py:attr:drjit.ADMode.Backward refer to forward and backward traversal.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph are cleared during traversal. The default value is drjit.ADFlag.Default.

drjit.enqueue(mode: drjit.ADMode, arg: object) None
drjit.enqueue(mode: drjit.ADMode, *args) None

Enqueues the input variable(s) for subsequent gradient propagation

Dr.Jit splits the process of automatic differentiation into three parts:

  1. Initializing the gradient of one or more input or output variables. The most common initialization entails setting the gradient of an output (e.g., an optimization loss) to 1.0.

  2. Enqueuing nodes that should partake in the gradient propagation pass. Dr.Jit will follow variable dependences (edges in the AD graph) to find variables that are reachable from the enqueued variable.

  3. Finally propagating gradients to all of the enqueued variables.

This function is responsible for step 2 of the above list and works differently depending on the specified mode:

-drjit.ADMode.Forward: Dr.Jit will recursively enqueue all variables that are

reachable along forward edges. That is, given a differentiable operation a = b+c, enqueuing c will also enqueue a for later traversal.

-drjit.ADMode.Backward: Dr.Jit will recursively enqueue all variables that are

reachable along backward edges. That is, given a differentiable operation a = b+c, enqueuing a will also enqueue b and c for later traversal.

For example, a typical chain of operations to forward propagate the gradients from a to b might look as follow:

a = dr.llvm.ad.Float(1.0)
b = f(a) # some computation involving 'a'

# The below three operations can also be written more compactly as dr.forward_from(a)
dr.set_gradient(a, 1.0)
dr.enqueue(dr.ADMode.Forward, a)

grad = dr.grad(b)

One interesting aspect of this design is that enqueuing and traversal don’t necessarily need to follow the same direction.

For example, we may only be interested in forward gradients reaching a specific output node c, which can be expressed as follows:

a = dr.llvm.ad.Float(1.0)

b, c, d, e = f(a)

dr.set_gradient(a, 1.0)
dr.enqueue(dr.ADMode.Backward, b)

grad = dr.grad(b)

The same naturally also works in the reverse direction. Dr.Jit provides a higher level API that encapsulate such logic in a few different flavors:

  • mode (drjit.ADMode) –

    Specifies the set edges which Dr.Jit should follow to

    enqueue variables to be visited by a later gradient propagation phase.

    drjit.ADMode.Forward and:py:attr:drjit.ADMode.Backward refer to forward and

    backward edges, respectively.

  • value (object) – An arbitrary Dr.Jit array, tensor or PyTree.

drjit.forward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None

Forward-propagate gradients from the provided Dr.Jit array or tensor.

This function sets the gradient of the provided Dr.Jit array or tensor arg to 1.0 and then forward-propagates derivatives through forward-connected components of the computation graph (i.e., reaching all variables that directly or indirectly depend on arg).

The operation is equivalent to

dr.set_grad(arg, 1.0)
dr.enqueue(dr.ADMode.Forward, h)
dr.traverse(dr.ADMode.Forward, flags=flags)

Refer to the documentation functions drjit.set_grad(), drjit.enqueue(), and drjit.traverse() for further details on the nuances of forward derivative propagation.

By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of drjit.enqueue() and the meaning of the flags parameter.

When drjit.JitFlag.SymbolicCalls is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled via drjit.enable_grad(), as this generally indicates the presence of a bug. Specify the drjit.ADFlag.AllowNoGrad flag (e.g. by passing flags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.

  • args (object) – A Dr.Jit array, tensor, or PyTree.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is drjit.ADFlag.Default.

drjit.forward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT
drjit.forward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]

Forward-propagate gradients to the provided set of Dr.Jit arrays/tensors.

dr.enqueue(dr.ADMode.Backward, *args)
dr.traverse(dr.ADMode.Forward, flags=flags)
return dr.grad(*args)

Internally, the operation first traverses the computation graph backwards from args to find potential paths along which gradients can flow to the given set of arrays. Then, it performs a gradient propagation pass along the detected variables.

For this to work, you must have previously enabled and specified input gradients for inputs of the computation. (see drjit.enable_grad() and via drjit.set_grad()).

Refer to the documentation functions drjit.enqueue() and drjit.traverse() for further details on the nuances of forward derivative propagation.

By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of drjit.enqueue() and the meaning of the flags parameter.

When drjit.JitFlag.SymbolicCalls is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled via drjit.enable_grad(), as this generally indicates the presence of a bug. Specify the drjit.ADFlag.AllowNoGrad flag (e.g. by passing flags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.

  • *args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is drjit.ADFlag.Default.


the gradient value(s) associated with *args following the traversal.

Return type:


drjit.forward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None

Forward-propagate gradients from the provided Dr.Jit array or tensor

This function is an alias of drjit.forward_from(). Please refer to the documentation of this function.

  • args (object) – A Dr.Jit array, tensor, or PyTree.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph are cleared during traversal. The default value is drjit.ADFlag.Default.

drjit.backward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None

Backpropagate gradients from the provided Dr.Jit array or tensor.

This function sets the gradient of the provided Dr.Jit array or tensor arg to 1.0 and then backpropagates derivatives through backward-connected components of the computation graph (i.e., reaching differentiable variables that potentially influence the value of arg).

The operation is equivalent to

dr.set_grad(arg, 1.0)
dr.enqueue(dr.ADMode.Backward, h)
dr.traverse(dr.ADMode.Backward, flags=flags)

Refer to the documentation functions drjit.set_grad(), drjit.enqueue(), and drjit.traverse() for further details on the nuances of derivative backpropagation.

By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of drjit.enqueue() and the meaning of the flags parameter.

When drjit.JitFlag.SymbolicCalls is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled via drjit.enable_grad(), as this generally indicates the presence of a bug. Specify the drjit.ADFlag.AllowNoGrad flag (e.g. by passing flags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.

  • args (object) – A Dr.Jit array, tensor, or PyTree.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is drjit.ADFlag.Default.

drjit.backward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT
drjit.backward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]

Backpropagate gradients to the provided set of Dr.Jit arrays/tensors.

dr.enqueue(dr.ADMode.Forward, *args)
dr.traverse(dr.ADMode.Backwards, flags=flags)
return dr.grad(*args)

Internally, the operation first traverses the computation graph forwards from args to find potential paths along which reverse-mode gradients can flow to the given set of input variables. Then, it performs a backpropagation pass along the detected variables.

For this to work, you must have previously enabled and specified input gradients for outputs of the computation. (see drjit.enable_grad() and via drjit.set_grad()).

Refer to the documentation functions drjit.enqueue() and drjit.traverse() for further details on the nuances of derivative backpropagation.

By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of drjit.enqueue() and the meaning of the flags parameter.

When drjit.JitFlag.SymbolicCalls is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled via drjit.enable_grad(), as this generally indicates the presence of a bug. Specify the drjit.ADFlag.AllowNoGrad flag (e.g. by passing flags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.

  • *args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is drjit.ADFlag.Default.


the gradient value(s) associated with *args following the traversal.

Return type:


drjit.backward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None

Backpropgate gradients from the provided Dr.Jit array or tensor.

This function is an alias of drjit.backward_from(). Please refer to the documentation of this function.

  • args (object) – A Dr.Jit array, tensor, or PyTree.

  • flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is drjit.ADFlag.Default.

drjit.suspend_grad(*args, when=True)

Python context manager to temporarily disable gradient tracking globally, or for a specific set of variables.

This context manager can be used as follows to completely disable all gradient tracking. Newly created variables will be detached from Dr.Jit’s AD graph.

with dr.suspend_grad():
    # .. code coes here ..

You may also specify any number of Dr.Jit arrays, tensors, or Pytrees. In this case, the context manager behaves differently by disabling gradient tracking more selectively for the specified variables.

with dr.suspend_grad(x):
    z = x + y  # 'z' will not track any gradients arising from 'x'

The suspend_grad() and resume_grad() context manager can be arbitrarily nested and suitably update the set of tracked variables.

A note about the interaction with drjit.enable_grad(): it is legal to register further AD variables within a scope that disables gradient tracking for specific variables.

with dr.suspend_grad(x):
    y = Float(1)

    # The following condition holds
    assert not dr.grad_enabled(x) and                   dr.grad_enabled(y)

In contrast, a suspend_grad() environment without arguments that completely disables AD does not allow further variables to be registered:

with dr.suspend_grad():
    y = Float(1)
    dr.enable_grad(y) # ignored

    # The following condition holds
    assert not dr.grad_enabled(x) and                   not dr.grad_enabled(y)
  • *args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or Pytrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.

  • when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via when=False. The default value is when=True.

drjit.resume_grad(*args, when=True)

Python context manager to temporarily resume gradient tracking globally, or for a specific set of variables.

This context manager can be used as follows to fully re-enable all gradient tracking following a previous call to drjit.suspend_grad(). Newly created variables will then again be attached to Dr.Jit’s AD graph.

with dr.suspend_grad():
    # ..

    with dr.resume_grad():
        # In this scope, the effect of the outer context
        # manager is effectively disabled

You may also specify any number of Dr.Jit arrays, tensors, or Pytrees. In this case, the context manager behaves differently by enabling gradient tracking more selectively for the specified variables.

with dr.suspend_grad():
    with dr.resume_grad(x):
        z = x + y  # 'z' will only track gradients arising from 'x'

The suspend_grad() and resume_grad() context manager can be arbitrarily nested and suitably update the set of tracked variables.

  • *args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or Pytrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.

  • when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via when=False. The default value is when=True.


Python context manager to isolate and partition AD traversals into multiple distinct phases.

Consider a sequence of steps being differentiated in reverse mode, like so:

x = ..

y = f(x)
z = g(y)

The drjit.backward() call would automatically traverse the AD graph nodes created during the execution of the function f() and g().

However, sometimes this is undesirable and more control is needed. For example, Dr.Jit may be in an execution context (a symbolic loop or call) that temporarily disallows differentiation of the f() part. The drjit.isolate_grad() context manager addresses this need:

y = f(x)

with dr.isolate_grad():
    z = g(y)

Any reverse-mode AD traversal of an edge that crosses the isolation boundary is postponed until leaving the scope. This is mathematically equivalent but produces two smaller separate AD graph traversals.

Dr.Jit operations like symbolic loops and calls internally create such an isolation boundary, hence it is rare that you would need to do so yourself. isolate_grad() is not useful for forward mode AD.


when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via when=False. The default value is when=True.

class drjit.CustomOp

Base class for implementing custom differentiable operations.

Dr.Jit can compute derivatives of builtin operations in both forward and reverse mode. In some cases, it may be useful or even necessary to control how a particular operation should be differentiated.

To do so, you may extend this class to provide three callback functions:

  1. CustomOp.eval(): Implements the primal evaluation of the function with detached inputs.

  2. CustomOp.forward(): Implements the forward derivative that propagates derivatives from input arguments to the return value

  3. CustomOp.backward(): Implements the backward derivative that propagates derivatives from the return value to the input arguments.

An example for a hypothetical custom addition operation is shown below

class Addition(dr.CustomOp):
    def eval(self, x, y):
        # Primal calculation without derivative tracking
        return x + y

    def forward(self):
        # Compute forward derivatives
        self.set_grad_out(self.grad_in('x') + self.grad_in('y'))

    def backward(self):
        # .. compute backward derivatives ..
        self.set_grad_in('x', self.grad_out())
        self.set_grad_in('y', self.grad_out())

    def name(self):
        # Optional: a descriptive name shown in GraphViz visualizations
        return "Addition"

You should never need to call these functions yourself—Dr.Jit will do so when appropriate. To weave such a custom operation into the AD graph, use the drjit.custom() function, which expects a subclass of drjit.CustomOp as first argument, followed by arguments to the actual operation that are directly forwarded to the .eval() callback.

# Add two numbers 'x' and 'y'. Calls our '.eval()' callback with detached arguments
result = dr.custom(Addition, x, y)

Forward or backward derivatives are then automatically handled through the standard operations. For example,


will invoke the .backward() callback from above.

Many derivatives are more complex than the above examples and require access to inputs or intermediate steps of the primal evaluation routine. You can simply stash them in the instance (self.field = ...), which is shown below for a differentiable multiplication operation that implements the product rule:

class Multiplication(dr.CustomOp):
    def eval(self, x, y):
        # Stash input arguments
        self.x = x
        self.y = y

        return x * y

    def forward(self):
        self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y'))

    def backward(self):
        self.set_grad_in('x', self.y * self.grad_out())
        self.set_grad_in('y', self.x * self.grad_out())

    def name(self):
        return "Multiplication"
eval(self, *args, **kwargs) object

Evaluate the custom operation in primal mode.

You must implement this method when subclassing CustomOp, since the default implementation raises an exception. It should realize the original (non-derivative-aware) form of a computation and may take an arbitrary sequence of positional, keyword, and variable-length positional/keyword arguments.

You should not need to call this function yourself—Dr.Jit will automatically do so when performing custom operations through the drjit.custom() interface.

Note that the input arguments passed to .eval() will be detached (i.e. they don’t have derivative tracking enabled). This is intentional, since derivative tracking is handled by the custom operation along with the other callbacks forward() and backward().

forward(self) None

Evaluate the forward derivative of the custom operation.

You must implement this method when subclassing CustomOp, since the default implementation raises an exception. It takes no arguments and has no return value.

An implementation will generally perform repeated calls to grad_in() to query the gradients of all function followed by a single call to set_grad_out() to set the gradient of the return value.

For example, this is how one would implement the product rule of the primal calculation x*y, assuming that the .eval() routine stashed the inputs in the custom operation object.

def forward(self):
    self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y'))
backward(self) None

Evaluate the backward derivative of the custom operation.

You must implement this method when subclassing CustomOp, since the default implementation raises an exception. It takes no arguments and has no return value.

An implementation will generally perform a single call to grad_out() to query the gradient of the function return value followed by a sequence of calls to set_grad_in() to assign the gradients of the function inputs.

For example, this is how one would implement the product rule of the primal calculation x*y, assuming that the .eval() routine stashed the inputs in the custom operation object.

def backward(self):
    self.set_grad_in('x', self.y * self.grad_out())
    self.set_grad_in('y', self.x * self.grad_out())
name(self) str

Return a descriptive name of the CustomOp instance.

Amongst other things, this name is used to document the presence of the custom operation in GraphViz debug output. (See graphviz_ad().)

grad_out(self) object

Query the gradient of the return value.

Returns an object, whose type matches the original return value produced in eval(). This function should only be used within the backward() callback.

set_grad_out(self, arg: object, /) None

Accumulate a gradient into the return value.

This function should only be used within the forward() callback.

grad_in(self, arg: object, /) object

Query the gradient of a specified input parameter.

The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (*args) can be queried by providing an integer index instead.

This function should only be used within the forward() callback.

set_grad_in(self, arg0: object, arg1: object, /) None

Accumulate a gradient into the specified input parameter.

The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (*args) can be assigned by providing an integer index instead.

This function should only be used within the backward() callback.

add_input(self, arg: object, /) None

Register an implicit input dependency of the operation on an AD variable.

This function should be called by the eval() implementation when an operation has a differentiable dependence on an input that is not a ordinary input argument of the function (e.g., a global program variable or a field of a class).

add_output(self, arg: object, /) None

Register an implicit output dependency of the operation on an AD variable.

This function should be called by the eval() implementation when an operation has a differentiable dependence on an output that is not part of the function return value (e.g., a global program variable or a field of a class).”

drjit.custom(arg0: type[drjit.CustomOp], /, *args, **kwargs) object

Evaluate a custom differentiable operation.

It can be useful or even necessary to control how a particular operation should be differentiated by Dr.Jit’s automatic differentiation (AD) layer. The drjit.custom() function enables such use cases by stitching an opque operation with user-defined primal and forward/backward derivative implementations into the AD graph.

The function expects a subclass of the CustomOp interface as first argument. The remaining positional and keyword arguments are forwarded to the CustomOp.eval() callback.

See the documentation of CustomOp for examples on how to realize such a custom operation.

drjit.wrap(source: str | ModuleType, target: str | ModuleType) Callable[[T], T]

Differentiable bridge between Dr.Jit and other array programming frameworks.

This function wraps computation performed using one array programming framework to expose it in another. Currently, PyTorch and JAX are supported, though other frameworks may be added in the future.

Annotating a function with @drjit.wrap adds code that suitably converts arguments and return values. Furthermore, it stitches the operation into the automatic differentiation (AD) graph of the other framework to ensure correct gradient propagation.

When exposing code written using another framework, the wrapped function can take and return any PyTree including flat or nested Dr.Jit arrays, tensors, and arbitrary nested lists/tuples, dictionaries, and custom data structures. The arguments don’t need to be differentiable—for example, integer/boolean arrays that don’t carry derivative information can be passed as well.

The wrapped function should be pure: in other words, it should read its input(s) and compute an associated output so that re-evaluating the function again produces the same answer. Multi-framework derivative tracking of impure computation will likely not behave as expected.

The following table lists the currently supported conversions:



Forward-mode AD

Reverse-mode AD



Everything just works.


Limitation: The passed/returned PyTrees can contain arbitrary arrays or tensors, but other types (e.g., a custom Python object not understood by PyTorch) will will raise errors when differentiating in forward mode (backward mode works fine).

An issue was filed on the PyTorch bugtracker.


You may want to further annotate the wrapped function with jax.jit to trace and just-in-time compile it in the JAX environment, i.e.,

@dr.wrap(source='drjit', target='jax')

Limitation: The passed/returned PyTrees can contain arbitrary arrays or Python scalar types, but other types (e.g., a custom Python object not understood by JAX) will raise errors.


This direction is currently unsupported. We plan to add it in the future.

Please also refer to the documentation sections on multi-framework differentiation associated caveats.


Types that have no equivalent on the other side (e.g. a quaternion array) will convert to generic tensors.

Data exchange is limited to representations that exist on both sides. There are a few limitations:

  • PyTorch lacks support for most unsigned integer types (uint16, uint32, or uint64-typed arrays). Use signed integer types to work around this issue.

  • Dr.Jit currently lacks support for most 8- and 16-bit numeric types (besides half precision floats).

  • JAX refuses to exchange boolean-valued tensors with other frameworks.

  • source (str | module) – The framework used outside of the wrapped function. The argument is currently limited to either 'drjit', 'torch', or jax'. For convenience, the associated Python module can be specified as well.

  • target (str | module) – The framework used inside of the wrapped function. The argument is currently limited to either 'drjit', 'torch', or 'jax'. For convenience, the associated Python module can be specified as well.


The decorated function.



The exponential constant \(e\) represented as a Python float.


The value \(\log(2)\) represented as a Python float.


The value \(\frac{1}{\log(2)}\) represented as a Python float.


The value \(\pi\) represented as a Python float.


The value \(\frac{1}{\pi}\) represented as a Python float.


The value \(\sqrt{\pi}\) represented as a Python float.


The value \(\frac{1}{\sqrt{\pi}}\) represented as a Python float.


The value \(2\pi\) represented as a Python float.


The value \(\frac{1}{2\pi}\) represented as a Python float.


The value \(\sqrt{2\pi}\) represented as a Python float.


The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python float.


The value \(4\pi\) represented as a Python float.


The value \(\frac{1}{4\pi}\) represented as a Python float.


The value \(\sqrt{4\pi}\) represented as a Python float.


The value \(\sqrt{2\pi}\) represented as a Python float.


The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python float.


The value float('inf') represented as a Python float.


The value float('nan') represented as a Python float.

drjit.epsilon(arg, /)

Returns the machine epsilon.

The machine epsilon gives an upper bound on the relative approximation error due to rounding in floating point arithmetic.


arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.


The machine epsilon.

Return type:


drjit.one_minus_epsilon(arg, /)

Returns one minus the machine epsilon value.


arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.


One minus the machine epsilon.

Return type:


drjit.recip_overflow(arg, /)

Returns the reciprocal overflow threshold value.

Any numbers equal to this threshold or a smaller value will overflow to infinity when reciprocated.


arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.


The reciprocal overflow threshold value.

Return type:


drjit.smallest(arg, /)

Returns the smallest representable normalized floating point value.


arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.


The smallest representable normalized floating point value.

Return type:


drjit.largest(arg, /)

Returns the largest representable finite floating point value for t.


arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.


The largest representable finite floating point value.

Return type:


Array base class

class drjit.ArrayBase

This is the base class of all Dr.Jit arrays and tensors. It provides an abstract version of the array API that becomes usable when the type is extended by a concrete specialization. ArrayBase itself cannot be instantiated.

See the section on Dr.Jit type signatures <type_signatures> to learn about the type parameters of ArrayBase.

property array

This member plays multiple roles:

  • When self is a tensor, this property returns the storage representation of the tensor in the form of a linearized dynamic 1D array.

  • When self is a special arithmetic object (matrix, quaternion, or complex number), array provides an copy of the same data with ordinary array semantics.

  • In all other cases, array is simply a reference to self.



property ndim

This property represents the dimension of the provided Dr.Jit array or tensor.



property shape

This property provides a tuple describing dimension and shape of the provided Dr.Jit array or tensor.

When the input array is ragged the function raises a RuntimeError. The term ragged refers to an array, whose components have mismatched sizes, such as [[1, 2], [3, 4, 5]]. Note that scalar entries (e.g. [[1, 2], [3]]) are acceptable, since broadcasting can effectively convert them to any size.

The expressions drjit.shape(arg) and arg.shape are equivalent.


tuple[int, …]

property state

This read-only property returns an enumeration value describing the evaluation state of this Dr.Jit array.



property x

If self is a static Dr.Jit array of size 1 (or larger), the property self.x can be used synonymously with self[0]. Otherwise, accessing this field will generate a RuntimeError.



property y

If self is a static Dr.Jit array of size 2 (or larger), the property self.y can be used synonymously with self[1]. Otherwise, accessing this field will generate a RuntimeError.



property z

If self is a static Dr.Jit array of size 3 (or larger), the property self.z can be used synonymously with self[2]. Otherwise, accessing this field will generate a RuntimeError.



property w

If self is a static Dr.Jit array of size 4 (or larger), the property self.w can be used synonymously with self[3]. Otherwise, accessing this field will generate a RuntimeError.



property T

This property returns the transpose of self. When the underlying array is not a matrix type, it raises a TypeError.

property index

If self is a leaf Dr.Jit array managed by a just-in-time compiled backend (i.e, CUDA or LLVM), this property contains the associated variable index in the graph data structure storing the computation trace. This graph can be visualized using drjit.graphviz(). Otherwise, the value of this property equals zero. A non-leaf array (e.g. drjit.cuda.Array2i) consists of several JIT variables, whose indices must be queried separately.

Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index index_ad keeps track of the variable index within the AD computation graph, if applicable.



property index_ad

If self is a leaf Dr.Jit array represented by an AD backend, this property contains the variable index in the graph data structure storing the computation trace for later differentiation (this graph can be visualized using drjit.graphviz_ad()). A non-leaf array (e.g. drjit.cuda.ad.Array2f) consists of several AD variables, whose indices must be queried separately.

Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index index keeps track of the variable index within the raw computation graph, if applicable.



property grad

This property can be used to retrieve or set the gradient associated with the Dr.Jit array or tensor.

The expressions drjit.grad(arg) and arg.grad are equivalent when arg is a Dr.Jit array/tensor.




Return len(self).


Implement iter(self).


Return repr(self).


True if self else False

Casts the array to a Python bool type. This is only permissible when self represents an boolean array of both depth and size 1.

__add__(value, /)

Return self+value.

__radd__(value, /)

Return value+self.

__iadd__(value, /)

Return self+=value.

__sub__(value, /)

Return self-value.

__rsub__(value, /)

Return value-self.

__isub__(value, /)

Return self-=value.

__mul__(value, /)

Return self*value.

__rmul__(value, /)

Return value*self.

__imul__(value, /)

Return self*=value.

__matmul__(value, /)

Return self@value.

__rmatmul__(value, /)

Return value@self.

__imatmul__(value, /)

Return self@=value.

__truediv__(value, /)

Return self/value.

__rtruediv__(value, /)

Return value/self.

__itruediv__(value, /)

Return self/=value.

__floordiv__(value, /)

Return self//value.

__rfloordiv__(value, /)

Return value//self.

__ifloordiv__(value, /)

Return self//=value.

__mod__(value, /)

Return self%value.

__rmod__(value, /)

Return value%self.

__imod__(value, /)

Return self%=value.

__rshift__(value, /)

Return self>>value.

__rrshift__(value, /)

Return value>>self.

__irshift__(value, /)

Return self>>=value.

__lshift__(value, /)

Return self<<value.

__rlshift__(value, /)

Return value<<self.

__ilshift__(value, /)

Return self<<=value.

__and__(value, /)

Return self&value.

__rand__(value, /)

Return value&self.

__iand__(value, /)

Return self&=value.

__or__(value, /)

Return self|value.

__ror__(value, /)

Return value|self.

__ior__(value, /)

Return self|=value.

__xor__(value, /)

Return self^value.

__rxor__(value, /)

Return value^self.

__ixor__(value, /)

Return self^=value.



__le__(value, /)

Return self<=value.

__lt__(value, /)

Return self<value.

__ge__(value, /)

Return self>=value.

__gt__(value, /)

Return self>value.

__ne__(value, /)

Return self!=value.

__eq__(value, /)

Return self==value.

__dlpack__(self, stream: object | None = None) ndarray[]

Returns a DLPack capsule representing the data in this array.

This operation may potentially perform a copy. For example, nested arrays like drjit.llvm.Array3f or drjit.cuda.Matrix4f need to be rearranged into a contiguous memory representation before they can be exposed.

In other case, e.g. for drjit.llvm.Float, drjit.scalar.Array3f, or drjit.scalar.ArrayXf, the data is already contiguous and a zero-copy approach is used instead.

__array__(self, dtype: object | None = None) object

Returns a NumPy array representing the data in this array.

This operation may potentially perform a copy. For example, nested arrays like drjit.llvm.Array3f or drjit.cuda.Matrix4f need to be rearranged into a contiguous memory representation before they can be wrapped.

In other case, e.g. for drjit.llvm.Float, drjit.scalar.Array3f, or drjit.scalar.ArrayXf, the data is already contiguous and a zero-copy approach is used instead.

numpy(self) numpy.ndarray[]

Returns a NumPy array representing the data in this array.

This operation may potentially perform a copy. For example, nested arrays like drjit.llvm.Array3f or drjit.cuda.Matrix4f need to be rearranged into a contiguous memory representation before they can be wrapped.

In other case, e.g. for drjit.llvm.Float, drjit.scalar.Array3f, or drjit.scalar.ArrayXf, the data is already contiguous and a zero-copy approach is used instead.

torch(self) object

Returns a PyTorch tensor representing the data in this array.

For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.


This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function drjit.wrap() for further information on how to accomplish this.

jax(self) object

Returns a JAX tensor representing the data in this array.

For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.


This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function drjit.wrap() for further information on how to accomplish this.

tf(self) object

Returns a TensorFlow tensor representing the data in this array.

For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.


This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function drjit.wrap() for further information on how to accomplish this.

Computation graph analysis

The following operations visualize the contents of Dr.Jit’s computation graphs (of which there are two: one for Jit compilation, and one for automatic differentiation).

drjit.graphviz(as_string: bool = False) object

Return a GraphViz diagram describing registered JIT variables and their connectivity.

This function returns a representation of the computation graph underlying the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the graphviz_ad() function to visualize the computation graph of the latter.

Run dr.graphviz().view() to open up a PDF viewer that shows the resulting output in a separate window.

The function depends on the graphviz Python package when as_string=False (the default).


as_string (bool) – if set to True, the function will return raw GraphViz markup as a string. (Default: False)


GraphViz object or raw markup.

Return type:


drjit.graphviz_ad(as_string: bool = False) object

Return a GraphViz diagram describing variables registered with the automatic differentiation layer, as well as their connectivity.

This function returns a representation of the computation graph underlying the Dr.Jit AD layer, which one architectural layer above the just-in-time compiler. See the graphviz() function to visualize the computation graph of the latter.

Run dr.graphviz_ad().view() to open up a PDF viewer that shows the resulting output in a separate window.

The function depends on the graphviz Python package when as_string=False (the default).


as_string (bool) – if set to True, the function will return raw GraphViz markup as a string. (Default: False)


GraphViz object or raw markup.

Return type:


drjit.whos(as_string: bool) str | None

Return/print a list of live JIT variables.

This function provides information about the set of variables that are currently registered with the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the whos_ad() function for the latter.


as_string (bool) – if set to True, the function will return the list in string form. Otherwise, it will print directly onto the console and return None. (Default: False)


a human-readable list (if requested).

Return type:

None | str

drjit.whos_ad(as_string: bool) str | None

Return/print a list of live variables registered with the automatic differentiation layer.

This function provides information about the set of variables that are currently registered with the Dr.Jit automatic differentiation layer, which one architectural layer above the just-in-time compiler. See the whos() function to obtain information about the latter.


as_string (bool) – if set to True, the function will return the list in string form. Otherwise, it will print directly onto the console and return None. (Default: False)


a human-readable list (if requested).

Return type:

None | str

drjit.set_label(arg0: object, arg1: str, /) None
drjit.set_label(**kwargs) None

Assign a label to the provided Dr.Jit array.

This can be helpful to identify computation in GraphViz output (see drjit.graphviz(), graphviz_ad()).

The operations assumes that the array is tracked by the just-in-time compiler. It has no effect on unsupported inputs (e.g., arrays from the drjit.scalar package). It recurses through PyTrees (tuples, lists, dictionaries, custom data structures) and appends names (indices, dictionary keys, field names) separated by underscores to uniquely identify each element.

The following **kwargs-based shorthand notation can be used to assign multiple labels at once:

set_label(x=x, y=y)
  • *arg (tuple) – a Dr.Jit array instance and its corresponding label str value.

  • **kwarg (dict) – A set of (keyword, object) pairs.


drjit.assert_true(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)

Generate an assertion failure message when any of the entries in cond are False.

This function resembles the built-in assert keyword in that it raises an AssertionError when the condition cond is False.

In contrast to the built-in keyword, it also works when cond is an array of boolean values. In this case, the function raises an exception when any entry of cond is False.

The function accepts an optional format string along with positional and keyword arguments, and it processes them like drjit.print(). When only a subset of the entries of cond is False, the function reduces the generated output to only include the associated entries.

>>> x = Float(1, -4, -2, 3)
>>> dr.assert_true(x >= 0, 'Found negative values: {}', x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "drjit/__init__.py", line 1327, in assert_true
    raise AssertionError(msg)
AssertionError: Assertion failure: Found negative values: [-4, -2]

This function also works when some of the function inputs are symbolic. In this case, the check is delayed and potential failures will be reported asynchronously. In this case, drjit.assert_true() generates output on sys.stderr instead of raising an exception, as the original execution context no longer exists at that point.

Assertion checks carry a performance cost, hence they are disabled by default. To enable them, set the JIT flag dr.JitFlag.Debug.

  • cond (bool | drjit.ArrayBase) – The condition used to trigger the assertion. This should be a scalar Python boolean or a 1D boolean array.

  • fmt (str) – An optional format string that will be appended to the error message. It can reference positional or keyword arguments specified via *args and **kwargs.

  • *args (tuple) – Optional variable-length positional arguments referenced by fmt, see drjit.print() for details on this.

  • tb_depth (int) – Depth of the backtrace that should be appended to the assertion message. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed.

  • tb_skip (int) – The first tb_skip entries of the backtrace will be removed. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed. This is helpful when the assertion check is called from a helper function that should not be shown.

  • **kwargs (dict) – Optional variable-length keyword arguments referenced by fmt, see drjit.print() for details on this.

drjit.assert_false(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)

Equivalent to assert_true() with a flipped condition cond. Please refer to the documentation of this function for further details.

drjit.assert_equal(arg0, arg1, fmt: str | None = None, *args, limit: int = 3, tb_skip: int = 0, **kwargs)

Equivalent to assert_true() with the condition arg0==arg1. Please refer to the documentation of this function for further details.

drjit.print(fmt: str, *args, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None
drjit.print(value: object, /, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None

Generate a formatted string representation and print it immediately or in a delayed fashion (if any of the inputs are symbolic).

This function combines the behavior of the built-in Python format() and print() functions: it generates a formatted string representation as specified by a format string fmt and then outputs it on the console. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.

>>> from drjit.cuda import Array3f
>>> dr.print("{}:\n{foo}",
...          "A PyTree containing an array",
...          foo={ 'a' : Array3f(1, 2, 3) })
A PyTree containing an array:
  'a': [[1, 2, 3]]

The key advance of drjit.print() compared to the built-in Python print() statement is that it can run asynchronously, which allows it to print symbolic variables without requiring their evaluation. Dr.Jit uses symbolic variables to trace loops (drjit.while_loop()), conditionals (drjit.if_stmt()), and calls (drjit.switch(), drjit.dispatch()). Such symbolic variables represent values that are unknown at trace time, and which cannot be printed using the built-in Python print() function (attempting to do so will raise an exception).

When the print statement does not reference any symbolic arrays or tensors, it will execute immediately. Otherwise, the output will appear after the next drjit.eval() statement, or whenever any subsequent computation is evaluated. Here is an example from an interactive Python session demonstrating printing from a symbolic call performed via drjit.switch():

>>> from drjit.llvm import UInt, Float
>>> def f1(x):
...     dr.print("in f1: {x=}", x=x)
>>> def f2(x):
...     dr.print("in f2: {x=}", x=x)
>>> dr.switch(index=UInt(0, 0, 0, 1, 1, 1),
...           targets=[f1, f2],
...           x=Float(1, 2, 3, 4, 5, 6))
>>> # No output (yet)
>>> dr.eval()
in f1: x=[1, 2, 3]
in f2: x=[4, 5, 6]

Dynamic arrays with more than 20 entries will be abbreviated. Specify the limit=.. argument to reveal the contents of larger arrays.

>>> dr.print(dr.arange(dr.llvm.Int, 30))
[0, 1, 2, .. 24 skipped .., 27, 28, 29]

>>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵
 23, 24, 25, 26, 27, 28, 29]

This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:

  • Positional arguments (in *args) can be referenced implicitly ({}), or using indices ({0}, {1}, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.

  • Keyword arguments (in **kwargs) can be referenced via their keyword name ({foo}). Unreferenced keywords will be silently ignored.

  • A trailing = in a brace expression repeats the string within the braces followed by the output:

    >>> dr.print('{foo=}', foo=1)

When the format string fmt is omitted, it is implicitly set to {}, and the function formats a single positional argument.

The function implicitly appends end to the format string, which is set to a newline by default. The final result is sent to sys.stdout (by default) or file. When a file argument is given, it must implement the method write(arg: str).

A related operation drjit.format() admits the same format string syntax but returns a Python str instead of printing to the console. This operation, however, does not support symbolic inputs—use drjit.print() with a custom file argument to stringify symbolic inputs asynchronously.


This operation is not suitable for extracting large amounts of data from Dr.Jit kernels, as the conversion to a string representation incurs a nontrivial runtime cost.


Technical details on symbolic printing

When Dr.Jit compiles and executes queued computation on the target device, it includes additional code for symbolic print operations that that captures referenced arguments and copies them back to the host (CPU). The information is then printed following the end of that process.

Only a limited amount of memory is set aside to capture the output of symbolic print operations. This is because the amount of data produced within long-running symbolic loops can often exceed the total device memory. Also, printing gigabytes of ASCII text into a Python console or Jupyter notebook is likely not a good idea.

For the electronically inclined, the operation is best thought of as hooking up an oscilloscope to a high-frequency circuit. The oscilloscope provides a limited view into a vast torrent of data to assist the user, who would be overwhelmed if the oscilloscope worked by capturing and showing everything.

The operation warns when the size of the buffers was insufficient. In this case, the output is still printed in the correct order, but chunks of the data are missing. The position of the resulting holes is unspecified and non-deterministic.

>>> dr.print(dr.arange(Float, 10000000), method='symbolic')
>>> dr.eval()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

RuntimeWarning: dr.print(): symbolic print statement only captured 20 of 10000000 available outputs.
The above is a non-deterministic sample, in which entries are in the right order but not necessarily
contiguous. Specify `limit=..` to capture more information and/or add the special format field
`{thread_id}` show the thread ID/array index associated with each entry of the captured output.

This is because the (many) parallel threads of the program all try to append their state to the output buffer, but only the first limit (20 by default) succeed. The host subsequently re-sorts the captured data by thread ID. This means that the output [5, 6, 102, 188, 1026, ..] would also be a valid result of the prior command. When a print statement references multiple arrays, then the operations either shows all array entries associated with a particular execution thread, or none of them.

To refine what is captured, you can specify the active argument to disable the print statement for a subset of the entries (a “trigger” in the oscilloscope analogy). Printing from an inactive thread within a symbolic loop (drjit.while_loop()), conditional (drjit.if_stmt()), or call (drjit.switch(), drjit.dispatch()) will likewise not generate any output.

A potential gotcha of the current design is that a symbolic print within a symbolic loop counts as one print statement and will only generate a single combined output string. The output of each thread is arranged in one contiguous block. You can add the special format string keyword {thread_id} to reveal the mapping between output values and the execution thread that generated them:

>>> from drjit.llvm import Int
>>> @dr.syntax
>>> def f(j: Int):
...     i = Int(0)
...     while i < j:
...         dr.print('{thread_id=} {i=}', i=i)
...         i += 1
>>> f(Int(2, 3))
>>> dr.eval();
thread_id=[0, 0, 1, 1, 1], i=[0, 1, 0, 1, 2]

The example above runs a symbolic loop twice in parallel: the first thread runs for for 2 iterations, and the second runs for 3 iterations. The loop prints the iteration counter i, which then leads to the output [0, 1, 0, 1, 2] where the first two entries are produced by the first thread, and the trailing three belong to the second thread. The thread_id output clarifies this mapping.

  • fmt (str) – A format string that potentially references input arguments from *args and **kwargs.

  • active (drjit.ArrayBase | bool) – A mask argument that can be used to disable a subset of the entries. The print statement will be completely suppressed when there is no output. (default: True).

  • end (str) – This string will be appended to the format string. It is set to a newline character ("\n") by default.

  • file (object) – The print operation will eventually invoke file.write(arg:str) to print the formatted string. Specify this argument to route the output somewhere other than the default output stream sys.stdout.

  • mode (str) – Specify this parameter to override the evaluation mode. Possible values are: "symbolic", "evaluated", or "auto". The default value of "auto" causes the function to use evaluated mode (which prints immediately) unless a symbolic input is detected, in which case printing takes place symbolically (i.e., in a delayed fashion).

  • limit (int) – The operation will abbreviate dynamic arrays with more than limit (default: 20) entries.

drjit.format(fmt: str, *args, limit: int = 20, **kwargs)
drjit.format(value: object, *, limit: int = 20, **kwargs) None

Return a formatted string representation.

This function generates a formatted string representation as specified by a format string fmt and then returns it as a Python str object. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.

>>> from drjit.cuda import Array3f
>>> s = dr.format("{}:\n{foo}",
...               "A PyTree containing an array",
...               foo={ 'a' : Array3f(1, 2, 3) })
>>> print(s)
A PyTree containing an array:
  'a': [[1, 2, 3]]

Dynamic arrays with more than 20 entries will be abbreviated. Specify the limit=.. argument to reveal the contents of larger arrays.

>>> dr.format(dr.arange(dr.llvm.Int, 30))
[0, 1, 2, .. 24 skipped .., 27, 28, 29]

>>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵
 23, 24, 25, 26, 27, 28, 29]

This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:

  • Positional arguments (in *args) can be referenced implicitly ({}), or using indices ({0}, {1}, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.

  • Keyword arguments (in **kwargs) can be referenced via their keyword name ({foo}). Unreferenced keywords will be silently ignored.

  • A trailing = in a brace expression repeats the string within the braces followed by the output:

    >>> dr.format('{foo=}', foo=1)

When the format string fmt is omitted, it is implicitly set to {}, and the function formats a single positional argument.

In contrast to the related drjit.print() , this function does not output the result on the console, and it cannot support symbolic inputs. This is because returning a string right away is incompatible with the requirement of evaluating/formatting symbolic inputs in a delayed fashion. If you wish to format symbolic arrays, you must call drjit.print() with a custom file object that implements the .write() function. Dr.Jit will call this function with the generated string when it is ready.

  • fmt (str) – A format string that potentially references input arguments from *args and **kwargs.

  • limit (int) – The operation will abbreviate dynamic arrays with more than limit (default: 20) entries.


The formatted string representation created as specified above.

Return type:


drjit.log_level() drjit.LogLevel
drjit.set_log_level(arg: drjit.LogLevel, /) None
drjit.set_log_level(arg: int, /) None


drjit.profile_mark(arg: str, /) None

Mark an event on the timeline of profiling tools.

Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.

Note that this event will be recorded on the CPU timeline.


Context manager to mark a region (e.g. a function call) on the timeline of profiling tools.

You can use this context manager to wrap parts of your code and track when and for how long it runs. Regions can be arbitrarily nested, which profiling tools visualize as a stack.

Note that this function is intended to track activity on the CPU timeline. If the wrapped region launches asynchronous GPU kernels, then those won’t generally be included in the length of the range unless drjit.sync_thread() or some other type of synchronization operation waits for their completion (which is generally not advisable, since keeping CPU and GPU asynchronous with respect to each other improves performance).

Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.

Low-level bits

drjit.set_backend(arg: Literal['cuda', 'llvm', 'scalar'], /)
drjit.set_backend(arg: drjit.JitBackend, /) None

Adjust the drjit.auto.* module so that it refers to types from the specified backend.

drjit.thread_count() int

Return the number of threads that Dr.Jit uses to parallelize computation on the CPU

drjit.set_thread_count(arg: int, /) None

Adjust the number of threads that Dr.Jit uses to parallelize computation on the CPU.

The thread pool is primarily used by Dr.Jit’s LLVM backend. Other projects using underlying nanothread thread pool library will also be affected by changes performed using by this function. It is legal to call it even while parallel computation is currently ongoing.

drjit.sync_thread() None

Wait for all currently running computation to finish.

This function synchronizes the device (e.g. the GPU) with the host (CPU) by waiting for the termination of all computation enqueued by the current host thread.

One potential use of this function is to measure the runtime of a kernel launched by Dr.Jit. We instead recommend the use of the drjit.kernel_history(), which exposes more accurate device timers.

In general, calling this function in user code is considered bad practice. Dr.Jit programs “run ahead” of the device to keep it fed with work. This is important for performance, and drjit.sync_thread() breaks this optimization.

All operations sent to a device (including reads) are strictly ordered, so there is generally no reason to wait for this queue to empty. If you find that drjit.sync_thread() is needed for your program to run correctly, then you have found a bug. Please report it on the project’s GitHub issue tracker.

drjit.flush_kernel_cache() None

Release all currently cached kernels.

When Dr.Jit evaluates a previously unseen computation, it compiles a kernel and then maps it into the memory of the CPU or GPU. The kernel stays resident so that it can be immediately reused when that same computation reoccurs at a later point.

In long development sessions (e.g. a Jupyter notebook-based prototyping), this cache may eventually become unreasonably large, and calling flush_kernel_cache() to free it may be advisable.

Note that this does not free the disk cache that also exists to preserve compiled programs across sessions. To clear this cache as well, delete the directory $HOME/.drjit on Linux/macOS, and %AppData%\Local\Temp\drjit on Windows. (The AppData folder is typically found in C:\Users\<your username>).

drjit.flush_malloc_cache() None

Free the memory allocation cache maintained by Dr.Jit.

Allocating and releasing large chunks of memory tends to be relatively expensive, and Dr.Jit programs often need to do so at high rates.

Like most other array programming frameworks, Dr.Jit implements an internal cache to reduce such allocation-related costs. This cache starts out empty and grows on demand. Allocated memory is never released by default, which can be problematic when using multiple array programming frameworks within the same Python session, or when running multiple processes in parallel.

The drjit.flush_malloc_cache() function releases all currently unused memory back to the operating system. This is a relatively expensive step: you likely don’t want to use it within a performance-sensitive program region (e.g. an optimization loop).

drjit.expand_threshold() int

Query the threshold for performing scatter-reductions via expansion.

Getter for the quantity set in drjit.set_expand_threshold()

drjit.set_expand_threshold(arg: int, /) None

Set the threshold for performing scatter-reductions via expansion.

The documentation of drjit.ReduceOp explains the cost of atomic scatter-reductions and introduces various optimization strategies.

One particularly effective optimization (the section on optimizations for plots) named drjit.ReduceOp.Expand is specific to the LLVM backend. It replicates the target array to avoid write conflicts altogether, which enables the use of non-atomic memory operations. This is significantly faster but also very memory-intensive. The storage cost of an 1MB array targeted by a drjit.scatter_reduce() operation now grows to N megabytes, where N is the number of cores.

For this reason, Dr.Jit implements a user-controllable threshold exposed via the functions drjit.expand_threshold() and drjit.set_expand_threshold(). When the array has more entries than the value specified here, the drjit.ReduceOp.Expand strategy will not be used unless specifically requested via the mode= parameter of operations like drjit.scatter_reduce(), drjit.scatter_add(), and drjit.gather().

The default value of this parameter is 1000000 (1 million entries).

drjit.kernel_history(types: collections.abc.Sequence[drjit.KernelType] = []) list

Return the history of captured kernel launches.

Dr.Jit can optionally capture performance-related metadata. To do so, set the drjit.JitFlag.KernelHistory flag as follows:

with dr.scoped_set_flag(dr.JitFlag.KernelHistory):
   # .. computation to be analyzed ..

hist = dr.kernel_history()

The drjit.kernel_history() function returns a list of dictionaries characterizing each major operation performed by the analyzed region. This dictionary has the following entries

  • backend: The used JIT backend.

  • execution_time: The time (in microseconds) used by this operation.

    On the CUDA backend, this value is captured via CUDA events. On the LLVM backend, this involves querying CLOCK_MONOTONIC (Linux/macOS) or QueryPerformanceCounter (Windows).

  • type: The type of computation expressed by an enumeration value of type drjit.KernelType. The most interesting workload generated by Dr.Jit are just-in-time compiled kernels, which are identified by drjit.KernelType.JIT.

    These have several additional entries:

    • hash: The hash code identifying the kernel. (This is the same hash code is also shown when increasing the log level via drjit.set_log_level()).

    • ir: A capture of the intermediate representation used in this kernel.

    • operation_count: The number of low-level IR operations. (A rough proxy for the complexity of the operation.)

    • cache_hit: Was this kernel present in Dr.Jit’s in-memory cache? Otherwise, it as either loaded from memory or had to be recompiled from scratch.

    • cache_disk: Was this kernel present in Dr.Jit’s on-disk cache? Otherwise, it had to be recompiled from scratch.

    • codegen_time: The time (in microseconds) which Dr.Jit needed to generate the textual low-level IR representation of the kernel. This step is always needed even if the resulting kernel is already cached.

    • backend_time: The time (in microseconds) which the backend (either the LLVM compiler framework or the CUDA PTX just-in-time compiler) required to compile and link the low-level IR into machine code. This step is only needed when the kernel did not already exist in the in-memory or on-disk cache.

    • uses_optix: Was this kernel compiled by the NVIDIA OptiX ray tracing engine?

Note that drjit.kernel_history() clears the history while extracting this information. A related operation drjit.kernel_history_clear() only clears the history without returning any information.

drjit.kernel_history_clear() None

Clear the kernel history.

This operation clears the kernel history without returning any information about it. See drjit.kernel_history() for details.


Digital Differential Analyzer

The drjit.dda module provides a general implementation of a digital differential analyzer (DDA) that steps through the intersection of a ray segment and a N-dimensional grid, performing a custom computation at every cell.

The drjit.integrate() function builds on this functionality to compute differentiable line integrals of bi- or trilinearly interpolants stored on a grid.

drjit.dda.dda(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: object, grid_res: ArrayNuT, grid_min: ArrayNfT, grid_max: ArrayNfT, func: Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], Tuple[StateT, BoolT]], state: StateT, active: BoolT, mode: Literal['scalar', 'evaluated', 'symbolic', None] | None = None, max_iterations: int | None = None) StateT

N-dimensional digital differential analyzer (DDA).

This function traverses the intersection of a Cartesian coordinate grid and a specified ray or ray segment. The following snippet shows how to use it to enumerate the intersection of a grid with a single ray.

from drjit.scalar import Array3f, Array3i, Float, Bool

def dda_fun(state: list, index: Array3i,
            pt_in: Array3f, pt_out: Array3f) -> tuple[list, bool]:
    # Entered a grid cell, stash it in the 'state' variable
    return state, Bool(True)

result = dda(
     ray_o = Array3f(-.1),
     ray_d = Array3f(.1, .2, .3),
     ray_max = Float(float('inf')),
     grid_res = Array3i(10),
     grid_min = Array3f(0),
     grid_max = Array3f(1),
     func = dda_fun,
     state = [],
     active = Bool(True)


Since all input elements are Dr.Jit arrays, everything works analogously when processing N rays and N potentially different grid configurations. The entire process can be captured symbolically.

The function takes the following arguments. Note that many of them are generic type variables (signaled by ending with a capital T). To support different dimensions and precisions, the implementation must be able to deal with various input types, which is communicated by these type variables.

  • ray_o (ArrayNfT) – the ray origin, where the ArrayNfT type variable refers to an n-dimensional scalar or Jit-compiled floating point array.

  • ray_d (ArrayNfT) – the ray direction. Does not need to be normalized.

  • ray_max (object) – the maximum extent along the ray, which is permitted to be infinite. The value is specfied as a multiple of the norm of ray_d, which is not necessarily unit-length. Must be of type dr.value_t(ArrayNfT).

  • grid_res (ArrayNuT) – the grid resolution, where the ArrayNuT type variable refers to a matched 32-bit integer array (i.e., ArrayNuT = dr.int32_array_t(ArrayNfT)).

  • grid_min (ArrayNfT) – the minimum position of the grid bounds.

  • grid_max (ArrayNfT) – the maximum position of the grid bounds.

  • func (Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], tuple[StateT, BoolT]]) –

    a callback that will be invoked when the DDA traverses a grid cell. It must take the following five positional arguments:

    1. arg0: StateT: An arbitrary state value.

    2. arg1: ArrayNuT: An integer array specifying the cell index along each dimension.

    3. arg2: ArrayNfT: The fractional position (\(\in [0, 1]^n\)) where the ray enters the current cell.

    4. arg3: ArrayNfT: The fractional position (\(\in [0, 1]^n\)) where the ray leaves the current cell.

    5. arg4: BoolT: A boolean array specifying which elements are active.

    The callback should then return a tuple of type tuple[StateT, BoolT] containing

    1. An updated state value.

    2. A boolean array that can be used to exit the loop prematurely for some or all rays. The iteration stops if the associated entry of the return value equals False.

  • state (StateT) – an arbitrary initial state that will be passed to the callback.

  • active (BoolT) – an array specifying which elements of the input are active, where the BoolT type variable refers to a matched boolean array (i.e., BoolT = dr.mask_t(ray_o.x)).

  • mode – (str | None): The operation can operate in scalar, symbolic, or evaluated modes—see the mode argument and the documentation of drjit.while_loop() for details.

  • max_iterations – int | None: Bound on the iteration count that is needed for reverse-mode differentiation. Forwarded to the max_iterations parameter of drjit.while_loop().


The function returns the final state value of the callback upon termination.

Return type:



Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g. (Z, Y, X)), while regular 3D positions use the opposite (X, Y, Z) order.

In particular, all ArrayNuT-typed parameters of the function and the callback use the ZYX convention, while ArrayNfT-typed parameters use the XYZ convention.

drjit.dda.integrate(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: FloatT, grid_min: ArrayNfT, grid_max: ArrayNfT, vol: ArrayBase[Any, Any, Any, Any, Any, Any, Any], active: object | None = None, mode: Literal['scalar', 'evaluated', 'symbolic', None] | None = None) FloatT

Compute an analytic definite integral of a bi- or trilinear interpolant.

This function uses DDA (drjit.dda.dda()) to step along the voxels of a 2D/3D volume traversed by a finite segment or a infinite-length ray. It analytically computes and accumulates the definite integral of the interpolant in each voxel.

The input 2D/3D volume is provided using a tensor vol (e.g., of type drjit.cuda.ad.TensorXf) with an implicitly specified grid resolution vol.shape. This data volume is placed into an axis-aligned region with bounds (grid_min, grid_max).

The operation provides an efficient forward and backward derivative.


Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g. (Z, Y, X)), while regular 3D positions use the opposite (X, Y, Z) order.

In particular, vol.shape uses the ZYX convention, while the ArrayNfT-typed parameters use the XYZ convention.

One important difference to the texture classes is that the interpolant is sampled at integer grid positions, whereas the Dr.Jit texture classes places values at cell centers, i.e. with a .5 fractional offset.