API Reference (Main)¶
This document explains the public API (behaviors, signatures) in exhaustive detail. If you’re new to Dr.Jit, it may be easier to start by reading the other sections first.
The reference documentation is also exposed through docstrings, which many visual editors (e.g., VS Code, neovim with LSP) will show during code completion, or when hovering over an expression.
The reference extensively use type variables which can be
recognized because their name equals or ends with a capital T (e.g., T,
ArrayT, MaskT, etc.). Type variables serve as placeholders that show
how types propagate through function calls. For example, a function with
signature
def f(arg: T, /) -> tuple[T, T]: ...
will return a pair of int instances when called with an int-typed
arg value.
Array creation¶
- drjit.zeros(dtype: type, shape: int = 1) object¶
- drjit.zeros(dtype: type, shape: collections.abc.Sequence[int]) object
Overloaded function.
zeros(dtype: type, shape: int = 1) -> object
Return a zero-initialized instance of the desired type and shape.
This function can create zero-initialized instances of various types. In particular,
dtypecan be:A Dr.Jit array type like
drjit.cuda.Array2f. Whenshapespecifies a sequence, it must be compatible with static dimensions of thedtype. For example,dr.zeros(dr.cuda.Array2f, shape=(3, 100))fails, since the leading dimension is incompatible withdrjit.cuda.Array2f. Whenshapeis an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf. Whenshapespecifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshapeis an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.zeros()will invoke itself recursively to zero-initialize each field of the data structure.A scalar Python type like
int,float, orbool. Theshapeparameter is ignored in this case.
Note that when
dtyperefers to a scalar mask or a mask array, it will be initialized toFalseas opposed to zero.The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A zero-initialized instance of type
dtype.- Return type:
object
zeros(dtype: type, shape: collections.abc.Sequence[int]) -> object
- drjit.empty(dtype: type, shape: int = 1) object¶
- drjit.empty(dtype: type, shape: collections.abc.Sequence[int]) object
Overloaded function.
empty(dtype: type, shape: int = 1) -> object
Return an uninitialized Dr.Jit array of the desired type and shape.
This function can create uninitialized buffers of various types. It should only be used in combination with a subsequent call to an operation like
drjit.scatter()that fills the array contents with valid data.The
dtypeparameter can be used to request:A Dr.Jit array type like
drjit.cuda.Array2f. Whenshapespecifies a sequence, it must be compatible with static dimensions of thedtype. For example,dr.empty(dr.cuda.Array2f, shape=(3, 100))fails, since the leading dimension is incompatible withdrjit.cuda.Array2f. Whenshapeis an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf. Whenshapespecifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshapeis an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.empty()will invoke itself recursively to allocate memory for each field of the data structure.A scalar Python type like
int,float, orbool. Theshapeparameter is ignored in this case, and the function returns a zero-initialized result (there is little point in instantiating uninitialized versions of scalar Python types).
drjit.empty()delays allocation of the underlying buffer until an operation tries to read/write the actual array contents.- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
An instance of type
dtypewith arbitrary/undefined contents.- Return type:
object
empty(dtype: type, shape: collections.abc.Sequence[int]) -> object
- drjit.ones(dtype: type, shape: int = 1) object¶
- drjit.ones(dtype: type, shape: collections.abc.Sequence[int]) object
Overloaded function.
ones(dtype: type, shape: int = 1) -> object
Return an instance of the desired type and shape filled with ones.
This function can create one-initialized instances of various types. In particular,
dtypecan be:A Dr.Jit array type like
drjit.cuda.Array2f. Whenshapespecifies a sequence, it must be compatible with static dimensions of thedtype. For example,dr.ones(dr.cuda.Array2f, shape=(3, 100))fails, since the leading dimension is incompatible withdrjit.cuda.Array2f. Whenshapeis an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf. Whenshapespecifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshapeis an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.ones()will invoke itself recursively to initialize each field of the data structure.A scalar Python type like
int,float, orbool. Theshapeparameter is ignored in this case.
Note that when
dtyperefers to a scalar mask or a mask array, it will be initialized toTrueas opposed to one.The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtypefilled with ones.- Return type:
object
ones(dtype: type, shape: collections.abc.Sequence[int]) -> object
- drjit.full(dtype: type, value: object, shape: int = 1) object¶
- drjit.full(dtype: type, value: object, shape: collections.abc.Sequence[int]) object
Overloaded function.
full(dtype: type, value: object, shape: int = 1) -> object
Return an constant-valued instance of the desired type and shape.
This function can create constant-valued instances of various types. In particular,
dtypecan be:A Dr.Jit array type like
drjit.cuda.Array2f. Whenshapespecifies a sequence, it must be compatible with static dimensions of thedtype. For example,dr.full(dr.cuda.Array2f, value=1.0, shape=(3, 100))fails, since the leading dimension is incompatible withdrjit.cuda.Array2f. Whenshapeis an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf. Whenshapespecifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshapeis an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.full()will invoke itself recursively to initialize each field of the data structure.A scalar Python type like
int,float, orbool. Theshapeparameter is ignored in this case.
The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
value (object) – An instance of the underlying scalar type (
float/int/bool, etc.) that will be used to initialize the array contents.shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtypefilled withvalue- Return type:
object
full(dtype: type, value: object, shape: collections.abc.Sequence[int]) -> object
- drjit.opaque(dtype: type, value: object, shape: int = 1) object¶
- drjit.opaque(dtype: type, value: object, shape: collections.abc.Sequence[int]) object
Overloaded function.
opaque(dtype: type, value: object, shape: int = 1) -> object
Return an opaque constant-valued instance of the desired type and shape.
This function is very similar to
drjit.full()in that it creates constant-valued instances of various types including (potentially nested) Dr.Jit arrays, tensors, and PyTrees. Please refer to the documentation ofdrjit.full()for details on the function signature. However,drjit.full()creates literal constant arrays, which means that Dr.Jit is fully aware of the array contents.In contrast,
drjit.opaque()produces an opaque array backed by a representation in device memory.Why is this useful?
Consider the following snippet, where a complex calculation is parameterized by the constant
1.from drjit.llvm import Float result = complex_function(Float(1), ...) # Float(1) is equivalent to dr.full(Float, 1) print(result)
The
print()statement will cause Dr.Jit to evaluate the queued computation, which likely also requires compilation of a new kernel (if that exact pattern of steps hasn’t been observed before). Kernel compilation is costly and may be much slower than the actual computation that needs to be done.Suppose we later wish to evaluate the function with a different parameter:
result = complex_function(Float(2), ...) print(result)
The constant
2is essentially copy-pasted into the generated program, causing a mismatch with the previously compiled kernel that therefore cannot be reused. This unfortunately means that we must once more wait a few tens or even hundreds of milliseconds until a new kernel has been compiled and uploaded to the device.This motivates the existence of
drjit.opaque(). By making a variable opaque to Dr.Jit’s tracing mechanism, we can keep constants out of the generated program and improve the effectiveness of the kernel cache:# The following lines reuse the compiled kernel regardless of the constant value = dr.opqaque(Float, 2) result = complex_function(value, ...) print(result)
This function is related to
drjit.make_opaque(), which can turn an already existing Dr.Jit array, tensor, or PyTree into an opaque representation.- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
value (object) – An instance of the underlying scalar type (
float/int/bool, etc.) that will be used to initialize the array contents.shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtypefilled withvalue- Return type:
object
opaque(dtype: type, value: object, shape: collections.abc.Sequence[int]) -> object
- drjit.arange(dtype: type[T], size: int) T¶
- drjit.arange(dtype: type[T], start: int, stop: int, step: int = 1) T
Overloaded function.
arange(dtype: type[T], size: int) -> T
This function generates an integer sequence on the interval [
start,stop) with step sizestep, wherestart= 0 andstep= 1 if not specified.- Parameters:
dtype (type) – Desired Dr.Jit array type. The
dtypemust refer to a dynamically sized 1D Dr.Jit array such asdrjit.scalar.ArrayXuordrjit.cuda.Float.start (int) – Start of the interval. The default value is
0.stop/size (int) – End of the interval (not included). The name of this parameter differs between the two provided overloads.
step (int) – Spacing between values. The default value is
1.
- Returns:
The computed sequence of type
dtype.- Return type:
object
arange(dtype: type[T], start: int, stop: int, step: int = 1) -> T
- drjit.linspace(dtype: type[T], start: float, stop: float, num: int, endpoint: bool = True) T¶
This function generates an evenly spaced floating point sequence of size
numcovering the interval [start,stop].When half-precision output is requested, the function first computes an intermediate result in 32-bit precision and then casts it to 16 bit to limit the effect of rounding errors.
- Parameters:
dtype (type) – Desired Dr.Jit array type. The
dtypemust refer to a dynamically sized 1D Dr.Jit floating point array, such asdrjit.scalar.ArrayXfordrjit.cuda.Float.start (float) – Start of the interval.
stop (float) – End of the interval.
num (int) – Number of samples to generate.
endpoint (bool) – Should the interval endpoint be included? The default is
True.
- Returns:
The computed sequence of type
dtype.- Return type:
object
- drjit.zeros_like(arg: T, /) T¶
Return an array of zeros with the same shape and type as a given array.
- drjit.empty_like(arg: T, /) T¶
Return an empty array with the same shape and type as a given array.
- drjit.ones_like(arg: T, /) T¶
Return an array of ones with the same shape and type as a given array.
Control flow¶
- drjit.syntax(f: None = None, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) Callable[[T], T]¶
- drjit.syntax(f: T, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) T
- drjit.hint(arg: T, /, *, mode: Literal['scalar', 'evaluated', 'symbolic', None] = None, max_iterations: int | None = None, label: str | None = None, include: List[object] | None = None, exclude: List[object] | None = None, strict: bool = True) T¶
Within ordinary Python code, this function is unremarkable: it returns the positional-only argument
argwhile ignoring any specified keyword arguments.The main purpose of
drjit.hint()is to provide hints that influence the transformation performed by the@drjit.syntaxdecorator. The following kinds of hints are supported:modeoverrides the compilation mode of awhileloop orifstatement. The following choices are available:mode='scalar'disables code transformations, which is permitted when the predicate of a loop orifstatement is a scalar Pythonbool.i: int = 0 while dr.hint(i < 10, mode='scalar'): # ...
Routing such code through
drjit.while_loop()ordrjit.if_stmt()still works but may add small overheads, which motivates the existence of this flag. Note that this annotation does not causemode=scalarto be passeddrjit.while_loop(), anddrjit.if_stmt()(which happens to be a valid input of both). Instead, it disables the code transformation altogether so that the above example translates into ordinary Python code:i: int = 0 while i < 10: # ...
mode='evaluated'forces execution in evaluated mode and causes the code transformation to forward this argument to the relevantdrjit.while_loop()ordrjit.if_stmt()call.Refer to the discussion of
drjit.while_loop(),drjit.JitFlag.SymbolicLoops,drjit.if_stmt(), anddrjit.JitFlag.SymbolicConditionalsfor details.mode='symbolic'forces execution in symbolic mode and causes the code transformation to forward this argument to the relevantdrjit.while_loop()ordrjit.if_stmt()call.Refer to the discussion of
drjit.while_loop(),drjit.JitFlag.SymbolicLoops,drjit.if_stmt(), anddrjit.JitFlag.SymbolicConditionalsfor details.
The optional
strict=Falsereduces the strictness of variable consistency checks.Consider the following snippet:
from drjit.llvm import UInt32 @dr.syntax def f(x: UInt32): if x < 4: y = 3 else: y = 5 return y
This code will raise an exception.
>> f(UInt32(1)) RuntimeError: drjit.if_stmt(): the non-array state variable 'y' of type 'int' changed from '5' to '10'. Please review the interface and assumptions of 'drjit.while_loop()' as explained in the documentation (https://drjit.readthedocs.io/en/latest/reference.html#drjit.while_loop).
This is because the computed variable
yof typeinthas an inconsistent value depending on the taken branch. Furthermore,yis a scalar Python type that isn’t tracked by Dr.Jit. The fix here is to initializeywithUInt32(<integer value>).However, there may also be legitimate situations where such an inconsistency is needed by the implementation. This can be fine as
yis not used below theifstatement. In this case, you can annotate the conditional or loop withdr.hint(..., strict=False), which disables the check.max_iterationsspecifies a maximum number of loop iterations for reverse-mode automatic differentiation.Naive reverse-mode differentiation of loops (unless replaced by a smarter problem-specific strategy via
drjit.customanddrjit.CustomOp) requires allocation of large buffers that hold loop state for all iterations.Dr.Jit requires an upper bound on the maximum number of loop iterations so that it can allocate such buffers, which can be provided via this hint. Otherwise, reverse-mode differentiation of loops will fail with an error message.
labelprovovides a descriptive label.Dr.Jit will include this label as a comment in the generated intermediate representation, which can be helpful when debugging the compilation of large programs.
includeandexcludeindicates to the@drjit.syntaxdecorator that a local variable should or should not be considered to be part of the set of state variables passed todrjit.while_loop()ordrjit.if_stmt().While transforming a function, the
@drjit.syntaxdecorator sequentially steps through a program to identify the set of read and written variables. It then forwards referenced variables to recursivedrjit.while_loop()anddrjit.if_stmt()calls. In rare cases, it may be useful to manually include or exclude a local variable from this process— specify a list of such variables to thedrjit.hint()annotation to do so.
- drjit.while_loop(state: tuple[*Ts], cond: typing.Callable[[*Ts], AnyArray | bool], body: typing.Callable[[*Ts], tuple[*Ts]], labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True, compress: bool | None = None, max_iterations: int | None = None) tuple[*Ts]¶
Repeatedly execute a function while a loop condition holds.
Motivation
This function provides a vectorized generalization of a standard Python
whileloop. For example, consider the following Python snippeti: int = 1 while i < 10: x *= x i += 1
This code would fail when
iis replaced by an array with multiple entries (e.g., of typedrjit.llvm.Int). In that case, the loop condition evaluates to a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, each entry of the array may need to run the loop for a different number of iterations. A standard Pythonwhileloop is not able to do so.The
drjit.while_loop()function realizes such a fine-grained looping mechanism. It takes three main input arguments:state, a tuple of state variables that are modified by the loop iteration,Dr.Jit optimizes away superfluous state variables, so there isn’t any harm in specifying variables that aren’t actually modified by the loop.
cond, a function that takes the state variables as input and uses them to evaluate and return the loop condition in the form of a boolean array,body, a function that also takes the state variables as input and runs one loop iteration. It must return an updated set of state variables.
The function calls
condandbodyto execute the loop. It then returns a tuple containing the final version of thestatevariables. With this functionality, a vectorized version of the above loop can be written as follows:i, x = dr.while_loop( state=(i, x), cond=lambda i, x: i < 10, body=lambda i, x: (i+1, x*x) )
Lambda functions are convenient when the condition and body are simple enough to fit onto a single line. In general you may prefer to define local functions (
def loop_cond(i, x): ...) and pass them to thecondandbodyarguments.Dr.Jit also provides the
@drjit.syntaxdecorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) usingdrjit.while_loop(). With this decorator, the above example would be written as follows:@dr.syntax def f(i, x): while i < 10: x *= x i += 1 return i, x
Evaluation modes
Dr.Jit uses one of three different modes to compile this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).
Scalar mode: Scalar loops that don’t need any vectorization can fall back to a simple Python loop construct.
while cond(state): state = body(*state)
This is the default strategy when
cond(state)returns a scalar Pythonbool.The loop body may still use Dr.Jit types, but note that this effectively unrolls the loop, generating a potentially long sequence of instructions that may take a long time to compile. Symbolic mode (discussed next) may be advantageous in such cases.
Symbolic mode: Here, Dr.Jit runs a single loop iteration to capture its effect on the state variables. It embeds this captured computation into the generated machine code. The loop will eventually run on the device (e.g., the GPU) but unlike a Python
whilestatement, the loop does not run on the host CPU (besides the mentioned tentative evaluation for symbolic tracing).When loop optimizations are enabled (
drjit.JitFlag.OptimizeLoops), Dr.Jit may re-trace the loop body so that it runs twice in total. This happens transparently and has no influence on the semantics of this operation.Evaluated mode: in this mode, Dr.Jit will repeatedly evaluate the loop state variables and update active elements using the loop body function until all of them are done. Conceptually, this is equivalent to the following Python code:
active = True while True: dr.eval(state) active &= cond(state) if not dr.any(active): break state = dr.select(active, body(state), state)
(In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)
Dr.Jit will typically compile a kernel when it runs the first loop iteration. Subsequent iterations can then reuse this cached kernel since they perform the same exact sequence of operations. Kernel caching tends to be crucial to achieve good performance, and it is good to be aware of pitfalls that can effectively disable it.
For example, when you update a scalar (e.g. a Python
int) in each loop iteration, this changing counter might be merged into the generated program, forcing the system to re-generate and re-compile code at every iteration, and this can ultimately dominate the execution time. If in doubt, increase the log level of Dr.Jit (drjit.set_log_level()todrjit.LogLevel.Info) and check if the kernels being launched contain the termcache miss. You can also inspect the Kernels launched line in the output ofdrjit.whos(). If you observe soft or hard misses at every loop iteration, then kernel caching isn’t working and you should carefully inspect your code to ensure that the computation stays consistent across iterations.When the loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.
Alternatively, you may specify the parameter
compress=Trueor set the flagdrjit.JitFlag.CompressLoops, which causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases, which is why it isn’t enabled by default.
A separate section about symbolic and evaluated modes discusses these two options in further detail.
The
drjit.while_loop()function chooses the evaluation mode as follows:When the
modeargument is set toNone(the default), the function examines the loop condition. It uses scalar mode when this produces a Python bool, otherwise it inspects thedrjit.JitFlag.SymbolicLoopsflag to switch between symbolic (the default) and evaluated mode.To change this automatic choice for a region of code, you may specify the
mode=keyword argument, nest code into adrjit.scoped_set_flag()block, or change the behavior globally viadrjit.set_flag():with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False): # .. nested code will use evaluated loops ..
When
modeis set to"scalar""symbolic", or"evaluated", it directly uses that method without inspecting the compilation flags or loop condition type.
When using the
@drjit.syntaxdecorator to automatically convert Pythonwhileloops intodrjit.while_loop()calls, you can also use thedrjit.hint()function to pass keyword arguments includingmode,label, ormax_iterationsto the generated looping construct:while dr.hint(i < 10, name='My loop', mode='evaluated'): # ...
Assumptions
The loop condition function must be pure (i.e., it should never modify the state variables or any other kind of program state). The loop body should not write to variables besides the officially declared state variables:
y = .. def loop_body(x): y[0] += x # <-- don't do this. 'y' is not a loop state variable dr.while_loop(body=loop_body, ...)
Dr.Jit automatically tracks dependencies of indirect reads (done via
drjit.gather()) and indirect writes (done viadrjit.scatter(),drjit.scatter_reduce(),drjit.scatter_add(),drjit.scatter_inc(), etc.). Such operations create implicit inputs and outputs of a loop, and these do not need to be specified as loop state variables (however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within loops), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).y = .. def loop_body(x): # Scattering to 'y' is okay even if it is not declared as loop state dr.scatter(target=y, value=x, index=0)
Another important assumption is that the loop state remains consistent across iterations, which means:
The type of state variables is not allowed to change. You may not declare a Python
floatbefore a loop and then overwrite it with adrjit.cuda.Float(or vice versa).Their structure/size must be consistent. The loop body may not turn a variable with 3 entries into one that has 5.
Analogously, state variables must always be initialized prior to the loop. This is the case even if you know that the loop body is guaranteed to overwrite the variable with a well-defined result. An initial value of
Nonewould violate condition 1 (type invariance), while an empty array would violate condition 2 (shape compatibility).
The implementation will check for violations and, if applicable, raise an exception identifying problematic state variables.
Potential pitfalls
Long compilation times.
In the example below,
i < 100000is scalar, causingdrjit.while_loop()to use the scalar evaluation strategy that effectively copy-pastes the loop body 100000 times to produce a giant program. Code written in this way will be bottlenecked by the CUDA/LLVM compilation stage.@dr.syntax def f(): i = 0 while i < 100000: # .. costly computation i += 1
Incorrect behavior in symbolic mode.
Let’s fix the above program by casting the loop condition into a Dr.Jit type to ensure that a symbolic loop is used. Problem solved, right?
from drjit.cuda import Bool @dr.syntax def f(): i = 0 while Bool(i < 100000): # .. costly computation i += 1
Unfortunately, no: this loop never terminates when run in symbolic mode. Symbolic mode does not track modifications of scalar/non-Dr.Jit types across loop iterations such as the
int-valued loop counteri. It’s as if we had writtenwhile Bool(0 < 100000), which of course never finishes.Evaluated mode does not have this problem—if your loop behaves differently in symbolic and evaluated modes, then some variation of this mistake is likely to blame. To fix this, we must declare the loop counter as a vector type before the loop and then modify it as follows:
from drjit.cuda import Int @dr.syntax def f(): i = Int(0) while i < 100000: # .. costly computation i += 1
Warning
This new implementation of the
drjit.while_loop()abstraction still lacks the functionality tobreakorreturnfrom the loop, or tocontinueto the next loop iteration. We plan to add these capabilities in the near future.Interface
- Parameters:
state (tuple) – A tuple containing the initial values of the loop’s state variables. This tuple normally consists of Dr.Jit arrays or PyTrees. Other values are permissible as well and will be forwarded to the loop body. However, such variables will not be captured by the symbolic tracing process.
cond (Callable) – a function/callable that will be invoked with
*state(i.e., the state variables will be unpacked and turned into function arguments). It should return a scalar Pythonboolor a boolean-typed Dr.Jit array representing the loop condition.body (Callable) – a function/callable that will be invoked with
*state(i.e., the state variables will be unpacked and turned into function arguments). It should update the loop state and then return a new tuple of state variables that are compatible with the previous state (see the earlier description regarding what such compatibility entails).mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
Noneare:"scalar","symbolic","evaluated". If not specified, the function first checks if the loop is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicLoopsand then either performs a symbolic or an evaluated loop.compress (Optional[bool]) – Set this parameter to
TrueorFalseto enable or disable loop state compression in evaluated loops (see the text above for a description of this feature). The function queries the value ofdrjit.JitFlag.CompressLoopswhen the parameter is not specified. Symbolic loops ignore this parameter.labels (list[str]) – An optional list of labels associated with each
stateentry. Dr.Jit uses this to provide better error messages in case of a detected inconsistency. The@drjit.syntaxdecorator automatically provides these labels based on the transformed code.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
max_iterations (int) – The maximum number of loop iterations (default:
-1). You must specify a correct upper bound here if you wish to differentiate the loop in reverse mode. In that case, the maximum iteration count is used to reserve memory to store intermediate loop state.strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of
drjit.hint()for an example. The default isstrict=True.
- Returns:
The function returns the final state of the loop variables following termination of the loop.
- Return type:
tuple
- drjit.if_stmt(args: tuple[*Ts], cond: AnyArray | bool, true_fn: typing.Callable[[*Ts], T], false_fn: typing.Callable[[*Ts], T], arg_labels: typing.Sequence[str] = (), rv_labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True) T¶
Conditionally execute code.
Motivation
This function provides a vectorized generalization of a standard Python
ifstatement. For example, consider the following Python snippeti: int = .. some expression .. if i > 0: x = f(i) # <-- some costly function 'f' that depends on 'i' else: y += 1
This code would fail if
iis replaced by an array containing multiple entries (e.g., of typedrjit.llvm.Int). In that case, the conditional expression produces a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, some of the entries may want to run the body of theifstatement, while others must skip to theelseblock. This is not compatible with the semantics of a standard Pythonifstatement.The
drjit.if_stmt()function realizes a more fine-grained conditional operation that accommodates these requirements, while avoiding execution of the costly branch unless this is truly needed. It takes the following input arguments:cond, a boolean array that specifies whether the body of theifstatement should execute.A tuple of input arguments (
args) that will be forwarded totrue_fnandfalse_fn. It is important to specify all inputs to ensure correct derivative tracking of the operation.true_fn, a callable that implements the body of theifblock.false_fn, a callable that implements the body of theelseblock.
The implementation will invoke
true_fn(*args)andfalse_fn(*args)to trace their contents. The return values of these functions must be compatible with each other (a precise definition of compatibility is described below). A vectorized version of the earlier example can then be written as follows:x, y = dr.if_stmt( args=(i, x, y), cond=i > 0, true_fn=lambda i, x, y: (f(i), y), false_fn=lambda i, x, y: (x, y + 1) )
Lambda functions are convenient when
true_fnandfalse_fnare simple enough to fit onto a single line. In general you may prefer to define local functions (def true_fn(i, x, y): ...) and pass them to thetrue_fnandfalse_fnarguments.Dr.Jit later optimizes away superfluous inputs/outputs of
drjit.if_stmt(), so there isn’t any harm in, e.g., specifying an identical element of a return value in bothtrue_fnandfalse_fn.Dr.Jit also provides the
@drjit.syntaxdecorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) usingdrjit.if_stmt(). With this decorator, the above example would be written as follows:@dr.syntax def f(i, x, y): if i > 0: x = f(i) else: y += 1 return x, y
Evaluation modes
Dr.Jit uses one of three different modes to realize this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).
Scalar mode: Scalar
ifstatements that don’t need any vectorization can be reduced to normal Python branching constructs:if cond: state = true_fn(*args) else: state = false_fn(*args)
This strategy is the default when
condis a scalar Pythonbool.Symbolic mode: Dr.Jit runs
true_fnandfalse_fnto capture the computation performed by each function, which allows it to generate an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.Evaluated mode: in this mode, Dr.Jit runs both branches of the
ifstatement and then combines the results viadrjit.select(). This is nearly equivalent to the following Python code:true_state = true_fn(*state) false_state = false_fn(*state) if false_fn else state state = dr.select(cond, true_fn, false_fn)
(In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)
Evaluated mode is conceptually simpler but also slower, since the device executes both sides of a branch when only one of them is actually needed.
The mode is chosen as follows:
When the
modeargument is set toNone(the default), the function examines the type of thecondinput and uses scalar mode if the type is a builtin Pythonbool.Otherwise, it chooses between symbolic and evaluated mode based on the
drjit.JitFlag.SymbolicConditionalsflag, which is set by default. To change this choice for a region of code, you may specify themode=keyword argument, nest it into adrjit.scoped_set_flag()block, or change the behavior globally viadrjit.set_flag():with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False): # .. nested code will use evaluated mode ..
When
modeis set to"scalar""symbolic", or"evaluated", it directly uses that mode without inspecting the compilation flags or condition type.
When using the
@drjit.syntaxdecorator to automatically convert Pythonifstatements intodrjit.if_stmt()calls, you can also use thedrjit.hint()function to pass keyword arguments including themodeandlabelparameters.if dr.hint(i < 10, mode='evaluated'): # ...
Assumptions
The return values of
true_fnandfalse_fnmust be of the same type. This requirement applies recursively if the return value is a PyTree.Dr.Jit will refuse to compile vectorized conditionals, in which
true_fnandfalse_fnreturn a scalar that is inconsistent between the branches.>>> @dr.syntax ... def (x): ... if x > 0: ... y = 1 ... else: ... y = 0 ... return y ... >>> print(f(dr.llvm.Float(-1,2))) RuntimeError: dr.if_stmt(): detected an inconsistency when comparing the return values of 'true_fn' and 'false_fn': drjit.detail.check_compatibility(): inconsistent scalar Python object of type 'int' for field 'y'. Please review the interface and assumptions of dr.if_stmt() as explained in the Dr.Jit documentation.
The problem can be solved by assigning an instance of a capitalized Dr.Jit type (e.g.,
y=Int(1)) so that the operation can be tracked.The functions
true_fnandfalse_fnshould not write to variables besides the explicitly declared return value(s):vec = drjit.cuda.Array3f(1, 2, 3) def true_fn(x): vec.x += x # <-- don't do this. 'y' is not a declared output dr.if_stmt(args=(x,), true_fun=true_fn, ...)
This example can be fixed as follows:
def true_fn(x, vec): vec.x += x return vec vec = dr.if_stmt(args=(x, vec), true_fun=true_fn, ...)
drjit.if_stmt()is differentiable in both forward and reverse modes. Correct derivative tracking requires that regular differentiable inputs are specified via theargsparameter. The@drjit.syntaxdecorator ensures that these assumptions are satisfied.Dr.Jit also tracks dependencies of indirect reads (done via
drjit.gather()) and indirect writes (done viadrjit.scatter(),drjit.scatter_reduce(),drjit.scatter_add(),drjit.scatter_inc(), etc.). Such operations create implicit inputs and outputs, and these do not need to be specified as part ofargsor the return value oftrue_fnandfalse_fn(however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within conditional statements), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).y = .. def true_fn(x): # 'y' is neither declared as input nor output of 'f', which is fine dr.scatter(target=y, value=x, index=0) dr.if_stmt(args=(x,), true_fn=true_fn, ...)
Interface
- Parameters:
cond (bool|drjit.ArrayBase) – a scalar Python
boolor a boolean-valued Dr.Jit array.args (tuple) – A list of positional arguments that will be forwarded to
true_fnandfalse_fn.true_fn (Callable) – a callable that implements the body of the
ifblock.false_fn (Callable) – a callable that implements the body of the
elseblock.mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
Noneare:"scalar","symbolic","evaluated".arg_labels (list[str]) – An optional list of labels associated with each input argument. Dr.Jit uses this feature in combination with the
@drjit.syntaxdecorator to provide better error messages in case of detected inconsistencies.rv_labels (list[str]) – An optional list of labels associated with each element of the return value. This parameter should only be specified when the return value is a tuple. Dr.Jit uses this feature in combination with the
@drjit.syntaxdecorator to provide better error messages in case of detected inconsistencies.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of
drjit.hint()for an example. The default isstrict=True.
- Returns:
Combined return value mixing the results of
true_fnandfalse_fn.- Return type:
object
- drjit.switch(index: object, targets: collections.abc.Sequence, *args, **kwargs) object¶
Selectively invoke functions based on a provided index array.
When called with a scalar
index(of typeint), this function is equivalent to the following Python expression:targets[index](*args, **kwargs)
When called with a Dr.Jit index array (specifically, 32-bit unsigned integers), it performs the vectorized equivalent of the above and assembles an array of return values containing the result of all referenced functions. It does so efficiently using at most a single invocation of each function in
targets.from drjit.llvm import UInt32 res = dr.switch( index=UInt32(0, 0, 1, 1), # <-- selects the function targets=[ # <-- arbitrary list of callables lambda x: x, lambda x: x*10 ], UInt32(1, 2, 3, 4) # <-- argument passed to function ) # res now contains [0, 10, 20, 30]
The function traverses the set of positional (
*args) and keyword arguments (**kwargs) to find all Dr.Jit arrays including arrays contained within PyTrees. It routes a subset of array entries to each function as specified by theindexargument.Dr.Jit will use one of two possible strategies to compile this operation depending on the active compilation flags (see
drjit.set_flag(),drjit.scoped_set_flag()):Symbolic mode: Dr.Jit transcribes every function into a counterpart in the generated low-level intermediate representation (LLVM IR or PTX) and targets them via an indirect jump instruction.
This mode is used when
drjit.JitFlag.SymbolicCallsis set, which is the default.Evaluated mode: Dr.Jit evaluates the inputs
index,args,kwargsviadrjit.eval(), groups them byindex, and invokes each function with the subset of inputs that reference it. Callables that are not referenced by any element ofindexare ignored.In this mode, a
drjit.switch()statement will cause Dr.Jit to launch a series of kernels processing subsets of the input data (one per function).
A separate section about symbolic and evaluated modes discusses these two options in detail.
To switch the compilation mode locally, use
drjit.scoped_set_flag()as shown below:with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False): result = dr.switch(..)
When a boolean Dr.Jit array (e.g.,
drjit.llvm.Bool,drjit.cuda.ad.Bool, etc.) is specified as last positional argument or as a keyword argument namedactive, that argument is treated specially: entries of the input arrays associated with aFalsemask entry are ignored and never passed to the functions. Associated entries of the return value will be zero-initialized. The function will still receive the mask argument as input, but it will always be set toTrue.Danger
The indices provided to this operation are unchecked by default. Attempting to call functions beyond the end of the
targetsarray is undefined behavior and may crash the application, unless such calls are explicitly disabled via theactiveparameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debugflag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound calls and furthermore report warnings to identify problematic source locations:>>> print(dr.switch(UInt32(0, 100), [lambda x:x], UInt32(1))) Attempted to invoke callable with index 100, but this↵ value must be smaller than 1. (<stdin>:2)
- Parameters:
index (int|drjit.ArrayBase) – a list of indices to choose the functions
targets (Sequence[Callable]) – a list of callables to which calls will be dispatched based on the
indexargument.mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
Noneare:"symbolic","evaluated". If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicCallsand then either performs a symbolic or an evaluated call.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
*args (tuple) – a variable-length list of positional arguments passed to the functions. PyTrees are supported.
**kwargs (dict) – a variable-length list of keyword arguments passed to the functions. PyTrees are supported.
- Returns:
When
indexis a scalar Python integer, the return value simply forwards the return value of the selected function. Otherwise, the function returns a Dr.Jit array or PyTree combining the results from each referenced callable.- Return type:
object
- drjit.dispatch(inst: drjit.ArrayBase, target: collections.abc.Callable, *args, **kwargs) object¶
Invoke a provided Python function for each instance in an instance array.
This function invokes the provided
targetfor each instance in the instance arrayinstand assembles the return values into a result array. Conceptually, it does the following:def dispatch(inst, target, *args, **kwargs): result = [] for inst_ in inst: result.append(target(inst_, *args, **kwargs))
However, the implementation accomplishes this more efficiently using only a single call per unique instance. Instead of a Python
list, it returns a Dr.Jit array or PyTree.In practice, this function is mainly good for two things:
Dr.Jit instance arrays contain C++ instance, and these will typically expose a set of methods. Adding further methods requires re-compiling C++ code and adding bindings, which may impede quick prototyping. With
drjit.dispatch(), a developer can quickly implement additional vectorized method calls within Python (with the caveat that these can only access public members of the underlying type).Dynamic dispatch is a relatively costly operation. When multiple calls are performed on the same set of instances, it may be preferable to merge them into a single and potentially significantly faster use of
drjit.dispatch(). An example is shown below:inst = # .. Array of C++ instances .. result_1 = inst.func_1(arg1) result_2 = inst.func_2(arg2)
The following alternative implementation instead uses
drjit.dispatch():def my_func(self, arg1, arg2): return (self.func_1(arg1), self.func_2(arg2)) result_1, result_2 = dr.dispatch(inst, my_func, arg1, arg2)
This function is otherwise very similar to
drjit.switch()and similarly provides two different compilation modes, differentiability, and special handling of mask arguments. Please review the documentation ofdrjit.switch()for details.- Parameters:
inst (drjit.ArrayBase) – a Dr.Jit instance array.
target (Callable) – function to dispatch on all instances
mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
Noneare:"symbolic","evaluated". If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicCallsand then either performs a symbolic or an evaluated call.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
*args (tuple) – a variable-length list of positional arguments passed to the function. PyTrees are supported.
**kwargs (dict) – a variable-length list of keyword arguments passed to the function. PyTrees are supported.
- Returns:
A Dr.Jit array or PyTree containing the result of each performed function call.
- Return type:
object
Scatter/gather operations¶
- drjit.gather(dtype: type[T], source: object, index: AnyArray | Sequence[int] | int, active: AnyArray | Sequence[bool] | bool = True, mode: drjit.ReduceMode = drjit.ReduceMode.Auto, shape: tuple[int, ...] | None = None) T¶
Gather values from a flat array or nested data structure.
This function performs a gather (i.e., indirect memory read) from
sourceat positionindex. It expects adtypeargument and will return an instance of this type. The optionalactiveargument can be used to disable some of the components, which is useful when not all indices are valid; the corresponding output will be zero in this case.This operation can be used in the following different ways:
When
dtypeis a 1D Dr.Jit array likedrjit.llvm.ad.Float, this operation implements a parallelized version of the Python array indexing expressionsource[index]with optional masking. Example:source = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted result = dr.gather(dtype=type(source), source=source, index=index)
When
dtypeis a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
type(source)matchesdtype, the gather operation threads through entries and invokes itself recursively. For example, the gather operation insource = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) result = dr.gather(dr.cuda.Array3f, source, index)
is equivalent to
result = dr.cuda.Array3f( dr.gather(dr.cuda.Float, source.x, index), dr.gather(dr.cuda.Float, source.y, index), dr.gather(dr.cuda.Float, source.z, index) )
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation reconstructs the requested
dtypefrom a flatsourcearray, using C-style ordering with a suitably modifiedindex. For example, the gather below reads 3D vectors from a 1D array.source = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) result = dr.gather(dr.cuda.Array3f, source, index)
and is equivalent to
result = dr.cuda.Array3f( dr.gather(dr.cuda.Float, source, index*3 + 0), dr.gather(dr.cuda.Float, source, index*3 + 1), dr.gather(dr.cuda.Float, source, index*3 + 2))
By default, Dr.Jit bundles sequences of gather operations that access a power-of-two number of contiguous elements (e.g.
dr.gather(Array4f, ...)ordr.gather(ArrayXf, ..., shape=(16, N))into one or more packet loads (the precise number depending on the hardware’s capabilities) that is potentially significantly faster. This optimization can be controlled via thedrjit.JitFlag.PacketOpsflag.
Danger
The indices provided to this operation are unchecked by default. Attempting to read beyond the end of the
sourcearray is undefined behavior and may crash the application, unless such reads are explicitly disabled via theactiveparameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debugflag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound reads and furthermore report warnings to identify problematic source locations:>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100)) drjit.gather(): out-of-bounds read from position 100 in an array↵ of size 3. (<stdin>:2)
- Parameters:
dtype (type) – The desired output type (typically equal to
type(source), but other variations are possible as well, see the description above.)source (object) – The object from which data should be read (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXuordrjit.cuda.UInt) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXbordrjit.cuda.Bool) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue.mode (drjit.ReduceMode) – The reverse-mode derivative of a gather is an atomic scatter-reduction. The execution of such atomics can be rather performance-sensitive (see the discussion of
drjit.ReduceModefor details), hence Dr.Jit offers a few different compilation strategies to realize them. Specifying this parameter selects a strategy for the derivative of a particular gather operation. The default isdrjit.ReduceMode.Auto.shape (tuple[int, ...] | None) – When gathering into a dynamically sized array type (e.g.
drjit.cuda.ArrayXf), this parameter can be used to specify the shape of the unknown dimensions. Otherwise, it is not needed. The default isNone.
- drjit.scatter(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None¶
Scatter values into a flat array or nested data structure.
This operation performs a scatter (i.e., indirect memory write) of the
valueparameter to thetargetarray at positionindex. The optionalactiveargument can be used to disable some of the individual write operations, which is useful when not all provided values or indices are valid.This operation can be used in the following different ways:
When
targetis a 1D Dr.Jit array likedrjit.llvm.ad.Float, this operation implements a parallelized version of the Python array indexing expressiontarget[index] = valuewith optional masking. Example:target = dr.empty(dr.cuda.Float, 1024*1024) value = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted dr.scatter(target, value=value, index=index)
When
targetis a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
targetandvalueare of the same type, the scatter operation threads through entries and invokes itself recursively. For example, the scatter operation intarget = dr.cuda.Array3f(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter(target, value, index)
is equivalent to
dr.scatter(target.x, value.x, index) dr.scatter(target.y, value.y, index) dr.scatter(target.z, value.z, index)
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation flattens the
valuearray and writes it using C-style ordering with a suitably modifiedindex. For example, the scatter below writes 3D vectors into a 1D array.target = dr.cuda.Float(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter(target, value, index)
and is equivalent to
dr.scatter(target, value.x, index*3 + 0) dr.scatter(target, value.y, index*3 + 1) dr.scatter(target, value.z, index*3 + 2)
By default, Dr.Jit bundles sequences of scatter operations that write a power-of-two number of contiguous elements into one or more packet stores (the precise number depending on the hardware’s capabilities) that is potentially significantly faster. This optimization can be controlled via the
drjit.JitFlag.PacketOpsflag.
Danger
The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the
activeparameter). Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debugflag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.Dr.Jit makes no guarantees about the expected behavior when a scatter operation has conflicts, i.e., when a specific position is written multiple times by a single
drjit.scatter()operation.- Parameters:
target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
value (object) – The values to be written (typically of type
type(target), but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the input is not an array type.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXuordrjit.cuda.UInt) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXbordrjit.cuda.Bool) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue.
- drjit.scatter_reduce(op: drjit.ReduceOp, target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None¶
Atomically update values in a flat array or nested data structure.
This function performs an atomic scatter-reduction, which is a read-modify-write operation that applies one of several possible mathematical functions to selected entries of an array. The following are supported:
drjit.ReduceOp.Add:a=a+b.drjit.ReduceOp.Max:a=max(a, b).drjit.ReduceOp.Min:a=min(a, b).drjit.ReduceOp.Or:a=a | b(integer arrays only).drjit.ReduceOp.And:a=a & b(integer arrays only).
Here,
arefers to an entry oftargetselected byindex, andbdenotes the associated element ofvalue. The operation resolves potential conflicts arising due to the parallel execution of this operation.The optional
activeargument can be used to disable some of the updates, e.g., when not all provided values or indices are valid.Atomic additions are subject to non-deterministic rounding errors. The reason for this is that IEEE-754 addition are non-commutative. The execution order is scheduling-dependent, which can lead to small variations across program runs.
Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude. Dr.Jit provides several different compilation strategies to reduce these costs, which can be selected via the
modeparameter. The documentation ofdrjit.ReduceModeprovides more detail and performance plots.Backend support
Many combinations of reductions and variable types are not supported. Some combinations depend on the compute capability (CC) of the underlying CUDA device or on the LLVM version (LV) and the host architecture (AMD64, x86_64). The following matrices display the level of support.
For CUDA:
Reduction
Bool[U]Int{32,64}Float16Float32Float64✅
✅
✅
✅
✅
❌
✅
⚠️ CC≥60
✅
⚠️ CC≥60
❌
❌
❌
❌
❌
❌
✅
⚠️ CC≥90
❌
❌
❌
✅
⚠️ CC≥90
❌
❌
❌
✅
❌
❌
❌
❌
✅
❌
❌
❌
For LLVM:
Reduction
Bool[U]Int{32,64}Float16Float32Float64✅
✅
✅
✅
✅
❌
✅
⚠️ LV≥16
✅
✅
❌
❌
❌
❌
❌
❌
⚠️ LV≥15
⚠️ LV≥16, ARM64
⚠️ LV≥15
⚠️ LV≥15
❌
⚠️ LV≥15
⚠️ LV≥16, ARM64
⚠️ LV≥15
⚠️ LV≥15
❌
✅
❌
❌
❌
❌
✅
❌
❌
❌
The function raises an exception when the operation is not supported by the backend.
Scatter-reducing nested types
This operation can be used in the following different ways:
When
targetis a 1D Dr.Jit array likedrjit.llvm.ad.Float, this operation implements a parallelized version of the Python array indexing expressiontarget[index] = op(target[index], value)with optional masking. Example:target = dr.zeros(dr.cuda.Float, 1024*1024) value = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted dr.scatter_reduce(dr.ReduceOp.Add, target, value=value, index=index)
When
targetis a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
targetandvalueare of the same type, the scatter-reduction threads through entries and invokes itself recursively. For example, the scatter operation inop = dr.ReduceOp.Add target = dr.cuda.Array3f(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter_reduce(op, target, value, index)
is equivalent to
dr.scatter_reduce(op, target.x, value.x, index) dr.scatter_reduce(op, target.y, value.y, index) dr.scatter_reduce(op, target.z, value.z, index)
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation flattens the
valuearray and writes it using C-style ordering with a suitably modifiedindex. For example, the scatter-reduction below writes 3D vectors into a 1D array.op = dr.ReduceOp.Add target = dr.cuda.Float(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter_reduce(op, target, value, index)
and is equivalent to
dr.scatter_reduce(op, target, value.x, index*3 + 0) dr.scatter_reduce(op, target, value.y, index*3 + 1) dr.scatter_reduce(op, target, value.z, index*3 + 2)
Dr.Jit may be able to bundle sequences of scatter-reductions that write a power-of-two number of contiguous elements into one or more packet scatter-updates that are potentially significantly faster. This optimization can be controlled via the
drjit.JitFlag.PacketOpsflag. Currently, this is only possible in a narrow set of circumstances requiring:use of the LLVM backend.
2. integer or floating point scatter-additions. 2. use of the
drjit.ReduceMode.Expandreduction strategy.
Danger
The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the
activeparameter). Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debugflag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.Dr.Jit makes no guarantees about the relative ordering of atomic operations when a
drjit.scatter_reduce()writes to the same element multiple times. Combined with the non-associate nature of floating point operations, concurrent writes will generally introduce non-deterministic rounding error.- Parameters:
op (drjit.ReduceOp) – Specifies the type of update that should be performed.
target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
value (object) – The values to be used in the RMW operation (typically of type
type(target), but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the the input is not an array type.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXuordrjit.cuda.UInt) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXbordrjit.cuda.Bool) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue.mode (drjit.ReduceMode) – Dr.Jit offers several different strategies to implement atomic scatter-reductions that can be selected via this parameter. They achieve different best/worst case performance and, in the case of
drjit.ReduceMode.Expand, involve additional memory storage overheads. The default isdrjit.ReduceMode.Auto.
- drjit.scatter_add(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None¶
Atomically add values to a flat array or nested data structure.
This function is equivalent to
drjit.scatter_reduce(drjit.ReduceOp.Add, ...)and exists for reasons of convenience. Please refer todrjit.scatter_reduce()for details on atomic scatter-reductions.
- drjit.scatter_add_kahan(target_1: drjit.ArrayBase, target_2: drjit.ArrayBase, value: object, index: object, active: object = True) None¶
Perform a Kahan-compensated atomic scatter-addition.
Atomic floating point accumulation can incur significant rounding error when many values contribute to a single element. This function implements an error-compensating version of
drjit.scatter_add()based on the Kahan-Babuška-Neumeier algorithm that simultaneously accumulates into two target buffers.In particular, the operation first accumulates a values into entries of a dynamic 1D array
target_1. It tracks the round-off error caused by this operation and then accumulates this error into a second 1D array namedtarget_2. At the end, the buffers can simply be added together to obtain the error-compensated result.This function has a few limitations: in contrast to
drjit.scatter_reduce()anddrjit.scatter_add(), it does not perform a local reduction (see flagJitFlag.ScatterReduceLocal), which can be an important optimization when atomic accumulation is a performance bottleneck.Furthermore, the function currently works with flat 1D arrays. This is an implementation limitation that could in principle be removed in the future.
Finally, the function is differentiable, but derivatives currently only propagate into
target_1. This means that forward derivatives don’t enjoy the error compensation of the primal computation. This limitation is of no relevance for backward derivatives.
- drjit.scatter_inc(target: drjit.ArrayBase, index: object, active: object = True) object¶
Atomically increment a value within an unsigned 32-bit integer array and return the value prior to the update.
This operation works just like the
drjit.scatter_reduce()operation for 32-bit unsigned integer operands, but with a fixedvalue=1parameter andop=ReduceOp::Add.The main difference is that this variant additionally returns the old value of the target array prior to the atomic update in contrast to the more general scatter-reduction, which just returns
None. The operation also supports masking—the return value in the masked case is undefined. Bothtargetandindexparameters must be 1D unsigned 32-bit arrays.This operation is a building block for stream compaction: threads can scatter-increment a global counter to request a spot in an array and then write their result there. The recipe for this is look as follows:
data_1 = ... data_2 = ... active = drjit.ones(Bool, len(data_1)) # .. or a more complex condition # This will hold the counter ctr = UInt32(0) # Allocate output buffers max_size = 1024 data_compact_1 = dr.empty(Float, max_size) data_compact_2 = dr.empty(Float, max_size) idx = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) # Disable dr.scatter() operations below in case of a buffer overflow active &= idx < max_size dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.scatter( target=data_compact_2, value=data_2, index=my_index, mask=active )
When following this approach, be sure to provide the same mask value to the
drjit.scatter_inc()and subsequentdrjit.scatter()operations.The function
drjit.reshape()can be used to resize the resulting arrays to their compacted size. Please refer to the documentation of this function, specifically the code example illustrating the use of theshrink=Trueargument.The function
drjit.scatter_inc()exhibits the following unusual behavior compared to regular Dr.Jit operations: the return value references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the computed index is reused. In essence, the variable is “consumed” by the process of evaluation.my_index = dr.scatter_inc(target=ctr, index=UInt32(0), active=active) dr.scatter( target=data_compact_1, value=data_1, index=my_index, active=active ) dr.eval(data_compact_1) # Run Kernel #1 dr.scatter( target=data_compact_2, value=data_2, index=my_index, # <-- oops, reusing my_index in another kernel. active=active # This raises an exception. )
To get the above code to work, you will need to evaluate
my_indexat the same time to materialize it into a stored (and therefore trivially reproducible) representation. For this, ensure that the size of theactivemask matcheslen(data_*)and that it is not the trivialTruedefault mask (otherwise, the evaluatedmy_indexwill be scalar).dr.eval(data_compact_1, my_index)
Such multi-stage evaluation is potentially inefficient and may defeat the purpose of performing stream compaction in the first place. In general, prefer keeping all scatter operations involving the computed index in the same kernel, and then this issue does not arise.
The implementation of
drjit.scatter_inc()performs a local reduction first, followed by a single atomic write per SIMD packet/warp. This is done to reduce contention from a potentially very large number of atomic operations targeting the same memory address. Fully masked updates do not cause memory traffic.There is some conceptual overlap between this function and
drjit.compress(), which can likewise be used to reduce a stream to a smaller subset of active items. The downside ofdrjit.compress()is that it requires evaluating the variables to be reduced, which can be very costly in terms of memory traffic and storage footprint. Reducing throughdrjit.scatter_inc()does not have this limitation: it can operate on symbolic arrays that greatly exceed the available device memory. One advantage ofdrjit.compress()is that it essentially boils down to a relatively simple prefix sum, which does not require atomic memory operations (these can be slow in some cases).
- drjit.scatter_exch(target: drjit.ArrayBase, value: object, index: object, active: object = True) object¶
Atomically exchange values in an array and return the previous values.
This operation performs an atomic exchange (swap) operation: it writes
valuetotarget[index]and returns the original value that was stored at that location before the write. This is essentially an atomic read-modify-write operation that combines a gather and scatter in a single atomic step.The operation is similar to
drjit.scatter(), but with the additional functionality of returning the old values. This makes it useful for implementing various synchronization primitives and lock-free data structures.Each lane of this operation executes atomically. However, when multiple lanes write to the same index, the order of operations is non-deterministic due to the parallel execution model. While each individual exchange is atomic, the final value at a contested location and the returned old values will depend on the hardware scheduling.
The optional
activeargument can be used to mask out some of the updates. Masked lanes will return zero and will not perform the exchange operation.The function
drjit.scatter_exch()exhibits the following unusual behavior compared to regular Dr.Jit operations: the return value references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the returned value is reused. In essence, the variable is “consumed” by the process of evaluation.old = dr.scatter_exch(target=buffer, value=new_val, index=idx, active=active) dr.scatter( target=data_compact_1, value=data_1, index=old, # Using the returned old values as indices active=active ) dr.eval(data_compact_1) # Run Kernel #1 dr.scatter( target=data_compact_2, value=data_2, index=old, # <-- oops, reusing 'old' in another kernel. active=active # This raises an exception. )
To get the above code to work, you will need to evaluate
oldat the same time to materialize it into a stored (and therefore trivially reproducible) representation.dr.eval(data_compact_1, old)
- Parameters:
target (object) – a JIT-compiled 1D dynamic Dr.Jit array (e.g.,
drjit.cuda.Floatordrjit.llvm.Float) that will be updated with new values.value (object) – values to write into the target array. Must be convertible to the same type as
target.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.cuda.UInt32ordrjit.llvm.UInt32) specifying the write indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.cuda.Boolordrjit.llvm.Bool) specifying active lanes. The default isTrue.
- Returns:
The previous values at the specified indices before the exchange operation. Masked lanes return zero.
- Return type:
- drjit.scatter_cas(target: drjit.ArrayBase, compare: object, value: object, index: object, active: object = True) object¶
Atomically perform a compare-and-swap operation on array elements.
This operation performs an atomic compare-and-swap (CAS): it compares the current value at
target[index]withcompare, and only if they match, it writesvalueto that location. The operation returns a tuple containing the original value at the target location and a boolean mask indicating which comparisons succeeded.This atomic primitive is fundamental for building lock-free data structures and synchronization mechanisms. It enables threads to conditionally update shared memory only if it hasn’t been modified by another thread.
Each lane of this operation executes atomically. When multiple lanes target the same index, the order of operations is non-deterministic due to parallel execution. Only one lane will succeed in performing the swap if they all provide the same comparison value.
The optional
activeargument can be used to mask out some of the operations. Masked lanes will return zero for the old value andFalsefor the success mask, and will not perform any memory transactions.The function
drjit.scatter_cas()exhibits the following unusual behavior compared to regular Dr.Jit operations: the returnedoldandsuccessvalues references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the returned values are reused. In essence, these variables are “consumed” by the process of evaluation.old, success = dr.scatter_cas(target=buffer, compare=expected, value=new_val, index=idx, active=active) # Use the old values immediately in the same kernel dr.scatter(target=other_buffer, value=data, index=old, active=success) dr.eval(other_buffer) # Run Kernel #1 # This would fail - can't reuse 'old' or 'success' in another kernel # dr.scatter(target=another_buffer, value=data2, index=old, active=success)
To reuse these values across multiple kernels, evaluate them together with the first operation to materialize them into stored representations:
dr.eval(other_buffer, old, success)
- Parameters:
target (object) – a JIT-compiled 1D dynamic Dr.Jit array (e.g.,
drjit.cuda.Floatordrjit.llvm.UInt32) that will be conditionally updated.compare (object) – values to compare against the current contents of
target[index]. Must be convertible to the same type astarget.value (object) – values to write into the target array if the comparison succeeds. Must be convertible to the same type as
target.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.cuda.UInt32ordrjit.llvm.UInt32) specifying the indices to operate on. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.cuda.Boolordrjit.llvm.Bool) specifying active lanes. The default isTrue.
- Returns:
A tuple
(old, success)where:old(drjit.ArrayBase): The original values at the specified indices before any modification. Masked lanes return zero.success(drjit.ArrayBase): A boolean mask indicating which compare-and-swap operations succeeded. Masked lanes returnFalse.
- Return type:
tuple
- drjit.slice(value: object, index: object = 0) object¶
Select a subset of the input array or PyTree along the trailing dynamic dimension.
Given a Dr.Jit array
valuewith shape(..., N)(whereNrepresents a dynamically sized dimension), this operation effectively evaluates the expressionvalue[..., index]. It recursively traverses PyTrees and transforms each compatible array element. Other values are returned unchanged.The following properties of
indexdetermine the return type:When
indexis a 1D integer array, the operation reduces to one or more calls todrjit.gather(), andslice()returns a reduced output object of the same type and structure.When
indexis a scalar Pythonint, the trailing dimension is entirely removed, and the operation returns an array from thedrjit.scalarnamespace containing the extracted values.
Reductions¶
- enum drjit.ReduceOp(value)¶
List of different atomic read-modify-write (RMW) operations supported by
drjit.scatter_reduce().Valid values are as follows:
- Identity = ReduceOp.Identity¶
Perform an ordinary scatter operation that ignores the current entry.
- Add = ReduceOp.Add¶
Addition.
- Mul = ReduceOp.Mul¶
Multiplication.
- Min = ReduceOp.Min¶
Minimum.
- Max = ReduceOp.Max¶
Maximum.
- And = ReduceOp.And¶
Binary AND operation.
- Or = ReduceOp.Or¶
Binary OR operation.
- enum drjit.ReduceMode(value)¶
Compilation strategy for atomic scatter-reductions.
Elements of this enumeration determine how Dr.Jit executes atomic scatter-reductions, which refers to indirect writes that update an existing element in an array, while avoiding problems arising due to concurrency.
Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude.
This parameter also plays an important role for
drjit.gather(), which is nominally a read-only operation. This is because the reverse-mode derivative of a gather turns it into an atomic scatter-addition, where further context on how to compile the operation is needed.Dr.Jit implements several strategies to address contention, which can be selected by passing the optional
modeparameter todrjit.scatter_reduce(),drjit.scatter_add(), anddrjit.gather().If you find that a part of your program is bottlenecked by atomic writes, then it may be worth explicitly specifying some of the strategies below to see which one performs best.
Valid values are as follows:
- Auto = ReduceMode.Auto¶
Select the first valid option from the following list:
use
drjit.ReduceMode.Expandif the computation uses the LLVM backend and thetargetarray storage size is smaller or equal than the value given bydrjit.expand_threshold(). This threshold can be changed using thedrjit.set_expand_threshold()function.use
drjit.ReduceMode.Localifdrjit.JitFlag.ScatterReduceLocalis set.fall back to
drjit.ReduceMode.Direct.
- Direct = ReduceMode.Direct¶
Insert an ordinary atomic reduction operation into the program.
This mode is ideal when no or little contention is expected, for example because the target indices of scatters are well spread throughout the target array. This mode generates a minimal amount of code, which can help improve performance especially on GPU backends.
- Local = ReduceMode.Local¶
Locally pre-reduce operands.
In this mode, Dr.Jit adds extra code to the compiled program to examine the target indices of atomic updates. For example, CUDA programs run with an instruction granularity referred to as a warp, which is a group of 32 threads. When some of these threads want to write to the same location, then those operands can be pre-processed to reduce the total number of necessary atomic memory transactions (potentially to just a single one!)
On the CPU/LLVM backend, the same process works at the granularity of packets. The details depends on the underlying instruction set—for example, there are 16 threads per packet on a machine with AVX512, so there is a potential for reducing atomic write traffic by that factor.
- NoConflicts = ReduceMode.NoConflicts¶
Perform a non-atomic read-modify-write operation.
This mode is only safe in specific situations. The caller must guarantee that there are no conflicts (i.e., scatters targeting the same elements). If specified, Dr.Jit will generate a non-atomic read-modify-update operation that potentially runs significantly faster, especially on the LLVM backend.
- Permute = ReduceMode.Permute¶
In contrast to prior enumeration entries, this one modifies plain (non-reductive) scatters and gathers. It exists to enable internal optimizations that Dr.Jit uses when differentiating vectorized function calls and compressed loops. You likely should not use it in your own code.
When setting this mode, the caller guarantees that there will be no conflicts, and that every entry is written exactly single time using an index vector representing a permutation (it’s fine this permutation is accomplished by multiple separate write operations, but there should be no more than 1 write to each element).
Giving ‘Permute’ as an argument to a (nominally read-only) gather operation is helpful because we then know that the reverse-mode derivative of this operation can be a plain scatter instead of a more costly atomic scatter-add.
Giving ‘Permute’ as an argument to a scatter operation is helpful because we then know that the forward-mode derivative does not depend on any prior derivative values associated with that array, as all current entries will be overwritten.
- Expand = ReduceMode.Expand¶
Expand the target array to avoid write conflicts, then scatter non-atomically.
This feature is only supported on the LLVM backend. Other backends interpret this flag as if
drjit.ReduceMode.Autohad been specified.This mode internally expands the storage underlying the target array to a much larger size that is proportional to the number of CPU cores. Scalar (length-1) target arrays are expanded even further to ensure that each CPU gets an entirely separate cache line.
Following this one-time expansion step, the array can then accommodate an arbitrary sequence of scatter-reduction operations that the system will internally perform using non-atomic read-modify-write operations (i.e., analogous to the
NoConflictsmode). Dr.Jit automatically re-compress the array into the ordinary representation.On bigger arrays and on machines with many cores, the storage costs resulting from this mode can be prohibitive.
- drjit.reduce(op: ReduceOp, value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Reduce the input array, tensor, or iterable along the specified axis/axes.
This function reduces arrays, tensors and other iterable Python types along one or multiple axes, where
opselects the operation to be performed:drjit.ReduceOp.Add:a[0] + a[1] + ....drjit.ReduceOp.Mul:a[0] * a[1] * ...).drjit.ReduceOp.Min:min(a[0], a[1], ...).drjit.ReduceOp.Max:max(a[0], a[1], ...).drjit.ReduceOp.Or:a[0] | a[1] | ...(integer arrays only).drjit.ReduceOp.And:a[0] & a[1] & ...(integer arrays only).
The functions
drjit.sum(),drjit.prod(),drjit.min(), anddrjit.max()are convenience aliases that calldrjit.reduce()with specific values ofop.By default, the reduction for tensor types is performed over all axes (
axis=None), while for all other types, the default isaxis=0(i.e., the outermost one), returning an instance of the array’s element type. For instance, sum-reducing an arrayaof typedrjit.cuda.Array3fis equivalent to writinga[0] + a[1] + a[2]and produces a result of typedrjit.cuda.Float. Dr.Jit can trace this operation and include it in the generated kernel.Negative indices (e.g.
axis=-1) count backward from the innermost axis. Multiple axes can be specified as a tuple. The valueaxis=Nonerequests a simultaneous reduction over all axes.When reducing axes of a tensor, or when reducing the trailing dimension of a Jit-compiled array, some special precautions apply: these axes correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. This can be done using the following strategies:
mode="evaluated"first evaluates the input array viadrjit.eval()and then launches a specialized reduction kernel.On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions. The LLVM backend parallelizes the reduction via the built-in thread pool.
mode="symbolic"usesdrjit.scatter_reduce()to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).Disadvantages of this mode are that
Atomic scatters can suffer from memory contention (though the
drjit.scatter_reduce()function takes steps to reduce contention, see its documentation for details).Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer reductions and floating point min/max reductions are unaffected by this.
mode=None(default) automatically picks a reasonable strategy according to the following logic:Use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.
Use evaluated mode when the necessary atomic reduction operation is not supported by the backend.
Otherwise, use symbolic mode.
This function generally strips away reduced axes, but there is one notable exception: it will never remove a trailing dynamic dimension, if present in the input array.
For example, reducing an instance of type
drjit.cuda.Floatalong axis0does not produce a scalar Pythonfloat. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.- Parameters:
op (ReduceOp) – The operation that should be applied along the reduced axis/axes.
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array or tensor.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- drjit.sum(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Sum-reduce the input array, tensor, or iterable along the specified axis/axes.
This function sum-reduces arrays, tensors and other iterable Python types along one or multiple axes. It is equivalent to
dr.reduce(dr.ReduceOp.Add, ...). See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.prod(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Multiplicatively reduce the input array, tensor, or iterable along the specified axis/axes.
This function performs a multiplicative reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Mul, ...). See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.min(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Perform a minimum reduction of the input array, tensor, or iterable along the specified axis/axes.
(Not to be confused with
drjit.minimum(), which computes the smaller of two values).This function performs a minimum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Min, ...). See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.max(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Perform a maximum reduction of the input array, tensor, or iterable along the specified axis/axes.
(Not to be confused with
drjit.maximum(), which computes the larger of two values).This function performs a maximum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Max, ...). See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- drjit.mean(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Compute the mean of the input array or tensor along one or multiple axes.
This function performs a horizontal sum reduction by adding values of the input array, tensor, or Python sequence along one or multiple axes and then dividing by the number of entries. The mean of an empty array is considered to be zero.
See the discussion of
dr.reduce()for important general information about the properties of horizontal reductions.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- drjit.var(value, axis=Ellipsis, mode=None, keepdims: bool = False, ddof: int = 0)¶
Compute the variance of the input along the specified axis/axes.
This is the population variance by default. Set
ddof=1to obtain the (Bessel-corrected) sample variance, which divides byN - 1instead ofN.The implementation uses the standard two-pass algorithm
m = dr.mean(value, axis=axis, keepdims=True) dr.sum((value - m) ** 2, axis=axis) / (N - ddof)
where
Nis the number of input elements that contribute to each output element. The two-pass form avoids the catastrophic cancellation of the algebraically-equivalentE[X^2] - E[X]^2expression when the mean is large compared to the variance.See
dr.reduce()for the meaning of theaxis,mode, andkeepdimsarguments.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained as size-1 dimensions in the output. Defaults toFalse.ddof (int) – “delta degrees of freedom”; the divisor used is
N - ddof. Defaults to0(population variance).
- Returns:
The variance along the specified axis/axes.
- drjit.std(value, axis=Ellipsis, mode=None, keepdims: bool = False, ddof: int = 0)¶
Compute the standard deviation of the input along the specified axis/axes.
Equivalent to
dr.sqrt(dr.var(value, ...))– seedr.var()for details on the algorithm and the meaning of theddofparameter, anddr.reduce()foraxis,mode, andkeepdims.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce.
mode (str | None) – optional evaluation strategy.
keepdims (bool) – if
True, the reduced axes are retained as size-1 dimensions in the output.ddof (int) – delta degrees of freedom (default
0).
- Returns:
The standard deviation along the specified axis/axes.
- drjit.argmin(value: ArrayBase, /, axis: int | None = None, keepdims: bool = False) ArrayBase¶
Return the indices of the minimum values along an axis.
When
axisisNone, the index refers to the flattened input. Whenaxisis an integer, the reduction operates along that axis and the result has the same number of dimensions as the input (minus one, unlesskeepdimsisTrue).When multiple elements share the minimum value, the smallest index is returned.
- Parameters:
value – Input array or tensor.
axis (int | None) – Axis along which to operate.
Nonereduces over all elements.keepdims (bool) – If
True, the reduced axis is retained as a length-one dimension.
- Returns:
An unsigned 32-bit integer array or tensor containing the indices of the minimum values.
- Return type:
object
- drjit.argmax(value: ArrayBase, /, axis: int | None = None, keepdims: bool = False) ArrayBase¶
Return the indices of the maximum values along an axis.
When
axisisNone, the index refers to the flattened input. Whenaxisis an integer, the reduction operates along that axis and the result has the same number of dimensions as the input (minus one, unlesskeepdimsisTrue).When multiple elements share the maximum value, the smallest index is returned.
- Parameters:
value – Input array or tensor.
axis (int | None) – Axis along which to operate.
Nonereduces over all elements.keepdims (bool) – If
True, the reduced axis is retained as a length-one dimension.
- Returns:
An unsigned 32-bit integer array or tensor containing the indices of the maximum values.
- Return type:
object
- drjit.sort(value: ArrayT, /, axis: int = -1, descending: bool = False) ArrayT¶
Sort the elements of an array or tensor.
For 1D arrays, returns a sorted copy. For tensors, sorts along the specified axis (default: last axis). The sort is stable: elements with equal values preserve their original relative order.
Uses a multi-pass radix sort internally (3 passes for 32-bit types on CPU, 4 on GPU).
- Parameters:
value – Input array or tensor.
axis (int) – Axis along which to sort. Only
-1(last axis) and0are currently supported for tensors.descending (bool) – If
True, sort in descending order.
- Returns:
A sorted copy of the input with the same type and shape.
- Return type:
object
- drjit.argsort(value: ArrayBase, /, axis: int = -1, descending: bool = False) ArrayBase¶
Return the indices that would sort an array or tensor.
For 1D arrays, returns index array such that
dr.gather(type(value), value, result)is sorted. For tensors, returns indices along the specified axis.The sort is stable: among equal elements, the original index order is preserved.
- Parameters:
value – Input array or tensor.
axis (int) – Axis along which to sort. Only
-1(last axis) and0are currently supported for tensors.descending (bool) – If
True, sort in descending order.
- Returns:
An unsigned 32-bit integer array or tensor containing the sorting indices.
- Return type:
object
- drjit.all(value: object, axis: int | tuple[int, ...] | ... | None = ..., keepdims: bool = False) object¶
Check if all elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
&(AND) operator.Reductions along index
0refer to the outermost axis and negative indices (e.g.-1) count backwards from the innermost axis. The special argumentaxis=Nonecauses a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to beTrue.The function is internally based on
dr.reduce(). See the documentation of this function for further information.Like
dr.reduce(), this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Booldoes not produce a scalar Pythonbool. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
boolif they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.all(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt(),while_loop(), etc. The@dr.syntaxdecorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.any(value: object, axis: int | tuple[int, ...] | ... | None = ..., keepdims: bool = False) object¶
Check if any elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
|(OR) operator.Reductions along index
0refer to the outermost axis and negative indices (e.g.-1) count backwards from the innermost axis. The special argumentaxis=Nonecauses a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to beFalse.The function is internally based on
dr.reduce(). See the documentation of this function for further information.Like
dr.reduce(), this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Booldoes not produce a scalar Pythonbool. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
boolif they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.any(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt(),while_loop(), etc. The@dr.syntaxdecorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
Result of the reduction operation
- Return type:
bool | drjit.ArrayBase
- drjit.none(value: object, axis: int | tuple[int, ...] | ... | None = ..., keepdims: bool = False) object¶
Check if none elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
|(OR) operator and finally returns the bit-wise inverse of the result.The function is internally based on
dr.reduce(). See the documentation of this function for further information.Like
dr.reduce(), this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Booldoes not produce a scalar Pythonbool. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
boolif they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.none(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt(),while_loop(), etc. The@dr.syntaxdecorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
Result of the reduction operation
- Return type:
bool | drjit.ArrayBase
- drjit.count(value: object, axis: int | tuple[int, ...] | ... | None = ..., keepdims: bool = False) object¶
Compute the number of active entries along the given axis.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
+operator (interpretingTrueelements as1andFalseelements as0). It returns an unsigned 32-bit version of the input array.Reductions along index
0refer to the outermost axis and negative indices (e.g.-1) count backwards from the innermost axis. The special argumentaxis=Nonecauses a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to be zero.See the discussion of
dr.reduce()for important general information about the properties of horizontal operations.- Parameters:
value (bool | Sequence | drjit.ArrayBase) – A Python or Dr.Jit mask type
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
Result of the reduction operation
- Return type:
int | drjit.ArrayBase
- drjit.dot(arg0: object, arg1: object, /) object¶
Compute the dot product of two arrays.
Whenever possible, the implementation uses a sequence of
fma()(fused multiply-add) operations to evaluate the dot product. When the input is a 1D JIT array likedrjit.cuda.Float, the function evaluates the product of the input arrays viadrjit.eval()and then performs a sum reduction viadrjit.sum().See the discussion of
dr.reduce()for important general information about the properties of horizontal reductions.- Parameters:
arg0 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Dot product of inputs
- Return type:
float | int | drjit.ArrayBase
- drjit.abs_dot(arg0: object, arg1: object, /) object¶
Compute the absolute value of the dot product of two arrays.
This function implements a convenience short-hand for
abs(dot(arg0, arg1)).See the discussion of
dr.reduce()for important general information about the properties of horizontal reductions.- Parameters:
arg0 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Absolute value of the dot product of inputs
- Return type:
float | int | drjit.ArrayBase
- drjit.squared_norm(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Computes the squared 2-norm of a Dr.Jit array, tensor, or Python sequence along the specified axis/axes.
The operation is equivalent to
dr.sum(value*value, axis=axis, keepdims=keepdims, mode=mode)
For the default arguments, the implementation routes through
dr.dot(value, value)to use the specialized dot-product kernels (e.g., the fused dot reduction for 1D Jit arrays).The
squared_norm()operation performs a horizontal reduction. See the discussion ofdr.reduce()for important general information about their properties.- Parameters:
value (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
squared 2-norm of the input
- Return type:
float | int | drjit.ArrayBase
- drjit.norm(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None, keepdims: bool = False) object¶
Computes the 2-norm of a Dr.Jit array, tensor, or Python sequence along the specified axis/axes.
The operation is equivalent to
dr.sqrt(dr.sum(value*value, axis=axis, keepdims=keepdims, mode=mode))
For the default arguments, the implementation routes through
dr.sqrt(dr.dot(value, value))to use the specialized dot-product kernels (e.g., the fused dot reduction for 1D Jit arrays).The
norm()operation performs a horizontal reduction. See the discussion ofdr.reduce()for important general information about their properties.- Parameters:
value (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=Nonecauses a simultaneous reduction over all axes. The defaultaxis=...applies a reduction over all axes for tensor types and index0otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated","symbolic", orNone.keepdims (bool) – if
True, the reduced axes are retained in the output as size-1 dimensions. Defaults toFalse.
- Returns:
2-norm of the input
- Return type:
float | int | drjit.ArrayBase
Prefix reductions¶
- drjit.prefix_reduce(op: ReduceOp, value: T, axis: int | tuple[int, ...] = 0, exclusive: bool = True, reverse: bool = False) T¶
Compute an exclusive or inclusive prefix reduction of the input array, tensor, or iterable along the specified axis/axes.
The function returns an output array of the same shape as the input. The
opparameter selects the operation to be performed.For example, when reducing a 1D array using
exclusive=True(the default), this produces the following outputdrjit.ReduceOp.Add:[0, a[0], a[0] + a[1], ...].drjit.ReduceOp.Mul:[1, a[0], a[0] * a[1], ...].drjit.ReduceOp.Min:[inf, a[0], min(a[0], a[1]), ...].drjit.ReduceOp.Max:[-inf, a[0], max(a[0], a[1]), ...].drjit.ReduceOp.Or:[0, a[0], a[0] | a[1], ...](integer arrays only).drjit.ReduceOp.And:[-1, a[0], a[0] & a[1], ...](integer arrays only).
With
inclusive=False, the function instead performs an inclusive prefix reduction, which effectively shifts the output by one entry:drjit.ReduceOp.Add:[a[0], a[0] + a[1], ...].drjit.ReduceOp.Mul:[a[0], a[0] * a[1], ...].drjit.ReduceOp.Min:[a[0], min(a[0], a[1]), ...].drjit.ReduceOp.Max:[a[0], max(a[0], a[1]), ...].drjit.ReduceOp.Or:[a[0], a[0] | a[1], ...](integer arrays only).drjit.ReduceOp.And:[a[0], a[0] & a[1], ...](integer arrays only).
By default, the reduction is along axis
0(i.e., the outermost one). Negative indices (e.g.axis=-1) count backward from the innermost axis. Multiple axes can be specified as a tuple and are handled iteratively.- Parameters:
op (ReduceOp) – The operation that should be applied along the specified axis/axes.
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array or tensor.
axis (int | tuple[int, ...]) – The axis/axes along which to reduce. The default value is
0.exclusive (bool) – Whether to perform an exclusive (the default) or inclusive prefix reduction.
reverse (bool) – if set to
True, the prefix reduction is done from the end of the selected axis.
- Returns:
The prefix-reduced array or tensor as specified above. It has the same shape and type as the input.
- drjit.prefix_sum(value: T, axis: int | tuple[int, ...] = 0, reverse: bool = False) T¶
Compute an exclusive prefix sum of the input array.
This function is a convenience wrapper that internally calls
drjit.prefix_reduce(dr.ReduceOp.Add, arg, exclusive=True, ...). Please refer to this function for further detail.
- drjit.cumsum(value: T, axis: int | tuple[int, ...] = 0, reverse: bool = False) T¶
Compute an cumulative sum (aka. inclusive prefix sum) of the input array.
This function is a convenience wrapper that internally calls
drjit.prefix_reduce(dr.ReduceOp.Add, arg, exclusive=False, ...). Please refer to this function for further detail.
Block reductions¶
- drjit.block_reduce(op: ReduceOp, value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T¶
Reduce elements within blocks.
This function reduces all elements within contiguous blocks of size
block_sizealong the trailing dimension of the input arrayvalue, returning a correspondingly smaller output array. Various types of reductions are supported (seedrjit.ReduceOpfor details).For example, a sum reduction of a hypothetical array
[a, b, c, d, e, f]withblock_size=2produces the output[a+b, c+d, e+f].The function raises an exception when the length of the trailing dimension is not a multiple of the block size. It recursively threads through nested arrays and PyTrees.
Dr.Jit uses one of two strategies to realize this operation, which can be optionally forced by specifying the
modeparameter.mode="evaluated"first evaluates the input array viadrjit.eval()and then launches a specialized reduction kernel.On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions. The LLVM backend parallelizes the operation via the built-in thread pool.
This strategy uses an increased intermediate precision (single precision) when reducing half precision arrays.
mode="symbolic"usesdrjit.scatter_reduce()to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input array is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).Disadvantages of this mode are that
Atomic scatters can suffer from memory contention (though
drjit.scatter_reduce()takes steps to reduce contention, see its documentation for details).Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer and floating point min/max reductions are unaffected by this.
mode=None(default) automatically picks a reasonable strategy according to the following logic:Symbolic mode is admissible when the necessary atomic reduction is supported by the backend.
Evaluated mode is admissible when the input does not involve symbolic variables.
If only one strategy remains, then pick that one. Raise an exception when no strategy works out.
Otherwise, use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.
Use symbolic mode in all other cases.
On the CUDA backend, you may observe speedups you design your program so that it invokes
drjit.block_reduce()with power-of-two block sizes, as the underlying kernel optimizes this case.Note
This operation traverses PyTrees and transforms any dynamically sized Dr.Jit arrays it encounters. Everything else is left as-is.
Tensors are not supported and will cause an exception to be raised. While
drjit.block-reduce()is internally used by Dr.Jit when reducing tensors, the function on its own does not support tensor-valued inputs.To reduce blocks within a tensor, reshape it and call the regular tensor-compatible reduction operations (e.g.,
drjit.sum(),drjit.prod(),drjit.min(),drjit.max(), or the generaldrjit.reduce()).For example, to sum-reduce a
(16, 16)tensor by a factor of(4, 2)(i.e., to a(4, 8)-sized tensor), writeresult = dr.sum(dr.reshape(value, shape=(4, 4, 8, 2)), axis=(1, 3))
- Parameters:
value (object) – A Dr.Jit array or PyTree
block_size (int) – The size of contiguous blocks to be reduced.
mode (str | None) – optional parameter to force an evaluation strategy.
- Returns:
The block-reduced array or PyTree as specified above.
- drjit.block_sum(value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T¶
Sum elements within blocks.
This is a convenience alias for
drjit.block_reduce()withopset todrjit.ReduceOp.Add.
- drjit.block_prefix_reduce(op: ReduceOp, value: ArrayT, block_size: int, exclusive: bool = True, reverse: bool = False) ArrayT¶
Compute a blocked exclusive or inclusive prefix reduction of the input array.
Starting from the identity element of the specified reduction
op, this function reduces along increasingly large prefixes, returning an output array of the same size and type.For example, when reducing a 1D array using
exclusive=True(the default), this produces the following outputdrjit.ReduceOp.Add:[0, a[0], a[0] + a[1], ...].drjit.ReduceOp.Mul:[1, a[0], a[0] * a[1], ...].drjit.ReduceOp.Min:[inf, a[0], min(a[0], a[1]), ...].drjit.ReduceOp.Max:[-inf, a[0], max(a[0], a[1]), ...].drjit.ReduceOp.Or:[0, a[0], a[0] | a[1], ...](integer arrays only).drjit.ReduceOp.And:[-1, a[0], a[0] & a[1], ...](integer arrays only).
With
inclusive=False, the function instead performs an inclusive prefix reduction, which effectively shifts the output by one entry:drjit.ReduceOp.Add:[a[0], a[0] + a[1], ...].drjit.ReduceOp.Mul:[a[0], a[0] * a[1], ...].drjit.ReduceOp.Min:[a[0], min(a[0], a[1]), ...].drjit.ReduceOp.Max:[a[0], max(a[0], a[1]), ...].drjit.ReduceOp.Or:[a[0], a[0] | a[1], ...](integer arrays only).drjit.ReduceOp.And:[a[0], a[0] & a[1], ...](integer arrays only).
The reduction is furthermore blocked, which means that it restarts after
block_sizeentries. To reduce the entire array, simply setblock_size=len(value).Finally, the reduction can optionally be done from the end of each block by specifying
reverse=True.This operation traverses PyTrees and transforms any dynamically sized Dr.Jit arrays it encounters. Everything else is left as-is.
- Parameters:
value (object) – A Dr.Jit array or PyTree
block_size (int) – The size of contiguous blocks to be reduced.
exclusive (bool) – Specifies whether or not the prefix sum should be exclusive (the default) or inclusive.
reverse (bool) – if set to
True, the reduction is done from the end of each block.
- Returns:
The block-reduced array or PyTree as specified above.
- drjit.block_prefix_sum(value: T, block_size: int, exclusive: bool = True, reverse: bool = False) T¶
Convenience wrapper around
dr.block_prefix_reduce(dr.ReduceOp.Add, ...).
Rearranging array contents¶
- drjit.concat(arr: Sequence[ArrayT], /, axis: int | None = 0) ArrayT¶
Concatenate a sequence of arrays or tensors along a given axis.
The inputs must all be of the same type, and they must have the same shape except for the axis being concatenated. Negative
axisvalues count backwards from the last dimension.When
axis=None, the function ravels the input arrays or tensors prior to concatenating them.
- drjit.stack(arrays: Sequence[ArrayT], /, axis: int = 0) ArrayT¶
Join a sequence of tensors along a new axis.
All input tensors must have the same type and shape. The result has one more dimension than the inputs: a new axis of size
len(arrays)is inserted at positionaxis.For example, stacking two tensors of shape
(M, N)withaxis=0produces shape(2, M, N), whileaxis=1produces(M, 2, N).Unlike
concat(), which joins arrays along an existing axis,stackalways creates a new one.- Parameters:
arrays – Sequence of tensors. All must have the same type and shape.
axis (int) – The position of the new axis in the result. Negative values count backwards from the last dimension.
- Returns:
A tensor with
ndim + 1dimensions.- Return type:
object
- drjit.vstack(arrays: Sequence[ArrayT], /) ArrayT¶
Stack tensors vertically (row-wise).
Concatenates tensors along the first axis. 1D tensors of shape
(N,)are first reshaped to(1, N)so that they become rows. The result is always at least 2D.The inputs must have the same shape along all but the first axis.
row_stackis an alias for this function.- Parameters:
arrays – Sequence of tensors to stack.
- Returns:
A tensor with at least two dimensions formed by vertical concatenation.
- Return type:
object
- drjit.hstack(arrays: Sequence[ArrayT], /) ArrayT¶
Stack tensors horizontally (column-wise).
For tensors with two or more dimensions, this concatenates along the second axis. For 1D tensors, it concatenates along the first (and only) axis.
The inputs must have the same shape along all but the concatenation axis.
- Parameters:
arrays – Sequence of tensors to stack.
- Returns:
A tensor formed by horizontal concatenation.
- Return type:
object
- drjit.column_stack(arrays: Sequence[ArrayT], /) ArrayT¶
Stack 1D tensors as columns into a 2D tensor.
Takes a sequence of 1D tensors and stacks them as columns to produce a 2D result. Each 1D tensor of shape
(N,)is first reshaped to(N, 1), then all inputs are concatenated along axis 1.2D (or higher) tensors are concatenated along axis 1 as-is, like
hstack().All inputs must have the same first dimension.
- Parameters:
arrays – Sequence of tensors to stack.
- Returns:
A tensor formed by column-wise concatenation.
- Return type:
object
- drjit.dstack(arrays: Sequence[ArrayT], /) ArrayT¶
Stack tensors depth-wise (along the third axis).
Concatenates tensors along the third axis after reshaping them to at least 3D. 1D tensors of shape
(N,)become(1, N, 1)and 2D tensors of shape(M, N)become(M, N, 1)before concatenation. The result is always at least 3D.The inputs must have the same shape along all but the third axis.
- Parameters:
arrays – Sequence of tensors to stack.
- Returns:
A tensor with at least three dimensions formed by depth-wise concatenation.
- Return type:
object
- drjit.expand_dims(value: ArrayT, /, axis: int | Tuple[int, ...]) ArrayT¶
Expand the shape of a tensor by inserting new length-one axes.
The
axisparameter specifies where new axes are placed in the output shape. For example, given a tensor of shape(3, 4):axis=0produces shape(1, 3, 4)axis=1produces shape(3, 1, 4)axis=-1produces shape(3, 4, 1)
When
axisis a tuple, multiple axes are inserted simultaneously. The axis positions refer to the output shape, and negative values count backwards from the last output dimension.This operation does not copy data; it returns a view of the input with a different shape.
- Parameters:
value – Input tensor.
axis (int | tuple[int, ...]) – Position(s) in the output shape where new length-one axes should be inserted.
- Returns:
A tensor with the same underlying data and one or more additional length-one dimensions.
- Return type:
object
- drjit.squeeze(value: ArrayT, /, axis: int | Tuple[int, ...] | None = None) ArrayT¶
Remove length-one axes from a tensor.
When
axisisNone(the default), all length-one axes are removed. Whenaxisis an integer or tuple of integers, only those axes are removed, and an error is raised if any of them does not have length one.Negative axis values count backwards from the last dimension.
This is the inverse of
expand_dims().- Parameters:
value – Input tensor.
axis (int | tuple[int, ...] | None) – The axis or axes to remove. If
None, all length-one axes are removed.
- Returns:
A tensor with the specified length-one dimensions removed.
- Return type:
object
- drjit.split(value: ArrayT, indices_or_sections: int | Sequence[int], /, axis: int = 0) List[ArrayT]¶
Split a tensor into multiple parts along an axis.
When
indices_or_sectionsis an integer N, the tensor is divided into N equal parts alongaxis. The axis size must be divisible by N; usearray_split()to allow unequal parts.When
indices_or_sectionsis a sequence of indices[i, j, ...], the splits occur before those positions along the given axis, producing sections[:i],[i:j],[j:], etc.- Parameters:
value – Input tensor.
indices_or_sections (int | Sequence[int]) – Either the number of equal-sized sections or a sequence of split points.
axis (int) – The axis along which to split. Negative values count backwards from the last dimension.
- Returns:
A list of tensors.
- Return type:
list
- drjit.array_split(value: ArrayT, sections: int, /, axis: int = 0) List[ArrayT]¶
Split a tensor into approximately equal parts along an axis.
Like
split(), but allows the axis size to not be evenly divisible bysections. The firstn % sectionschunks have one extra element.- Parameters:
value – Input tensor.
sections (int) – Number of output sections.
axis (int) – The axis along which to split. Negative values count backwards from the last dimension.
- Returns:
A list of tensors.
- Return type:
list
- drjit.reverse(value, axis: int = 0)¶
Reverses the given Dr.Jit array or Python sequence along the specified axis.
- Parameters:
value (ArrayBase|Sequence) – Dr.Jit array or Python sequence type
axis (int) – Axis along which the reversal should be performed. Only
axis==0is supported for now.
- Returns:
An output of the same type as value containing a copy of the reversed array.
- Return type:
object
- drjit.moveaxis(arg: ArrayBase, /, source: int | Tuple[int, ...], destination: int | Tuple[int, ...])¶
Move one or more axes of an input tensor to another position.
Dimensions of that are not explicitly moved remain in their original order and appear at the positions not specified in the destination. Negative axis values count backwards from the end.
- drjit.transpose(value: ArrayT, /, axes: Tuple[int, ...] | None = None) ArrayT¶
Permute the axes of a tensor.
When
axesisNone, the axis order is reversed (the default). Whenaxesis a tuple, it must be a permutation of(0, 1, ..., ndim-1)specifying the new axis order.For example, given a tensor of shape
(2, 3, 4):transpose(a)produces shape(4, 3, 2)transpose(a, (2, 0, 1))produces shape(4, 2, 3)
- Parameters:
value – Input tensor.
axes (tuple[int, ...] | None) – The desired axis order. If
None, reverses all axes.
- Returns:
A tensor with permuted axes.
- Return type:
object
- drjit.swapaxes(value: ArrayT, /, axis1: int, axis2: int) ArrayT¶
Swap two axes of a tensor.
For example, given a tensor of shape
(2, 3, 4):swapaxes(a, 0, 2)produces shape(4, 3, 2)swapaxes(a, 0, 1)produces shape(3, 2, 4)
Negative axis values count backwards from the last dimension.
- Parameters:
value – Input tensor.
axis1 (int) – First axis.
axis2 (int) – Second axis.
- Returns:
A tensor with the two axes exchanged.
- Return type:
object
- drjit.take(value: ArrayT, index: int | ArrayBase, axis: int = 0) ArrayT¶
Select values from a tensor along a specified axis using an index or index array.
This function evaluates
value[..., index, ...]whereindexis applied at positionaxis. The output tensor has one fewer dimension than the input.- Parameters:
value (drjit.ArrayBase) – Input tensor
index (Union[int, drjit.ArrayBase]) – Integer or 1D integer array.
axis (int) – Axis along which to select values. Negative values count from the end. The default is 0.
- Returns:
Output tensor with shape equal to the input shape minus the indexed axis dimension. The dtype matches the input tensor.
- Return type:
- drjit.take_interp(value: ArrayT, pos: float | ArrayBase, axis: int = 0) ArrayT¶
Select and interpolate values from a tensor along a specified axis using fractional indices.
Similar to
drjit.take(), but accepts fractional positions and performs linear interpolation between adjacent values along the specified axis. This is useful for smooth sampling from discrete data.- Parameters:
value (drjit.ArrayBase) – Input tensor
pos (Union[float, drjit.ArrayBase]) – Python
floator 1D float array.axis (int) – Axis along which to interpolate values. Negative values count from the end. Default is 0.
- Returns:
Output tensor with shape equal to the input shape minus the indexed axis dimension. Values are linearly interpolated based on the fractional indices. The dtype matches the input tensor.
- Return type:
- drjit.compress(arg: drjit.ArrayBase, /) object¶
Compress a mask into an array of nonzero indices.
This function takes an boolean array as input and then returns an unsigned 32-bit integer array containing the indices of nonzero entries.
It can be used to reduce a stream to a subset of active entries via the following recipe:
# Input: an active mask and several arrays data_1, data_2, ... dr.schedule(active, data_1, data_2, ...) indices = dr.compress(active) data_1 = dr.gather(type(data_1), data_1, indices) data_2 = dr.gather(type(data_2), data_2, indices) # ...
There is some conceptual overlap between this function and
drjit.scatter_inc(), which can likewise be used to reduce a stream to a smaller subset of active items. Please see the documentation ofdrjit.scatter_inc()for details.Danger
This function internally performs a synchronization step.
- Parameters:
arg (bool | drjit.ArrayBase) – A Python or Dr.Jit boolean type
- Returns:
Array of nonzero indices
- drjit.ravel(array: object, order: str = 'A') object¶
Convert the input into a contiguous flat array.
This operation takes a Dr.Jit array, typically with some static and some dynamic dimensions (e.g.,
drjit.cuda.Array3fwith shape 3xN), and converts it into a flattened 1D dynamically sized array (e.g.,drjit.cuda.Float) using either a C or Fortran-style ordering convention.It can also convert Dr.Jit tensors into a flat representation, though only C-style ordering is supported in this case.
Internally,
drjit.ravel()performs a series of calls todrjit.scatter()to suitably reorganize the array contents.For example,
x = dr.cuda.Array3f([1, 2], [3, 4], [5, 6]) y = dr.ravel(x, order=...)
will produce
[1, 3, 5, 2, 4, 6]withorder='F'(the default for Dr.Jit arrays), which means that X/Y/Z components alternate.[1, 2, 3, 4, 5, 6]withorder='C', in which case all X coordinates are written as a contiguous block followed by the Y- and then Z-coordinates.
- Parameters:
array (drjit.ArrayBase) – An arbitrary Dr.Jit array or tensor
order (str) – A single character indicating the index order.
'F'indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'(automatic) will use F-style ordering for arrays and C-style ordering for tensors.
- Returns:
A dynamic 1D array containing the flattened representation of
arraywith the desired ordering. The type of the return value depends on the type of the input. Whenarrayis already contiguous/flattened, this function returns it without making a copy.- Return type:
object
- drjit.unravel(dtype: type[ArrayT], array: AnyArray, order: Literal['A', 'C', 'F'] = 'A') ArrayT¶
Load a sequence of Dr.Jit vectors/matrices/etc. from a contiguous flat array.
This operation implements the inverse of
drjit.ravel(). In contrast todrjit.ravel(), it requires one additional parameter (dtype) specifying type of the return value. For example,x = dr.cuda.Float([1, 2, 3, 4, 5, 6]) y = dr.unravel(dr.cuda.Array3f, x, order=...)
will produce an array of two 3D vectors with different contents depending on the indexing convention:
[1, 2, 3]and[4, 5, 6]when unraveled withorder='F'(the default for Dr.Jit arrays), and[1, 3, 5]and[2, 4, 6]when unraveled withorder='C'
Internally,
drjit.unravel()performs a series of calls todrjit.gather()to suitably reorganize the array contents.- Parameters:
dtype (type) – An arbitrary Dr.Jit array type
array (drjit.ArrayBase) – A dynamically sized 1D Dr.Jit array instance that is compatible with
dtype. In other words, both must have the same underlying scalar type and be located imported in the same package (e.g.,drjit.llvm.ad).order (str) – A single character indicating the index order.
'F'(the default) indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'specifies row-major/C-style ordering, in which case the last index changes at the highest frequency.
- Returns:
An instance of type
dtypecontaining the result of the unravel operation.- Return type:
object
- drjit.reshape(dtype: type, value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) object¶
- drjit.reshape(dtype: type, value: object, shape: int, order: str = 'A', shrink: bool = False) object
- drjit.reshape(value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) object
- drjit.reshape(value: object, shape: int, order: str = 'A', shrink: bool = False) object
Overloaded function.
reshape(dtype: type, value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) -> object
Converts
valueinto an array of typedtypeby rearranging the contents according to the specified shape.The parameter
shapemay contain a single-1-valued target dimension, in which case its value is inferred from the remaining shape entries and the size of the input. Whenshapeis of typeint, it is interpreted as a 1-tuple(shape,).This function supports the following behaviors:
Reshaping tensors: Dr.Jit tensors admit arbitrary shapes. The
drjit.reshape()can convert between them as long as the total number of elements remains unchanged.>>> from drjit.llvm.ad import TensorXf >>> value = dr.arange(TensorXf, 6) >>> dr.reshape(value, (3, -1)) [[0, 1] [2, 3] [4, 5]]
Reshaping nested arrays: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions
drjit.ravel()anddrjit.unravel(). In this case, the targetdtypemust be specified:>>> from drjit.llvm.ad import Array2f, Array3f >>> value = Array2f([1, 2, 3], [4, 5, 6]) >>> dr.reshape(Array3f, value, shape=(3, -1), order='C') [[1, 4, 2], [5, 3, 6]] >>> dr.reshape(Array3f, value, shape=(3, -1), order='F') [[1, 3, 5], [2, 4, 6]]
(By convention, Dr.Jit nested arrays are always printed in transposed form, which explains the difference in output compared to the identically shaped Tensor example just above.)
The
orderargument can be used to specify C ("C") or Fortran ("F")-style ordering when rearranging the array. The default value"A"corresponds to Fortran-style ordering.PyTrees: When
valueis a PyTree, the operation recursively threads through the tree’s elements.
Stream compression and loops that fork recursive work. When called with
shrink=True, the function creates a view of the original data that potentially has a smaller number of elements.The main use of this feature is to implement loops that process large numbers of elements in parallel, and which need to occasionally “fork” some recursive work. On modern compute accelerators, an efficient way to handle this requirement is to append this work into a queue that is processed in a subsequent pass until no work is left. The reshape operation with
shrink=Truethen resizes the preallocated queue to the actual number of collected items, which are the input of the next iteration.Please refer to the following example that illustrates how
drjit.scatter_inc(),drjit.scatter(), anddrjit.reshape(..., shrink=True)can be combined to realize a parallel loop with a fork condition@drjit.syntax def f(): # Loop state variables (an arbitrary array or PyTree) state = ... # Determine how many elements should be processed size = dr.width(loop_state) # Run the following loop until no work is left while size > 0: # 1-element array used as an atomic counter queue_index = UInt(0) # Preallocate memory for the queue. The necessary # amount of memory is task-dependent queue_size = size queue = dr.empty(dtype=type(state), shape=queue_size) # Create an opaque variable representing the number 'loop_state'. # This keeps this changing value from being baked into the program, # which is needed for proper kernel caching queue_size_o = dr.opaque(UInt32, queue_size) while not stopping_criterion(state): # This line represents the loop body that processes work state = loop_body(state) # if the condition 'fork' is True, spawn a new work item that # will be handled in a future iteration of the parent loop. if fork(state): # Atomically reserve a slot in 'queue' slot = dr.scatter_inc(target=queue_index, index=0) # Work item for the next iteration, task dependent todo = state # Be careful not to write beyond the end of the queue valid = slot < queue_size_o # Write 'todo' into the reserved slot dr.scatter(target=queue, index=slot, value=todo, active=valid) # Determine how many fork operations took place size = queue_index[0] if size > queue_size: raise RuntimeError('Preallocated queue was too small: tried to store ' f'{size} elements in a queue of size {queue_size}') # Reshape the queue and re-run the loop state = dr.reshape(queue, shape=size, shrink=True)
- Parameters:
dtype (type) – Desired output type of the reshaped array. This could equal
type(value)or refer to an entirely different array type. Must only be specified if the target dtype is different.value (object) – An arbitrary Dr.Jit array, tensor, or PyTree. The function returns unknown objects of other types unchanged.
shape (int|tuple[int, ...]) – The target shape.
order (str) – A single character indicating the index order used to reinterpret the input.
'F'indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'(automatic) will use F-style ordering for arrays and C-style ordering for tensors.shrink (bool) – Cheaply construct a view of the input that potentially has a smaller number of elements. The main use case of this method is explained above.
- Returns:
The reshaped array or PyTree.
- Return type:
object
reshape(dtype: type, value: object, shape: int, order: str = 'A', shrink: bool = False) -> objectreshape(value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) -> object
Converts
valueinto an array of typedtypeby rearranging the contents according to the specified shape.The parameter
shapemay contain a single-1-valued target dimension, in which case its value is inferred from the remaining shape entries and the size of the input. Whenshapeis of typeint, it is interpreted as a 1-tuple(shape,).This function supports the following behaviors:
Reshaping tensors: Dr.Jit tensors admit arbitrary shapes. The
drjit.reshape()can convert between them as long as the total number of elements remains unchanged.>>> from drjit.llvm.ad import TensorXf >>> value = dr.arange(TensorXf, 6) >>> dr.reshape(value, (3, -1)) [[0, 1] [2, 3] [4, 5]]
Reshaping nested arrays: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions
drjit.ravel()anddrjit.unravel(). In this case, the targetdtypemust be specified:>>> from drjit.llvm.ad import Array2f, Array3f >>> value = Array2f([1, 2, 3], [4, 5, 6]) >>> dr.reshape(Array3f, value, shape=(3, -1), order='C') [[1, 4, 2], [5, 3, 6]] >>> dr.reshape(Array3f, value, shape=(3, -1), order='F') [[1, 3, 5], [2, 4, 6]]
(By convention, Dr.Jit nested arrays are always printed in transposed form, which explains the difference in output compared to the identically shaped Tensor example just above.)
The
orderargument can be used to specify C ("C") or Fortran ("F")-style ordering when rearranging the array. The default value"A"corresponds to Fortran-style ordering.PyTrees: When
valueis a PyTree, the operation recursively threads through the tree’s elements.
Stream compression and loops that fork recursive work. When called with
shrink=True, the function creates a view of the original data that potentially has a smaller number of elements.The main use of this feature is to implement loops that process large numbers of elements in parallel, and which need to occasionally “fork” some recursive work. On modern compute accelerators, an efficient way to handle this requirement is to append this work into a queue that is processed in a subsequent pass until no work is left. The reshape operation with
shrink=Truethen resizes the preallocated queue to the actual number of collected items, which are the input of the next iteration.Please refer to the following example that illustrates how
drjit.scatter_inc(),drjit.scatter(), anddrjit.reshape(..., shrink=True)can be combined to realize a parallel loop with a fork condition@drjit.syntax def f(): # Loop state variables (an arbitrary array or PyTree) state = ... # Determine how many elements should be processed size = dr.width(loop_state) # Run the following loop until no work is left while size > 0: # 1-element array used as an atomic counter queue_index = UInt(0) # Preallocate memory for the queue. The necessary # amount of memory is task-dependent queue_size = size queue = dr.empty(dtype=type(state), shape=queue_size) # Create an opaque variable representing the number 'loop_state'. # This keeps this changing value from being baked into the program, # which is needed for proper kernel caching queue_size_o = dr.opaque(UInt32, queue_size) while not stopping_criterion(state): # This line represents the loop body that processes work state = loop_body(state) # if the condition 'fork' is True, spawn a new work item that # will be handled in a future iteration of the parent loop. if fork(state): # Atomically reserve a slot in 'queue' slot = dr.scatter_inc(target=queue_index, index=0) # Work item for the next iteration, task dependent todo = state # Be careful not to write beyond the end of the queue valid = slot < queue_size_o # Write 'todo' into the reserved slot dr.scatter(target=queue, index=slot, value=todo, active=valid) # Determine how many fork operations took place size = queue_index[0] if size > queue_size: raise RuntimeError('Preallocated queue was too small: tried to store ' f'{size} elements in a queue of size {queue_size}') # Reshape the queue and re-run the loop state = dr.reshape(queue, shape=size, shrink=True)
- Parameters:
dtype (type) – Desired output type of the reshaped array. This could equal
type(value)or refer to an entirely different array type. Must only be specified if the target dtype is different.value (object) – An arbitrary Dr.Jit array, tensor, or PyTree. The function returns unknown objects of other types unchanged.
shape (int|tuple[int, ...]) – The target shape.
order (str) – A single character indicating the index order used to reinterpret the input.
'F'indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'(automatic) will use F-style ordering for arrays and C-style ordering for tensors.shrink (bool) – Cheaply construct a view of the input that potentially has a smaller number of elements. The main use case of this method is explained above.
- Returns:
The reshaped array or PyTree.
- Return type:
object
reshape(value: object, shape: int, order: str = 'A', shrink: bool = False) -> object
- drjit.tile(value: T, count: int) T¶
Tile the input array
counttimes along the trailing dimension.This function replicates the input
counttimes along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. Whencount==1, the function returns the input without changes.An example is shown below:
- Parameters:
value (drjit.ArrayBase) – A Dr.Jit type or PyTree.
count (int) – Number of repetitions
- Returns:
The tiled input as described above. The return type matches that of
value.- Return type:
object
- drjit.repeat(value: T, count: int) T¶
Repeat each successive entry of the input
counttimes along the trailing dimension.This function replicates the input
counttimes along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. Whencount==1, the function returns the input without changes.An example is shown below:
- Parameters:
value (drjit.ArrayBase) – A Dr.Jit type or PyTree.
count (int) – Number of repetitions
- Returns:
The repeated input as described above. The return type matches that of
value.- Return type:
object
Other operations on tensors¶
- drjit.resample(source: ArrayT, shape: Sequence[int], filter: Literal['box', 'linear', 'hamming', 'cubic', 'lanczos', 'gaussian'] | Callable[[float], float] = 'cubic', filter_radius: float | None = None) ArrayT¶
Resample an input array/tensor to increase or decrease its resolution along a set of axes.
This function up- and/or downsamples a given array or tensor along a specified set of axes. Given an input array (
source) and target shape (shape), it returns a compatible array of the specified configuration. This is implemented using a sequence of successive 1D resampling steps for each mismatched axis.Example usage:
image: TensorXf = ... # a RGB image width, height, channels = image.shape scaled_image = dr.resample( image, (width // 2, height // 2, channels) )
Resampling uses a reconstruction filter. The following options are available, where \(n\) refers to the number of dimensions being resampled:
"box": use nearest-neighbor interpolation/averaging. This is very efficient but generally produces sub-par output that is either pixelated (when upsampling) or aliased (when downsampling)."linear": use linear ramp / tent filter that uses \(2^n\) neighbors to reconstruct each output sample when upsampling. Tends to produce relatively blurry results."hamming": uses the same number of input samples as"linear"but better preserves sharpness when downscaling. Do not use for upscaling."cubic": use cubic filter kernel that queries \(4^n\) neighbors to reconstruct each output sample when upsampling. Produces high-quality results. This is the default."lanczos": use a windowed Lanczos filter that queries \(6^n\) neighbors to reconstruct each output sample when upsampling. This is the best filter for smooth signals, but also the costliest. The Lanczos filter is susceptible to ringing when the input array contains discontinuities."gaussian": use a Gaussian filter that queries :math:4^n` neighbors to reconstruct each output sample when upsampling. The kernel has a standard deviation of 0.5 and is truncated after 4 standard deviations. This filter is mainly useful when intending to blur a signal.Besides the above choices, it is also possible to specify a custom filter. To do so, use the
filterargument to pass a Python callable with signatureCallable[[float], float]. In this case, you must also specify a filter radius via thefilter_radiusparameter.
The implementation was extensively tested against Image.resize() from the Pillow library and should be a drop-in replacement with added support for JIT tracing / GPU evaluation, differentiability, and compatibility with higher-dimensional tensors.
Warning
When using
filter="hamming","cubic", or"lanczos", the range of the output array can exceed that of the input array. For example, positive-valued data may contain negative values following resampling. Clamp the output in case it is important that array values remain within a fixed range (e.g., \([0,1]\)).- Parameters:
source (dr.ArrayBase) – The Dr.Jit tensor or 1D array to be resampled.
shape (Sequence[int]) – The desired output shape.
filter (str | Callable[[float], float]) – The desired reconstruction filter, see the above text for an overview. Alternatively, a custom reconstruction filter function can also be specified.
filter_radius (float | None) – The radius of the pixel filter in the output sample space. Should only be specified when using a custom reconstruction filter.
- Returns:
The resampled output array. Its type matches
source, and its shape matchesshape.- Return type:
- drjit.convolve(source: ArrayT, filter: Literal['box', 'linear', 'hamming', 'cubic', 'lanczos', 'gaussian'] | Callable[[float], float], filter_radius: float, axis: int | Tuple[int, ...] | None = None) ArrayT¶
Convolve one or more axes of an input array/tensor with a 1D filter
This function filters one or more axes of a Dr.Jit array or tensor, for example to convolve an image with a 2D Gaussian filter to blur spatial detail.
image: TensorXf = ... # a RGB image blured_image = dr.convolve( image, filter='gaussian', filter_radius=10 )
The filter weights are renormalized to reduce edge effects near the boundary of the array.
The function supports a set of provided filters, and custom filters can also be specified. This works analogously to the
resample()function, please refer to its documentation for detail.- Parameters:
source (dr.ArrayBase) – The Dr.Jit tensor or 1D array to be resampled.
filter (str | Callable[[float], float]) – The desired reconstruction filter, see the above text for an overview. Alternatively, a custom reconstruction filter function can also be specified.
filter_radius (float) – The radius of the continous function to be used in the convolution.
axis (int | tuple[int, ...] | ... | None) – The axis or set of axes along which to convolve. The default argument
axis=Nonecauses all axes to be convolved. Negative values count from the last dimension.
- Returns:
The resampled output array. Its type matches
source.- Return type:
- drjit.matmul(A: object, B: object, At: bool = False, Bt: bool = False) object¶
Compute a matrix-matrix, matrix-vector, vector-matrix, or inner product.
This function implements the semantics of the
@operator introduced in Python’s PEP 465. There is no practical difference between usingdrjit.matul()or@in Dr.Jit-based code. Multiplication of matrix types (e.g.,drjit.scalar.Matrix2f) using the standard multiplication operator (*) is also based on matrix multiplication.This function takes two Dr.Jit arrays and picks one of the following 5 cases based on their leading fixed-size dimensions.
Matrix-matrix product: If both arrays have leading static dimensions
(n, n), they are multiplied like conventional matrices.Matrix-vector product: If
arg0has leading static dimensions(n, n)andarg1has leading static dimension(n,), the operation conceptually appends a trailing 1-sized dimension toarg1, multiplies, and then removes the extra dimension from the result.Vector-matrix product: If
arg0has leading static dimensions(n,)andarg1has leading static dimension(n, n), the operation conceptually prepends a leading 1-sized dimension toarg0, multiplies, and then removes the extra dimension from the result.Inner product: If
arg0andarg1have leading static dimensions(n,), the operation returns the sum of the elements ofarg0*arg1.Scalar product: If
arg0orarg1is a scalar, the operation scales the elements of the other argument.
It is legal to combine vectorized and non-vectorized types, e.g.
dr.matmul(dr.scalar.Matrix4f(...), dr.cuda.Matrix4f(...))
Also, note that doesn’t matter whether an input is an instance of a matrix type or a similarly-shaped nested array—for example,
drjit.scalar.Matrix3f()anddrjit.scalar.Array33f()have the same shape and are treated identically.Tensor matrix multiplication. In addition to the fixed-size cases above, this function also accepts two Dr.Jit tensors of matching element type and backend. It fully replicates the NumPy / PyTorch
matmulsemantics including batched matrix products, broadcasting of leading batch axes, matrix-vector products, and inner products. The optionalAt/Btflags transpose the associated operand on the fly at essentially no cost on either backend.Under the hood, the CUDA backend uses a block-matrix multiplication in which each thread block cooperatively stages tiles of
AandBthrough shared memory and accumulates the output tile in registers. The CPU backend uses a GotoBLAS-style tiled GEMM with a vectorized microkernel, parallelized over both axes of the output tile grid via nanothread, so shapes with few output rows still use every core. Broadcasts along batch dimensions are consumed directly by the kernel via zero strides, and under automatic differentiation the reverse-mode gradient of a broadcast operand folds its sum-over-batch into the backward GEMM’s contraction, so no expanded copy of a broadcast operand is materialized in either the primal or the derivative.Half-precision inputs are multiplied and summed in single precision throughout the reduction; the result is narrowed to half precision only at the final store.
Note that the CUDA implementation does not use the tensor cores available on recent NVIDIA GPUs, which can greatly accelerate half-precision math. For fp16 matmuls, Dr.Jit is therefore not competitive with PyTorch.
Supported element types are
drjit.VarType.Float16,drjit.VarType.Float32,drjit.VarType.Float64,drjit.VarType.Int32, anddrjit.VarType.UInt32.Note
Performance tips (CUDA). The tensor matmul ships a small family of precompiled kernels and picks the largest tile it can align to the operand strides. To land on the fastest path, the contiguous dimensions of both operands and
Nshould be divisible byV, whereV = 8forFloat16,V = 4forFloat32/Int32/UInt32,V = 2forFloat64.
When this divisibility doesn’t hold the kernel falls back to a smaller tile with scalar loads, which can be an order of magnitude slower. The CPU (LLVM) backend is not affected by this.
The fixed-size array path (types such as
drjit.cuda.Matrix4f(), as opposed to Dr.Jit tensors) only handles small matrices whose dimensions are known at compile time. To multiply large dynamic matrices, use the N-D tensor path described above.- Parameters:
arg0 (dr.ArrayBase) – Dr.Jit array or Dr.Jit tensor.
arg1 (dr.ArrayBase) – Dr.Jit array or Dr.Jit tensor.
At (bool) – If
True, transpose the last two dimensions ofarg0on the fly. Only applies to the tensor path and is invalid whenarg0is 1-D. Defaults toFalse.Bt (bool) – If
True, transpose the last two dimensions ofarg1on the fly. Only applies to the tensor path and is invalid whenarg1is 1-D. Defaults toFalse.
- Returns:
The result of the operation as defined above
- Return type:
object
Random number generation¶
- drjit.rng(seed: ArrayBase | int = 0, method='philox4x32', symbolic: bool = False) Generator¶
Return a seeded random number generator.
This function returns a
drjit.random.Generatorobject. Note the following:Differently seeded random number generators produce statistically independent streams of random variates.
seedcan be a Python int or Dr.JitUInt64-typed array. The default value0is used when no seed is specified, making the generator’s behavior deterministic across runs.Only
method=philox4x32is supported at the moment. This returns a generator object wrapping thePhilox4x32counter-based PRNG.When
symbolic=Trueis specified, the internal sampler state will never be explicitly evaluated. This is useful in cases where you wish to explicitly bake these constants into the generated program. Dr.Jit also detects when the sampler is used in a symbolic code block (e.g., a symbolic loop) and automaticallky sets this flag in such a case.
- class drjit.random.Generator¶
- random(dtype: Type[ArrayT], shape: int | Tuple[int, ...]) ArrayT¶
Return a Dr.Jit array or tensor containing uniformly distributed pseudorandom variates.
This function supports floating point arrays/tensors of various configurations and precisions, e.g.:
from drjit.cuda import Float, TensorXf, Array3f, Matrix4f # Example usage rng = dr.rng(seed=0) rand_array = rng.random(Float, 128) rand_tensor = rng.random(TensorXf16, shape=(128, 128)) rand_vec = rng.random(Array3f, (3, 128)) rand_mat = rng.random(Matrix4f64, (4, 4, 128))
The output is uniformly distributed the half-open interval \([0, 1)\). Integer arrays are not supported.
- Parameters:
source (type[ArrayT]) – A Dr.Jit tensor or array type.
shape (int | tuple[int, ...]) – The target shape
- Returns:
The generated array of random variates.
- Return type:
ArrayT
- uniform(dtype: Type[ArrayT], shape: int | Tuple[int, ...], low: float | ArrayBase = 0.0, high: float | ArrayBase = 1.0)¶
Return a Dr.Jit array or tensor containing uniformly distributed pseudorandom variates.
This function resembles
random()but additionally ensures that variates are distributed on the half-open interval \([ exttt{low}, exttt{high})\).- Parameters:
source (type[ArrayT]) – A Dr.Jit tensor or array type.
shape (int | tuple[int, ...]) – The target shape
low (float | drjit.ArrayBase) – The low value of the desired interval
high (float | drjit.ArrayBase) – The high value of the desired interval
- Returns:
The generated array of random variates.
- Return type:
ArrayT
- normal(dtype: Type[ArrayT], shape: int | Tuple[int, ...], loc: float | ArrayBase = 0.0, scale: float | ArrayBase = 1.0) ArrayT¶
Return a Dr.Jit array or tensor containing pseudorandom variates following a standard normal distribution
This function supports arrays/tensors of various configurations and precisions–see the similar
drjit.random()for examples on how to call this function.- Parameters:
source (type[ArrayT]) – A Dr.Jit tensor or array type.
shape (int | tuple[int, ...]) – The target shape
loc (float | drjit.ArrayBase) – The mean of the normal distribution (
0.0by default)scale (float | drjit.ArrayBase) – The standard deviation of the normal distribution (
1.0by default)
- Returns:
The generated array of random variates.
- Return type:
ArrayT
- integers(dtype: Type[ArrayT], shape: int | Tuple[int, ...], low: int | ArrayBase = 0, high: int | ArrayBase | None = None, endpoint: bool = False) ArrayT¶
Return a Dr.Jit array or tensor containing uniformly distributed pseudorandom 32-bit integers.
If
highisNone, integers are drawn from[0, low). Otherwise, integers are drawn from[low, high), or[low, high]ifendpoint=True.- Parameters:
dtype (type[ArrayT]) – A Dr.Jit 32-bit integer array type (
Int32orUInt32).shape (int | tuple[int, ...]) – The target shape
low (int | drjit.ArrayBase) – Lowest integer to be drawn (inclusive). If
highisNone, this parameter specifies the exclusive upper bound, and 0 becomes the lower bound.high (int | drjit.ArrayBase | None) – If provided, one above the largest integer to be drawn. If
endpoint=True, this is the largest integer to be drawn (inclusive).endpoint (bool) – If
True,highis inclusive. Default:False.
- Returns:
The generated array of random integers.
- Return type:
ArrayT
- permutation(dtype: Type[ArrayT], n: int, /) ArrayT¶
- permutation(x: ArrayT, /, axis: int = 0) ArrayT
Generate a random permutation of elements.
If
xis a type andnis an integer, return a random permutation ofdr.arange(dtype, n).If
xis a 1D Dr.Jit array, return a new array containing the same elements in a uniformly random order.If
xis a tensor, shuffle along the givenaxis(default0): each slice along that axis is kept intact, but their order is randomized.The implementation generates uniform random float keys and calls
drjit.argsort()to obtain a permutation using an efficient parallel radix sort on both CPU and GPU backends.Overloaded signatures:
- permutation(dtype, n, /)
Return a random permutation of indices
[0, n).- Parameters:
dtype (type) – A Dr.Jit array type (e.g.,
dr.cuda.UInt32) that selects the backend. The return type is the correspondingUInt32type for that backend.n (int) – Length of the permutation (must be non-negative).
- permutation(x, /, axis=0)
Return a shuffled copy of an array or tensor.
- Parameters:
x (drjit.ArrayBase) – Input array or tensor to shuffle.
axis (int) – Axis along which to shuffle (tensors only, default:
0). Negative indices are supported.
Mask operations¶
Also relevant here are any(), all(), none(), and count().
- drjit.select(arg0: object, arg1: object | None, arg2: object | None) object¶
- drjit.select(arg0: bool, arg1: object | None, arg2: object | None) object
Overloaded function.
select(arg0: object, arg1: object | None, arg2: object | None) -> object
Select elements from inputs based on a condition
This function uses a first mask argument to select between the subsequent two arguments. It implements the following component-wise operation:
\[\mathrm{result}_i = \begin{cases} \texttt{arg1}_i,\quad&\text{if }\texttt{arg0}_i,\\ \texttt{arg2}_i,\quad&\text{otherwise.} \end{cases}\]- Parameters:
arg0 (bool | drjit.ArrayBase) – A Python or Dr.Jit mask type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for
True-valued mask entries.arg2 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for
False-valued mask entries.
- Returns:
Component-wise result of the selection operation
- Return type:
float | int | drjit.ArrayBase
select(arg0: bool, arg1: object | None, arg2: object | None) -> object
- drjit.isinf(arg, /)¶
Performs an elementwise test for positive or negative infinity
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test
- Return type:
- drjit.isnan(arg, /)¶
Performs an elementwise test for NaN (Not a Number) values
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test.
- Return type:
- drjit.isfinite(arg, /)¶
Performs an elementwise test that checks whether values are finite and not equal to NaN (Not a Number)
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test
- Return type:
- drjit.allclose(a: object, b: object, rtol: float | None = None, atol: float | None = None, equal_nan: bool = False) bool¶
Returns
Trueif two arrays are element-wise equal within a given error tolerance.The function considers both absolute and relative error thresholds. In particular, a and b are considered equal if all elements satisfy
\[|a - b| \le |b| \cdot \texttt{rtol} + \texttt{atol}. \]If not specified, the constants
atolandrtolare chosen depending on the precision of the input arrays:Precision
rtolatolfloat641e-51e-8float321e-31e-5float161e-11e-2Note that these constants used are fairly loose and far larger than the roundoff error of the underlying floating point representation. The double precision parameters were chosen to match the behavior of
numpy.allclose().- Parameters:
a (object) – A Dr.Jit array or other kind of numeric sequence type.
b (object) – A Dr.Jit array or other kind of numeric sequence type.
rtol (float) – A relative error threshold chosen according to the above table.
atol (float) – An absolute error threshold according to the above table.
equal_nan (bool) – If a and b both contain a NaN (Not a Number) entry, should they be considered equal? The default is
False.
- Returns:
The result of the comparison.
- Return type:
bool
- drjit.assert_allclose(actual: object, desired: object, rtol: float | None = None, atol: float | None = None, equal_nan: bool = False, err_msg: str = '') None¶
Raise an
AssertionErrorif two arrays are not element-wise equal within a given error tolerance.This is the assertion-raising counterpart of
drjit.allclose()and is analogous tonumpy.testing.assert_allclose. Elements are considered equal if\[|\texttt{actual} - \texttt{desired}| \le |\texttt{desired}| \cdot \texttt{rtol} + \texttt{atol}. \]See
drjit.allclose()for the default values ofrtolandatol(they depend on the input precision).Unlike
numpy.testing.assert_allclose, the happy path stays on device: the only host-visible reduction is a single boolean, and mismatch count / worst-case absolute and relative differences are only read back when the comparison fails.- Parameters:
actual (object) – A Dr.Jit array or other kind of numeric sequence type.
desired (object) – A Dr.Jit array or other kind of numeric sequence type. Its magnitude sets the relative-tolerance scale.
rtol (float) – Relative error threshold. Defaults depend on precision (see
drjit.allclose()).atol (float) – Absolute error threshold. Defaults depend on precision.
equal_nan (bool) – If actual and desired both contain a NaN entry at the same position, should they compare equal?
err_msg (str) – Optional prefix prepended to the generated error message.
- Raises:
AssertionError – If the arrays are not equal within the given tolerance.
Miscellaneous operations¶
- drjit.shape(arg: object) tuple[int, ...]¶
Return a tuple describing dimension and shape of the provided Dr.Jit array, tensor, or standard sequence type.
When the input array is ragged the function raises a
RuntimeError. The term ragged refers to an array, whose components have mismatched sizes, such as[[1, 2], [3, 4, 5]]. Note that scalar entries (e.g.[[1, 2], [3]]) are acceptable, since broadcasting can effectively convert them to any size.The expressions
drjit.shape(arg)andarg.shapeare equivalent.- Parameters:
arg (drjit.ArrayBase) – an arbitrary Dr.Jit array or tensor
- Returns:
A tuple describing the dimension and shape of the provided Dr.Jit input array or tensor.
- Return type:
tuple[int, …]
- drjit.width(arg: object, /) int¶
- drjit.width(*args) int
Overloaded function.
width(arg: object, /) -> int
Returns the vectorization width of the provided input(s), which is defined as the length of the last dynamic dimension.
When working with Jit-compiled CUDA or LLVM-based arrays, this corresponds to the number of items being processed in parallel.
The function raises an exception when the input(s) is ragged, i.e., when it contains arrays with incompatible sizes. It returns
1if the input is scalar and/or does not contain any Dr.Jit arrays.- Parameters:
arg (object) – An arbitrary Dr.Jit array or PyTree.
- Returns:
The width of the provided input(s).
- Return type:
int
width(*args) -> int
- drjit.slice_index(dtype: type[drjit.ArrayBase], shape: tuple, indices: tuple) tuple[tuple, object]¶
Computes an index array that can be used to slice a tensor. It is used internally by Dr.Jit to implement complex cases of the
__getitem__operation.It must be called with the desired output
dtype, which must be a dynamically sized 1D array of 32-bit integers. Theshapeparameter specifies the dimensions of a hypothetical input tensor, andindicescontains the entries that would appear in a complex slicing operation, but as a tuple. For example,[5:10:2, ..., None]would be specified as(slice(5, 10, 2), Ellipsis, None).An example is shown below:
>>> dr.slice_index(dtype=dr.scalar.ArrayXu, shape=(10, 1), indices=(slice(0, 10, 2), 0)) [0, 2, 4, 6, 8]
- Parameters:
dtype (type) – A dynamic 32-bit unsigned integer Dr.Jit array type, such as
drjit.scalar.ArrayXuordrjit.cuda.UInt.shape (tuple[int, ...]) – The shape of the tensor to be sliced.
indices (tuple[int|slice|ellipsis|None|dr.ArrayBase, ...]) – A set of indices used to slice the tensor. Its entries can be
sliceinstances, integers, integer arrays,...(ellipsis) orNone.
- Returns:
Tuple consisting of the output array shape and a flattened unsigned integer array of type
dtypecontaining element indices.- Return type:
tuple[tuple[int, …], drjit.ArrayBase]
- drjit.meshgrid(*args, indexing='xy') tuple¶
Return flattened N-D coordinate arrays from a sequence of 1D coordinate vectors.
This function constructs flattened coordinate arrays that are convenient for evaluating and plotting functions on a regular grid. An example is shown below:
import drjit as dr x, y = dr.meshgrid( dr.arange(dr.llvm.UInt, 4), dr.arange(dr.llvm.UInt, 4) ) # x = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] # y = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
This function carefully reproduces the behavior of
numpy.meshgridexcept for one major difference: the output coordinates are returned in flattened/raveled form. Like the NumPy version, theindexing=='xy'case internally reorders the first two elements of*args.- Parameters:
*args – A sequence of 1D coordinate arrays
indexing (str) – Specifies the indexing convention. Must be either set to
'xy'(the default) or'ij'.
- Returns:
A tuple of flattened coordinate arrays (one per input)
- Return type:
tuple
- drjit.binary_search(start, end, pred)¶
Perform a binary search over a range given a predicate
pred, which monotonically decreases over this range (i.e. max oneTrue->Falsetransition).Given a (scalar)
startandendindex of a range, this function evaluates a predicatefloor(log2(end-start) + 1)times with index values on the interval [start, end] (inclusive) to find the first index that no longer satisfies it. Note that the template parameterIndexis automatically inferred from the supplied predicate. Specifically, the predicate takes an index array as input argument. WhenpredisFalsefor all entries, the function returnsstart, and when it isTruefor all cases, it returnsend.The following code example shows a typical use case:
datacontains a sorted list of floating point numbers, and the goal is to map floating point entries ofxto the first indexjsuch thatdata[j] >= threshold(and all of this of course in parallel for each vector element).dtype = dr.llvm.Float data = dtype(...) threshold = dtype(...) index = dr.binary_search( 0, len(data) - 1, lambda index: dr.gather(dtype, data, index) < threshold )
- Parameters:
start (int) – Starting index for the search range
end (int) – Ending index for the search range
pred (function) – The predicate function to be evaluated
- Returns:
Index array resulting from the binary search
- drjit.make_opaque(arg: object, /) None¶
- drjit.make_opaque(*args) None
Overloaded function.
make_opaque(arg: object, /) -> None
Forcefully evaluate arrays (including literal constants).
This function implements a more drastic version of
drjit.eval()that additionally converts literal constant arrays into evaluated (device memory-based) representations.It is related to the function
drjit.opaque()that can be used to directly construct such opaque arrays. Please see the documentation of this function regarding the rationale of making array contents opaque to Dr.Jit’s symbolic tracing mechanism.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)
make_opaque(*args) -> None
- drjit.copy(arg: T, /) T¶
Create a deep copy of a PyTree
This function recursively traverses PyTrees and replaces Dr.Jit arrays with copies created via the ordinary copy constructor. It also rebuilds tuples, lists, dictionaries, and other custom data strutures.
- drjit.linear_to_srgb(x: ArrayT, clip: bool = True) ArrayT¶
Convert linear intensity values to sRGB gamma-corrected values.
Applies the sRGB transfer function (gamma compression) to convert linear RGB values to gamma-encoded sRGB values suitable for display or storage. The sRGB transfer function uses a piecewise curve with a linear segment near black for numerical stability.
When clip=True (default), input values are clamped to [0, 1] before conversion. When clip=False, the transformation preserves the sign and applies the sRGB curve to the absolute value, enabling handling of out-of-gamut colors in wide-gamut workflows.
The transformation is defined as:
\[f(x) = \begin{cases} 12.92 \cdot x & \text{if } x \leq 0.0031308 \\ 1.055 \cdot x^{1/2.4} - 0.055 & \text{if } x > 0.0031308 \end{cases}\]When
clip=False, the transformation becomes:\[f(x) = \text{sign}(x) \cdot f(|x|)\]- Parameters:
x (ArrayT) – Dr.Jit array containing linear intensity values. Typically in range [0, 1] for in-gamut colors.
clip (bool) – If True, clamp input values to [0, 1]. If False, preserve sign and apply transformation to absolute values. Default: True.
- Returns:
- sRGB gamma-corrected values. When clip=True, output is in [0, 1].
When clip=False, output preserves the sign of input values.
- Return type:
ArrayT
- drjit.srgb_to_linear(x: ArrayT, clip: bool = True) ArrayT¶
Convert sRGB gamma-corrected values to linear intensity values.
Applies the inverse sRGB transfer function (gamma expansion) to convert gamma-encoded sRGB values to linear RGB values. The sRGB transfer function uses a piecewise curve with a linear segment near black for numerical stability.
When clip=True (default), input values are clamped to [0, 1] before conversion. When clip=False, the transformation preserves the sign and applies the sRGB curve to the absolute value, enabling round-trip conversion of out-of-gamut colors. This is useful for wide-gamut workflows where colors may temporarily exceed the [0, 1] range during processing.
The transformation is defined as:
\[f(x) = \begin{cases} \frac{x}{12.92} & \text{if } x \leq 0.04045 \\ \left(\frac{x + 0.055}{1.055}\right)^{2.4} & \text{if } x > 0.04045 \end{cases}\]When
clip=False, the transformation becomes:\[f(x) = \text{sign}(x) \cdot f(|x|)\]- Parameters:
x (ArrayT) – Dr.Jit array containing sRGB gamma-corrected values. Typically in range [0, 1] for in-gamut colors.
clip (bool) – If True, clamp input values to [0, 1]. If False, preserve sign and apply transformation to absolute values. Default: True.
- Returns:
- Linear intensity values. When clip=True, output is in [0, 1].
When clip=False, output preserves the sign of input values.
- Return type:
ArrayT
- drjit.linear_srgb_to_oklab(value: ArrayT) ArrayT¶
Convert colors from linear sRGB to Oklab color space.
Oklab is a perceptual color space designed for image processing that provides better perceptual uniformity than CIELAB or HSV. The L, a, b coordinates are perceptually orthogonal, enabling independent manipulation of lightness, green-red axis, and blue-yellow axis without perceived changes in the other dimensions. Oklab produces smooth color transitions and accurately predicts human perception of hue and chroma.
For more details, see: https://bottosson.github.io/posts/oklab/
This function supports Dr.Jit tensors and arrays with RGB or RGBA data. For tensors, the color channels must be in the trailing dimension. For arrays, the color channels must be in the leading dimension. Alpha channels are preserved unchanged in RGBA inputs.
The function expects linear sRGB values as input, not gamma-encoded sRGB. For typical gamma-encoded image data (e.g., from image files), first apply
dr.srgb_to_linear()to convert to linear sRGB before using this function.- Parameters:
value (dr.ArrayBase) – Dr.Jit tensor or array containing linear sRGB colors. For tensors: shape […, 3] for RGB or […, 4] for RGBA. For arrays: shape [3, …] for RGB or [4, …] for RGBA.
- Returns:
- Colors converted to Oklab space with same shape as input.
L represents lightness, a represents green-red axis, b represents blue-yellow axis.
- Return type:
dr.ArrayBase
- drjit.oklab_to_linear_srgb(value: ArrayT) ArrayT¶
Convert colors from Oklab color space back to linear sRGB.
This function performs the inverse transformation of linear_srgb_to_oklab, converting Oklab L, a, b coordinates back to linear sRGB values.
This function supports Dr.Jit tensors and arrays with Lab or LabA data. For tensors, the color channels must be in the trailing dimension. For arrays, the color channels must be in the leading dimension. Alpha channels are preserved unchanged in LabA inputs.
This function returns linear sRGB values, not gamma-encoded sRGB. To obtain gamma-encoded sRGB for typical image output (e.g., for display or saving to image files), apply
dr.linear_to_srgb()to the output of this function.- Parameters:
value (dr.ArrayBase) – Dr.Jit tensor or array containing Oklab colors. For tensors: shape […, 3] for Lab or […, 4] for LabA. For arrays: shape [3, …] for Lab or [4, …] for LabA.
- Returns:
- Colors converted to linear sRGB space with same shape as input.
Values are in linear sRGB (not gamma-corrected).
- Return type:
dr.ArrayBase
- drjit.reorder_threads(key: drjit.ArrayBase, num_bits: int, value: object) object¶
Trigger a call to the Shader Execution Reordering (SER) feature of the GPU.
This function performs a hardware-assisted shuffle of the GPU threads to improve the kernel occupancy by reducing warp-level divergence in certain workloads. In order to perorm this shuffle, it requires a sorting key to indicate which threads should be grouped into coherent warps.
An extra
valueargument must be passed to the function. This argument will be returned as is but internally Dr.Jit will add some tracking to it to guarantee that, on any subsequent use of it, a reordering operation will be inserted in the kernel.Example usage:
arg = dr.cuda.Array3f(...) key = dr.cuda.UInt32(...) % 4 # Reorder threads before `dr.switch()` to reduce divergence # Only do it if `arg` is used arg = dr.reorder_threads(key, 2, arg) callables = [...] callable_idx = dr.cuda.UInt32(...) out = dr.switch(callable_idx, callables, arg)
When
drjit.JitFlag.ShaderExecutionReorderingis not set, or when using the LLVM backend, this operation is a no-op.- Parameters:
key (drjit.ArrayBase) – A 1D unsigned integer 32-bit array that serves as a sorting key for the shuffle operation. Only the lower
num_bitsare used.num_bits (int) – Number of bits from the key to use (starting from the least signifcant bit). It is recommended to use as few as possible. At most, 16 bits can be used.
value (object) – An arbitrary Dr.Jit array, tensor, or PyTree. This argument is returned without modification. The reordering will only happen if the returned version of this arugment is used.
- Returns:
The updated
valuevariable that will trigger the reordering if used.- Return type:
object
Just-in-time compilation¶
- enum drjit.JitBackend(value)¶
List of just-in-time compilation backends supported by Dr.Jit. See also
drjit.backend_v().Valid values are as follows:
- Invalid = JitBackend.Invalid¶
Indicates that a type is not handled by a Dr.Jit backend (e.g., a scalar type)
- CUDA = JitBackend.CUDA¶
Dr.Jit backend targeting NVIDIA GPUs using PTX (“Parallel Thread Execution”) IR.
- LLVM = JitBackend.LLVM¶
Dr.Jit backend targeting various processors via the LLVM compiler infrastructure.
- enum drjit.VarType(value)¶
List of possible scalar array types (not all of them are supported).
Valid values are as follows:
- Void = VarType.Void¶
Unknown/unspecified type.
- Bool = VarType.Bool¶
Boolean/mask type.
- Int8 = VarType.Int8¶
Signed 8-bit integer.
- UInt8 = VarType.UInt8¶
Unsigned 8-bit integer.
- Int16 = VarType.Int16¶
Signed 16-bit integer.
- UInt16 = VarType.UInt16¶
Unsigned 16-bit integer.
- Int32 = VarType.Int32¶
Signed 32-bit integer.
- UInt32 = VarType.UInt32¶
Unsigned 32-bit integer.
- Int64 = VarType.Int64¶
Signed 64-bit integer.
- UInt64 = VarType.UInt64¶
Unsigned 64-bit integer.
- Pointer = VarType.Pointer¶
Pointer to a memory address.
- Float16 = VarType.Float16¶
16-bit floating point format (IEEE 754).
- Float32 = VarType.Float32¶
32-bit floating point format (IEEE 754).
- Float64 = VarType.Float64¶
64-bit floating point format (IEEE 754).
- enum drjit.VarState(value)¶
The
drjit.ArrayBase.stateproperty returns one of the following enumeration values describing possible evaluation states of a Dr.Jit variable.Valid values are as follows:
- Invalid = VarState.Invalid¶
The variable has length 0 and effectively does not exist.
- Literal = VarState.Literal¶
A literal constant. Does not consume device memory.
- Undefined = VarState.Undefined¶
An undefined memory region. Does not (yet) consume device memory.
- Unevaluated = VarState.Unevaluated¶
An ordinary unevaluated variable that is neither a literal constant nor symbolic.
- Evaluated = VarState.Evaluated¶
Evaluated variable backed by an device memory region.
- Dirty = VarState.Dirty¶
An evaluated variable backed by a device memory region. The variable furthermore has pending side effects (i.e. the user has performed a :py:func`:drjit.scatter`,
drjit.scatter_reduce():py:func`:drjit.scatter_inc`, :py:func`:drjit.scatter_add`, or :py:func`:drjit.scatter_add_kahan` operation, and the effect of this operation has not been realized yet). The array’s status will automatically change toEvaluatedthe next time that Dr.Jit evaluates computation, e.g. viadrjit.eval().
- Symbolic = VarState.Symbolic¶
A symbolic variable that could take on various inputs. Cannot be evaluated.
- Mixed = VarState.Mixed¶
This is a nested array, and the components have mixed states.
- enum drjit.JitFlag(value)¶
Flags that control how Dr.Jit compiles and optimizes programs.
This enumeration lists various flag that control how Dr.Jit compiles and optimizes programs, most of which are enabled by default. The status of each flag can be queried via
drjit.flag()and enabled/disabled via thedrjit.set_flag()or the recommendeddrjit.scoped_set_flag()functions, e.g.:with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False): # code that has this flag disabled goes here
The most common reason to update the flags is to switch between symbolic and evaluated execution of loops and functions. The former eagerly executes programs by breaking them into many smaller kernels, while the latter records computation symbolically to assemble large megakernels. See explanations below along with the documentation of
drjit.switch()anddrjit.while_loopfor more details on these two modes.Dr.Jit flags are a thread-local property. This means that multiple independent threads using Dr.Jit can set them independently without interfering with each other.
- Member Type:
int
Valid values are as follows:
- Debug = JitFlag.Debug¶
Debug mode: Enable functionality to uncover errors in application code.
When debug mode is enabled, Dr.Jit inserts a number of additional runtime checks to locate sources of undefined behavior.
Debug mode comes at a significant cost: it interferes with kernel caching, reduces tracing performance, and produce kernels that run slower. We recommend that you only use it periodically before a release, or when encountering a serious problem like a crashing kernel.
First, debug mode enables assertion checks in user code such as those performed by
drjit.assert_true(),drjit.assert_false(), anddrjit.assert_equal().Second, Dr.Jit inserts additional checks to intercept out-of-bound reads and writes performed by operations such as
drjit.scatter(),drjit.gather(),drjit.scatter_reduce(),drjit.scatter_inc(), etc. It also detects calls to invalid callables performed viadrjit.switch(),drjit.dispatch(). Such invalid operations are masked, and they generate a warning message on the console, e.g.:>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100)) RuntimeWarning: drjit.gather(): out-of-bounds read from position 100 in an array↵ of size 3. (<stdin>:2)
Third, for kernels launched with OptiX, additional validation and debug layers will be enabled. For this functionality to be fully enabled, it is recommended to set the
drjit.JitFlag.Debugflag at the very beginning of your Python process.Finally, Dr.Jit also installs a python tracing hook that associates all Jit variables with their Python source code location, and this information is propagated all the way to the final intermediate representation (PTX, LLVM IR). This is useful for low-level debugging and development of Dr.Jit itself. You can query the source location information of a variable
xby writingx.label.Due to limitations of the Python tracing interface, this handler becomes active within the next called function (or Jupyter notebook cell) following activation of the
drjit.JitFlag.Debugflag. It does not apply to code within the same scope/function.C++ code using Dr.Jit also benefits from debug mode but will lack accurate source code location information. In mixed-language projects, the reported file and line number information will reflect that of the last operation on the Python side of the interface.
- ReuseIndices = JitFlag.ReuseIndices¶
Index reuse: Dr.Jit consists of two main parts: the just-in-time compiler, and the automatic differentiation layer. Both maintain an internal data structure representing captured computation, in which each variable is associated with an index (e.g.,
r1234in the JIT compiler, anda1234in the AD graph).The index of a Dr.Jit array in these graphs can be queried via the
drjit.indexanddrjit.index_advariables, and they are also visible in debug messages (ifdrjit.set_log_level()is set to a more verbose debug level).Dr.Jit aggressively reuses the indices of expired variables by default, but this can make debug output difficult to interpret. When debugging Dr.Jit itself, it is often helpful to investigate the history of a particular variable. In such cases, set this flag to
Falseto disable variable reuse both at the JIT and AD levels. This comes at a cost: the internal data structures keep on growing, so it is not suitable for long-running computations.Index reuse is enabled by default.
- ConstantPropagation = JitFlag.ConstantPropagation¶
Constant propagation: immediately evaluate arithmetic involving literal constants on the host and don’t generate any device-specific code for them.
For example, the following assertion holds when value numbering is enabled in Dr.Jit.
from drjit.llvm import Int # Create two literal constant arrays a, b = Int(4), Int(5) # This addition operation can be immediately performed and does not need to be recorded c1 = a + b # Double-check that c1 and c2 refer to the same Dr.Jit variable c2 = Int(9) assert c1.index == c2.index
Constant propagation is enabled by default.
- ValueNumbering = JitFlag.ValueNumbering¶
Local value numbering: a simple variant of common subexpression elimination that collapses identical expressions within basic blocks. For example, the following assertion holds when value numbering is enabled in Dr.Jit.
from drjit.llvm import Int # Create two non-literal arrays stored in device memory a, b = Int(1, 2, 3), Int(4, 5, 6) # Perform the same arithmetic operation twice c1 = a + b c2 = a + b # Verify that c1 and c2 reference the same Dr.Jit variable assert c1.index == c2.index
Local value numbering is enabled by default.
- FastMath = JitFlag.FastMath¶
Fast Math: this flag is analogous to the
-ffast-mathflag in C compilers. When set, the system may use approximations and simplifications that sacrifice strict IEEE-754 compatibility.Currently, it changes two behaviors:
expressions of the form
a * 0will be simplified to0(which is technically not correct whenais infinite or NaN-valued).Dr.Jit will use slightly approximate division and square root operations in CUDA mode. Note that disabling fast math mode is costly on CUDA devices, as the strict IEEE-754 compliant version of these operations uses software-based emulation.
Fast math mode is enabled by default.
- SymbolicLoops = JitFlag.SymbolicLoops¶
Dr.Jit provides two main ways of compiling loops involving Dr.Jit arrays.
Symbolic mode (the default): Dr.Jit executes the loop a single time regardless of how many iterations it requires in practice. It does so with symbolic (abstract) arguments to capture the loop condition and body and then turns it into an equivalent loop in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.
Evaluated mode: Dr.Jit evaluates the loop’s state variables and reduces the loop condition to a single element (
bool) that expresses whether any elements are still alive. If so, it runs the loop body and the process repeats.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Symbolic loops are enabled by default.
- OptimizeLoops = JitFlag.OptimizeLoops¶
Perform basic optimizations for loops involving Dr.Jit arrays.
This flag enables two optimizations:
Constant arrays: loop state variables that aren’t modified by the loop are automatically removed. This shortens the generated code, which can be helpful especially in combination with the automatic transformations performed by
@drjit.syntaxthat can be somewhat conservative in classifying too many local variables as potential loop state.Literal constant arrays: In addition to the above point, constant loop state variables that are literal constants are propagated into the loop body, where this may unlock further optimization opportunities.
This is useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.
A practical implication of this optimization flag is that it may cause
drjit.while_loop()to run the loop body twice instead of just once.This flag is enabled by default. Note that it is only meaningful in combination with
SymbolicLoops.
- CompressLoops = JitFlag.CompressLoops¶
Compress the loop state of evaluated loops after every iteration.
When an evaluated loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.
Setting this flag causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases.
This flag is disabled by default. Note that it only applies to evaluated loops (i.e., when
SymbolicLoopsis disabled, or themode='evaluted'parameter as passed to the loop in question).
- SymbolicCalls = JitFlag.SymbolicCalls¶
Dr.Jit provides two main ways of compiling function calls targeting instance arrays.
Symbolic mode (the default): Dr.Jit invokes each callable with symbolic (abstract) arguments. It does this to capture a transcript of the computation that it can turn into a function in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.
Evaluated mode: Dr.Jit evaluates all inputs and groups them by instance ID. Following this, it launches a kernel per instance to process the rearranged inputs and assemble the function return value.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Besides calls to instance arrays, this flag also controls the behavior of the functions
drjit.switch()anddrjit.dispatch().Symbolic calls are enabled by default.
- OptimizeCalls = JitFlag.OptimizeCalls¶
Perform basic optimizations for function calls on instance arrays.
This flag enables two optimizations:
Constant propagation: Dr.Jit will propagate literal constants across function boundaries while tracing, which can unlock simplifications within. This is especially useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.
Devirtualization: When an element of the return value has the same computation graph in all instances, it is removed from the function call interface and moved to the caller.
The flag is enabled by default. Note that it is only meaningful in combination with
SymbolicCalls. Besides calls to instance arrays, this flag also controls the behavior of the functionsdrjit.switch()anddrjit.dispatch().
- MergeFunctions = JitFlag.MergeFunctions¶
Deduplicate code generated by function calls on instance arrays.
When
arris an instance array (potentially with thousands of instances), a function call likearr.f(inputs...)
can potentially generate vast numbers of different functions in the generated code. At the same time, many of these functions may contain identical code (or code that is identical except for data references).
Dr.Jit can exploit such redundancy and merge such functions during code generation. Besides generating shorter programs, this also helps to reduce thread divergence.
This flag is enabled by default. Note that it is only meaningful in combination with
SymbolicCalls. Besides calls to instance arrays, this flag also controls the behavior of the functionsdrjit.switch()anddrjit.dispatch().
- PacketOps = JitFlag.PacketOps¶
Turn sequences of contiguous gather/scatter operations into packet loads/stores.
When indirect memory accesses (gathers/scatters) access multiple subsequent elements, it should be possible to exploit this to perform the operation more efficiently.
If
drjit.JitFlag.PacketOpsis set, Dr.Jit will realizes this optimization opportunity whenThe size of the leading dimension of the source/target array is a power of two.
(Non-power-of-two sizes are decomposed into sequences of smaller packet operations—for example, size 24 is realized as 3 packets with width 8).
The array is read/written via
drjit.gather(),drjit.scatter(),drjit.scatter_add(),drjit.ravel(), ordrjit.unravel().For example, the following operation gathers 4D vectors from a flat array:
from drjit.auto import Array4f, Float, UInt32 source = Float(...) result = dr.gather(Array4f, source, index=UInt32(...))
Wider cases are supported as well by providing a
shapeargumentfrom drjit.auto import Array4f, Float, UInt32 source = Float(...) result = dr.gather(ArrayXf, source, index=index, shape=(16, len(index)))
Packet gathers yield a modest performance improvement on the CUDA backend, where they produce fewer and larger memory transactions. For example, a single 128 bit load can fetch 8 half-precision values at once. Speedups of 5-30% have been observed.
On the LLVM backend, the operation replaces vector/scatter instructions with a combination of aligned packet loads/stores and a matrix transpose (implemented for 2, 4, and 8-D inputs, larger vectors perform several 8D transposes). Speedups here are rather dramatic (up to >20× for scatters, 1.5-2× for gathers have been measured).
This optimization is expected to make an even bigger difference following microcode-based security mitigations for a recent side-channel attack that effectively break the performance characteristics of the native gather operation on Intel CPUs. Packet gathers don’t use the regular gather instruction and thus aren’t affected by the mitigation.
This flag is enabled by default.
- ForceOptiX = JitFlag.ForceOptiX¶
Force execution through OptiX even if a kernel doesn’t use ray tracing. This only applies to the CUDA backend is mainly helpful for automated tests done by the Dr.Jit team.
This flag is disabled by default.
- PrintIR = JitFlag.PrintIR¶
Print the low-level IR representation when launching a kernel.
If enabled, this flag causes Dr.Jit to print the low-level IR (LLVM IR, NVIDIA PTX) representation of the generated code onto the console (or Jupyter notebook).
This flag is disabled by default.
- KernelHistory = JitFlag.KernelHistory¶
Maintain a history of kernel launches to profile/debug programs.
Programs written on top of Dr.Jit execute in an extremely asynchronous manner. By default, the system postpones the computation to build large fused kernels. Even when this computation eventually runs, it does so asynchronously with respect to the host, which can make benchmarking difficult.
In general, beware of the following benchmarking anti-pattern:
import time a = time.time() # Some Dr.Jit computation b = time.time() print("took %.2f ms" % ((b-a) * 1000))
In the worst case, the measured time interval may only capture the tracing time, without any actual computation having taken place. Another common mistake with this pattern is that Dr.Jit or the target device may still be busy with computation that started prior to the
a = time.time()line, which is now incorrectly added to the measured period.Dr.Jit provides a kernel history feature, where it creates an entry in a list whenever it launches a kernel or related operation (memory copies, etc.). This not only gives accurate and isolated timings (measured with counters on the CPU/GPU) but also reveals if a kernel was launched at all. To capture the kernel history, set this flag just before the region to be benchmarked and call
drjit.kernel_history()at the end.Capturing the history has a (very) small cost and is therefore disabled by default.
- LaunchBlocking = JitFlag.LaunchBlocking¶
Force synchronization after every kernel launch. This is useful to isolate severe problems (e.g. crashes) to a specific kernel.
This flag has a severe performance impact and is disabled by default.
- ForbidSynchronization = JitFlag.ForbidSynchronization¶
Treat any kind of synchronization as an error and raise an exception when it is encountered.
Operations like
drjit.sync_thread()are costly because they prevent the system from overlapping CPU/GPU work. Enable this flag to find places in a larger codebase that are responsible for this.This flag is disabled by default.
- ScatterReduceLocal = JitFlag.ScatterReduceLocal¶
Reduce locally before performing atomic scatter-reductions.
Atomic memory operations are expensive when many writes target the same region of memory. This leads to a phenomenon called contention that is normally associated with significant slowdowns (10-100x aren’t unusual).
This issue is particularly common when automatically differentiating computation in reverse mode (e.g.
drjit.backward()), since this transformation turns differentiable global memory reads into atomic scatter-additions. A differentiable scalar read is all it takes to create such an atomic memory bottleneck.To reduce this cost, Dr.Jit can perform a local reduction that uses cooperation between SIMD/warp lanes to resolve all requests targeting the same address and then only issuing a single atomic memory transaction per unique target. This can reduce atomic memory traffic 32-fold on the GPU (CUDA) and 16-fold on the CPU (AVX512). On the CUDA backend, local reduction is currently only supported for 32-bit operands (signed/unsigned integers and single precision variables).
The section on optimizations presents plots that demonstrate the impact of this optimization.
The JIT flag
drjit.JitFlag.ScatterReduceLocalaffects the behavior ofscatter_add(),scatter_reduce()along with the reverse-mode derivative ofgather(). Setting the flag toTruewill usually cause amode=argument value ofdrjit.ReduceOp.Autoto be interpreted asdrjit.ReduceOp.Local. Another LLVM-specific optimization takes precedence in certain situations, refer to the discussion of this flag for details.This flag is enabled by default.
- SymbolicConditionals = JitFlag.SymbolicConditionals¶
Dr.Jit provides two main ways of compiling conditionals involving Dr.Jit arrays.
Symbolic mode (the default): Dr.Jit captures the computation performed by the
TrueandFalsebranches and generates an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.Evaluated mode: Dr.Jit always executes both branches and blends their outputs.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Symbolic conditionals are enabled by default.
- SymbolicScope = JitFlag.SymbolicScope¶
This flag is set to
Truewhen Dr.Jit is currently capturing symbolic computation. The flag is automatically managed and should not be updated by application code.User code may query this flag to check if it is legal to perform certain operations (e.g., evaluating array contents).
Note that this information can also be queried in a more fine-grained manner (per variable) using the
drjit.ArrayBase.statefield.
- ShaderExecutionReordering = JitFlag.ShaderExecutionReordering¶
Enable OptiX’s SER feature in ray tracing functions and in
reorder_threads(). This flag only applies to the CUDA backend.This flag is enabled by default.
- KernelFreezing = JitFlag.KernelFreezing¶
Enable recording and replay of functions annotated with
freeze().If KernelFreezing is enabled, all Dr.Jit operations executed in a function annotated with
freeze()are recorded during its first execution and replayed without re-tracing on subsequent calls.If this flag is disabled, replay of previously frozen functions is disabled as well.
- FreezingScope = JitFlag.FreezingScope¶
This flag is set to
Truewhen Dr.Jit is currently recording a frozen function. The flag is automatically managed and should not be updated by application code.User code may query this flag to conditionally optimize kernels for frozen function recording, such as re-seeding the sampler, used for rendering.
- EnableObjectTraversal = JitFlag.EnableObjectTraversal¶
This flag is set to
Truewhen Dr.Jit is currently traversing inputs and outputs of a frozen function. The flag is automatically managed and should not be updated by application code.When enabled, traversal of complex objects, that usually are opaque to loops and conditionals, is enabled.
Enable spilling of excess registers into shared memory.
This flag activates an optimization that stores registers in shared memory when register pressure is high, reducing the need to spill to slower local memory. This can improve performance by lowering memory latency on register-intensive kernels. This flag only applies to the CUDA backend.
This flag is enabled by default.
- Default = JitFlag.Default¶
The default set of optimization flags consisting of
- drjit.has_backend(arg: drjit.JitBackend, /) int¶
Check if the specified Dr.Jit backend was successfully initialized.
- drjit.schedule(arg: object, /) bool¶
- drjit.schedule(*args) bool
Overloaded function.
schedule(arg: object, /) -> bool
Schedule the provided JIT variable(s) for later evaluation
This function causes
argsto be evaluated by the next kernel launch. In other words, the effect of this operation is deferred: the next time that Dr.Jit’s LLVM or CUDA backends compile and execute code, they will include the trace of the specified variables in the generated kernel and turn them into an explicit memory-based representation.Scheduling and evaluation of traced computation happens automatically, hence it is rare that a user would need to call this function explicitly. Explicit scheduling can improve performance in certain cases—for example, consider the following code:
# Computation that produces Dr.Jit arrays a, b = ... # The following line launches a kernel that computes 'a' print(a) # The following line launches a kernel that computes 'b' print(b)
If the traces of
aandboverlap (perhaps they reference computation from an earlier step not shown here), then this is inefficient as these steps will be executed twice. It is preferable to launch bigger kernels that leverage common subexpressions, which is whatdrjit.schedule()enables:a, b = ... # Computation that produces Dr.Jit arrays # Schedule both arrays for deferred evaluation, but don't evaluate yet dr.schedule(a, b) # The following line launches a kernel that computes both 'a' and 'b' print(a) # References the stored array, no kernel launch print(b)
Note that
drjit.eval()would also have been a suitable alternative in the above example; the main difference todrjit.schedule()is that it does the evaluation immediately without deferring the kernel launch.This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During recursion, the function gathers all unevaluated Dr.Jit arrays. Evaluated arrays and incompatible types are ignored. Multiple variables can be equivalently scheduled with a single
drjit.schedule()call or a sequence of calls todrjit.schedule(). Variables that are garbage collected between the originaldrjit.schedule()call and the next kernel launch are ignored and will not be stored in memory.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to all differentiable variables.)
- Returns:
Trueif a variable was scheduled,Falseif the operation did not do anything.- Return type:
bool
schedule(*args) -> bool
- drjit.eval(arg: object, /) bool¶
- drjit.eval(*args) bool
Overloaded function.
eval(arg: object, /) -> bool
Evaluate the provided JIT variable(s)
Dr.Jit automatically evaluates variables as needed, hence it is usually not necessary to call this function explicitly. That said, explicit evaluation may sometimes improve performance—refer to the documentation of
drjit.schedule()for an example of such a use case.drjit.eval()invokes Dr.Jit’s LLVM or CUDA backends to compile and then execute a kernel containing the all steps that are needed to evaluate the specified variables, which will turn them into a memory-based representation. The generated kernel(s) will also include computation that was previously scheduled viadrjit.schedule(). In fact,drjit.eval()internally callsdrjit.schedule(), asdr.eval(arg_1, arg_2, ...)
is equivalent to
dr.schedule(arg_1, arg_2, ...) dr.eval()
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function collects all unevaluated Dr.Jit arrays, while ignoring previously evaluated arrays along and non-array types. The function also does not evaluate literal constant arrays (this refers to potentially large arrays that are entirely uniform), as this is generally not wanted. Use the function
drjit.make_opaque()if you wish to evaluate literal constant arrays as well.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)
- Returns:
Trueif a variable was evaluated,Falseif the operation did not do anything.- Return type:
bool
eval(*args) -> bool
- drjit.set_flag(arg0: drjit.JitFlag, arg1: bool, /) None¶
Set the value of the given Dr.Jit compilation flag.
- drjit.flag(arg: drjit.JitFlag, /) bool¶
Query whether the given Dr.Jit compilation flag is active.
- class drjit.scoped_set_flag(*args, **kwargs)¶
Context manager, which sets or unsets a Dr.Jit compilation flag in a local execution scope.
For example, the following snippet shows how to temporarily disable a flag:
with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False): # Code affected by the change should be placed here # Flag is returned to its original status
- __init__(self, flag: drjit.JitFlag, value: bool = True) None¶
- __enter__(self) None¶
- __exit__(self, arg0: object | None, arg1: object | None, arg2: object | None) None¶
Function freezing¶
- drjit.freeze(f: None = None, *, state_fn: Callable | None, limit: int | None = None, warn_after: int = 10, backend: JitBackend | None = None, auto_opaque: bool = True, enabled: bool = True) Callable[[F], F]¶
- drjit.freeze(f: F, *, state_fn: Callable | None = None, limit: int | None = None, warn_after: int = 10, backend: JitBackend | None = None, auto_opaque: bool = True, enabled: bool = True) F
Decorator to “freeze” functions, which improves efficiency by removing repeated JIT tracing overheads.
In general, Dr.Jit traces computation and then compiles and launches kernels containing this trace (see the section on evaluation for details). While the compilation step can often be skipped via caching, the tracing cost can still be significant especially when repeatedly evaluating complex models, e.g., as part of an optimization loop.
The
@dr.freezedecorator adresses this problem by altogether removing the need to trace repeatedly. For example, consider the following decorated function:@dr.freeze def f(x, y, z): return ... # Complicated code involving the arguments
Dr.Jit will trace the first call to the decorated function
f(), while collecting additional information regarding the nature of the function’s inputs and regarding the CPU/GPU kernel launches representing the body off().If the function is subsequently called with compatible arguments (more on this below), it will immediately launch the previously made CPU/GPU kernels without re-tracing, which can substantially improve performance.
When
@dr.freezedetects incompatibilities (e.g.,xhaving a different type compared to the previous call), it will conservatively re-trace the body and keep track of another potential input configuration.Frozen functions support arbitrary PyTrees as function arguments and return values.
The following may trigger re-tracing:
Changes in the type of an argument or PyTree element.
Changes in the length of a container (
list,tuple,dict).Changes of dictionary keys or field names of dataclasses.
Changes in the AD status (
dr.grad_enabled()) of a variable.Changes of (non-PyTree) Python objects, as detected by mismatching
hash()orid()if they are not hashable.
The following more technical conditions also trigger re-tracing:
A Dr.Jit variable changes from/to a scalar configuration (size
1).The sets of variables of the same size change. In the example above, this would be the case if
len(x) == len(y)in one call, andlen(x) != len(y)subsequently.When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the memory can be aligned or unaligned. A re-tracing step is needed when this status changes.
These all correspond to situations where the generated kernel code may need to change, and the system conservatively re-traces to ensure correctness.
Frozen functions support arguments with a different variable width (see
dr.with()) without re-tracing, as long as the sets of variables of the same width stay consistent.Some constructions are problematic and should be avoided in frozen functions.
The function
dr.width()returns an integer literal that may be merged into the generated code. If the frozen function is later rerun with differently-sized arguments, the executed kernels will still reference the old size. One exception to this rule are constructions like dr.arange(UInt32, dr.width(a)), where the result only implicitly depends on the width value.
When calling a frozen function from within an outer frozen function, the content of the inner function will be executed and recorded by the outer function. No separate recording will be made for the inner function, and its
n_recordingscount will not change. Calling the inner function separately from outside a frozen function will therefore require re-tracing for the provided inputs.Advanced features. The
@dr.freezedecorator takes several optional parameters that are helpful in certain situations.Warning when re-tracing happens too often: Incompatible arguments trigger re-tracing, which can mask issues where accidentally incompatible arguments keep
@dr.freezefrom producing the expected performance benefits.In such situations, it can be helpful to warn and identify changing parameters by name. This feature is enabled and set to
10by default.>>> @dr.freeze(warn_after=1) >>> def f(x): ... return x ... >>> f(Int(1)) >>> f(Float(1)) The frozen function has been recorded 2 times, this indicates a problem with how the frozen function is being called. For example, calling it with changing python values such as an index. For more information about which variables changed set the log level to ``LogLevel::Debug``.
Limiting memory usage. Storing kernels for many possible input configuration requires device memory, which can become problematic. Set the
limit=parameter to enable a LRU cache. This is useful when calls to a function are mostly compatible but require occasional re-tracing.
- Parameters:
limit (Optional[int]) – An optional integer specifying the maximum number of stored configurations. Once this limit is reached, incompatible calls requiring re-tracing will cause the last used configuration to be dropped.
warn_after (int) – When the number of re-tracing steps exceeds this value, Dr.Jit will generate a warning that explains which variables changed between calls to the function.
state_fn (Optional[Callable]) – This optional callable can specify additional state to identifies the configuration.
state_fnwill be called with the same arguments as that of the decorated function. It should return a traversable object (e.g., a list or tuple) that is conceptually treated as if it was another input of the function.backend (Optional[JitBackend]) – If no inputs are given when calling the frozen function, the backend used has to be specified using this argument. It must match the backend used for computation within the function.
auto_opaque (bool) – If this flag is set true and only literal values or their size changes between calls to the function, these variables will be marked and made opaque. This reduces the memory usage, traversal overhead, and can improve the performance of generated kernels. If the flag is set to false, all input variables will be made opaque.
enabled (bool) – If this flag is set to false, the function will not be frozen, and the call will be forwarded to the inner function.
Type traits¶
The functions in this section can be used to infer properties or types of Dr.Jit arrays.
The naming convention with a trailing _v or _t indicates whether a
function returns a value or a type. Evaluation takes place at runtime within
Python. In C++, these expressions are all constexpr (i.e., evaluated at
compile time.).
Array type tests¶
- drjit.is_array_v(arg: object | None) bool¶
Check if the input is a Dr.Jit array instance or type
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargor type(arg) is a Dr.Jit array type, andFalseotherwise
- Return type:
bool
- drjit.is_mask_v(arg: object, /) bool¶
Check whether the input array instance or type is a Dr.Jit mask array or a Python
boolvalue/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents a Dr.Jit mask array or Pythonboolinstance or type.- Return type:
bool
- drjit.is_half_v(arg: object, /) bool¶
Check whether the input array instance or type is a Dr.Jit half-precision floating point array or a Python
halfvalue/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents a Dr.Jit half-precision floating point array or Pythonhalfinstance or type.- Return type:
bool
- drjit.is_float_v(arg: object, /) bool¶
Check whether the input array instance or type is a Dr.Jit floating point array or a Python
floatvalue/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents a Dr.Jit floating point array or Pythonfloatinstance or type.- Return type:
bool
- drjit.is_integral_v(arg: object, /) bool¶
Check whether the input array instance or type is an integral Dr.Jit array or a Python
intvalue/type.Note that a mask array is not considered to be integral.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents an integral Dr.Jit array or Pythonintinstance or type.- Return type:
bool
- drjit.is_arithmetic_v(arg: object, /) bool¶
Check whether the input array instance or type is an arithmetic Dr.Jit array or a Python
intorfloatvalue/type.Note that a mask type (e.g.
bool,drjit.scalar.Array2b, etc.) is not considered to be arithmetic.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents an arithmetic Dr.Jit array or Pythonintorfloatinstance or type.- Return type:
bool
- drjit.is_signed_v(arg: object, /) bool¶
Check whether the input array instance or type is an signed Dr.Jit array or a Python
intorfloatvalue/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents an signed Dr.Jit array or Pythonintorfloatinstance or type.- Return type:
bool
- drjit.is_unsigned_v(arg: object, /) bool¶
Check whether the input array instance or type is an unsigned integer Dr.Jit array or a Python
boolvalue/type (masks and boolean values are also considered to be unsigned).- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents an unsigned Dr.Jit array or Pythonboolinstance or type.- Return type:
bool
- drjit.is_dynamic_v(arg: object, /) bool¶
Check whether the input instance or type represents a dynamically sized Dr.Jit array type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_jit_v(arg: object, /) bool¶
Check whether the input array instance or type represents a type that undergoes just-in-time compilation.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents an array type from thedrjit.cuda.*ordrjit.llvm.*namespaces, andFalseotherwise.- Return type:
bool
- drjit.is_diff_v(arg: object, /) bool¶
Check whether the input is a differentiable Dr.Jit array instance or type.
Note that this is a type-based statement that is unrelated to mathematical differentiability. For example, the integral type
drjit.cuda.ad.Intfrom the CUDA AD namespace satisfiesis_diff_v(..) = 1.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifargrepresents an array type from thedrjit.[cuda/llvm].ad.*namespace, andFalseotherwise.- Return type:
bool
- drjit.is_vector_v(arg: object, /) bool¶
Check whether the input is a Dr.Jit array instance or type representing a vectorial array type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_complex_v(arg: object, /) bool¶
Check whether the input is a Dr.Jit array instance or type representing a complex number.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_matrix_v(arg: object, /) bool¶
Check whether the input is a Dr.Jit array instance or type representing a matrix.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_quaternion_v(arg: object, /) bool¶
Check whether the input is a Dr.Jit array instance or type representing a quaternion.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_tensor_v(arg: object, /) bool¶
Check whether the input is a Dr.Jit array instance or type representing a tensor.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_special_v(arg: object, /) bool¶
Check whether the input is a special Dr.Jit array instance or type.
A special array type requires precautions when performing arithmetic operations like multiplications (complex numbers, quaternions, matrices).
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueif the test was successful, andFalseotherwise.- Return type:
bool
- drjit.is_struct_v(arg: object, /) bool¶
Check if the input is a Dr.Jit-compatible data structure
Custom data structures can be made compatible with various Dr.Jit operations by specifying a
DRJIT_STRUCTmember. See the section on PyTrees for details. This type trait can be used to check for the existence of such a field.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Trueifarghas aDRJIT_STRUCTmember- Return type:
bool
Array properties (shape, type, etc.)¶
- drjit.type_v(arg: object, /) drjit.VarType¶
Returns the scalar type associated with the given Dr.Jit array instance or type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
The associated type
drjit.VarType.Void.- Return type:
- drjit.backend_v(arg: object, /) drjit.JitBackend¶
Returns the backend responsible for the given Dr.Jit array instance or type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
The associated Jit backend or
drjit.JitBackend.None.- Return type:
- drjit.size_v(arg: object, /) int¶
Return the (static) size of the outermost dimension of the provided Dr.Jit array instance or type
Note that this function mainly exists to query type-level information. Use the Python
len()function to query the size in a way that does not distinguish between static and dynamic arrays.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Returns either the static size or
drjit.Dynamicwhenargis a dynamic Dr.Jit array. Returns1for all other types.- Return type:
int
- drjit.depth_v(arg: object, /) int¶
Return the depth of the provided Dr.Jit array instance or type
For example, an array consisting of floating point values (for example,
drjit.scalar.Array3f) has depth1, while an array consisting of sub-arrays (e.g.,drjit.cuda.Array3f) has depth2.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Returns the depth of the input, if it is a Dr.Jit array instance or type. Returns
0for all other inputs.- Return type:
int
- drjit.itemsize_v(arg: object, /) int¶
Return the per-item size (in bytes) of the scalar type underlying a Dr.Jit array
- Parameters:
arg (object) – A Dr.Jit array instance or array type.
- Returns:
Returns the item size array elements in bytes.
- Return type:
int
Bit-level operations¶
- drjit.reinterpret_array(arg0: type[drjit.ArrayBase], arg1: drjit.ArrayBase, /) object¶
Reinterpret the provided Dr.Jit array or tensor as a different type.
This operation reinterprets the input type as another type provided that it has a compatible in-memory layout (this operation is also known as a bit-cast).
- Parameters:
dtype (type) – Target type.
value (object) – A compatible Dr.Jit input array or tensor.
- Returns:
Result of the conversion as described above.
- Return type:
object
- drjit.popcnt(arg: ArrayT, /) ArrayT¶
- drjit.popcnt(arg: int, /) int
Overloaded function.
popcnt(arg: ArrayT, /) -> ArrayT
Return the number of nonzero zero bits.
This function evaluates the component-wise population count of the input scalar, array, or tensor. This function assumes that
argis either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of nonzero zero bits in
arg- Return type:
int | drjit.ArrayBase
popcnt(arg: int, /) -> int
- drjit.lzcnt(arg: ArrayT, /) ArrayT¶
- drjit.lzcnt(arg: int, /) int
Overloaded function.
lzcnt(arg: ArrayT, /) -> ArrayT
Return the number of leading zero bits.
This function evaluates the component-wise leading zero count of the input scalar, array, or tensor. This function assumes that
argis either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation is well-defined when
argis zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of leading zero bits in
arg- Return type:
int | drjit.ArrayBase
lzcnt(arg: int, /) -> int
- drjit.tzcnt(arg: ArrayT, /) ArrayT¶
- drjit.tzcnt(arg: int, /) int
Overloaded function.
tzcnt(arg: ArrayT, /) -> ArrayT
Return the number of trailing zero bits.
This function evaluates the component-wise trailing zero count of the input scalar, array, or tensor. This function assumes that
argis either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation is well-defined when
argis zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of trailing zero bits in
arg- Return type:
int | drjit.ArrayBase
tzcnt(arg: int, /) -> int
- drjit.brev(arg: ArrayT, /) ArrayT¶
- drjit.brev(arg: int, /) int
Overloaded function.
brev(arg: ArrayT, /) -> ArrayT
Reverse the bit representation of an integer value or array.
This function assumes that
argis either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.- Parameters:
arg (int | drjit.ArrayBase) – A Python
intor Dr.Jit integer array.- Returns:
the bit-reversed version of
arg.- Return type:
int | drjit.ArrayBase
brev(arg: int, /) -> int
- drjit.log2i(arg: T, /) T¶
Return the floor of the base-2 logarithm.
This function evaluates the component-wise floor of the base-2 logarithm of the input scalar, array, or tensor. This function assumes that
argis either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation overflows when
argis zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of leading zero bits in the input array
- Return type:
int | drjit.ArrayBase
Standard mathematical functions¶
- drjit.fma(arg0: object, arg1: object, arg2: object, /) object¶
- drjit.fma(arg0: object, arg1: object, arg2: object, /) object
- drjit.fma(arg0: int, arg1: int, arg2: int, /) int
- drjit.fma(arg0: float, arg1: float, arg2: float, /) float
Overloaded function.
fma(arg0: object, arg1: object, arg2: object, /) -> objectfma(arg0: object, arg1: object, arg2: object, /) -> objectfma(arg0: int, arg1: int, arg2: int, /) -> int
Perform a fused multiply-addition (FMA) of the inputs.
Given arguments
arg0,arg1, andarg2, this operation computesarg0*arg1+arg2using only one final rounding step. The operation is not only more accurate, but also more efficient, since FMA maps to a native machine instruction on all platforms targeted by Dr.Jit.When the input is complex- or quaternion-valued, the function internally uses a complex or quaternion product. In this case, it reduces the number of internal rounding steps instead of avoiding them altogether.
While FMA is traditionally a floating point operation, Dr.Jit also implements FMA for integer arrays and maps it onto dedicated instructions provided by the backend if possible (e.g.
mad.lo.*for CUDA/PTX).- Parameters:
arg0 (float | drjit.ArrayBase) – First multiplication operand
arg1 (float | drjit.ArrayBase) – Second multiplication operand
arg2 (float | drjit.ArrayBase) – Additive operand
- Returns:
Result of the FMA operation
- Return type:
float | drjit.ArrayBase
fma(arg0: float, arg1: float, arg2: float, /) -> float
- drjit.abs(arg: drjit.nn.CoopVec, /) drjit.nn.CoopVec¶
- drjit.abs(arg: ArrayT, /) ArrayT
- drjit.abs(arg: int, /) int
- drjit.abs(arg: float, /) float
Overloaded function.
abs(arg: drjit.nn.CoopVec, /) -> drjit.nn.CoopVecabs(arg: ArrayT, /) -> ArrayT
Compute the absolute value of the provided input.
This function evaluates the component-wise absolute value of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
- Parameters:
arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Absolute value of the input
- Return type:
int | float | drjit.ArrayBase
abs(arg: int, /) -> intabs(arg: float, /) -> float
- drjit.minimum(arg0: object, arg1: object, /) object¶
- drjit.minimum(arg0: int, arg1: int, /) int
- drjit.minimum(arg0: object, arg1: object, /) object
- drjit.minimum(arg0: float, arg1: float, /) float
Overloaded function.
minimum(arg0: object, arg1: object, /) -> objectminimum(arg0: int, arg1: int, /) -> intminimum(arg0: object, arg1: object, /) -> object
Compute the element-wise minimum value of the provided inputs.
(Not to be confused with
drjit.min(), which reduces the input along the specified axes to determine the minimum)- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Minimum of the input(s)
- Return type:
int | float | drjit.ArrayBase
minimum(arg0: float, arg1: float, /) -> float
- drjit.maximum(arg0: object, arg1: object, /) object¶
- drjit.maximum(arg0: int, arg1: int, /) int
- drjit.maximum(arg0: object, arg1: object, /) object
- drjit.maximum(arg0: float, arg1: float, /) float
Overloaded function.
maximum(arg0: object, arg1: object, /) -> objectmaximum(arg0: int, arg1: int, /) -> intmaximum(arg0: object, arg1: object, /) -> object
Compute the element-wise maximum value of the provided inputs.
(Not to be confused with
drjit.max(), which reduces the input along the specified axes to determine the maximum)- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Maximum of the input(s)
- Return type:
int | float | drjit.ArrayBase
maximum(arg0: float, arg1: float, /) -> float
- drjit.sqrt(arg: ArrayT, /) ArrayT¶
- drjit.sqrt(arg: float, /) float
Overloaded function.
sqrt(arg: ArrayT, /) -> ArrayT
Evaluate the square root of the provided input.
This function evaluates the component-wise square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
Negative inputs produce a NaN output value. Consider using the
safe_sqrt()function to work around issues where the input might occasionally be negative due to prior round-off errors.Another noteworthy behavior of the square root function is that it has an infinite derivative at
arg=0, which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. Thesafe_sqrt()function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Square root of the input
- Return type:
float | drjit.ArrayBase
sqrt(arg: float, /) -> float
- drjit.cbrt(arg: ArrayT, /) ArrayT¶
- drjit.cbrt(arg: float, /) float
Overloaded function.
cbrt(arg: ArrayT, /) -> ArrayT
Evaluate the cube root of the provided input.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Cube root of the input
- Return type:
float | drjit.ArrayBase
cbrt(arg: float, /) -> float
- drjit.rcp(arg: ArrayT, /) ArrayT¶
- drjit.rcp(arg: float, /) float
Overloaded function.
rcp(arg: ArrayT, /) -> ArrayT
Evaluate the reciprocal (1 / arg) of the provided input.
When
argis a CUDA single precision array, the operation is implemented slightly approximately—see the documentation of the instructionrcp.approx.ftz.f32in the NVIDIA PTX manual for details. For full IEEE-754 compatibility, unsetdrjit.JitFlag.FastMath.When called with a matrix-, complex- or quaternion-valued array, this function uses the matrix, complex, or quaternion multiplicative inverse to evaluate the reciprocal.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Reciprocal of the input
- Return type:
float | drjit.ArrayBase
rcp(arg: float, /) -> float
- drjit.rsqrt(arg: ArrayT, /) ArrayT¶
- drjit.rsqrt(arg: float, /) float
Overloaded function.
rsqrt(arg: ArrayT, /) -> ArrayT
Evaluate the reciprocal square root (1 / sqrt(arg)) of the provided input.
This function evaluates the component-wise reciprocal square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
When
argis a single precision array, the operation is implemented slightly approximately—see the documentation of the instructionrsqrt.approx.ftz.f32in the NVIDIA PTX manual for details on how this works on the CUDA backend. For full IEEE-754 compatibility, unsetdrjit.JitFlag.FastMath.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Reciprocal square root of the input
- Return type:
float | drjit.ArrayBase
rsqrt(arg: float, /) -> float
- drjit.clip(value, min, max)¶
Clip the provided input to the given interval.
This function is equivalent to
dr.maximum(dr.minimum(value, max), min)
- Parameters:
value (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
min (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
max (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Clipped input
- Return type:
float | drjit.ArrayBase
- drjit.ceil(arg: ArrayT, /) ArrayT¶
- drjit.ceil(arg: float, /) float
Overloaded function.
ceil(arg: ArrayT, /) -> ArrayT
Evaluate the ceiling, i.e. the smallest integer >= arg.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Ceiling of the input
- Return type:
float | drjit.ArrayBase
ceil(arg: float, /) -> float
- drjit.floor(arg: ArrayT, /) ArrayT¶
- drjit.floor(arg: float, /) float
Overloaded function.
floor(arg: ArrayT, /) -> ArrayT
Evaluate the floor, i.e. the largest integer <= arg.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Floor of the input
- Return type:
float | drjit.ArrayBase
floor(arg: float, /) -> float
- drjit.trunc(arg: ArrayT, /) ArrayT¶
- drjit.trunc(arg: float, /) float
Overloaded function.
trunc(arg: ArrayT, /) -> ArrayT
Truncates arg to the nearest integer by towards zero.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Truncated result
- Return type:
float | drjit.ArrayBase
trunc(arg: float, /) -> float
- drjit.round(arg: ArrayT, /) ArrayT¶
- drjit.round(arg: float, /) float
Overloaded function.
round(arg: ArrayT, /) -> ArrayT
Rounds the input to the nearest integer using Banker’s rounding for half-way values.
This function is equivalent to
std::rintin C++. It does not convert the type of the input array. A separate cast is necessary when integer output is desired.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Rounded result
- Return type:
float | drjit.ArrayBase
round(arg: float, /) -> float
- drjit.sign(arg, /)¶
Return the element-wise sign of the provided array.
The function returns
\[\mathrm{sign}(\texttt{arg}) = \begin{cases} 1&\texttt{arg}>=0,\\ -1&\mathrm{otherwise}. \end{cases}\]- Parameters:
arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Sign of the input array
- Return type:
float | int | drjit.ArrayBase
- drjit.copysign(arg0, arg1, /)¶
Copy the sign of
arg1toarg0element-wise.- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to change the sign of
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to copy the sign from
- Returns:
The values of
arg0with the sign ofarg1- Return type:
float | int | drjit.ArrayBase
- drjit.mulsign(arg0, arg1, /)¶
Multiply
arg0by the sign ofarg1element-wise.This function is equivalent to
a * dr.sign(b)
- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to multiply the sign of
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to take the sign from
- Returns:
The values of
arg0multiplied with the sign ofarg1- Return type:
float | int | drjit.ArrayBase
- drjit.step(arg0: object, arg1: object, /) object¶
- drjit.step(arg0: object, arg1: object, /) object
Overloaded function.
step(arg0: object, arg1: object, /) -> object
Step function.
This function generates a step function by comparing
arg0toarg1. The function is equivalent todr.select( arg0 < arg1, 0, # if arg0 < arg1 1, # if arg1 >= arg1 ) Args: arg0 (object): A Dr.Jit array/tensor or Python arithmetic type arg1 (object): A Dr.Jit array/tensor or Python arithmetic type Returns: object: The computed array as described above
step(arg0: object, arg1: object, /) -> object
- drjit.mul_hi(arg0: object, arg1: object, /) object¶
- drjit.mul_hi(arg0: int, arg1: int, /) int
- drjit.mul_hi(arg0: int, arg1: int, /) int
Overloaded function.
mul_hi(arg0: object, arg1: object, /) -> object
Return the high part of an integer product
This function multiplies two signed or unsigned 32 bit operands and returns the upper 32 bit of the result.
- Parameters:
arg0 (int | drjit.ArrayBase) – A Python or Dr.Jit array
arg1 (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
High part of the multiplication result
- Return type:
int | drjit.ArrayBase
mul_hi(arg0: int, arg1: int, /) -> intmul_hi(arg0: int, arg1: int, /) -> int
- drjit.mul_wide(arg0: object, arg1: object, /) object¶
- drjit.mul_wide(arg0: int, arg1: int, /) int
- drjit.mul_wide(arg0: int, arg1: int, /) int
Overloaded function.
mul_wide(arg0: object, arg1: object, /) -> object
Return all bits of an integer product
This function multiplies two signed or unsigned 32 bit operands and returns the full 64 bit result.
- Parameters:
arg0 (int | drjit.ArrayBase) – A Python or Dr.Jit array
arg1 (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
High part of the multiplication result
- Return type:
int | drjit.ArrayBase
mul_wide(arg0: int, arg1: int, /) -> intmul_wide(arg0: int, arg1: int, /) -> int
Operations for vectors and matrices¶
- drjit.cross(arg0: ArrayT, arg1: ArrayT, /) ArrayT¶
Returns the cross-product of the two input 3D arrays
- Parameters:
arg0 (drjit.ArrayBase) – A Dr.Jit 3D array
arg1 (drjit.ArrayBase) – A Dr.Jit 3D array
- Returns:
Cross-product of the two input 3D arrays
- Return type:
- drjit.det(arg, /)¶
Compute the determinant of the provided Dr.Jit matrix.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The determinant value of the input matrix
- Return type:
- drjit.diag(arg, /)¶
This function either returns the diagonal entries of the provided Dr.Jit matrix, or it constructs a new matrix from the diagonal entries.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The diagonal matrix of the input matrix
- Return type:
- drjit.trace(arg, /)¶
Returns the trace of the provided Dr.Jit matrix.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The trace of the input matrix
- Return type:
drjit.value_t(arg)
- drjit.matmul(A: object, B: object, At: bool = False, Bt: bool = False) object¶
Compute a matrix-matrix, matrix-vector, vector-matrix, or inner product.
This function implements the semantics of the
@operator introduced in Python’s PEP 465. There is no practical difference between usingdrjit.matul()or@in Dr.Jit-based code. Multiplication of matrix types (e.g.,drjit.scalar.Matrix2f) using the standard multiplication operator (*) is also based on matrix multiplication.This function takes two Dr.Jit arrays and picks one of the following 5 cases based on their leading fixed-size dimensions.
Matrix-matrix product: If both arrays have leading static dimensions
(n, n), they are multiplied like conventional matrices.Matrix-vector product: If
arg0has leading static dimensions(n, n)andarg1has leading static dimension(n,), the operation conceptually appends a trailing 1-sized dimension toarg1, multiplies, and then removes the extra dimension from the result.Vector-matrix product: If
arg0has leading static dimensions(n,)andarg1has leading static dimension(n, n), the operation conceptually prepends a leading 1-sized dimension toarg0, multiplies, and then removes the extra dimension from the result.Inner product: If
arg0andarg1have leading static dimensions(n,), the operation returns the sum of the elements ofarg0*arg1.Scalar product: If
arg0orarg1is a scalar, the operation scales the elements of the other argument.
It is legal to combine vectorized and non-vectorized types, e.g.
dr.matmul(dr.scalar.Matrix4f(...), dr.cuda.Matrix4f(...))
Also, note that doesn’t matter whether an input is an instance of a matrix type or a similarly-shaped nested array—for example,
drjit.scalar.Matrix3f()anddrjit.scalar.Array33f()have the same shape and are treated identically.Tensor matrix multiplication. In addition to the fixed-size cases above, this function also accepts two Dr.Jit tensors of matching element type and backend. It fully replicates the NumPy / PyTorch
matmulsemantics including batched matrix products, broadcasting of leading batch axes, matrix-vector products, and inner products. The optionalAt/Btflags transpose the associated operand on the fly at essentially no cost on either backend.Under the hood, the CUDA backend uses a block-matrix multiplication in which each thread block cooperatively stages tiles of
AandBthrough shared memory and accumulates the output tile in registers. The CPU backend uses a GotoBLAS-style tiled GEMM with a vectorized microkernel, parallelized over both axes of the output tile grid via nanothread, so shapes with few output rows still use every core. Broadcasts along batch dimensions are consumed directly by the kernel via zero strides, and under automatic differentiation the reverse-mode gradient of a broadcast operand folds its sum-over-batch into the backward GEMM’s contraction, so no expanded copy of a broadcast operand is materialized in either the primal or the derivative.Half-precision inputs are multiplied and summed in single precision throughout the reduction; the result is narrowed to half precision only at the final store.
Note that the CUDA implementation does not use the tensor cores available on recent NVIDIA GPUs, which can greatly accelerate half-precision math. For fp16 matmuls, Dr.Jit is therefore not competitive with PyTorch.
Supported element types are
drjit.VarType.Float16,drjit.VarType.Float32,drjit.VarType.Float64,drjit.VarType.Int32, anddrjit.VarType.UInt32.Note
Performance tips (CUDA). The tensor matmul ships a small family of precompiled kernels and picks the largest tile it can align to the operand strides. To land on the fastest path, the contiguous dimensions of both operands and
Nshould be divisible byV, whereV = 8forFloat16,V = 4forFloat32/Int32/UInt32,V = 2forFloat64.
When this divisibility doesn’t hold the kernel falls back to a smaller tile with scalar loads, which can be an order of magnitude slower. The CPU (LLVM) backend is not affected by this.
The fixed-size array path (types such as
drjit.cuda.Matrix4f(), as opposed to Dr.Jit tensors) only handles small matrices whose dimensions are known at compile time. To multiply large dynamic matrices, use the N-D tensor path described above.- Parameters:
arg0 (dr.ArrayBase) – Dr.Jit array or Dr.Jit tensor.
arg1 (dr.ArrayBase) – Dr.Jit array or Dr.Jit tensor.
At (bool) – If
True, transpose the last two dimensions ofarg0on the fly. Only applies to the tensor path and is invalid whenarg0is 1-D. Defaults toFalse.Bt (bool) – If
True, transpose the last two dimensions ofarg1on the fly. Only applies to the tensor path and is invalid whenarg1is 1-D. Defaults toFalse.
- Returns:
The result of the operation as defined above
- Return type:
object
- drjit.hypot(a, b)¶
Computes \(\sqrt{x^2+y^2}\) while avoiding overflow and underflow.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
The computed hypotenuse.
- Return type:
float | drjit.ArrayBase
- drjit.normalize(arg: T, /) T¶
Normalize the input vector so that it has unit length and return the result.
This operation is equivalent to
arg * dr.rsqrt(dr.squared_norm(arg))
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit array type
- Returns:
Unit-norm version of the input
- Return type:
- drjit.lerp(a, b, t)¶
Linearly blend between two values.
This function computes
\[\mathrm{lerp}(a, b, t) = (1-t) a + t b\]In other words, it linearly blends between \(a\) and \(b\) based on the value \(t\) that is typically on the interval \([0, 1]\).
It does so using two fused multiply-additions (
drjit.fma()) to improve performance and avoid numerical errors.- Parameters:
a (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
b (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
t (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Interpolated result
- Return type:
float | drjit.ArrayBase
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT¶
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
- drjit.slerp(a: ArrayT, b: ArrayT, t: dr.ArrayBase | float) ArrayT
Linearly interpolate between two rotation quaternions
aandb. Works analogously tolerp().- Parameters:
a (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
b (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
t (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Interpolated result
- Return type:
float | drjit.ArrayBase
- drjit.sh_eval(d: ArrayBase, order: int) list¶
Evalute real spherical harmonics basis function up to a specified order.
The input
dmust be a normalized 3D Cartesian coordinate vector. The function returns a list containing all spherical harmonic basis functions evaluated with respect todup to the desired order, for a total of(order+1)**2output values.The implementation relies on efficient pre-generated branch-free code with aggressive constant folding and common subexpression elimination. It admits scalar and Jit-compiled input arrays. Evaluation routines are included for orders
0to10. Requesting higher orders triggers a runtime exception.This automatically generated code is based on the paper Efficient Spherical Harmonic Evaluation, Journal of Computer Graphics Techniques (JCGT), vol. 2, no. 2, 84-90, 2013 by Peter-Pike Sloan.
The SciPy equivalent of this function is given by
def sh_eval(d, order: int): from scipy.special import sph_harm_y theta, phi = np.arccos(d.z), np.arctan2(d.y, d.x) r = [] for l in range(order + 1): for m in range(-l, l + 1): Y = sph_harm(l, abs(m), theta, phi) if m > 0: Y = np.sqrt(2) * Y.real elif m < 0: Y = np.sqrt(2) * Y.imag r.append(Y.real) return r
The Mathematica equivalent of a specific entry is given by:
SphericalHarmonicQ[l_, m_, d_] := Block[{θ, ϕ}, θ = ArcCos[d[[3]]]; ϕ = ArcTan[d[[1]], d[[2]]]; Piecewise[{ {SphericalHarmonicY[l, m, θ, ϕ], m == 0}, {Sqrt[2] * Re[SphericalHarmonicY[l, m, θ, ϕ]], m > 0}, {Sqrt[2] * Im[SphericalHarmonicY[l, -m, θ, ϕ]], m < 0} }] ]
- drjit.frob(arg, /)¶
Returns the squared Frobenius norm of the provided Dr.Jit matrix.
The squared Frobenius norm is defined as the sum of the squares of its elements:
\[\sum_{i=1}^m \sum_{j=1}^n a_{i j}^2\]- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The squared Frobenius norm of the input matrix
- Return type:
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT¶
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
- drjit.rotate(dtype: Type[ArrayT], axis: dr.ArrayBase, angle: dr.ArrayBase | float) ArrayT
Constructs a rotation quaternion encoding a ritation by
angleradians around theaxis.The function requires
axisto be normalized.- Parameters:
axis (drjit.ArrayBase) – A 3-dimensional Dr.Jit array representing the rotation axis
angle (float | drjit.ArrayBase) – Rotation angle.
- Returns:
The resulting rotation quaternion
- Return type:
- drjit.polar_decomp(arg, it=10)¶
Returns the polar decomposition of the provided Dr.Jit matrix.
The polar decomposition separates the matrix into a rotation followed by a scaling along each of its eigen vectors. This decomposition always exists for square matrices.
The implementation relies on an iterative algorithm, where the number of iterations can be controlled by the argument
it(tradeoff between precision and computational cost).- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
it (int) – Number of iterations to be taken by the algorithm.
- Returns:
A tuple containing the rotation matrix and the scaling matrix resulting from the decomposition.
- Return type:
tuple
- drjit.matrix_to_quat(arg: drjit.scalar.Matrix3f, /) drjit.scalar.Quaternion4f¶
- drjit.matrix_to_quat(arg: drjit.scalar.Matrix4f, /) drjit.scalar.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.scalar.Matrix3f64, /) drjit.scalar.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.scalar.Matrix4f64, /) drjit.scalar.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.cuda.ad.Matrix3f, /) drjit.cuda.ad.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.cuda.ad.Matrix4f, /) drjit.cuda.ad.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.cuda.ad.Matrix3f64, /) drjit.cuda.ad.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.cuda.ad.Matrix4f64, /) drjit.cuda.ad.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.cuda.Matrix3f, /) drjit.cuda.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.cuda.Matrix4f, /) drjit.cuda.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.cuda.Matrix3f64, /) drjit.cuda.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.cuda.Matrix4f64, /) drjit.cuda.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.llvm.ad.Matrix3f, /) drjit.llvm.ad.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.llvm.ad.Matrix4f, /) drjit.llvm.ad.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.llvm.ad.Matrix3f64, /) drjit.llvm.ad.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.llvm.ad.Matrix4f64, /) drjit.llvm.ad.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.llvm.Matrix3f, /) drjit.llvm.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.llvm.Matrix4f, /) drjit.llvm.Quaternion4f
- drjit.matrix_to_quat(arg: drjit.llvm.Matrix3f64, /) drjit.llvm.Quaternion4f64
- drjit.matrix_to_quat(arg: drjit.llvm.Matrix4f64, /) drjit.llvm.Quaternion4f64
Converts a 3x3 matrix containing a pure rotation into a rotation quaternion.
- Parameters:
arg (drjit.ArrayBase) – A 3x3 Dr.Jit matrix instance
- Returns:
The Dr.Jit quaternion corresponding the to input matrix.
- Return type:
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT¶
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
- drjit.quat_to_matrix(dtype: Type[ArrayT], q: dr.ArrayBase) ArrayT
Converts a rotation quaternion into a 3x3 rotation matrix.
- Parameters:
arg (drjit.ArrayBase) – A 3x3 Dr.Jit matrix typ
- Returns:
The resulting 3x3 matrix
- Return type:
- drjit.quat_to_euler(arg: drjit.scalar.Quaternion4f, /) drjit.scalar.Array3f¶
- drjit.quat_to_euler(arg: drjit.scalar.Quaternion4f64, /) drjit.scalar.Array3f64
- drjit.quat_to_euler(arg: drjit.cuda.ad.Quaternion4f, /) drjit.cuda.ad.Array3f
- drjit.quat_to_euler(arg: drjit.cuda.ad.Quaternion4f64, /) drjit.cuda.ad.Array3f64
- drjit.quat_to_euler(arg: drjit.cuda.Quaternion4f, /) drjit.cuda.Array3f
- drjit.quat_to_euler(arg: drjit.cuda.Quaternion4f64, /) drjit.cuda.Array3f64
- drjit.quat_to_euler(arg: drjit.llvm.ad.Quaternion4f, /) drjit.llvm.ad.Array3f
- drjit.quat_to_euler(arg: drjit.llvm.ad.Quaternion4f64, /) drjit.llvm.ad.Array3f64
- drjit.quat_to_euler(arg: drjit.llvm.Quaternion4f, /) drjit.llvm.Array3f
- drjit.quat_to_euler(arg: drjit.llvm.Quaternion4f64, /) drjit.llvm.Array3f64
Converts a rotation quaternion into its Euler angle representation. The Euler angle order is “XYZ”.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit quaternion type
- Returns:
A 3D Dr.Jit array containing Euler angles.
- Return type:
- drjit.euler_to_quat(arg: drjit.scalar.Array3f, /) drjit.scalar.Quaternion4f¶
- drjit.euler_to_quat(arg: drjit.scalar.Array3f64, /) drjit.scalar.Quaternion4f64
- drjit.euler_to_quat(arg: drjit.cuda.ad.Array3f, /) drjit.cuda.ad.Quaternion4f
- drjit.euler_to_quat(arg: drjit.cuda.ad.Array3f64, /) drjit.cuda.ad.Quaternion4f64
- drjit.euler_to_quat(arg: drjit.cuda.Array3f, /) drjit.cuda.Quaternion4f
- drjit.euler_to_quat(arg: drjit.cuda.Array3f64, /) drjit.cuda.Quaternion4f64
- drjit.euler_to_quat(arg: drjit.llvm.ad.Array3f, /) drjit.llvm.ad.Quaternion4f
- drjit.euler_to_quat(arg: drjit.llvm.ad.Array3f64, /) drjit.llvm.ad.Quaternion4f64
- drjit.euler_to_quat(arg: drjit.llvm.Array3f, /) drjit.llvm.Quaternion4f
- drjit.euler_to_quat(arg: drjit.llvm.Array3f64, /) drjit.llvm.Quaternion4f64
Converts Euler angles into a rotation quaternion. The order of the Euler angles is “XYZ”.
- Parameters:
arg (drjit.ArrayBase) – A 3D Dr.Jit array containing the Euler angles.
- Returns:
A rotation quaternion.
- Return type:
- drjit.transform_decompose(arg, it=10)¶
Performs a polar decomposition of a non-perspective 4x4 homogeneous coordinate matrix and returns a tuple of
A positive definite 3x3 matrix containing an inhomogeneous scaling operation
A rotation quaternion
A 3D translation vector
This representation is helpful when animating keyframe animations.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
it (int) – Number of iterations to be taken by the polar decomposition algorithm.
- Returns:
The tuple containing the scaling matrix, rotation quaternion and 3D translation vector.
- Return type:
tuple
- drjit.transform_compose(S, Q, T, /)¶
This function composes a 4x4 homogeneous coordinate transformation from the given scale, rotation, and translation. It performs the reverse of
transform_decompose().- Parameters:
S (drjit.ArrayBase) – A Dr.Jit matrix type representing the scaling part
Q (drjit.ArrayBase) – A Dr.Jit quaternion type representing the rotation part
T (drjit.ArrayBase) – A 3D Dr.Jit array type representing the translation part
- Returns:
The Dr.Jit matrix resulting from the composition described above.
- Return type:
- drjit.unit_angle(a, b)¶
Numerically well-behaved routine for computing the angle between two normalized 3D direction vectors
This should be used wherever one is tempted to compute the angle via
acos(dot(a, b)). It yields significantly more accurate results when the angle is close to zero.By Don Hatch.
- drjit.quat_apply(arg0: drjit.scalar.Quaternion4f, arg1: drjit.scalar.Array3f, /) drjit.scalar.Array3f¶
- drjit.quat_apply(arg0: drjit.scalar.Quaternion4f64, arg1: drjit.scalar.Array3f64, /) drjit.scalar.Array3f64
- drjit.quat_apply(arg0: drjit.cuda.ad.Quaternion4f, arg1: drjit.cuda.ad.Array3f, /) drjit.cuda.ad.Array3f
- drjit.quat_apply(arg0: drjit.cuda.ad.Quaternion4f64, arg1: drjit.cuda.ad.Array3f64, /) drjit.cuda.ad.Array3f64
- drjit.quat_apply(arg0: drjit.cuda.Quaternion4f, arg1: drjit.cuda.Array3f, /) drjit.cuda.Array3f
- drjit.quat_apply(arg0: drjit.cuda.Quaternion4f64, arg1: drjit.cuda.Array3f64, /) drjit.cuda.Array3f64
- drjit.quat_apply(arg0: drjit.llvm.ad.Quaternion4f, arg1: drjit.llvm.ad.Array3f, /) drjit.llvm.ad.Array3f
- drjit.quat_apply(arg0: drjit.llvm.ad.Quaternion4f64, arg1: drjit.llvm.ad.Array3f64, /) drjit.llvm.ad.Array3f64
- drjit.quat_apply(arg0: drjit.llvm.Quaternion4f, arg1: drjit.llvm.Array3f, /) drjit.llvm.Array3f
- drjit.quat_apply(arg0: drjit.llvm.Quaternion4f64, arg1: drjit.llvm.Array3f64, /) drjit.llvm.Array3f64
Applies a rotation quaternion to a 3D vector.
This operation is equivalent to
quat_to_matrix(arg0) @ arg1but more efficient.- Parameters:
arg0 (drjit.ArrayBase) – A rotation quaterion
arg1 (drjit.ArrayBase) – A 3D Dr.Jit array
- Returns:
The rotated vector
- Return type:
Operations for complex values and quaternions¶
- drjit.conj(arg, /)¶
Returns the conjugate of the provided complex or quaternion-valued array. For all other types, it returns the input unchanged.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit 3D array
- Returns:
Conjugate form of the input
- Return type:
- drjit.arg(z, /)¶
Return the argument of a complex Dr.Jit array.
The argument refers to the angle (in radians) between the positive real axis and a vector towards
zin the complex plane. When the input isn’t complex-valued, the function returns \(0\) or \(\pi\) depending on the sign ofz.- Parameters:
z (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Argument of the complex input array
- Return type:
float | drjit.ArrayBase
- drjit.real(arg, /)¶
Return the real part of a complex or quaternion-valued input.
When the input isn’t complex- or quaternion-valued, the function returns the input unchanged.
- Parameters:
arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Real part of the input array
- Return type:
float | drjit.ArrayBase
- drjit.imag(arg, /)¶
Return the imaginary part of a complex or quaternion-valued input.
When the input isn’t complex- or quaternion-valued, the function returns zero.
- Parameters:
arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Imaginary part of the input array
- Return type:
float | drjit.ArrayBase
Transcendental functions¶
Dr.Jit implements the most common transcendental functions using methods that are based on the CEPHES math library. The accuracy of these approximations is documented in a set of tables below.
Trigonometric functions¶
- drjit.sin(arg: ArrayT, /) ArrayT¶
- drjit.sin(arg: float, /) float
Overloaded function.
sin(arg: ArrayT, /) -> ArrayT
Evaluate the sine function.
This function evaluates the component-wise sine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library and is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Sine of the input
- Return type:
float | drjit.ArrayBase
sin(arg: float, /) -> float
- drjit.cos(arg: ArrayT, /) ArrayT¶
- drjit.cos(arg: float, /) float
Overloaded function.
cos(arg: ArrayT, /) -> ArrayT
Evaluate the cosine function.
This function evaluates the component-wise cosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Cosine of the input
- Return type:
float | drjit.ArrayBase
cos(arg: float, /) -> float
- drjit.sincos(arg: ArrayT, /) tuple[ArrayT, ArrayT]¶
- drjit.sincos(arg: float, /) tuple[float, float]
Overloaded function.
sincos(arg: ArrayT, /) -> tuple[ArrayT, ArrayT]
Evaluate both sine and cosine functions at the same time.
This function simultaneously evaluates the component-wise sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to
drjit.sin()anddrjit.cos()when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation instead uses the hardware’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Sine and cosine of the input
- Return type:
(float, float) | (drjit.ArrayBase, drjit.ArrayBase)
sincos(arg: float, /) -> tuple[float, float]
- drjit.tan(arg: ArrayT, /) ArrayT¶
- drjit.tan(arg: float, /) float
Overloaded function.
tan(arg: ArrayT, /) -> ArrayT
Evaluate the tangent function.
This function evaluates the component-wise tangent function associated with each entry of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Tangent of the input
- Return type:
float | drjit.ArrayBase
tan(arg: float, /) -> float
- drjit.asin(arg: ArrayT, /) ArrayT¶
- drjit.asin(arg: float, /) float
Overloaded function.
asin(arg: ArrayT, /) -> ArrayT
Evaluate the arcsine function.
This function evaluates the component-wise arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when called with a complex-valued input.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the
safe_asin()function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The
safe_asin()function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arcsine of the input
- Return type:
float | drjit.ArrayBase
asin(arg: float, /) -> float
- drjit.acos(arg: ArrayT, /) ArrayT¶
- drjit.acos(arg: float, /) float
Overloaded function.
acos(arg: ArrayT, /) -> ArrayT
Evaluate the arccosine function.
This function evaluates the component-wise arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the
safe_acos()function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The
safe_acos()function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arccosine of the input
- Return type:
float | drjit.ArrayBase
acos(arg: float, /) -> float
- drjit.atan(arg: ArrayT, /) ArrayT¶
- drjit.atan(arg: float, /) float
Overloaded function.
atan(arg: ArrayT, /) -> ArrayT
Evaluate the arctangent function.
This function evaluates the component-wise arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arctangent of the input
- Return type:
float | drjit.ArrayBase
atan(arg: float, /) -> float
- drjit.atan2(arg0: object, arg1: object, /) object¶
- drjit.atan2(arg0: float, arg1: float, /) float
Overloaded function.
atan2(arg0: object, arg1: object, /) -> object
Evaluate the four-quadrant arctangent function.
This function is currently only implemented for real-valued inputs.
See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
y (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
x (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arctangent of
y/x, using the argument signs to determine the quadrant of the return value- Return type:
float | drjit.ArrayBase
atan2(arg0: float, arg1: float, /) -> float
Hyperbolic functions¶
- drjit.sinh(arg: ArrayT, /) ArrayT¶
- drjit.sinh(arg: float, /) float
Overloaded function.
sinh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic sine function.
This function evaluates the component-wise hyperbolic sine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic sine of the input
- Return type:
float | drjit.ArrayBase
sinh(arg: float, /) -> float
- drjit.cosh(arg: ArrayT, /) ArrayT¶
- drjit.cosh(arg: float, /) float
Overloaded function.
cosh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic cosine function.
This function evaluates the component-wise hyperbolic cosine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic cosine of the input
- Return type:
float | drjit.ArrayBase
cosh(arg: float, /) -> float
- drjit.sincosh(arg: ArrayT, /) tuple[ArrayT, ArrayT]¶
- drjit.sincosh(arg: float, /) tuple[float, float]
Overloaded function.
sincosh(arg: ArrayT, /) -> tuple[ArrayT, ArrayT]
Evaluate both hyperbolic sine and cosine functions at the same time.
This function simultaneously evaluates the component-wise hyperbolic sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to
drjit.sinh()anddrjit.cosh()when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic sine and cosine of the input
- Return type:
(float, float) | (drjit.ArrayBase, drjit.ArrayBase)
sincosh(arg: float, /) -> tuple[float, float]
- drjit.tanh(arg: drjit.nn.CoopVec, /) drjit.nn.CoopVec¶
- drjit.tanh(arg: ArrayT, /) ArrayT
- drjit.tanh(arg: float, /) float
Overloaded function.
tanh(arg: drjit.nn.CoopVec, /) -> drjit.nn.CoopVectanh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic tangent function.
This function evaluates the component-wise hyperbolic tangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic tangent of the input
- Return type:
float | drjit.ArrayBase
tanh(arg: float, /) -> float
- drjit.asinh(arg: ArrayT, /) ArrayT¶
- drjit.asinh(arg: float, /) float
Overloaded function.
asinh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic arcsine function.
This function evaluates the component-wise hyperbolic arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arcsine of the input
- Return type:
float | drjit.ArrayBase
asinh(arg: float, /) -> float
- drjit.acosh(arg: ArrayT, /) ArrayT¶
- drjit.acosh(arg: float, /) float
Overloaded function.
acosh(arg: ArrayT, /) -> ArrayT
Hyperbolic arccosine approximation.
This function evaluates the component-wise hyperbolic arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arccosine of the input
- Return type:
float | drjit.ArrayBase
acosh(arg: float, /) -> float
- drjit.atanh(arg: ArrayT, /) ArrayT¶
- drjit.atanh(arg: float, /) float
Overloaded function.
atanh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic arctangent function.
This function evaluates the component-wise hyperbolic arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arctangent of the input
- Return type:
float | drjit.ArrayBase
atanh(arg: float, /) -> float
Exponentials, logarithms, power function¶
- drjit.log2(arg: drjit.nn.CoopVec, /) drjit.nn.CoopVec¶
- drjit.log2(arg: ArrayT, /) ArrayT
- drjit.log2(arg: float, /) float
Overloaded function.
log2(arg: drjit.nn.CoopVec, /) -> drjit.nn.CoopVeclog2(arg: ArrayT, /) -> ArrayT
Evaluate the base-2 logarithm.
This function evaluates the component-wise base-2 logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Base-2 logarithm of the input
- Return type:
float | drjit.ArrayBase
log2(arg: float, /) -> float
- drjit.log(arg: ArrayT, /) ArrayT¶
- drjit.log(arg: float, /) float
Overloaded function.
log(arg: ArrayT, /) -> ArrayT
Evaluate the natural logarithm.
This function evaluates the component-wise natural logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Natural logarithm of the input
- Return type:
float | drjit.ArrayBase
log(arg: float, /) -> float
- drjit.exp2(arg: drjit.nn.CoopVec, /) drjit.nn.CoopVec¶
- drjit.exp2(arg: ArrayT, /) ArrayT
- drjit.exp2(arg: float, /) float
Overloaded function.
exp2(arg: drjit.nn.CoopVec, /) -> drjit.nn.CoopVecexp2(arg: ArrayT, /) -> ArrayT
Evaluate
2raised to a given power.This function evaluates the component-wise base-2 exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Base-2 exponential of the input
- Return type:
float | drjit.ArrayBase
exp2(arg: float, /) -> float
- drjit.exp(arg: ArrayT, /) ArrayT¶
- drjit.exp(arg: float, /) float
Overloaded function.
exp(arg: ArrayT, /) -> ArrayT
Evaluate the natural exponential function.
This function evaluates the component-wise natural exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
argis a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Natural exponential of the input
- Return type:
float | drjit.ArrayBase
exp(arg: float, /) -> float
- drjit.power(arg0: int, arg1: int, /) float¶
- drjit.power(arg0: float, arg1: float, /) float
- drjit.power(arg0: object, arg1: object, /) object
Overloaded function.
power(arg0: int, arg1: int, /) -> float
Raise the first argument to a power specified via the second argument.
This function evaluates the component-wise power of the input scalar, array, or tensor arguments. When called with a complex or quaternion-valued inputs, it uses a suitable generalization of the operation.
When
arg1is a Pythonintor integralfloatvalue, the function reduces operation to a sequence of multiplies and adds (potentially followed by a reciprocation operation whenarg1is negative).The general case involves recursive use of the identity
power(arg0, arg1) = exp2(log2(arg0) * arg1).There is no difference between using
drjit.power()and the builtin Python**operator.- Parameters:
arg (object) – A Python or Dr.Jit arithmetic type
- Returns:
The result of the operation
arg0**arg1- Return type:
object
power(arg0: float, arg1: float, /) -> floatpower(arg0: object, arg1: object, /) -> object
Other¶
- drjit.erf(arg: ArrayT, /) ArrayT¶
- drjit.erf(arg: float, /) float
Overloaded function.
erf(arg: ArrayT, /) -> ArrayT
Evaluate the error function.
The error function is defined as
\[\operatorname{erf}(z) = \frac{2}{\sqrt\pi}\int_0^z e^{-t^2}\,\mathrm{d}t.\]See the section on transcendental function approximations for details regarding accuracy.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\mathrm{erf}(\texttt{arg})\)
- Return type:
float | drjit.ArrayBase
erf(arg: float, /) -> float
- drjit.erfinv(arg: ArrayT, /) ArrayT¶
- drjit.erfinv(arg: float, /) float
Overloaded function.
erfinv(arg: ArrayT, /) -> ArrayT
Evaluate the inverse error function.
This function evaluates the inverse of
drjit.erf(). Its implementation is based on the paper Approximating the erfinv function by Mike Giles.This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\mathrm{erf}^{-1}(\texttt{arg})\)
- Return type:
float | drjit.ArrayBase
erfinv(arg: float, /) -> float
- drjit.lgamma(arg: ArrayT, /) ArrayT¶
- drjit.lgamma(arg: float, /) float
Overloaded function.
lgamma(arg: ArrayT, /) -> ArrayT
Evaluate the natural logarithm of the absolute value the gamma function.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\log|\Gamma(\texttt{arg})|\)
- Return type:
float | drjit.ArrayBase
lgamma(arg: float, /) -> float
- drjit.rad2deg(arg: T, /) T¶
Convert angles from radians to degrees.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
The equivalent angle in degrees.
- Return type:
float | drjit.ArrayBase
- drjit.deg2rad(arg: T, /) T¶
Convert angles from degrees to radians.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
The equivalent angle in radians.
- Return type:
float | drjit.ArrayBase
- drjit.sphdir(theta, phi)¶
Spherical coordinate parameterization of the unit sphere
- Parameters:
theta (float | drjit.ArrayBase) – Elevation angle in radians, measured from the positive Z axis. Valid range is [0, π].
phi (float | drjit.ArrayBase) – Azimuth angle in radians, measured from the positive X axis in the XY plane. Valid range is [0, 2π].
- Returns:
A 3D unit direction vector corresponding to the input spherical coordinates. The result is a 3-component array with unit length.
- Return type:
Safe mathematical functions¶
Dr.Jit provides “safe” variants of a few standard mathematical operations that are prone to out-of-domain errors in calculations with floating point rounding errors. Such errors could, e.g., cause the argument of a square root to become negative, which would ordinarily require complex arithmetic. At zero, the derivative of the square root function is infinite. The following operations clamp the input to a safe range to avoid these extremes.
- drjit.safe_sqrt(arg: T, /) T¶
Safely evaluate the square root of the provided input avoiding domain errors.
Negative inputs produce zero-valued output. When differentiated via AD, this function also avoids generating infinite derivatives at
x=0.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Square root of the input
- Return type:
float | drjit.ArrayBase
- drjit.safe_asin(arg: T, /) T¶
Safe wrapper around
drjit.asin()that avoids domain errors.Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arcsine approximation
- Return type:
float | drjit.ArrayBase
- drjit.safe_acos(arg: T, /) T¶
Safe wrapper around
drjit.acos()that avoids domain errors.Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arccosine approximation
- Return type:
float | drjit.ArrayBase
Automatic differentiation¶
- enum drjit.ADMode(value)¶
Enumeration to distinguish different types of primal/derivative computation.
See also
drjit.enqueue(),drjit.traverse().Valid values are as follows:
- Primal = ADMode.Primal¶
Primal/original computation without derivative tracking. Note that this is not a valid input to Dr.Jit AD routines, but it is sometimes useful to have this entry when to indicate to a computation that derivative propagation should not be performed.
- Forward = ADMode.Forward¶
Propagate derivatives in forward mode (from inputs to outputs)
- Backward = ADMode.Backward¶
Propagate derivatives in backward/reverse mode (from outputs to inputs
- enum drjit.ADFlag(value)¶
By default, Dr.Jit’s AD system destructs the enqueued input graph during forward/backward mode traversal. This frees up resources, which is useful when working with large wavefronts or very complex computation graphs. However, this also prevents repeated propagation of gradients through a shared subgraph that is being differentiated multiple times.
To support more fine-grained use cases that require this, the following flags can be used to control what should and should not be destructed.
Members of this enumeration can be combined using the
|operator.- Member Type:
int
Valid values are as follows:
- ClearNone = ADFlag.ClearNone¶
Clear nothing.
- ClearEdges = ADFlag.ClearEdges¶
Delete all traversed edges from the computation graph
- ClearInput = ADFlag.ClearInput¶
Clear the gradients of processed input vertices (in-degree == 0)
- ClearInterior = ADFlag.ClearInterior¶
Clear the gradients of processed interior vertices (out-degree != 0)
- ClearVertices = ADFlag.ClearVertices¶
Clear gradients of processed vertices only, but leave edges intact. Equal to
ClearInput | ClearInterior.
- AllowNoGrad = ADFlag.AllowNoGrad¶
Don’t fail when the input to a
drjit.forwardorbackwardoperation is not a differentiable array.
- Default = ADFlag.Default¶
Default: clear everything (edges, gradients of processed vertices). Equal to
ClearEdges | ClearVertices.
- drjit.detach(arg: T, preserve_type: bool = True) T¶
Transforms the input variable into its non-differentiable version (detaches it from the AD computational graph).
This function supports arbitrary Dr.Jit arrays/tensors and PyTrees as input. In the latter case, it applies the transformation recursively. When the input variable is not a PyTree or Dr.Jit array, it is returned as it is.
While the type of the returned array is preserved by default, it is possible to set the
preserve_typeargument to false to force the returned type to be non-differentiable. For example, this will convert an array of typedrjit.llvm.ad.Floatinto one of typedrjit.llvm.Float.- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.
preserve_type (bool) – Defines whether the returned variable should preserve the type of the input variable.
- Returns:
The detached variable.
- Return type:
object
- drjit.enable_grad(arg: object, /) None¶
- drjit.enable_grad(*args) None
Overloaded function.
enable_grad(arg: object, /) -> None
Enable gradient tracking for the provided variables.
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function enables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.
enable_grad(*args) -> None
- drjit.disable_grad(arg: object, /) None¶
- drjit.disable_grad(*args) None
Overloaded function.
disable_grad(arg: object, /) -> None
Disable gradient tracking for the provided variables.
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function disables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.
disable_grad(*args) -> None
- drjit.set_grad_enabled(arg0: object, arg1: bool, /) None¶
Enable or disable gradient tracking on the provided variables.
- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, PyTree, sequence, or mapping.
value (bool) – Defines whether gradient tracking should be enabled or disabled.
- drjit.grad_enabled(arg: object, /) bool¶
- drjit.grad_enabled(*args) bool
Overloaded function.
grad_enabled(arg: object, /) -> bool
Return whether gradient tracking is enabled on any of the given variables.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors instances or PyTrees. The function recursively traverses them to all differentiable variables.
- Returns:
Trueif any of the input variables has gradient tracking enabled,Falseotherwise.- Return type:
bool
grad_enabled(*args) -> bool
- drjit.grad(arg: T, preserve_type: bool = True) T¶
Return the gradient value associated to a given variable.
When the variable doesn’t have gradient tracking enabled, this function returns
0.- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor or PyTree.
preserve_type (bool) – Should the operation preserve the input type in the return value? (This is the default). Otherwise, Dr.Jit will, e.g., return a type of drjit.cuda.Float for an input of type drjit.cuda.ad.Float.
- Returns:
the gradient value associated to the input variable.
- Return type:
object
- drjit.set_grad(target: T, source: T) None¶
Set the gradient associated with the provided variable.
This operation internally decomposes into two sub-steps:
dr.clear_grad(target) dr.accum_grad(target, source)
When
sourceis not of the same type astarget, Dr.Jit will try to broadcast its contents into the right shape.
- drjit.accum_grad(target: T, source: T) None¶
Accumulate the contents of one variable into the gradient of another variable.
When
sourceis not of the same type astarget, Dr.Jit will try to broadcast its contents into the right shape.
- drjit.replace_grad(arg0: T, arg1: T, /) None¶
Replace the gradient value of
arg0with the one ofarg1.This is a relatively specialized operation to be used with care when implementing advanced automatic differentiation-related features.
One example use would be to inform Dr.Jit that there is a better way to compute the gradient of a particular expression than what the normal AD traversal of the primal computation graph would yield.
The function promotes and broadcasts
arg0andarg1if they are not of the same type.
- drjit.clear_grad(arg: object, /) None¶
Clear the gradient of the given variable.
- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.
- drjit.traverse(mode: drjit.ADMode, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None¶
Propagate gradients along the enqueued set of AD graph edges.
Given prior use of :py:func`drjit.enqueue()` to enqueue AD nodes for gradient propagation, this functions now performs the actual gradient propagation into either the forward or reverse direction (as specified by the
modeparameter)By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. The term leaf node is defined as follows: refers to
In forward AD, leaf nodes have no forward edges. They are outputs of a computation, and no other differentiable variable depends on them.
In backward AD, leaf nodes have no backward edges. They are inputs to a computation.
By default, the traversal also removes the edges of visited nodes to isolate them. These defaults are usually good ones: cleaning up the graph his frees up resources, which is useful when working with large wavefronts or very complex computation graphs. It also avoids potentially undesired derivative contributions that can arise when the AD graphs of two unrelated computations are connected by an edge and subsequently separately differentiated.
In advanced applications that require multiple AD traversals of the same graph, specify different combinations of the enumeration
drjit.ADFlagvia theflagsparameter.- Parameters:
mode (drjit.ADMode) – Specifies the direction in which gradients should be propgated.
drjit.ADMode.Forwardand:py:attr:drjit.ADMode.Backward refer to forward and backward traversal.flags (drjit.ADFlag | int) – Controls what parts of the AD graph are cleared during traversal. The default value is
drjit.ADFlag.Default.
- drjit.enqueue(mode: drjit.ADMode, arg: object) None¶
- drjit.enqueue(mode: drjit.ADMode, *args) None
Overloaded function.
enqueue(mode: drjit.ADMode, arg: object) -> None
Enqueues the input variable(s) for subsequent gradient propagation
Dr.Jit splits the process of automatic differentiation into three parts:
Initializing the gradient of one or more input or output variables. The most common initialization entails setting the gradient of an output (e.g., an optimization loss) to
1.0.Enqueuing nodes that should partake in the gradient propagation pass. Dr.Jit will follow variable dependences (edges in the AD graph) to find variables that are reachable from the enqueued variable.
Finally propagating gradients to all of the enqueued variables.
This function is responsible for step 2 of the above list and works differently depending on the specified
mode:- -
drjit.ADMode.Forward: Dr.Jit will recursively enqueue all variables that are reachable along forward edges. That is, given a differentiable operation
a = b+c, enqueuingcwill also enqueueafor later traversal.- -
drjit.ADMode.Backward: Dr.Jit will recursively enqueue all variables that are reachable along backward edges. That is, given a differentiable operation
a = b+c, enqueuingawill also enqueuebandcfor later traversal.
For example, a typical chain of operations to forward propagate the gradients from
atobmight look as follow:a = dr.llvm.ad.Float(1.0) dr.enable_grad(a) b = f(a) # some computation involving 'a' # The below three operations can also be written more compactly as dr.forward(a) dr.set_gradient(a, 1.0) dr.enqueue(dr.ADMode.Forward, a) dr.traverse(dr.ADMode.Forward) grad = dr.grad(b)
One interesting aspect of this design is that enqueuing and traversal don’t necessarily need to follow the same direction.
For example, we may only be interested in forward gradients reaching a specific output node
c, which can be expressed as follows:a = dr.llvm.ad.Float(1.0) dr.enable_grad(a) b, c, d, e = f(a) dr.set_gradient(a, 1.0) dr.enqueue(dr.ADMode.Backward, b) dr.traverse(dr.ADMode.Forward) grad = dr.grad(b)
The same naturally also works in the reverse direction. Dr.Jit provides a higher level API that encapsulate such logic in a few different flavors:
drjit.forward_from(),drjit.forward_to(), anddrjit.forward().drjit.backward_from(),drjit.backward_to(), anddrjit.backward().
- Parameters:
mode (drjit.ADMode) –
- Specifies the set edges which Dr.Jit should follow to
enqueue variables to be visited by a later gradient propagation phase.
drjit.ADMode.Forwardand:py:attr:drjit.ADMode.Backward refer to forward andbackward edges, respectively.
value (object) – An arbitrary Dr.Jit array, tensor or PyTree.
enqueue(mode: drjit.ADMode, *args) -> None
- drjit.forward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None¶
Forward-propagate gradients from the provided Dr.Jit array or tensor.
This function checks if a derivative has already been assigned to the provided Dr.Jit array or tensor
arg. If not, it assigns the value1.Following this, it forward-propagates derivatives through forward-connected components of the computation graph (i.e., reaching all variables that directly or indirectly depend on
arg).The operation is equivalent to
if arg.grad has not been set yet: dr.set_grad(arg, 1) dr.enqueue(dr.ADMode.Forward, h) dr.traverse(dr.ADMode.Forward, flags=flags)
Refer to the documentation functions
drjit.set_grad(),drjit.enqueue(), anddrjit.traverse()for further details on the nuances of forward derivative propagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()and the meaning of theflagsparameter.When
drjit.JitFlag.SymbolicCallsis set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad(), as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGradflag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default.
- drjit.forward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT¶
- drjit.forward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]
Overloaded function.
forward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> ArrayT
Forward-propagate gradients to the provided set of Dr.Jit arrays/tensors.
The operation is equivalent to
dr.enqueue(dr.ADMode.Backward, *args) dr.traverse(dr.ADMode.Forward, flags=flags) return dr.grad(*args)
Internally, the operation first traverses the computation graph backwards from
argsto find potential paths along which gradients can flow to the given set of arrays. Then, it performs a gradient propagation pass along the detected variables.For this to work, you must have previously enabled and specified input gradients for inputs of the computation. (see
drjit.enable_grad()and viadrjit.set_grad()).Refer to the documentation functions
drjit.enqueue()anddrjit.traverse()for further details on the nuances of forward derivative propagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()and the meaning of theflagsparameter.When
drjit.JitFlag.SymbolicCallsis set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad(), as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGradflag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default.
- Returns:
the gradient value(s) associated with
*argsfollowing the traversal.- Return type:
object
forward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> tuple[*Ts]
- drjit.forward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None¶
Forard-propagate gradients from the provided Dr.Jit array or tensor.
This operation is equivalent to
dr.set_grad(arg, 1) dr.forward_from(arg)
In other words, it assigns an initial gradient of
1toargand then forward-propagates it through the rest of the computation.Please refer to the function
drjit.forward_from()for further detail.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default.
- drjit.backward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None¶
Backpropagate gradients from the provided Dr.Jit array or tensor.
This function checks if a derivative has already been assigned to the provided Dr.Jit array or tensor
arg. If not, it assigns the value1.Following this, it backpropagates derivatives through backward-connected components of the computation graph (i.e., reaching differentiable variables that potentially influence the value of
arg).The operation is conceptually equivalent to
if arg.grad has not been set yet: dr.set_grad(arg, 1) dr.enqueue(dr.ADMode.Backward, h) dr.traverse(dr.ADMode.Backward, flags=flags)
Refer to the documentation functions
drjit.set_grad(),drjit.enqueue(), anddrjit.traverse()for further details on the nuances of derivative backpropagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()and the meaning of theflagsparameter.When
drjit.JitFlag.SymbolicCallsis set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad(), as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGradflag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default.
- drjit.backward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT¶
- drjit.backward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]
Overloaded function.
backward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> ArrayT
Backpropagate gradients to the provided set of Dr.Jit arrays/tensors.
The operation is equivalent to
dr.enqueue(dr.ADMode.Forward, *args) dr.traverse(dr.ADMode.Backwards, flags=flags) return dr.grad(*args)
Internally, the operation first traverses the computation graph forwards from
argsto find potential paths along which reverse-mode gradients can flow to the given set of input variables. Then, it performs a backpropagation pass along the detected variables.For this to work, you must have previously enabled and specified input gradients for outputs of the computation. (see
drjit.enable_grad()and viadrjit.set_grad()).Refer to the documentation functions
drjit.enqueue()anddrjit.traverse()for further details on the nuances of derivative backpropagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()and the meaning of theflagsparameter.When
drjit.JitFlag.SymbolicCallsis set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad(), as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGradflag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad) to the function.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default.
- Returns:
the gradient value(s) associated with
*argsfollowing the traversal.- Return type:
object
backward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> tuple[*Ts]
- drjit.backward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None¶
Backpropgate gradients from the provided Dr.Jit array or tensor.
This operation is equivalent to
dr.set_grad(arg, 1) dr.backward_from(arg)
In other words, it assigns an initial gradient of
1toargand then backpropagates it through the rest of the computation.Please refer to the function
drjit.backward_from()for further detail.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default.
- drjit.suspend_grad(*args, when=True)¶
Python context manager to temporarily disable gradient tracking globally, or for a specific set of variables.
This context manager can be used as follows to completely disable all gradient tracking. Newly created variables will be detached from Dr.Jit’s AD graph.
with dr.suspend_grad(): # .. code coes here ..
You may also specify any number of Dr.Jit arrays, tensors, or PyTrees. In this case, the context manager behaves differently by disabling gradient tracking more selectively for the specified variables.
with dr.suspend_grad(x): z = x + y # 'z' will not track any gradients arising from 'x'
The
suspend_grad()andresume_grad()context manager can be arbitrarily nested and suitably update the set of tracked variables.A note about the interaction with
drjit.enable_grad(): it is legal to register further AD variables within a scope that disables gradient tracking for specific variables.with dr.suspend_grad(x): y = Float(1) dr.enable_grad(y) # The following condition holds assert not dr.grad_enabled(x) and dr.grad_enabled(y)
In contrast, a
suspend_grad()environment without arguments that completely disables AD does not allow further variables to be registered:with dr.suspend_grad(): y = Float(1) dr.enable_grad(y) # ignored # The following condition holds assert not dr.grad_enabled(x) and not dr.grad_enabled(y)
- Parameters:
*args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or PyTrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False. The default value iswhen=True.
- drjit.resume_grad(*args, when=True)¶
Python context manager to temporarily resume gradient tracking globally, or for a specific set of variables.
This context manager can be used as follows to fully re-enable all gradient tracking following a previous call to
drjit.suspend_grad(). Newly created variables will then again be attached to Dr.Jit’s AD graph.with dr.suspend_grad(): # .. with dr.resume_grad(): # In this scope, the effect of the outer context # manager is effectively disabled
You may also specify any number of Dr.Jit arrays, tensors, or PyTrees. In this case, the context manager behaves differently by enabling gradient tracking more selectively for the specified variables.
with dr.suspend_grad(): with dr.resume_grad(x): z = x + y # 'z' will only track gradients arising from 'x'
The
suspend_grad()andresume_grad()context manager can be arbitrarily nested and suitably update the set of tracked variables.- Parameters:
*args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or PyTrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False. The default value iswhen=True.
- drjit.isolate_grad(when=True)¶
Python context manager to isolate and partition AD traversals into multiple distinct phases.
Consider a sequence of steps being differentiated in reverse mode, like so:
x = .. dr.enable_grad(x) y = f(x) z = g(y) dr.backward(z)
The
drjit.backward()call would automatically traverse the AD graph nodes created during the execution of the functionf()andg().However, sometimes this is undesirable and more control is needed. For example, Dr.Jit may be in an execution context (a symbolic loop or call) that temporarily disallows differentiation of the
f()part. Thedrjit.isolate_grad()context manager addresses this need:dr.enable_grad(x) y = f(x) with dr.isolate_grad(): z = g(y) dr.backward(z)
Any reverse-mode AD traversal of an edge that crosses the isolation boundary is postponed until leaving the scope. This is mathematically equivalent but produces two smaller separate AD graph traversals.
Dr.Jit operations like symbolic loops and calls internally create such an isolation boundary, hence it is rare that you would need to do so yourself.
isolate_grad()is not useful for forward mode AD.- Parameters:
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False. The default value iswhen=True.
- class drjit.CustomOp(*args, **kwargs)¶
Base class for implementing custom differentiable operations.
Dr.Jit can compute derivatives of builtin operations in both forward and reverse mode. In some cases, it may be useful or even necessary to control how a particular operation should be differentiated.
To do so, you may extend this class to provide three callback functions:
CustomOp.eval(): Implements the primal evaluation of the function with detached inputs.CustomOp.forward(): Implements the forward derivative that propagates derivatives from input arguments to the return valueCustomOp.backward(): Implements the backward derivative that propagates derivatives from the return value to the input arguments.
An example for a hypothetical custom addition operation is shown below
class Addition(dr.CustomOp): def eval(self, x, y): # Primal calculation without derivative tracking return x + y def forward(self): # Compute forward derivatives self.set_grad_out(self.grad_in('x') + self.grad_in('y')) def backward(self): # .. compute backward derivatives .. self.set_grad_in('x', self.grad_out()) self.set_grad_in('y', self.grad_out()) def name(self): # Optional: a descriptive name shown in GraphViz visualizations return "Addition"
You should never need to call these functions yourself—Dr.Jit will do so when appropriate. To weave such a custom operation into the AD graph, use the
drjit.custom()function, which expects a subclass ofdrjit.CustomOpas first argument, followed by arguments to the actual operation that are directly forwarded to the.eval()callback.# Add two numbers 'x' and 'y'. Calls our '.eval()' callback with detached arguments result = dr.custom(Addition, x, y)
Forward or backward derivatives are then automatically handled through the standard operations. For example,
dr.backward(result)
will invoke the
.backward()callback from above.Many derivatives are more complex than the above examples and require access to inputs or intermediate steps of the primal evaluation routine. You can simply stash them in the instance (
self.field = ...), which is shown below for a differentiable multiplication operation that implements the product rule:class Multiplication(dr.CustomOp): def eval(self, x, y): # Stash input arguments self.x = x self.y = y return x * y def forward(self): self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y')) def backward(self): self.set_grad_in('x', self.y * self.grad_out()) self.set_grad_in('y', self.x * self.grad_out()) def name(self): return "Multiplication"
- eval(self, *args, **kwargs) object¶
Evaluate the custom operation in primal mode.
You must implement this method when subclassing
CustomOp, since the default implementation raises an exception. It should realize the original (non-derivative-aware) form of a computation and may take an arbitrary sequence of positional, keyword, and variable-length positional/keyword arguments.You should not need to call this function yourself—Dr.Jit will automatically do so when performing custom operations through the
drjit.custom()interface.Note that the input arguments passed to
.eval()will be detached (i.e. they don’t have derivative tracking enabled). This is intentional, since derivative tracking is handled by the custom operation along with the other callbacksforward()andbackward().
- forward(self) None¶
Evaluate the forward derivative of the custom operation.
You must implement this method when subclassing
CustomOp, since the default implementation raises an exception. It takes no arguments and has no return value.An implementation will generally perform repeated calls to
grad_in()to query the gradients of all function followed by a single call toset_grad_out()to set the gradient of the return value.For example, this is how one would implement the product rule of the primal calculation
x*y, assuming that the.eval()routine stashed the inputs in the custom operation object.def forward(self): self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y'))
- backward(self) None¶
Evaluate the backward derivative of the custom operation.
You must implement this method when subclassing
CustomOp, since the default implementation raises an exception. It takes no arguments and has no return value.An implementation will generally perform a single call to
grad_out()to query the gradient of the function return value followed by a sequence of calls toset_grad_in()to assign the gradients of the function inputs.For example, this is how one would implement the product rule of the primal calculation
x*y, assuming that the.eval()routine stashed the inputs in the custom operation object.def backward(self): self.set_grad_in('x', self.y * self.grad_out()) self.set_grad_in('y', self.x * self.grad_out())
- name(self) str¶
Return a descriptive name of the
CustomOpinstance.Amongst other things, this name is used to document the presence of the custom operation in GraphViz debug output. (See
graphviz_ad().)
- grad_out(self) object¶
Query the gradient of the return value.
Returns an object, whose type matches the original return value produced in
eval(). This function should only be used within thebackward()callback.
- set_grad_out(self, arg: object, /) None¶
Accumulate a gradient into the return value.
This function should only be used within the
forward()callback.
- grad_in(self, arg: object, /) object¶
Query the gradient of a specified input parameter.
The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (
*args) can be queried by providing an integer index instead.This function should only be used within the
forward()callback.
- set_grad_in(self, arg0: object, arg1: object, /) None¶
Accumulate a gradient into the specified input parameter.
The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (
*args) can be assigned by providing an integer index instead.This function should only be used within the
backward()callback.
- add_input(self, arg: object, /) None¶
Register an implicit input dependency of the operation on an AD variable.
This function should be called by the
eval()implementation when an operation has a differentiable dependence on an input that is not a ordinary input argument of the function (e.g., a global program variable or a field of a class).
- add_output(self, arg: object, /) None¶
Register an implicit output dependency of the operation on an AD variable.
This function should be called by the
eval()implementation when an operation has a differentiable dependence on an output that is not part of the function return value (e.g., a global program variable or a field of a class).”
- drjit.custom(arg0: type[drjit.CustomOp], /, *args, **kwargs) object¶
Evaluate a custom differentiable operation.
It can be useful or even necessary to control how a particular operation should be differentiated by Dr.Jit’s automatic differentiation (AD) layer. The
drjit.custom()function enables such use cases by stitching an opque operation with user-defined primal and forward/backward derivative implementations into the AD graph.The function expects a subclass of the
CustomOpinterface as first argument. The remaining positional and keyword arguments are forwarded to theCustomOp.eval()callback.See the documentation of
CustomOpfor examples on how to realize such a custom operation.
- drjit.wrap(source: str | ModuleType, target: str | ModuleType) Callable[[T], T]¶
Differentiable bridge between Dr.Jit and other array programming frameworks.
This function wraps computation performed using one array programming framework to expose it in another. Currently, PyTorch, TensorFlow, and JAX are supported, though other frameworks may be added in the future.
Annotating a function with
@drjit.wrapadds code that suitably converts arguments and return values. Furthermore, it stitches the operation into the automatic differentiation (AD) graph of the other framework to ensure correct gradient propagation.When exposing code written using another framework, the wrapped function can take and return any PyTree including flat or nested Dr.Jit arrays, tensors, and arbitrary nested lists/tuples, dictionaries, and custom data structures. The arguments don’t need to be differentiable—for example, integer/boolean arrays that don’t carry derivative information can be passed as well.
The wrapped function should be pure: in other words, it should read its input(s) and compute an associated output so that re-evaluating the function again produces the same answer. Multi-framework derivative tracking of impure computation will likely not behave as expected.
The following table lists the currently supported conversions:
Direction
Primal
Forward-mode AD
Reverse-mode AD
Remarks
drjit→torch✅
✅
✅
Everything just works.
torch→drjit✅
✅
✅
Limitation: The passed/returned PyTrees can contain arbitrary arrays or tensors, but other types (e.g., a custom Python object not understood by PyTorch) will raise errors when differentiating in forward mode (backward mode works fine).
An issue was filed on the PyTorch bugtracker.
drjit→tf✅
✅
✅
You may want to further annotate the wrapped function with
tf.functionto trace and just-in-time compile it in the Tensorflow environment, i.e.,@dr.wrap(source='drjit', target='tf') @tf.function(jit_compile=False) # Set to True for XLA mode
Limitation: There is an issue for tf.int32 tensors which are wrongly placed on CPU by DLPack. This can lead to inconsistent device placement of tensors.
An issue was filed on the TensorFlow bugtracker.
tf→drjit✅
❌
✅
TensorFlow has some limitiations with respect to custom gradients in foward-mode AD.
Limitation: TensorFlow does not allow for non-tensor input structures in fuctions with custom gradients.
TensorFlow has a bug for functions with custom gradients and keyword arguments.
An issue was filed on the TensorFlow bugtracker.
drjit→jax✅
✅
✅
You may want to further annotate the wrapped function with
jax.jitto trace and just-in-time compile it in the JAX environment, i.e.,@dr.wrap(source='drjit', target='jax') @jax.jit
Limitation: The passed/returned PyTrees can contain arbitrary arrays or Python scalar types, but other types (e.g., a custom Python object not understood by JAX) will raise errors.
jax→drjit❌
❌
❌
This direction is currently unsupported. We plan to add it in the future.
Please also refer to the documentation sections on multi-framework differentiation associated caveats.
Note
Types that have no equivalent on the other side (e.g. a quaternion array) will convert to generic tensors.
Data exchange is limited to representations that exist on both sides. There are a few limitations:
PyTorch lacks support for most unsigned integer types (
uint16,uint32, oruint64-typed arrays). Use signed integer types to work around this issue.TensorFlow has limitations with respect to forward-mode AD for functions with custom gradients.There is also an issue for functions with keyword arguments.
Dr.Jit currently lacks support for most 8- and 16-bit numeric types (besides half precision floats).
JAX refuses to exchange boolean-valued tensors with other frameworks.
- Parameters:
source (str | module) – The framework used outside of the wrapped function. The argument is currently limited to either
'drjit','torch','tf', orjax'. For convenience, the associated Python module can be specified as well.target (str | module) – The framework used inside of the wrapped function. The argument is currently limited to either
'drjit','torch','tf', or'jax'. For convenience, the associated Python module can be specified as well.
- Returns:
The decorated function.
Constants¶
- drjit.e¶
The exponential constant \(e\) represented as a Python
float.
- drjit.log_two¶
The value \(\log(2)\) represented as a Python
float.
- drjit.inv_log_two¶
The value \(\frac{1}{\log(2)}\) represented as a Python
float.
- drjit.pi¶
The value \(\pi\) represented as a Python
float.
- drjit.inv_pi¶
The value \(\frac{1}{\pi}\) represented as a Python
float.
- drjit.sqrt_pi¶
The value \(\sqrt{\pi}\) represented as a Python
float.
- drjit.inv_sqrt_pi¶
The value \(\frac{1}{\sqrt{\pi}}\) represented as a Python
float.
- drjit.two_pi¶
The value \(2\pi\) represented as a Python
float.
- drjit.inv_two_pi¶
The value \(\frac{1}{2\pi}\) represented as a Python
float.
- drjit.sqrt_two_pi¶
The value \(\sqrt{2\pi}\) represented as a Python
float.
- drjit.inv_sqrt_two_pi¶
The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python
float.
- drjit.four_pi¶
The value \(4\pi\) represented as a Python
float.
- drjit.inv_four_pi¶
The value \(\frac{1}{4\pi}\) represented as a Python
float.
- drjit.sqrt_four_pi¶
The value \(\sqrt{4\pi}\) represented as a Python
float.
- drjit.sqrt_two¶
The value \(\sqrt{2\pi}\) represented as a Python
float.
- drjit.inv_sqrt_two¶
The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python
float.
- drjit.inf¶
The value
float('inf')represented as a Pythonfloat.
- drjit.nan¶
The value
float('nan')represented as a Pythonfloat.
- drjit.epsilon(arg, /)¶
Returns the machine epsilon.
The machine epsilon gives an upper bound on the relative approximation error due to rounding in floating point arithmetic.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The machine epsilon.
- Return type:
float
- drjit.one_minus_epsilon(arg, /)¶
Returns one minus the machine epsilon value.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
One minus the machine epsilon.
- Return type:
float
- drjit.recip_overflow(arg, /)¶
Returns the reciprocal overflow threshold value.
Any numbers equal to this threshold or a smaller value will overflow to infinity when reciprocated.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The reciprocal overflow threshold value.
- Return type:
float
- drjit.smallest(arg, /)¶
Returns the smallest representable normalized floating point value.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The smallest representable normalized floating point value.
- Return type:
float
- drjit.largest(arg, /)¶
Returns the largest representable finite floating point value for t.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The largest representable finite floating point value.
- Return type:
float
Array base class¶
- class drjit.ArrayBase(*args, **kwargs)¶
This is the base class of all Dr.Jit arrays and tensors. It provides an abstract version of the array API that becomes usable when the type is extended by a concrete specialization.
ArrayBaseitself cannot be instantiated.See the section on Dr.Jit type signatures <type_signatures> to learn about the type parameters of
ArrayBase.- property array¶
This member plays multiple roles:
When
selfis a tensor, this property returns the storage representation of the tensor in the form of a linearized dynamic 1D array.When
selfis a special arithmetic object (matrix, quaternion, or complex number),arrayprovides an copy of the same data with ordinary array semantics.In all other cases,
arrayis simply a reference toself.
- Type:
- property ndim¶
This property represents the dimension of the provided Dr.Jit array or tensor.
- Type:
int
- property shape¶
This property provides a tuple describing dimension and shape of the provided Dr.Jit array or tensor.
When the input array is ragged the function raises a
RuntimeError. The term ragged refers to an array, whose components have mismatched sizes, such as[[1, 2], [3, 4, 5]]. Note that scalar entries (e.g.[[1, 2], [3]]) are acceptable, since broadcasting can effectively convert them to any size.The expressions
drjit.shape(arg)andarg.shapeare equivalent.- Type:
tuple[int, …]
- property state¶
This read-only property returns an enumeration value describing the evaluation state of this Dr.Jit array.
- Type:
- property x¶
If
selfis a static Dr.Jit array of size 1 (or larger), the propertyself.xcan be used synonymously withself[0]. Otherwise, accessing this field will generate aRuntimeError.- Type:
- property y¶
If
selfis a static Dr.Jit array of size 2 (or larger), the propertyself.ycan be used synonymously withself[1]. Otherwise, accessing this field will generate aRuntimeError.- Type:
- property z¶
If
selfis a static Dr.Jit array of size 3 (or larger), the propertyself.zcan be used synonymously withself[2]. Otherwise, accessing this field will generate aRuntimeError.- Type:
- property w¶
If
selfis a static Dr.Jit array of size 4 (or larger), the propertyself.wcan be used synonymously withself[3]. Otherwise, accessing this field will generate aRuntimeError.- Type:
- property T¶
Transpose of
self.For fixed-size matrix types, returns a matrix with rows and columns swapped.
For Dr.Jit tensors, returns the transpose of a 2-D tensor. This matches PyTorch’s semantics: only rank-2 tensors are accepted, and higher-rank tensors raise a
TypeError. UsemTto transpose only the last two dimensions of a higher-rank tensor.
Other array types raise
TypeError.
- property mT¶
Matrix transpose — swaps the last two dimensions of
self, leaving any leading batch dimensions unchanged. Requires at least 2 dimensions and mirrors the semantics of PyTorch’s.mT/ NumPy’s.mT.For fixed-size matrix types this is equivalent to
T. For tensors, the operation is evaluated as a permuted gather and is fully differentiable. Non-matrix, non-tensor types raiseTypeError.
- property index¶
If
selfis a leaf Dr.Jit array managed by a just-in-time compiled backend (i.e, CUDA or LLVM), this property contains the associated variable index in the graph data structure storing the computation trace. This graph can be visualized usingdrjit.graphviz(). Otherwise, the value of this property equals zero. A non-leaf array (e.g.drjit.cuda.Array2i) consists of several JIT variables, whose indices must be queried separately.Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index
index_adkeeps track of the variable index within the AD computation graph, if applicable.- Type:
int
- property index_ad¶
If
selfis a leaf Dr.Jit array represented by an AD backend, this property contains the variable index in the graph data structure storing the computation trace for later differentiation (this graph can be visualized usingdrjit.graphviz_ad()). A non-leaf array (e.g.drjit.cuda.ad.Array2f) consists of several AD variables, whose indices must be queried separately.Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index
indexkeeps track of the variable index within the raw computation graph, if applicable.- Type:
int
- property grad¶
This property can be used to retrieve or set the gradient associated with the Dr.Jit array or tensor.
The expressions
drjit.grad(arg)andarg.gradare equivalent whenargis a Dr.Jit array/tensor.- Type:
- item(self) object¶
Return the content of the array/tensor as Python scalar
This operation is only permitted when the array/tensor has exactly a single element, otherwise it raises an exception.
- Type:
- __len__()¶
Return len(self).
- __iter__()¶
Implement iter(self).
- __repr__()¶
Return repr(self).
- __bool__()¶
True if self else False
Casts the array to a Python
booltype. This is only permissible whenselfrepresents an boolean array of both depth and size 1.
- __add__(value, /)¶
Return self+value.
- __radd__(value, /)¶
Return value+self.
- __iadd__(value, /)¶
Return self+=value.
- __sub__(value, /)¶
Return self-value.
- __rsub__(value, /)¶
Return value-self.
- __isub__(value, /)¶
Return self-=value.
- __mul__(value, /)¶
Return self*value.
- __rmul__(value, /)¶
Return value*self.
- __imul__(value, /)¶
Return self*=value.
- __matmul__(value, /)¶
Return self@value.
- __rmatmul__(value, /)¶
Return value@self.
- __imatmul__(value, /)¶
Return self@=value.
- __truediv__(value, /)¶
Return self/value.
- __rtruediv__(value, /)¶
Return value/self.
- __itruediv__(value, /)¶
Return self/=value.
- __floordiv__(value, /)¶
Return self//value.
- __rfloordiv__(value, /)¶
Return value//self.
- __ifloordiv__(value, /)¶
Return self//=value.
- __mod__(value, /)¶
Return self%value.
- __rmod__(value, /)¶
Return value%self.
- __imod__(value, /)¶
Return self%=value.
- __rshift__(value, /)¶
Return self>>value.
- __rrshift__(value, /)¶
Return value>>self.
- __irshift__(value, /)¶
Return self>>=value.
- __lshift__(value, /)¶
Return self<<value.
- __rlshift__(value, /)¶
Return value<<self.
- __ilshift__(value, /)¶
Return self<<=value.
- __and__(value, /)¶
Return self&value.
- __rand__(value, /)¶
Return value&self.
- __iand__(value, /)¶
Return self&=value.
- __or__(value, /)¶
Return self|value.
- __ror__(value, /)¶
Return value|self.
- __ior__(value, /)¶
Return self|=value.
- __xor__(value, /)¶
Return self^value.
- __rxor__(value, /)¶
Return value^self.
- __ixor__(value, /)¶
Return self^=value.
- __abs__()¶
abs(self)
- __le__(value, /)¶
Return self<=value.
- __lt__(value, /)¶
Return self<value.
- __ge__(value, /)¶
Return self>=value.
- __gt__(value, /)¶
Return self>value.
- __ne__(value, /)¶
Return self!=value.
- __eq__(value, /)¶
Return self==value.
- __dlpack__(self, stream: object | None = None) ndarray[]¶
Returns a DLPack capsule representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3fordrjit.cuda.Matrix4fneed to be rearranged into a contiguous memory representation before they can be exposed.In other case, e.g. for
drjit.llvm.Float,drjit.scalar.Array3f, ordrjit.scalar.ArrayXf, the data is already contiguous and a zero-copy approach is used instead.
- __array__(self, dtype: object | None = None, copy: object | None = None) object¶
Returns a NumPy array representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3fordrjit.cuda.Matrix4fneed to be rearranged into a contiguous memory representation before they can be wrapped.In other case, e.g. for
drjit.llvm.Float,drjit.scalar.Array3f, ordrjit.scalar.ArrayXf, the data is already contiguous and a zero-copy approach is used instead.
- numpy(self) numpy.ndarray[]¶
Returns a NumPy array representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3fordrjit.cuda.Matrix4fneed to be rearranged into a contiguous memory representation before they can be wrapped.In other case, e.g. for
drjit.llvm.Float,drjit.scalar.Array3f, ordrjit.scalar.ArrayXf, the data is already contiguous and a zero-copy approach is used instead.
- torch(self) object¶
Returns a PyTorch tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()for further information on how to accomplish this.
- jax(self) object¶
Returns a JAX tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()for further information on how to accomplish this.
- tf(self) object¶
Returns a TensorFlow tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()for further information on how to accomplish this.
- memview(self) memoryview[]¶
Returns a
memoryviewrepresenting the data in this array.For CPU flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form. GPU arrays must be copied to the CPU first.
This function is sometimes preferable to
.numpy()when data needs to be shuffled between two frameworks. Loading NumPy consumes resources and starts a background thread pool on some platforms. This can be avoided when using the simplermemoryviewclass that is part of stock Python.Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()for further information on how to accomplish this.
Computation graph analysis¶
The following operations visualize the contents of Dr.Jit’s computation graphs (of which there are two: one for Jit compilation, and one for automatic differentiation).
- drjit.graphviz(as_string: bool = False) object¶
Return a GraphViz diagram describing registered JIT variables and their connectivity.
This function returns a representation of the computation graph underlying the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the
graphviz_ad()function to visualize the computation graph of the latter.Run
dr.graphviz().view()to open up a PDF viewer that shows the resulting output in a separate window.The function depends on the
graphvizPython package whenas_string=False(the default).- Parameters:
as_string (bool) – if set to
True, the function will return raw GraphViz markup as a string. (Default:False)- Returns:
GraphViz object or raw markup.
- Return type:
object
- drjit.graphviz_ad(as_string: bool = False) object¶
Return a GraphViz diagram describing variables registered with the automatic differentiation layer, as well as their connectivity.
This function returns a representation of the computation graph underlying the Dr.Jit AD layer, which one architectural layer above the just-in-time compiler. See the
graphviz()function to visualize the computation graph of the latter.Run
dr.graphviz_ad().view()to open up a PDF viewer that shows the resulting output in a separate window.The function depends on the
graphvizPython package whenas_string=False(the default).- Parameters:
as_string (bool) – if set to
True, the function will return raw GraphViz markup as a string. (Default:False)- Returns:
GraphViz object or raw markup.
- Return type:
object
- drjit.whos(as_string: bool) str | None¶
Return/print a list of live JIT variables.
This function provides information about the set of variables that are currently registered with the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the
whos_ad()function for the latter.- Parameters:
as_string (bool) – if set to
True, the function will return the list in string form. Otherwise, it will print directly onto the console and returnNone. (Default:False)- Returns:
a human-readable list (if requested).
- Return type:
None | str
- drjit.whos_ad(as_string: bool) str | None¶
Return/print a list of live variables registered with the automatic differentiation layer.
This function provides information about the set of variables that are currently registered with the Dr.Jit automatic differentiation layer, which one architectural layer above the just-in-time compiler. See the
whos()function to obtain information about the latter.- Parameters:
as_string (bool) – if set to
True, the function will return the list in string form. Otherwise, it will print directly onto the console and returnNone. (Default:False)- Returns:
a human-readable list (if requested).
- Return type:
None | str
- drjit.set_label(arg0: object, arg1: str, /) None¶
- drjit.set_label(**kwargs) None
Overloaded function.
set_label(arg0: object, arg1: str, /) -> None
Assign a label to the provided Dr.Jit array.
This can be helpful to identify computation in GraphViz output (see
drjit.graphviz(),graphviz_ad()).The operations assumes that the array is tracked by the just-in-time compiler. It has no effect on unsupported inputs (e.g., arrays from the
drjit.scalarpackage). It recurses through PyTrees (tuples, lists, dictionaries, custom data structures) and appends names (indices, dictionary keys, field names) separated by underscores to uniquely identify each element.The following
**kwargs-based shorthand notation can be used to assign multiple labels at once:set_label(x=x, y=y)
- Parameters:
*arg (tuple) – a Dr.Jit array instance and its corresponding label
strvalue.**kwarg (dict) – A set of (keyword, object) pairs.
set_label(**kwargs) -> None
Debugging¶
- enum drjit.LogLevel(value)¶
Valid values are as follows:
- Disable = LogLevel.Disable¶
- Error = LogLevel.Error¶
- Warn = LogLevel.Warn¶
- Info = LogLevel.Info¶
- InfoSym = LogLevel.InfoSym¶
- Debug = LogLevel.Debug¶
- Trace = LogLevel.Trace¶
- drjit.assert_true(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)¶
Generate an assertion failure message when any of the entries in
condareFalse.This function resembles the built-in
assertkeyword in that it raises anAssertionErrorwhen the conditioncondisFalse.In contrast to the built-in keyword, it also works when
condis an array of boolean values. In this case, the function raises an exception when any entry ofcondisFalse.The function accepts an optional format string along with positional and keyword arguments, and it processes them like
drjit.print(). When only a subset of the entries ofcondisFalse, the function reduces the generated output to only include the associated entries.>>> x = Float(1, -4, -2, 3) >>> dr.assert_true(x >= 0, 'Found negative values: {}', x) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "drjit/__init__.py", line 1327, in assert_true raise AssertionError(msg) AssertionError: Assertion failure: Found negative values: [-4, -2]
This function also works when some of the function inputs are symbolic. In this case, the check is delayed and potential failures will be reported asynchronously. In this case,
drjit.assert_true()generates output onsys.stderrinstead of raising an exception, as the original execution context no longer exists at that point.Assertion checks carry a performance cost, hence they are disabled by default. To enable them, set the JIT flag
dr.JitFlag.Debug.- Parameters:
cond (bool | drjit.ArrayBase) – The condition used to trigger the assertion. This should be a scalar Python boolean or a 1D boolean array.
fmt (str) – An optional format string that will be appended to the error message. It can reference positional or keyword arguments specified via
*argsand**kwargs.*args (tuple) – Optional variable-length positional arguments referenced by
fmt, seedrjit.print()for details on this.tb_depth (int) – Depth of the backtrace that should be appended to the assertion message. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed.
tb_skip (int) – The first
tb_skipentries of the backtrace will be removed. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed. This is helpful when the assertion check is called from a helper function that should not be shown.**kwargs (dict) – Optional variable-length keyword arguments referenced by
fmt, seedrjit.print()for details on this.
- drjit.assert_false(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)¶
Equivalent to
assert_true()with a flipped conditioncond. Please refer to the documentation of this function for further details.
- drjit.assert_equal(arg0, arg1, fmt: str | None = None, *args, limit: int = 3, tb_skip: int = 0, **kwargs)¶
Equivalent to
assert_true()with the conditionarg0==arg1. Please refer to the documentation of this function for further details.
- drjit.print(fmt: str, *args, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None¶
- drjit.print(value: object, /, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None
Overloaded function.
print(fmt: str, *args, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) -> None
Generate a formatted string representation and print it immediately or in a delayed fashion (if any of the inputs are symbolic).
This function combines the behavior of the built-in Python
format()andprint()functions: it generates a formatted string representation as specified by a format stringfmtand then outputs it on the console. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.>>> from drjit.cuda import Array3f >>> dr.print("{}:\n{foo}", ... "A PyTree containing an array", ... foo={ 'a' : Array3f(1, 2, 3) }) A PyTree containing an array: { 'a': [[1, 2, 3]] }
The key advance of
drjit.print()compared to the built-in Pythonprint()statement is that it can run asynchronously, which allows it to print symbolic variables without requiring their evaluation. Dr.Jit uses symbolic variables to trace loops (drjit.while_loop()), conditionals (drjit.if_stmt()), and calls (drjit.switch(),drjit.dispatch()). Such symbolic variables represent values that are unknown at trace time, and which cannot be printed using the built-in Pythonprint()function (attempting to do so will raise an exception).When the print statement does not reference any symbolic arrays or tensors, it will execute immediately. Otherwise, the output will appear after the next
drjit.eval()statement, or whenever any subsequent computation is evaluated. Here is an example from an interactive Python session demonstrating printing from a symbolic call performed viadrjit.switch():>>> from drjit.llvm import UInt, Float >>> def f1(x): ... dr.print("in f1: {x=}", x=x) ... >>> def f2(x): ... dr.print("in f2: {x=}", x=x) ... >>> dr.switch(index=UInt(0, 0, 0, 1, 1, 1), ... targets=[f1, f2], ... x=Float(1, 2, 3, 4, 5, 6)) >>> # No output (yet) >>> dr.eval() in f1: x=[1, 2, 3] in f2: x=[4, 5, 6]
Dynamic arrays with more than 20 entries will be abbreviated. Specify the
limit=..argument to reveal the contents of larger arrays.>>> dr.print(dr.arange(dr.llvm.Int, 30)) [0, 1, 2, .. 24 skipped .., 27, 28, 29] >>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵ 23, 24, 25, 26, 27, 28, 29]
This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:
Positional arguments (in
*args) can be referenced implicitly ({}), or using indices ({0},{1}, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.Keyword arguments (in
**kwargs) can be referenced via their keyword name ({foo}). Unreferenced keywords will be silently ignored.A trailing
=in a brace expression repeats the string within the braces followed by the output:>>> dr.print('{foo=}', foo=1) foo=1
When the format string
fmtis omitted, it is implicitly set to{}, and the function formats a single positional argument.The function implicitly appends
endto the format string, which is set to a newline by default. The final result is sent tosys.stdout(by default) orfile. When afileargument is given, it must implement the methodwrite(arg: str).A related operation
drjit.format()admits the same format string syntax but returns a Pythonstrinstead of printing to the console. This operation, however, does not support symbolic inputs—usedrjit.print()with a customfileargument to stringify symbolic inputs asynchronously.Note
This operation is not suitable for extracting large amounts of data from Dr.Jit kernels, as the conversion to a string representation incurs a nontrivial runtime cost.
Note
Technical details on symbolic printing
When Dr.Jit compiles and executes queued computation on the target device, it includes additional code for symbolic print operations that captures referenced arguments and copies them back to the host (CPU). The information is then printed following the end of that process.
Only a limited amount of memory is set aside to capture the output of symbolic print operations. This is because the amount of data produced within long-running symbolic loops can often exceed the total device memory. Also, printing gigabytes of ASCII text into a Python console or Jupyter notebook is likely not a good idea.
For the electronically inclined, the operation is best thought of as hooking up an oscilloscope to a high-frequency circuit. The oscilloscope provides a limited view into a vast torrent of data to assist the user, who would be overwhelmed if the oscilloscope worked by capturing and showing everything.
The operation warns when the size of the buffers was insufficient. In this case, the output is still printed in the correct order, but chunks of the data are missing. The position of the resulting holes is unspecified and non-deterministic.
>>> dr.print(dr.arange(Float, 10000000), method='symbolic') >>> dr.eval() [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] RuntimeWarning: dr.print(): symbolic print statement only captured 20 of 10000000 available outputs. The above is a non-deterministic sample, in which entries are in the right order but not necessarily contiguous. Specify `limit=..` to capture more information and/or add the special format field `{thread_id}` show the thread ID/array index associated with each entry of the captured output.
This is because the (many) parallel threads of the program all try to append their state to the output buffer, but only the first
limit(20 by default) succeed. The host subsequently re-sorts the captured data by thread ID. This means that the output[5, 6, 102, 188, 1026, ..]would also be a valid result of the prior command. When a print statement references multiple arrays, then the operations either shows all array entries associated with a particular execution thread, or none of them.To refine what is captured, you can specify the
activeargument to disable the print statement for a subset of the entries (a “trigger” in the oscilloscope analogy). Printing from an inactive thread within a symbolic loop (drjit.while_loop()), conditional (drjit.if_stmt()), or call (drjit.switch(),drjit.dispatch()) will likewise not generate any output.A potential gotcha of the current design is that a symbolic print within a symbolic loop counts as one print statement and will only generate a single combined output string. The output of each thread is arranged in one contiguous block. You can add the special format string keyword
{thread_id}to reveal the mapping between output values and the execution thread that generated them:>>> from drjit.llvm import Int >>> @dr.syntax >>> def f(j: Int): ... i = Int(0) ... while i < j: ... dr.print('{thread_id=} {i=}', i=i) ... i += 1 ... >>> f(Int(2, 3)) >>> dr.eval(); thread_id=[0, 0, 1, 1, 1], i=[0, 1, 0, 1, 2]
The example above runs a symbolic loop twice in parallel: the first thread runs for 2 iterations, and the second runs for 3 iterations. The loop prints the iteration counter
i, which then leads to the output[0, 1, 0, 1, 2]where the first two entries are produced by the first thread, and the trailing three belong to the second thread. Thethread_idoutput clarifies this mapping.- Parameters:
fmt (str) – A format string that potentially references input arguments from
*argsand**kwargs.active (drjit.ArrayBase | bool) – A mask argument that can be used to disable a subset of the entries. The print statement will be completely suppressed when there is no output. (default:
True).end (str) – This string will be appended to the format string. It is set to a newline character (
"\n") by default.file (object) – The print operation will eventually invoke
file.write(arg:str)to print the formatted string. Specify this argument to route the output somewhere other than the default output streamsys.stdout.mode (str) – Specify this parameter to override the evaluation mode. Possible values are:
"symbolic","evaluated", or"auto". The default value of"auto"causes the function to use evaluated mode (which prints immediately) unless a symbolic input is detected, in which case printing takes place symbolically (i.e., in a delayed fashion).limit (int) – The operation will abbreviate dynamic arrays with more than
limit(default: 20) entries.
print(value: object, /, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) -> None
- drjit.format(fmt: str, *args, limit: int = 20, **kwargs)¶
- drjit.format(value: object, *, limit: int = 20, **kwargs) None
Overloaded function.
format(fmt: str, *args, limit: int = 20, **kwargs)
Return a formatted string representation.
This function generates a formatted string representation as specified by a format string
fmtand then returns it as a Pythonstrobject. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.>>> from drjit.cuda import Array3f >>> s = dr.format("{}:\n{foo}", ... "A PyTree containing an array", ... foo={ 'a' : Array3f(1, 2, 3) }) >>> print(s) A PyTree containing an array: { 'a': [[1, 2, 3]] }
Dynamic arrays with more than 20 entries will be abbreviated. Specify the
limit=..argument to reveal the contents of larger arrays.>>> dr.format(dr.arange(dr.llvm.Int, 30)) [0, 1, 2, .. 24 skipped .., 27, 28, 29] >>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵ 23, 24, 25, 26, 27, 28, 29]
This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:
Positional arguments (in
*args) can be referenced implicitly ({}), or using indices ({0},{1}, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.Keyword arguments (in
**kwargs) can be referenced via their keyword name ({foo}). Unreferenced keywords will be silently ignored.A trailing
=in a brace expression repeats the string within the braces followed by the output:>>> dr.format('{foo=}', foo=1) foo=1
When the format string
fmtis omitted, it is implicitly set to{}, and the function formats a single positional argument.In contrast to the related
drjit.print(), this function does not output the result on the console, and it cannot support symbolic inputs. This is because returning a string right away is incompatible with the requirement of evaluating/formatting symbolic inputs in a delayed fashion. If you wish to format symbolic arrays, you must calldrjit.print()with a customfileobject that implements the.write()function. Dr.Jit will call this function with the generated string when it is ready.- Parameters:
fmt (str) – A format string that potentially references input arguments from
*argsand**kwargs.limit (int) – The operation will abbreviate dynamic arrays with more than
limit(default: 20) entries.
- Returns:
The formatted string representation created as specified above.
- Return type:
str
format(value: object, *, limit: int = 20, **kwargs)
- drjit.log_level() drjit.LogLevel¶
- drjit.set_log_level(arg: drjit.LogLevel, /) None¶
- drjit.set_log_level(arg: int, /) None
Profiling¶
- class drjit.profile_enable(*args, **kwargs)¶
Context manager to selectively activate profiling for a region of a program.
Some profiling tools (e.g., NSight Compute) support targeted profiling of smaller parts of a program. Use this context manager to locally enable profiling.
Note the difference between this context manager and
dr.profile_range(), which annotates a profiled region with a label.
- drjit.profile_mark(arg: str, /) None¶
Mark an event on the timeline of profiling tools.
Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.
Note that this event will be recorded on the CPU timeline.
- class drjit.profile_range(*args, **kwargs)¶
Context manager to mark a region (e.g. a function call) on the timeline of profiling tools.
You can use this context manager to wrap parts of your code and track when and for how long it runs. Regions can be arbitrarily nested, which profiling tools visualize as a stack.
Note that this function is intended to track activity on the CPU timeline. If the wrapped region launches asynchronous GPU kernels, then those won’t generally be included in the length of the range unless
drjit.sync_thread()or some other type of synchronization operation waits for their completion (which is generally not advisable, since keeping CPU and GPU asynchronous with respect to each other improves performance).Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.
Note the difference between this context manager and
dr.profile_enable(), which enables targeted profiling of a smaller region of code (as opposed to profiling the entire program).
Textures¶
The texture implementations are defined in the various backends.
(e.g. drjit.llvm.ad.Texture3f16). However, they reference
enumerations provided here
Low-level bits¶
- drjit.set_backend(arg: Literal['cuda', 'llvm', 'scalar'], /)¶
- drjit.set_backend(arg: drjit.JitBackend, /) None
Overloaded function.
set_backend(arg: Literal['cuda', 'llvm', 'scalar'], /)
Adjust the
drjit.auto.*module so that it refers to types from the specified backend.set_backend(arg: drjit.JitBackend, /) -> None
- drjit.thread_count() int¶
Return the number of threads that Dr.Jit uses to parallelize computation on the CPU.
The Python main thread will generally also participate in parallel execution, and it implicitly counts as one of the threads. A return value of
Ntherefore corresponds toN - 1worker threads plus the calling thread.
- drjit.set_thread_count(arg: int, /) None¶
Adjust the number of threads that Dr.Jit uses to parallelize computation on the CPU.
The Python main thread will generally also participate in parallel execution, and it implicitly counts as one of the threads. Passing
Ntherefore spawnsN - 1worker threads; the values0and1both disable worker threads entirely, in which case all parallel work runs on the calling thread.The thread pool is primarily used by Dr.Jit’s LLVM backend. Other projects using underlying nanothread thread pool library will also be affected by changes performed using by this function. It is legal to call it even while parallel computation is currently ongoing.
- drjit.sync_thread() None¶
Wait for all currently running computation to finish.
This function synchronizes the device (e.g. the GPU) with the host (CPU) by waiting for the termination of all computation enqueued by the current host thread.
One potential use of this function is to measure the runtime of a kernel launched by Dr.Jit. We instead recommend the use of the
drjit.kernel_history(), which exposes more accurate device timers.In general, calling this function in user code is considered bad practice. Dr.Jit programs “run ahead” of the device to keep it fed with work. This is important for performance, and
drjit.sync_thread()breaks this optimization.All operations sent to a device (including reads) are strictly ordered, so there is generally no reason to wait for this queue to empty. If you find that
drjit.sync_thread()is needed for your program to run correctly, then you have found a bug. Please report it on the project’s GitHub issue tracker.
- drjit.flush_kernel_cache() None¶
Release all currently cached kernels.
When Dr.Jit evaluates a previously unseen computation, it compiles a kernel and then maps it into the memory of the CPU or GPU. The kernel stays resident so that it can be immediately reused when that same computation reoccurs at a later point.
In long development sessions (e.g. a Jupyter notebook-based prototyping), this cache may eventually become unreasonably large, and calling
flush_kernel_cache()to free it may be advisable.Note that this does not free the disk cache that also exists to preserve compiled programs across sessions. To clear this cache as well, delete the directory
$HOME/.drjiton Linux/macOS, and%AppData%\Local\Temp\drjiton Windows. (TheAppDatafolder is typically found inC:\Users\<your username>).
- drjit.flush_malloc_cache() None¶
Free the memory allocation cache maintained by Dr.Jit.
Allocating and releasing large chunks of memory tends to be relatively expensive, and Dr.Jit programs often need to do so at high rates.
Like most other array programming frameworks, Dr.Jit implements an internal cache to reduce such allocation-related costs. This cache starts out empty and grows on demand. Allocated memory is never released by default, which can be problematic when using multiple array programming frameworks within the same Python session, or when running multiple processes in parallel.
The
drjit.flush_malloc_cache()function releases all currently unused memory back to the operating system. This is a relatively expensive step: you likely don’t want to use it within a performance-sensitive program region (e.g. an optimization loop).
- drjit.expand_threshold() int¶
Query the threshold for performing scatter-reductions via expansion.
Getter for the quantity set in
drjit.set_expand_threshold()
- drjit.set_expand_threshold(arg: int, /) None¶
Set the threshold for performing scatter-reductions via expansion.
The documentation of
drjit.ReduceOpexplains the cost of atomic scatter-reductions and introduces various optimization strategies.One particularly effective optimization (the section on optimizations for plots) named
drjit.ReduceOp.Expandis specific to the LLVM backend. It replicates the target array to avoid write conflicts altogether, which enables the use of non-atomic memory operations. This is significantly faster but also very memory-intensive. The storage cost of an 1MB array targeted by adrjit.scatter_reduce()operation now grows toNmegabytes, whereNis the number of cores.For this reason, Dr.Jit implements a user-controllable threshold exposed via the functions
drjit.expand_threshold()anddrjit.set_expand_threshold(). When the array has more entries than the value specified here, thedrjit.ReduceOp.Expandstrategy will not be used unless specifically requested via themode=parameter of operations likedrjit.scatter_reduce(),drjit.scatter_add(), anddrjit.gather().The default value of this parameter is 1000000 (1 million entries).
- drjit.kernel_history(types: collections.abc.Sequence[drjit.KernelType] = []) list¶
Return the history of captured kernel launches.
Dr.Jit can optionally capture performance-related metadata. To do so, set the
drjit.JitFlag.KernelHistoryflag as follows:with dr.scoped_set_flag(dr.JitFlag.KernelHistory): # .. computation to be analyzed .. hist = dr.kernel_history()
The
drjit.kernel_history()function returns a list of dictionaries characterizing each major operation performed by the analyzed region. This dictionary has the following entriesbackend: The used JIT backend.execution_time: The time (in milliseconds) used by this operation.On the CUDA backend, this value is captured via CUDA events. On the LLVM backend, this involves querying
CLOCK_MONOTONIC(Linux/macOS) orQueryPerformanceCounter(Windows).type: The type of computation expressed by an enumeration value of typedrjit.KernelType. The most interesting workload generated by Dr.Jit are just-in-time compiled kernels, which are identified bydrjit.KernelType.JIT.These have several additional entries:
hash: The hash code identifying the kernel. (This is the same hash code is also shown when increasing the log level viadrjit.set_log_level()).ir: A capture of the intermediate representation used in this kernel.operation_count: The number of low-level IR operations. (A rough proxy for the complexity of the operation.)cache_hit: Was this kernel present in Dr.Jit’s in-memory cache? Otherwise, it as either loaded from memory or had to be recompiled from scratch.cache_disk: Was this kernel present in Dr.Jit’s on-disk cache? Otherwise, it had to be recompiled from scratch.codegen_time: The time (in milliseconds) which Dr.Jit needed to generate the textual low-level IR representation of the kernel. This step is always needed even if the resulting kernel is already cached.backend_time: The time (in milliseconds) which the backend (either the LLVM compiler framework or the CUDA PTX just-in-time compiler) required to compile and link the low-level IR into machine code. This step is only needed when the kernel did not already exist in the in-memory or on-disk cache.uses_optix: Was this kernel compiled by the NVIDIA OptiX ray tracing engine?
recording_mode: Indicates if this kernel was executed in the context of a frozen function (see@dr.freeze) and if so, if it was recorded or replayed by one.
Note that
drjit.kernel_history()clears the history while extracting this information. A related operationdrjit.kernel_history_clear()only clears the history without returning any information.
- drjit.kernel_history_clear() None¶
Clear the kernel history.
This operation clears the kernel history without returning any information about it. See
drjit.kernel_history()for details.
- drjit.detail.set_leak_warnings(arg: bool, /) None¶
Dr.Jit tracks and can report leaks of various types (Python instance leaks, Dr.Jit-Core variable leaks, AD variable leaks). Since benign warnings can sometimes occur, they are disabled by default for PyPI release builds. Use this function to enable/disable them explicitly.
- drjit.detail.leak_warnings() bool¶
Query whether leak warnings are enabled. See
drjit.detail.set_leak_warnings().
- drjit.detail.llvm_version() tuple¶
- drjit.detail.cuda_version() tuple¶
Typing¶
Local memory¶
- drjit.alloc_local(dtype: type[T], size: int, value: T | None = None) Local[T]¶
Allocate a local memory buffer with type
dtypeand sizesize.See the separate documentation section on local memory for details on the role of local memory and situations where it is useful.
- Parameters:
dtype (type) – Desired Dr.Jit array type or PyTree.
size (int) – Number of buffer elements. This value must be statically known.
value – If desired, an instance of type
dtype/Tcan be provided here to default-initialize all entries of the buffer. Otherwise, it is left uninitialized.
- Returns:
The allocated local memory buffer
- Return type:
Local[T]
- class drjit.Local(*args, **kwargs)¶
This generic class (parameterized by an extra type
T) represents a local memory buffer—that is, a temporary scratch space with support for indexed reads and writes.See the separate documentation section on local memory for details on the role of local memory and situations where it is useful.
- __init__(self, arg: drjit.Local) None¶
Copy-constructor, creates a copy of a given local memory buffer.
- read(self, index: int | AnyArray, active: bool | AnyArray = True) T¶
Read the local memory buffer at index
indexand return a result of typeT. An optional mask can be provided as well. Masked reads evaluate to zero.Danger
The indices provided to this operation are unchecked by default. Attempting to read beyond the end of the buffer is undefined behavior and may crash the application, unless such reads are explicitly disabled via the
activeparameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debugflag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound reads and furthermore report warnings to identify problematic source locations.
- write(self, value: T, index: int | AnyArray, active: bool | AnyArray = True) None¶
Store the value
valueat indexindex. An optional mask can be provided as well. Masked writes are no-ops.Danger
The indices provided to this operation are unchecked by default. Attempting to write beyond the end of the buffer is undefined behavior and may crash the application, unless such writes are explicitly disabled via the
activeparameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debugflag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound writes and furthermore report warnings to identify problematic source locations.
- __getitem__(self, arg: int | AnyArray, /) T¶
Perform an normal read at the given index. This is equivalent to
read(.., active=True).
- __setitem__(self, arg0: int | AnyArray, arg1: T, /) None¶
Perform an normal write at the given index. This is equivalent to
write(.., active=True).
- __len__(self) int¶
Return the length (number of entries) of the local memory buffer. This corresponds to the
sizevalue passed todrjit.alloc_local().
Digital Differential Analyzer¶
The drjit.dda module provides a general implementation of a digital
differential analyzer (DDA) that steps through the intersection of a ray
segment and a N-dimensional grid, performing a custom computation at every
cell.
The drjit.integrate() function builds on this functionality to compute
differentiable line integrals of bi- or trilinearly interpolants stored on a
grid.
- drjit.dda.dda(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: object, grid_res: ArrayNuT, grid_min: ArrayNfT, grid_max: ArrayNfT, func: Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], Tuple[StateT, BoolT]], state: StateT, active: BoolT, mode: Literal['scalar', 'symbolic', 'evaluated', None] = None, max_iterations: int | None = None) StateT¶
N-dimensional digital differential analyzer (DDA).
This function traverses the intersection of a Cartesian coordinate grid and a specified ray or ray segment. The following snippet shows how to use it to enumerate the intersection of a grid with a single ray.
from drjit.scalar import Array3f, Array3u, Float, Bool def dda_fun(state: list, index: Array3u, pt_in: Array3f, pt_out: Array3f, active: Bool) -> tuple[list, bool]: # Entered a grid cell, stash it in the 'state' variable state.append(Array3f(index)) return state, Bool(True) result = dda( ray_o = Array3f(-.1), ray_d = Array3f(.1, .2, .3), ray_max = Float(float('inf')), grid_res = Array3u(10), grid_min = Array3f(0), grid_max = Array3f(1), func = dda_fun, state = [], active = Bool(True) ) print(result)
Since all input elements are Dr.Jit arrays, everything works analogously when processing
Nrays andNpotentially different grid configurations. The entire process can be captured symbolically.The function takes the following arguments. Note that many of them are generic type variables (signaled by ending with a capital
T). To support different dimensions and precisions, the implementation must be able to deal with various input types, which is communicated by these type variables.- Parameters:
ray_o (ArrayNfT) – the ray origin, where the
ArrayNfTtype variable refers to an n-dimensional scalar or Jit-compiled floating point array.ray_d (ArrayNfT) – the ray direction. Does not need to be normalized.
ray_max (object) – the maximum extent along the ray, which is permitted to be infinite. The value is specfied as a multiple of the norm of
ray_d, which is not necessarily unit-length. Must be of typedr.value_t(ArrayNfT).grid_res (ArrayNuT) – the grid resolution, where the
ArrayNuTtype variable refers to a matched 32-bit integer array (i.e.,ArrayNuT = dr.int32_array_t(ArrayNfT)).grid_min (ArrayNfT) – the minimum position of the grid bounds.
grid_max (ArrayNfT) – the maximum position of the grid bounds.
func (Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], tuple[StateT, BoolT]]) –
a callback that will be invoked when the DDA traverses a grid cell. It must take the following five positional arguments:
arg0: StateT: An arbitrary state value.arg1: ArrayNuT: An integer array specifying the cell index along each dimension.arg2: ArrayNfT: The fractional position (\(\in [0, 1]^n\)) where the ray enters the current cell.arg3: ArrayNfT: The fractional position (\(\in [0, 1]^n\)) where the ray leaves the current cell.arg4: BoolT: A boolean array specifying which elements are active.
The callback should then return a tuple of type
tuple[StateT, BoolT]containingAn updated state value.
A boolean array that can be used to exit the loop prematurely for some or all rays. The iteration stops if the associated entry of the return value equals
False.
state (StateT) – an arbitrary initial state that will be passed to the callback.
active (BoolT) – an array specifying which elements of the input are active, where the
BoolTtype variable refers to a matched boolean array (i.e.,BoolT = dr.mask_t(ray_o.x)).mode – (str | None): The operation can operate in scalar, symbolic, or evaluated modes—see the
modeargument and the documentation ofdrjit.while_loop()for details.max_iterations – int | None: Bound on the iteration count that is needed for reverse-mode differentiation. Forwarded to the
max_iterationsparameter ofdrjit.while_loop().
- Returns:
The function returns the final state value of the callback upon termination.
- Return type:
StateT
Note
Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g.
(Z, Y, X)), while regular 3D positions use the opposite(X, Y, Z)order.In particular, all
ArrayNuT-typed parameters of the function and the callback use the ZYX convention, whileArrayNfT-typed parameters use theXYZconvention.
- drjit.dda.integrate(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: FloatT, grid_min: ArrayNfT, grid_max: ArrayNfT, vol: ArrayBase[Any, Any, Any, Any, Any, Any, Any], active: object = None, mode: Literal['scalar', 'symbolic', 'evaluated', None] = None) FloatT¶
Compute an analytic definite integral of a bi- or trilinear interpolant.
This function uses DDA (
drjit.dda.dda()) to step along the voxels of a 2D/3D volume traversed by a finite segment or a infinite-length ray. It analytically computes and accumulates the definite integral of the interpolant in each voxel.The input 2D/3D volume is provided using a tensor
vol(e.g., of typedrjit.cuda.ad.TensorXf) with an implicitly specified grid resolutionvol.shape. This data volume is placed into an axis-aligned region with bounds (grid_min,grid_max).The operation provides an efficient forward and backward derivative.
Note
Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g.
(Z, Y, X)), while regular 3D positions use the opposite(X, Y, Z)order.In particular,
vol.shapeuses the ZYX convention, while theArrayNfT-typed parameters use theXYZconvention.One important difference to the texture classes is that the interpolant is sampled at integer grid positions, whereas the Dr.Jit texture classes places values at cell centers, i.e. with a
.5fractional offset.
Optimizers¶
The drjit.opt module implements basic infrastructure for
gradient-based optimization and adaptive mixed-precision training.
- class drjit.opt.Optimizer(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, mask_updates: bool = False, promote_fp16: bool = True)¶
Gradient-based optimizer base class
This class resembles a Python dictionary enabling retrieval and assignment of parameter values by name. It furthermore implements common functionality used by gradient-based optimizers, such as stochastic gradient descent (
SGD) and Adam (Adam).Typically, optimizers are used as follows:
# Create an optimizer object opt = Adam(lr=1e-3) # Register one or more parameters opt["x"] = Float(1, 2, 3) # Alternative syntax # opt = Adam( # lr=1e-3, # params={'x': Float(1, 2, 3)} # ) # For some number of iterations.. for i in range(1000): # Fetch the current parameter value x = opt["x"] # Compute a scalar loss value and backpropagate loss = my_optimization_task(x) dr.backward(loss) # Take a gradient step (details depend on the choice of optimizer) opt.step()
Following
opt.step(), some applications may need to project the parameter values onto a valid subset, e.g.:# Values outside of the range [0, 1] are not permitted opt['x'] = clip(opt['x'], 0, 1)
You may register additional parameters or delete existing ones during an optimization.
Note
There are several notable differences compared to optimizers in PyTorch:
PyTorch keeps references to the original parameters and manipulates them in-place. The user must call
opt.zero_grad()to clear out remaining gradients from the last iteration.Dr.Jit optimizers own the parameters being optimized. The function
opt.step()only updates this internal set of parameters without causing changes elsewhere.Compared to PyTorch, an optimization loop therefore involves some boilerplate code (e.g.,
my_param = opt["my_param"]) to fetch parameter values and use them to compute a differentiable objective.
In general, it is recommended that you optimize Dr.Jit code using Dr.Jit optimizers, rather than combining frameworks with differentiable bridges (e.g., via the
@dr.wrapdecorator), which can add significant inter-framework communication and bookkeeping overheads.Note
It is tempting to print or plot the loss decay during the optimization. However, doing so forces the CPU to wait for asynchronous execution to finish on the device (e.g., GPU), impeding the system’s ability to schedule new work during this time. In other words, such a seemingly minor detail can actually have a detrimental impact on device utilization. As a workaround, consider printing asynchronously using
dr.print(loss, symbolic=True).- step(*, eval: bool = True, grad_scale: float | ArrayBase | None = None, active: ArrayBase | None = None) None¶
Take a gradient step.
This function visits each registered parameter in turn and
extracts the associated gradient (
.grad),takes a step using an optimizer-dependent update rule, and
reattaches the resulting new state with the AD graph.
- Parameters:
eval (bool) – If
eval=True(the default), the system will explicitly evaluate the resulting parameter state, causing a kernel lauch. Seteval=Falseif you wish to perform further computation that should be fused into the optimizer step.grad_scale (float | drjit.ArrayBase) – Use this parameter to scale all gradients by a custom amount. Dr.Jit uses this parameter for automatic mixed-precision training.
active (drjit.ArrayBase | None) – This parameter can be used to pass a 1-element boolean mask. A value of
Bool(False)disables the optimizer state update. Dr.Jit uses this parameter for automatic mixed-precision training.
- reset(key: str | None = None) None¶
Reset the internal state (e.g., momentum, adaptive learning rate, etc.) associated with the parameter
key. Whenkey=None, the implementation resets the state of all parameters.
- update(params: Mapping[str, ArrayBase] | None = None, **args: ArrayBase) None¶
Overwrite multiple parameter values at once.
This function simply calls
__setitem__()multiple times. Likedict.update(),update()supports two calling conventions:# Update using a dictionary opt.update({'key_1': value_1, 'key_2': value_2}) # Update using a variable keyword arguments opt.update(key_1=value_1, key_2=value_2)
- learning_rate(key: str | None = None) float | ArrayBase | None¶
Return the learning rate (globally, or of a specific parameter).
When
keyis provided, the function returns the associated parameter-specific learning rate (orNone, if no learning rate was set for this parameter).When
keyis not provided, the function returns the default learning rate.
- set_learning_rate(value: float | ArrayBase | Mapping[str, float | ArrayBase | None] | None = None, /, **kwargs: float | ArrayBase | None) None¶
Set the learning rate (globally, or of a specific parameter).
This function can be used as follows:
To modify the default learning rate of the optimizer:
opt = Adam(lr=1e-3) # ... some time later: opt.set_learning_rate(1e-4)
To modify the learning rate of a specific parameter:
opt = Adam(lr=1e-3, params={'x': x, 'y': y}) opt.set_learning_rate({'y': 1e-4}) # Alternative calling convention opt.set_learning_rate(y=1e-4)
Note that once the learning rate of a specific parameter is set, it always takes precedence over the global setting. You must remove the parameter-specific setting to return to the global default:
opt.set_learning_rate(y=None)
- __setitem__(key: str, value: ArrayBase, /)¶
Overwrite a parameter value or register a new parameter.
Supported parameter types includes:
Differentiable Dr.Jit arrays and nested arrays
Differentiable Dr.Jit tensors
Special array types (matrices, quaternions, complex numbers). These will be optimized component-wise.
In contrast to assignment in a regular dictionary, this function conceptually creates a copy of the input parameter (conceptual because the Copy-On-Write optimization avoids an actual on-device copy).
When
keyrefers to a known parameter, the optimizer will overwrite it withvalue. In doing so, it will preserve any associated optimizer state, such as momentum, adaptive step size, etc. When the new parameter value is substantially different (e.g., as part of a different optimization run), the previous momentum value may be meaningless, in which case a call toreset()is advisable.When the new parameter value’s
.shapediffers from the current setting, the implementation automatically callsreset()to discard the associated optimizer state.When
keydoes not refer to a known parameter, the optimizer will register it. Note that only differentiable parameters are supported—incompatible types will raise aTypeError.
- __delitem__(key: str, /) None¶
Remove a parameter from the optimizer.
- __contains__(key: object, /) bool¶
Check whether the optimizer contains a parameter with the name
key.
- __len__() int¶
Return the number of registered parameters.
- keys() Iterator[str]¶
Return an iterator traversing the names of registered parameters.
- class drjit.opt.SGD(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, momentum: float = 0.0, nesterov: bool = False, mask_updates: bool = False, promote_fp16: bool = True)¶
Implements basic stochastic gradient descent (SGD) with a fixed learning rate and, optionally, momentum (0.9 is a typical parameter value for the
momentumparameter).The default initailization (
momentum=0) uses the following update equation:\[\begin{align*} \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta\cdot\mathbf{g}_{i+1}, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, and \(\eta\) is the learning rate.
Momentum-based SGD (with
momentum>0,nesterov=False) uses the update equation:\[\begin{align*} \mathbf{v}_{i+1} &= \mu\cdot\mathbf{v}_i + \mathbf{g}_{i+1}\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \cdot \mathbf{v}_{i+1}, \end{align*}\]where \(\mathbf{v}\) is the velocity and \(\mu\) is the momentum parameter. Nesterov-style SGD (
nesterov=True) switches to the following update rule:\[\begin{align*} \mathbf{v}_{i+1} &= \mu \cdot \mathbf{v}_i + \mathbf{g}_{i+1}\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \cdot (\mathbf{g}_{i+1} + \mu \mathbf{v}_{i+1}). \end{align*}\]Some frameworks implement variations of the above quations. The code in Dr.Jit was designed to reproduce the behavior of torch.optim.SGD.
- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, momentum: float = 0.0, nesterov: bool = False, mask_updates: bool = False, promote_fp16: bool = True)¶
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()to later adjust this value globally, or for specific parameters.momentum (float) – The momentum factor as described above. Larger values will cause past gradients to persist for a longer amount of time.
mask_updates (bool) – Mask updates to zero-valued gradient components? See
Optimizer.__init__()for details on this parameter.promote_fp16 (bool) – promoted half-precision variables to single precision internal storage? See
Optimizer.__init__()for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.RMSProp(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, alpha: float = 0.99, epsilon: float = 1e-08, mask_updates: bool = False, promote_fp16: bool = True)¶
Implements the RMSProp optimizer explained in lecture notes by G. Hinton.
RMSProp scales the learning rate by the reciprocal of a running average of the magnitude of past gradients:
\[\begin{align*} \mathbf{m}_{i+1} &= \alpha \cdot \mathbf{m}_i + (1-\alpha)\cdot\mathbf{g}_{i+1}^2\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \frac{\eta}{\sqrt{\mathbf{m}_{i+1}+\varepsilon}}\, \mathbf{g}_{i+1}, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, \(\mathbf{m}_i\) accumulates the second moment, \(\eta\) is the learning rate, and \(\varepsilon\) is a small number to avoid division by zero.
The implementation reproduces the behavior of torch.optim.RMSprop.
- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, alpha: float = 0.99, epsilon: float = 1e-08, mask_updates: bool = False, promote_fp16: bool = True)¶
Construct a RMSProp optimizer instance.
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()to later adjust this value globally, or for specific parameters.alpha (float) – Weight of the second-order moment exponential moving average (EMA). Values approaching
1will cause past gradients to persist for a longer amount of time.mask_updates (bool) – Mask updates to zero-valued gradient components? See
Optimizer.__init__()for details on this parameter.promote_fp16 (bool) – promoted half-precision variables to single precision internal storage? See
Optimizer.__init__()for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.Adam(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-08, mask_updates: bool = False, promote_fp16: bool = True, uniform: bool = False, amsgrad: bool = False)¶
This class implements the Adam optimizer as presented in the paper Adam: A Method for Stochastic Optimization by Kingman and Ba, ICLR 2015.
Adam effectively combines momentum (as in
SGDwithmomentum>0) with the adaptive magnitude-based scale factor fromRMSProp. To do so, it maintains two exponential moving averages (EMAs) per parameter: \(\mathbf{m}_i\) for the first moment, and \(\mathbf{v}_i\) for the second moment. This triples the memory usage and should be considered when optimizing very large representations.The method uses the following update equation:
\[\begin{align*} \mathbf{m}_{i+1} &= \beta_1 \cdot \mathbf{m}_i + (1-\beta_1)\cdot\mathbf{g}_{i+1}\\ \mathbf{v}_{i+1} &= \beta_2 \cdot \mathbf{v}_i + (1-\beta_2)\cdot\mathbf{g}_{i+1}^2\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \frac{\sqrt{1-\beta_2^{i+1}}}{1-\beta_1^{i+1}} \frac{\mathbf{m}_{i+1}}{\sqrt{\mathbf{v}_{i+1}}+\varepsilon}, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, \(\eta\) is the learning rate, and \(\varepsilon\) is a small number to avoid division by zero.
The scale factor \(\frac{\sqrt{1-\beta_2^{i+1}}}{1-\beta_1^{i+1}}\) corrects for the zero-valued initialization of the moment accumulators \(\mathbf{m}_i\) and \(\mathbf{v}_i\) at \(i=0\).
This class also implements several extensions that are turned off by default. See the descriptions of the
mask_updates,uniform, andamsgradparameters below.- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-08, mask_updates: bool = False, promote_fp16: bool = True, uniform: bool = False, amsgrad: bool = False)¶
Construct a new Adam optimizer object. The default parameters replicate the behavior of the original method.
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()to later adjust this value globally, or for specific parameters.beta_1 (float) – Weight of the first-order moment exponential moving average (EMA). Values approaching
1will cause past gradients to persist for a longer amount of time.beta_2 (float) – Weight of the second-order moment EMA. Values approaching
1will cause past gradients to persist for a longer amount of time.uniform (bool) – If enabled, the optimizer will use the UniformAdam variant of Adam [Nicolet et al. 2021], where the update rule uses the maximum of the second moment estimates at the current step instead of the per-element second moments.
amsgrad (bool) – If enabled, the optimizer will use the AMSGrad variant [Reddi et al. 2018], which maintains the maximum of all past squared gradient values and uses that maximum instead of the exponential moving average to normalize the gradient. This can help with convergence in some cases where Adam fails.
mask_updates (bool) – Mask updates to zero-valued gradient components? See
Optimizer.__init__()for details on this parameter.promote_fp16 (bool) – promoted half-precision variables to single precision internal storage? See
Optimizer.__init__()for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.AdamW(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-08, weight_decay: float = 0.01, mask_updates: bool = False, promote_fp16: bool = True, uniform: bool = False, amsgrad: bool = False)¶
This class implements the AdamW optimizer as presented in the paper Decoupled Weight Decay Regularization by Loshchilov and Hutter, ICLR 2019.
AdamW improves upon Adam by decoupling weight decay from gradient-based optimization. Instead of applying weight decay through L2 regularization in the loss function (which interacts poorly with adaptive learning rates), AdamW applies weight decay directly to the parameters.
The method uses the following update equations:
\[\begin{align*} \mathbf{m}_{i+1} &= \beta_1 \cdot \mathbf{m}_i + (1-\beta_1)\cdot\mathbf{g}_{i+1}\\ \mathbf{v}_{i+1} &= \beta_2 \cdot \mathbf{v}_i + (1-\beta_2)\cdot\mathbf{g}_{i+1}^2\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \frac{\sqrt{1-\beta_2^{i+1}}}{1-\beta_1^{i+1}} \frac{\mathbf{m}_{i+1}}{\sqrt{\mathbf{v}_{i+1}}+\varepsilon} - \eta \lambda \mathbf{p}_i, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, \(\eta\) is the learning rate, \(\varepsilon\) is a small number to avoid division by zero, and \(\lambda\) is the weight decay coefficient.
The key difference from Adam is the final term \(- \eta \lambda \mathbf{p}_i\), which applies weight decay directly to parameters independently of gradients. This leads to better generalization, especially when using adaptive learning rates.
The scale factor \(\frac{\sqrt{1-\beta_2^{i+1}}}{1-\beta_1^{i+1}}\) corrects for the zero-valued initialization of the moment accumulators.
- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-08, weight_decay: float = 0.01, mask_updates: bool = False, promote_fp16: bool = True, uniform: bool = False, amsgrad: bool = False)¶
Construct a new AdamW optimizer object with decoupled weight decay.
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()to later adjust this value globally, or for specific parameters.beta_1 (float) – Weight of the first-order moment exponential moving average (EMA). Values approaching
1will cause past gradients to persist for a longer amount of time.beta_2 (float) – Weight of the second-order moment EMA. Values approaching
1will cause past gradients to persist for a longer amount of time.weight_decay (float) – Weight decay coefficient for L2 regularization. Unlike Adam, this is applied directly to parameters rather than gradients, providing better regularization with adaptive learning rates.
uniform (bool) – If enabled, the optimizer will use the UniformAdam variant of Adam [Nicolet et al. 2021], where the update rule uses the maximum of the second moment estimates at the current step instead of the per-element second moments.
amsgrad (bool) – If enabled, the optimizer will use the AMSGrad variant [Reddi et al. 2018], which maintains the maximum of all past squared gradient values and uses that maximum instead of the exponential moving average to normalize the gradient. This can help with convergence in some cases where Adam fails.
mask_updates (bool) – Mask updates to zero-valued gradient components? See
Optimizer.__init__()for details on this parameter.promote_fp16 (bool) – promoted half-precision variables to single precision internal storage? See
Optimizer.__init__()for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.Muon(lr, params=None, *, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.0, mask_updates=False, promote_fp16=True)¶
Muon - MomentUm Orthogonalized by Newton-schulz.
Muon is specifically designed for the hidden weights of a neural network. Other parameters (input embeddings, output layers, scalars, biases) should be optimized with a standard method such as
AdamW.The optimizer runs SGD with momentum and then post-processes the (optionally Nesterov-mixed) update by replacing each 2D parameter’s update with (approximately) the nearest orthogonal matrix. The orthogonalization uses a quintic Newton-Schulz iteration whose coefficients maximize the slope at zero.
- Reference:
Keller Jordan, “Muon: An optimizer for hidden layers in neural networks,” https://kellerjordan.github.io/posts/muon/, 2024.
This class is a Dr.Jit port of the reference PyTorch implementation by Keller Jordan et al.
When handed a
drjit.nn.Moduleviaupdate(), Muon silently drops every non-2D entry (biases, scalars, etc.) — these are meant to be optimized by a standard rule such asAdamWattached to the same module. The two optimizers then cover disjoint subsets ofnet.keys(), and a singlenet.update(opt)call per optimizer at the top of the training loop keeps the network parameters synchronized:muon = Muon(lr=0.02) muon.update(net) adamw = AdamW(lr=1e-3) adamw.update({k: net[k] for k in net if len(net[k].shape) == 1}) for i in range(n_iter): net.update(muon) net.update(adamw) y = net(x_tensor) loss = ... dr.backward(loss) muon.step() adamw.step()
For
drjit.nn.pack()-based cooperative-vector networks, handMuonthe unpacked module and calldrjit.nn.pack()inside the training loop. The matrix-pack path is differentiable, so gradients on the packed buffer flow back through the layout transform to the per-layer 2D weight tensors and from there into Muon’s single-precision state. See the neural network documentation for the full side-by-side flows.- __init__(lr, params=None, *, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.0, mask_updates=False, promote_fp16=True)¶
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate, interpreted as the target spectral norm per update. Use
Optimizer.set_learning_rate()to later adjust this value globally, or for specific parameters.momentum (float) – Momentum factor used to compute the EMA of past gradients. Must lie in
[0, 1). A value of0.95is usually fine.nesterov (bool) – If enabled, the Nesterov-mixed combination of the gradient and the momentum EMA is orthogonalized instead of the raw momentum EMA.
ns_steps (int) – Number of Newton-Schulz iterations used to orthogonalize the update. Five steps are sufficient in practice.
weight_decay (float) – Decoupled AdamW-style weight decay coefficient. The update rule applies
-lr * weight_decay * pin addition to the orthogonalized step.mask_updates (bool) – Mask updates to zero-valued gradient components? See
Optimizer.__init__()for details on this parameter.promote_fp16 (bool) – Promote half-precision variables to single-precision internal storage? See
Optimizer.__init__()for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters. Each entry must be a 2D tensor.
- class drjit.opt.GradScaler(init_scale: float = 65536.0, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, debug: bool = False)¶
Gradient scaler for automatic mixed-precision training.
It is sometimes necessary to perform some part of a computation using lower precision (e.g.,
drjit.auto.Float16) to improve storage and runtime efficiency. One issue with such lower-precision arithmetic is that gradients tend to underflow to zero, which can break the optimization.The
GradientScalerclass implements a strategy for automatic mixed precision (AMP) training to prevent such numerical issues. A comprehensive overview of AMP can be found hereAMP in Dr.Jit works as follows:
Construct a
GradientScalerinstance prior to the optimization loop. Suppose it is calledscaler.
Invoke
scaler.scale(loss)()function to scale the optimization loss by a suitable value to prevent gradient underflow, and then propagate derivatives usingdrjit.backward()or a similar AD function.Replace the call to
opt.step()withscale.step(opt)()function, which removes the scaling prior to the gradient step.
Concretely, this might look as follows:
opt = Adam(lr=1e-3) opt['my_param'] = 0 scaler = GradScaler() for i in range(1000): my_param = opt['my_param'] loss = my_func(my_param) dr.backward(scaler.scale(loss)) scaler.step(opt)
A large scale factor can also cause the opposite problem: gradient overflow, which manifests in the form of infinity and NaN-valued gradient components.
GradientScalerautomatically detects this, skips the optimizer step, and decreases the step size.The implementation starts with a relatively aggressive scale factor that is likely to cause overflows, hence it may appear that the optimizer is initially stagnant for a few iterations. This is expected.
- __init__(init_scale: float = 65536.0, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, debug: bool = False)¶
The Dr.Jit
GradScalerclass follows the high-level API of pytorch.amp.GradScaler <https://pytorch.org/docs/stable/notes/amp_examples.html>__.- Parameters:
init_scale (float) – The initial scale factor.
growth_factor (float) – When
growth_intervaloptimization steps have taken place without overflows,GradScalerwill begin to progressively increase the scale by multiplying it withgrowth_factorat every iteration until an overflow is again detected.backoff_factor (float) – When an overflow issue is detected,
GradScalerwill decrease the scale factor by multiplying it withbackoff_factor.growth_interval (int) – A large iteration count, following which it can be helpful to begin exploring larger scale factors.
debug (bool) – Print a debug message whenever the scale changes. This synchronizes with the device after each step, which will have a negative effect on optimization performance.
Cooperative Vectors¶
The drjit.nn module provides infrastructure to implement small
neural networks and revolves around the notion of cooperative vectors that
facilitate code generation of matrix-vector products. Please see the separate
documentation section for an introduction.
- class drjit.nn.CoopVec(*args, **kwargs)¶
- __init__(self, *args: Unpack[Tuple[Union[drjit.ArrayBase[SelfT, SelfCpT, ValT, ValCpT, T, PlainT, MaskT], float, int], ...]]) None¶
The constructor accepts a variable number of arguments including Dr.Jit arrays, scalar Python integers and floating point values, and PyTrees. It flattens this input into a list of vector components.
At least one Jit-compiled array must be provided as input so that Dr.Jit can infer the cooperative vector’s element type. An exception will be raised if the input contains Dr.Jit arrays of inconsistent scalar types (e.g.,
drjit.cuda.Array2fanddrjit.cuda.UInt).
- __len__(self) int¶
- __repr__(self) str¶
- property index: int¶
Stores the Dr.Jit variable index of the cooperative vector.
- property type: type[drjit.ArrayBase]¶
Stores the element type
- class drjit.nn.MatrixView(*args, **kwargs)¶
The
drjit.nn.MatrixViewprovides pointer into a buffer along with shape and type metadata.Dr.Jit uses views to tightly pack sequences of matrices and bias vectors into a joint buffer, and to preserve information about the underlying data type and layout. The
__getitem__()function can be used to slice a view into smaller sub-blocks.The typical process is to pack a PyTree of weight and bias vectors via
drjit.pack()into an inference or training-optimal representation. The returned views can then be passed todrjit.nn.matvec().- __getitem__(self, arg: Union[int, slice, Tuple[Union[int, slice], Union[int, slice]]]) MatrixView¶
- property dtype: drjit.VarType¶
Scalar type underlying the view.
- property shape: tuple[int, int]¶
Number of rows/columns. Vectors are stored as matrices with one column.
- property layout: MatrixLayout¶
One of several possible matrix layouts (training/inference-optimal and row-major).
- property stride: int¶
Row stride (in # of elements)
- property size: int¶
Total number of elements
- property offset: int¶
Offset of the matrix data within
buffer(counted in # of elements)
- property transpose: bool¶
The
MatrixView.Tproperty flips this flag (all other values stay unchanged).
- property buffer: drjit.ArrayBase¶
The underlying buffer, which may contain additional matrices/vectors besides the data referenced by the
MatrixView.
- property T: MatrixView¶
Return a transposed view.
- property grad: MatrixView¶
Return an analogous view of the gradient.
- drjit.nn.view(arg: drjit.ArrayBase, /) drjit.nn.MatrixView¶
Convert a Dr.Jit array or tensor into a view.
This function simply returns a view of the original tensor without transforming the underlying representation. This is useful to
Use
drjit.nn.matvec()with a row-major matrix layout (which, however, is not recommended, since this can be significantly slower compared to matrices in inference/training-optimal layouts).Slice a larger matrix into sub-blocks before passing them to
drjit.nn.pack()(which also accepts views as inputs). This is useful when several matrices are already packed into a single matrix (which is, however, still in row-major layout). They can then be directly re-packed into optimal layouts without performing further unnecessary copies.
- drjit.nn.pack(arg: MatrixView | drjit.AnyArray, *, layout: Literal['inference', 'training'] = 'inference') MatrixView¶
- drjit.nn.pack(*args: PyTree, layout: Literal['inference', 'training'] = 'inference') Tuple[PyTree, ...]
Overloaded function.
pack(arg: MatrixView | drjit.AnyArray, *, layout: typing.Literal['inference', 'training'] = 'inference') -> MatrixView
Re-pack a set of matrices and bias vectors into a single contiguous buffer with an inference- or training-optimal layout for use with
drjit.nn.matvec().When the argument is an
nn.Module, the function returns(buffer, packed_module); seenn.Modulefor the training-loop pattern.A training-optimal layout must be used if the program backpropagates (as in
dr.backward*()) gradients through matrix-vector products. Inference (primal evaluation) and forward derivative propagation (as indr.forward*()) does not require a training-optimal layout.If the input matrices are already packed in a row-major layout, call
dr.nn.view()to create an efficient reference and then pass slices of the view todr.nn.pack(). This avoids additional copies.mat: TensorXf = ... mat_view = dr.nn.view(mat) A1_view, A2_view = dr.nn.pack( mat_view[0:32, :], mat_view[32:64, :] )
pack(*args: PyTree, layout: typing.Literal['inference', 'training'] = 'inference') -> typing.Tuple[PyTree, ...]
- drjit.nn.unpack(arg: MatrixView | drjit.AnyArray, /) MatrixView¶
- drjit.nn.unpack(*args: PyTree) Tuple[PyTree, ...]
Overloaded function.
unpack(arg: MatrixView | drjit.AnyArray, /) -> MatrixView
The function
dr.nn.unpack()transforms a sequence (or PyTree) of vectors and optimal-layout matrices back into row-major layout.A_out, b_out = dr.nn.unpack(A_opt, b_opt)
Note that the output of this function are (row-major) views into a shared buffer. Each view holds a reference to the shared buffer. Views can be converted back into regular tensors:
A = TensorXf16(A)
unpack(*args: PyTree) -> typing.Tuple[PyTree, ...]
- drjit.nn.matvec(A: MatrixView, x: drjit.nn.CoopVec[T], b: Optional[MatrixView] = None, /, transpose: bool = False) drjit.nn.CoopVec[T]¶
Evaluate a matrix-vector multiplication involving a cooperative vector.
This function takes a matrix view
A(seedrjit.nn.pack()anddrjit.nn.view()for details on views) and a cooperative vectorx. It then computes the associated matrix-vector product and returns it in the form of a new cooperative vector (potentially with a different size).The function can optionally apply an additive bias (i.e., to evaluate
A@x + b). This bias vectorbshould also be specified as a view.Specify
tranpose=Trueto multiply by the transpose of the matrixA. On the CUDA/OptiX backend, this feature requires thatAis in inference or training-optimal layout.
Neural Networks¶
Besides cooperative vector classes, the
drjit.nn module also provides convenient abstractions to declare,
evaluate, and train networks. Please see the separate documentation
section for an introduction.
- class drjit.nn.Module¶
This is the base class of a modular set of operations that make the specification of neural network architectures more convenient.
Module subclasses are PyTrees, which means that various Dr.Jit operations can automatically traverse them.
Every allocated
Moduleadditionally behaves as acollections.abc.MutableMappingkeyed by dotted parameter paths (e.g.,'layers.0.weights'). This mirrors thedrjit.opt.Optimizerinterface and enables symmetric parameter transfer:opt = Adam(lr=1e-3) opt.update(net) # pull every parameter into the optimizer (once) for i in range(n): net.update(opt) # push optimizer state back into the net ...
After attach, the optimizer is the source of truth:
opt.update(net)must not be called again. On a packed module the mapping exposes a single'weights'entry whose underlying buffer is referenced by the per-layerMatrixViewinstances; writes to that entry are in-place so the views remain valid.Constructing a neural network generally involves the following pattern:
# 1. Establish the network structure net = nn.Sequential( nn.Linear(-1, 32, bias=False), nn.ReLU(), nn.Linear(-1, 3) ) # 2. Instantiate the network for a specific backend + input size net = net.alloc(TensorXf16, 2) # 3. Pack coefficients into a training-optimal layout net = nn.pack(net, layout='training')
Network evaluation expects a cooperative vector as input (i.e.,
net(nn.CoopVec(...))) and returns another cooperative vector.- __call__(arg: CoopVec, /) CoopVec¶
Evaluate the model with an input cooperative vector and return the result.
- alloc(dtype: Type[ArrayBase], size: int = -1, rng: Generator | None = None) Module¶
Returns a new instance of the model with allocated weights.
This function expects a suitable tensor
dtype(e.g.drjit.cuda.ad.TensorXf16ordrjit.llvm.ad.TensorXf) that will be used to store the weights on the device.If the model or one of its sub-models is automatically sized (e.g.,
input_features=-1indrjit.nn.Linear), the final network configuration may ambiguous and an exception will be raised. Specify the optionalsizeparameter in such cases to inform the allocation about the size of the input cooperative vector.Layer weights are initialized using pseudorandom values obtained from the specified generator object
rng.Specifying a newly seeded random number generator with the same seed ensures that weights will be consistent across runs (i.e., calling
alloc()twice will produce the same initialization).If
rng=None(the default), a generator is constructed on the fly viadr.rng(seed=0x100000000). This particular seed value is used to de-correlate the network weights with respect to any potential future network evaluations that might be produced by a random number generator with the default seed (0). (Please ignore this paragraph if it is unclear, it explains a protection against a subtle/niche issue.)
- class drjit.nn.Sequential(*args: Module, prefix: str = '')¶
This model evaluates provided arguments
arg[0],arg[1], …, in sequence.The optional
prefixkeyword is prepended to every key exposed through theMutableMappinginterface (e.g.,'mlp.layers.0.weights'instead of'layers.0.weights'). This is useful when sharing a single optimizer across multiple networks. The prefix is retained throughpack().
- class drjit.nn.Linear(in_features: int = -1, out_features: int = -1, bias=True)¶
This layer represents a learnable affine linear transformation of the input data following the expression \(\mathbf{y} = \mathbf{A}\mathbf{x} + \mathbf{b}\).
It takes
in_featuresinputs and returns a cooperative vector without_featuresdimensions. The following parameter values have a special a meaning:in_features=-1: set the input size to match the previous model’s output (or the input of the network, if there is no previous model).out_features=-1: set the output size to match the input size.
The bias (\(\textbf{b}\)) term is optional and can be disabled by specifying
bias=False.The method
Module.alloc()initializes the underlying coefficient storage with random weights following a uniform Xavier initialization, i.e., uniform variates on the interval \([-k,k]\) where \(k=1/\sqrt{\texttt{out\_features}}\).
- class drjit.nn.ReLU¶
ReLU (rectified linear unit) activation function.
This model evaluates the following expression:
\[\mathrm{ReLU}(x) = \mathrm{max}\{x, 0\}.\]Accepts both
CoopVecand tensor inputs.
- class drjit.nn.LeakyReLU(negative_slope: float | ArrayBase = 0.01)¶
“Leaky” ReLU (rectified linear unit) activation function.
This model evaluates the following expression:
\[\mathrm{LeakyReLU}(x) = \begin{cases} x,&\mathrm{if}\ x\ge 0,\\ \texttt{negative\_slope}\cdot x,&\mathrm{otherwise}. \end{cases}\]Accepts both
CoopVecand tensor inputs.
- class drjit.nn.SinEncode(octaves: int = 0, shift: float = 0)¶
Map an input onto a higher-dimensional space by transforming it using sines and cosines of an increasing frequency.
\[x\mapsto \begin{bmatrix} \sin\bigl(2\pi(2^0\,x + 0\cdot s)\bigr)\\ \cos\bigl(2\pi(2^0\,x + 0\cdot s)\bigr)\\ \vdots\\ \sin\bigl(2\pi(2^{n-1}\,x + (n-1)\cdot s)\bigr)\\ \cos\bigl(2\pi(2^{n-1}\,x + (n-1)\cdot s)\bigr) \end{bmatrix}\]The value \(n\) refers to the number of octaves. This layer increases the dimension by a factor of \(2n\).
Note that this encoding has period 1. If your input exceeds the interval \([0, 1]\), it is advisable that you reduce it to this range to avoid losing information.
Minima/maxima of higher frequency components coincide on a regular lattice, which can lead to reduced fitting performance at those locations. Specify the optional
shiftparameter \(s\) (in fractional periods, so thatshift=0.25is a quarter period) to phase-shift the \(i\)-th octave by \(i\cdot s\) and avoid this.Accepts both
CoopVecand 2D tensor inputs (batched evaluation). For a tensor of shape(C, N)withNindependent samples, the output has shape(2\,n\,C, N).The following plot shows the first two octaves applied to the linear function on \([0, 1]\) (without shift).
- class drjit.nn.TriEncode(octaves: int = 0, shift: float = 0)¶
Map an input onto a higher-dimensional space by transforming it using triangular sine and cosine approximations of an increasing frequency.
\[x\mapsto \begin{bmatrix} \sin_\triangle(2^0\,x + 0\cdot s)\\ \cos_\triangle(2^0\,x + 0\cdot s)\\ \vdots\\ \sin_\triangle(2^{n-1}\, x + (n-1)\cdot s)\\ \cos_\triangle(2^{n-1}\, x + (n-1)\cdot s) \end{bmatrix}\]where
\[\cos_\triangle(x) = 1-4\left|x-\mathrm{round}(x)\right|\]and
\[\sin_\triangle(x) = \cos_\triangle(x-1/4).\]The value \(n\) refers to the number of octaves. This layer increases the dimension by a factor of \(2n\).
Note that this encoding has period 1. If your input exceeds the interval \([0, 1]\), it is advisable that you reduce it to this range to avoid losing information.
Minima/maxima of higher frequency components coincide on a regular lattice, which can lead to reduced fitting performance at those locations. Specify the optional
shiftparameter \(s\) (in fractional periods, so thatshift=0.25is a quarter period) to phase-shift the \(i\)-th octave by \(i\cdot s\) and avoid this.Accepts both
CoopVecand 2D tensor inputs (batched evaluation). For a tensor of shape(C, N)withNindependent samples, the output has shape(2\,n\,C, N).The following plot shows the first two octaves applied to the linear function on \([0, 1]\) (without shift).
- class drjit.nn.Exp¶
Applies the exponential function to each component.
\[\mathrm{Exp}(x) = e^x\]Accepts both
CoopVecand tensor inputs.
- class drjit.nn.Exp2¶
Applies the base-2 exponential function to each component.
\[\mathrm{Exp2}(x) = 2^x\]On the CUDA backend, this function directly maps to an efficient native GPU instruction. Accepts both
CoopVecand tensor inputs.
- class drjit.nn.Tanh¶
Applies the hyperbolic tangent function to each component.
\[\mathrm{Tanh}(x) = \frac{\exp(x)-\exp(-x)}{\exp(x)+\exp(-x)}\]On the CUDA backend, this function directly maps to an efficient native GPU instruction. Accepts both
CoopVecand tensor inputs.
- class drjit.nn.Cast(dtype: Type[ArrayBase] | None = None)¶
Cast the input to a different precision. Should be instantiated with the desired element type, e.g.
Cast(drjit.cuda.ad.Float32). Accepts bothCoopVecand tensor inputs.
- class drjit.nn.ScaleAdd(scale: float | int | ArrayBase | None = None, offset: float | int | ArrayBase | None = None)¶
Scale the input by a fixed scale and apply an offset.
Note that
scaleandoffsetare assumed to be constant (i.e., not trainable).\[\mathrm{ScaleAdd}(x) = x\cdot\texttt{scale} + \texttt{offset}\]Accepts both
CoopVecand tensor inputs.
- class drjit.nn.HashEncodingLayer(encoding: HashEncoding)¶
Simple layer wrapping a hash encoding like
drjit.nn.HashGridEncodingordrjit.nn.PermutoEncoding.Note that the parameters of the encoding will not be included when packing the network, as the data representations are generally incompatible. You must initialize the encoding parameters separately.
- class drjit.nn.HashGridEncoding(dtype: Type[ArrayBase], dimension: int, *, n_levels: int = 16, n_features_per_level: int = 2, hashmap_size: int = 2**19, base_resolution: int = 16, per_level_scale: float = 2, align_corners: bool = False, torchngp_compat: bool = False, smooth_weight_gradients: bool = False, smooth_weight_lambda: float = 1.0, init_scale: float = 0.0001, rng: Generator | None = None)¶
This encoding implements a Multiresolution Hash Grid. For every resolution level, this encoding looks up the \(2^D\) vertices of the cell in which the input point is located, performs multilinear interpolation, and concatenates the features accross all resolution levels.
- Parameters:
dimension – The dimensionality of the hash encoding. This corresponds to the number of input features the encoding can take.
n_levels – Hash encodings generally make use of multiple levels of the same encoding with different scales. This parameter specifies the number of levels used by this encoding.
n_features_per_level – Specifies how many features are stored at each vertex and at each level. The number of output features of the hash encoding is given by
n_levels * n_features_per_level. In order to ensure efficient gradient backpropagation, this value should be a multiple of two.hashmap_size – Specifies the maximal number of parameters per level of the hash encoding. HashGrids will use a dense grid lookup for layers with a low enough scale, and use less than
hashmap_sizenumber of parameters per level.base_resolution – The scale factor of the 0th layer in the hash encoding.
per_level_scale – To calculate the scale of a layer, the scale of the previous layer is multiplied by this value.
align_corners – If this value is
True, the simplex vertices are aligned with the domain of the encoding [0, 1].smooth_weight_gradients – whether to smooth the gradients of the weights by using a straight-through estimator.
smooth_weight_lambda – the value of lambda used for the straight-through estimator.
init_scale – The parameters of the hashgrid are initialized with a uniform distribution, ranging from -init_scale to +init_scale.
rng – Random number generator, used to initialize the parameters.
- class drjit.nn.PermutoEncoding(dtype: Type[ArrayBase], dimension: int, *, n_levels: int = 16, n_features_per_level: int = 2, hashmap_size: int = 2**19, base_resolution: int = 16, per_level_scale: float = 2, align_corners: bool = False, smooth_weight_gradients: bool = False, smooth_weight_lambda: float = 1.0, init_scale: float = 0.0001, rng: Generator | None = None)¶
Permutohedral lattice-based encoding inspired by the paper PermutoSDF Fast Multi-View Reconstruction with Implicit Surfaces using Permutohedral Lattices.
Unlike hash grid encodings that use regular grid lattices, this encoding employs a permutohedral lattice structure where simplices consist of triangles, tetrahedra, and higher-dimensional analogs. The key advantage is linear scaling: the number of vertices per simplex (and thus memory lookups per sample per level) grows linearly with dimensionality, compared to exponential growth in grid-based approaches.
This implementation by Tobias Zirr simplifies the original method by performing sorting and interpolation directly in \(d\)-dimensional space, avoiding the elevation to a hyperplane in \((d+1)\)-dimensional space used in the reference implementation.
- Parameters:
dimension – The dimensionality of the hash encoding. This corresponds to the number of input features the encoding can take.
n_levels – Hash encodings generally make use of multiple levels of the same encoding with different scales. This parameter specifies the number of levels used by this encoding.
n_features_per_level – Specifies how many features are stored at each vertex and at each level. The number of output features of the hash encoding is given by
n_levels * n_features_per_level. In order to ensure efficient gradient backpropagation, this value should be a multiple of two.hashmap_size – Specifies the maximal number of parameters per level of the hash encoding. HashGrids will use a dense grid lookup for layers with a low enough scale, and use less than
hashmap_sizenumber of parameters per level.base_resolution – The scale factor of the 0th layer in the hash encoding.
per_level_scale – To calculate the scale of a layer, the scale of the previous layer is multiplied by this value.
align_corners – If this value is
True, the simplex vertices are aligned with the domain of the encoding [0, 1].smooth_weight_gradients – whether to smooth the gradients of the weights by using a straight-through estimator.
smooth_weight_lambda – the value of lambda used for the straight-through estimator.
init_scale – The parameters of the hashgrid are initialized with a uniform distribution, ranging from -init_scale to +init_scale.
rng – Random number generator, used to initialize the parameters.
CUDA contexts¶
- class drjit.cuda.green_context(*args, **kwargs)¶
Context manager for creating CUDA green contexts with isolated GPU resources.
A green context allows partitioning the GPU into smaller units with a specific number of streaming multiprocessors (SMs). Dr.Jit can launch kernels into this green context, which will be isolated from other computations running on the remaining SMs. This is useful for concurrent kernel execution and resource management on GPUs.
Green contexts can be created once and entered/exited multiple times via the Python context manager interface.
Note
CUDA may not always satisfy the requested SM count exactly. The actual number of SMs allocated may be larger due to hardware alignment or minimum requirements. If fewer SMs than requested are available, an exception is raised.
- Parameters:
sm_count (int) – The number of streaming multiprocessors to allocate for this green context.
- Raises:
RuntimeError – If the green context cannot be created or if CUDA cannot provide at least the requested number of SMs.
Example
from drjit.cuda import green_context # Simple usage - isolate 16 SMs for computation with green_context(16): # Kernels launched here use only 16 SMs result = some_computation() # Access context information with green_context(16) as ctx: print(f"Requested: {ctx.requested_sm_count} SMs") print(f"Actual: {ctx.sm_count} SMs") # Use remaining_ctx for concurrent computation other_ctx = ctx.remaining_ctx if other_ctx is not None: # Launch work on remaining SMs pass
This class is also available as
drjit.cuda.ad.green_context.- property sm_count¶
The actual number of SMs allocated to this green context.
The actual SM count may be larger than the requested count due to hardware alignment constraints or minimum requirements.
- Type:
int
- property requested_sm_count¶
The requested number of SMs for this green context.
- Type:
int
- property remaining_ctx¶
A CUDA context capsule for the remaining SMs.
This property provides a Python capsule containing the CUDA context (
CUcontext) for the set of streaming multiprocessors that are not part of this green context. This can be used to launch separate computations on the remaining GPU resources in isolation from this green context.Important
The green context object must remain alive while the remaining context is being used by other computations. Destroying the green context will invalidate the remaining context.
- Type:
typing.CapsuleType | None: A capsule with type
"CUcontext"containing the CUDA context for the remaining SMs, orNoneif no such context exists.
CUDA / GL interoperability¶
High-level interface¶
- class drjit.cuda.GLInterop¶
Abstraction for efficient CUDA/OpenGL interoperability.
This class provides a high-level interface for mapping OpenGL buffers and textures to CUDA, allowing direct GPU-to-GPU data transfers without going through host memory.
The class manages the lifecycle of CUDA graphics resources and ensures proper resource mapping/unmapping through its context manager interface.
Example
# For OpenGL buffers interop = dr.cuda.GLInterop.from_buffer(gl_buffer_id) interop.map().upload(drjit_tensor).unmap() # For OpenGL textures interop = dr.cuda.GLInterop.from_texture(gl_texture_id) interop.map().upload(drjit_tensor).unmap()
Low-level interface¶
- drjit.cuda.register_gl_buffer(gl_buffer: int) typing_extensions.CapsuleType¶
Register a GL buffer as a CUDA graphics resource.
The created
CUgraphicsResource(pointer) is returned as an opaque pointer. The resource must be unregistered withdrjit.unregister_cuda_resource()when done.- Parameters:
gl_buffer (int) – Integer identifying the GL buffer.
- Returns:
An opaque pointer to the CUDA graphics resource.
- Return type:
capsule
- drjit.cuda.register_gl_texture(gl_texture: int) typing_extensions.CapsuleType¶
Register a GL texture a CUDA graphics resource.
The created
CUgraphicsResource(pointer) is returned as an opaque pointer. The resource must be unregistered withdrjit.unregister_cuda_resource()when done.- Parameters:
gl_texture (int) – Integer identifying the GL texture.
- Returns:
An opaque pointer to the CUDA graphics resource.
- Return type:
capsule
- drjit.cuda.unregister_cuda_resource(cuda_resource: typing_extensions.CapsuleType) None¶
Unregister a CUDA graphics resource that was previously registered with
drjit.register_gl_buffer()ordrjit.register_gl_texture(). This frees the associated CUDA graphics resource.- Parameters:
cuda_resource (capsule) – The CUDA graphics resource to unregister.
- drjit.cuda.map_graphics_resource_ptr(cuda_resource: typing_extensions.CapsuleType) tuple[int, int]¶
Map a CUDA graphics resource and return a device pointer to the mapped resource. The resource must be unmapped with
drjit.unmap_graphics_resource()when done.- Parameters:
cuda_resource (capsule) – The CUDA graphics resource to map
n_bytes (int) – Output parameter that receives the
- Returns:
Device pointer to the mapped resource, as an integer. int: Size of the mapped resource, in bytes.
- Return type:
int
- drjit.cuda.map_graphics_resource_array(cuda_resource: typing_extensions.CapsuleType, mip_level: int = 0) int¶
Map a CUDA graphics resource and return a CUDA array handle to the specified sub-resource. The resource must be unmapped with
drjit.unmap_graphics_resource()when done.- Parameters:
cuda_resource (capsule) – The CUDA graphics resource to map.
mip_level (int) – The mip level of the sub-resource (default: 0)
- Returns:
CUDA array handle to the mapped sub-resource, as an integer.
- Return type:
int
- drjit.cuda.unmap_graphics_resource(cuda_resource: typing_extensions.CapsuleType) None¶
Unmap a CUDA graphics resource that was previously mapped with
drjit.map_graphics_resource_ptr()ordrjit.map_graphics_resource_array().- Parameters:
cuda_resource (capsule) – The CUDA graphics resource to unmap
- drjit.cuda.memcpy_2d_to_array_async(dst: int, src: int, src_pitch: int, height: int, from_host: bool = False) None¶
Perform an asynchronous 2D memory copy from a source buffer to a CUDA array.
- Parameters:
dst (int) – Destination CUDA array pointer, as an int.
src (int) – Source buffer (host or device memory) pointer, as an int.
src_pitch (int) – Pitch (bytes per row) of the source buffer
height (int) – Height of the region to copy (in elements)
from_host (bool) – True if copying from host memory, False for device memory