API Reference¶
Array creation¶
- drjit.zeros(dtype: type, shape: int = 1) object ¶
- drjit.zeros(dtype: type, shape: collections.abc.Sequence[int]) object
Return a zero-initialized instance of the desired type and shape.
This function can create zero-initialized instances of various types. In particular,
dtype
can be:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.zeros(dr.cuda.Array2f, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.zeros()
will invoke itself recursively to zero-initialize each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case.
Note that when
dtype
refers to a scalar mask or a mask array, it will be initialized toFalse
as opposed to zero.The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A zero-initialized instance of type
dtype
.- Return type:
object
- drjit.empty(dtype: type, shape: int = 1) object ¶
- drjit.empty(dtype: type, shape: collections.abc.Sequence[int]) object
Return an uninitialized Dr.Jit array of the desired type and shape.
This function can create uninitialized buffers of various types. It should only be used in combination with a subsequent call to an operation like
drjit.scatter()
that fills the array contents with valid data.The
dtype
parameter can be used to request:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.empty(dr.cuda.Array2f, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.empty()
will invoke itself recursively to allocate memory for each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case, and the function returns a zero-initialized result (there is little point in instantiating uninitialized versions of scalar Python types).
drjit.empty()
delays allocation of the underlying buffer until an operation tries to read/write the actual array contents.- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
An instance of type
dtype
with arbitrary/undefined contents.- Return type:
object
- drjit.ones(dtype: type, shape: int = 1) object ¶
- drjit.ones(dtype: type, shape: collections.abc.Sequence[int]) object
Return an instance of the desired type and shape filled with ones.
This function can create one-initialized instances of various types. In particular,
dtype
can be:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.ones(dr.cuda.Array2f, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.ones()
will invoke itself recursively to initialize each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case.
Note that when
dtype
refers to a scalar mask or a mask array, it will be initialized toTrue
as opposed to one.The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtype
filled with ones.- Return type:
object
- drjit.full(dtype: type, value: object, shape: int = 1) object ¶
- drjit.full(dtype: type, value: object, shape: collections.abc.Sequence[int]) object
Return an constant-valued instance of the desired type and shape.
This function can create constant-valued instances of various types. In particular,
dtype
can be:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.full(dr.cuda.Array2f, value=1.0, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.full()
will invoke itself recursively to initialize each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case.
The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
value (object) – An instance of the underlying scalar type (
float
/int
/bool
, etc.) that will be used to initialize the array contents.shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtype
filled withvalue
- Return type:
object
- drjit.opaque(dtype: type, value: object, shape: int = 1) object ¶
- drjit.opaque(dtype: type, value: object, shape: collections.abc.Sequence[int]) object
Return an opaque constant-valued instance of the desired type and shape.
This function is very similar to
drjit.full()
in that it creates constant-valued instances of various types including (potentially nested) Dr.Jit arrays, tensors, and PyTrees. Please refer to the documentation ofdrjit.full()
for details on the function signature. However,drjit.full()
creates literal constant arrays, which means that Dr.Jit is fully aware of the array contents.In contrast,
drjit.opaque()
produces an opaque array backed by a representation in device memory.Why is this useful?
Consider the following snippet, where a complex calculation is parameterized by the constant
1
.from drjit.llvm import Float result = complex_function(Float(1), ...) # Float(1) is equivalent to dr.full(Float, 1) print(result)
The
print()
statement will cause Dr.Jit to evaluate the queued computation, which likely also requires compilation of a new kernel (if that exact pattern of steps hasn’t been observed before). Kernel compilation is costly and may be much slower than the actual computation that needs to be done.Suppose we later wish to evaluate the function with a different parameter:
result = complex_function(Float(2), ...) print(result)
The constant
2
is essentially copy-pasted into the generated program, causing a mismatch with the previously compiled kernel that therefore cannot be reused. This unfortunately means that we must once more wait a few tens or even hundreds of milliseconds until a new kernel has been compiled and uploaded to the device.This motivates the existence of
drjit.opaque()
. By making a variable opaque to Dr.Jit’s tracing mechanism, we can keep constants out of the generated program and improve the effectiveness of the kernel cache:# The following lines reuse the compiled kernel regardless of the constant value = dr.opqaque(Float, 2) result = complex_function(value, ...) print(result)
This function is related to
drjit.make_opaque()
, which can turn an already existing Dr.Jit array, tensor, or PyTree into an opaque representation.- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
value (object) – An instance of the underlying scalar type (
float
/int
/bool
, etc.) that will be used to initialize the array contents.shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtype
filled withvalue
- Return type:
object
- drjit.arange(dtype: type[T], size: int) T ¶
- drjit.arange(dtype: type[T], start: int, stop: int, step: int = 1) T
This function generates an integer sequence on the interval [
start
,stop
) with step sizestep
, wherestart
= 0 andstep
= 1 if not specified.- Parameters:
dtype (type) – Desired Dr.Jit array type. The
dtype
must refer to a dynamically sized 1D Dr.Jit array such asdrjit.scalar.ArrayXu
ordrjit.cuda.Float
.start (int) – Start of the interval. The default value is
0
.stop/size (int) – End of the interval (not included). The name of this parameter differs between the two provided overloads.
step (int) – Spacing between values. The default value is
1
.
- Returns:
The computed sequence of type
dtype
.- Return type:
object
- drjit.linspace(dtype: type[T], start: float, stop: float, num: int, endpoint: bool = True) T ¶
This function generates an evenly spaced floating point sequence of size
num
covering the interval [start
,stop
].- Parameters:
dtype (type) – Desired Dr.Jit array type. The
dtype
must refer to a dynamically sized 1D Dr.Jit floating point array, such asdrjit.scalar.ArrayXf
ordrjit.cuda.Float
.start (float) – Start of the interval.
stop (float) – End of the interval.
num (int) – Number of samples to generate.
endpoint (bool) – Should the interval endpoint be included? The default is
True
.
- Returns:
The computed sequence of type
dtype
.- Return type:
object
Control flow¶
- drjit.syntax(f: None = None, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) Callable[[T], T] ¶
- drjit.syntax(f: T, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) T
- drjit.hint(arg: T, /, *, mode: Literal['scalar', 'evaluated', 'symbolic', None] | None = None, max_iterations: int | None = None, label: str | None = None, include: List[object] | None = None, exclude: List[object] | None = None, strict: bool = True) T ¶
Within ordinary Python code, this function is unremarkable: it returns the positional-only argument
arg
while ignoring any specified keyword arguments.The main purpose of
drjit.hint()
is to provide hints that influence the transformation performed by the@drjit.syntax
decorator. The following kinds of hints are supported:mode
overrides the compilation mode of awhile
loop orif
statement. The following choices are available:mode='scalar'
disables code transformations, which is permitted when the predicate of a loop orif
statement is a scalar Pythonbool
.i: int = 0 while dr.hint(i < 10, mode='scalar'): # ...
Routing such code through
drjit.while_loop()
ordrjit.if_stmt()
still works but may add small overheads, which motivates the existence of this flag. Note that this annotation does not causemode=scalar
to be passeddrjit.while_loop()
, anddrjit.if_stmt()
(which happens to be a valid input of both). Instead, it disables the code transformation altogether so that the above example translates into ordinary Python code:i: int = 0 while i < 10: # ...
mode='evaluated'
forces execution in evaluated mode and causes the code transformation to forward this argument to the relevantdrjit.while_loop()
ordrjit.if_stmt()
call.Refer to the discussion of
drjit.while_loop()
,drjit.JitFlag.SymbolicLoops
,drjit.if_stmt()
, anddrjit.JitFlag.SymbolicConditionals
for details.mode='symbolic'
forces execution in symbolic mode and causes the code transformation to forward this argument to the relevantdrjit.while_loop()
ordrjit.if_stmt()
call.Refer to the discussion of
drjit.while_loop()
,drjit.JitFlag.SymbolicLoops
,drjit.if_stmt()
, anddrjit.JitFlag.SymbolicConditionals
for details.
The optional
strict=False
reduces the strictness of variable consistency checks.Consider the following snippet:
from drjit.llvm import UInt32 @dr.syntax def f(x: UInt32): if x < 4: y = 3 else: y = 5 return y
This code will raise an exception.
>> f(UInt32(1)) RuntimeError: drjit.if_stmt(): the non-array state variable 'y' of type 'int' changed from '5' to '10'. Please review the interface and assumptions of 'drjit.while_loop()' as explained in the documentation (https://drjit.readthedocs.io/en/latest/reference.html#drjit.while_loop).
This is because the computed variable
y
of typeint
has an inconsistent value depending on the taken branch. Furthermore,y
is a scalar Python type that isn’t tracked by Dr.Jit. The fix here is to initializey
withUInt32(<integer value>)
.However, there may also be legitimate situations where such an inconsistency is needed by the implementation. This can be fine as
y
is not used below theif
statement. In this case, you can annotate the conditional or loop withdr.hint(..., strict=False)
, which disables the check.max_iterations
specifies a maximum number of loop iterations for reverse-mode automatic differentiation.Naive reverse-mode differentiation of loops (unless replaced by a smarter problem-specific strategy via
drjit.custom
anddrjit.CustomOp
) requires allocation of large buffers that hold loop state for all iterations.Dr.Jit requires an upper bound on the maximum number of loop iterations so that it can allocate such buffers, which can be provided via this hint. Otherwise, reverse-mode differentiation of loops will fail with an error message.
label
provovides a descriptive label.Dr.Jit will include this label as a comment in the generated intermediate representation, which can be helpful when debugging the compilation of large programs.
include
andexclude
indicates to the@drjit.syntax
decorator that a local variable should or should not be considered to be part of the set of state variables passed todrjit.while_loop()
ordrjit.if_stmt()
.While transforming a function, the
@drjit.syntax
decorator sequentially steps through a program to identify the set of read and written variables. It then forwards referenced variables to recursivedrjit.while_loop()
anddrjit.if_stmt()
calls. In rare cases, it may be useful to manually include or exclude a local variable from this process— specify a list of such variables to thedrjit.hint()
annotation to do so.
- drjit.while_loop(state: tuple[*Ts], cond: typing.Callable[[*Ts], AnyArray | bool], body: typing.Callable[[*Ts], tuple[*Ts]], labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True, compress: bool | None = None, max_iterations: int | None = None) tuple[*Ts] ¶
Repeatedly execute a function while a loop condition holds.
Motivation
This function provides a vectorized generalization of a standard Python
while
loop. For example, consider the following Python snippeti: int = 1 while i < 10: x *= x i += 1
This code would fail when
i
is replaced by an array with multiple entries (e.g., of typedrjit.llvm.Int
). In that case, the loop condition evaluates to a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, each entry of the array may need to run the loop for a different number of iterations. A standard Pythonwhile
loop is not able to do so.The
drjit.while_loop()
function realizes such a fine-grained looping mechanism. It takes three main input arguments:state
, a tuple of state variables that are modified by the loop iteration,Dr.Jit optimizes away superfluous state variables, so there isn’t any harm in specifying variables that aren’t actually modified by the loop.
cond
, a function that takes the state variables as input and uses them to evaluate and return the loop condition in the form of a boolean array,body
, a function that also takes the state variables as input and runs one loop iteration. It must return an updated set of state variables.
The function calls
cond
andbody
to execute the loop. It then returns a tuple containing the final version of thestate
variables. With this functionality, a vectorized version of the above loop can be written as follows:i, x = dr.while_loop( state=(i, x), cond=lambda i, x: i < 10, body=lambda i, x: (i+1, x*x) )
Lambda functions are convenient when the condition and body are simple enough to fit onto a single line. In general you may prefer to define local functions (
def loop_cond(i, x): ...
) and pass them to thecond
andbody
arguments.Dr.Jit also provides the
@drjit.syntax
decorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) usingdrjit.while_loop()
. With this decorator, the above example would be written as follows:@dr.syntax def f(i, x): while i < 10: x *= x i += 1 return i, x
Evaluation modes
Dr.Jit uses one of three different modes to compile this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).
Scalar mode: Scalar loops that don’t need any vectorization can fall back to a simple Python loop construct.
while cond(state): state = body(*state)
This is the default strategy when
cond(state)
returns a scalar Pythonbool
.The loop body may still use Dr.Jit types, but note that this effectively unrolls the loop, generating a potentially long sequence of instructions that may take a long time to compile. Symbolic mode (discussed next) may be advantageous in such cases.
Symbolic mode: Here, Dr.Jit runs a single loop iteration to capture its effect on the state variables. It embeds this captured computation into the generated machine code. The loop will eventually run on the device (e.g., the GPU) but unlike a Python
while
statement, the loop does not run on the host CPU (besides the mentioned tentative evaluation for symbolic tracing).When loop optimizations are enabled (
drjit.JitFlag.OptimizeLoops
), Dr.Jit may re-trace the loop body so that it runs twice in total. This happens transparently and has no influence on the semantics of this operation.Evaluated mode: in this mode, Dr.Jit will repeatedly evaluate the loop state variables and update active elements using the loop body function until all of them are done. Conceptually, this is equivalent to the following Python code:
active = True while True: dr.eval(state) active &= cond(state) if not dr.any(active): break state = dr.select(active, body(state), state)
(In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)
Dr.Jit will typically compile a kernel when it runs the first loop iteration. Subsequent iterations can then reuse this cached kernel since they perform the same exact sequence of operations. Kernel caching tends to be crucial to achieve good performance, and it is good to be aware of pitfalls that can effectively disable it.
For example, when you update a scalar (e.g. a Python
int
) in each loop iteration, this changing counter might be merged into the generated program, forcing the system to re-generate and re-compile code at every iteration, and this can ultimately dominate the execution time. If in doubt, increase the log level of Dr.Jit (drjit.set_log_level()
todrjit.LogLevel.Info
) and check if the kernels being launched contain the termcache miss
. You can also inspect the Kernels launched line in the output ofdrjit.whos()
. If you observe soft or hard misses at every loop iteration, then kernel caching isn’t working and you should carefully inspect your code to ensure that the computation stays consistent across iterations.When the loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.
Alternatively, you may specify the parameter
compress=True
or set the flagdrjit.JitFlag.CompressLoops
, which causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases, which is why it isn’t enabled by default.
A separate section about symbolic and evaluated modes discusses these two options in further detail.
The
drjit.while_loop()
function chooses the evaluation mode as follows:When the
mode
argument is set toNone
(the default), the function examines the loop condition. It uses scalar mode when this produces a Python bool, otherwise it inspects thedrjit.JitFlag.SymbolicLoops
flag to switch between symbolic (the default) and evaluated mode.To change this automatic choice for a region of code, you may specify the
mode=
keyword argument, nest code into adrjit.scoped_set_flag()
block, or change the behavior globally viadrjit.set_flag()
:with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False): # .. nested code will use evaluated loops ..
When
mode
is set to"scalar"
"symbolic"
, or"evaluated"
, it directly uses that method without inspecting the compilation flags or loop condition type.
When using the
@drjit.syntax
decorator to automatically convert Pythonwhile
loops intodrjit.while_loop()
calls, you can also use thedrjit.hint()
function to pass keyword arguments includingmode
,label
, ormax_iterations
to the generated looping construct:while dr.hint(i < 10, name='My loop', mode='evaluated'): # ...
Assumptions
The loop condition function must be pure (i.e., it should never modify the state variables or any other kind of program state). The loop body should not write to variables besides the officially declared state variables:
y = .. def loop_body(x): y[0] += x # <-- don't do this. 'y' is not a loop state variable dr.while_loop(body=loop_body, ...)
Dr.Jit automatically tracks dependencies of indirect reads (done via
drjit.gather()
) and indirect writes (done viadrjit.scatter()
,drjit.scatter_reduce()
,drjit.scatter_add()
,drjit.scatter_inc()
, etc.). Such operations create implicit inputs and outputs of a loop, and these do not need to be specified as loop state variables (however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within loops), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).y = .. def loop_body(x): # Scattering to 'y' is okay even if it is not declared as loop state dr.scatter(target=y, value=x, index=0)
Another important assumption is that the loop state remains consistent across iterations, which means:
The type of state variables is not allowed to change. You may not declare a Python
float
before a loop and then overwrite it with adrjit.cuda.Float
(or vice versa).Their structure/size must be consistent. The loop body may not turn a variable with 3 entries into one that has 5.
Analogously, state variables must always be initialized prior to the loop. This is the case even if you know that the loop body is guaranteed to overwrite the variable with a well-defined result. An initial value of
None
would violate condition 1 (type invariance), while an empty array would violate condition 2 (shape compatibility).
The implementation will check for violations and, if applicable, raise an exception identifying problematic state variables.
Potential pitfalls
Long compilation times.
In the example below,
i < 100000
is scalar, causingdrjit.while_loop()
to use the scalar evaluation strategy that effectively copy-pastes the loop body 100000 times to produce a giant program. Code written in this way will be bottlenecked by the CUDA/LLVM compilation stage.@dr.syntax def f(): i = 0 while i < 100000: # .. costly computation i += 1
Incorrect behavior in symbolic mode.
Let’s fix the above program by casting the loop condition into a Dr.Jit type to ensure that a symbolic loop is used. Problem solved, right?
from drjit.cuda import Bool @dr.syntax def f(): i = 0 while Bool(i < 100000): # .. costly computation i += 1
Unfortunately, no: this loop never terminates when run in symbolic mode. Symbolic mode does not track modifications of scalar/non-Dr.Jit types across loop iterations such as the
int
-valued loop counteri
. It’s as if we had writtenwhile Bool(0 < 100000)
, which of course never finishes.Evaluated mode does not have this problem—if your loop behaves differently in symbolic and evaluated modes, then some variation of this mistake is likely to blame. To fix this, we must declare the loop counter as a vector type before the loop and then modify it as follows:
from drjit.cuda import Int @dr.syntax def f(): i = Int(0) while i < 100000: # .. costly computation i += 1
Warning
This new implementation of the
drjit.while_loop()
abstraction still lacks the functionality tobreak
orreturn
from the loop, or tocontinue
to the next loop iteration. We plan to add these capabilities in the near future.Interface
- Parameters:
state (tuple) – A tuple containing the initial values of the loop’s state variables. This tuple normally consists of Dr.Jit arrays or PyTrees. Other values are permissible as well and will be forwarded to the loop body. However, such variables will not be captured by the symbolic tracing process.
cond (Callable) – a function/callable that will be invoked with
*state
(i.e., the state variables will be unpacked and turned into function arguments). It should return a scalar Pythonbool
or a boolean-typed Dr.Jit array representing the loop condition.body (Callable) – a function/callable that will be invoked with
*state
(i.e., the state variables will be unpacked and turned into function arguments). It should update the loop state and then return a new tuple of state variables that are compatible with the previous state (see the earlier description regarding what such compatibility entails).mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"scalar"
,"symbolic"
,"evaluated"
. If not specified, the function first checks if the loop is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicLoops
and then either performs a symbolic or an evaluated loop.compress (Optional[bool]) – Set this this parameter to
True
orFalse
to enable or disable loop state compression in evaluated loops (see the text above for a description of this feature). The function queries the value ofdrjit.JitFlag.CompressLoops
when the parameter is not specified. Symbolic loops ignore this parameter.labels (list[str]) – An optional list of labels associated with each
state
entry. Dr.Jit uses this to provide better error messages in case of a detected inconsistency. The@drjit.syntax
decorator automatically provides these labels based on the transformed code.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
max_iterations (int) – The maximum number of loop iterations (default:
-1
). You must specify a correct upper bound here if you wish to differentiate the loop in reverse mode. In that case, the maximum iteration count is used to reserve memory to store intermediate loop state.strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of
drjit.hint()
for an example. The default isstrict=True
.
- Returns:
The function returns the final state of the loop variables following termination of the loop.
- Return type:
tuple
- drjit.if_stmt(args: tuple[*Ts], cond: AnyArray | bool, true_fn: typing.Callable[[*Ts], T], false_fn: typing.Callable[[*Ts], T], arg_labels: typing.Sequence[str] = (), rv_labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True) T ¶
Conditionally execute code.
Motivation
This function provides a vectorized generalization of a standard Python
if
statement. For example, consider the following Python snippeti: int = .. some expression .. if i > 0: x = f(i) # <-- some costly function 'f' that depends on 'i' else: y += 1
This code would fail if
i
is replaced by an array containing multiple entries (e.g., of typedrjit.llvm.Int
). In that case, the conditional expression produces a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, some of the entries may want to run the body of theif
statement, while others must skip to theelse
block. This is not compatible with the semantics of a standard Pythonif
statement.The
drjit.if_stmt()
function realizes a more fine-grained conditional operation that accommodates these requirements, while avoiding execution of the costly branch unless this is truly needed. It takes the following input arguments:cond
, a boolean array that specifies whether the body of theif
statement should execute.A tuple of input arguments (
args
) that will be forwarded totrue_fn
andfalse_fn
. It is important to specify all inputs to ensure correct derivative tracking of the operation.true_fn
, a callable that implements the body of theif
block.false_fn
, a callable that implements the body of theelse
block.
The implementation will invoke
true_fn(*args)
andfalse_fn(*args)
to trace their contents. The return values of these functions must be compatible with each other (a precise definition of compatibility is described below). A vectorized version of the earlier example can then be written as follows:x, y = dr.if_stmt( args=(i, x, y), cond=i > 0, true_fn=lambda i, x, y: (f(i), y), false_fn=lambda i, x, y: (x, y + 1) )
Lambda functions are convenient when
true_fn
andfalse_fn
are simple enough to fit onto a single line. In general you may prefer to define local functions (def true_fn(i, x, y): ...
) and pass them to thetrue_fn
andfalse_fn
arguments.Dr.Jit later optimizes away superfluous inputs/outputs of
drjit.if_stmt()
, so there isn’t any harm in, e.g., specifying an identical element of a return value in bothtrue_fn
andfalse_fn
.Dr.Jit also provides the
@drjit.syntax
decorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) usingdrjit.if_stmt()
. With this decorator, the above example would be written as follows:@dr.syntax def f(i, x, y): if i > 0: x = f(i) else: y += 1 return x, y
Evaluation modes
Dr.Jit uses one of three different modes to realize this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).
Scalar mode: Scalar
if
statements that don’t need any vectorization can be reduced to normal Python branching constructs:if cond: state = true_fn(*args) else: state = false_fn(*args)
This strategy is the default when
cond
is a scalar Pythonbool
.Symbolic mode: Dr.Jit runs
true_fn
andfalse_fn
to capture the computation performed by each function, which allows it to generate an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.Evaluated mode: in this mode, Dr.Jit runs both branches of the
if
statement and then combines the results viadrjit.select()
. This is nearly equivalent to the following Python code:true_state = true_fn(*state) false_state = false_fn(*state) if false_fn else state state = dr.select(cond, true_fn, false_fn)
(In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)
Evaluated mode is conceptually simpler but also slower, since the device executes both sides of a branch when only one of them is actually needed.
The mode is chosen as follows:
When the
mode
argument is set toNone
(the default), the function examines the type of thecond
input and uses scalar mode if the type is a builtin Pythonbool
.Otherwise, it chooses between symbolic and evaluated mode based on the
drjit.JitFlag.SymbolicConditionals
flag, which is set by default. To change this choice for a region of code, you may specify themode=
keyword argument, nest it into adrjit.scoped_set_flag()
block, or change the behavior globally viadrjit.set_flag()
:with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False): # .. nested code will use evaluated mode ..
When
mode
is set to"scalar"
"symbolic"
, or"evaluated"
, it directly uses that mode without inspecting the compilation flags or condition type.
When using the
@drjit.syntax
decorator to automatically convert Pythonif
statements intodrjit.if_stmt()
calls, you can also use thedrjit.hint()
function to pass keyword arguments including themode
andlabel
parameters.if dr.hint(i < 10, mode='evaluated'): # ...
Assumptions
The return values of
true_fn
andfalse_fn
must be of the same type. This requirement applies recursively if the return value is a PyTree.Dr.Jit will refuse to compile vectorized conditionals, in which
true_fn
andfalse_fn
return a scalar that is inconsistent between the branches.>>> @dr.syntax ... def (x): ... if x > 0: ... y = 1 ... else: ... y = 0 ... return y ... >>> print(f(dr.llvm.Float(-1,2))) RuntimeError: dr.if_stmt(): detected an inconsistency when comparing the return values of 'true_fn' and 'false_fn': drjit.detail.check_compatibility(): inconsistent scalar Python object of type 'int' for field 'y'. Please review the interface and assumptions of dr.if_stmt() as explained in the Dr.Jit documentation.
The problem can be solved by assigning an instance of a capitalized Dr.Jit type (e.g.,
y=Int(1)
) so that the operation can be tracked.The functions
true_fn
andfalse_fn
should not write to variables besides the explicitly declared return value(s):vec = drjit.cuda.Array3f(1, 2, 3) def true_fn(x): vec.x += x # <-- don't do this. 'y' is not a declared output dr.if_stmt(args=(x,), true_fun=true_fn, ...)
This example can be fixed as follows:
def true_fn(x, vec): vec.x += x return vec vec = dr.if_stmt(args=(x, vec), true_fun=true_fn, ...)
drjit.if_stmt()
is differentiable in both forward and reverse modes. Correct derivative tracking requires that regular differentiable inputs are specified via theargs
parameter. The@drjit.syntax
decorator ensures that these assumptions are satisfied.Dr.Jit also tracks dependencies of indirect reads (done via
drjit.gather()
) and indirect writes (done viadrjit.scatter()
,drjit.scatter_reduce()
,drjit.scatter_add()
,drjit.scatter_inc()
, etc.). Such operations create implicit inputs and outputs, and these do not need to be specified as part ofargs
or the return value oftrue_fn
andfalse_fn
(however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within conditional statements), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).y = .. def true_fn(x): # 'y' is neither declared as input nor output of 'f', which is fine dr.scatter(target=y, value=x, index=0) dr.if_stmt(args=(x,), true_fn=true_fn, ...)
Interface
- Parameters:
cond (bool|drjit.ArrayBase) – a scalar Python
bool
or a boolean-valued Dr.Jit array.args (tuple) – A list of positional arguments that will be forwarded to
true_fn
andfalse_fn
.true_fn (Callable) – a callable that implements the body of the
if
block.false_fn (Callable) – a callable that implements the body of the
else
block.mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"scalar"
,"symbolic"
,"evaluated"
.arg_labels (list[str]) – An optional list of labels associated with each input argument. Dr.Jit uses this feature in combination with the
@drjit.syntax
decorator to provide better error messages in case of detected inconsistencies.rv_labels (list[str]) – An optional list of labels associated with each element of the return value. This parameter should only be specified when the return value is a tuple. Dr.Jit uses this feature in combination with the
@drjit.syntax
decorator to provide better error messages in case of detected inconsistencies.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of
drjit.hint()
for an example. The default isstrict=True
.
- Returns:
Combined return value mixing the results of
true_fn
andfalse_fn
.- Return type:
object
- drjit.switch(index: object, targets: collections.abc.Sequence, *args, **kwargs) object ¶
Selectively invoke functions based on a provided index array.
When called with a scalar
index
(of typeint
), this function is equivalent to the following Python expression:targets[index](*args, **kwargs)
When called with a Dr.Jit index array (specifically, 32-bit unsigned integers), it performs the vectorized equivalent of the above and assembles an array of return values containing the result of all referenced functions. It does so efficiently using at most a single invocation of each function in
targets
.from drjit.llvm import UInt32 res = dr.switch( index=UInt32(0, 0, 1, 1), # <-- selects the function targets=[ # <-- arbitrary list of callables lambda x: x, lambda x: x*10 ], UInt32(1, 2, 3, 4) # <-- argument passed to function ) # res now contains [0, 10, 20, 30]
The function traverses the set of positional (
*args
) and keyword arguments (**kwargs
) to find all Dr.Jit arrays including arrays contained within PyTrees. It routes a subset of array entries to each function as specified by theindex
argument.Dr.Jit will use one of two possible strategies to compile this operation depending on the active compilation flags (see
drjit.set_flag()
,drjit.scoped_set_flag()
):Symbolic mode: Dr.Jit transcribes every function into a counterpart in the generated low-level intermediate representation (LLVM IR or PTX) and targets them via an indirect jump instruction.
This mode is used when
drjit.JitFlag.SymbolicCalls
is set, which is the default.Evaluated mode: Dr.Jit evaluates the inputs
index
,args
,kwargs
viadrjit.eval()
, groups them byindex
, and invokes each function with with the subset of inputs that reference it. Callables that are not referenced by any element ofindex
are ignored.In this mode, a
drjit.switch()
statement will cause Dr.Jit to launch a series of kernels processing subsets of the input data (one per function).
A separate section about symbolic and evaluated modes discusses these two options in detail.
To switch the compilation mode locally, use
drjit.scoped_set_flag()
as shown below:with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False): result = dr.switch(..)
When a boolean Dr.Jit array (e.g.,
drjit.llvm.Bool
,drjit.cuda.ad.Bool
, etc.) is specified as last positional argument or as a keyword argument namedactive
, that argument is treated specially: entries of the input arrays associated with aFalse
mask entry are ignored and never passed to the functions. Associated entries of the return value will be zero-initialized. The function will still receive the mask argument as input, but it will always be set toTrue
.Danger
The indices provided to this operation are unchecked by default. Attempting to call functions beyond the end of the
targets
array is undefined behavior and may crash the application, unless such calls are explicitly disabled via theactive
parameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound calls and furthermore report warnings to identify problematic source locations:>>> print(dr.switch(UInt32(0, 100), [lambda x:x], UInt32(1))) Attempted to invoke callable with index 100, but this↵ value must be smaller than 1. (<stdin>:2)
- Parameters:
index (int|drjit.ArrayBase) – a list of indices to choose the functions
targets (Sequence[Callable]) – a list of callables to which calls will be dispatched based on the
index
argument.mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"symbolic"
,"evaluated"
. If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicCalls
and then either performs a symbolic or an evaluated call.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
*args (tuple) – a variable-length list of positional arguments passed to the functions. PyTrees are supported.
**kwargs (dict) – a variable-length list of keyword arguments passed to the functions. PyTrees are supported.
- Returns:
When
index
is a scalar Python integer, the return value simply forwards the return value of the selected function. Otherwise, the function returns a Dr.Jit array or PyTree combining the results from each referenced callable.- Return type:
object
- drjit.dispatch(inst: drjit.ArrayBase, target: collections.abc.Callable, *args, **kwargs) object ¶
Invoke a provided Python function for each instance in an instance array.
This function invokes the provided
target
for each instance in the instance arrayinst
and assembles the return values into a result array. Conceptually, it does the following:def dispatch(inst, target, *args, **kwargs): result = [] for in in inst: result.append(target(inst, *args, **kwargs))
However, the implementation accomplishes this more efficiently using only a single call per unique instance. Instead of a Python
list
, it returns a Dr.Jit array or PyTree.In practice, this function is mainly good for two things:
Dr.Jit instance arrays contain C++ instance, and these will typically expose a set of methods. Adding further methods requires re-compiling C++ code and adding bindings, which may impede quick prototyping. With
drjit.dispatch()
, a developer can quickly implement additional vectorized method calls within Python (with the caveat that these can only access public members of the underlying type).Dynamic dispatch is a relatively costly operation. When multiple calls are performed on the same set of instances, it may be preferable to merge them into a single and potentially significantly faster use of
drjit.dispatch()
. An example is shown below:inst = # .. Array of C++ instances .. result_1 = inst.func_1(arg1) result_2 = inst.func_2(arg2)
The following alternative implementation instead uses
drjit.dispatch()
:def my_func(self, arg1, arg2): return (self.func_1(arg1), self.func_2(arg2)) result_1, result_2 = dr.dispatch(inst, my_func, arg1, arg2)
This function is otherwise very similar to
drjit.switch()
and similarly provides two different compilation modes, differentiability, and special handling of mask arguments. Please review the documentation ofdrjit.switch()
for details.- Parameters:
inst (drjit.ArrayBase) – a Dr.Jit instance array.
target (Callable) – function to dispatch on all instances
mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"symbolic"
,"evaluated"
. If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicCalls
and then either performs a symbolic or an evaluated call.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
*args (tuple) – a variable-length list of positional arguments passed to the function. PyTrees are supported.
**kwargs (dict) – a variable-length list of keyword arguments passed to the function. PyTrees are supported.
- Returns:
A Dr.Jit array or PyTree containing the result of each performed function call.
- Return type:
object
Horizontal operations¶
These operations are horizontal in the sense that [..]
- drjit.gather(dtype: type[T], source: object, index: AnyArray | Sequence[int] | int, active: AnyArray | Sequence[bool] | bool = True, mode: drjit.ReduceMode = drjit.ReduceMode.Auto) T ¶
Gather values from a flat array or nested data structure.
This function performs a gather (i.e., indirect memory read) from
source
at positionindex
. It expects adtype
argument and will return an instance of this type. The optionalactive
argument can be used to disable some of the components, which is useful when not all indices are valid; the corresponding output will be zero in this case.This operation can be used in the following different ways:
When
dtype
is a 1D Dr.Jit array likedrjit.llvm.ad.Float
, this operation implements a parallelized version of the Python array indexing expressionsource[index]
with optional masking. Example:source = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted result = dr.gather(dtype=type(source), source=source, index=index)
When
dtype
is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
type(source)
matchesdtype
, the gather operation threads through entries and invokes itself recursively. For example, the gather operation inresult = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) result = dr.gather(dr.cuda.Array3f, source, index)
is equivalent to
result = dr.cuda.Array3f( dr.gather(dr.cuda.Float, source.x, index), dr.gather(dr.cuda.Float, source.y, index), dr.gather(dr.cuda.Float, source.z, index) )
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation reconstructs the requested
dtype
from a flatsource
array, using C-style ordering with a suitably modifiedindex
. For example, the gather below reads 3D vectors from a 1D array.source = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) result = dr.gather(dr.cuda.Array3f, source, index)
and is equivalent to
result = dr.cuda.Vector3f( dr.gather(dr.cuda.Float, source, index*3 + 0), dr.gather(dr.cuda.Float, source, index*3 + 1), dr.gather(dr.cuda.Float, source, index*3 + 2))
Danger
The indices provided to this operation are unchecked by default. Attempting to read beyond the end of the
source
array is undefined behavior and may crash the application, unless such reads are explicitly disabled via theactive
parameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound reads and furthermore report warnings to identify problematic source locations:>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100)) drjit.gather(): out-of-bounds read from position 100 in an array↵ of size 3. (<stdin>:2)
- Parameters:
dtype (type) – The desired output type (typically equal to
type(source)
, but other variations are possible as well, see the description above.)source (object) – The object from which data should be read (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXb
ordrjit.cuda.Bool
) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue
.mode (drjit.ReduceMode) – The reverse-mode derivative of a gather is an atomic scatter-reduction. The execution of such atomics can be rather performance-sensitive (see the discussion of
drjit.ReduceMode
for details), hence Dr.Jit offers a few different compilation strategies to realize them. Specifying this parameter selects a strategy for the derivative of a particular gather operation. The default isdrjit.ReduceMode.Auto
.
- drjit.scatter(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None ¶
Scatter values into a flat array or nested data structure.
This operation performs a scatter (i.e., indirect memory write) of the
value
parameter to thetarget
array at positionindex
. The optionalactive
argument can be used to disable some of the individual write operations, which is useful when not all provided values or indices are valid.This operation can be used in the following different ways:
When
target
is a 1D Dr.Jit array likedrjit.llvm.ad.Float
, this operation implements a parallelized version of the Python array indexing expressiontarget[index] = value
with optional masking. Example:target = dr.empty(dr.cuda.Float, 1024*1024) value = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted dr.scatter(target, value=value, index=index)
When
target
is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
target
andvalue
are of the same type, the scatter operation threads through entries and invokes itself recursively. For example, the scatter operation intarget = dr.cuda.Array3f(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter(target, value, index)
is equivalent to
dr.scatter(target.x, value.x, index) dr.scatter(target.y, value.y, index) dr.scatter(target.z, value.z, index)
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation flattens the
value
array and writes it using C-style ordering with a suitably modifiedindex
. For example, the scatter below writes 3D vectors into a 1D array.target = dr.cuda.Float(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter(target, value, index)
and is equivalent to
dr.scatter(target, value.x, index*3 + 0) dr.scatter(target, value.y, index*3 + 1) dr.scatter(target, value.z, index*3 + 2)
Danger
The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the
active
parameter). Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.Dr.Jit makes no guarantees about the expected behavior when a scatter operation has conflicts, i.e., when a specific position is written multiple times by a single
drjit.scatter()
operation.- Parameters:
target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
value (object) – The values to be written (typically of type
type(target)
, but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the input is not an array type.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXb
ordrjit.cuda.Bool
) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue
.
- enum drjit.ReduceOp(value)¶
List of different atomic read-modify-write (RMW) operations supported by
drjit.scatter_reduce()
.Valid values are as follows:
- Identity = ReduceOp.Identity¶
Perform an ordinary scatter operation that ignores the current entry.
- Add = ReduceOp.Add¶
Addition.
- Mul = ReduceOp.Mul¶
Multiplication.
- Min = ReduceOp.Min¶
Minimum.
- Max = ReduceOp.Max¶
Maximum.
- And = ReduceOp.And¶
Binary AND operation.
- Or = ReduceOp.Or¶
Binary OR operation.
- enum drjit.ReduceMode(value)¶
Compilation strategy for atomic scatter-reductions.
Elements of of this enumeration determine how Dr.Jit executes atomic scatter-reductions, which refers to indirect writes that update an existing element in an array, while avoiding problems arising due to concurrency.
Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude.
This parameter also plays an important role for
drjit.gather()
, which is nominally a read-only operation. This is because the reverse-mode derivative of a gather turns it into an atomic scatter-addition, where further context on how to compile the operation is needed.Dr.Jit implements several strategies to address contention, which can be selected by passing the optional
mode
parameter todrjit.scatter_reduce()
,drjit.scatter_add()
, anddrjit.gather()
.If you find that a part of your program is bottlenecked by atomic writes, then it may be worth explicitly specifying some of the strategies below to see which one performs best.
Valid values are as follows:
- Auto = ReduceMode.Auto¶
Select the first valid option from the following list:
use
drjit.ReduceMode.Expand
if the computation uses the LLVM backend and thetarget
array storage size is smaller or equal than than the value given bydrjit.expand_threshold()
. This threshold can be changed using thedrjit.set_expand_threshold()
function.use
drjit.ReduceMode.Local
ifdrjit.JitFlag.ScatterReduceLocal
is set.fall back to
drjit.ReduceMode.Direct
.
- Direct = ReduceMode.Direct¶
Insert an ordinary atomic reduction operation into the program.
This mode is ideal when no or little contention is expected, for example because the target indices of scatters are well spread throughout the target array. This mode generates a minimal amount of code, which can help improve performance especially on GPU backends.
- Local = ReduceMode.Local¶
Locally pre-reduce operands.
In this mode, Dr.Jit adds extra code to the compiled program to examine the target indices of atomic updates. For example, CUDA programs run with an instruction granularity referred to as a warp, which is a group of 32 threads. When some of these threads want to write to the same location, then those operands can be pre-processed to reduce the total number of necessary atomic memory transactions (potentially to just a single one!)
On the CPU/LLVM backend, the same process works at the granularity of packets. The details depends on the underlying instruction set—for example, there are 16 threads per packet on a machine with AVX512, so there is a potential for reducing atomic write traffic by that factor.
- NoConflicts = ReduceMode.NoConflicts¶
Perform a non-atomic read-modify-write operation.
This mode is only safe in specific situations. The caller must guarantee that there are no conflicts (i.e., scatters targeting the same elements). If specified, Dr.Jit will generate a non-atomic read-modify-update operation that potentially runs significantly faster, especially on the LLVM backend.
- Permute = ReduceMode.Permute¶
In contrast to prior enumeration entries, this one modifies plain (non-reductive) scatters and gathers. It exists to enable internal optimizations that Dr.Jit uses when differentiating vectorized function calls and compressed loops. You likely should not use it in your own code.
When setting this mode, the caller guarantees that there will be no conflicts, and that every entry is written exactly single time using an index vector representing a permutation (it’s fine this permutation is accomplished by multiple separate write operations, but there should be no more than 1 write to each element).
Giving ‘Permute’ as an argument to a (nominally read-only) gather operation is helpful because we then know that the reverse-mode derivative of this operation can be a plain scatter instead of a more costly atomic scatter-add.
Giving ‘Permute’ as an argument to a scatter operation is helpful because we then know that the forward-mode derivative does not depend on any prior derivative values associated with that array, as all current entries will be overwritten.
- Expand = ReduceMode.Expand¶
Expand the target array to avoid write conflicts, then scatter non-atomically.
This feature is only supported on the LLVM backend. Other backends interpret this flag as if
drjit.ReduceMode.Auto
had been specified.This mode internally expands the storage underlying the target array to a much larger size that is proportional to the number of CPU cores. Scalar (length-1) target arrays are expanded even further to ensure that each CPU gets an entirely separate cache line.
Following this one-time expansion step, the array can then accommodate an arbitrary sequence of scatter-reduction operations that the system will internally perform using non-atomic read-modify-write operations (i.e., analogous to the
NoConflicts
mode). Dr.Jit automatically re-compress the array into the ordinary representation.On bigger arrays and on machines with many cores, the storage costs resulting from this mode can be prohibitive.
- drjit.scatter_reduce(op: drjit.ReduceOp, target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None ¶
Atomically update values in a flat array or nested data structure.
This function performs an atomic scatter-reduction, which is a read-modify-write operation that applies one of several possible mathematical functions to selected entries of an array. The following are supported:
drjit.ReduceOp.Add
:a=a+b
.drjit.ReduceOp.Max
:a=max(a, b)
.drjit.ReduceOp.Min
:a=min(a, b)
.drjit.ReduceOp.Or
:a=a | b
(integer arrays only).drjit.ReduceOp.And
:a=a & b
(integer arrays only).
Here,
a
refers to an entry oftarget
selected byindex
, andb
denotes the associated element ofvalue
. The operation resolves potential conflicts arising due to the parallel execution of this operation.The optional
active
argument can be used to disable some of the updates, e.g., when not all provided values or indices are valid.Atomic additions are subject to non-deterministic rounding errors. The reason for this is that IEEE-754 addition are non-commutative. The execution order is scheduling-dependent, which can lead to small variations across program runs.
Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude. Dr.Jit provides several different compilation strategies to reduce these costs, which can be selected via the
mode
parameter. The documentation ofdrjit.ReduceMode
provides more detail and performance plots.Backend support
Many combinations of reductions and variable types are not supported. Some combinations depend on the compute capability (CC) of the underlying CUDA device or on the LLVM version (LV) and the host architecutre (AMD64, x86_64). The following matrices display the level of support.
For CUDA:
Reduction
Bool
[U]Int{32,64}
Float16
Float32
Float64
✅
✅
✅
✅
✅
❌
✅
⚠️ CC≥60
✅
⚠️ CC≥60
❌
❌
❌
❌
❌
❌
❌
⚠️ CC≥90
❌
❌
❌
❌
⚠️ CC≥90
❌
❌
❌
✅
❌
❌
❌
❌
✅
❌
❌
❌
For LLVM:
Reduction
Bool
[U]Int{32,64}
Float16
Float32
Float64
✅
✅
✅
✅
✅
❌
✅
⚠️ LV≥16
✅
✅
❌
❌
❌
❌
❌
❌
⚠️ LV≥15
⚠️ LV≥16, ARM64
⚠️ LV≥15
⚠️ LV≥15
❌
⚠️ LV≥15
⚠️ LV≥16, ARM64
⚠️ LV≥15
⚠️ LV≥15
❌
✅
❌
❌
❌
❌
✅
❌
❌
❌
The function raises an exception when the operation is not supported by the backend.
Scatter-reducing nested types
This operation can be used in the following different ways:
When
target
is a 1D Dr.Jit array likedrjit.llvm.ad.Float
, this operation implements a parallelized version of the Python array indexing expressiontarget[index] = op(target[index], value)
with optional masking. Example:target = dr.zeros(dr.cuda.Float, 1024*1024) value = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted dr.scatter_reduce(dr.ReduceOp.Add, target, value=value, index=index)
When
target
is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
target
andvalue
are of the same type, the scatter-reduction threads through entries and invokes itself recursively. For example, the scatter operation inop = dr.ReduceOp.Add target = dr.cuda.Array3f(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter_reduce(op, target, value, index)
is equivalent to
dr.scatter_reduce(op, target.x, value.x, index) dr.scatter_reduce(op, target.y, value.y, index) dr.scatter_reduce(op, target.z, value.z, index)
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation flattens the
value
array and writes it using C-style ordering with a suitably modifiedindex
. For example, the scatter-reduction below writes 3D vectors into a 1D array.op = dr.ReduceOp.Add target = dr.cuda.Float(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter_reduce(op, target, value, index)
and is equivalent to
dr.scatter_reduce(op, target, value.x, index*3 + 0) dr.scatter_reduce(op, target, value.y, index*3 + 1) dr.scatter_reduce(op, target, value.z, index*3 + 2)
Danger
The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the
active
parameter). Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.Dr.Jit makes no guarantees about the relative ordering of atomic operations when a
drjit.scatter_reduce()
writes to the same element multiple times. Combined with the non-associate nature of floating point operations, concurrent writes will generally introduce non-deterministic rounding error.- Parameters:
op (drjit.ReduceOp) – Specifies the type of update that should be performed.
target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
value (object) – The values to be used in the RMW operation (typically of type
type(target)
, but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the the input is not an array type.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXb
ordrjit.cuda.Bool
) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue
.mode (drjit.ReduceMode) – Dr.Jit offers several different strategies to implement atomic scatter-reductions that can be selected via this parameter. They achieve different best/worst case performance and, in the case of
drjit.ReduceMode.Expand
, involve additional memory storage overheads. The default isdrjit.ReduceMode.Auto
.
- drjit.scatter_add(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None ¶
Atomically add values to a flat array or nested data structure.
This function is equivalent to
drjit.scatter_reduce(drjit.ReduceOp.Add, ...)
and exists for reasons of convenience. Please refer todrjit.scatter_reduce()
for details on atomic scatter-reductions.
- drjit.scatter_add_kahan(target_1: drjit.ArrayBase, target_2: drjit.ArrayBase, value: object, index: object, active: object = True) None ¶
Perform a Kahan-compensated atomic scatter-addition.
Atomic floating point accumulation can incur significant rounding error when many values contribute to a single element. This function implements an error-compensating version of
drjit.scatter_add()
based on the Kahan-Babuška-Neumeier algorithm that simultaneously accumulates into two target buffers.In particular, the operation first accumulates a values into entries of a dynamic 1D array
target_1
. It tracks the round-off error caused by this operation and then accumulates this error into a second 1D array namedtarget_2
. At the end, the buffers can simply be added together to obtain the error-compensated result.This function has a few limitations: in contrast to
drjit.scatter_reduce()
anddrjit.scatter_add()
, it does not perform a local reduction (see flagJitFlag.ScatterReduceLocal
), which can be an important optimization when atomic accumulation is a performance bottleneck.Furthermore, the function currently works with flat 1D arrays. This is an implementation limitation that could in principle be removed in the future.
Finally, the function is differentiable, but derivatives currently only propagate into
target_1
. This means that forward derivatives don’t enjoy the error compensation of the primal computation. This limitation is of no relevance for backward derivatives.
- drjit.scatter_inc(target: drjit.ArrayBase, index: object, active: object = True) object ¶
Atomically increment a value within an unsigned 32-bit integer array and return the value prior to the update.
This operation works just like the
drjit.scatter_reduce()
operation for 32-bit unsigned integer operands, but with a fixedvalue=1
parameter andop=ReduceOp::Add
.The main difference is that this variant additionally returns the old value of the target array prior to the atomic update in contrast to the more general scatter-reduction, which just returns
None
. The operation also supports masking—the return value in the unmasked case is undefined. Bothtarget
andindex
parameters must be 1D unsigned 32-bit arrays.This operation is a building block for stream compaction: threads can scatter-increment a global counter to request a spot in an array and then write their result there. The recipe for this is look as follows:
data_1 = ... data_2 = ... active = drjit.ones(Bool, len(data_1)) # .. or a more complex condition # This will hold the counter ctr = UInt32(0) # Allocate output buffers max_size = 1024 data_compact_1 = dr.empty(Float, max_size) data_compact_2 = dr.empty(Float, max_size) idx = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) # Disable dr.scatter() operations below in case of a buffer overflow active &= idx < max_size dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.scatter( target=data_compact_2, value=data_2, index=my_index, mask=active )
When following this approach, be sure to provide the same mask value to the
drjit.scatter_inc()
and subsequentdrjit.scatter()
operations.The function
drjit.reshape()
can be used to resize the resulting arrays to their compacted size. Please refer to the documentation of this function, specifically the code example illustrating the use of theshrink=True
argument.The function
drjit.scatter_inc()
exhibits the following unusual behavior compared to regular Dr.Jit operations: the return value references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the computed index is reused. In essence, the variable is “consumed” by the process of evaluation.my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.eval(data_compact_1) # Run Kernel #1 dr.scatter( target=data_compact_2, value=data_2, index=my_index, # <-- oops, reusing my_index in another kernel. mask=active # This raises an exception. )
To get the above code to work, you will need to evaluate
my_index
at the same time to materialize it into a stored (and therefore trivially reproducible) representation. For this, ensure that the size of theactive
mask matcheslen(data_*)
and that it is not the trivialTrue
default mask (otherwise, the evaluatedmy_index
will be scalar).dr.eval(data_compact_1, my_index)
Such multi-stage evaluation is potentially inefficient and may defeat the purpose of performing stream compaction in the first place. In general, prefer keeping all scatter operations involving the computed index in the same kernel, and then this issue does not arise.
The implementation of
drjit.scatter_inc()
performs a local reduction first, followed by a single atomic write per SIMD packet/warp. This is done to reduce contention from a potentially very large number of atomic operations targeting the same memory address. Fully masked updates do not cause memory traffic.There is some conceptual overlap between this function and
drjit.compress()
, which can likewise be used to reduce a stream to a smaller subset of active items. The downside ofdrjit.compress()
is that it requires evaluating the variables to be reduced, which can be very costly in terms of of memory traffic and storage footprint. Reducing throughdrjit.scatter_inc()
does not have this limitation: it can operate on symbolic arrays that greatly exceed the available device memory. One advantage ofdrjit.compress()
is that it essentially boils down to a relatively simple prefix sum, which does not require atomic memory operations (these can be slow in some cases).
- drjit.block_reduce(op: ReduceOp, value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T ¶
Reduce elements within blocks.
This function reduces all elements within contiguous blocks of size
block_size
along the trailing dimension of the input arrayvalue
, returning a correspondingly smaller output array. Various types of reductions are supported (seedrjit.ReduceOp
for details).For example, a sum reduction of a hypothetical array
[a, b, c, d, e, f]
withblock_size=2
produces the output[a+b, c+d, e+f]
.The function raises an exception when the length of the trailing dimension is not a multiple of the block size. It recursively threads through nested arrays and PyTrees.
Dr.Jit uses one of two strategies to realize this operation, which can be optionally forced by specifying the
mode
parameter.mode="evaluated"
first evaluates the input array viadrjit.eval()
and then launches a specialized reduction kernel.On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions with the limitation that it requires
block_size
to be a power of two. The LLVM backend parallelizes the operation via the built-in thread pool and has noblock_size
limitations.mode="symbolic"
usesdrjit.scatter_reduce()
to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input array is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).Disadvantages of this mode are that
Atomic scatters can suffer from memory contention (though
drjit.scatter_reduce()
takes steps to reduce contention, see its documentation for details).Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer and floating point min/max reductions are unaffected by this.
mode=None
(default) automatically picks a reasonable strategy according to the following logic:Symbolic mode is admissible when the necessary atomic reduction is supported by the backend.
Evaluated mode is admissible when the input does not involve symbolic variables. On the CUDA backend
block_size
must furthermore be a power of two.If only one strategy remains, then pick that one. Raise an exception when no strategy works out.
Otherwise, use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.
Use symbolic mode in all other cases.
For some inputs, no strategy works out (e.g., multiplicative reduction of an array with a non-power-of-two block size on the CUDA backend). The function will raise an exception in such cases.
Since evaluated mode can be quite a bit faster and is guaranteed to be deterministic, it is recommended that you design your program so that it invokes
drjit.block_reduce()
with a power-of-twoblock_size
.Note
Tensor inputs are not supported. To reduce blocks within tensors, apply the regular axis-wide reductions (
drjit.sum()
,drjit.prod()
,drjit.min()
,drjit.max()
) to reshaped tensors. For example, to sum-reduce a(16, 16)
tensor by a factor of(4, 2)
(i.e., to a(4, 8)
-sized tensor), writedr.sum(dr.reshape(value, shape=(4, 4, 8, 2)), axis=(1, 3))
.- Parameters:
value (object) – A Dr.Jit array or PyTree
block_size (int) – size of the block
mode (str | None) – optional parameter to force an evaluation strategy.
- Returns:
The block-reduced array or PyTree as specified above.
- drjit.block_sum(value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T ¶
Sum elements within blocks.
This is a convenience alias for
drjit.block_reduce()
withop
set todrjit.ReduceOp.Add
.
- drjit.reduce(op: ReduceOp, value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object ¶
Reduce the input array, tensor, or iterable along the specified axis/axes.
This function reduces arrays, tensors and other iterable Python types along one or multiple axes, where
op
selects the operation to be performed:drjit.ReduceOp.Add
:a[0] + a[1] + ...
.drjit.ReduceOp.Mul
:a[0] * a[1] * ...)
.drjit.ReduceOp.Min
:min(a[0], a[1], ...)
.drjit.ReduceOp.Max
:max(a[0], a[1], ...)
.drjit.ReduceOp.Or
:a[0] | a[1] | ...
(integer arrays only).drjit.ReduceOp.And
:a[0] & a[1] & ...
(integer arrays only).
The functions
drjit.sum()
,drjit.prod()
,drjit.min()
, anddrjit.max()
are convenience aliases that calldrjit.reduce()
with specific values ofop
.By default, the reduction is along axis
0
(i.e., the outermost one), returning an instance of the array’s element type. For instance, sum-reducing an arraya
of typedrjit.cuda.Array3f
is equivalent to writinga[0] + a[1] + a[2]
and produces a result of typedrjit.cuda.Float
. Dr.Jit can trace this operation and include it in the generated kernel.Negative indices (e.g.
axis=-1
) count backward from the innermost axis. Multiple axes can be specified as a tuple. The valueaxis=None
requests a simultaneous reduction over all axes.When reducing axes of a tensor, or when reducing the trailing dimension of a Jit-compiled array, some special precautions apply: these axes correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. This can be done using the following strategies:
mode="evaluated"
first evaluates the input array viadrjit.eval()
and then launches a specialized reduction kernel.On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions. The LLVM backend parallelizes the reduction via the built-in thread pool.
mode="symbolic"
usesdrjit.scatter_reduce()
to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).Disadvantages of this mode are that
Atomic scatters can suffer from memory contention (though the
drjit.scatter_reduce()
function takes steps to reduce contention, see its documentation for details).Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer reductions and floating point min/max reductions are unaffected by this.
mode=None
(default) automatically picks a reasonable strategy according to the following logic:Use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.
Use evaluated mode when the necessary atomic reduction operation is not supported by the backend.
Otherwise, use symbolic mode.
This function generally strips away reduced axes, but there is one notable exception: it will never remove a trailing dynamic dimension, if present in the input array.
For example, reducing an instance of type
drjit.cuda.Float
along axis0
does not produce a scalar Pythonfloat
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.- Parameters:
op (ReduceOp) – The operation that should be applied along the reduced axis/axes.
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array or tensor.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- drjit.sum(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object ¶
Sum-reduce the input array, tensor, or iterable along the specified axis/axes.
This function sum-reduces arrays, tensors and other iterable Python types along one or multiple axes. It is equivalent to
dr.reduce(dr.ReduceOp.Add, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.prod(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object ¶
Multiplicatively reduce the input array, tensor, or iterable along the specified axis/axes.
This function performs a multiplicative reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Mul, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.min(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object ¶
Perform a minimum reduction of the input array, tensor, or iterable along the specified axis/axes.
(Not to be confused with
drjit.minimum()
, which computes the smaller of two values).This function performs a minimum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Min, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.max(value: object, axis: int | tuple[int, ...] | None = 0, mode: str | None = None) object ¶
Perform a maximum reduction of the input array, tensor, or iterable along the specified axis/axes.
(Not to be confused with
drjit.maximum()
, which computes the larger of two values).This function performs a maximum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Max, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- drjit.mean(value: object, axis: int | Tuple[int, ...] | None = 0, mode: Literal['symbolic', 'evaluated', None] | None = None) object ¶
Compute the mean of the input array or tensor along one or multiple axes.
This function performs a horizontal sum reduction by adding values of the input array, tensor, or Python sequence along one or multiple axes and then dividing by the number of entries. By default, it sums along the outermost axis; specify
axis=None
to sum over all of them at once. The mean of an empty array is considered to be zero.See the section on horizontal reductions for important general information about their properties.
- Parameters:
value (float | int | Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
axis (int | None) – The axis along which to reduce (Default:
0
). A value ofNone
causes a simultaneous reduction along all axes. Currently, only values of0
andNone
are supported.
- Returns:
Result of the reduction operation)”;
- Return type:
float | int | drjit.ArrayBase
- drjit.all(value: object, axis: int | tuple[int, ...] | None = 0) object ¶
Check if all elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
&
(AND) operator.By default, it reduces along index
0
, which refers to the outermost axis. Negative indices (e.g.-1
) count backwards from the innermost axis. The special argumentaxis=None
causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to beTrue
.The function is internally based on
dr.reduce()
. See the documentation of this function for further information.Like
dr.reduce()
, this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Bool
does not produce a scalar Pythonbool
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
bool
if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.all(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()
to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt()
,while_loop()
, etc. The@dr.syntax
decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...
) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.any(value: object, axis: int | tuple[int, ...] | None = 0) object ¶
Check if any elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
|
(OR) operator.By default, it reduces along index
0
, which refers to the outermost axis. Negative indices (e.g.-1
) count backwards from the innermost axis. The special argumentaxis=None
causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to beFalse
.The function is internally based on
dr.reduce()
. See the documentation of this function for further information.Like
dr.reduce()
, this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Bool
does not produce a scalar Pythonbool
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
bool
if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.any(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()
to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt()
,while_loop()
, etc. The@dr.syntax
decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...
) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.
- Returns:
Result of the reduction operation
- Return type:
bool | drjit.ArrayBase
- drjit.none(value: object, axis: int | tuple[int, ...] | None = 0) object ¶
Check if none elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
|
(OR) operator and finally returns the bit-wise inverse of the result.The function is internally based on
dr.reduce()
. See the documentation of this function for further information.Like
dr.reduce()
, this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Bool
does not produce a scalar Pythonbool
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
bool
if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.none(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()
to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt()
,while_loop()
, etc. The@dr.syntax
decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...
) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axes (int | tuple[int, ...] | None) – The axis/axes along which to reduce. The default value is
0
.
- Returns:
Result of the reduction operation
- Return type:
bool | drjit.ArrayBase
- drjit.count(value: object, axis: int | tuple[int, ...] | None = 0) object ¶
Compute the number of active entries along the given axis.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
+
operator (interpretingTrue
elements as1
andFalse
elements as0
). It returns an unsigned 32-bit version of the input array.By default, it reduces along index
0
, which refers to the outermost axis. Negative indices (e.g.-1
) count backwards from the innermost axis. The special argumentaxis=None
causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to be zero.See the section on horizontal reductions for important general information about their properties.
- Parameters:
value (bool | Sequence | drjit.ArrayBase) – A Python or Dr.Jit mask type
axis (int | None) – The axis along which to reduce. The default value of
0
refers to the outermost axis. Negative values count backwards from the innermost axis. A value ofNone
causes a simultaneous reduction along all axes.
- Returns:
Result of the reduction operation
- Return type:
int | drjit.ArrayBase
- drjit.dot(arg0: object, arg1: object, /) object ¶
Compute the dot product of two arrays.
Whenever possible, the implementation uses a sequence of
fma()
(fused multiply-add) operations to evaluate the dot product. When the input is a 1D JIT array likedrjit.cuda.Float
, the function evaluates the product of the input arrays viadrjit.eval()
and then performs a sum reduction viadrjit.sum()
.See the section on horizontal reductions for details on the properties of such horizontal reductions.
- Parameters:
arg0 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Dot product of inputs
- Return type:
float | int | drjit.ArrayBase
- drjit.abs_dot(arg0: object, arg1: object, /) object ¶
Compute the absolute value of the dot product of two arrays.
This function implements a convenience short-hand for
abs(dot(arg0, arg1))
.See the section on horizontal reductions for details on the properties of such horizontal reductions.
- Parameters:
arg0 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Absolute value of the dot product of inputs
- Return type:
float | int | drjit.ArrayBase
- drjit.squared_norm(arg: object, /) object ¶
Computes the squared 2-norm of a Dr.Jit array, tensor, or Python sequence.
The operation is equivalent to
dr.dot(arg, arg)
The
squared_norm()
operation performs a horizontal reduction. Please see the section on horizontal reductions for details on their properties.- Parameters:
arg (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
squared 2-norm of the input
- Return type:
float | int | drjit.ArrayBase
- drjit.norm(arg: object, /) object ¶
Computes the 2-norm of a Dr.Jit array, tensor, or Python sequence.
The operation is equivalent to
dr.sqrt(dr.dot(arg, arg))
The
norm()
operation performs a horizontal reduction. Please see the section on horizontal reductions for details on their properties.- Parameters:
arg (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
2-norm of the input
- Return type:
float | int | drjit.ArrayBase
- drjit.prefix_sum(value: ArrayT, exclusive: bool = True, axis: int | None = 0) ArrayT ¶
Compute an exclusive or inclusive prefix sum of the input array.
By default, the function returns an output array \(\mathbf{y}\) of the same size as the input \(\mathbf{x}\), where
\[y_i = \sum_{j=0}^{i-1} x_j.\]which is known as an exclusive prefix sum, as each element of the output array excludes the corresponding input in its sum. When the
exclusive
argument is set toFalse
, the function instead returns an inclusive prefix sum defined as\[y_i = \sum_{j=0}^i x_j.\]There is also a convenience alias
drjit.cumsum()
that computes an inclusive sum analogous to various other nd-array frameworks.Not all numeric data types are supported by
prefix_sum()
: presently, the function acceptsInt32
,UInt32
,UInt64
,Float32
, andFloat64
-typed arrays.The CUDA backend implementation for “large” numeric types (
Float64
,UInt64
) has the following technical limitation: when reducing 64-bit integers, their values must be smaller than \(2^{62}\). When reducing double precision arrays, the two least significant mantissa bits are clamped to zero when forwarding the prefix from one 512-wide block to the next (at a very minor, probably negligible loss in accuracy). See the implementation for details on the rationale of this limitation.- Parameters:
value (drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
exclusive (bool) – Specifies whether or not the prefix sum should be exclusive (the default) or inclusive.
- Returns:
An array of the same type containing the computed prefix sum.
- Return type:
- drjit.cumsum(arg, /)¶
Compute an cumulative sum (aka. inclusive prefix sum) of the input array.
This function wraps
drjit.prefix_sum()
and is implemented asdef cumsum(arg, /): return prefix_sum(arg, exclusive=False)
- drjit.reverse(value, axis: int = 0)¶
Reverses the given Dr.Jit array or Python sequence along the specified axis.
- Parameters:
value (ArrayBase|Sequence) – Dr.Jit array or Python sequence type
axis (int) – Axis along which the reversal should be performed. Only
axis==0
is supported for now.
- Returns:
An output of the same type as value containing a copy of the reversed array.
- Return type:
object
- drjit.compress(arg: drjit.ArrayBase, /) object ¶
Compress a mask into an array of nonzero indices.
This function takes an boolean array as input and then returns an unsigned 32-bit integer array containing the indices of nonzero entries.
It can be used to reduce a stream to a subset of active entries via the following recipe:
# Input: an active mask and several arrays data_1, data_2, ... dr.schedule(active, data_1, data_2, ...) indices = dr.compress(active) data_1 = dr.gather(type(data_1), data_1, indices) data_2 = dr.gather(type(data_2), data_2, indices) # ...
There is some conceptual overlap between this function and
drjit.scatter_inc()
, which can likewise be used to reduce a stream to a smaller subset of active items. Please see the documentation ofdrjit.scatter_inc()
for details.Danger
This function internally performs a synchronization step.
- Parameters:
arg (bool | drjit.ArrayBase) – A Python or Dr.Jit boolean type
- Returns:
Array of nonzero indices
- drjit.ravel(array: object, order: str = 'A') object ¶
Convert the input into a contiguous flat array.
This operation takes a Dr.Jit array, typically with some static and some dynamic dimensions (e.g.,
drjit.cuda.Array3f
with shape 3xN), and converts it into a flattened 1D dynamically sized array (e.g.,drjit.cuda.Float
) using either a C or Fortran-style ordering convention.It can also convert Dr.Jit tensors into a flat representation, though only C-style ordering is supported in this case.
Internally,
drjit.ravel()
performs a series of calls todrjit.scatter()
to suitably reorganize the array contents.For example,
x = dr.cuda.Array3f([1, 2], [3, 4], [5, 6]) y = dr.ravel(x, order=...)
will produce
[1, 3, 5, 2, 4, 6]
withorder='F'
(the default for Dr.Jit arrays), which means that X/Y/Z components alternate.[1, 2, 3, 4, 5, 6]
withorder='C'
, in which case all X coordinates are written as a contiguous block followed by the Y- and then Z-coordinates.
- Parameters:
array (drjit.ArrayBase) – An arbitrary Dr.Jit array or tensor
order (str) – A single character indicating the index order.
'F'
indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'
specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'
(automatic) will use F-style ordering for arrays and C-style ordering for tensors.
- Returns:
A dynamic 1D array containing the flattened representation of
array
with the desired ordering. The type of the return value depends on the type of the input. Whenarray
is already contiguous/flattened, this function returns it without making a copy.- Return type:
object
- drjit.unravel(dtype: type[ArrayT], array: AnyArray, order: Literal['A', 'C', 'F'] = 'A') ArrayT ¶
Load a sequence of Dr.Jit vectors/matrices/etc. from a contiguous flat array.
This operation implements the inverse of
drjit.ravel()
. In contrast todrjit.ravel()
, it requires one additional parameter (dtype
) specifying type of the return value. For example,x = dr.cuda.Float([1, 2, 3, 4, 5, 6]) y = dr.unravel(dr.cuda.Array3f, x, order=...)
will produce an array of two 3D vectors with different contents depending on the indexing convention:
[1, 2, 3]
and[4, 5, 6]
when unraveled withorder='F'
(the default for Dr.Jit arrays), and[1, 3, 5]
and[2, 4, 6]
when unraveled withorder='C'
Internally,
drjit.unravel()
performs a series of calls todrjit.gather()
to suitably reorganize the array contents.- Parameters:
dtype (type) – An arbitrary Dr.Jit array type
array (drjit.ArrayBase) – A dynamically sized 1D Dr.Jit array instance that is compatible with
dtype
. In other words, both must have the same underlying scalar type and be located imported in the same package (e.g.,drjit.llvm.ad
).order (str) – A single character indicating the index order.
'F'
(the default) indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'
specifies row-major/C-style ordering, in which case the last index changes at the highest frequency.
- Returns:
An instance of type
dtype
containing the result of the unravel operation.- Return type:
object
- drjit.reshape(dtype: type, value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) object ¶
- drjit.reshape(dtype: type, value: object, shape: int, order: str = 'A', shrink: bool = False) object
Converts
value
into an array of typedtype
by rearranging the contents according to the specified shape.The parameter
shape
may contain a single-1
-valued target dimension, in which case its value is inferred from the remaining shape entries and the size of the input. Whenshape
is of typeint
, it is interpreted as a 1-tuple(shape,)
.This function supports the following behaviors:
Reshaping tensors: Dr.Jit tensors admit arbitrary shapes. The
drjit.reshape()
can convert between them as long as the total number of elements remains unchanged.>>> from drjit.llvm.ad import TensorXf >>> value = dr.arange(TensorXf, 6) >>> dr.reshape(dtype=TensorXf, value=value, shape=(3, -1)) [[0, 1] [2, 3] [4, 5]]
Reshaping nested arrays: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions
drjit.ravel()
anddrjit.unravel()
.>>> from drjit.llvm.ad import Array2f, Array3f >>> value = Array2f([1, 2, 3], [4, 5, 6]) >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='C') [[1, 4, 2], [5, 3, 6]] >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='F') [[1, 3, 5], [2, 4, 6]]
(By convention, Dr.Jit nested arrays are always printed in transposed form, which explains the difference in output compared to the identically shaped Tensor example just above.)
The
order
argument can be used to specify C ("C"
) or Fortran ("F"
)-style ordering when rearranging the array. The default value"A"
corresponds to Fortran-style ordering.PyTrees: When
value
is a PyTree, the operation recursively threads through the tree’s elements.
Stream compression and loops that fork recursive work. When called with
shrink=True
, the function creates a view of the original data that potentially has a smaller number of elements.The main use of this feature is to implement loops that process large numbers of elements in parallel, and which need to occasionally “fork” some recursive work. On modern compute accelerators, an efficient way to handle this requirement is to append this work into a queue that is processed in a subsequent pass until no work is left. The reshape operation with
shrink=True
then resizes the preallocated queue to the actual number of collected items, which are the input of the next iteration.Please refer to the following example that illustrates how
drjit.scatter_inc()
,drjit.scatter()
, anddrjit.reshape(..., shrink=True)
can be combined to realize a parallel loop with a fork condition@drjit.syntax def f(): # Loop state variables (an arbitrary array or PyTree) state = ... # Determine how many elements should be processed size = dr.width(loop_state) # Run the following loop until no work is left while size > 0: # 1-element array used as an atomic counter queue_index = UInt(0) # Preallocate memory for the queue. The necessary # amount of memory is task-dependent queue_size = size queue = dr.empty(dtype=type(state), shape=queue_size) # Create an opaque variable representing the number 'loop_state'. # This keeps this changing value from being baked into the program, # which is needed for proper kernel caching queue_size_o = dr.opaque(UInt32, queue_size) while not stopping_criterion(state): # This line represents the loop body that processes work state = loop_body(state) # if the condition 'fork' is True, spawn a new work item that # will be handled in a future iteration of the parent loop. if fork(state): # Atomically reserve a slot in 'queue' slot = dr.scatter_inc(target=queue_index, index=0) # Work item for the next iteration, task dependent todo = state # Be careful not to write beyond the end of the queue valid = slot < queue_size_o # Write 'todo' into the reserved slot dr.scatter(target=queue, index=slot, value=todo, active=valid) # Determine how many fork operations took place size = queue_index[0] if size > queue_size: raise RuntimeError('Preallocated queue was too small: tried to store ' f'{size} elements in a queue of size {queue_size}') # Reshape the queue and re-run the loop state = dr.reshape(dtype=type(state), value=queue, shape=size, shrink=True)
- Parameters:
dtype (type) – Desired output type of the reshaped array. This could equal
type(value)
or refer to an entirely different array type.value (object) – An arbitrary Dr.Jit array, tensor, or PyTree. The function returns unknown objects of other types unchanged.
shape (int|tuple[int, ...]) – The target shape.
order (str) – A single character indicating the index order used to reinterpret the input.
'F'
indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'
specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'
(automatic) will use F-style ordering for arrays and C-style ordering for tensors.shrink (bool) – Cheaply construct a view of the input that potentially has a smaller number of elements. The main use case of this method is explained above.
- Returns:
The reshaped array or PyTree.
- Return type:
object
- drjit.slice(value: object, index: object = 0) object ¶
Select a subset of the input array or PyTree along the trailing dynamic dimension.
Given a Dr.Jit array
value
with shape(..., N)
(whereN
represents a dynamically sized dimension), this operation effectively evaluates the expressionvalue[..., index]
. It recursively traverses PyTrees and transforms each compatible array element. Other values are returned unchanged.The following properties of
index
determine the return type:When
index
is a 1D integer array, the operation reduces to one or more calls todrjit.gather()
, andslice()
returns a reduced output object of the same type and structure.When
index
is a scalar Pythonint
, the trailing dimension is entirely removed, and the operation returns an array from thedrjit.scalar
namespace containing the extracted values.
- drjit.tile(value: T, count: int) T ¶
Tile the input array
count
times along the trailing dimension.This function replicates the input
count
times along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. Whencount==1
, the function returns the input without changes.An example is shown below:
- Parameters:
value (drjit.ArrayBase) – A Dr.Jit type or PyTree.
count (int) – Number of repetitions
- Returns:
The tiled input as described above. The return type matches that of
value
.- Return type:
object
- drjit.repeat(value: T, count: int) T ¶
Repeat each successive entry of the input
count
times along the trailing dimension.This function replicates the input
count
times along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. Whencount==1
, the function returns the input without changes.An example is shown below:
- Parameters:
value (drjit.ArrayBase) – A Dr.Jit type or PyTree.
count (int) – Number of repetitions
- Returns:
The repeated input as described above. The return type matches that of
value
.- Return type:
object
Mask operations¶
Also relevant here are any()
, all()
, none()
, and count()
.
- drjit.select(arg0: object, arg1: object, arg2: object, /) object ¶
- drjit.select(arg0: bool, arg1: object, arg2: object, /) object
Select elements from inputs based on a condition
This function uses a first mask argument to select between the subsequent two arguments. It implements the following component-wise operation:
\[\mathrm{result}_i = \begin{cases} \texttt{arg1}_i,\quad&\text{if }\texttt{arg0}_i,\\ \texttt{arg2}_i,\quad&\text{otherwise.} \end{cases}\]- Parameters:
arg0 (bool | drjit.ArrayBase) – A Python or Dr.Jit mask type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for
True
-valued mask entries.arg2 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for
False
-valued mask entries.
- Returns:
Component-wise result of the selection operation
- Return type:
float | int | drjit.ArrayBase
- drjit.isinf(arg, /)¶
Performs an elementwise test for positive or negative infinity
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test
- Return type:
- drjit.isnan(arg, /)¶
Performs an elementwise test for NaN (Not a Number) values
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test.
- Return type:
- drjit.isfinite(arg, /)¶
Performs an elementwise test that checks whether values are finite and not equal to NaN (Not a Number)
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test
- Return type:
- drjit.allclose(a: object, b: object, rtol: float | None = None, atol: float | None = None, equal_nan: bool = False) bool ¶
Returns
True
if two arrays are element-wise equal within a given error tolerance.The function considers both absolute and relative error thresholds. In particular, a and b are considered equal if all elements satisfy
\[|a - b| \le |b| \cdot \texttt{rtol} + \texttt{atol}. \]If not specified, the constants
atol
andrtol
are chosen depending on the precision of the input arrays:Precision
rtol
atol
float64
1e-5
1e-8
float32
1e-3
1e-5
float16
1e-1
1e-2
Note that these constants used are fairly loose and far larger than the roundoff error of the underlying floating point representation. The double precision parameters were chosen to match the behavior of
numpy.allclose()
.- Parameters:
a (object) – A Dr.Jit array or other kind of numeric sequence type.
b (object) – A Dr.Jit array or other kind of numeric sequence type.
rtol (float) – A relative error threshold chosen according to the above table.
atol (float) – An absolute error threshold according to the above table.
equal_nan (bool) – If a and b both contain a NaN (Not a Number) entry, should they be considered equal? The default is
False
.
- Returns:
The result of the comparison.
- Return type:
bool
Miscellaneous operations¶
- drjit.shape(arg: object) tuple[int, ...] ¶
Return a tuple describing dimension and shape of the provided Dr.Jit array, tensor, or standard sequence type.
When the input array is ragged the function raises a
RuntimeError
. The term ragged refers to an array, whose components have mismatched sizes, such as[[1, 2], [3, 4, 5]]
. Note that scalar entries (e.g.[[1, 2], [3]]
) are acceptable, since broadcasting can effectively convert them to any size.The expressions
drjit.shape(arg)
andarg.shape
are equivalent.- Parameters:
arg (drjit.ArrayBase) – an arbitrary Dr.Jit array or tensor
- Returns:
A tuple describing the dimension and shape of the provided Dr.Jit input array or tensor.
- Return type:
tuple[int, …]
- drjit.width(arg: object, /) int ¶
- drjit.width(*args) int
Returns the vectorization width of the provided input(s), which is defined as the length of the last dynamic dimension.
When working with Jit-compiled CUDA or LLVM-based arrays, this corresponds to the number of items being processed in parallel.
The function raises an exception when the input(s) is ragged, i.e., when it contains arrays with incompatible sizes. It returns
1
if if the input is scalar and/or does not contain any Dr.Jit arrays.- Parameters:
arg (object) – An arbitrary Dr.Jit array or PyTree.
- Returns:
The width of the provided input(s).
- Return type:
int
- drjit.slice_index(dtype: type[drjit.ArrayBase], shape: tuple, indices: tuple) tuple[tuple, object] ¶
Computes an index array that can be used to slice a tensor. It is used internally by Dr.Jit to implement complex cases of the
__getitem__
operation.It must be called with the desired output
dtype
, which must be a dynamically sized 1D array of 32-bit integers. Theshape
parameter specifies the dimensions of a hypothetical input tensor, andindices
contains the entries that would appear in a complex slicing operation, but as a tuple. For example,[5:10:2, ..., None]
would be specified as(slice(5, 10, 2), Ellipsis, None)
.An example is shown below:
>>> dr.slice_index(dtype=dr.scalar.ArrayXu, shape=(10, 1), indices=(slice(0, 10, 2), 0)) [0, 2, 4, 6, 8]
- Parameters:
dtype (type) – A dynamic 32-bit unsigned integer Dr.Jit array type, such as
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
.shape (tuple[int, ...]) – The shape of the tensor to be sliced.
indices (tuple[int|slice|ellipsis|None|dr.ArrayBase, ...]) – A set of indices used to slice the tensor. Its entries can be
slice
instances, integers, integer arrays,...
(ellipsis) orNone
.
- Returns:
Tuple consisting of the output array shape and a flattened unsigned integer array of type
dtype
containing element indices.- Return type:
tuple[tuple[int, …], drjit.ArrayBase]
- drjit.meshgrid(*args, indexing='xy') tuple ¶
Return flattened N-D coordinate arrays from a sequence of 1D coordinate vectors.
This function constructs flattened coordinate arrays that are convenient for evaluating and plotting functions on a regular grid. An example is shown below:
import drjit as dr x, y = dr.meshgrid( dr.arange(dr.llvm.UInt, 4), dr.arange(dr.llvm.UInt, 4) ) # x = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] # y = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
This function carefully reproduces the behavior of
numpy.meshgrid
except for one major difference: the output coordinates are returned in flattened/raveled form. Like the NumPy version, theindexing=='xy'
case internally reorders the first two elements of*args
.- Parameters:
*args – A sequence of 1D coordinate arrays
indexing (str) – Specifies the indexing convention. Must be either set
'xy' (to)
- Returns:
A tuple of flattened coordinate arrays (one per input)
- Return type:
tuple
- drjit.make_opaque(arg: object, /) None ¶
- drjit.make_opaque(*args) None
Forcefully evaluate arrays (including literal constants).
This function implements a more drastic version of
drjit.eval()
that additionally converts literal constant arrays into evaluated (device memory-based) representations.It is related to the function
drjit.opaque()
that can be used to directly construct such opaque arrays. Please see the documentation of this function regarding the rationale of making array contents opaque to Dr.Jit’s symbolic tracing mechanism.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)
- drjit.copy(arg: T, /) T ¶
Create a deep copy of a PyTree
This function recursively traverses PyTrees and replaces Dr.Jit arrays with copies created via the ordinary copy constructor. It also rebuilds tuples, lists, dictionaries, and other custom data strutures.
Just-in-time compilation¶
- enum drjit.JitBackend(value)¶
List of just-in-time compilation backends supported by Dr.Jit. See also
drjit.backend_v()
.Valid values are as follows:
- Invalid = JitBackend.Invalid¶
Indicates that a type is not handled by a Dr.Jit backend (e.g., a scalar type)
- CUDA = JitBackend.CUDA¶
Dr.Jit backend targeting NVIDIA GPUs using PTX (“Parallel Thread Execution”) IR.
- LLVM = JitBackend.LLVM¶
Dr.Jit backend targeting various processors via the LLVM compiler infrastructure.
- enum drjit.VarType(value)¶
List of possible scalar array types (not all of them are supported).
Valid values are as follows:
- Void = VarType.Void¶
Unknown/unspecified type.
- Bool = VarType.Bool¶
Boolean/mask type.
- Int8 = VarType.Int8¶
Signed 8-bit integer.
- UInt8 = VarType.UInt8¶
Unsigned 8-bit integer.
- Int16 = VarType.Int16¶
Signed 16-bit integer.
- UInt16 = VarType.UInt16¶
Unsigned 16-bit integer.
- Int32 = VarType.Int32¶
Signed 32-bit integer.
- UInt32 = VarType.UInt32¶
Unsigned 32-bit integer.
- Int64 = VarType.Int64¶
Signed 64-bit integer.
- UInt64 = VarType.UInt64¶
Unsigned 64-bit integer.
- Pointer = VarType.Pointer¶
Pointer to a memory address.
- Float16 = VarType.Float16¶
16-bit floating point format (IEEE 754).
- Float32 = VarType.Float32¶
32-bit floating point format (IEEE 754).
- Float64 = VarType.Float64¶
64-bit floating point format (IEEE 754).
- enum drjit.VarState(value)¶
The
drjit.ArrayBase.state
property returns one of the following enumeration values describing possible evaluation states of a Dr.Jit variable.Valid values are as follows:
- Invalid = VarState.Invalid¶
The variable has length 0 and effectively does not exist.
- Literal = VarState.Literal¶
A literal constant. Does not consume device memory.
- Undefined = VarState.Undefined¶
An undefined memory region. Does not (yet) consume device memory.
- Unevaluated = VarState.Unevaluated¶
An ordinary unevaluated variable that is neither a literal constant nor symbolic.
- Evaluated = VarState.Evaluated¶
Evaluated variable backed by an device memory region.
- Dirty = VarState.Dirty¶
An evaluated variable backed by a device memory region. The variable furthermore has pending side effects (i.e. the user has performed a :py:func`:drjit.scatter`,
drjit.scatter_reduce()
:py:func`:drjit.scatter_inc`, :py:func`:drjit.scatter_add`, or :py:func`:drjit.scatter_add_kahan` operation, and the effect of this operation has not been realized yet). The array’s status will automatically change toEvaluated
the next time that Dr.Jit evaluates computation, e.g. viadrjit.eval()
.
- Symbolic = VarState.Symbolic¶
A symbolic variable that could take on various inputs. Cannot be evaluated.
- Mixed = VarState.Mixed¶
This is a nested array, and the components have mixed states.
- enum drjit.JitFlag(value)¶
Flags that control how Dr.Jit compiles and optimizes programs.
This enumeration lists various flag that control how Dr.Jit compiles and optimizes programs, most of which are enabled by default. The status of each flag can be queried via
drjit.flag()
and enabled/disabled via thedrjit.set_flag()
or the recommendeddrjit.scoped_set_flag()
functions, e.g.:with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False): # code that has this flag disabled goes here
The most common reason to update the flags is to switch between symbolic and evaluated execution of loops and functions. The former eagerly executes programs by breaking them into many smaller kernels, while the latter records computation symbolically to assemble large megakernels. See explanations below along with the documentation of
drjit.switch()
anddrjit.while_loop
for more details on these two modes.Dr.Jit flags are a thread-local property. This means that multiple independent threads using Dr.Jit can set them independently without interfering with each other.
- Member Type:
int
Valid values are as follows:
- Debug = JitFlag.Debug¶
Debug mode: Enable functionality to uncover errors in application code.
When debug mode is enabled, Dr.Jit inserts a number of additional runtime checks to locate sources of undefined behavior.
Debug mode comes at a significant cost: it interferes with kernel caching, reduces tracing performance, and produce kernels that run slower. We recommend that you only use it periodically before a release, or when encountering a serious problem like a crashing kernel.
First, debug mode enables assertion checks in user code such as those performed by
drjit.assert_true()
,drjit.assert_false()
, anddrjit.assert_equal()
.Second, Dr.Jit inserts additional checks to intercept out-of-bound reads and writes performed by operations such as
drjit.scatter()
,drjit.gather()
,drjit.scatter_reduce()
,drjit.scatter_inc()
, etc. It also detects calls to invalid callables performed viadrjit.switch()
,drjit.dispatch()
. Such invalid operations are masked, and they generate a warning message on the console, e.g.:>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100)) RuntimeWarning: drjit.gather(): out-of-bounds read from position 100 in an array↵ of size 3. (<stdin>:2)
Finally, Dr.Jit also installs a python tracing hook that associates all Jit variables with their Python source code location, and this information is propagated all the way to the final intermediate representation (PTX, LLVM IR). This is useful for low-level debugging and development of Dr.Jit itself. You can query the source location information of a variable
x
by writingx.label
.Due to limitations of the Python tracing interface, this handler becomes active within the next called function (or Jupyter notebook cell) following activation of the
drjit.JitFlag.Debug
flag. It does not apply to code within the same scope/function.C++ code using Dr.Jit also benefits from debug mode but will lack accurate source code location information. In mixed-language projects, the reported file and line number information will reflect that of the last operation on the Python side of the interface.
- ReuseIndices = JitFlag.ReuseIndices¶
Index reuse: Dr.Jit consists of two main parts: the just-in-time compiler, and the automatic differentiation layer. Both maintain an internal data structure representing captured computation, in which each variable is associated with an index (e.g.,
r1234
in the JIT compiler, anda1234
in the AD graph).The index of a Dr.Jit array in these graphs can be queried via the
drjit.index
anddrjit.index_ad
variables, and they are also visible in debug messages (ifdrjit.set_log_level()
is set to a more verbose debug level).Dr.Jit aggressively reuses the indices of expired variables by default, but this can make debug output difficult to interpret. When when debugging Dr.Jit itself, it is often helpful to investigate the history of a particular variable. In such cases, set this flag to
False
to disable variable reuse both at the JIT and AD levels. This comes at a cost: the internal data structures keep on growing, so it is not suitable for long-running computations.Index reuse is enabled by default.
- ConstantPropagation = JitFlag.ConstantPropagation¶
Constant propagation: immediately evaluate arithmetic involving literal constants on the host and don’t generate any device-specific code for them.
For example, the following assertion holds when value numbering is enabled in Dr.Jit.
from drjit.llvm import Int # Create two literal constant arrays a, b = Int(4), Int(5) # This addition operation can be immediately performed and does not need to be recorded c1 = a + b # Double-check that c1 and c2 refer to the same Dr.Jit variable c2 = Int(9) assert c1.index == c2.index
Constant propagation is enabled by default.
- ValueNumbering = JitFlag.ValueNumbering¶
Local value numbering: a simple variant of common subexpression elimination that collapses identical expressions within basic blocks. For example, the following assertion holds when value numbering is enabled in Dr.Jit.
from drjit.llvm import Int # Create two non-literal arrays stored in device memory a, b = Int(1, 2, 3), Int(4, 5, 6) # Perform the same arithmetic operation twice c1 = a + b c2 = a + b # Verify that c1 and c2 reference the same Dr.Jit variable assert c1.index == c2.index
Local value numbering is enabled by default.
- FastMath = JitFlag.FastMath¶
Fast Math: this flag is analogous to the
-ffast-math
flag in C compilers. When set, the system may use approximations and simplifications that sacrifice strict IEEE-754 compatibility.Currently, it changes two behaviors:
expressions of the form
a * 0
will be simplified to0
(which is technically not correct whena
is infinite or NaN-valued).Dr.Jit will use slightly approximate division and square root operations in CUDA mode. Note that disabling fast math mode is costly on CUDA devices, as the strict IEEE-754 compliant version of these operations uses software-based emulation.
Fast math mode is enabled by default.
- SymbolicLoops = JitFlag.SymbolicLoops¶
Dr.Jit provides two main ways of compiling loops involving Dr.Jit arrays.
Symbolic mode (the default): Dr.Jit executes the loop a single time regardless of how many iterations it requires in practice. It does so with symbolic (abstract) arguments to capture the loop condition and body and then turns it into an equivalent loop in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.
Evaluated mode: Dr.Jit evaluates the loop’s state variables and reduces the loop condition to a single element (
bool
) that expresses whether any elements are still alive. If so, it runs the loop body and the process repeats.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Symbolic loops are enabled by default.
- OptimizeLoops = JitFlag.OptimizeLoops¶
Perform basic optimizations for loops involving Dr.Jit arrays.
This flag enables two optimizations:
Constant arrays: loop state variables that aren’t modified by the loop are automatically removed. This shortens the generated code, which can be helpful especially in combination with the automatic transformations performed by
@drjit.syntax
that can be somewhat conservative in classifying too many local variables as potential loop state.Literal constant arrays: In addition to the above point, constant loop state variables that are literal constants are propagated into the loop body, where this may unlock further optimization opportunities.
This is useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.
A practical implication of this optimization flag is that it may cause
drjit.while_loop()
to run the loop body twice instead of just once.This flag is enabled by default. Note that it is only meaningful in combination with
SymbolicLoops
.
- CompressLoops = JitFlag.CompressLoops¶
Compress the loop state of evaluated loops after every iteration.
When an evaluated loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.
Setting this flag causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases.
This flag is disabled by default. Note that it only applies to evaluated loops (i.e., when
SymbolicLoops
is disabled, or themode='evaluted'
parameter as passed to the loop in question).
- SymbolicCalls = JitFlag.SymbolicCalls¶
Dr.Jit provides two main ways of compiling function calls targeting instance arrays.
Symbolic mode (the default): Dr.Jit invokes each callable with symbolic (abstract) arguments. It does this to capture a transcript of the computation that it can turn into a function in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.
Evaluated mode: Dr.Jit evaluates all inputs and groups them by instance ID. Following this, it launches a kernel per instance to process the rearranged inputs and assemble the function return value.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Besides calls to instance arrays, this flag also controls the behavior of the functions
drjit.switch()
anddrjit.dispatch()
.Symbolic calls are enabled by default.
- OptimizeCalls = JitFlag.OptimizeCalls¶
Perform basic optimizations for function calls on instance arrays.
This flag enables two optimizations:
Constant propagation: Dr.Jit will propagate literal constants across function boundaries while tracing, which can unlock simplifications within. This is especially useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.
Devirtualization: When an element of the return value has the same computation graph in all instances, it is removed from the function call interface and moved to the caller.
The flag is enabled by default. Note that it is only meaningful in combination with
SymbolicCalls
. Besides calls to instance arrays, this flag also controls the behavior of the functionsdrjit.switch()
anddrjit.dispatch()
.
- MergeFunctions = JitFlag.MergeFunctions¶
Deduplicate code generated by function calls on instance arrays.
When
arr
is an instance array (potentially with thousands of instances), a function call likearr.f(inputs...)
can potentially generate vast numbers of different functions in the generated code. At the same time, many of these functions may contain identical code (or code that is identical except for data references).
Dr.Jit can exploit such redundancy and merge such functions during code generation. Besides generating shorter programs, this also helps to reduce thread divergence.
This flag is enabled by default. Note that it is only meaningful in combination with
SymbolicCalls
. Besides calls to instance arrays, this flag also controls the behavior of the functionsdrjit.switch()
anddrjit.dispatch()
.
- ForceOptiX = JitFlag.ForceOptiX¶
Force execution through OptiX even if a kernel doesn’t use ray tracing. This only applies to the CUDA backend is mainly helpful for automated tests done by the Dr.Jit team.
This flag is disabled by default.
- PrintIR = JitFlag.PrintIR¶
Print the low-level IR representation when launching a kernel.
If enabled, this flag causes Dr.Jit to print the low-level IR (LLVM IR, NVIDIA PTX) representation of the generated code onto the console (or Jupyter notebook).
This flag is disabled by default.
- KernelHistory = JitFlag.KernelHistory¶
Maintain a history of kernel launches to profile/debug programs.
Programs written on top of Dr.Jit execute in an extremely asynchronous manner. By default, the system postpones the computation to build large fused kernels. Even when this computation eventually runs, it does so asynchronously with respect to the host, which can make benchmarking difficult.
In general, beware of the following benchmarking anti-pattern:
import time a = time.time() # Some Dr.Jit computation b = time.time() print("took %.2f ms" % ((b-a) * 1000))
In the worst case, the measured time interval may only capture the tracing time, without any actual computation having taken place. Another common mistake with this pattern is that Dr.Jit or the target device may still be busy with computation that started prior to the
a = time.time()
line, which is now incorrectly added to the measured period.Dr.Jit provides a kernel history feature, where it creates an entry in a list whenever it launches a kernel or related operation (memory copies, etc.). This not only gives accurate and isolated timings (measured with counters on the CPU/GPU) but also reveals if a kernel was launched at all. To capture the kernel history, set this flag just before the region to be benchmarked and call
drjit.kernel_history()
at the end.Capturing the history has a (very) small cost and is therefore disabled by default.
- LaunchBlocking = JitFlag.LaunchBlocking¶
Force synchronization after every kernel launch. This is useful to isolate severe problems (e.g. crashes) to a specific kernel.
This flag has a severe performance impact and is disabled by default.
- ScatterReduceLocal = JitFlag.ScatterReduceLocal¶
Reduce locally before performing atomic scatter-reductions.
Atomic memory operations are expensive when many writes target the same region of memory. This leads to a phenomenon called contention that is normally associated with significant slowdowns (10-100x aren’t unusual).
This issue is particularly common when automatically differentiating computation in reverse mode (e.g.
drjit.backward()
), since this transformation turns differentiable global memory reads into atomic scatter-additions. A differentiable scalar read is all it takes to create such an atomic memory bottleneck.To reduce this cost, Dr.Jit can perform a local reduction that uses cooperation between SIMD/warp lanes to resolve all requests targeting the same address and then only issuing a single atomic memory transaction per unique target. This can reduce atomic memory traffic 32-fold on the GPU (CUDA) and 16-fold on the CPU (AVX512). On the CUDA backend, local reduction is currently only supported for 32-bit operands (signed/unsigned integers and single precision variables).
The section on optimizations presents plots that demonstrate the impact of this optimization.
The JIT flag
drjit.JitFlag.ScatterReduceLocal
affects the behavior ofscatter_add()
,scatter_reduce()
along with the reverse-mode derivative ofgather()
. Setting the flag toTrue
will usually cause amode=
argument value ofdrjit.ReduceOp.Auto
to be interpreted asdrjit.ReduceOp.Local
. Another LLVM-specific optimization takes precedence in certain situations, refer to the discussion of this flag for details.This flag is enabled by default.
- SymbolicConditionals = JitFlag.SymbolicConditionals¶
Dr.Jit provides two main ways of compiling conditionals involving Dr.Jit arrays.
Symbolic mode (the default): Dr.Jit captures the computation performed by the
True
andFalse
branches and generates an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.Evaluated mode: Dr.Jit always executes both branches and blends their outputs.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Symbolic conditionals are enabled by default.
- SymbolicScope = JitFlag.SymbolicScope¶
This flag is set to
True
when Dr.Jit is currently capturing symbolic computation. The flag is automatically managed and should not be updated by application code.User code may query this flag to check if it is legal to perform certain operations (e.g., evaluating array contents).
Note that this information can also be queried in a more fine-grained manner (per variable) using the
drjit.ArrayBase.state
field.
- Default = JitFlag.Default¶
The default set of optimization flags consisting of
- drjit.has_backend(arg: drjit.JitBackend, /) int ¶
Check if the specified Dr.Jit backend was successfully initialized.
- drjit.schedule(arg: object, /) bool ¶
- drjit.schedule(*args) bool
Schedule the provided JIT variable(s) for later evaluation
This function causes
args
to be evaluated by the next kernel launch. In other words, the effect of this operation is deferred: the next time that Dr.Jit’s LLVM or CUDA backends compile and execute code, they will include the trace of the specified variables in the generated kernel and turn them into an explicit memory-based representation.Scheduling and evaluation of traced computation happens automatically, hence it is rare that a user would need to call this function explicitly. Explicit scheduling can improve performance in certain cases—for example, consider the following code:
# Computation that produces Dr.Jit arrays a, b = ... # The following line launches a kernel that computes 'a' print(a) # The following line launches a kernel that computes 'b' print(b)
If the traces of
a
andb
overlap (perhaps they reference computation from an earlier step not shown here), then this is inefficient as these steps will be executed twice. It is preferable to launch bigger kernels that leverage common subexpressions, which is whatdrjit.schedule()
enables:a, b = ... # Computation that produces Dr.Jit arrays # Schedule both arrays for deferred evaluation, but don't evaluate yet dr.schedule(a, b) # The following line launches a kernel that computes both 'a' and 'b' print(a) # References the stored array, no kernel launch print(b)
Note that
drjit.eval()
would also have been a suitable alternative in the above example; the main difference todrjit.schedule()
is that it does the evaluation immediately without deferring the kernel launch.This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During recursion, the function gathers all unevaluated Dr.Jit arrays. Evaluated arrays and incompatible types are ignored. Multiple variables can be equivalently scheduled with a single
drjit.schedule()
call or a sequence of calls todrjit.schedule()
. Variables that are garbage collected between the originaldrjit.schedule()
call and the next kernel launch are ignored and will not be stored in memory.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to all differentiable variables.)
- Returns:
True
if a variable was scheduled,False
if the operation did not do anything.- Return type:
bool
- drjit.eval(arg: object, /) bool ¶
- drjit.eval(*args) bool
Evaluate the provided JIT variable(s)
Dr.Jit automatically evaluates variables as needed, hence it is usually not necessary to call this function explicitly. That said, explicit evaluation may sometimes improve performance—refer to the documentation of
drjit.schedule()
for an example of such a use case.drjit.eval()
invokes Dr.Jit’s LLVM or CUDA backends to compile and then execute a kernel containing the all steps that are needed to evaluate the specified variables, which will turn them into a memory-based representation. The generated kernel(s) will also include computation that was previously scheduled viadrjit.schedule()
. In fact,drjit.eval()
internally callsdrjit.schedule()
, asdr.eval(arg_1, arg_2, ...)
is equivalent to
dr.schedule(arg_1, arg_2, ...) dr.eval()
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function collects all unevaluated Dr.Jit arrays, while ignoring previously evaluated arrays along and non-array types. The function also does not evaluate literal constant arrays (this refers to potentially large arrays that are entirely uniform), as this is generally not wanted. Use the function
drjit.make_opaque()
if you wish to evaluate literal constant arrays as well.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)
- Returns:
True
if a variable was evaluated,False
if the operation did not do anything.- Return type:
bool
- drjit.set_flag(arg0: drjit.JitFlag, arg1: bool, /) None ¶
Set the value of the given Dr.Jit compilation flag.
- drjit.flag(arg: drjit.JitFlag, /) bool ¶
Query whether the given Dr.Jit compilation flag is active.
- class drjit.scoped_set_flag¶
Context manager, which sets or unsets a Dr.Jit compilation flag in a local execution scope.
For example, the following snippet shows how to temporarily disable a flag:
with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False): # Code affected by the change should be placed here # Flag is returned to its original status
- __init__(self, flag: drjit.JitFlag, value: bool = True) None ¶
- __enter__(self) None ¶
- __exit__(self, arg0: object | None, arg1: object | None, arg2: object | None) None ¶
Type traits¶
The functions in this section can be used to infer properties or types of Dr.Jit arrays.
The naming convention with a trailing _v
or _t
indicates whether a
function returns a value or a type. Evaluation takes place at runtime within
Python. In C++, these expressions are all constexpr
(i.e., evaluated at
compile time.).
Array type tests¶
- drjit.is_array_v(arg: object | None) bool ¶
Check if the input is a Dr.Jit array instance or type
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
or type(arg
) is a Dr.Jit array type, andFalse
otherwise
- Return type:
bool
- drjit.is_mask_v(arg: object, /) bool ¶
Check whether the input array instance or type is a Dr.Jit mask array or a Python
bool
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents a Dr.Jit mask array or Pythonbool
instance or type.- Return type:
bool
- drjit.is_half_v(arg: object, /) bool ¶
Check whether the input array instance or type is a Dr.Jit half-precision floating point array or a Python
half
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents a Dr.Jit half-precision floating point array or Pythonhalf
instance or type.- Return type:
bool
- drjit.is_float_v(arg: object, /) bool ¶
Check whether the input array instance or type is a Dr.Jit floating point array or a Python
float
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents a Dr.Jit floating point array or Pythonfloat
instance or type.- Return type:
bool
- drjit.is_integral_v(arg: object, /) bool ¶
Check whether the input array instance or type is an integral Dr.Jit array or a Python
int
value/type.Note that a mask array is not considered to be integral.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an integral Dr.Jit array or Pythonint
instance or type.- Return type:
bool
- drjit.is_arithmetic_v(arg: object, /) bool ¶
Check whether the input array instance or type is an arithmetic Dr.Jit array or a Python
int
orfloat
value/type.Note that a mask type (e.g.
bool
,drjit.scalar.Array2b
, etc.) is not considered to be arithmetic.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an arithmetic Dr.Jit array or Pythonint
orfloat
instance or type.- Return type:
bool
- drjit.is_signed_v(arg: object, /) bool ¶
Check whether the input array instance or type is an signed Dr.Jit array or a Python
int
orfloat
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an signed Dr.Jit array or Pythonint
orfloat
instance or type.- Return type:
bool
- drjit.is_unsigned_v(arg: object, /) bool ¶
Check whether the input array instance or type is an unsigned integer Dr.Jit array or a Python
bool
value/type (masks and boolean values are also considered to be unsigned).- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an unsigned Dr.Jit array or Pythonbool
instance or type.- Return type:
bool
- drjit.is_dynamic_v(arg: object, /) bool ¶
Check whether the input instance or type represents a dynamically sized Dr.Jit array type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_jit_v(arg: object, /) bool ¶
Check whether the input array instance or type represents a type that undergoes just-in-time compilation.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an array type from thedrjit.cuda.*
ordrjit.llvm.*
namespaces, andFalse
otherwise.- Return type:
bool
- drjit.is_diff_v(arg: object, /) bool ¶
Check whether the input is a differentiable Dr.Jit array instance or type.
Note that this is a type-based statement that is unrelated to mathematical differentiability. For example, the integral type
drjit.cuda.ad.Int
from the CUDA AD namespace satisfiesis_diff_v(..) = 1
.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an array type from thedrjit.[cuda/llvm].ad.*
namespace, andFalse
otherwise.- Return type:
bool
- drjit.is_vector_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a vectorial array type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_complex_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a complex number.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_matrix_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a matrix.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_quaternion_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a quaternion.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_tensor_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a tensor.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_special_v(arg: object, /) bool ¶
Check whether the input is a special Dr.Jit array instance or type.
A special array type requires precautions when performing arithmetic operations like multiplications (complex numbers, quaternions, matrices).
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_struct_v(arg: object, /) bool ¶
Check if the input is a Dr.Jit-compatible data structure
Custom data structures can be made compatible with various Dr.Jit operations by specifying a
DRJIT_STRUCT
member. See the section on PyTrees for details. This type trait can be used to check for the existence of such a field.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
has aDRJIT_STRUCT
member- Return type:
bool
Array properties (shape, type, etc.)¶
- drjit.type_v(arg: object, /) drjit.VarType ¶
Returns the scalar type associated with the given Dr.Jit array instance or type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
The associated type
drjit.VarType.Void
.- Return type:
- drjit.backend_v(arg: object, /) drjit.JitBackend ¶
Returns the backend responsible for the given Dr.Jit array instance or type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
The associated Jit backend or
drjit.JitBackend.None
.- Return type:
- drjit.size_v(arg: object, /) int ¶
Return the (static) size of the outermost dimension of the provided Dr.Jit array instance or type
Note that this function mainly exists to query type-level information. Use the Python
len()
function to query the size in a way that does not distinguish between static and dynamic arrays.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Returns either the static size or
drjit.Dynamic
whenarg
is a dynamic Dr.Jit array. Returns1
for all other types.- Return type:
int
- drjit.depth_v(arg: object, /) int ¶
Return the depth of the provided Dr.Jit array instance or type
For example, an array consisting of floating point values (for example,
drjit.scalar.Array3f
) has depth1
, while an array consisting of sub-arrays (e.g.,drjit.cuda.Array3f
) has depth2
.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Returns the depth of the input, if it is a Dr.Jit array instance or type. Returns
0
for all other inputs.- Return type:
int
- drjit.itemsize_v(arg: object, /) int ¶
Return the per-item size (in bytes) of the scalar type underlying a Dr.Jit array
- Parameters:
arg (object) – A Dr.Jit array instance or array type.
- Returns:
Returns the item size array elements in bytes.
- Return type:
int
Bit-level operations¶
- drjit.reinterpret_array(arg0: type[drjit.ArrayBase], arg1: drjit.ArrayBase, /) object ¶
Reinterpret the provided Dr.Jit array or tensor as a different type.
This operation reinterprets the input type as another type provided that it has a compatible in-memory layout (this operation is also known as a bit-cast).
- Parameters:
dtype (type) – Target type.
value (object) – A compatible Dr.Jit input array or tensor.
- Returns:
Result of the conversion as described above.
- Return type:
object
- drjit.popcnt(arg: ArrayT, /) ArrayT ¶
- drjit.popcnt(arg: int, /) int
Return the number of nonzero zero bits.
This function evaluates the component-wise population count of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of nonzero zero bits in
arg
- Return type:
int | drjit.ArrayBase
- drjit.lzcnt(arg: ArrayT, /) ArrayT ¶
- drjit.lzcnt(arg: int, /) int
Return the number of leading zero bits.
This function evaluates the component-wise leading zero count of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation is well-defined when
arg
is zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of leading zero bits in
arg
- Return type:
int | drjit.ArrayBase
- drjit.tzcnt(arg: ArrayT, /) ArrayT ¶
- drjit.tzcnt(arg: int, /) int
Return the number of trailing zero bits.
This function evaluates the component-wise trailing zero count of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation is well-defined when
arg
is zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of trailing zero bits in
arg
- Return type:
int | drjit.ArrayBase
- drjit.brev(arg: ArrayT, /) ArrayT ¶
- drjit.brev(arg: int, /) int
Reverse the bit representation of an integer value or array.
This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.- Parameters:
arg (int | drjit.ArrayBase) – A Python
int
or Dr.Jit integer array.- Returns:
the bit-reversed version of
arg
.- Return type:
int | drjit.ArrayBase
- drjit.log2i(arg: T, /) T ¶
Return the floor of the base-2 logarithm.
This function evaluates the component-wise floor of the base-2 logarithm of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation overflows when
arg
is zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of leading zero bits in the input array
- Return type:
int | drjit.ArrayBase
Standard mathematical functions¶
- drjit.fma(arg0: object, arg1: object, arg2: object, /) object ¶
- drjit.fma(arg0: int, arg1: int, arg2: int, /) int
- drjit.fma(arg0: float, arg1: float, arg2: float, /) float
Perform a fused multiply-addition (FMA) of the inputs.
Given arguments
arg0
,arg1
, andarg2
, this operation computesarg0
*arg1
+arg2
using only one final rounding step. The operation is not only more accurate, but also more efficient, since FMA maps to a native machine instruction on all platforms targeted by Dr.Jit.When the input is complex- or quaternion-valued, the function internally uses a complex or quaternion product. In this case, it reduces the number of internal rounding steps instead of avoiding them altogether.
While FMA is traditionally a floating point operation, Dr.Jit also implements FMA for integer arrays and maps it onto dedicated instructions provided by the backend if possible (e.g.
mad.lo.*
for CUDA/PTX).- Parameters:
arg0 (float | drjit.ArrayBase) – First multiplication operand
arg1 (float | drjit.ArrayBase) – Second multiplication operand
arg2 (float | drjit.ArrayBase) – Additive operand
- Returns:
Result of the FMA operation
- Return type:
float | drjit.ArrayBase
- drjit.abs(arg: ArrayT, /) ArrayT ¶
- drjit.abs(arg: int, /) int
- drjit.abs(arg: float, /) float
Compute the absolute value of the provided input.
This function evaluates the component-wise absolute value of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
- Parameters:
arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Absolute value of the input
- Return type:
int | float | drjit.ArrayBase
- drjit.minimum(arg0: int, arg1: int, /) int ¶
- drjit.minimum(arg0: object, arg1: object, /) object
- drjit.minimum(arg0: float, arg1: float, /) float
Compute the element-wise minimum value of the provided inputs.
(Not to be confused with
drjit.min()
, which reduces the input along the specified axes to determine the minimum)- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Minimum of the input(s)
- Return type:
int | float | drjit.ArrayBase
- drjit.maximum(arg0: int, arg1: int, /) int ¶
- drjit.maximum(arg0: object, arg1: object, /) object
- drjit.maximum(arg0: float, arg1: float, /) float
Compute the element-wise maximum value of the provided inputs.
(Not to be confused with
drjit.max()
, which reduces the input along the specified axes to determine the maximum)- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Maximum of the input(s)
- Return type:
int | float | drjit.ArrayBase
- drjit.sqrt(arg: ArrayT, /) ArrayT ¶
- drjit.sqrt(arg: float, /) float
Evaluate the square root of the provided input.
This function evaluates the component-wise square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
Negative inputs produce a NaN output value. Consider using the
safe_sqrt()
function to work around issues where the input might occasionally be negative due to prior round-off errors.Another noteworthy behavior of the square root function is that it has an infinite derivative at
arg=0
, which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. Thesafe_sqrt()
function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Square root of the input
- Return type:
float | drjit.ArrayBase
- drjit.cbrt(arg: ArrayT, /) ArrayT ¶
- drjit.cbrt(arg: float, /) float
Evaluate the cube root of the provided input.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Cube root of the input
- Return type:
float | drjit.ArrayBase
- drjit.rcp(arg: ArrayT, /) ArrayT ¶
- drjit.rcp(arg: float, /) float
Evaluate the reciprocal (1 / arg) of the provided input.
When
arg
is a CUDA single precision array, the operation is implemented slightly approximately—see the documentation of the instructionrcp.approx.ftz.f32
in the NVIDIA PTX manual for details. For full IEEE-754 compatibility, unsetdrjit.JitFlag.FastMath
.When called with a matrix-, complex- or quaternion-valued array, this function uses the matrix, complex, or quaternion multiplicative inverse to evaluate the reciprocal.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Reciprocal of the input
- Return type:
float | drjit.ArrayBase
- drjit.rsqrt(arg: ArrayT, /) ArrayT ¶
- drjit.rsqrt(arg: float, /) float
Evaluate the reciprocal square root (1 / sqrt(arg)) of the provided input.
This function evaluates the component-wise reciprocal square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
When
arg
is a CUDA single precision array, the operation is implemented slightly approximately—see the documentation of the instructionrsqrt.approx.ftz.f32
in the NVIDIA PTX manual for details. For full IEEE-754 compatibility, unsetdrjit.JitFlag.FastMath
.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Reciprocal square root of the input
- Return type:
float | drjit.ArrayBase
- drjit.clip(value, min, max)¶
Clip the provided input to the given interval.
This function is equivalent to
dr.maximum(dr.minimum(value, max), min)
- Parameters:
value (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
min (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
max (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Clipped input
- Return type:
float | drjit.ArrayBase
- drjit.ceil(arg: ArrayT, /) ArrayT ¶
- drjit.ceil(arg: float, /) float
Evaluate the ceiling, i.e. the smallest integer >= arg.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Ceiling of the input
- Return type:
float | drjit.ArrayBase
- drjit.floor(arg: ArrayT, /) ArrayT ¶
- drjit.floor(arg: float, /) float
Evaluate the floor, i.e. the largest integer <= arg.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Floor of the input
- Return type:
float | drjit.ArrayBase
- drjit.trunc(arg: ArrayT, /) ArrayT ¶
- drjit.trunc(arg: float, /) float
Truncates arg to the nearest integer by towards zero.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Truncated result
- Return type:
float | drjit.ArrayBase
- drjit.round(arg: ArrayT, /) ArrayT ¶
- drjit.round(arg: float, /) float
Rounds the input to the nearest integer using Banker’s rounding for half-way values.
This function is equivalent to
std::rint
in C++. It does not convert the type of the input array. A separate cast is necessary when integer output is desired.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Rounded result
- Return type:
float | drjit.ArrayBase
- drjit.sign(arg, /)¶
Return the element-wise sign of the provided array.
The function returns
\[\mathrm{sign}(\texttt{arg}) = \begin{cases} 1&\texttt{arg}>=0,\\ -1&\mathrm{otherwise}. \end{cases}\]- Parameters:
arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Sign of the input array
- Return type:
float | int | drjit.ArrayBase
- drjit.copysign(arg0, arg1, /)¶
Copy the sign of
arg1
toarg0
element-wise.- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to change the sign of
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to copy the sign from
- Returns:
The values of
arg0
with the sign ofarg1
- Return type:
float | int | drjit.ArrayBase
- drjit.mulsign(arg0, arg1, /)¶
Multiply
arg0
by the sign ofarg1
element-wise.This function is equivalent to
a * dr.sign(b)
- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to multiply the sign of
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to take the sign from
- Returns:
The values of
arg0
multiplied with the sign ofarg1
- Return type:
float | int | drjit.ArrayBase
Operations for vectors and matrices¶
- drjit.cross(arg0: ArrayT, arg1: ArrayT, /) ArrayT ¶
Returns the cross-product of the two input 3D arrays
- Parameters:
arg0 (drjit.ArrayBase) – A Dr.Jit 3D array
arg1 (drjit.ArrayBase) – A Dr.Jit 3D array
- Returns:
Cross-product of the two input 3D arrays
- Return type:
- drjit.det(arg, /)¶
Compute the determinant of the provided Dr.Jit matrix.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The determinant value of the input matrix
- Return type:
- drjit.diag(arg, /)¶
This function either returns the diagonal entries of the provided Dr.Jit matrix, or it constructs a new matrix from the diagonal entries.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The diagonal matrix of the input matrix
- Return type:
- drjit.trace(arg, /)¶
Returns the trace of the provided Dr.Jit matrix.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The trace of the input matrix
- Return type:
drjit.value_t(arg)
- drjit.matmul(arg0: object, arg1: object, /) object ¶
Compute a matrix-matrix, matrix-vector, vector-matrix, or inner product.
This function implements the semantics of the
@
operator introduced in Python’s PEP 465. There is no practical difference between usingdrjit.matul()
or@
in Dr.Jit-based code. Multiplication of matrix types (e.g.,drjit.scalar.Matrix2f
) using the standard multiplication operator (*
) is also based on on matrix multiplication.This function takes two Dr.Jit arrays and picks one of the following 5 cases based on their leading fixed-size dimensions.
Matrix-matrix product: If both arrays have leading static dimensions
(n, n)
, they are multiplied like conventional matrices.Matrix-vector product: If
arg0
has leading static dimensions(n, n)
andarg1
has leading static dimension(n,)
, the operation conceptually appends a trailing 1-sized dimension toarg1
, multiplies, and then removes the extra dimension from the result.Vector-matrix product: If
arg0
has leading static dimensions(n,)
andarg1
has leading static dimension(n, n)
, the operation conceptually prepends a leading 1-sized dimension toarg0
, multiplies, and then removes the extra dimension from the result.Inner product: If
arg0
andarg1
have leading static dimensions(n,)
, the operation returns the sum of the elements ofarg0*arg1
.Scalar product: If
arg0
orarg1
is a scalar, the operation scales the elements of the other argument.
It is legal to combine vectorized and non-vectorized types, e.g.
dr.matmul(dr.scalar.Matrix4f(...), dr.cuda.Matrix4f(...))
Also, note that doesn’t matter whether an input is an instance of a matrix type or a similarly-shaped nested array—for example,
drjit.scalar.Matrix3f()
anddrjit.scalar.Array33f()
have the same shape and are treated identically.Note
This operation only handles fixed-sizes arrays. A different approach is needed for multiplications involving potentially large dynamic arrays/tensors. Other other tools like PyTorch, JAX, or Tensorflow will be preferable in such situations (e.g., to train neural networks).
- Parameters:
arg0 (dr.ArrayBase) – Dr.Jit array type
arg1 (dr.ArrayBase) – Dr.Jit array type
- Returns:
The result of the operation as defined above
- Return type:
object
- drjit.hypot(a, b)¶
Computes \(\sqrt{x^2+y^2}\) while avoiding overflow and underflow.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
The computed hypotenuse.
- Return type:
float | drjit.ArrayBase
- drjit.normalize(arg: T, /) T ¶
Normalize the input vector so that it has unit length and return the result.
This operation is equivalent to
arg * dr.rsqrt(dr.squared_norm(arg))
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit array type
- Returns:
Unit-norm version of the input
- Return type:
- drjit.lerp(a, b, t)¶
Linearly blend between two values.
This function computes
\[\mathrm{lerp}(t) = (1-t) a + t b\]In other words, it linearly blends between \(a\) and \(b\) based on the value \(t\) that is typically on the interval \([0, 1]\).
It does so using two fused multiply-additions (
drjit.fma()
) to improve performance and avoid numerical errors.- Parameters:
value (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
min (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
max (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Interpolated result
- Return type:
float | drjit.ArrayBase
- drjit.sh_eval(d: ArrayBase, order: int) list ¶
Evalute real spherical harmonics basis function up to a specified order.
The input
d
must be a normalized 3D Cartesian coordinate vector. The function returns a list containing all spherical haromnic basis functions evaluated with respect tod
up to the desired order, for a total of(order+1)**2
output values.The implementation relies on efficient pre-generated branch-free code with aggressive constant folding and common subexpression elimination. It admits scalar and Jit-compiled input arrays. Evaluation routines are included for orders
0
to10
. Requesting higher orders triggers a runtime exception.This automatically generated code is based on the paper Efficient Spherical Harmonic Evaluation, Journal of Computer Graphics Techniques (JCGT), vol. 2, no. 2, 84-90, 2013 by Peter-Pike Sloan.
The SciPy equivalent of this function is given by
def sh_eval(d, order: int): from scipy.special import sph_harm theta, phi = np.arccos(d.z), np.arctan2(d.y, d.x) r = [] for l in range(order + 1): for m in range(-l, l + 1): Y = sph_harm(abs(m), l, phi, theta) if m > 0: Y = np.sqrt(2) * Y.real elif m < 0: Y = np.sqrt(2) * Y.imag r.append(Y.real) return d
The Mathematica equivalent of a specific entry is given by:
SphericalHarmonicQ[l_, m_, d_] := Block[{θ, ϕ}, θ = ArcCos[d[[3]]]; ϕ = ArcTan[d[[1]], d[[2]]]; Piecewise[{ {SphericalHarmonicY[l, m, θ, ϕ], m == 0}, {Sqrt[2] * Re[SphericalHarmonicY[l, m, θ, ϕ]], m > 0}, {Sqrt[2] * Im[SphericalHarmonicY[l, -m, θ, ϕ]], m < 0} }] ]
Operations for complex values and quaternions¶
- drjit.conj(arg, /)¶
Returns the conjugate of the provided complex or quaternion-valued array. For all other types, it returns the input unchanged.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit 3D array
- Returns:
Conjugate form of the input
- Return type:
- drjit.arg(z, /)¶
Return the argument of a complex Dr.Jit array.
The argument refers to the angle (in radians) between the positive real axis and a vector towards
z
in the complex plane. When the input isn’t complex-valued, the function returns \(0\) or \(\pi\) depending on the sign ofz
.- Parameters:
z (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Argument of the complex input array
- Return type:
float | drjit.ArrayBase
- drjit.real(arg, /)¶
Return the real part of a complex or quaternion-valued input.
When the input isn’t complex- or quaternion-valued, the function returns the input unchanged.
- Parameters:
arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Real part of the input array
- Return type:
float | drjit.ArrayBase
- drjit.imag(arg, /)¶
Return the imaginary part of a complex or quaternion-valued input.
When the input isn’t complex- or quaternion-valued, the function returns zero.
- Parameters:
arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Imaginary part of the input array
- Return type:
float | drjit.ArrayBase
Transcendental functions¶
Dr.Jit implements the most common transcendental functions using methods that are based on the CEPHES math library. The accuracy of these approximations is documented in a set of tables below.
Trigonometric functions¶
- drjit.sin(arg: ArrayT, /) ArrayT ¶
- drjit.sin(arg: float, /) float
Evaluate the sine function.
This function evaluates the component-wise sine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library and is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Sine of the input
- Return type:
float | drjit.ArrayBase
- drjit.cos(arg: ArrayT, /) ArrayT ¶
- drjit.cos(arg: float, /) float
Evaluate the cosine function.
This function evaluates the component-wise cosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Cosine of the input
- Return type:
float | drjit.ArrayBase
- drjit.sincos(arg: ArrayT, /) tuple[ArrayT, ArrayT] ¶
- drjit.sincos(arg: float, /) tuple[float, float]
Evaluate both sine and cosine functions at the same time.
This function simultaneously evaluates the component-wise sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to
drjit.sin()
anddrjit.cos()
when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the hardware’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Sine and cosine of the input
- Return type:
(float, float) | (drjit.ArrayBase, drjit.ArrayBase)
- drjit.tan(arg: ArrayT, /) ArrayT ¶
- drjit.tan(arg: float, /) float
Evaluate the tangent function.
This function evaluates the component-wise tangent function associated with each entry of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Tangent of the input
- Return type:
float | drjit.ArrayBase
- drjit.asin(arg: ArrayT, /) ArrayT ¶
- drjit.asin(arg: float, /) float
Evaluate the arcsine function.
This function evaluates the component-wise arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when called with a complex-valued input.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the
safe_asin()
function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The
safe_asin()
function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arcsine of the input
- Return type:
float | drjit.ArrayBase
- drjit.acos(arg: ArrayT, /) ArrayT ¶
- drjit.acos(arg: float, /) float
Evaluate the arccosine function.
This function evaluates the component-wise arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the
safe_acos()
function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The
safe_acos()
function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arccosine of the input
- Return type:
float | drjit.ArrayBase
- drjit.atan(arg: ArrayT, /) ArrayT ¶
- drjit.atan(arg: float, /) float
Evaluate the arctangent function.
This function evaluates the component-wise arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arctangent of the input
- Return type:
float | drjit.ArrayBase
- drjit.atan2(arg0: object, arg1: object, /) object ¶
- drjit.atan2(arg0: float, arg1: float, /) float
Evaluate the four-quadrant arctangent function.
This function is currently only implemented for real-valued inputs.
See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
y (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
x (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arctangent of
y
/x
, using the argument signs to determine the quadrant of the return value- Return type:
float | drjit.ArrayBase
Hyperbolic functions¶
- drjit.sinh(arg: ArrayT, /) ArrayT ¶
- drjit.sinh(arg: float, /) float
Evaluate the hyperbolic sine function.
This function evaluates the component-wise hyperbolic sine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic sine of the input
- Return type:
float | drjit.ArrayBase
- drjit.cosh(arg: ArrayT, /) ArrayT ¶
- drjit.cosh(arg: float, /) float
Evaluate the hyperbolic cosine function.
This function evaluates the component-wise hyperbolic cosine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic cosine of the input
- Return type:
float | drjit.ArrayBase
- drjit.sincosh(arg: ArrayT, /) tuple[ArrayT, ArrayT] ¶
- drjit.sincosh(arg: float, /) tuple[float, float]
Evaluate both hyperbolic sine and cosine functions at the same time.
This function simultaneously evaluates the component-wise hyperbolic sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to
drjit.sinh()
anddrjit.cosh()
when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic sine and cosine of the input
- Return type:
(float, float) | (drjit.ArrayBase, drjit.ArrayBase)
- drjit.tanh(arg: ArrayT, /) ArrayT ¶
- drjit.tanh(arg: float, /) float
Evaluate the hyperbolic tangent function.
This function evaluates the component-wise hyperbolic tangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic tangent of the input
- Return type:
float | drjit.ArrayBase
- drjit.asinh(arg: ArrayT, /) ArrayT ¶
- drjit.asinh(arg: float, /) float
Evaluate the hyperbolic arcsine function.
This function evaluates the component-wise hyperbolic arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arcsine of the input
- Return type:
float | drjit.ArrayBase
- drjit.acosh(arg: ArrayT, /) ArrayT ¶
- drjit.acosh(arg: float, /) float
Hyperbolic arccosine approximation.
This function evaluates the component-wise hyperbolic arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arccosine of the input
- Return type:
float | drjit.ArrayBase
- drjit.atanh(arg: ArrayT, /) ArrayT ¶
- drjit.atanh(arg: float, /) float
Evaluate the hyperbolic arctangent function.
This function evaluates the component-wise hyperbolic arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arctangent of the input
- Return type:
float | drjit.ArrayBase
Exponentials, logarithms, power function¶
- drjit.log2(arg: ArrayT, /) ArrayT ¶
- drjit.log2(arg: float, /) float
Evaluate the base-2 logarithm.
This function evaluates the component-wise base-2 logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Base-2 logarithm of the input
- Return type:
float | drjit.ArrayBase
- drjit.log(arg: ArrayT, /) ArrayT ¶
- drjit.log(arg: float, /) float
Evaluate the natural logarithm.
This function evaluates the component-wise natural logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Natural logarithm of the input
- Return type:
float | drjit.ArrayBase
- drjit.exp2(arg: ArrayT, /) ArrayT ¶
- drjit.exp2(arg: float, /) float
Evaluate
2
raised to a given power.This function evaluates the component-wise base-2 exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Base-2 exponential of the input
- Return type:
float | drjit.ArrayBase
- drjit.exp(arg: ArrayT, /) ArrayT ¶
- drjit.exp(arg: float, /) float
Evaluate the natural exponential function.
This function evaluates the component-wise natural exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Natural exponential of the input
- Return type:
float | drjit.ArrayBase
- drjit.power(arg0: int, arg1: int, /) float ¶
- drjit.power(arg0: float, arg1: float, /) float
- drjit.power(arg0: object, arg1: object, /) object
Raise the first argument to a power specified via the second argument.
This function evaluates the component-wise power of the input scalar, array, or tensor arguments. When called with a complex or quaternion-valued inputs, it uses a suitable generalization of the operation.
When
arg1
is a Pythonint
or integralfloat
value, the function reduces operation to a sequence of multiplies and adds (potentially followed by a reciprocation operation whenarg1
is negative).The general case involves recursive use of the identity
pow(arg0, arg1) = exp2(log2(arg0) * arg1)
.There is no difference between using
drjit.power()
and the builtin Python**
operator.- Parameters:
arg (object) – A Python or Dr.Jit arithmetic type
- Returns:
The result of the operation
arg0**arg1
- Return type:
object
Other¶
- drjit.erf(arg: ArrayT, /) ArrayT ¶
- drjit.erf(arg: float, /) float
Evaluate the error function.
The error function <https://en.wikipedia.org/wiki/Error_function> is defined as
\[\operatorname{erf}(z) = \frac{2}{\sqrt\pi}\int_0^z e^{-t^2}\,\mathrm{d}t.\]See the section on transcendental function approximations for details regarding accuracy.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\mathrm{erf}(\textt{arg})\)
- Return type:
float | drjit.ArrayBase
- drjit.erfinv(arg: ArrayT, /) ArrayT ¶
- drjit.erfinv(arg: float, /) float
Evaluate the inverse error function.
This function evaluates the inverse of
drjit.erf()
. Its implementation is based on the paper Approximating the erfinv function by Mike Giles.This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\mathrm{erf}^{-1}(\textt{arg})\)
- Return type:
float | drjit.ArrayBase
- drjit.lgamma(arg: ArrayT, /) ArrayT ¶
- drjit.lgamma(arg: float, /) float
Evaluate the natural logarithm of the absolute value the gamma function.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\log|\Gamma(\texttt{arg})|\)
- Return type:
float | drjit.ArrayBase
- drjit.rad2deg(arg: T, /) T ¶
Convert angles from radians to degrees.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
The equivalent angle in degrees.
- Return type:
float | drjit.ArrayBase
- drjit.deg2rad(arg: T, /) T ¶
Convert angles from degrees to radians.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
The equivalent angle in radians.
- Return type:
float | drjit.ArrayBase
Safe mathematical functions¶
Dr.Jit provides “safe” variants of a few standard mathematical operations that are prone to out-of-domain errors in calculations with floating point rounding errors. Such errors could, e.g., cause the argument of a square root to become negative, which would ordinarily require complex arithmetic. At zero, the derivative of the square root function is infinite. The following operations clamp the input to a safe range to avoid these extremes.
- drjit.safe_sqrt(arg: T, /) T ¶
Safely evaluate the square root of the provided input avoiding domain errors.
Negative inputs produce zero-valued output. When differentiated via AD, this function also avoids generating infinite derivatives at
x=0
.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Square root of the input
- Return type:
float | drjit.ArrayBase
- drjit.safe_asin(arg: T, /) T ¶
Safe wrapper around
drjit.asin()
that avoids domain errors.Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arcsine approximation
- Return type:
float | drjit.ArrayBase
- drjit.safe_acos(arg: T, /) T ¶
Safe wrapper around
drjit.acos()
that avoids domain errors.Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arccosine approximation
- Return type:
float | drjit.ArrayBase
Automatic differentiation¶
- enum drjit.ADMode(value)¶
Enumeration to distinguish different types of primal/derivative computation.
See also
drjit.enqueue()
,drjit.traverse()
.Valid values are as follows:
- Primal = ADMode.Primal¶
Primal/original computation without derivative tracking. Note that this is not a valid input to Dr.Jit AD routines, but it is sometimes useful to have this entry when to indicate to a computation that derivative propagation should not be performed.
- Forward = ADMode.Forward¶
Propagate derivatives in forward mode (from inputs to outputs)
- Backward = ADMode.Backward¶
Propagate derivatives in backward/reverse mode (from outputs to inputs
- enum drjit.ADFlag(value)¶
By default, Dr.Jit’s AD system destructs the enqueued input graph during forward/backward mode traversal. This frees up resources, which is useful when working with large wavefronts or very complex computation graphs. However, this also prevents repeated propagation of gradients through a shared subgraph that is being differentiated multiple times.
To support more fine-grained use cases that require this, the following flags can be used to control what should and should not be destructed.
- Member Type:
int
Valid values are as follows:
- ClearNone = ADFlag.ClearNone¶
Clear nothing.
- ClearEdges = ADFlag.ClearEdges¶
Delete all traversed edges from the computation graph
- ClearInput = ADFlag.ClearInput¶
Clear the gradients of processed input vertices (in-degree == 0)
- ClearInterior = ADFlag.ClearInterior¶
Clear the gradients of processed interior vertices (out-degree != 0)
- ClearVertices = ADFlag.ClearVertices¶
Clear gradients of processed vertices only, but leave edges intact. Equal to
ClearInput | ClearInterior
.
- AllowNoGrad = ADFlag.AllowNoGrad¶
Don’t fail when the input to a
drjit.forward
orbackward
operation is not a differentiable array.
- Default = ADFlag.Default¶
Default: clear everything (edges, gradients of processed vertices). Equal to
ClearEdges | ClearVertices
.
- drjit.detach(arg: T, preserve_type: bool = True) T ¶
Transforms the input variable into its non-differentiable version (detaches it from the AD computational graph).
This function supports arbitrary Dr.Jit arrays/tensors and PyTrees as input. In the latter case, it applies the transformation recursively. When the input variable is not a PyTree or Dr.Jit array, it is returned as it is.
While the type of the returned array is preserved by default, it is possible to set the
preserve_type
argument to false to force the returned type to be non-differentiable. For example, this will convert an array of typedrjit.llvm.ad.Float
into one of typedrjit.llvm.Float
.- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.
preserve_type (bool) – Defines whether the returned variable should preserve the type of the input variable.
- Returns:
The detached variable.
- Return type:
object
- drjit.enable_grad(arg: object, /) None ¶
- drjit.enable_grad(*args) None
Enable gradient tracking for the provided variables.
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function enables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.
- drjit.disable_grad(arg: object, /) None ¶
- drjit.disable_grad(*args) None
Disable gradient tracking for the provided variables.
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function disables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.
- drjit.set_grad_enabled(arg0: object, arg1: bool, /) None ¶
Enable or disable gradient tracking on the provided variables.
- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, PyTree, sequence, or mapping.
value (bool) – Defines whether gradient tracking should be enabled or disabled.
- drjit.grad_enabled(arg: object, /) bool ¶
- drjit.grad_enabled(*args) bool
Return whether gradient tracking is enabled on any of the given variables.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors instances or PyTrees. The function recursively traverses them to all differentiable variables.
- Returns:
True
if any of the input variables has gradient tracking enabled,False
otherwise.- Return type:
bool
- drjit.grad(arg: T, preserve_type: bool = True) T ¶
Return the gradient value associated to a given variable.
When the variable doesn’t have gradient tracking enabled, this function returns
0
.- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor or PyTree.
preserve_type (bool) – Should the operation preserve the input type in the return value? (This is the default). Otherwise, Dr.Jit will, e.g., return a type of drjit.cuda.Float for an input of type drjit.cuda.ad.Float.
- Returns:
the gradient value associated to the input variable.
- Return type:
object
- drjit.set_grad(target: T, source: T) None ¶
Set the gradient associated with the provided variable.
This operation internally decomposes into two sub-steps:
dr.clear_grad(target) dr.accum_grad(target, source)
When
source
is not of the same type astarget
, Dr.Jit will try to broadcast its contents into the right shape.
- drjit.accum_grad(target: T, source: T) None ¶
Accumulate the contents of one variable into the gradient of another variable.
When
source
is not of the same type astarget
, Dr.Jit will try to broadcast its contents into the right shape.
- drjit.replace_grad(arg0: T, arg1: T, /) None ¶
Replace the gradient value of
arg0
with the one ofarg1
.This is a relatively specialized operation to be used with care when implementing advanced automatic differentiation-related features.
One example use would be to inform Dr.Jit that there is a better way to compute the gradient of a particular expression than what the normal AD traversal of the primal computation graph would yield.
The function promotes and broadcasts
arg0
andarg1
if they are not of the same type.
- drjit.clear_grad(arg: object, /) None ¶
Clear the gradient of the given variable.
- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.
- drjit.traverse(mode: drjit.ADMode, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Propagate gradients along the enqueued set of AD graph edges.
Given prior use of :py:func`drjit.enqueue()` to enqueue AD nodes for gradient propagation, this functions now performs the actual gradient propagation into either the forward or reverse direction (as specified by the
mode
parameter)By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. The term leaf node is defined as follows: refers to
In forward AD, leaf nodes have no forward edges. They are outputs of a computation, and no other differentiable variable depends on them.
In backward AD, leaf nodes have no backward edges. They are inputs to a computation.
By default, the traversal also removes the edges of visited nodes to isolate them. These defaults are usually good ones: cleaning up the graph his frees up resources, which is useful when working with large wavefronts or very complex computation graphs. It also avoids potentially undesired derivative contributions that can arise when the AD graphs of two unrelated computations are connected by an edge and subsequently separately differentiated.
In advanced applications that require multiple AD traversals of the same graph, specify specify different combinations of the enumeration
drjit.ADFlag
via theflags
parameter.- Parameters:
mode (drjit.ADMode) – Specifies the direction in which gradients should be propgated.
drjit.ADMode.Forward
and:py:attr:drjit.ADMode.Backward refer to forward and backward traversal.flags (drjit.ADFlag | int) – Controls what parts of the AD graph are cleared during traversal. The default value is
drjit.ADFlag.Default
.
- drjit.enqueue(mode: drjit.ADMode, arg: object) None ¶
- drjit.enqueue(mode: drjit.ADMode, *args) None
Enqueues the input variable(s) for subsequent gradient propagation
Dr.Jit splits the process of automatic differentiation into three parts:
Initializing the gradient of one or more input or output variables. The most common initialization entails setting the gradient of an output (e.g., an optimization loss) to
1.0
.Enqueuing nodes that should partake in the gradient propagation pass. Dr.Jit will follow variable dependences (edges in the AD graph) to find variables that are reachable from the enqueued variable.
Finally propagating gradients to all of the enqueued variables.
This function is responsible for step 2 of the above list and works differently depending on the specified
mode
:- -
drjit.ADMode.Forward
: Dr.Jit will recursively enqueue all variables that are reachable along forward edges. That is, given a differentiable operation
a = b+c
, enqueuingc
will also enqueuea
for later traversal.- -
drjit.ADMode.Backward
: Dr.Jit will recursively enqueue all variables that are reachable along backward edges. That is, given a differentiable operation
a = b+c
, enqueuinga
will also enqueueb
andc
for later traversal.
For example, a typical chain of operations to forward propagate the gradients from
a
tob
might look as follow:a = dr.llvm.ad.Float(1.0) dr.enable_grad(a) b = f(a) # some computation involving 'a' # The below three operations can also be written more compactly as dr.forward_from(a) dr.set_gradient(a, 1.0) dr.enqueue(dr.ADMode.Forward, a) dr.traverse(dr.ADMode.Forward) grad = dr.grad(b)
One interesting aspect of this design is that enqueuing and traversal don’t necessarily need to follow the same direction.
For example, we may only be interested in forward gradients reaching a specific output node
c
, which can be expressed as follows:a = dr.llvm.ad.Float(1.0) dr.enable_grad(a) b, c, d, e = f(a) dr.set_gradient(a, 1.0) dr.enqueue(dr.ADMode.Backward, b) dr.traverse(dr.ADMode.Forward) grad = dr.grad(b)
The same naturally also works in the reverse direction. Dr.Jit provides a higher level API that encapsulate such logic in a few different flavors:
drjit.forward_from()
(alias:drjit.forward()
) anddrjit.forward_to()
.drjit.backward_from()
(alias:drjit.backward()
) anddrjit.backward_to()
.
- Parameters:
mode (drjit.ADMode) –
- Specifies the set edges which Dr.Jit should follow to
enqueue variables to be visited by a later gradient propagation phase.
drjit.ADMode.Forward
and:py:attr:drjit.ADMode.Backward refer to forward andbackward edges, respectively.
value (object) – An arbitrary Dr.Jit array, tensor or PyTree.
- drjit.forward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Forward-propagate gradients from the provided Dr.Jit array or tensor.
This function sets the gradient of the provided Dr.Jit array or tensor
arg
to1.0
and then forward-propagates derivatives through forward-connected components of the computation graph (i.e., reaching all variables that directly or indirectly depend onarg
).The operation is equivalent to
dr.set_grad(arg, 1.0) dr.enqueue(dr.ADMode.Forward, h) dr.traverse(dr.ADMode.Forward, flags=flags)
Refer to the documentation functions
drjit.set_grad()
,drjit.enqueue()
, anddrjit.traverse()
for further details on the nuances of forward derivative propagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.forward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT ¶
- drjit.forward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]
Forward-propagate gradients to the provided set of Dr.Jit arrays/tensors.
dr.enqueue(dr.ADMode.Backward, *args) dr.traverse(dr.ADMode.Forward, flags=flags) return dr.grad(*args)
Internally, the operation first traverses the computation graph backwards from
args
to find potential paths along which gradients can flow to the given set of arrays. Then, it performs a gradient propagation pass along the detected variables.For this to work, you must have previously enabled and specified input gradients for inputs of the computation. (see
drjit.enable_grad()
and viadrjit.set_grad()
).Refer to the documentation functions
drjit.enqueue()
anddrjit.traverse()
for further details on the nuances of forward derivative propagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- Returns:
the gradient value(s) associated with
*args
following the traversal.- Return type:
object
- drjit.forward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Forward-propagate gradients from the provided Dr.Jit array or tensor
This function is an alias of
drjit.forward_from()
. Please refer to the documentation of this function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph are cleared during traversal. The default value is
drjit.ADFlag.Default
.
- drjit.backward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Backpropagate gradients from the provided Dr.Jit array or tensor.
This function sets the gradient of the provided Dr.Jit array or tensor
arg
to1.0
and then backpropagates derivatives through backward-connected components of the computation graph (i.e., reaching differentiable variables that potentially influence the value ofarg
).The operation is equivalent to
dr.set_grad(arg, 1.0) dr.enqueue(dr.ADMode.Backward, h) dr.traverse(dr.ADMode.Backward, flags=flags)
Refer to the documentation functions
drjit.set_grad()
,drjit.enqueue()
, anddrjit.traverse()
for further details on the nuances of derivative backpropagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.backward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT ¶
- drjit.backward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]
Backpropagate gradients to the provided set of Dr.Jit arrays/tensors.
dr.enqueue(dr.ADMode.Forward, *args) dr.traverse(dr.ADMode.Backwards, flags=flags) return dr.grad(*args)
Internally, the operation first traverses the computation graph forwards from
args
to find potential paths along which reverse-mode gradients can flow to the given set of input variables. Then, it performs a backpropagation pass along the detected variables.For this to work, you must have previously enabled and specified input gradients for outputs of the computation. (see
drjit.enable_grad()
and viadrjit.set_grad()
).Refer to the documentation functions
drjit.enqueue()
anddrjit.traverse()
for further details on the nuances of derivative backpropagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- Returns:
the gradient value(s) associated with
*args
following the traversal.- Return type:
object
- drjit.backward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Backpropgate gradients from the provided Dr.Jit array or tensor.
This function is an alias of
drjit.backward_from()
. Please refer to the documentation of this function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.suspend_grad(*args, when=True)¶
Python context manager to temporarily disable gradient tracking globally, or for a specific set of variables.
This context manager can be used as follows to completely disable all gradient tracking. Newly created variables will be detached from Dr.Jit’s AD graph.
with dr.suspend_grad(): # .. code coes here ..
You may also specify any number of Dr.Jit arrays, tensors, or Pytrees. In this case, the context manager behaves differently by disabling gradient tracking more selectively for the specified variables.
with dr.suspend_grad(x): z = x + y # 'z' will not track any gradients arising from 'x'
The
suspend_grad()
andresume_grad()
context manager can be arbitrarily nested and suitably update the set of tracked variables.A note about the interaction with
drjit.enable_grad()
: it is legal to register further AD variables within a scope that disables gradient tracking for specific variables.with dr.suspend_grad(x): y = Float(1) dr.enable_grad(y) # The following condition holds assert not dr.grad_enabled(x) and dr.grad_enabled(y)
In contrast, a
suspend_grad()
environment without arguments that completely disables AD does not allow further variables to be registered:with dr.suspend_grad(): y = Float(1) dr.enable_grad(y) # ignored # The following condition holds assert not dr.grad_enabled(x) and not dr.grad_enabled(y)
- Parameters:
*args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or Pytrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False
. The default value iswhen=True
.
- drjit.resume_grad(*args, when=True)¶
Python context manager to temporarily resume gradient tracking globally, or for a specific set of variables.
This context manager can be used as follows to fully re-enable all gradient tracking following a previous call to
drjit.suspend_grad()
. Newly created variables will then again be attached to Dr.Jit’s AD graph.with dr.suspend_grad(): # .. with dr.resume_grad(): # In this scope, the effect of the outer context # manager is effectively disabled
You may also specify any number of Dr.Jit arrays, tensors, or Pytrees. In this case, the context manager behaves differently by enabling gradient tracking more selectively for the specified variables.
with dr.suspend_grad(): with dr.resume_grad(x): z = x + y # 'z' will only track gradients arising from 'x'
The
suspend_grad()
andresume_grad()
context manager can be arbitrarily nested and suitably update the set of tracked variables.- Parameters:
*args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or Pytrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False
. The default value iswhen=True
.
- drjit.isolate_grad(when=True)¶
Python context manager to isolate and partition AD traversals into multiple distinct phases.
Consider a sequence of steps being differentiated in reverse mode, like so:
x = .. dr.enable_grad(x) y = f(x) z = g(y) dr.backward(z)
The
drjit.backward()
call would automatically traverse the AD graph nodes created during the execution of the functionf()
andg()
.However, sometimes this is undesirable and more control is needed. For example, Dr.Jit may be in an execution context (a symbolic loop or call) that temporarily disallows differentiation of the
f()
part. Thedrjit.isolate_grad()
context manager addresses this need:dr.enable_grad(x) y = f(x) with dr.isolate_grad(): z = g(y) dr.backward(z)
Any reverse-mode AD traversal of an edge that crosses the isolation boundary is postponed until leaving the scope. This is mathematically equivalent but produces two smaller separate AD graph traversals.
Dr.Jit operations like symbolic loops and calls internally create such an isolation boundary, hence it is rare that you would need to do so yourself.
isolate_grad()
is not useful for forward mode AD.- Parameters:
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False
. The default value iswhen=True
.
- class drjit.CustomOp¶
Base class for implementing custom differentiable operations.
Dr.Jit can compute derivatives of builtin operations in both forward and reverse mode. In some cases, it may be useful or even necessary to control how a particular operation should be differentiated.
To do so, you may extend this class to provide three callback functions:
CustomOp.eval()
: Implements the primal evaluation of the function with detached inputs.CustomOp.forward()
: Implements the forward derivative that propagates derivatives from input arguments to the return valueCustomOp.backward()
: Implements the backward derivative that propagates derivatives from the return value to the input arguments.
An example for a hypothetical custom addition operation is shown below
class Addition(dr.CustomOp): def eval(self, x, y): # Primal calculation without derivative tracking return x + y def forward(self): # Compute forward derivatives self.set_grad_out(self.grad_in('x') + self.grad_in('y')) def backward(self): # .. compute backward derivatives .. self.set_grad_in('x', self.grad_out()) self.set_grad_in('y', self.grad_out()) def name(self): # Optional: a descriptive name shown in GraphViz visualizations return "Addition"
You should never need to call these functions yourself—Dr.Jit will do so when appropriate. To weave such a custom operation into the AD graph, use the
drjit.custom()
function, which expects a subclass ofdrjit.CustomOp
as first argument, followed by arguments to the actual operation that are directly forwarded to the.eval()
callback.# Add two numbers 'x' and 'y'. Calls our '.eval()' callback with detached arguments result = dr.custom(Addition, x, y)
Forward or backward derivatives are then automatically handled through the standard operations. For example,
dr.backward(result)
will invoke the
.backward()
callback from above.Many derivatives are more complex than the above examples and require access to inputs or intermediate steps of the primal evaluation routine. You can simply stash them in the instance (
self.field = ...
), which is shown below for a differentiable multiplication operation that implements the product rule:class Multiplication(dr.CustomOp): def eval(self, x, y): # Stash input arguments self.x = x self.y = y return x * y def forward(self): self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y')) def backward(self): self.set_grad_in('x', self.y * self.grad_out()) self.set_grad_in('y', self.x * self.grad_out()) def name(self): return "Multiplication"
- eval(self, *args, **kwargs) object ¶
Evaluate the custom operation in primal mode.
You must implement this method when subclassing
CustomOp
, since the default implementation raises an exception. It should realize the original (non-derivative-aware) form of a computation and may take an arbitrary sequence of positional, keyword, and variable-length positional/keyword arguments.You should not need to call this function yourself—Dr.Jit will automatically do so when performing custom operations through the
drjit.custom()
interface.Note that the input arguments passed to
.eval()
will be detached (i.e. they don’t have derivative tracking enabled). This is intentional, since derivative tracking is handled by the custom operation along with the other callbacksforward()
andbackward()
.
- forward(self) None ¶
Evaluate the forward derivative of the custom operation.
You must implement this method when subclassing
CustomOp
, since the default implementation raises an exception. It takes no arguments and has no return value.An implementation will generally perform repeated calls to
grad_in()
to query the gradients of all function followed by a single call toset_grad_out()
to set the gradient of the return value.For example, this is how one would implement the product rule of the primal calculation
x*y
, assuming that the.eval()
routine stashed the inputs in the custom operation object.def forward(self): self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y'))
- backward(self) None ¶
Evaluate the backward derivative of the custom operation.
You must implement this method when subclassing
CustomOp
, since the default implementation raises an exception. It takes no arguments and has no return value.An implementation will generally perform a single call to
grad_out()
to query the gradient of the function return value followed by a sequence of calls toset_grad_in()
to assign the gradients of the function inputs.For example, this is how one would implement the product rule of the primal calculation
x*y
, assuming that the.eval()
routine stashed the inputs in the custom operation object.def backward(self): self.set_grad_in('x', self.y * self.grad_out()) self.set_grad_in('y', self.x * self.grad_out())
- name(self) str ¶
Return a descriptive name of the
CustomOp
instance.Amongst other things, this name is used to document the presence of the custom operation in GraphViz debug output. (See
graphviz_ad()
.)
- grad_out(self) object ¶
Query the gradient of the return value.
Returns an object, whose type matches the original return value produced in
eval()
. This function should only be used within thebackward()
callback.
- set_grad_out(self, arg: object, /) None ¶
Accumulate a gradient into the return value.
This function should only be used within the
forward()
callback.
- grad_in(self, arg: object, /) object ¶
Query the gradient of a specified input parameter.
The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (
*args
) can be queried by providing an integer index instead.This function should only be used within the
forward()
callback.
- set_grad_in(self, arg0: object, arg1: object, /) None ¶
Accumulate a gradient into the specified input parameter.
The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (
*args
) can be assigned by providing an integer index instead.This function should only be used within the
backward()
callback.
- add_input(self, arg: object, /) None ¶
Register an implicit input dependency of the operation on an AD variable.
This function should be called by the
eval()
implementation when an operation has a differentiable dependence on an input that is not a ordinary input argument of the function (e.g., a global program variable or a field of a class).
- add_output(self, arg: object, /) None ¶
Register an implicit output dependency of the operation on an AD variable.
This function should be called by the
eval()
implementation when an operation has a differentiable dependence on an output that is not part of the function return value (e.g., a global program variable or a field of a class).”
- drjit.custom(arg0: type[drjit.CustomOp], /, *args, **kwargs) object ¶
Evaluate a custom differentiable operation.
It can be useful or even necessary to control how a particular operation should be differentiated by Dr.Jit’s automatic differentiation (AD) layer. The
drjit.custom()
function enables such use cases by stitching an opque operation with user-defined primal and forward/backward derivative implementations into the AD graph.The function expects a subclass of the
CustomOp
interface as first argument. The remaining positional and keyword arguments are forwarded to theCustomOp.eval()
callback.See the documentation of
CustomOp
for examples on how to realize such a custom operation.
- drjit.wrap(source: str | ModuleType, target: str | ModuleType) Callable[[T], T] ¶
Differentiable bridge between Dr.Jit and other array programming frameworks.
This function wraps computation performed using one array programming framework to expose it in another. Currently, PyTorch and JAX are supported, though other frameworks may be added in the future.
Annotating a function with
@drjit.wrap
adds code that suitably converts arguments and return values. Furthermore, it stitches the operation into the automatic differentiation (AD) graph of the other framework to ensure correct gradient propagation.When exposing code written using another framework, the wrapped function can take and return any PyTree including flat or nested Dr.Jit arrays, tensors, and arbitrary nested lists/tuples, dictionaries, and custom data structures. The arguments don’t need to be differentiable—for example, integer/boolean arrays that don’t carry derivative information can be passed as well.
The wrapped function should be pure: in other words, it should read its input(s) and compute an associated output so that re-evaluating the function again produces the same answer. Multi-framework derivative tracking of impure computation will likely not behave as expected.
The following table lists the currently supported conversions:
Direction
Primal
Forward-mode AD
Reverse-mode AD
Remarks
drjit
→torch
✅
✅
✅
Everything just works.
torch
→drjit
✅
✅
✅
Limitation: The passed/returned PyTrees can contain arbitrary arrays or tensors, but other types (e.g., a custom Python object not understood by PyTorch) will will raise errors when differentiating in forward mode (backward mode works fine).
An issue was filed on the PyTorch bugtracker.
drjit
→jax
✅
✅
✅
You may want to further annotate the wrapped function with
jax.jit
to trace and just-in-time compile it in the JAX environment, i.e.,@dr.wrap(source='drjit', target='jax') @jax.jit
Limitation: The passed/returned PyTrees can contain arbitrary arrays or Python scalar types, but other types (e.g., a custom Python object not understood by JAX) will raise errors.
jax
→drjit
❌
❌
❌
This direction is currently unsupported. We plan to add it in the future.
Please also refer to the documentation sections on multi-framework differentiation associated caveats.
Note
Types that have no equivalent on the other side (e.g. a quaternion array) will convert to generic tensors.
Data exchange is limited to representations that exist on both sides. There are a few limitations:
PyTorch lacks support for most unsigned integer types (
uint16
,uint32
, oruint64
-typed arrays). Use signed integer types to work around this issue.Dr.Jit currently lacks support for most 8- and 16-bit numeric types (besides half precision floats).
JAX refuses to exchange boolean-valued tensors with other frameworks.
- Parameters:
source (str | module) – The framework used outside of the wrapped function. The argument is currently limited to either
'drjit'
,'torch'
, orjax'
. For convenience, the associated Python module can be specified as well.target (str | module) – The framework used inside of the wrapped function. The argument is currently limited to either
'drjit'
,'torch'
, or'jax'
. For convenience, the associated Python module can be specified as well.
- Returns:
The decorated function.
Constants¶
- drjit.e¶
The exponential constant \(e\) represented as a Python
float
.
- drjit.log_two¶
The value \(\log(2)\) represented as a Python
float
.
- drjit.inv_log_two¶
The value \(\frac{1}{\log(2)}\) represented as a Python
float
.
- drjit.pi¶
The value \(\pi\) represented as a Python
float
.
- drjit.inv_pi¶
The value \(\frac{1}{\pi}\) represented as a Python
float
.
- drjit.sqrt_pi¶
The value \(\sqrt{\pi}\) represented as a Python
float
.
- drjit.inv_sqrt_pi¶
The value \(\frac{1}{\sqrt{\pi}}\) represented as a Python
float
.
- drjit.two_pi¶
The value \(2\pi\) represented as a Python
float
.
- drjit.inv_two_pi¶
The value \(\frac{1}{2\pi}\) represented as a Python
float
.
- drjit.sqrt_two_pi¶
The value \(\sqrt{2\pi}\) represented as a Python
float
.
- drjit.inv_sqrt_two_pi¶
The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python
float
.
- drjit.four_pi¶
The value \(4\pi\) represented as a Python
float
.
- drjit.inv_four_pi¶
The value \(\frac{1}{4\pi}\) represented as a Python
float
.
- drjit.sqrt_four_pi¶
The value \(\sqrt{4\pi}\) represented as a Python
float
.
- drjit.sqrt_two¶
The value \(\sqrt{2\pi}\) represented as a Python
float
.
- drjit.inv_sqrt_two¶
The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python
float
.
- drjit.inf¶
The value
float('inf')
represented as a Pythonfloat
.
- drjit.nan¶
The value
float('nan')
represented as a Pythonfloat
.
- drjit.epsilon(arg, /)¶
Returns the machine epsilon.
The machine epsilon gives an upper bound on the relative approximation error due to rounding in floating point arithmetic.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The machine epsilon.
- Return type:
float
- drjit.one_minus_epsilon(arg, /)¶
Returns one minus the machine epsilon value.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
One minus the machine epsilon.
- Return type:
float
- drjit.recip_overflow(arg, /)¶
Returns the reciprocal overflow threshold value.
Any numbers equal to this threshold or a smaller value will overflow to infinity when reciprocated.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The reciprocal overflow threshold value.
- Return type:
float
- drjit.smallest(arg, /)¶
Returns the smallest representable normalized floating point value.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The smallest representable normalized floating point value.
- Return type:
float
- drjit.largest(arg, /)¶
Returns the largest representable finite floating point value for t.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The largest representable finite floating point value.
- Return type:
float
Array base class¶
- class drjit.ArrayBase¶
This is the base class of all Dr.Jit arrays and tensors. It provides an abstract version of the array API that becomes usable when the type is extended by a concrete specialization.
ArrayBase
itself cannot be instantiated.See the section on Dr.Jit type signatures <type_signatures> to learn about the type parameters of
ArrayBase
.- property array¶
This member plays multiple roles:
When
self
is a tensor, this property returns the storage representation of the tensor in the form of a linearized dynamic 1D array.When
self
is a special arithmetic object (matrix, quaternion, or complex number),array
provides an copy of the same data with ordinary array semantics.In all other cases,
array
is simply a reference toself
.
- Type:
- property ndim¶
This property represents the dimension of the provided Dr.Jit array or tensor.
- Type:
int
- property shape¶
This property provides a tuple describing dimension and shape of the provided Dr.Jit array or tensor.
When the input array is ragged the function raises a
RuntimeError
. The term ragged refers to an array, whose components have mismatched sizes, such as[[1, 2], [3, 4, 5]]
. Note that scalar entries (e.g.[[1, 2], [3]]
) are acceptable, since broadcasting can effectively convert them to any size.The expressions
drjit.shape(arg)
andarg.shape
are equivalent.- Type:
tuple[int, …]
- property state¶
This read-only property returns an enumeration value describing the evaluation state of this Dr.Jit array.
- Type:
- property x¶
If
self
is a static Dr.Jit array of size 1 (or larger), the propertyself.x
can be used synonymously withself[0]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property y¶
If
self
is a static Dr.Jit array of size 2 (or larger), the propertyself.y
can be used synonymously withself[1]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property z¶
If
self
is a static Dr.Jit array of size 3 (or larger), the propertyself.z
can be used synonymously withself[2]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property w¶
If
self
is a static Dr.Jit array of size 4 (or larger), the propertyself.w
can be used synonymously withself[3]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property T¶
This property returns the transpose of
self
. When the underlying array is not a matrix type, it raises aTypeError
.
- property index¶
If
self
is a leaf Dr.Jit array managed by a just-in-time compiled backend (i.e, CUDA or LLVM), this property contains the associated variable index in the graph data structure storing the computation trace. This graph can be visualized usingdrjit.graphviz()
. Otherwise, the value of this property equals zero. A non-leaf array (e.g.drjit.cuda.Array2i
) consists of several JIT variables, whose indices must be queried separately.Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index
index_ad
keeps track of the variable index within the AD computation graph, if applicable.- Type:
int
- property index_ad¶
If
self
is a leaf Dr.Jit array represented by an AD backend, this property contains the variable index in the graph data structure storing the computation trace for later differentiation (this graph can be visualized usingdrjit.graphviz_ad()
). A non-leaf array (e.g.drjit.cuda.ad.Array2f
) consists of several AD variables, whose indices must be queried separately.Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index
index
keeps track of the variable index within the raw computation graph, if applicable.- Type:
int
- property grad¶
This property can be used to retrieve or set the gradient associated with the Dr.Jit array or tensor.
The expressions
drjit.grad(arg)
andarg.grad
are equivalent whenarg
is a Dr.Jit array/tensor.- Type:
- __len__()¶
Return len(self).
- __iter__()¶
Implement iter(self).
- __repr__()¶
Return repr(self).
- __bool__()¶
True if self else False
Casts the array to a Python
bool
type. This is only permissible whenself
represents an boolean array of both depth and size 1.
- __add__(value, /)¶
Return self+value.
- __radd__(value, /)¶
Return value+self.
- __iadd__(value, /)¶
Return self+=value.
- __sub__(value, /)¶
Return self-value.
- __rsub__(value, /)¶
Return value-self.
- __isub__(value, /)¶
Return self-=value.
- __mul__(value, /)¶
Return self*value.
- __rmul__(value, /)¶
Return value*self.
- __imul__(value, /)¶
Return self*=value.
- __matmul__(value, /)¶
Return self@value.
- __rmatmul__(value, /)¶
Return value@self.
- __imatmul__(value, /)¶
Return self@=value.
- __truediv__(value, /)¶
Return self/value.
- __rtruediv__(value, /)¶
Return value/self.
- __itruediv__(value, /)¶
Return self/=value.
- __floordiv__(value, /)¶
Return self//value.
- __rfloordiv__(value, /)¶
Return value//self.
- __ifloordiv__(value, /)¶
Return self//=value.
- __mod__(value, /)¶
Return self%value.
- __rmod__(value, /)¶
Return value%self.
- __imod__(value, /)¶
Return self%=value.
- __rshift__(value, /)¶
Return self>>value.
- __rrshift__(value, /)¶
Return value>>self.
- __irshift__(value, /)¶
Return self>>=value.
- __lshift__(value, /)¶
Return self<<value.
- __rlshift__(value, /)¶
Return value<<self.
- __ilshift__(value, /)¶
Return self<<=value.
- __and__(value, /)¶
Return self&value.
- __rand__(value, /)¶
Return value&self.
- __iand__(value, /)¶
Return self&=value.
- __or__(value, /)¶
Return self|value.
- __ror__(value, /)¶
Return value|self.
- __ior__(value, /)¶
Return self|=value.
- __xor__(value, /)¶
Return self^value.
- __rxor__(value, /)¶
Return value^self.
- __ixor__(value, /)¶
Return self^=value.
- __abs__()¶
abs(self)
- __le__(value, /)¶
Return self<=value.
- __lt__(value, /)¶
Return self<value.
- __ge__(value, /)¶
Return self>=value.
- __gt__(value, /)¶
Return self>value.
- __ne__(value, /)¶
Return self!=value.
- __eq__(value, /)¶
Return self==value.
- __dlpack__(self, stream: object | None = None) ndarray[] ¶
Returns a DLPack capsule representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3f
ordrjit.cuda.Matrix4f
need to be rearranged into a contiguous memory representation before they can be exposed.In other case, e.g. for
drjit.llvm.Float
,drjit.scalar.Array3f
, ordrjit.scalar.ArrayXf
, the data is already contiguous and a zero-copy approach is used instead.
- __array__(self, dtype: object | None = None) object ¶
Returns a NumPy array representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3f
ordrjit.cuda.Matrix4f
need to be rearranged into a contiguous memory representation before they can be wrapped.In other case, e.g. for
drjit.llvm.Float
,drjit.scalar.Array3f
, ordrjit.scalar.ArrayXf
, the data is already contiguous and a zero-copy approach is used instead.
- numpy(self) numpy.ndarray[] ¶
Returns a NumPy array representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3f
ordrjit.cuda.Matrix4f
need to be rearranged into a contiguous memory representation before they can be wrapped.In other case, e.g. for
drjit.llvm.Float
,drjit.scalar.Array3f
, ordrjit.scalar.ArrayXf
, the data is already contiguous and a zero-copy approach is used instead.
- torch(self) object ¶
Returns a PyTorch tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()
for further information on how to accomplish this.
- jax(self) object ¶
Returns a JAX tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()
for further information on how to accomplish this.
- tf(self) object ¶
Returns a TensorFlow tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()
for further information on how to accomplish this.
Computation graph analysis¶
The following operations visualize the contents of Dr.Jit’s computation graphs (of which there are two: one for Jit compilation, and one for automatic differentiation).
- drjit.graphviz(as_string: bool = False) object ¶
Return a GraphViz diagram describing registered JIT variables and their connectivity.
This function returns a representation of the computation graph underlying the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the
graphviz_ad()
function to visualize the computation graph of the latter.Run
dr.graphviz().view()
to open up a PDF viewer that shows the resulting output in a separate window.The function depends on the
graphviz
Python package whenas_string=False
(the default).- Parameters:
as_string (bool) – if set to
True
, the function will return raw GraphViz markup as a string. (Default:False
)- Returns:
GraphViz object or raw markup.
- Return type:
object
- drjit.graphviz_ad(as_string: bool = False) object ¶
Return a GraphViz diagram describing variables registered with the automatic differentiation layer, as well as their connectivity.
This function returns a representation of the computation graph underlying the Dr.Jit AD layer, which one architectural layer above the just-in-time compiler. See the
graphviz()
function to visualize the computation graph of the latter.Run
dr.graphviz_ad().view()
to open up a PDF viewer that shows the resulting output in a separate window.The function depends on the
graphviz
Python package whenas_string=False
(the default).- Parameters:
as_string (bool) – if set to
True
, the function will return raw GraphViz markup as a string. (Default:False
)- Returns:
GraphViz object or raw markup.
- Return type:
object
- drjit.whos(as_string: bool) str | None ¶
Return/print a list of live JIT variables.
This function provides information about the set of variables that are currently registered with the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the
whos_ad()
function for the latter.- Parameters:
as_string (bool) – if set to
True
, the function will return the list in string form. Otherwise, it will print directly onto the console and returnNone
. (Default:False
)- Returns:
a human-readable list (if requested).
- Return type:
None | str
- drjit.whos_ad(as_string: bool) str | None ¶
Return/print a list of live variables registered with the automatic differentiation layer.
This function provides information about the set of variables that are currently registered with the Dr.Jit automatic differentiation layer, which one architectural layer above the just-in-time compiler. See the
whos()
function to obtain information about the latter.- Parameters:
as_string (bool) – if set to
True
, the function will return the list in string form. Otherwise, it will print directly onto the console and returnNone
. (Default:False
)- Returns:
a human-readable list (if requested).
- Return type:
None | str
- drjit.set_label(arg0: object, arg1: str, /) None ¶
- drjit.set_label(**kwargs) None
Assign a label to the provided Dr.Jit array.
This can be helpful to identify computation in GraphViz output (see
drjit.graphviz()
,graphviz_ad()
).The operations assumes that the array is tracked by the just-in-time compiler. It has no effect on unsupported inputs (e.g., arrays from the
drjit.scalar
package). It recurses through PyTrees (tuples, lists, dictionaries, custom data structures) and appends names (indices, dictionary keys, field names) separated by underscores to uniquely identify each element.The following
**kwargs
-based shorthand notation can be used to assign multiple labels at once:set_label(x=x, y=y)
- Parameters:
*arg (tuple) – a Dr.Jit array instance and its corresponding label
str
value.**kwarg (dict) – A set of (keyword, object) pairs.
Debugging¶
- drjit.assert_true(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)¶
Generate an assertion failure message when any of the entries in
cond
areFalse
.This function resembles the built-in
assert
keyword in that it raises anAssertionError
when the conditioncond
isFalse
.In contrast to the built-in keyword, it also works when
cond
is an array of boolean values. In this case, the function raises an exception when any entry ofcond
isFalse
.The function accepts an optional format string along with positional and keyword arguments, and it processes them like
drjit.print()
. When only a subset of the entries ofcond
isFalse
, the function reduces the generated output to only include the associated entries.>>> x = Float(1, -4, -2, 3) >>> dr.assert_true(x >= 0, 'Found negative values: {}', x) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "drjit/__init__.py", line 1327, in assert_true raise AssertionError(msg) AssertionError: Assertion failure: Found negative values: [-4, -2]
This function also works when some of the function inputs are symbolic. In this case, the check is delayed and potential failures will be reported asynchronously. In this case,
drjit.assert_true()
generates output onsys.stderr
instead of raising an exception, as the original execution context no longer exists at that point.Assertion checks carry a performance cost, hence they are disabled by default. To enable them, set the JIT flag
dr.JitFlag.Debug
.- Parameters:
cond (bool | drjit.ArrayBase) – The condition used to trigger the assertion. This should be a scalar Python boolean or a 1D boolean array.
fmt (str) – An optional format string that will be appended to the error message. It can reference positional or keyword arguments specified via
*args
and**kwargs
.*args (tuple) – Optional variable-length positional arguments referenced by
fmt
, seedrjit.print()
for details on this.tb_depth (int) – Depth of the backtrace that should be appended to the assertion message. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed.
tb_skip (int) – The first
tb_skip
entries of the backtrace will be removed. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed. This is helpful when the assertion check is called from a helper function that should not be shown.**kwargs (dict) – Optional variable-length keyword arguments referenced by
fmt
, seedrjit.print()
for details on this.
- drjit.assert_false(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)¶
Equivalent to
assert_true()
with a flipped conditioncond
. Please refer to the documentation of this function for further details.
- drjit.assert_equal(arg0, arg1, fmt: str | None = None, *args, limit: int = 3, tb_skip: int = 0, **kwargs)¶
Equivalent to
assert_true()
with the conditionarg0==arg1
. Please refer to the documentation of this function for further details.
- drjit.print(fmt: str, *args, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None ¶
- drjit.print(value: object, /, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None
Generate a formatted string representation and print it immediately or in a delayed fashion (if any of the inputs are symbolic).
This function combines the behavior of the built-in Python
format()
andprint()
functions: it generates a formatted string representation as specified by a format stringfmt
and then outputs it on the console. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.>>> from drjit.cuda import Array3f >>> dr.print("{}:\n{foo}", ... "A PyTree containing an array", ... foo={ 'a' : Array3f(1, 2, 3) }) A PyTree containing an array: { 'a': [[1, 2, 3]] }
The key advance of
drjit.print()
compared to the built-in Pythonprint()
statement is that it can run asynchronously, which allows it to print symbolic variables without requiring their evaluation. Dr.Jit uses symbolic variables to trace loops (drjit.while_loop()
), conditionals (drjit.if_stmt()
), and calls (drjit.switch()
,drjit.dispatch()
). Such symbolic variables represent values that are unknown at trace time, and which cannot be printed using the built-in Pythonprint()
function (attempting to do so will raise an exception).When the print statement does not reference any symbolic arrays or tensors, it will execute immediately. Otherwise, the output will appear after the next
drjit.eval()
statement, or whenever any subsequent computation is evaluated. Here is an example from an interactive Python session demonstrating printing from a symbolic call performed viadrjit.switch()
:>>> from drjit.llvm import UInt, Float >>> def f1(x): ... dr.print("in f1: {x=}", x=x) ... >>> def f2(x): ... dr.print("in f2: {x=}", x=x) ... >>> dr.switch(index=UInt(0, 0, 0, 1, 1, 1), ... targets=[f1, f2], ... x=Float(1, 2, 3, 4, 5, 6)) >>> # No output (yet) >>> dr.eval() in f1: x=[1, 2, 3] in f2: x=[4, 5, 6]
Dynamic arrays with more than 20 entries will be abbreviated. Specify the
limit=..
argument to reveal the contents of larger arrays.>>> dr.print(dr.arange(dr.llvm.Int, 30)) [0, 1, 2, .. 24 skipped .., 27, 28, 29] >>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵ 23, 24, 25, 26, 27, 28, 29]
This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:
Positional arguments (in
*args
) can be referenced implicitly ({}
), or using indices ({0}
,{1}
, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.Keyword arguments (in
**kwargs
) can be referenced via their keyword name ({foo}
). Unreferenced keywords will be silently ignored.A trailing
=
in a brace expression repeats the string within the braces followed by the output:>>> dr.print('{foo=}', foo=1) foo=1
When the format string
fmt
is omitted, it is implicitly set to{}
, and the function formats a single positional argument.The function implicitly appends
end
to the format string, which is set to a newline by default. The final result is sent tosys.stdout
(by default) orfile
. When afile
argument is given, it must implement the methodwrite(arg: str)
.A related operation
drjit.format()
admits the same format string syntax but returns a Pythonstr
instead of printing to the console. This operation, however, does not support symbolic inputs—usedrjit.print()
with a customfile
argument to stringify symbolic inputs asynchronously.Note
This operation is not suitable for extracting large amounts of data from Dr.Jit kernels, as the conversion to a string representation incurs a nontrivial runtime cost.
Note
Technical details on symbolic printing
When Dr.Jit compiles and executes queued computation on the target device, it includes additional code for symbolic print operations that that captures referenced arguments and copies them back to the host (CPU). The information is then printed following the end of that process.
Only a limited amount of memory is set aside to capture the output of symbolic print operations. This is because the amount of data produced within long-running symbolic loops can often exceed the total device memory. Also, printing gigabytes of ASCII text into a Python console or Jupyter notebook is likely not a good idea.
For the electronically inclined, the operation is best thought of as hooking up an oscilloscope to a high-frequency circuit. The oscilloscope provides a limited view into a vast torrent of data to assist the user, who would be overwhelmed if the oscilloscope worked by capturing and showing everything.
The operation warns when the size of the buffers was insufficient. In this case, the output is still printed in the correct order, but chunks of the data are missing. The position of the resulting holes is unspecified and non-deterministic.
>>> dr.print(dr.arange(Float, 10000000), method='symbolic') >>> dr.eval() [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] RuntimeWarning: dr.print(): symbolic print statement only captured 20 of 10000000 available outputs. The above is a non-deterministic sample, in which entries are in the right order but not necessarily contiguous. Specify `limit=..` to capture more information and/or add the special format field `{thread_id}` show the thread ID/array index associated with each entry of the captured output.
This is because the (many) parallel threads of the program all try to append their state to the output buffer, but only the first
limit
(20 by default) succeed. The host subsequently re-sorts the captured data by thread ID. This means that the output[5, 6, 102, 188, 1026, ..]
would also be a valid result of the prior command. When a print statement references multiple arrays, then the operations either shows all array entries associated with a particular execution thread, or none of them.To refine what is captured, you can specify the
active
argument to disable the print statement for a subset of the entries (a “trigger” in the oscilloscope analogy). Printing from an inactive thread within a symbolic loop (drjit.while_loop()
), conditional (drjit.if_stmt()
), or call (drjit.switch()
,drjit.dispatch()
) will likewise not generate any output.A potential gotcha of the current design is that a symbolic print within a symbolic loop counts as one print statement and will only generate a single combined output string. The output of each thread is arranged in one contiguous block. You can add the special format string keyword
{thread_id}
to reveal the mapping between output values and the execution thread that generated them:>>> from drjit.llvm import Int >>> @dr.syntax >>> def f(j: Int): ... i = Int(0) ... while i < j: ... dr.print('{thread_id=} {i=}', i=i) ... i += 1 ... >>> f(Int(2, 3)) >>> dr.eval(); thread_id=[0, 0, 1, 1, 1], i=[0, 1, 0, 1, 2]
The example above runs a symbolic loop twice in parallel: the first thread runs for for 2 iterations, and the second runs for 3 iterations. The loop prints the iteration counter
i
, which then leads to the output[0, 1, 0, 1, 2]
where the first two entries are produced by the first thread, and the trailing three belong to the second thread. Thethread_id
output clarifies this mapping.- Parameters:
fmt (str) – A format string that potentially references input arguments from
*args
and**kwargs
.active (drjit.ArrayBase | bool) – A mask argument that can be used to disable a subset of the entries. The print statement will be completely suppressed when there is no output. (default:
True
).end (str) – This string will be appended to the format string. It is set to a newline character (
"\n"
) by default.file (object) – The print operation will eventually invoke
file.write(arg:str)
to print the formatted string. Specify this argument to route the output somewhere other than the default output streamsys.stdout
.mode (str) – Specify this parameter to override the evaluation mode. Possible values are:
"symbolic"
,"evaluated"
, or"auto"
. The default value of"auto"
causes the function to use evaluated mode (which prints immediately) unless a symbolic input is detected, in which case printing takes place symbolically (i.e., in a delayed fashion).limit (int) – The operation will abbreviate dynamic arrays with more than
limit
(default: 20) entries.
- drjit.format(fmt: str, *args, limit: int = 20, **kwargs)¶
- drjit.format(value: object, *, limit: int = 20, **kwargs) None
Return a formatted string representation.
This function generates a formatted string representation as specified by a format string
fmt
and then returns it as a Pythonstr
object. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.>>> from drjit.cuda import Array3f >>> s = dr.format("{}:\n{foo}", ... "A PyTree containing an array", ... foo={ 'a' : Array3f(1, 2, 3) }) >>> print(s) A PyTree containing an array: { 'a': [[1, 2, 3]] }
Dynamic arrays with more than 20 entries will be abbreviated. Specify the
limit=..
argument to reveal the contents of larger arrays.>>> dr.format(dr.arange(dr.llvm.Int, 30)) [0, 1, 2, .. 24 skipped .., 27, 28, 29] >>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵ 23, 24, 25, 26, 27, 28, 29]
This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:
Positional arguments (in
*args
) can be referenced implicitly ({}
), or using indices ({0}
,{1}
, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.Keyword arguments (in
**kwargs
) can be referenced via their keyword name ({foo}
). Unreferenced keywords will be silently ignored.A trailing
=
in a brace expression repeats the string within the braces followed by the output:>>> dr.format('{foo=}', foo=1) foo=1
When the format string
fmt
is omitted, it is implicitly set to{}
, and the function formats a single positional argument.In contrast to the related
drjit.print()
, this function does not output the result on the console, and it cannot support symbolic inputs. This is because returning a string right away is incompatible with the requirement of evaluating/formatting symbolic inputs in a delayed fashion. If you wish to format symbolic arrays, you must calldrjit.print()
with a customfile
object that implements the.write()
function. Dr.Jit will call this function with the generated string when it is ready.- Parameters:
fmt (str) – A format string that potentially references input arguments from
*args
and**kwargs
.limit (int) – The operation will abbreviate dynamic arrays with more than
limit
(default: 20) entries.
- Returns:
The formatted string representation created as specified above.
- Return type:
str
- drjit.log_level() drjit.LogLevel ¶
- drjit.set_log_level(arg: drjit.LogLevel, /) None ¶
- drjit.set_log_level(arg: int, /) None
Profiling¶
- drjit.profile_mark(arg: str, /) None ¶
Mark an event on the timeline of profiling tools.
Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.
Note that this event will be recorded on the CPU timeline.
- drjit.profile_range()¶
Context manager to mark a region (e.g. a function call) on the timeline of profiling tools.
You can use this context manager to wrap parts of your code and track when and for how long it runs. Regions can be arbitrarily nested, which profiling tools visualize as a stack.
Note that this function is intended to track activity on the CPU timeline. If the wrapped region launches asynchronous GPU kernels, then those won’t generally be included in the length of the range unless
drjit.sync_thread()
or some other type of synchronization operation waits for their completion (which is generally not advisable, since keeping CPU and GPU asynchronous with respect to each other improves performance).Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.
Low-level bits¶
- drjit.set_backend(arg: Literal['cuda', 'llvm', 'scalar'], /)¶
- drjit.set_backend(arg: drjit.JitBackend, /) None
Adjust the
drjit.auto.*
module so that it refers to types from the specified backend.
- drjit.thread_count() int ¶
Return the number of threads that Dr.Jit uses to parallelize computation on the CPU
- drjit.set_thread_count(arg: int, /) None ¶
Adjust the number of threads that Dr.Jit uses to parallelize computation on the CPU.
The thread pool is primarily used by Dr.Jit’s LLVM backend. Other projects using underlying nanothread thread pool library will also be affected by changes performed using by this function. It is legal to call it even while parallel computation is currently ongoing.
- drjit.sync_thread() None ¶
Wait for all currently running computation to finish.
This function synchronizes the device (e.g. the GPU) with the host (CPU) by waiting for the termination of all computation enqueued by the current host thread.
One potential use of this function is to measure the runtime of a kernel launched by Dr.Jit. We instead recommend the use of the
drjit.kernel_history()
, which exposes more accurate device timers.In general, calling this function in user code is considered bad practice. Dr.Jit programs “run ahead” of the device to keep it fed with work. This is important for performance, and
drjit.sync_thread()
breaks this optimization.All operations sent to a device (including reads) are strictly ordered, so there is generally no reason to wait for this queue to empty. If you find that
drjit.sync_thread()
is needed for your program to run correctly, then you have found a bug. Please report it on the project’s GitHub issue tracker.
- drjit.flush_kernel_cache() None ¶
Release all currently cached kernels.
When Dr.Jit evaluates a previously unseen computation, it compiles a kernel and then maps it into the memory of the CPU or GPU. The kernel stays resident so that it can be immediately reused when that same computation reoccurs at a later point.
In long development sessions (e.g. a Jupyter notebook-based prototyping), this cache may eventually become unreasonably large, and calling
flush_kernel_cache()
to free it may be advisable.Note that this does not free the disk cache that also exists to preserve compiled programs across sessions. To clear this cache as well, delete the directory
$HOME/.drjit
on Linux/macOS, and%AppData%\Local\Temp\drjit
on Windows. (TheAppData
folder is typically found inC:\Users\<your username>
).
- drjit.flush_malloc_cache() None ¶
Free the memory allocation cache maintained by Dr.Jit.
Allocating and releasing large chunks of memory tends to be relatively expensive, and Dr.Jit programs often need to do so at high rates.
Like most other array programming frameworks, Dr.Jit implements an internal cache to reduce such allocation-related costs. This cache starts out empty and grows on demand. Allocated memory is never released by default, which can be problematic when using multiple array programming frameworks within the same Python session, or when running multiple processes in parallel.
The
drjit.flush_malloc_cache()
function releases all currently unused memory back to the operating system. This is a relatively expensive step: you likely don’t want to use it within a performance-sensitive program region (e.g. an optimization loop).
- drjit.expand_threshold() int ¶
Query the threshold for performing scatter-reductions via expansion.
Getter for the quantity set in
drjit.set_expand_threshold()
- drjit.set_expand_threshold(arg: int, /) None ¶
Set the threshold for performing scatter-reductions via expansion.
The documentation of
drjit.ReduceOp
explains the cost of atomic scatter-reductions and introduces various optimization strategies.One particularly effective optimization (the section on optimizations for plots) named
drjit.ReduceOp.Expand
is specific to the LLVM backend. It replicates the target array to avoid write conflicts altogether, which enables the use of non-atomic memory operations. This is significantly faster but also very memory-intensive. The storage cost of an 1MB array targeted by adrjit.scatter_reduce()
operation now grows toN
megabytes, whereN
is the number of cores.For this reason, Dr.Jit implements a user-controllable threshold exposed via the functions
drjit.expand_threshold()
anddrjit.set_expand_threshold()
. When the array has more entries than the value specified here, thedrjit.ReduceOp.Expand
strategy will not be used unless specifically requested via themode=
parameter of operations likedrjit.scatter_reduce()
,drjit.scatter_add()
, anddrjit.gather()
.The default value of this parameter is 1000000 (1 million entries).
- drjit.kernel_history(types: collections.abc.Sequence[drjit.KernelType] = []) list ¶
Return the history of captured kernel launches.
Dr.Jit can optionally capture performance-related metadata. To do so, set the
drjit.JitFlag.KernelHistory
flag as follows:with dr.scoped_set_flag(dr.JitFlag.KernelHistory): # .. computation to be analyzed .. hist = dr.kernel_history()
The
drjit.kernel_history()
function returns a list of dictionaries characterizing each major operation performed by the analyzed region. This dictionary has the following entriesbackend
: The used JIT backend.execution_time
: The time (in microseconds) used by this operation.On the CUDA backend, this value is captured via CUDA events. On the LLVM backend, this involves querying
CLOCK_MONOTONIC
(Linux/macOS) orQueryPerformanceCounter
(Windows).type
: The type of computation expressed by an enumeration value of typedrjit.KernelType
. The most interesting workload generated by Dr.Jit are just-in-time compiled kernels, which are identified bydrjit.KernelType.JIT
.These have several additional entries:
hash
: The hash code identifying the kernel. (This is the same hash code is also shown when increasing the log level viadrjit.set_log_level()
).ir
: A capture of the intermediate representation used in this kernel.operation_count
: The number of low-level IR operations. (A rough proxy for the complexity of the operation.)cache_hit
: Was this kernel present in Dr.Jit’s in-memory cache? Otherwise, it as either loaded from memory or had to be recompiled from scratch.cache_disk
: Was this kernel present in Dr.Jit’s on-disk cache? Otherwise, it had to be recompiled from scratch.codegen_time
: The time (in microseconds) which Dr.Jit needed to generate the textual low-level IR representation of the kernel. This step is always needed even if the resulting kernel is already cached.backend_time
: The time (in microseconds) which the backend (either the LLVM compiler framework or the CUDA PTX just-in-time compiler) required to compile and link the low-level IR into machine code. This step is only needed when the kernel did not already exist in the in-memory or on-disk cache.uses_optix
: Was this kernel compiled by the NVIDIA OptiX ray tracing engine?
Note that
drjit.kernel_history()
clears the history while extracting this information. A related operationdrjit.kernel_history_clear()
only clears the history without returning any information.
- drjit.kernel_history_clear() None ¶
Clear the kernel history.
This operation clears the kernel history without returning any information about it. See
drjit.kernel_history()
for details.
Typing¶
Digital Differential Analyzer¶
The drjit.dda
module provides a general implementation of a digital
differential analyzer (DDA) that steps through the intersection of a ray
segment and a N-dimensional grid, performing a custom computation at every
cell.
The drjit.integrate()
function builds on this functionality to compute
differentiable line integrals of bi- or trilinearly interpolants stored on a
grid.
- drjit.dda.dda(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: object, grid_res: ArrayNuT, grid_min: ArrayNfT, grid_max: ArrayNfT, func: Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], Tuple[StateT, BoolT]], state: StateT, active: BoolT, mode: Literal['scalar', 'evaluated', 'symbolic', None] | None = None, max_iterations: int | None = None) StateT ¶
N-dimensional digital differential analyzer (DDA).
This function traverses the intersection of a Cartesian coordinate grid and a specified ray or ray segment. The following snippet shows how to use it to enumerate the intersection of a grid with a single ray.
from drjit.scalar import Array3f, Array3i, Float, Bool def dda_fun(state: list, index: Array3i, pt_in: Array3f, pt_out: Array3f) -> tuple[list, bool]: # Entered a grid cell, stash it in the 'state' variable state.append(Array3f(index)) return state, Bool(True) result = dda( ray_o = Array3f(-.1), ray_d = Array3f(.1, .2, .3), ray_max = Float(float('inf')), grid_res = Array3i(10), grid_min = Array3f(0), grid_max = Array3f(1), func = dda_fun, state = [], active = Bool(True) ) print(result)
Since all input elements are Dr.Jit arrays, everything works analogously when processing
N
rays andN
potentially different grid configurations. The entire process can be captured symbolically.The function takes the following arguments. Note that many of them are generic type variables (signaled by ending with a capital
T
). To support different dimensions and precisions, the implementation must be able to deal with various input types, which is communicated by these type variables.- Parameters:
ray_o (ArrayNfT) – the ray origin, where the
ArrayNfT
type variable refers to an n-dimensional scalar or Jit-compiled floating point array.ray_d (ArrayNfT) – the ray direction. Does not need to be normalized.
ray_max (object) – the maximum extent along the ray, which is permitted to be infinite. The value is specfied as a multiple of the norm of
ray_d
, which is not necessarily unit-length. Must be of typedr.value_t(ArrayNfT)
.grid_res (ArrayNuT) – the grid resolution, where the
ArrayNuT
type variable refers to a matched 32-bit integer array (i.e.,ArrayNuT = dr.int32_array_t(ArrayNfT)
).grid_min (ArrayNfT) – the minimum position of the grid bounds.
grid_max (ArrayNfT) – the maximum position of the grid bounds.
func (Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], tuple[StateT, BoolT]]) –
a callback that will be invoked when the DDA traverses a grid cell. It must take the following five positional arguments:
arg0: StateT
: An arbitrary state value.arg1: ArrayNuT
: An integer array specifying the cell index along each dimension.arg2: ArrayNfT
: The fractional position (\(\in [0, 1]^n\)) where the ray enters the current cell.arg3: ArrayNfT
: The fractional position (\(\in [0, 1]^n\)) where the ray leaves the current cell.arg4: BoolT
: A boolean array specifying which elements are active.
The callback should then return a tuple of type
tuple[StateT, BoolT]
containingAn updated state value.
A boolean array that can be used to exit the loop prematurely for some or all rays. The iteration stops if the associated entry of the return value equals
False
.
state (StateT) – an arbitrary initial state that will be passed to the callback.
active (BoolT) – an array specifying which elements of the input are active, where the
BoolT
type variable refers to a matched boolean array (i.e.,BoolT = dr.mask_t(ray_o.x)
).mode – (str | None): The operation can operate in scalar, symbolic, or evaluated modes—see the
mode
argument and the documentation ofdrjit.while_loop()
for details.max_iterations – int | None: Bound on the iteration count that is needed for reverse-mode differentiation. Forwarded to the
max_iterations
parameter ofdrjit.while_loop()
.
- Returns:
The function returns the final state value of the callback upon termination.
- Return type:
StateT
Note
Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g.
(Z, Y, X)
), while regular 3D positions use the opposite(X, Y, Z)
order.In particular, all
ArrayNuT
-typed parameters of the function and the callback use the ZYX convention, whileArrayNfT
-typed parameters use theXYZ
convention.
- drjit.dda.integrate(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: FloatT, grid_min: ArrayNfT, grid_max: ArrayNfT, vol: ArrayBase[Any, Any, Any, Any, Any, Any, Any], active: object | None = None, mode: Literal['scalar', 'evaluated', 'symbolic', None] | None = None) FloatT ¶
Compute an analytic definite integral of a bi- or trilinear interpolant.
This function uses DDA (
drjit.dda.dda()
) to step along the voxels of a 2D/3D volume traversed by a finite segment or a infinite-length ray. It analytically computes and accumulates the definite integral of the interpolant in each voxel.The input 2D/3D volume is provided using a tensor
vol
(e.g., of typedrjit.cuda.ad.TensorXf
) with an implicitly specified grid resolutionvol.shape
. This data volume is placed into an axis-aligned region with bounds (grid_min
,grid_max
).The operation provides an efficient forward and backward derivative.
Note
Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g.
(Z, Y, X)
), while regular 3D positions use the opposite(X, Y, Z)
order.In particular,
vol.shape
uses the ZYX convention, while theArrayNfT
-typed parameters use theXYZ
convention.One important difference to the texture classes is that the interpolant is sampled at integer grid positions, whereas the Dr.Jit texture classes places values at cell centers, i.e. with a
.5
fractional offset.