API Reference (Main)¶
This document explains the public API (behaviors, signatures) in exhaustive detail. If you’re new to Dr.Jit, it may be easier to start by reading the other sections first.
The reference documentation is also exposed through docstrings, which many visual editors (e.g., VS Code, neovim with LSP) will show during code completion, or when hovering over an expression.
The reference extensively use type variables which can be
recognized because their name equals or ends with a capital T
(e.g., T
,
ArrayT
, MaskT
, etc.). Type variables serve as placeholders that show
how types propagate through function calls. For example, a function with
signature
def f(arg: T, /) -> tuple[T, T]: ...
will return a pair of int
instances when called with an int
-typed
arg
value.
Array creation¶
- drjit.zeros(dtype: type, shape: int = 1) object ¶
- drjit.zeros(dtype: type, shape: collections.abc.Sequence[int]) object
Overloaded function.
zeros(dtype: type, shape: int = 1) -> object
Return a zero-initialized instance of the desired type and shape.
This function can create zero-initialized instances of various types. In particular,
dtype
can be:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.zeros(dr.cuda.Array2f, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.zeros()
will invoke itself recursively to zero-initialize each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case.
Note that when
dtype
refers to a scalar mask or a mask array, it will be initialized toFalse
as opposed to zero.The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A zero-initialized instance of type
dtype
.- Return type:
object
zeros(dtype: type, shape: collections.abc.Sequence[int]) -> object
- drjit.empty(dtype: type, shape: int = 1) object ¶
- drjit.empty(dtype: type, shape: collections.abc.Sequence[int]) object
Overloaded function.
empty(dtype: type, shape: int = 1) -> object
Return an uninitialized Dr.Jit array of the desired type and shape.
This function can create uninitialized buffers of various types. It should only be used in combination with a subsequent call to an operation like
drjit.scatter()
that fills the array contents with valid data.The
dtype
parameter can be used to request:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.empty(dr.cuda.Array2f, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.empty()
will invoke itself recursively to allocate memory for each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case, and the function returns a zero-initialized result (there is little point in instantiating uninitialized versions of scalar Python types).
drjit.empty()
delays allocation of the underlying buffer until an operation tries to read/write the actual array contents.- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
An instance of type
dtype
with arbitrary/undefined contents.- Return type:
object
empty(dtype: type, shape: collections.abc.Sequence[int]) -> object
- drjit.ones(dtype: type, shape: int = 1) object ¶
- drjit.ones(dtype: type, shape: collections.abc.Sequence[int]) object
Overloaded function.
ones(dtype: type, shape: int = 1) -> object
Return an instance of the desired type and shape filled with ones.
This function can create one-initialized instances of various types. In particular,
dtype
can be:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.ones(dr.cuda.Array2f, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.ones()
will invoke itself recursively to initialize each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case.
Note that when
dtype
refers to a scalar mask or a mask array, it will be initialized toTrue
as opposed to one.The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtype
filled with ones.- Return type:
object
ones(dtype: type, shape: collections.abc.Sequence[int]) -> object
- drjit.full(dtype: type, value: object, shape: int = 1) object ¶
- drjit.full(dtype: type, value: object, shape: collections.abc.Sequence[int]) object
Overloaded function.
full(dtype: type, value: object, shape: int = 1) -> object
Return an constant-valued instance of the desired type and shape.
This function can create constant-valued instances of various types. In particular,
dtype
can be:A Dr.Jit array type like
drjit.cuda.Array2f
. Whenshape
specifies a sequence, it must be compatible with static dimensions of thedtype
. For example,dr.full(dr.cuda.Array2f, value=1.0, shape=(3, 100))
fails, since the leading dimension is incompatible withdrjit.cuda.Array2f
. Whenshape
is an integer, it specifies the size of the last (dynamic) dimension, if available.A tensorial type like
drjit.scalar.TensorXf
. Whenshape
specifies a sequence (list/tuple/..), it determines the tensor rank and shape. Whenshape
is an integer, the function creates a rank-1 tensor of the specified size.A PyTree. In this case,
drjit.full()
will invoke itself recursively to initialize each field of the data structure.A scalar Python type like
int
,float
, orbool
. Theshape
parameter is ignored in this case.
The function returns a literal constant array that consumes no device memory.
- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
value (object) – An instance of the underlying scalar type (
float
/int
/bool
, etc.) that will be used to initialize the array contents.shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtype
filled withvalue
- Return type:
object
full(dtype: type, value: object, shape: collections.abc.Sequence[int]) -> object
- drjit.opaque(dtype: type, value: object, shape: int = 1) object ¶
- drjit.opaque(dtype: type, value: object, shape: collections.abc.Sequence[int]) object
Overloaded function.
opaque(dtype: type, value: object, shape: int = 1) -> object
Return an opaque constant-valued instance of the desired type and shape.
This function is very similar to
drjit.full()
in that it creates constant-valued instances of various types including (potentially nested) Dr.Jit arrays, tensors, and PyTrees. Please refer to the documentation ofdrjit.full()
for details on the function signature. However,drjit.full()
creates literal constant arrays, which means that Dr.Jit is fully aware of the array contents.In contrast,
drjit.opaque()
produces an opaque array backed by a representation in device memory.Why is this useful?
Consider the following snippet, where a complex calculation is parameterized by the constant
1
.from drjit.llvm import Float result = complex_function(Float(1), ...) # Float(1) is equivalent to dr.full(Float, 1) print(result)
The
print()
statement will cause Dr.Jit to evaluate the queued computation, which likely also requires compilation of a new kernel (if that exact pattern of steps hasn’t been observed before). Kernel compilation is costly and may be much slower than the actual computation that needs to be done.Suppose we later wish to evaluate the function with a different parameter:
result = complex_function(Float(2), ...) print(result)
The constant
2
is essentially copy-pasted into the generated program, causing a mismatch with the previously compiled kernel that therefore cannot be reused. This unfortunately means that we must once more wait a few tens or even hundreds of milliseconds until a new kernel has been compiled and uploaded to the device.This motivates the existence of
drjit.opaque()
. By making a variable opaque to Dr.Jit’s tracing mechanism, we can keep constants out of the generated program and improve the effectiveness of the kernel cache:# The following lines reuse the compiled kernel regardless of the constant value = dr.opqaque(Float, 2) result = complex_function(value, ...) print(result)
This function is related to
drjit.make_opaque()
, which can turn an already existing Dr.Jit array, tensor, or PyTree into an opaque representation.- Parameters:
dtype (type) – Desired Dr.Jit array type, Python scalar type, or PyTree.
value (object) – An instance of the underlying scalar type (
float
/int
/bool
, etc.) that will be used to initialize the array contents.shape (Sequence[int] | int) – Shape of the desired array
- Returns:
A instance of type
dtype
filled withvalue
- Return type:
object
opaque(dtype: type, value: object, shape: collections.abc.Sequence[int]) -> object
- drjit.arange(dtype: type[T], size: int) T ¶
- drjit.arange(dtype: type[T], start: int, stop: int, step: int = 1) T
Overloaded function.
arange(dtype: type[T], size: int) -> T
This function generates an integer sequence on the interval [
start
,stop
) with step sizestep
, wherestart
= 0 andstep
= 1 if not specified.- Parameters:
dtype (type) – Desired Dr.Jit array type. The
dtype
must refer to a dynamically sized 1D Dr.Jit array such asdrjit.scalar.ArrayXu
ordrjit.cuda.Float
.start (int) – Start of the interval. The default value is
0
.stop/size (int) – End of the interval (not included). The name of this parameter differs between the two provided overloads.
step (int) – Spacing between values. The default value is
1
.
- Returns:
The computed sequence of type
dtype
.- Return type:
object
arange(dtype: type[T], start: int, stop: int, step: int = 1) -> T
- drjit.linspace(dtype: type[T], start: float, stop: float, num: int, endpoint: bool = True) T ¶
This function generates an evenly spaced floating point sequence of size
num
covering the interval [start
,stop
].- Parameters:
dtype (type) – Desired Dr.Jit array type. The
dtype
must refer to a dynamically sized 1D Dr.Jit floating point array, such asdrjit.scalar.ArrayXf
ordrjit.cuda.Float
.start (float) – Start of the interval.
stop (float) – End of the interval.
num (int) – Number of samples to generate.
endpoint (bool) – Should the interval endpoint be included? The default is
True
.
- Returns:
The computed sequence of type
dtype
.- Return type:
object
- drjit.zeros_like(arg: T, /) T ¶
Return an array of zeros with the same shape and type as a given array.
- drjit.empty_like(arg: T, /) T ¶
Return an empty array with the same shape and type as a given array.
- drjit.ones_like(arg: T, /) T ¶
Return an array of ones with the same shape and type as a given array.
Control flow¶
- drjit.syntax(f: None = None, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) Callable[[T], T] ¶
- drjit.syntax(f: T, *, recursive: bool = False, print_ast: bool = False, print_code: bool = False) T
- drjit.hint(arg: T, /, *, mode: Literal['scalar', 'evaluated', 'symbolic', None] = None, max_iterations: int | None = None, label: str | None = None, include: List[object] | None = None, exclude: List[object] | None = None, strict: bool = True) T ¶
Within ordinary Python code, this function is unremarkable: it returns the positional-only argument
arg
while ignoring any specified keyword arguments.The main purpose of
drjit.hint()
is to provide hints that influence the transformation performed by the@drjit.syntax
decorator. The following kinds of hints are supported:mode
overrides the compilation mode of awhile
loop orif
statement. The following choices are available:mode='scalar'
disables code transformations, which is permitted when the predicate of a loop orif
statement is a scalar Pythonbool
.i: int = 0 while dr.hint(i < 10, mode='scalar'): # ...
Routing such code through
drjit.while_loop()
ordrjit.if_stmt()
still works but may add small overheads, which motivates the existence of this flag. Note that this annotation does not causemode=scalar
to be passeddrjit.while_loop()
, anddrjit.if_stmt()
(which happens to be a valid input of both). Instead, it disables the code transformation altogether so that the above example translates into ordinary Python code:i: int = 0 while i < 10: # ...
mode='evaluated'
forces execution in evaluated mode and causes the code transformation to forward this argument to the relevantdrjit.while_loop()
ordrjit.if_stmt()
call.Refer to the discussion of
drjit.while_loop()
,drjit.JitFlag.SymbolicLoops
,drjit.if_stmt()
, anddrjit.JitFlag.SymbolicConditionals
for details.mode='symbolic'
forces execution in symbolic mode and causes the code transformation to forward this argument to the relevantdrjit.while_loop()
ordrjit.if_stmt()
call.Refer to the discussion of
drjit.while_loop()
,drjit.JitFlag.SymbolicLoops
,drjit.if_stmt()
, anddrjit.JitFlag.SymbolicConditionals
for details.
The optional
strict=False
reduces the strictness of variable consistency checks.Consider the following snippet:
from drjit.llvm import UInt32 @dr.syntax def f(x: UInt32): if x < 4: y = 3 else: y = 5 return y
This code will raise an exception.
>> f(UInt32(1)) RuntimeError: drjit.if_stmt(): the non-array state variable 'y' of type 'int' changed from '5' to '10'. Please review the interface and assumptions of 'drjit.while_loop()' as explained in the documentation (https://drjit.readthedocs.io/en/latest/reference.html#drjit.while_loop).
This is because the computed variable
y
of typeint
has an inconsistent value depending on the taken branch. Furthermore,y
is a scalar Python type that isn’t tracked by Dr.Jit. The fix here is to initializey
withUInt32(<integer value>)
.However, there may also be legitimate situations where such an inconsistency is needed by the implementation. This can be fine as
y
is not used below theif
statement. In this case, you can annotate the conditional or loop withdr.hint(..., strict=False)
, which disables the check.max_iterations
specifies a maximum number of loop iterations for reverse-mode automatic differentiation.Naive reverse-mode differentiation of loops (unless replaced by a smarter problem-specific strategy via
drjit.custom
anddrjit.CustomOp
) requires allocation of large buffers that hold loop state for all iterations.Dr.Jit requires an upper bound on the maximum number of loop iterations so that it can allocate such buffers, which can be provided via this hint. Otherwise, reverse-mode differentiation of loops will fail with an error message.
label
provovides a descriptive label.Dr.Jit will include this label as a comment in the generated intermediate representation, which can be helpful when debugging the compilation of large programs.
include
andexclude
indicates to the@drjit.syntax
decorator that a local variable should or should not be considered to be part of the set of state variables passed todrjit.while_loop()
ordrjit.if_stmt()
.While transforming a function, the
@drjit.syntax
decorator sequentially steps through a program to identify the set of read and written variables. It then forwards referenced variables to recursivedrjit.while_loop()
anddrjit.if_stmt()
calls. In rare cases, it may be useful to manually include or exclude a local variable from this process— specify a list of such variables to thedrjit.hint()
annotation to do so.
- drjit.while_loop(state: tuple[*Ts], cond: typing.Callable[[*Ts], AnyArray | bool], body: typing.Callable[[*Ts], tuple[*Ts]], labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True, compress: bool | None = None, max_iterations: int | None = None) tuple[*Ts] ¶
Repeatedly execute a function while a loop condition holds.
Motivation
This function provides a vectorized generalization of a standard Python
while
loop. For example, consider the following Python snippeti: int = 1 while i < 10: x *= x i += 1
This code would fail when
i
is replaced by an array with multiple entries (e.g., of typedrjit.llvm.Int
). In that case, the loop condition evaluates to a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, each entry of the array may need to run the loop for a different number of iterations. A standard Pythonwhile
loop is not able to do so.The
drjit.while_loop()
function realizes such a fine-grained looping mechanism. It takes three main input arguments:state
, a tuple of state variables that are modified by the loop iteration,Dr.Jit optimizes away superfluous state variables, so there isn’t any harm in specifying variables that aren’t actually modified by the loop.
cond
, a function that takes the state variables as input and uses them to evaluate and return the loop condition in the form of a boolean array,body
, a function that also takes the state variables as input and runs one loop iteration. It must return an updated set of state variables.
The function calls
cond
andbody
to execute the loop. It then returns a tuple containing the final version of thestate
variables. With this functionality, a vectorized version of the above loop can be written as follows:i, x = dr.while_loop( state=(i, x), cond=lambda i, x: i < 10, body=lambda i, x: (i+1, x*x) )
Lambda functions are convenient when the condition and body are simple enough to fit onto a single line. In general you may prefer to define local functions (
def loop_cond(i, x): ...
) and pass them to thecond
andbody
arguments.Dr.Jit also provides the
@drjit.syntax
decorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) usingdrjit.while_loop()
. With this decorator, the above example would be written as follows:@dr.syntax def f(i, x): while i < 10: x *= x i += 1 return i, x
Evaluation modes
Dr.Jit uses one of three different modes to compile this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).
Scalar mode: Scalar loops that don’t need any vectorization can fall back to a simple Python loop construct.
while cond(state): state = body(*state)
This is the default strategy when
cond(state)
returns a scalar Pythonbool
.The loop body may still use Dr.Jit types, but note that this effectively unrolls the loop, generating a potentially long sequence of instructions that may take a long time to compile. Symbolic mode (discussed next) may be advantageous in such cases.
Symbolic mode: Here, Dr.Jit runs a single loop iteration to capture its effect on the state variables. It embeds this captured computation into the generated machine code. The loop will eventually run on the device (e.g., the GPU) but unlike a Python
while
statement, the loop does not run on the host CPU (besides the mentioned tentative evaluation for symbolic tracing).When loop optimizations are enabled (
drjit.JitFlag.OptimizeLoops
), Dr.Jit may re-trace the loop body so that it runs twice in total. This happens transparently and has no influence on the semantics of this operation.Evaluated mode: in this mode, Dr.Jit will repeatedly evaluate the loop state variables and update active elements using the loop body function until all of them are done. Conceptually, this is equivalent to the following Python code:
active = True while True: dr.eval(state) active &= cond(state) if not dr.any(active): break state = dr.select(active, body(state), state)
(In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)
Dr.Jit will typically compile a kernel when it runs the first loop iteration. Subsequent iterations can then reuse this cached kernel since they perform the same exact sequence of operations. Kernel caching tends to be crucial to achieve good performance, and it is good to be aware of pitfalls that can effectively disable it.
For example, when you update a scalar (e.g. a Python
int
) in each loop iteration, this changing counter might be merged into the generated program, forcing the system to re-generate and re-compile code at every iteration, and this can ultimately dominate the execution time. If in doubt, increase the log level of Dr.Jit (drjit.set_log_level()
todrjit.LogLevel.Info
) and check if the kernels being launched contain the termcache miss
. You can also inspect the Kernels launched line in the output ofdrjit.whos()
. If you observe soft or hard misses at every loop iteration, then kernel caching isn’t working and you should carefully inspect your code to ensure that the computation stays consistent across iterations.When the loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.
Alternatively, you may specify the parameter
compress=True
or set the flagdrjit.JitFlag.CompressLoops
, which causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases, which is why it isn’t enabled by default.
A separate section about symbolic and evaluated modes discusses these two options in further detail.
The
drjit.while_loop()
function chooses the evaluation mode as follows:When the
mode
argument is set toNone
(the default), the function examines the loop condition. It uses scalar mode when this produces a Python bool, otherwise it inspects thedrjit.JitFlag.SymbolicLoops
flag to switch between symbolic (the default) and evaluated mode.To change this automatic choice for a region of code, you may specify the
mode=
keyword argument, nest code into adrjit.scoped_set_flag()
block, or change the behavior globally viadrjit.set_flag()
:with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False): # .. nested code will use evaluated loops ..
When
mode
is set to"scalar"
"symbolic"
, or"evaluated"
, it directly uses that method without inspecting the compilation flags or loop condition type.
When using the
@drjit.syntax
decorator to automatically convert Pythonwhile
loops intodrjit.while_loop()
calls, you can also use thedrjit.hint()
function to pass keyword arguments includingmode
,label
, ormax_iterations
to the generated looping construct:while dr.hint(i < 10, name='My loop', mode='evaluated'): # ...
Assumptions
The loop condition function must be pure (i.e., it should never modify the state variables or any other kind of program state). The loop body should not write to variables besides the officially declared state variables:
y = .. def loop_body(x): y[0] += x # <-- don't do this. 'y' is not a loop state variable dr.while_loop(body=loop_body, ...)
Dr.Jit automatically tracks dependencies of indirect reads (done via
drjit.gather()
) and indirect writes (done viadrjit.scatter()
,drjit.scatter_reduce()
,drjit.scatter_add()
,drjit.scatter_inc()
, etc.). Such operations create implicit inputs and outputs of a loop, and these do not need to be specified as loop state variables (however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within loops), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).y = .. def loop_body(x): # Scattering to 'y' is okay even if it is not declared as loop state dr.scatter(target=y, value=x, index=0)
Another important assumption is that the loop state remains consistent across iterations, which means:
The type of state variables is not allowed to change. You may not declare a Python
float
before a loop and then overwrite it with adrjit.cuda.Float
(or vice versa).Their structure/size must be consistent. The loop body may not turn a variable with 3 entries into one that has 5.
Analogously, state variables must always be initialized prior to the loop. This is the case even if you know that the loop body is guaranteed to overwrite the variable with a well-defined result. An initial value of
None
would violate condition 1 (type invariance), while an empty array would violate condition 2 (shape compatibility).
The implementation will check for violations and, if applicable, raise an exception identifying problematic state variables.
Potential pitfalls
Long compilation times.
In the example below,
i < 100000
is scalar, causingdrjit.while_loop()
to use the scalar evaluation strategy that effectively copy-pastes the loop body 100000 times to produce a giant program. Code written in this way will be bottlenecked by the CUDA/LLVM compilation stage.@dr.syntax def f(): i = 0 while i < 100000: # .. costly computation i += 1
Incorrect behavior in symbolic mode.
Let’s fix the above program by casting the loop condition into a Dr.Jit type to ensure that a symbolic loop is used. Problem solved, right?
from drjit.cuda import Bool @dr.syntax def f(): i = 0 while Bool(i < 100000): # .. costly computation i += 1
Unfortunately, no: this loop never terminates when run in symbolic mode. Symbolic mode does not track modifications of scalar/non-Dr.Jit types across loop iterations such as the
int
-valued loop counteri
. It’s as if we had writtenwhile Bool(0 < 100000)
, which of course never finishes.Evaluated mode does not have this problem—if your loop behaves differently in symbolic and evaluated modes, then some variation of this mistake is likely to blame. To fix this, we must declare the loop counter as a vector type before the loop and then modify it as follows:
from drjit.cuda import Int @dr.syntax def f(): i = Int(0) while i < 100000: # .. costly computation i += 1
Warning
This new implementation of the
drjit.while_loop()
abstraction still lacks the functionality tobreak
orreturn
from the loop, or tocontinue
to the next loop iteration. We plan to add these capabilities in the near future.Interface
- Parameters:
state (tuple) – A tuple containing the initial values of the loop’s state variables. This tuple normally consists of Dr.Jit arrays or PyTrees. Other values are permissible as well and will be forwarded to the loop body. However, such variables will not be captured by the symbolic tracing process.
cond (Callable) – a function/callable that will be invoked with
*state
(i.e., the state variables will be unpacked and turned into function arguments). It should return a scalar Pythonbool
or a boolean-typed Dr.Jit array representing the loop condition.body (Callable) – a function/callable that will be invoked with
*state
(i.e., the state variables will be unpacked and turned into function arguments). It should update the loop state and then return a new tuple of state variables that are compatible with the previous state (see the earlier description regarding what such compatibility entails).mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"scalar"
,"symbolic"
,"evaluated"
. If not specified, the function first checks if the loop is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicLoops
and then either performs a symbolic or an evaluated loop.compress (Optional[bool]) – Set this parameter to
True
orFalse
to enable or disable loop state compression in evaluated loops (see the text above for a description of this feature). The function queries the value ofdrjit.JitFlag.CompressLoops
when the parameter is not specified. Symbolic loops ignore this parameter.labels (list[str]) – An optional list of labels associated with each
state
entry. Dr.Jit uses this to provide better error messages in case of a detected inconsistency. The@drjit.syntax
decorator automatically provides these labels based on the transformed code.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
max_iterations (int) – The maximum number of loop iterations (default:
-1
). You must specify a correct upper bound here if you wish to differentiate the loop in reverse mode. In that case, the maximum iteration count is used to reserve memory to store intermediate loop state.strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of
drjit.hint()
for an example. The default isstrict=True
.
- Returns:
The function returns the final state of the loop variables following termination of the loop.
- Return type:
tuple
- drjit.if_stmt(args: tuple[*Ts], cond: AnyArray | bool, true_fn: typing.Callable[[*Ts], T], false_fn: typing.Callable[[*Ts], T], arg_labels: typing.Sequence[str] = (), rv_labels: typing.Sequence[str] = (), label: str | None = None, mode: typing.Literal['scalar', 'symbolic', 'evaluated', None] = None, strict: bool = True) T ¶
Conditionally execute code.
Motivation
This function provides a vectorized generalization of a standard Python
if
statement. For example, consider the following Python snippeti: int = .. some expression .. if i > 0: x = f(i) # <-- some costly function 'f' that depends on 'i' else: y += 1
This code would fail if
i
is replaced by an array containing multiple entries (e.g., of typedrjit.llvm.Int
). In that case, the conditional expression produces a boolean array of per-component comparisons that are not necessarily consistent with each other. In other words, some of the entries may want to run the body of theif
statement, while others must skip to theelse
block. This is not compatible with the semantics of a standard Pythonif
statement.The
drjit.if_stmt()
function realizes a more fine-grained conditional operation that accommodates these requirements, while avoiding execution of the costly branch unless this is truly needed. It takes the following input arguments:cond
, a boolean array that specifies whether the body of theif
statement should execute.A tuple of input arguments (
args
) that will be forwarded totrue_fn
andfalse_fn
. It is important to specify all inputs to ensure correct derivative tracking of the operation.true_fn
, a callable that implements the body of theif
block.false_fn
, a callable that implements the body of theelse
block.
The implementation will invoke
true_fn(*args)
andfalse_fn(*args)
to trace their contents. The return values of these functions must be compatible with each other (a precise definition of compatibility is described below). A vectorized version of the earlier example can then be written as follows:x, y = dr.if_stmt( args=(i, x, y), cond=i > 0, true_fn=lambda i, x, y: (f(i), y), false_fn=lambda i, x, y: (x, y + 1) )
Lambda functions are convenient when
true_fn
andfalse_fn
are simple enough to fit onto a single line. In general you may prefer to define local functions (def true_fn(i, x, y): ...
) and pass them to thetrue_fn
andfalse_fn
arguments.Dr.Jit later optimizes away superfluous inputs/outputs of
drjit.if_stmt()
, so there isn’t any harm in, e.g., specifying an identical element of a return value in bothtrue_fn
andfalse_fn
.Dr.Jit also provides the
@drjit.syntax
decorator, which automatically rewrites standard Python control flow constructs into the form shown above. It combines vectorization with the readability of natural Python syntax and is the recommended way of (indirectly) usingdrjit.if_stmt()
. With this decorator, the above example would be written as follows:@dr.syntax def f(i, x, y): if i > 0: x = f(i) else: y += 1 return x, y
Evaluation modes
Dr.Jit uses one of three different modes to realize this operation depending on the inputs and active compilation flags (the text below this overview will explain how this mode is automatically selected).
Scalar mode: Scalar
if
statements that don’t need any vectorization can be reduced to normal Python branching constructs:if cond: state = true_fn(*args) else: state = false_fn(*args)
This strategy is the default when
cond
is a scalar Pythonbool
.Symbolic mode: Dr.Jit runs
true_fn
andfalse_fn
to capture the computation performed by each function, which allows it to generate an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.Evaluated mode: in this mode, Dr.Jit runs both branches of the
if
statement and then combines the results viadrjit.select()
. This is nearly equivalent to the following Python code:true_state = true_fn(*state) false_state = false_fn(*state) if false_fn else state state = dr.select(cond, true_fn, false_fn)
(In practice, the implementation does a few additional things like suppressing side effects associated with inactive entries.)
Evaluated mode is conceptually simpler but also slower, since the device executes both sides of a branch when only one of them is actually needed.
The mode is chosen as follows:
When the
mode
argument is set toNone
(the default), the function examines the type of thecond
input and uses scalar mode if the type is a builtin Pythonbool
.Otherwise, it chooses between symbolic and evaluated mode based on the
drjit.JitFlag.SymbolicConditionals
flag, which is set by default. To change this choice for a region of code, you may specify themode=
keyword argument, nest it into adrjit.scoped_set_flag()
block, or change the behavior globally viadrjit.set_flag()
:with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False): # .. nested code will use evaluated mode ..
When
mode
is set to"scalar"
"symbolic"
, or"evaluated"
, it directly uses that mode without inspecting the compilation flags or condition type.
When using the
@drjit.syntax
decorator to automatically convert Pythonif
statements intodrjit.if_stmt()
calls, you can also use thedrjit.hint()
function to pass keyword arguments including themode
andlabel
parameters.if dr.hint(i < 10, mode='evaluated'): # ...
Assumptions
The return values of
true_fn
andfalse_fn
must be of the same type. This requirement applies recursively if the return value is a PyTree.Dr.Jit will refuse to compile vectorized conditionals, in which
true_fn
andfalse_fn
return a scalar that is inconsistent between the branches.>>> @dr.syntax ... def (x): ... if x > 0: ... y = 1 ... else: ... y = 0 ... return y ... >>> print(f(dr.llvm.Float(-1,2))) RuntimeError: dr.if_stmt(): detected an inconsistency when comparing the return values of 'true_fn' and 'false_fn': drjit.detail.check_compatibility(): inconsistent scalar Python object of type 'int' for field 'y'. Please review the interface and assumptions of dr.if_stmt() as explained in the Dr.Jit documentation.
The problem can be solved by assigning an instance of a capitalized Dr.Jit type (e.g.,
y=Int(1)
) so that the operation can be tracked.The functions
true_fn
andfalse_fn
should not write to variables besides the explicitly declared return value(s):vec = drjit.cuda.Array3f(1, 2, 3) def true_fn(x): vec.x += x # <-- don't do this. 'y' is not a declared output dr.if_stmt(args=(x,), true_fun=true_fn, ...)
This example can be fixed as follows:
def true_fn(x, vec): vec.x += x return vec vec = dr.if_stmt(args=(x, vec), true_fun=true_fn, ...)
drjit.if_stmt()
is differentiable in both forward and reverse modes. Correct derivative tracking requires that regular differentiable inputs are specified via theargs
parameter. The@drjit.syntax
decorator ensures that these assumptions are satisfied.Dr.Jit also tracks dependencies of indirect reads (done via
drjit.gather()
) and indirect writes (done viadrjit.scatter()
,drjit.scatter_reduce()
,drjit.scatter_add()
,drjit.scatter_inc()
, etc.). Such operations create implicit inputs and outputs, and these do not need to be specified as part ofargs
or the return value oftrue_fn
andfalse_fn
(however, doing so causes no harm.) This auto-discovery mechanism is helpful when performing vectorized methods calls (within conditional statements), where the set of implicit inputs and outputs can often be difficult to know a priori. (in principle, any public/private field in any instance could be accessed in this way).y = .. def true_fn(x): # 'y' is neither declared as input nor output of 'f', which is fine dr.scatter(target=y, value=x, index=0) dr.if_stmt(args=(x,), true_fn=true_fn, ...)
Interface
- Parameters:
cond (bool|drjit.ArrayBase) – a scalar Python
bool
or a boolean-valued Dr.Jit array.args (tuple) – A list of positional arguments that will be forwarded to
true_fn
andfalse_fn
.true_fn (Callable) – a callable that implements the body of the
if
block.false_fn (Callable) – a callable that implements the body of the
else
block.mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"scalar"
,"symbolic"
,"evaluated"
.arg_labels (list[str]) – An optional list of labels associated with each input argument. Dr.Jit uses this feature in combination with the
@drjit.syntax
decorator to provide better error messages in case of detected inconsistencies.rv_labels (list[str]) – An optional list of labels associated with each element of the return value. This parameter should only be specified when the return value is a tuple. Dr.Jit uses this feature in combination with the
@drjit.syntax
decorator to provide better error messages in case of detected inconsistencies.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
strict (bool) – You can specify this parameter to reduce the strictness of variable consistency checks performed by the implementation. See the documentation of
drjit.hint()
for an example. The default isstrict=True
.
- Returns:
Combined return value mixing the results of
true_fn
andfalse_fn
.- Return type:
object
- drjit.switch(index: object, targets: collections.abc.Sequence, *args, **kwargs) object ¶
Selectively invoke functions based on a provided index array.
When called with a scalar
index
(of typeint
), this function is equivalent to the following Python expression:targets[index](*args, **kwargs)
When called with a Dr.Jit index array (specifically, 32-bit unsigned integers), it performs the vectorized equivalent of the above and assembles an array of return values containing the result of all referenced functions. It does so efficiently using at most a single invocation of each function in
targets
.from drjit.llvm import UInt32 res = dr.switch( index=UInt32(0, 0, 1, 1), # <-- selects the function targets=[ # <-- arbitrary list of callables lambda x: x, lambda x: x*10 ], UInt32(1, 2, 3, 4) # <-- argument passed to function ) # res now contains [0, 10, 20, 30]
The function traverses the set of positional (
*args
) and keyword arguments (**kwargs
) to find all Dr.Jit arrays including arrays contained within PyTrees. It routes a subset of array entries to each function as specified by theindex
argument.Dr.Jit will use one of two possible strategies to compile this operation depending on the active compilation flags (see
drjit.set_flag()
,drjit.scoped_set_flag()
):Symbolic mode: Dr.Jit transcribes every function into a counterpart in the generated low-level intermediate representation (LLVM IR or PTX) and targets them via an indirect jump instruction.
This mode is used when
drjit.JitFlag.SymbolicCalls
is set, which is the default.Evaluated mode: Dr.Jit evaluates the inputs
index
,args
,kwargs
viadrjit.eval()
, groups them byindex
, and invokes each function with the subset of inputs that reference it. Callables that are not referenced by any element ofindex
are ignored.In this mode, a
drjit.switch()
statement will cause Dr.Jit to launch a series of kernels processing subsets of the input data (one per function).
A separate section about symbolic and evaluated modes discusses these two options in detail.
To switch the compilation mode locally, use
drjit.scoped_set_flag()
as shown below:with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False): result = dr.switch(..)
When a boolean Dr.Jit array (e.g.,
drjit.llvm.Bool
,drjit.cuda.ad.Bool
, etc.) is specified as last positional argument or as a keyword argument namedactive
, that argument is treated specially: entries of the input arrays associated with aFalse
mask entry are ignored and never passed to the functions. Associated entries of the return value will be zero-initialized. The function will still receive the mask argument as input, but it will always be set toTrue
.Danger
The indices provided to this operation are unchecked by default. Attempting to call functions beyond the end of the
targets
array is undefined behavior and may crash the application, unless such calls are explicitly disabled via theactive
parameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound calls and furthermore report warnings to identify problematic source locations:>>> print(dr.switch(UInt32(0, 100), [lambda x:x], UInt32(1))) Attempted to invoke callable with index 100, but this↵ value must be smaller than 1. (<stdin>:2)
- Parameters:
index (int|drjit.ArrayBase) – a list of indices to choose the functions
targets (Sequence[Callable]) – a list of callables to which calls will be dispatched based on the
index
argument.mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"symbolic"
,"evaluated"
. If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicCalls
and then either performs a symbolic or an evaluated call.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
*args (tuple) – a variable-length list of positional arguments passed to the functions. PyTrees are supported.
**kwargs (dict) – a variable-length list of keyword arguments passed to the functions. PyTrees are supported.
- Returns:
When
index
is a scalar Python integer, the return value simply forwards the return value of the selected function. Otherwise, the function returns a Dr.Jit array or PyTree combining the results from each referenced callable.- Return type:
object
- drjit.dispatch(inst: drjit.ArrayBase, target: collections.abc.Callable, *args, **kwargs) object ¶
Invoke a provided Python function for each instance in an instance array.
This function invokes the provided
target
for each instance in the instance arrayinst
and assembles the return values into a result array. Conceptually, it does the following:def dispatch(inst, target, *args, **kwargs): result = [] for inst_ in inst: result.append(target(inst_, *args, **kwargs))
However, the implementation accomplishes this more efficiently using only a single call per unique instance. Instead of a Python
list
, it returns a Dr.Jit array or PyTree.In practice, this function is mainly good for two things:
Dr.Jit instance arrays contain C++ instance, and these will typically expose a set of methods. Adding further methods requires re-compiling C++ code and adding bindings, which may impede quick prototyping. With
drjit.dispatch()
, a developer can quickly implement additional vectorized method calls within Python (with the caveat that these can only access public members of the underlying type).Dynamic dispatch is a relatively costly operation. When multiple calls are performed on the same set of instances, it may be preferable to merge them into a single and potentially significantly faster use of
drjit.dispatch()
. An example is shown below:inst = # .. Array of C++ instances .. result_1 = inst.func_1(arg1) result_2 = inst.func_2(arg2)
The following alternative implementation instead uses
drjit.dispatch()
:def my_func(self, arg1, arg2): return (self.func_1(arg1), self.func_2(arg2)) result_1, result_2 = dr.dispatch(inst, my_func, arg1, arg2)
This function is otherwise very similar to
drjit.switch()
and similarly provides two different compilation modes, differentiability, and special handling of mask arguments. Please review the documentation ofdrjit.switch()
for details.- Parameters:
inst (drjit.ArrayBase) – a Dr.Jit instance array.
target (Callable) – function to dispatch on all instances
mode (Optional[str]) – Specify this parameter to override the evaluation mode. Possible values besides
None
are:"symbolic"
,"evaluated"
. If not specified, the function first checks if the index is potentially scalar, in which case it uses a trivial fallback implementation. Otherwise, it queries the state of the Jit flagdrjit.JitFlag.SymbolicCalls
and then either performs a symbolic or an evaluated call.label (Optional[str]) – An optional descriptive name. If specified, Dr.Jit will include this label in generated low-level IR, which can be helpful when debugging the compilation of large programs.
*args (tuple) – a variable-length list of positional arguments passed to the function. PyTrees are supported.
**kwargs (dict) – a variable-length list of keyword arguments passed to the function. PyTrees are supported.
- Returns:
A Dr.Jit array or PyTree containing the result of each performed function call.
- Return type:
object
Scatter/gather operations¶
- drjit.gather(dtype: type[T], source: object, index: AnyArray | Sequence[int] | int, active: AnyArray | Sequence[bool] | bool = True, mode: drjit.ReduceMode = drjit.ReduceMode.Auto, shape: tuple[int, ...] | None = None) T ¶
Gather values from a flat array or nested data structure.
This function performs a gather (i.e., indirect memory read) from
source
at positionindex
. It expects adtype
argument and will return an instance of this type. The optionalactive
argument can be used to disable some of the components, which is useful when not all indices are valid; the corresponding output will be zero in this case.This operation can be used in the following different ways:
When
dtype
is a 1D Dr.Jit array likedrjit.llvm.ad.Float
, this operation implements a parallelized version of the Python array indexing expressionsource[index]
with optional masking. Example:source = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted result = dr.gather(dtype=type(source), source=source, index=index)
When
dtype
is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
type(source)
matchesdtype
, the gather operation threads through entries and invokes itself recursively. For example, the gather operation insource = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) result = dr.gather(dr.cuda.Array3f, source, index)
is equivalent to
result = dr.cuda.Array3f( dr.gather(dr.cuda.Float, source.x, index), dr.gather(dr.cuda.Float, source.y, index), dr.gather(dr.cuda.Float, source.z, index) )
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation reconstructs the requested
dtype
from a flatsource
array, using C-style ordering with a suitably modifiedindex
. For example, the gather below reads 3D vectors from a 1D array.source = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) result = dr.gather(dr.cuda.Array3f, source, index)
and is equivalent to
result = dr.cuda.Array3f( dr.gather(dr.cuda.Float, source, index*3 + 0), dr.gather(dr.cuda.Float, source, index*3 + 1), dr.gather(dr.cuda.Float, source, index*3 + 2))
By default, Dr.Jit bundles sequences of gather operations that access a power-of-two number of contiguous elements (e.g.
dr.gather(Array4f, ...)
ordr.gather(ArrayXf, ..., shape=(16, N))
into one or more packet loads (the precise number depending on the hardware’s capabilities) that is potentially significantly faster. This optimization can be controlled via thedrjit.JitFlag.PacketOps
flag.
Danger
The indices provided to this operation are unchecked by default. Attempting to read beyond the end of the
source
array is undefined behavior and may crash the application, unless such reads are explicitly disabled via theactive
parameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound reads and furthermore report warnings to identify problematic source locations:>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100)) drjit.gather(): out-of-bounds read from position 100 in an array↵ of size 3. (<stdin>:2)
- Parameters:
dtype (type) – The desired output type (typically equal to
type(source)
, but other variations are possible as well, see the description above.)source (object) – The object from which data should be read (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXb
ordrjit.cuda.Bool
) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue
.mode (drjit.ReduceMode) – The reverse-mode derivative of a gather is an atomic scatter-reduction. The execution of such atomics can be rather performance-sensitive (see the discussion of
drjit.ReduceMode
for details), hence Dr.Jit offers a few different compilation strategies to realize them. Specifying this parameter selects a strategy for the derivative of a particular gather operation. The default isdrjit.ReduceMode.Auto
.shape (tuple[int, ...] | None) – When gathering into a dynamically sized array type (e.g.
drjit.cuda.ArrayXf
), this parameter can be used to specify the shape of the unknown dimensions. Otherwise, it is not needed. The default isNone
.
- drjit.scatter(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None ¶
Scatter values into a flat array or nested data structure.
This operation performs a scatter (i.e., indirect memory write) of the
value
parameter to thetarget
array at positionindex
. The optionalactive
argument can be used to disable some of the individual write operations, which is useful when not all provided values or indices are valid.This operation can be used in the following different ways:
When
target
is a 1D Dr.Jit array likedrjit.llvm.ad.Float
, this operation implements a parallelized version of the Python array indexing expressiontarget[index] = value
with optional masking. Example:target = dr.empty(dr.cuda.Float, 1024*1024) value = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted dr.scatter(target, value=value, index=index)
When
target
is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
target
andvalue
are of the same type, the scatter operation threads through entries and invokes itself recursively. For example, the scatter operation intarget = dr.cuda.Array3f(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter(target, value, index)
is equivalent to
dr.scatter(target.x, value.x, index) dr.scatter(target.y, value.y, index) dr.scatter(target.z, value.z, index)
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation flattens the
value
array and writes it using C-style ordering with a suitably modifiedindex
. For example, the scatter below writes 3D vectors into a 1D array.target = dr.cuda.Float(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter(target, value, index)
and is equivalent to
dr.scatter(target, value.x, index*3 + 0) dr.scatter(target, value.y, index*3 + 1) dr.scatter(target, value.z, index*3 + 2)
By default, Dr.Jit bundles sequences of scatter operations that write a power-of-two number of contiguous elements into one or more packet stores (the precise number depending on the hardware’s capabilities) that is potentially significantly faster. This optimization can be controlled via the
drjit.JitFlag.PacketOps
flag.
Danger
The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the
active
parameter). Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.Dr.Jit makes no guarantees about the expected behavior when a scatter operation has conflicts, i.e., when a specific position is written multiple times by a single
drjit.scatter()
operation.- Parameters:
target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
value (object) – The values to be written (typically of type
type(target)
, but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the input is not an array type.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXb
ordrjit.cuda.Bool
) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue
.
- drjit.scatter_reduce(op: drjit.ReduceOp, target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None ¶
Atomically update values in a flat array or nested data structure.
This function performs an atomic scatter-reduction, which is a read-modify-write operation that applies one of several possible mathematical functions to selected entries of an array. The following are supported:
drjit.ReduceOp.Add
:a=a+b
.drjit.ReduceOp.Max
:a=max(a, b)
.drjit.ReduceOp.Min
:a=min(a, b)
.drjit.ReduceOp.Or
:a=a | b
(integer arrays only).drjit.ReduceOp.And
:a=a & b
(integer arrays only).
Here,
a
refers to an entry oftarget
selected byindex
, andb
denotes the associated element ofvalue
. The operation resolves potential conflicts arising due to the parallel execution of this operation.The optional
active
argument can be used to disable some of the updates, e.g., when not all provided values or indices are valid.Atomic additions are subject to non-deterministic rounding errors. The reason for this is that IEEE-754 addition are non-commutative. The execution order is scheduling-dependent, which can lead to small variations across program runs.
Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude. Dr.Jit provides several different compilation strategies to reduce these costs, which can be selected via the
mode
parameter. The documentation ofdrjit.ReduceMode
provides more detail and performance plots.Backend support
Many combinations of reductions and variable types are not supported. Some combinations depend on the compute capability (CC) of the underlying CUDA device or on the LLVM version (LV) and the host architecture (AMD64, x86_64). The following matrices display the level of support.
For CUDA:
Reduction
Bool
[U]Int{32,64}
Float16
Float32
Float64
✅
✅
✅
✅
✅
❌
✅
⚠️ CC≥60
✅
⚠️ CC≥60
❌
❌
❌
❌
❌
❌
✅
⚠️ CC≥90
❌
❌
❌
✅
⚠️ CC≥90
❌
❌
❌
✅
❌
❌
❌
❌
✅
❌
❌
❌
For LLVM:
Reduction
Bool
[U]Int{32,64}
Float16
Float32
Float64
✅
✅
✅
✅
✅
❌
✅
⚠️ LV≥16
✅
✅
❌
❌
❌
❌
❌
❌
⚠️ LV≥15
⚠️ LV≥16, ARM64
⚠️ LV≥15
⚠️ LV≥15
❌
⚠️ LV≥15
⚠️ LV≥16, ARM64
⚠️ LV≥15
⚠️ LV≥15
❌
✅
❌
❌
❌
❌
✅
❌
❌
❌
The function raises an exception when the operation is not supported by the backend.
Scatter-reducing nested types
This operation can be used in the following different ways:
When
target
is a 1D Dr.Jit array likedrjit.llvm.ad.Float
, this operation implements a parallelized version of the Python array indexing expressiontarget[index] = op(target[index], value)
with optional masking. Example:target = dr.zeros(dr.cuda.Float, 1024*1024) value = dr.cuda.Float([...]) index = dr.cuda.UInt([...]) # Note: negative indices are not permitted dr.scatter_reduce(dr.ReduceOp.Add, target, value=value, index=index)
When
target
is a more complex type (e.g. a nested Dr.Jit array or PyTree), the behavior depends:When
target
andvalue
are of the same type, the scatter-reduction threads through entries and invokes itself recursively. For example, the scatter operation inop = dr.ReduceOp.Add target = dr.cuda.Array3f(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter_reduce(op, target, value, index)
is equivalent to
dr.scatter_reduce(op, target.x, value.x, index) dr.scatter_reduce(op, target.y, value.y, index) dr.scatter_reduce(op, target.z, value.z, index)
A similar recursive traversal is used for other kinds of sequences, mappings, and custom data structures.
Otherwise, the operation flattens the
value
array and writes it using C-style ordering with a suitably modifiedindex
. For example, the scatter-reduction below writes 3D vectors into a 1D array.op = dr.ReduceOp.Add target = dr.cuda.Float(...) value = dr.cuda.Array3f(...) index = dr.cuda.UInt([...]) dr.scatter_reduce(op, target, value, index)
and is equivalent to
dr.scatter_reduce(op, target, value.x, index*3 + 0) dr.scatter_reduce(op, target, value.y, index*3 + 1) dr.scatter_reduce(op, target, value.z, index*3 + 2)
Dr.Jit may be able to bundle sequences of scatter-reductions that write a power-of-two number of contiguous elements into one or more packet scatter-updates that are potentially significantly faster. This optimization can be controlled via the
drjit.JitFlag.PacketOps
flag. Currently, this is only possible in a narrow set of circumstances requiring:use of the LLVM backend.
2. integer or floating point scatter-additions. 2. use of the
drjit.ReduceMode.Expand
reduction strategy.
Danger
The indices provided to this operation are unchecked by default. Out-of-bound writes are considered undefined behavior and may crash the application (unless they are disabled via the
active
parameter). Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These will catch out-of-bound writes and print an error message identifying the responsible line of code.Dr.Jit makes no guarantees about the relative ordering of atomic operations when a
drjit.scatter_reduce()
writes to the same element multiple times. Combined with the non-associate nature of floating point operations, concurrent writes will generally introduce non-deterministic rounding error.- Parameters:
op (drjit.ReduceOp) – Specifies the type of update that should be performed.
target (object) – The object into which data should be written (typically a 1D Dr.Jit array, but other variations are possible as well, see the description above.)
value (object) – The values to be used in the RMW operation (typically of type
type(target)
, but other variations are possible as well, see the description above.) Dr.Jit will attempt an implicit conversion if the the input is not an array type.index (object) – a 1D dynamic unsigned 32-bit Dr.Jit array (e.g.,
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
) specifying gather indices. Dr.Jit will attempt an implicit conversion if another type is provided.active (object) – an optional 1D dynamic Dr.Jit mask array (e.g.,
drjit.scalar.ArrayXb
ordrjit.cuda.Bool
) specifying active components. Dr.Jit will attempt an implicit conversion if another type is provided. The default isTrue
.mode (drjit.ReduceMode) – Dr.Jit offers several different strategies to implement atomic scatter-reductions that can be selected via this parameter. They achieve different best/worst case performance and, in the case of
drjit.ReduceMode.Expand
, involve additional memory storage overheads. The default isdrjit.ReduceMode.Auto
.
- drjit.scatter_add(target: object, value: object, index: object, active: object = True, mode: drjit.ReduceMode = ReduceMode.Auto) None ¶
Atomically add values to a flat array or nested data structure.
This function is equivalent to
drjit.scatter_reduce(drjit.ReduceOp.Add, ...)
and exists for reasons of convenience. Please refer todrjit.scatter_reduce()
for details on atomic scatter-reductions.
- drjit.scatter_add_kahan(target_1: drjit.ArrayBase, target_2: drjit.ArrayBase, value: object, index: object, active: object = True) None ¶
Perform a Kahan-compensated atomic scatter-addition.
Atomic floating point accumulation can incur significant rounding error when many values contribute to a single element. This function implements an error-compensating version of
drjit.scatter_add()
based on the Kahan-Babuška-Neumeier algorithm that simultaneously accumulates into two target buffers.In particular, the operation first accumulates a values into entries of a dynamic 1D array
target_1
. It tracks the round-off error caused by this operation and then accumulates this error into a second 1D array namedtarget_2
. At the end, the buffers can simply be added together to obtain the error-compensated result.This function has a few limitations: in contrast to
drjit.scatter_reduce()
anddrjit.scatter_add()
, it does not perform a local reduction (see flagJitFlag.ScatterReduceLocal
), which can be an important optimization when atomic accumulation is a performance bottleneck.Furthermore, the function currently works with flat 1D arrays. This is an implementation limitation that could in principle be removed in the future.
Finally, the function is differentiable, but derivatives currently only propagate into
target_1
. This means that forward derivatives don’t enjoy the error compensation of the primal computation. This limitation is of no relevance for backward derivatives.
- drjit.scatter_inc(target: drjit.ArrayBase, index: object, active: object = True) object ¶
Atomically increment a value within an unsigned 32-bit integer array and return the value prior to the update.
This operation works just like the
drjit.scatter_reduce()
operation for 32-bit unsigned integer operands, but with a fixedvalue=1
parameter andop=ReduceOp::Add
.The main difference is that this variant additionally returns the old value of the target array prior to the atomic update in contrast to the more general scatter-reduction, which just returns
None
. The operation also supports masking—the return value in the masked case is undefined. Bothtarget
andindex
parameters must be 1D unsigned 32-bit arrays.This operation is a building block for stream compaction: threads can scatter-increment a global counter to request a spot in an array and then write their result there. The recipe for this is look as follows:
data_1 = ... data_2 = ... active = drjit.ones(Bool, len(data_1)) # .. or a more complex condition # This will hold the counter ctr = UInt32(0) # Allocate output buffers max_size = 1024 data_compact_1 = dr.empty(Float, max_size) data_compact_2 = dr.empty(Float, max_size) idx = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) # Disable dr.scatter() operations below in case of a buffer overflow active &= idx < max_size dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.scatter( target=data_compact_2, value=data_2, index=my_index, mask=active )
When following this approach, be sure to provide the same mask value to the
drjit.scatter_inc()
and subsequentdrjit.scatter()
operations.The function
drjit.reshape()
can be used to resize the resulting arrays to their compacted size. Please refer to the documentation of this function, specifically the code example illustrating the use of theshrink=True
argument.The function
drjit.scatter_inc()
exhibits the following unusual behavior compared to regular Dr.Jit operations: the return value references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the computed index is reused. In essence, the variable is “consumed” by the process of evaluation.my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.eval(data_compact_1) # Run Kernel #1 dr.scatter( target=data_compact_2, value=data_2, index=my_index, # <-- oops, reusing my_index in another kernel. mask=active # This raises an exception. )
To get the above code to work, you will need to evaluate
my_index
at the same time to materialize it into a stored (and therefore trivially reproducible) representation. For this, ensure that the size of theactive
mask matcheslen(data_*)
and that it is not the trivialTrue
default mask (otherwise, the evaluatedmy_index
will be scalar).dr.eval(data_compact_1, my_index)
Such multi-stage evaluation is potentially inefficient and may defeat the purpose of performing stream compaction in the first place. In general, prefer keeping all scatter operations involving the computed index in the same kernel, and then this issue does not arise.
The implementation of
drjit.scatter_inc()
performs a local reduction first, followed by a single atomic write per SIMD packet/warp. This is done to reduce contention from a potentially very large number of atomic operations targeting the same memory address. Fully masked updates do not cause memory traffic.There is some conceptual overlap between this function and
drjit.compress()
, which can likewise be used to reduce a stream to a smaller subset of active items. The downside ofdrjit.compress()
is that it requires evaluating the variables to be reduced, which can be very costly in terms of memory traffic and storage footprint. Reducing throughdrjit.scatter_inc()
does not have this limitation: it can operate on symbolic arrays that greatly exceed the available device memory. One advantage ofdrjit.compress()
is that it essentially boils down to a relatively simple prefix sum, which does not require atomic memory operations (these can be slow in some cases).
- drjit.slice(value: object, index: object = 0) object ¶
Select a subset of the input array or PyTree along the trailing dynamic dimension.
Given a Dr.Jit array
value
with shape(..., N)
(whereN
represents a dynamically sized dimension), this operation effectively evaluates the expressionvalue[..., index]
. It recursively traverses PyTrees and transforms each compatible array element. Other values are returned unchanged.The following properties of
index
determine the return type:When
index
is a 1D integer array, the operation reduces to one or more calls todrjit.gather()
, andslice()
returns a reduced output object of the same type and structure.When
index
is a scalar Pythonint
, the trailing dimension is entirely removed, and the operation returns an array from thedrjit.scalar
namespace containing the extracted values.
Reductions¶
- enum drjit.ReduceOp(value)¶
List of different atomic read-modify-write (RMW) operations supported by
drjit.scatter_reduce()
.Valid values are as follows:
- Identity = ReduceOp.Identity¶
Perform an ordinary scatter operation that ignores the current entry.
- Add = ReduceOp.Add¶
Addition.
- Mul = ReduceOp.Mul¶
Multiplication.
- Min = ReduceOp.Min¶
Minimum.
- Max = ReduceOp.Max¶
Maximum.
- And = ReduceOp.And¶
Binary AND operation.
- Or = ReduceOp.Or¶
Binary OR operation.
- enum drjit.ReduceMode(value)¶
Compilation strategy for atomic scatter-reductions.
Elements of this enumeration determine how Dr.Jit executes atomic scatter-reductions, which refers to indirect writes that update an existing element in an array, while avoiding problems arising due to concurrency.
Atomic scatter-reductions can have a significant detrimental impact on performance. When many threads in a parallel computation attempt to modify the same element, this can lead to contention—essentially a fight over which part of the processor owns the associated memory region, which can slow down a computation by many orders of magnitude.
This parameter also plays an important role for
drjit.gather()
, which is nominally a read-only operation. This is because the reverse-mode derivative of a gather turns it into an atomic scatter-addition, where further context on how to compile the operation is needed.Dr.Jit implements several strategies to address contention, which can be selected by passing the optional
mode
parameter todrjit.scatter_reduce()
,drjit.scatter_add()
, anddrjit.gather()
.If you find that a part of your program is bottlenecked by atomic writes, then it may be worth explicitly specifying some of the strategies below to see which one performs best.
Valid values are as follows:
- Auto = ReduceMode.Auto¶
Select the first valid option from the following list:
use
drjit.ReduceMode.Expand
if the computation uses the LLVM backend and thetarget
array storage size is smaller or equal than the value given bydrjit.expand_threshold()
. This threshold can be changed using thedrjit.set_expand_threshold()
function.use
drjit.ReduceMode.Local
ifdrjit.JitFlag.ScatterReduceLocal
is set.fall back to
drjit.ReduceMode.Direct
.
- Direct = ReduceMode.Direct¶
Insert an ordinary atomic reduction operation into the program.
This mode is ideal when no or little contention is expected, for example because the target indices of scatters are well spread throughout the target array. This mode generates a minimal amount of code, which can help improve performance especially on GPU backends.
- Local = ReduceMode.Local¶
Locally pre-reduce operands.
In this mode, Dr.Jit adds extra code to the compiled program to examine the target indices of atomic updates. For example, CUDA programs run with an instruction granularity referred to as a warp, which is a group of 32 threads. When some of these threads want to write to the same location, then those operands can be pre-processed to reduce the total number of necessary atomic memory transactions (potentially to just a single one!)
On the CPU/LLVM backend, the same process works at the granularity of packets. The details depends on the underlying instruction set—for example, there are 16 threads per packet on a machine with AVX512, so there is a potential for reducing atomic write traffic by that factor.
- NoConflicts = ReduceMode.NoConflicts¶
Perform a non-atomic read-modify-write operation.
This mode is only safe in specific situations. The caller must guarantee that there are no conflicts (i.e., scatters targeting the same elements). If specified, Dr.Jit will generate a non-atomic read-modify-update operation that potentially runs significantly faster, especially on the LLVM backend.
- Permute = ReduceMode.Permute¶
In contrast to prior enumeration entries, this one modifies plain (non-reductive) scatters and gathers. It exists to enable internal optimizations that Dr.Jit uses when differentiating vectorized function calls and compressed loops. You likely should not use it in your own code.
When setting this mode, the caller guarantees that there will be no conflicts, and that every entry is written exactly single time using an index vector representing a permutation (it’s fine this permutation is accomplished by multiple separate write operations, but there should be no more than 1 write to each element).
Giving ‘Permute’ as an argument to a (nominally read-only) gather operation is helpful because we then know that the reverse-mode derivative of this operation can be a plain scatter instead of a more costly atomic scatter-add.
Giving ‘Permute’ as an argument to a scatter operation is helpful because we then know that the forward-mode derivative does not depend on any prior derivative values associated with that array, as all current entries will be overwritten.
- Expand = ReduceMode.Expand¶
Expand the target array to avoid write conflicts, then scatter non-atomically.
This feature is only supported on the LLVM backend. Other backends interpret this flag as if
drjit.ReduceMode.Auto
had been specified.This mode internally expands the storage underlying the target array to a much larger size that is proportional to the number of CPU cores. Scalar (length-1) target arrays are expanded even further to ensure that each CPU gets an entirely separate cache line.
Following this one-time expansion step, the array can then accommodate an arbitrary sequence of scatter-reduction operations that the system will internally perform using non-atomic read-modify-write operations (i.e., analogous to the
NoConflicts
mode). Dr.Jit automatically re-compress the array into the ordinary representation.On bigger arrays and on machines with many cores, the storage costs resulting from this mode can be prohibitive.
- drjit.reduce(op: ReduceOp, value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None) object ¶
Reduce the input array, tensor, or iterable along the specified axis/axes.
This function reduces arrays, tensors and other iterable Python types along one or multiple axes, where
op
selects the operation to be performed:drjit.ReduceOp.Add
:a[0] + a[1] + ...
.drjit.ReduceOp.Mul
:a[0] * a[1] * ...)
.drjit.ReduceOp.Min
:min(a[0], a[1], ...)
.drjit.ReduceOp.Max
:max(a[0], a[1], ...)
.drjit.ReduceOp.Or
:a[0] | a[1] | ...
(integer arrays only).drjit.ReduceOp.And
:a[0] & a[1] & ...
(integer arrays only).
The functions
drjit.sum()
,drjit.prod()
,drjit.min()
, anddrjit.max()
are convenience aliases that calldrjit.reduce()
with specific values ofop
.By default, the reduction for tensor types is performed over all axes (
axis=None
), while for all other types, the default isaxis=0
(i.e., the outermost one), returning an instance of the array’s element type. For instance, sum-reducing an arraya
of typedrjit.cuda.Array3f
is equivalent to writinga[0] + a[1] + a[2]
and produces a result of typedrjit.cuda.Float
. Dr.Jit can trace this operation and include it in the generated kernel.Negative indices (e.g.
axis=-1
) count backward from the innermost axis. Multiple axes can be specified as a tuple. The valueaxis=None
requests a simultaneous reduction over all axes.When reducing axes of a tensor, or when reducing the trailing dimension of a Jit-compiled array, some special precautions apply: these axes correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. This can be done using the following strategies:
mode="evaluated"
first evaluates the input array viadrjit.eval()
and then launches a specialized reduction kernel.On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions. The LLVM backend parallelizes the reduction via the built-in thread pool.
mode="symbolic"
usesdrjit.scatter_reduce()
to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).Disadvantages of this mode are that
Atomic scatters can suffer from memory contention (though the
drjit.scatter_reduce()
function takes steps to reduce contention, see its documentation for details).Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer reductions and floating point min/max reductions are unaffected by this.
mode=None
(default) automatically picks a reasonable strategy according to the following logic:Use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.
Use evaluated mode when the necessary atomic reduction operation is not supported by the backend.
Otherwise, use symbolic mode.
This function generally strips away reduced axes, but there is one notable exception: it will never remove a trailing dynamic dimension, if present in the input array.
For example, reducing an instance of type
drjit.cuda.Float
along axis0
does not produce a scalar Pythonfloat
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.- Parameters:
op (ReduceOp) – The operation that should be applied along the reduced axis/axes.
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array or tensor.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- drjit.sum(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None) object ¶
Sum-reduce the input array, tensor, or iterable along the specified axis/axes.
This function sum-reduces arrays, tensors and other iterable Python types along one or multiple axes. It is equivalent to
dr.reduce(dr.ReduceOp.Add, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.prod(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None) object ¶
Multiplicatively reduce the input array, tensor, or iterable along the specified axis/axes.
This function performs a multiplicative reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Mul, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.min(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None) object ¶
Perform a minimum reduction of the input array, tensor, or iterable along the specified axis/axes.
(Not to be confused with
drjit.minimum()
, which computes the smaller of two values).This function performs a minimum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Min, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.max(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None) object ¶
Perform a maximum reduction of the input array, tensor, or iterable along the specified axis/axes.
(Not to be confused with
drjit.maximum()
, which computes the larger of two values).This function performs a maximum reduction along one or multiple axes of the provided Dr.Jit array, tensor, or iterable Python types. It is equivalent to
dr.reduce(dr.ReduceOp.Max, ...)
. See the documentation of this function for further information.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- drjit.mean(value: object, axis: int | tuple[int, ...] | ... | None = ..., mode: str | None = None) object ¶
Compute the mean of the input array or tensor along one or multiple axes.
This function performs a horizontal sum reduction by adding values of the input array, tensor, or Python sequence along one or multiple axes and then dividing by the number of entries. The mean of an empty array is considered to be zero.
See the discussion of
dr.reduce()
for important general information about the properties of horizontal reductions.- Parameters:
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.mode (str | None) – optional parameter to force an evaluation strategy. Must equal
"evaluated"
,"symbolic"
, orNone
.
- Returns:
The reduced array or tensor as specified above.
- drjit.all(value: object, axis: int | tuple[int, ...] | ... | None = ...) object ¶
Check if all elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
&
(AND) operator.Reductions along index
0
refer to the outermost axis and negative indices (e.g.-1
) count backwards from the innermost axis. The special argumentaxis=None
causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to beTrue
.The function is internally based on
dr.reduce()
. See the documentation of this function for further information.Like
dr.reduce()
, this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Bool
does not produce a scalar Pythonbool
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
bool
if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.all(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()
to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt()
,while_loop()
, etc. The@dr.syntax
decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...
) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.
- Returns:
The reduced array or tensor as specified above.
- Return type:
object
- drjit.any(value: object, axis: int | tuple[int, ...] | ... | None = ...) object ¶
Check if any elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
|
(OR) operator.Reductions along index
0
refer to the outermost axis and negative indices (e.g.-1
) count backwards from the innermost axis. The special argumentaxis=None
causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to beFalse
.The function is internally based on
dr.reduce()
. See the documentation of this function for further information.Like
dr.reduce()
, this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Bool
does not produce a scalar Pythonbool
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
bool
if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.any(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()
to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt()
,while_loop()
, etc. The@dr.syntax
decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...
) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.
- Returns:
Result of the reduction operation
- Return type:
bool | drjit.ArrayBase
- drjit.none(value: object, axis: int | tuple[int, ...] | ... | None = ...) object ¶
Check if none elements along the specified axis are active.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
|
(OR) operator and finally returns the bit-wise inverse of the result.The function is internally based on
dr.reduce()
. See the documentation of this function for further information.Like
dr.reduce()
, this function does not strip away trailing dynamic dimensions if present in the input array. This means that reducingdrjit.cuda.Bool
does not produce a scalar Pythonbool
. Instead, the operation returns another array of the same type with a single element. This is intentional–unboxing the array into a Python scalar would require transferring the value from the GPU, which would incur costly synchronization overheads. You must explicitly index into the result (result[0]
) to obtain a value with the underlying element type.Boolean 1D arrays automatically convert to
bool
if they only contain a single element. This means that the aforementioned indexing operation happens implicitly in the following fragment:from drjit.cuda import Float x = Float(...) if dr.none(s < 0): # ...
A last point to consider is that reductions along the last / trailing dynamic axis of an array are generally expensive. Its entries correspond to computational threads of a large parallel program that now have to coordinate to determine the reduced value. Normally, this involves
drjit.eval()
to evaluate and store the array in memory and then launch a device-specific reduction kernel. All of these steps interfere with Dr.Jit’s regular mode of operation, which is to capture a maximally large program without intermediate evaluation.To avoid Boolean reductions, one can often use symbolic operations such as
if_stmt()
,while_loop()
, etc. The@dr.syntax
decorator can generate these automatically. For example, the following fragment predicates the execution of the body (# ...
) based on the condition.@dr.syntax def f(x: Float): if a < 0: # ...
- Parameters:
value (ArrayBase | Iterable | bool) – An input Dr.Jit array, tensor, iterable, or scalar Python type.
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.
- Returns:
Result of the reduction operation
- Return type:
bool | drjit.ArrayBase
- drjit.count(value: object, axis: int | tuple[int, ...] | ... | None = ...) object ¶
Compute the number of active entries along the given axis.
Given a boolean-valued input array, tensor, or Python sequence, this function reduces elements using the
+
operator (interpretingTrue
elements as1
andFalse
elements as0
). It returns an unsigned 32-bit version of the input array.Reductions along index
0
refer to the outermost axis and negative indices (e.g.-1
) count backwards from the innermost axis. The special argumentaxis=None
causes a simultaneous reduction over all axes. Note that the reduced form of an empty array is considered to be zero.See the discussion of
dr.reduce()
for important general information about the properties of horizontal operations.- Parameters:
value (bool | Sequence | drjit.ArrayBase) – A Python or Dr.Jit mask type
axis (int | tuple[int, ...] | ... | None) – The axis/axes along which to reduce. The special argument
axis=None
causes a simultaneous reduction over all axes. The defaultaxis=...
applies a reduction over all axes for tensor types and index0
otherwise.
- Returns:
Result of the reduction operation
- Return type:
int | drjit.ArrayBase
- drjit.dot(arg0: object, arg1: object, /) object ¶
Compute the dot product of two arrays.
Whenever possible, the implementation uses a sequence of
fma()
(fused multiply-add) operations to evaluate the dot product. When the input is a 1D JIT array likedrjit.cuda.Float
, the function evaluates the product of the input arrays viadrjit.eval()
and then performs a sum reduction viadrjit.sum()
.See the discussion of
dr.reduce()
for important general information about the properties of horizontal reductions.- Parameters:
arg0 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Dot product of inputs
- Return type:
float | int | drjit.ArrayBase
- drjit.abs_dot(arg0: object, arg1: object, /) object ¶
Compute the absolute value of the dot product of two arrays.
This function implements a convenience short-hand for
abs(dot(arg0, arg1))
.See the discussion of
dr.reduce()
for important general information about the properties of horizontal reductions.- Parameters:
arg0 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (list | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Absolute value of the dot product of inputs
- Return type:
float | int | drjit.ArrayBase
- drjit.squared_norm(arg: object, /) object ¶
Computes the squared 2-norm of a Dr.Jit array, tensor, or Python sequence.
The operation is equivalent to
dr.dot(arg, arg)
The
squared_norm()
operation performs a horizontal reduction. See the discussion ofdr.reduce()
for important general information about their properties.- Parameters:
arg (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
squared 2-norm of the input
- Return type:
float | int | drjit.ArrayBase
- drjit.norm(arg: object, /) object ¶
Computes the 2-norm of a Dr.Jit array, tensor, or Python sequence.
The operation is equivalent to
dr.sqrt(dr.dot(arg, arg))
The
norm()
operation performs a horizontal reduction. See the discussion ofdr.reduce()
for important general information about their properties.- Parameters:
arg (Sequence | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
2-norm of the input
- Return type:
float | int | drjit.ArrayBase
Prefix reductions¶
- drjit.prefix_reduce(op: ReduceOp, value: T, axis: int | tuple[int, ...] = 0, exclusive: bool = True, reverse: bool = False) T ¶
Compute an exclusive or inclusive prefix reduction of the input array, tensor, or iterable along the specified axis/axes.
The function returns an output array of the same shape as the input. The
op
paramater selects the operation to be performed.For example, when reducing a 1D array using
exclusive=True
(the default), this produces the following outputdrjit.ReduceOp.Add
:[0, a[0], a[0] + a[1], ...]
.drjit.ReduceOp.Mul
:[1, a[0], a[0] * a[1], ...]
.drjit.ReduceOp.Min
:[inf, a[0], min(a[0], a[1]), ...]
.drjit.ReduceOp.Max
:[-inf, a[0], max(a[0], a[1]), ...]
.drjit.ReduceOp.Or
:[0, a[0], a[0] | a[1], ...]
(integer arrays only).drjit.ReduceOp.And
:[-1, a[0], a[0] & a[1], ...]
(integer arrays only).
With
inclusive=False
, the function instead performs an inclusive prefix reduction, which effectively shifts the output by one entry:drjit.ReduceOp.Add
:[a[0], a[0] + a[1], ...]
.drjit.ReduceOp.Mul
:[a[0], a[0] * a[1], ...]
.drjit.ReduceOp.Min
:[a[0], min(a[0], a[1]), ...]
.drjit.ReduceOp.Max
:[a[0], max(a[0], a[1]), ...]
.drjit.ReduceOp.Or
:[a[0], a[0] | a[1], ...]
(integer arrays only).drjit.ReduceOp.And
:[a[0], a[0] & a[1], ...]
(integer arrays only).
By default, the reduction is along axis
0
(i.e., the outermost one). Negative indices (e.g.axis=-1
) count backward from the innermost axis. Multiple axes can be specified as a tuple and are handled iteratively.- Parameters:
op (ReduceOp) – The operation that should be applied along the specified axis/axes.
value (ArrayBase | Iterable | float | int) – An input Dr.Jit array or tensor.
axis (int | tuple[int, ...]) – The axis/axes along which to reduce. The default value is
0
.exclusive (bool) – Whether to perform an exclusive (the default) or inclusive prefix reduction.
reverse (bool) – if set to
True
, the prefix reduction is done from the end of the selected axis.
- Returns:
The prefix-reduced array or tensor as specified above. It has the same shape and type as the input.
- drjit.prefix_sum(value: T, axis: int | tuple[int, ...] = 0, reverse: bool = False) T ¶
Compute an exclusive prefix sum of the input array.
This function is a convenience wrapper that internally calls
drjit.prefix_reduce(dr.ReduceOp.Add, arg, exclusive=True, ...)
. Please refer to this function for further detail.
- drjit.cumsum(value: T, axis: int | tuple[int, ...] = 0, reverse: bool = False) T ¶
Compute an cumulative sum (aka. inclusive prefix sum) of the input array.
This function is a convenience wrapper that internally calls
drjit.prefix_reduce(dr.ReduceOp.Add, arg, exclusive=False, ...)
. Please refer to this function for further detail.
Block reductions¶
- drjit.block_reduce(op: ReduceOp, value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T ¶
Reduce elements within blocks.
This function reduces all elements within contiguous blocks of size
block_size
along the trailing dimension of the input arrayvalue
, returning a correspondingly smaller output array. Various types of reductions are supported (seedrjit.ReduceOp
for details).For example, a sum reduction of a hypothetical array
[a, b, c, d, e, f]
withblock_size=2
produces the output[a+b, c+d, e+f]
.The function raises an exception when the length of the trailing dimension is not a multiple of the block size. It recursively threads through nested arrays and PyTrees.
Dr.Jit uses one of two strategies to realize this operation, which can be optionally forced by specifying the
mode
parameter.mode="evaluated"
first evaluates the input array viadrjit.eval()
and then launches a specialized reduction kernel.On the CUDA backend, this kernel makes efficient use of shared memory and cooperative warp instructions. The LLVM backend parallelizes the operation via the built-in thread pool.
This strategy uses an increased intermediate precision (single precision) when reducing half precision arrays.
mode="symbolic"
usesdrjit.scatter_reduce()
to atomically scatter-reduce values into the output array. This strategy can be advantageous when the input array is symbolic (making evaluation impossible) or both unevaluated and extremely large (making evaluation costly or impossible if there isn’t enough memory).Disadvantages of this mode are that
Atomic scatters can suffer from memory contention (though
drjit.scatter_reduce()
takes steps to reduce contention, see its documentation for details).Atomic floating point scatter-addition is subject to non-deterministic rounding errors that arise from its non-commutative nature. Coupled with the scheduling-dependent execution order, this can lead to small variations across program runs. Integer and floating point min/max reductions are unaffected by this.
mode=None
(default) automatically picks a reasonable strategy according to the following logic:Symbolic mode is admissible when the necessary atomic reduction is supported by the backend.
Evaluated mode is admissible when the input does not involve symbolic variables.
If only one strategy remains, then pick that one. Raise an exception when no strategy works out.
Otherwise, use evaluated mode when the input array is already evaluated, or when evaluating it would consume less than 1 GiB of memory.
Use symbolic mode in all other cases.
On the CUDA backend, you may observe speedups you design your program so that it invokes
drjit.block_reduce()
with power-of-two block sizes, as the underlying kernel optimizes this case.Note
This operation traverses PyTrees and transforms any dynamically sized Dr.Jit arrays it encounters. Everything else is left as-is.
Tensors are not supported and will cause an exception to be raised. While
drjit.block-reduce()
is internally used by Dr.Jit when reducing tensors, the function on its own does not support tensor-valued inputs.To reduce blocks within a tensor, reshape it and call the regular tensor-compatible reduction operations (e.g.,
drjit.sum()
,drjit.prod()
,drjit.min()
,drjit.max()
, or the generaldrjit.reduce()
).For example, to sum-reduce a
(16, 16)
tensor by a factor of(4, 2)
(i.e., to a(4, 8)
-sized tensor), writeresult = dr.sum(dr.reshape(value, shape=(4, 4, 8, 2)), axis=(1, 3))
- Parameters:
value (object) – A Dr.Jit array or PyTree
block_size (int) – The size of contiguous blocks to be reduced.
mode (str | None) – optional parameter to force an evaluation strategy.
- Returns:
The block-reduced array or PyTree as specified above.
- drjit.block_sum(value: T, block_size: int, mode: Literal['evaluated', 'symbolic', None] = None) T ¶
Sum elements within blocks.
This is a convenience alias for
drjit.block_reduce()
withop
set todrjit.ReduceOp.Add
.
- drjit.block_prefix_reduce(op: ReduceOp, value: ArrayT, block_size: int, exclusive: bool = True, reverse: bool = False) ArrayT ¶
Compute a blocked exclusive or inclusive prefix reduction of the input array.
Starting from the identity element of the specified reduction
op
, this function reduces along increasingly large prefixes, returning an output array of the same size and type.For example, when reducing a 1D array using
exclusive=True
(the default), this produces the following outputdrjit.ReduceOp.Add
:[0, a[0], a[0] + a[1], ...]
.drjit.ReduceOp.Mul
:[1, a[0], a[0] * a[1], ...]
.drjit.ReduceOp.Min
:[inf, a[0], min(a[0], a[1]), ...]
.drjit.ReduceOp.Max
:[-inf, a[0], max(a[0], a[1]), ...]
.drjit.ReduceOp.Or
:[0, a[0], a[0] | a[1], ...]
(integer arrays only).drjit.ReduceOp.And
:[-1, a[0], a[0] & a[1], ...]
(integer arrays only).
With
inclusive=False
, the function instead performs an inclusive prefix reduction, which effectively shifts the output by one entry:drjit.ReduceOp.Add
:[a[0], a[0] + a[1], ...]
.drjit.ReduceOp.Mul
:[a[0], a[0] * a[1], ...]
.drjit.ReduceOp.Min
:[a[0], min(a[0], a[1]), ...]
.drjit.ReduceOp.Max
:[a[0], max(a[0], a[1]), ...]
.drjit.ReduceOp.Or
:[a[0], a[0] | a[1], ...]
(integer arrays only).drjit.ReduceOp.And
:[a[0], a[0] & a[1], ...]
(integer arrays only).
The reduction is furthermore blocked, which means that it restarts after
block_size
entries. To reduce the entire array, simply setblock_size=len(value)
.Finally, the reduction can optionally be done from the end of each block by specifying
reverse=True
.This operation traverses PyTrees and transforms any dynamically sized Dr.Jit arrays it encounters. Everything else is left as-is.
- Parameters:
value (object) – A Dr.Jit array or PyTree
block_size (int) – The size of contiguous blocks to be reduced.
exclusive (bool) – Specifies whether or not the prefix sum should be exclusive (the default) or inclusive.
reverse (bool) – if set to
True
, the reduction is done from the end of each block.
- Returns:
The block-reduced array or PyTree as specified above.
- drjit.block_prefix_sum(value: T, block_size: int, exclusive: bool = True, reverse: bool = False) T ¶
Convenience wrapper around
dr.block_prefix_reduce(dr.ReduceOp.Add, ...)
.
Rearranging array contents¶
- drjit.concat(arr: Sequence[ArrayT], /, axis: int | None = 0) ArrayT ¶
Concatenate a sequence of arrays or tensors along a given axis.
The inputs must all be of the same type, and they must have the same shape except for the axis being concatenated. Negative
axis
values count backwards from the last dimension.When
axis=None
, the function ravels the input arrays or tensors prior to concatenating them.
- drjit.reverse(value, axis: int = 0)¶
Reverses the given Dr.Jit array or Python sequence along the specified axis.
- Parameters:
value (ArrayBase|Sequence) – Dr.Jit array or Python sequence type
axis (int) – Axis along which the reversal should be performed. Only
axis==0
is supported for now.
- Returns:
An output of the same type as value containing a copy of the reversed array.
- Return type:
object
- drjit.moveaxis(arg: ArrayBase, /, source: int | Tuple[int, ...], destination: int | Tuple[int, ...])¶
Move one or more axes of an input tensor to another position.
Dimensions of that are not explicitly moved remain in their original order and appear at the positions not specified in the destination. Negative axis values count backwards from the end.
- drjit.compress(arg: drjit.ArrayBase, /) object ¶
Compress a mask into an array of nonzero indices.
This function takes an boolean array as input and then returns an unsigned 32-bit integer array containing the indices of nonzero entries.
It can be used to reduce a stream to a subset of active entries via the following recipe:
# Input: an active mask and several arrays data_1, data_2, ... dr.schedule(active, data_1, data_2, ...) indices = dr.compress(active) data_1 = dr.gather(type(data_1), data_1, indices) data_2 = dr.gather(type(data_2), data_2, indices) # ...
There is some conceptual overlap between this function and
drjit.scatter_inc()
, which can likewise be used to reduce a stream to a smaller subset of active items. Please see the documentation ofdrjit.scatter_inc()
for details.Danger
This function internally performs a synchronization step.
- Parameters:
arg (bool | drjit.ArrayBase) – A Python or Dr.Jit boolean type
- Returns:
Array of nonzero indices
- drjit.ravel(array: object, order: str = 'A') object ¶
Convert the input into a contiguous flat array.
This operation takes a Dr.Jit array, typically with some static and some dynamic dimensions (e.g.,
drjit.cuda.Array3f
with shape 3xN), and converts it into a flattened 1D dynamically sized array (e.g.,drjit.cuda.Float
) using either a C or Fortran-style ordering convention.It can also convert Dr.Jit tensors into a flat representation, though only C-style ordering is supported in this case.
Internally,
drjit.ravel()
performs a series of calls todrjit.scatter()
to suitably reorganize the array contents.For example,
x = dr.cuda.Array3f([1, 2], [3, 4], [5, 6]) y = dr.ravel(x, order=...)
will produce
[1, 3, 5, 2, 4, 6]
withorder='F'
(the default for Dr.Jit arrays), which means that X/Y/Z components alternate.[1, 2, 3, 4, 5, 6]
withorder='C'
, in which case all X coordinates are written as a contiguous block followed by the Y- and then Z-coordinates.
- Parameters:
array (drjit.ArrayBase) – An arbitrary Dr.Jit array or tensor
order (str) – A single character indicating the index order.
'F'
indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'
specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'
(automatic) will use F-style ordering for arrays and C-style ordering for tensors.
- Returns:
A dynamic 1D array containing the flattened representation of
array
with the desired ordering. The type of the return value depends on the type of the input. Whenarray
is already contiguous/flattened, this function returns it without making a copy.- Return type:
object
- drjit.unravel(dtype: type[ArrayT], array: AnyArray, order: Literal['A', 'C', 'F'] = 'A') ArrayT ¶
Load a sequence of Dr.Jit vectors/matrices/etc. from a contiguous flat array.
This operation implements the inverse of
drjit.ravel()
. In contrast todrjit.ravel()
, it requires one additional parameter (dtype
) specifying type of the return value. For example,x = dr.cuda.Float([1, 2, 3, 4, 5, 6]) y = dr.unravel(dr.cuda.Array3f, x, order=...)
will produce an array of two 3D vectors with different contents depending on the indexing convention:
[1, 2, 3]
and[4, 5, 6]
when unraveled withorder='F'
(the default for Dr.Jit arrays), and[1, 3, 5]
and[2, 4, 6]
when unraveled withorder='C'
Internally,
drjit.unravel()
performs a series of calls todrjit.gather()
to suitably reorganize the array contents.- Parameters:
dtype (type) – An arbitrary Dr.Jit array type
array (drjit.ArrayBase) – A dynamically sized 1D Dr.Jit array instance that is compatible with
dtype
. In other words, both must have the same underlying scalar type and be located imported in the same package (e.g.,drjit.llvm.ad
).order (str) – A single character indicating the index order.
'F'
(the default) indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'
specifies row-major/C-style ordering, in which case the last index changes at the highest frequency.
- Returns:
An instance of type
dtype
containing the result of the unravel operation.- Return type:
object
- drjit.reshape(dtype: type, value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) object ¶
- drjit.reshape(dtype: type, value: object, shape: int, order: str = 'A', shrink: bool = False) object
Overloaded function.
reshape(dtype: type, value: object, shape: collections.abc.Sequence[int], order: str = 'A', shrink: bool = False) -> object
Converts
value
into an array of typedtype
by rearranging the contents according to the specified shape.The parameter
shape
may contain a single-1
-valued target dimension, in which case its value is inferred from the remaining shape entries and the size of the input. Whenshape
is of typeint
, it is interpreted as a 1-tuple(shape,)
.This function supports the following behaviors:
Reshaping tensors: Dr.Jit tensors admit arbitrary shapes. The
drjit.reshape()
can convert between them as long as the total number of elements remains unchanged.>>> from drjit.llvm.ad import TensorXf >>> value = dr.arange(TensorXf, 6) >>> dr.reshape(dtype=TensorXf, value=value, shape=(3, -1)) [[0, 1] [2, 3] [4, 5]]
Reshaping nested arrays: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions
drjit.ravel()
anddrjit.unravel()
.>>> from drjit.llvm.ad import Array2f, Array3f >>> value = Array2f([1, 2, 3], [4, 5, 6]) >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='C') [[1, 4, 2], [5, 3, 6]] >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='F') [[1, 3, 5], [2, 4, 6]]
(By convention, Dr.Jit nested arrays are always printed in transposed form, which explains the difference in output compared to the identically shaped Tensor example just above.)
The
order
argument can be used to specify C ("C"
) or Fortran ("F"
)-style ordering when rearranging the array. The default value"A"
corresponds to Fortran-style ordering.PyTrees: When
value
is a PyTree, the operation recursively threads through the tree’s elements.
Stream compression and loops that fork recursive work. When called with
shrink=True
, the function creates a view of the original data that potentially has a smaller number of elements.The main use of this feature is to implement loops that process large numbers of elements in parallel, and which need to occasionally “fork” some recursive work. On modern compute accelerators, an efficient way to handle this requirement is to append this work into a queue that is processed in a subsequent pass until no work is left. The reshape operation with
shrink=True
then resizes the preallocated queue to the actual number of collected items, which are the input of the next iteration.Please refer to the following example that illustrates how
drjit.scatter_inc()
,drjit.scatter()
, anddrjit.reshape(..., shrink=True)
can be combined to realize a parallel loop with a fork condition@drjit.syntax def f(): # Loop state variables (an arbitrary array or PyTree) state = ... # Determine how many elements should be processed size = dr.width(loop_state) # Run the following loop until no work is left while size > 0: # 1-element array used as an atomic counter queue_index = UInt(0) # Preallocate memory for the queue. The necessary # amount of memory is task-dependent queue_size = size queue = dr.empty(dtype=type(state), shape=queue_size) # Create an opaque variable representing the number 'loop_state'. # This keeps this changing value from being baked into the program, # which is needed for proper kernel caching queue_size_o = dr.opaque(UInt32, queue_size) while not stopping_criterion(state): # This line represents the loop body that processes work state = loop_body(state) # if the condition 'fork' is True, spawn a new work item that # will be handled in a future iteration of the parent loop. if fork(state): # Atomically reserve a slot in 'queue' slot = dr.scatter_inc(target=queue_index, index=0) # Work item for the next iteration, task dependent todo = state # Be careful not to write beyond the end of the queue valid = slot < queue_size_o # Write 'todo' into the reserved slot dr.scatter(target=queue, index=slot, value=todo, active=valid) # Determine how many fork operations took place size = queue_index[0] if size > queue_size: raise RuntimeError('Preallocated queue was too small: tried to store ' f'{size} elements in a queue of size {queue_size}') # Reshape the queue and re-run the loop state = dr.reshape(dtype=type(state), value=queue, shape=size, shrink=True)
- Parameters:
dtype (type) – Desired output type of the reshaped array. This could equal
type(value)
or refer to an entirely different array type.value (object) – An arbitrary Dr.Jit array, tensor, or PyTree. The function returns unknown objects of other types unchanged.
shape (int|tuple[int, ...]) – The target shape.
order (str) – A single character indicating the index order used to reinterpret the input.
'F'
indicates column-major/Fortran-style ordering, in which case the first index changes at the highest frequency. The alternative'C'
specifies row-major/C-style ordering, in which case the last index changes at the highest frequency. The default value'A'
(automatic) will use F-style ordering for arrays and C-style ordering for tensors.shrink (bool) – Cheaply construct a view of the input that potentially has a smaller number of elements. The main use case of this method is explained above.
- Returns:
The reshaped array or PyTree.
- Return type:
object
reshape(dtype: type, value: object, shape: int, order: str = 'A', shrink: bool = False) -> object
- drjit.tile(value: T, count: int) T ¶
Tile the input array
count
times along the trailing dimension.This function replicates the input
count
times along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. Whencount==1
, the function returns the input without changes.An example is shown below:
- Parameters:
value (drjit.ArrayBase) – A Dr.Jit type or PyTree.
count (int) – Number of repetitions
- Returns:
The tiled input as described above. The return type matches that of
value
.- Return type:
object
- drjit.repeat(value: T, count: int) T ¶
Repeat each successive entry of the input
count
times along the trailing dimension.This function replicates the input
count
times along the trailing dynamic dimension. It recursively threads through nested arrays and PyTree. Static arrays and tensors currently aren’t supported. Whencount==1
, the function returns the input without changes.An example is shown below:
- Parameters:
value (drjit.ArrayBase) – A Dr.Jit type or PyTree.
count (int) – Number of repetitions
- Returns:
The repeated input as described above. The return type matches that of
value
.- Return type:
object
- drjit.resample(source: ArrayT, shape: Sequence[int], filter: Literal['box', 'linear', 'hamming', 'cubic', 'lanczos', 'gaussian'] | Callable[[float], float] = 'cubic', filter_radius: float | None = None) ArrayT ¶
Resample an input array/tensor to increase or decrease its resolution along a set of axes.
This function up- and/or downsamples a given array or tensor along a specified set of axes. Given an input array (
source
) and target shape (shape
), it returns a compatible array of the specified configuration. This is implemented using a sequence of successive 1D resampling steps for each mismatched axis.Example usage:
image: TensorXf = ... # a RGB image width, height, channels = image.shape scaled_image = dr.resample( image, (width // 2, height // 2, channels) )
Resampling uses a reconstruction filter. The following options are available, where \(n\) refers to the number of dimensions being resampled:
"box"
: use nearest-neighbor interpolation/averaging. This is very efficient but generally produces sub-par output that is either pixelated (when upsampling) or aliased (when downsampling)."linear"
: use linear ramp / tent filter that uses \(2^n\) neighbors to reconstruct each output sample when upsampling. Tends to produce relatively blurry results."hamming"
: uses the same number of input samples as"linear"
but better preserves sharpness when downscaling. Do not use for upscaling."cubic"
: use cubic filter kernel that queries \(4^n\) neighbors to reconstruct each output sample when upsampling. Produces high-quality results. This is the default."lanczos"
: use a windowed Lanczos filter that queries \(6^n\) neighbors to reconstruct each output sample when upsampling. This is the best filter for smooth signals, but also the costliest. The Lanczos filter is susceptible to ringing when the input array contains discontinuities."gaussian"
: use a Gaussian filter that queries :math:4^n` neighbors to reconstruct each output sample when upsampling. The kernel has a standard deviation of 0.5 and is truncated after 4 standard deviations. This filter is mainly useful when intending to blur a signal.Besides the above choices, it is also possible to specify a custom filter. To do so, use the
filter
argument to pass a Python callable with signatureCallable[[float], float]
. In this case, you must also specify a filter radius via thefilter_radius
parameter.
The implementation was extensively tested against Image.resize() from the Pillow library and should be a drop-in replacement with added support for JIT tracing / GPU evaluation, differentiability, and compatibility with higher-dimensional tensors.
Warning
When using
filter="hamming"
,"cubic"
, or"lanczos"
, the range of the output array can exceed that of the input array. For example, positive-valued data may contain negative values following resampling. Clamp the output in case it is important that array values remain within a fixed range (e.g., \([0,1]\)).- Parameters:
source (dr.ArrayBase) – The Dr.Jit tensor or 1D array to be resampled.
shape (Sequence[int]) – The desired output shape.
filter (str | Callable[[float], float]) – The desired reconstruction filter, see the above text for an overview. Alternatively, a custom reconstruction filter function can also be specified.
filter_radius (float | None) – The radius of the pixel filter in the output sample space. Should only be specified when using a custom reconstruction filter.
- Returns:
The resampled output array. Its type matches
source
, and its shape matchesshape
.- Return type:
- drjit.convolve(source: ArrayT, filter: Literal['box', 'linear', 'hamming', 'cubic', 'lanczos', 'gaussian'] | Callable[[float], float], filter_radius: float, axis: int | Tuple[int, ...] | None = None) ArrayT ¶
Convolve one or more axes of an input array/tensor with a 1D filter
This function filters one more axes of a Dr.Jit array or tensor, for example to convolve an image with a 2D Gaussian filter to blur spatial detail.
image: TensorXf = ... # a RGB image blured_image = dr.convolve( image, filter='gaussian', filter_radius=10 )
The filter weights are renormalized to reduce edge effects near the boundary of the array.
The function supports a set of provided filters, and custom filters can also be specified. This works analogously to the
resample()
function, please refer to its documentation for detail.- Parameters:
source (dr.ArrayBase) – The Dr.Jit tensor or 1D array to be resampled.
filter (str | Callable[[float], float]) – The desired reconstruction filter, see the above text for an overview. Alternatively, a custom reconstruction filter function can also be specified.
filter_radius (float) – The radius of the continous function to be used in the convolution.
axis (int | tuple[int, ...] | ... | None) – The axis or set of axes along which to convolve. The default argument
axis=None
causes all axes to be convolved. Negative values count from the last dimension.
- Returns:
The resampled output array. Its type matches
source
.- Return type:
Random number generation¶
- drjit.rand(dtype: Type[ArrayT], shape: int | Tuple[int, ...], *, seed: int | ArrayBase[Any, Any, Any, Any, Any, Any, Any] | None = None, version: int = 1, _func_name='next_float') ArrayT ¶
Return a Dr.Jit array or tensor containing uniformly distributed pseudorandom variates.
This function supports floating point arrays/tensors of various configurations and precisions, e.g.:
from drjit.cuda import Float, TensorXf, Array3f, Matrix4f # Example usage rand_array = dr.rand(Float, 128) rand_tensor = dr.rand(TensorXf16, (128, 128)) rand_vec = dr.rand(Array3f, (3, 128)) rand_mat = dr.rand(Matrix4f64, (4, 4, 128))
The output is uniformly distributed the interval \([0, 1)\). Integer arrays are not supported.
Successive calls to
drjit.rand()
produce independent random variates. You can manually specify a 64-bit integer via theseed
parameter to avoid this. Use thedrjit.seed()
function to reset the global default seed value.Warning
This function is still considered experimental, and the algorithm used to generate random variates may change in future versions of Dr.Jit. Specify
version=1
to ensure that your program remains unaffected by such future changes.Note
When this function is used within a symbolic operation (e.g.
drjit.while_loop()
), you must provide theseed
parameter.In the non-symbolic case, the seed parameter is internally made opaque via
drjit.make_opaque()
so that the use of this function does not interfere with kernel caching.In applications that require repeated generation of random variates (e.g., in a symbolic loop), is more efficient to directly work with the underlying random number generator (e.g.,
drjit.cuda.PCG32
) instead of using the high-leveldrjit.rand()
interface.- Parameters:
source (type[ArrayT]) – A Dr.Jit tensor or array type.
shape (int | tuple[int, ...]) – The target shape
seed (int | None) – A seed value used to initialize the random number generator. If no value is provided, a global seed value is used (and then subsequently incremented). Refer to
drjit.seed()
.version (int) – Optional parameter to target a specific implementation of this function in the case of future changes.
- Returns:
The generated array of random variates.
- Return type:
ArrayT
- drjit.normal(dtype: Type[ArrayT], shape: int | Tuple[int, ...], *, seed: int | ArrayBase[Any, Any, Any, Any, Any, Any, Any] | None = None, version: int = 1) ArrayT ¶
Return a Dr.Jit array or tensor containing pseudorandom variates following a standard normal distribution
Please refer to
drjit.rand()
, the interfaces of these two functions are identical.
- drjit.seed(value: int)¶
Reset the seed value that is used for pseudorandom number generation.
Every successive call to
rand()
andnormal()
(without manually specifiedseed
) increments an internal counter that is used to initialize the random number generator to ensure independent output.This function can be used to reset this counter to a specific value.
Mask operations¶
Also relevant here are any()
, all()
, none()
, and count()
.
- drjit.select(arg0: object, arg1: object | None, arg2: object | None) object ¶
- drjit.select(arg0: bool, arg1: object | None, arg2: object | None) object
Overloaded function.
select(arg0: object, arg1: object | None, arg2: object | None) -> object
Select elements from inputs based on a condition
This function uses a first mask argument to select between the subsequent two arguments. It implements the following component-wise operation:
\[\mathrm{result}_i = \begin{cases} \texttt{arg1}_i,\quad&\text{if }\texttt{arg0}_i,\\ \texttt{arg2}_i,\quad&\text{otherwise.} \end{cases}\]- Parameters:
arg0 (bool | drjit.ArrayBase) – A Python or Dr.Jit mask type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for
True
-valued mask entries.arg2 (int | float | drjit.ArrayBase) – A Python or Dr.Jit type, whose entries should be returned for
False
-valued mask entries.
- Returns:
Component-wise result of the selection operation
- Return type:
float | int | drjit.ArrayBase
select(arg0: bool, arg1: object | None, arg2: object | None) -> object
- drjit.isinf(arg, /)¶
Performs an elementwise test for positive or negative infinity
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test
- Return type:
- drjit.isnan(arg, /)¶
Performs an elementwise test for NaN (Not a Number) values
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test.
- Return type:
- drjit.isfinite(arg, /)¶
Performs an elementwise test that checks whether values are finite and not equal to NaN (Not a Number)
- Parameters:
arg (object) – A Dr.Jit array or other kind of numeric sequence type.
- Returns:
A mask value describing the result of the test
- Return type:
- drjit.allclose(a: object, b: object, rtol: float | None = None, atol: float | None = None, equal_nan: bool = False) bool ¶
Returns
True
if two arrays are element-wise equal within a given error tolerance.The function considers both absolute and relative error thresholds. In particular, a and b are considered equal if all elements satisfy
\[|a - b| \le |b| \cdot \texttt{rtol} + \texttt{atol}. \]If not specified, the constants
atol
andrtol
are chosen depending on the precision of the input arrays:Precision
rtol
atol
float64
1e-5
1e-8
float32
1e-3
1e-5
float16
1e-1
1e-2
Note that these constants used are fairly loose and far larger than the roundoff error of the underlying floating point representation. The double precision parameters were chosen to match the behavior of
numpy.allclose()
.- Parameters:
a (object) – A Dr.Jit array or other kind of numeric sequence type.
b (object) – A Dr.Jit array or other kind of numeric sequence type.
rtol (float) – A relative error threshold chosen according to the above table.
atol (float) – An absolute error threshold according to the above table.
equal_nan (bool) – If a and b both contain a NaN (Not a Number) entry, should they be considered equal? The default is
False
.
- Returns:
The result of the comparison.
- Return type:
bool
Miscellaneous operations¶
- drjit.shape(arg: object) tuple[int, ...] ¶
Return a tuple describing dimension and shape of the provided Dr.Jit array, tensor, or standard sequence type.
When the input array is ragged the function raises a
RuntimeError
. The term ragged refers to an array, whose components have mismatched sizes, such as[[1, 2], [3, 4, 5]]
. Note that scalar entries (e.g.[[1, 2], [3]]
) are acceptable, since broadcasting can effectively convert them to any size.The expressions
drjit.shape(arg)
andarg.shape
are equivalent.- Parameters:
arg (drjit.ArrayBase) – an arbitrary Dr.Jit array or tensor
- Returns:
A tuple describing the dimension and shape of the provided Dr.Jit input array or tensor.
- Return type:
tuple[int, …]
- drjit.width(arg: object, /) int ¶
- drjit.width(*args) int
Overloaded function.
width(arg: object, /) -> int
Returns the vectorization width of the provided input(s), which is defined as the length of the last dynamic dimension.
When working with Jit-compiled CUDA or LLVM-based arrays, this corresponds to the number of items being processed in parallel.
The function raises an exception when the input(s) is ragged, i.e., when it contains arrays with incompatible sizes. It returns
1
if the input is scalar and/or does not contain any Dr.Jit arrays.- Parameters:
arg (object) – An arbitrary Dr.Jit array or PyTree.
- Returns:
The width of the provided input(s).
- Return type:
int
width(*args) -> int
- drjit.slice_index(dtype: type[drjit.ArrayBase], shape: tuple, indices: tuple) tuple[tuple, object] ¶
Computes an index array that can be used to slice a tensor. It is used internally by Dr.Jit to implement complex cases of the
__getitem__
operation.It must be called with the desired output
dtype
, which must be a dynamically sized 1D array of 32-bit integers. Theshape
parameter specifies the dimensions of a hypothetical input tensor, andindices
contains the entries that would appear in a complex slicing operation, but as a tuple. For example,[5:10:2, ..., None]
would be specified as(slice(5, 10, 2), Ellipsis, None)
.An example is shown below:
>>> dr.slice_index(dtype=dr.scalar.ArrayXu, shape=(10, 1), indices=(slice(0, 10, 2), 0)) [0, 2, 4, 6, 8]
- Parameters:
dtype (type) – A dynamic 32-bit unsigned integer Dr.Jit array type, such as
drjit.scalar.ArrayXu
ordrjit.cuda.UInt
.shape (tuple[int, ...]) – The shape of the tensor to be sliced.
indices (tuple[int|slice|ellipsis|None|dr.ArrayBase, ...]) – A set of indices used to slice the tensor. Its entries can be
slice
instances, integers, integer arrays,...
(ellipsis) orNone
.
- Returns:
Tuple consisting of the output array shape and a flattened unsigned integer array of type
dtype
containing element indices.- Return type:
tuple[tuple[int, …], drjit.ArrayBase]
- drjit.meshgrid(*args, indexing='xy') tuple ¶
Return flattened N-D coordinate arrays from a sequence of 1D coordinate vectors.
This function constructs flattened coordinate arrays that are convenient for evaluating and plotting functions on a regular grid. An example is shown below:
import drjit as dr x, y = dr.meshgrid( dr.arange(dr.llvm.UInt, 4), dr.arange(dr.llvm.UInt, 4) ) # x = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] # y = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
This function carefully reproduces the behavior of
numpy.meshgrid
except for one major difference: the output coordinates are returned in flattened/raveled form. Like the NumPy version, theindexing=='xy'
case internally reorders the first two elements of*args
.- Parameters:
*args – A sequence of 1D coordinate arrays
indexing (str) – Specifies the indexing convention. Must be either set
'xy' (to)
- Returns:
A tuple of flattened coordinate arrays (one per input)
- Return type:
tuple
- drjit.binary_search(start, end, pred)¶
Perform a binary search over a range given a predicate
pred
, which monotonically decreases over this range (i.e. max oneTrue
->False
transition).Given a (scalar)
start
andend
index of a range, this function evaluates a predicatefloor(log2(end-start) + 1)
times with index values on the interval [start, end] (inclusive) to find the first index that no longer satisfies it. Note that the template parameterIndex
is automatically inferred from the supplied predicate. Specifically, the predicate takes an index array as input argument. Whenpred
isFalse
for all entries, the function returnsstart
, and when it isTrue
for all cases, it returnsend
.The following code example shows a typical use case:
data
contains a sorted list of floating point numbers, and the goal is to map floating point entries ofx
to the first indexj
such thatdata[j] >= threshold
(and all of this of course in parallel for each vector element).dtype = dr.llvm.Float data = dtype(...) threshold = dtype(...) index = dr.binary_search( 0, len(data) - 1, lambda index: dr.gather(dtype, data, index) < threshold )
- Parameters:
start (int) – Starting index for the search range
end (int) – Ending index for the search range
pred (function) – The predicate function to be evaluated
- Returns:
Index array resulting from the binary search
- drjit.make_opaque(arg: object, /) None ¶
- drjit.make_opaque(*args) None
Overloaded function.
make_opaque(arg: object, /) -> None
Forcefully evaluate arrays (including literal constants).
This function implements a more drastic version of
drjit.eval()
that additionally converts literal constant arrays into evaluated (device memory-based) representations.It is related to the function
drjit.opaque()
that can be used to directly construct such opaque arrays. Please see the documentation of this function regarding the rationale of making array contents opaque to Dr.Jit’s symbolic tracing mechanism.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)
make_opaque(*args) -> None
- drjit.copy(arg: T, /) T ¶
Create a deep copy of a PyTree
This function recursively traverses PyTrees and replaces Dr.Jit arrays with copies created via the ordinary copy constructor. It also rebuilds tuples, lists, dictionaries, and other custom data strutures.
Just-in-time compilation¶
- enum drjit.JitBackend(value)¶
List of just-in-time compilation backends supported by Dr.Jit. See also
drjit.backend_v()
.Valid values are as follows:
- Invalid = JitBackend.Invalid¶
Indicates that a type is not handled by a Dr.Jit backend (e.g., a scalar type)
- CUDA = JitBackend.CUDA¶
Dr.Jit backend targeting NVIDIA GPUs using PTX (“Parallel Thread Execution”) IR.
- LLVM = JitBackend.LLVM¶
Dr.Jit backend targeting various processors via the LLVM compiler infrastructure.
- enum drjit.VarType(value)¶
List of possible scalar array types (not all of them are supported).
Valid values are as follows:
- Void = VarType.Void¶
Unknown/unspecified type.
- Bool = VarType.Bool¶
Boolean/mask type.
- Int8 = VarType.Int8¶
Signed 8-bit integer.
- UInt8 = VarType.UInt8¶
Unsigned 8-bit integer.
- Int16 = VarType.Int16¶
Signed 16-bit integer.
- UInt16 = VarType.UInt16¶
Unsigned 16-bit integer.
- Int32 = VarType.Int32¶
Signed 32-bit integer.
- UInt32 = VarType.UInt32¶
Unsigned 32-bit integer.
- Int64 = VarType.Int64¶
Signed 64-bit integer.
- UInt64 = VarType.UInt64¶
Unsigned 64-bit integer.
- Pointer = VarType.Pointer¶
Pointer to a memory address.
- Float16 = VarType.Float16¶
16-bit floating point format (IEEE 754).
- Float32 = VarType.Float32¶
32-bit floating point format (IEEE 754).
- Float64 = VarType.Float64¶
64-bit floating point format (IEEE 754).
- enum drjit.VarState(value)¶
The
drjit.ArrayBase.state
property returns one of the following enumeration values describing possible evaluation states of a Dr.Jit variable.Valid values are as follows:
- Invalid = VarState.Invalid¶
The variable has length 0 and effectively does not exist.
- Literal = VarState.Literal¶
A literal constant. Does not consume device memory.
- Undefined = VarState.Undefined¶
An undefined memory region. Does not (yet) consume device memory.
- Unevaluated = VarState.Unevaluated¶
An ordinary unevaluated variable that is neither a literal constant nor symbolic.
- Evaluated = VarState.Evaluated¶
Evaluated variable backed by an device memory region.
- Dirty = VarState.Dirty¶
An evaluated variable backed by a device memory region. The variable furthermore has pending side effects (i.e. the user has performed a :py:func`:drjit.scatter`,
drjit.scatter_reduce()
:py:func`:drjit.scatter_inc`, :py:func`:drjit.scatter_add`, or :py:func`:drjit.scatter_add_kahan` operation, and the effect of this operation has not been realized yet). The array’s status will automatically change toEvaluated
the next time that Dr.Jit evaluates computation, e.g. viadrjit.eval()
.
- Symbolic = VarState.Symbolic¶
A symbolic variable that could take on various inputs. Cannot be evaluated.
- Mixed = VarState.Mixed¶
This is a nested array, and the components have mixed states.
- enum drjit.JitFlag(value)¶
Flags that control how Dr.Jit compiles and optimizes programs.
This enumeration lists various flag that control how Dr.Jit compiles and optimizes programs, most of which are enabled by default. The status of each flag can be queried via
drjit.flag()
and enabled/disabled via thedrjit.set_flag()
or the recommendeddrjit.scoped_set_flag()
functions, e.g.:with dr.scoped_set_flag(dr.JitFlag.SymbolicLoops, False): # code that has this flag disabled goes here
The most common reason to update the flags is to switch between symbolic and evaluated execution of loops and functions. The former eagerly executes programs by breaking them into many smaller kernels, while the latter records computation symbolically to assemble large megakernels. See explanations below along with the documentation of
drjit.switch()
anddrjit.while_loop
for more details on these two modes.Dr.Jit flags are a thread-local property. This means that multiple independent threads using Dr.Jit can set them independently without interfering with each other.
- Member Type:
int
Valid values are as follows:
- Debug = JitFlag.Debug¶
Debug mode: Enable functionality to uncover errors in application code.
When debug mode is enabled, Dr.Jit inserts a number of additional runtime checks to locate sources of undefined behavior.
Debug mode comes at a significant cost: it interferes with kernel caching, reduces tracing performance, and produce kernels that run slower. We recommend that you only use it periodically before a release, or when encountering a serious problem like a crashing kernel.
First, debug mode enables assertion checks in user code such as those performed by
drjit.assert_true()
,drjit.assert_false()
, anddrjit.assert_equal()
.Second, Dr.Jit inserts additional checks to intercept out-of-bound reads and writes performed by operations such as
drjit.scatter()
,drjit.gather()
,drjit.scatter_reduce()
,drjit.scatter_inc()
, etc. It also detects calls to invalid callables performed viadrjit.switch()
,drjit.dispatch()
. Such invalid operations are masked, and they generate a warning message on the console, e.g.:>>> dr.gather(dtype=UInt, source=UInt(1, 2, 3), index=UInt(0, 1, 100)) RuntimeWarning: drjit.gather(): out-of-bounds read from position 100 in an array↵ of size 3. (<stdin>:2)
Finally, Dr.Jit also installs a python tracing hook that associates all Jit variables with their Python source code location, and this information is propagated all the way to the final intermediate representation (PTX, LLVM IR). This is useful for low-level debugging and development of Dr.Jit itself. You can query the source location information of a variable
x
by writingx.label
.Due to limitations of the Python tracing interface, this handler becomes active within the next called function (or Jupyter notebook cell) following activation of the
drjit.JitFlag.Debug
flag. It does not apply to code within the same scope/function.C++ code using Dr.Jit also benefits from debug mode but will lack accurate source code location information. In mixed-language projects, the reported file and line number information will reflect that of the last operation on the Python side of the interface.
- ReuseIndices = JitFlag.ReuseIndices¶
Index reuse: Dr.Jit consists of two main parts: the just-in-time compiler, and the automatic differentiation layer. Both maintain an internal data structure representing captured computation, in which each variable is associated with an index (e.g.,
r1234
in the JIT compiler, anda1234
in the AD graph).The index of a Dr.Jit array in these graphs can be queried via the
drjit.index
anddrjit.index_ad
variables, and they are also visible in debug messages (ifdrjit.set_log_level()
is set to a more verbose debug level).Dr.Jit aggressively reuses the indices of expired variables by default, but this can make debug output difficult to interpret. When debugging Dr.Jit itself, it is often helpful to investigate the history of a particular variable. In such cases, set this flag to
False
to disable variable reuse both at the JIT and AD levels. This comes at a cost: the internal data structures keep on growing, so it is not suitable for long-running computations.Index reuse is enabled by default.
- ConstantPropagation = JitFlag.ConstantPropagation¶
Constant propagation: immediately evaluate arithmetic involving literal constants on the host and don’t generate any device-specific code for them.
For example, the following assertion holds when value numbering is enabled in Dr.Jit.
from drjit.llvm import Int # Create two literal constant arrays a, b = Int(4), Int(5) # This addition operation can be immediately performed and does not need to be recorded c1 = a + b # Double-check that c1 and c2 refer to the same Dr.Jit variable c2 = Int(9) assert c1.index == c2.index
Constant propagation is enabled by default.
- ValueNumbering = JitFlag.ValueNumbering¶
Local value numbering: a simple variant of common subexpression elimination that collapses identical expressions within basic blocks. For example, the following assertion holds when value numbering is enabled in Dr.Jit.
from drjit.llvm import Int # Create two non-literal arrays stored in device memory a, b = Int(1, 2, 3), Int(4, 5, 6) # Perform the same arithmetic operation twice c1 = a + b c2 = a + b # Verify that c1 and c2 reference the same Dr.Jit variable assert c1.index == c2.index
Local value numbering is enabled by default.
- FastMath = JitFlag.FastMath¶
Fast Math: this flag is analogous to the
-ffast-math
flag in C compilers. When set, the system may use approximations and simplifications that sacrifice strict IEEE-754 compatibility.Currently, it changes two behaviors:
expressions of the form
a * 0
will be simplified to0
(which is technically not correct whena
is infinite or NaN-valued).Dr.Jit will use slightly approximate division and square root operations in CUDA mode. Note that disabling fast math mode is costly on CUDA devices, as the strict IEEE-754 compliant version of these operations uses software-based emulation.
Fast math mode is enabled by default.
- SymbolicLoops = JitFlag.SymbolicLoops¶
Dr.Jit provides two main ways of compiling loops involving Dr.Jit arrays.
Symbolic mode (the default): Dr.Jit executes the loop a single time regardless of how many iterations it requires in practice. It does so with symbolic (abstract) arguments to capture the loop condition and body and then turns it into an equivalent loop in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.
Evaluated mode: Dr.Jit evaluates the loop’s state variables and reduces the loop condition to a single element (
bool
) that expresses whether any elements are still alive. If so, it runs the loop body and the process repeats.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Symbolic loops are enabled by default.
- OptimizeLoops = JitFlag.OptimizeLoops¶
Perform basic optimizations for loops involving Dr.Jit arrays.
This flag enables two optimizations:
Constant arrays: loop state variables that aren’t modified by the loop are automatically removed. This shortens the generated code, which can be helpful especially in combination with the automatic transformations performed by
@drjit.syntax
that can be somewhat conservative in classifying too many local variables as potential loop state.Literal constant arrays: In addition to the above point, constant loop state variables that are literal constants are propagated into the loop body, where this may unlock further optimization opportunities.
This is useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.
A practical implication of this optimization flag is that it may cause
drjit.while_loop()
to run the loop body twice instead of just once.This flag is enabled by default. Note that it is only meaningful in combination with
SymbolicLoops
.
- CompressLoops = JitFlag.CompressLoops¶
Compress the loop state of evaluated loops after every iteration.
When an evaluated loop processes many elements, and when each element requires a different number of loop iterations, there is question of what should be done with inactive elements. The default implementation keeps them around and does redundant calculations that are, however, masked out. Consequently, later loop iterations don’t run faster despite fewer elements being active.
Setting this flag causes the removal of inactive elements after every iteration. This reorganization is not for free and does not benefit all use cases.
This flag is disabled by default. Note that it only applies to evaluated loops (i.e., when
SymbolicLoops
is disabled, or themode='evaluted'
parameter as passed to the loop in question).
- SymbolicCalls = JitFlag.SymbolicCalls¶
Dr.Jit provides two main ways of compiling function calls targeting instance arrays.
Symbolic mode (the default): Dr.Jit invokes each callable with symbolic (abstract) arguments. It does this to capture a transcript of the computation that it can turn into a function in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.
Evaluated mode: Dr.Jit evaluates all inputs and groups them by instance ID. Following this, it launches a kernel per instance to process the rearranged inputs and assemble the function return value.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Besides calls to instance arrays, this flag also controls the behavior of the functions
drjit.switch()
anddrjit.dispatch()
.Symbolic calls are enabled by default.
- OptimizeCalls = JitFlag.OptimizeCalls¶
Perform basic optimizations for function calls on instance arrays.
This flag enables two optimizations:
Constant propagation: Dr.Jit will propagate literal constants across function boundaries while tracing, which can unlock simplifications within. This is especially useful in combination with automatic differentiation, where it helps to detect code that does not influence the computed derivatives.
Devirtualization: When an element of the return value has the same computation graph in all instances, it is removed from the function call interface and moved to the caller.
The flag is enabled by default. Note that it is only meaningful in combination with
SymbolicCalls
. Besides calls to instance arrays, this flag also controls the behavior of the functionsdrjit.switch()
anddrjit.dispatch()
.
- MergeFunctions = JitFlag.MergeFunctions¶
Deduplicate code generated by function calls on instance arrays.
When
arr
is an instance array (potentially with thousands of instances), a function call likearr.f(inputs...)
can potentially generate vast numbers of different functions in the generated code. At the same time, many of these functions may contain identical code (or code that is identical except for data references).
Dr.Jit can exploit such redundancy and merge such functions during code generation. Besides generating shorter programs, this also helps to reduce thread divergence.
This flag is enabled by default. Note that it is only meaningful in combination with
SymbolicCalls
. Besides calls to instance arrays, this flag also controls the behavior of the functionsdrjit.switch()
anddrjit.dispatch()
.
- PacketOps = JitFlag.PacketOps¶
Turn sequences of contiguous gather/scatter operations into packet loads/stores.
When indirect memory accesses (gathers/scatters) access multiple subsequent elements, it should be possible to exploit this to perform the operation more efficiently.
If
drjit.JitFlag.PacketOps
is set, Dr.Jit will realizes this optimization opportunity whenThe size of the leading dimension of the source/target array is a power of two.
The array is read/written via
drjit.gather()
,drjit.scatter()
,drjit.ravel()
, ordrjit.unravel()
.For example, the following operation gathers 4D vectors from a flat array:
from drjit.auto import Array4f, Float, UInt32 source = Float(...) result = dr.gather(Array4f, source, index=UInt32(...))
Wider cases are supported as well by providing a
shape
argumentfrom drjit.auto import Array4f, Float, UInt32 source = Float(...) result = dr.gather(ArrayXf, source, index=index, shape=(16, len(index)))
This optimization also applies to atomic scatter-adds performed on the LLVM backend when using the
drjit.ReduceMode.Expand
reduction strategy.
Packet gathers yield a modest performance improvement on the CUDA backend, where they produce fewer and larger memory transactions. For example, a single 128 bit load can fetch 8 half-precision values at once. Speedups of 5-30% have been observed.
On the LLVM backend, the operation replaces vector/scatter instructions with a combination of aligned packet loads/stores and a matrix transpose (implemented for 2, 4, and 8-D inputs, larger vectors perform several 8D transposes). Speedups here are rather dramatic (up to >20× for scatters, 1.5-2× for gathers have been measured).
This optimization is expected to make an even bigger difference following microcode-based security mitigations for a recent side-channel attack that effectively break the performance characteristics of the native gather operation on Intel CPUs. Packet gathers don’t use the regular gather instruction and thus aren’t affected by the mitigation.
This flag is enabled by default.
- ForceOptiX = JitFlag.ForceOptiX¶
Force execution through OptiX even if a kernel doesn’t use ray tracing. This only applies to the CUDA backend is mainly helpful for automated tests done by the Dr.Jit team.
This flag is disabled by default.
- PrintIR = JitFlag.PrintIR¶
Print the low-level IR representation when launching a kernel.
If enabled, this flag causes Dr.Jit to print the low-level IR (LLVM IR, NVIDIA PTX) representation of the generated code onto the console (or Jupyter notebook).
This flag is disabled by default.
- KernelHistory = JitFlag.KernelHistory¶
Maintain a history of kernel launches to profile/debug programs.
Programs written on top of Dr.Jit execute in an extremely asynchronous manner. By default, the system postpones the computation to build large fused kernels. Even when this computation eventually runs, it does so asynchronously with respect to the host, which can make benchmarking difficult.
In general, beware of the following benchmarking anti-pattern:
import time a = time.time() # Some Dr.Jit computation b = time.time() print("took %.2f ms" % ((b-a) * 1000))
In the worst case, the measured time interval may only capture the tracing time, without any actual computation having taken place. Another common mistake with this pattern is that Dr.Jit or the target device may still be busy with computation that started prior to the
a = time.time()
line, which is now incorrectly added to the measured period.Dr.Jit provides a kernel history feature, where it creates an entry in a list whenever it launches a kernel or related operation (memory copies, etc.). This not only gives accurate and isolated timings (measured with counters on the CPU/GPU) but also reveals if a kernel was launched at all. To capture the kernel history, set this flag just before the region to be benchmarked and call
drjit.kernel_history()
at the end.Capturing the history has a (very) small cost and is therefore disabled by default.
- LaunchBlocking = JitFlag.LaunchBlocking¶
Force synchronization after every kernel launch. This is useful to isolate severe problems (e.g. crashes) to a specific kernel.
This flag has a severe performance impact and is disabled by default.
- ForbidSynchronization = JitFlag.ForbidSynchronization¶
Treat any kind of synchronization as an error and raise an exception when it is encountered.
Operations like
drjit.sync_thread()
are costly because they prevent the system from overlapping CPU/GPU work. Enable this flag to find places in a larger codebase that are responsible for this.This flag is disabled by default.
- ScatterReduceLocal = JitFlag.ScatterReduceLocal¶
Reduce locally before performing atomic scatter-reductions.
Atomic memory operations are expensive when many writes target the same region of memory. This leads to a phenomenon called contention that is normally associated with significant slowdowns (10-100x aren’t unusual).
This issue is particularly common when automatically differentiating computation in reverse mode (e.g.
drjit.backward()
), since this transformation turns differentiable global memory reads into atomic scatter-additions. A differentiable scalar read is all it takes to create such an atomic memory bottleneck.To reduce this cost, Dr.Jit can perform a local reduction that uses cooperation between SIMD/warp lanes to resolve all requests targeting the same address and then only issuing a single atomic memory transaction per unique target. This can reduce atomic memory traffic 32-fold on the GPU (CUDA) and 16-fold on the CPU (AVX512). On the CUDA backend, local reduction is currently only supported for 32-bit operands (signed/unsigned integers and single precision variables).
The section on optimizations presents plots that demonstrate the impact of this optimization.
The JIT flag
drjit.JitFlag.ScatterReduceLocal
affects the behavior ofscatter_add()
,scatter_reduce()
along with the reverse-mode derivative ofgather()
. Setting the flag toTrue
will usually cause amode=
argument value ofdrjit.ReduceOp.Auto
to be interpreted asdrjit.ReduceOp.Local
. Another LLVM-specific optimization takes precedence in certain situations, refer to the discussion of this flag for details.This flag is enabled by default.
- SymbolicConditionals = JitFlag.SymbolicConditionals¶
Dr.Jit provides two main ways of compiling conditionals involving Dr.Jit arrays.
Symbolic mode (the default): Dr.Jit captures the computation performed by the
True
andFalse
branches and generates an equivalent branch in the generated kernel. Symbolic mode preserves the control flow structure of the original program by replicating it within Dr.Jit’s intermediate representation.Evaluated mode: Dr.Jit always executes both branches and blends their outputs.
A separate section about symbolic and evaluated modes discusses these two options in detail.
Symbolic conditionals are enabled by default.
- SymbolicScope = JitFlag.SymbolicScope¶
This flag is set to
True
when Dr.Jit is currently capturing symbolic computation. The flag is automatically managed and should not be updated by application code.User code may query this flag to check if it is legal to perform certain operations (e.g., evaluating array contents).
Note that this information can also be queried in a more fine-grained manner (per variable) using the
drjit.ArrayBase.state
field.
- Default = JitFlag.Default¶
The default set of optimization flags consisting of
- drjit.has_backend(arg: drjit.JitBackend, /) int ¶
Check if the specified Dr.Jit backend was successfully initialized.
- drjit.schedule(arg: object, /) bool ¶
- drjit.schedule(*args) bool
Overloaded function.
schedule(arg: object, /) -> bool
Schedule the provided JIT variable(s) for later evaluation
This function causes
args
to be evaluated by the next kernel launch. In other words, the effect of this operation is deferred: the next time that Dr.Jit’s LLVM or CUDA backends compile and execute code, they will include the trace of the specified variables in the generated kernel and turn them into an explicit memory-based representation.Scheduling and evaluation of traced computation happens automatically, hence it is rare that a user would need to call this function explicitly. Explicit scheduling can improve performance in certain cases—for example, consider the following code:
# Computation that produces Dr.Jit arrays a, b = ... # The following line launches a kernel that computes 'a' print(a) # The following line launches a kernel that computes 'b' print(b)
If the traces of
a
andb
overlap (perhaps they reference computation from an earlier step not shown here), then this is inefficient as these steps will be executed twice. It is preferable to launch bigger kernels that leverage common subexpressions, which is whatdrjit.schedule()
enables:a, b = ... # Computation that produces Dr.Jit arrays # Schedule both arrays for deferred evaluation, but don't evaluate yet dr.schedule(a, b) # The following line launches a kernel that computes both 'a' and 'b' print(a) # References the stored array, no kernel launch print(b)
Note that
drjit.eval()
would also have been a suitable alternative in the above example; the main difference todrjit.schedule()
is that it does the evaluation immediately without deferring the kernel launch.This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During recursion, the function gathers all unevaluated Dr.Jit arrays. Evaluated arrays and incompatible types are ignored. Multiple variables can be equivalently scheduled with a single
drjit.schedule()
call or a sequence of calls todrjit.schedule()
. Variables that are garbage collected between the originaldrjit.schedule()
call and the next kernel launch are ignored and will not be stored in memory.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to all differentiable variables.)
- Returns:
True
if a variable was scheduled,False
if the operation did not do anything.- Return type:
bool
schedule(*args) -> bool
- drjit.eval(arg: object, /) bool ¶
- drjit.eval(*args) bool
Overloaded function.
eval(arg: object, /) -> bool
Evaluate the provided JIT variable(s)
Dr.Jit automatically evaluates variables as needed, hence it is usually not necessary to call this function explicitly. That said, explicit evaluation may sometimes improve performance—refer to the documentation of
drjit.schedule()
for an example of such a use case.drjit.eval()
invokes Dr.Jit’s LLVM or CUDA backends to compile and then execute a kernel containing the all steps that are needed to evaluate the specified variables, which will turn them into a memory-based representation. The generated kernel(s) will also include computation that was previously scheduled viadrjit.schedule()
. In fact,drjit.eval()
internally callsdrjit.schedule()
, asdr.eval(arg_1, arg_2, ...)
is equivalent to
dr.schedule(arg_1, arg_2, ...) dr.eval()
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function collects all unevaluated Dr.Jit arrays, while ignoring previously evaluated arrays along and non-array types. The function also does not evaluate literal constant arrays (this refers to potentially large arrays that are entirely uniform), as this is generally not wanted. Use the function
drjit.make_opaque()
if you wish to evaluate literal constant arrays as well.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit array instances or PyTrees (they will be recursively traversed to discover all Dr.Jit arrays.)
- Returns:
True
if a variable was evaluated,False
if the operation did not do anything.- Return type:
bool
eval(*args) -> bool
- drjit.set_flag(arg0: drjit.JitFlag, arg1: bool, /) None ¶
Set the value of the given Dr.Jit compilation flag.
- drjit.flag(arg: drjit.JitFlag, /) bool ¶
Query whether the given Dr.Jit compilation flag is active.
- class drjit.scoped_set_flag(*args, **kwargs)¶
Context manager, which sets or unsets a Dr.Jit compilation flag in a local execution scope.
For example, the following snippet shows how to temporarily disable a flag:
with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, False): # Code affected by the change should be placed here # Flag is returned to its original status
- __init__(self, flag: drjit.JitFlag, value: bool = True) None ¶
- __enter__(self) None ¶
- __exit__(self, arg0: object | None, arg1: object | None, arg2: object | None) None ¶
Type traits¶
The functions in this section can be used to infer properties or types of Dr.Jit arrays.
The naming convention with a trailing _v
or _t
indicates whether a
function returns a value or a type. Evaluation takes place at runtime within
Python. In C++, these expressions are all constexpr
(i.e., evaluated at
compile time.).
Array type tests¶
- drjit.is_array_v(arg: object | None) bool ¶
Check if the input is a Dr.Jit array instance or type
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
or type(arg
) is a Dr.Jit array type, andFalse
otherwise
- Return type:
bool
- drjit.is_mask_v(arg: object, /) bool ¶
Check whether the input array instance or type is a Dr.Jit mask array or a Python
bool
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents a Dr.Jit mask array or Pythonbool
instance or type.- Return type:
bool
- drjit.is_half_v(arg: object, /) bool ¶
Check whether the input array instance or type is a Dr.Jit half-precision floating point array or a Python
half
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents a Dr.Jit half-precision floating point array or Pythonhalf
instance or type.- Return type:
bool
- drjit.is_float_v(arg: object, /) bool ¶
Check whether the input array instance or type is a Dr.Jit floating point array or a Python
float
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents a Dr.Jit floating point array or Pythonfloat
instance or type.- Return type:
bool
- drjit.is_integral_v(arg: object, /) bool ¶
Check whether the input array instance or type is an integral Dr.Jit array or a Python
int
value/type.Note that a mask array is not considered to be integral.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an integral Dr.Jit array or Pythonint
instance or type.- Return type:
bool
- drjit.is_arithmetic_v(arg: object, /) bool ¶
Check whether the input array instance or type is an arithmetic Dr.Jit array or a Python
int
orfloat
value/type.Note that a mask type (e.g.
bool
,drjit.scalar.Array2b
, etc.) is not considered to be arithmetic.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an arithmetic Dr.Jit array or Pythonint
orfloat
instance or type.- Return type:
bool
- drjit.is_signed_v(arg: object, /) bool ¶
Check whether the input array instance or type is an signed Dr.Jit array or a Python
int
orfloat
value/type.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an signed Dr.Jit array or Pythonint
orfloat
instance or type.- Return type:
bool
- drjit.is_unsigned_v(arg: object, /) bool ¶
Check whether the input array instance or type is an unsigned integer Dr.Jit array or a Python
bool
value/type (masks and boolean values are also considered to be unsigned).- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an unsigned Dr.Jit array or Pythonbool
instance or type.- Return type:
bool
- drjit.is_dynamic_v(arg: object, /) bool ¶
Check whether the input instance or type represents a dynamically sized Dr.Jit array type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_jit_v(arg: object, /) bool ¶
Check whether the input array instance or type represents a type that undergoes just-in-time compilation.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an array type from thedrjit.cuda.*
ordrjit.llvm.*
namespaces, andFalse
otherwise.- Return type:
bool
- drjit.is_diff_v(arg: object, /) bool ¶
Check whether the input is a differentiable Dr.Jit array instance or type.
Note that this is a type-based statement that is unrelated to mathematical differentiability. For example, the integral type
drjit.cuda.ad.Int
from the CUDA AD namespace satisfiesis_diff_v(..) = 1
.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
represents an array type from thedrjit.[cuda/llvm].ad.*
namespace, andFalse
otherwise.- Return type:
bool
- drjit.is_vector_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a vectorial array type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_complex_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a complex number.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_matrix_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a matrix.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_quaternion_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a quaternion.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_tensor_v(arg: object, /) bool ¶
Check whether the input is a Dr.Jit array instance or type representing a tensor.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_special_v(arg: object, /) bool ¶
Check whether the input is a special Dr.Jit array instance or type.
A special array type requires precautions when performing arithmetic operations like multiplications (complex numbers, quaternions, matrices).
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
if the test was successful, andFalse
otherwise.- Return type:
bool
- drjit.is_struct_v(arg: object, /) bool ¶
Check if the input is a Dr.Jit-compatible data structure
Custom data structures can be made compatible with various Dr.Jit operations by specifying a
DRJIT_STRUCT
member. See the section on PyTrees for details. This type trait can be used to check for the existence of such a field.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
True
ifarg
has aDRJIT_STRUCT
member- Return type:
bool
Array properties (shape, type, etc.)¶
- drjit.type_v(arg: object, /) drjit.VarType ¶
Returns the scalar type associated with the given Dr.Jit array instance or type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
The associated type
drjit.VarType.Void
.- Return type:
- drjit.backend_v(arg: object, /) drjit.JitBackend ¶
Returns the backend responsible for the given Dr.Jit array instance or type.
- Parameters:
arg (object) – An arbitrary Python object
- Returns:
The associated Jit backend or
drjit.JitBackend.None
.- Return type:
- drjit.size_v(arg: object, /) int ¶
Return the (static) size of the outermost dimension of the provided Dr.Jit array instance or type
Note that this function mainly exists to query type-level information. Use the Python
len()
function to query the size in a way that does not distinguish between static and dynamic arrays.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Returns either the static size or
drjit.Dynamic
whenarg
is a dynamic Dr.Jit array. Returns1
for all other types.- Return type:
int
- drjit.depth_v(arg: object, /) int ¶
Return the depth of the provided Dr.Jit array instance or type
For example, an array consisting of floating point values (for example,
drjit.scalar.Array3f
) has depth1
, while an array consisting of sub-arrays (e.g.,drjit.cuda.Array3f
) has depth2
.- Parameters:
arg (object) – An arbitrary Python object
- Returns:
Returns the depth of the input, if it is a Dr.Jit array instance or type. Returns
0
for all other inputs.- Return type:
int
- drjit.itemsize_v(arg: object, /) int ¶
Return the per-item size (in bytes) of the scalar type underlying a Dr.Jit array
- Parameters:
arg (object) – A Dr.Jit array instance or array type.
- Returns:
Returns the item size array elements in bytes.
- Return type:
int
Bit-level operations¶
- drjit.reinterpret_array(arg0: type[drjit.ArrayBase], arg1: drjit.ArrayBase, /) object ¶
Reinterpret the provided Dr.Jit array or tensor as a different type.
This operation reinterprets the input type as another type provided that it has a compatible in-memory layout (this operation is also known as a bit-cast).
- Parameters:
dtype (type) – Target type.
value (object) – A compatible Dr.Jit input array or tensor.
- Returns:
Result of the conversion as described above.
- Return type:
object
- drjit.popcnt(arg: ArrayT, /) ArrayT ¶
- drjit.popcnt(arg: int, /) int
Overloaded function.
popcnt(arg: ArrayT, /) -> ArrayT
Return the number of nonzero zero bits.
This function evaluates the component-wise population count of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of nonzero zero bits in
arg
- Return type:
int | drjit.ArrayBase
popcnt(arg: int, /) -> int
- drjit.lzcnt(arg: ArrayT, /) ArrayT ¶
- drjit.lzcnt(arg: int, /) int
Overloaded function.
lzcnt(arg: ArrayT, /) -> ArrayT
Return the number of leading zero bits.
This function evaluates the component-wise leading zero count of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation is well-defined when
arg
is zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of leading zero bits in
arg
- Return type:
int | drjit.ArrayBase
lzcnt(arg: int, /) -> int
- drjit.tzcnt(arg: ArrayT, /) ArrayT ¶
- drjit.tzcnt(arg: int, /) int
Overloaded function.
tzcnt(arg: ArrayT, /) -> ArrayT
Return the number of trailing zero bits.
This function evaluates the component-wise trailing zero count of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation is well-defined when
arg
is zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of trailing zero bits in
arg
- Return type:
int | drjit.ArrayBase
tzcnt(arg: int, /) -> int
- drjit.brev(arg: ArrayT, /) ArrayT ¶
- drjit.brev(arg: int, /) int
Overloaded function.
brev(arg: ArrayT, /) -> ArrayT
Reverse the bit representation of an integer value or array.
This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.- Parameters:
arg (int | drjit.ArrayBase) – A Python
int
or Dr.Jit integer array.- Returns:
the bit-reversed version of
arg
.- Return type:
int | drjit.ArrayBase
brev(arg: int, /) -> int
- drjit.log2i(arg: T, /) T ¶
Return the floor of the base-2 logarithm.
This function evaluates the component-wise floor of the base-2 logarithm of the input scalar, array, or tensor. This function assumes that
arg
is either an arbitrary Dr.Jit integer array or a 32 bit-sized scalar integer value.The operation overflows when
arg
is zero.- Parameters:
arg (int | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
number of leading zero bits in the input array
- Return type:
int | drjit.ArrayBase
Standard mathematical functions¶
- drjit.fma(arg0: object, arg1: object, arg2: object, /) object ¶
- drjit.fma(arg0: int, arg1: int, arg2: int, /) int
- drjit.fma(arg0: float, arg1: float, arg2: float, /) float
Overloaded function.
fma(arg0: object, arg1: object, arg2: object, /) -> object
fma(arg0: int, arg1: int, arg2: int, /) -> int
Perform a fused multiply-addition (FMA) of the inputs.
Given arguments
arg0
,arg1
, andarg2
, this operation computesarg0
*arg1
+arg2
using only one final rounding step. The operation is not only more accurate, but also more efficient, since FMA maps to a native machine instruction on all platforms targeted by Dr.Jit.When the input is complex- or quaternion-valued, the function internally uses a complex or quaternion product. In this case, it reduces the number of internal rounding steps instead of avoiding them altogether.
While FMA is traditionally a floating point operation, Dr.Jit also implements FMA for integer arrays and maps it onto dedicated instructions provided by the backend if possible (e.g.
mad.lo.*
for CUDA/PTX).- Parameters:
arg0 (float | drjit.ArrayBase) – First multiplication operand
arg1 (float | drjit.ArrayBase) – Second multiplication operand
arg2 (float | drjit.ArrayBase) – Additive operand
- Returns:
Result of the FMA operation
- Return type:
float | drjit.ArrayBase
fma(arg0: float, arg1: float, arg2: float, /) -> float
- drjit.abs(arg: ArrayT, /) ArrayT ¶
- drjit.abs(arg: int, /) int
- drjit.abs(arg: float, /) float
Overloaded function.
abs(arg: ArrayT, /) -> ArrayT
Compute the absolute value of the provided input.
This function evaluates the component-wise absolute value of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
- Parameters:
arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Absolute value of the input
- Return type:
int | float | drjit.ArrayBase
abs(arg: int, /) -> int
abs(arg: float, /) -> float
- drjit.minimum(arg0: int, arg1: int, /) int ¶
- drjit.minimum(arg0: object, arg1: object, /) object
- drjit.minimum(arg0: float, arg1: float, /) float
Overloaded function.
minimum(arg0: int, arg1: int, /) -> int
minimum(arg0: object, arg1: object, /) -> object
Compute the element-wise minimum value of the provided inputs.
(Not to be confused with
drjit.min()
, which reduces the input along the specified axes to determine the minimum)- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Minimum of the input(s)
- Return type:
int | float | drjit.ArrayBase
minimum(arg0: float, arg1: float, /) -> float
- drjit.maximum(arg0: int, arg1: int, /) int ¶
- drjit.maximum(arg0: object, arg1: object, /) object
- drjit.maximum(arg0: float, arg1: float, /) float
Overloaded function.
maximum(arg0: int, arg1: int, /) -> int
maximum(arg0: object, arg1: object, /) -> object
Compute the element-wise maximum value of the provided inputs.
(Not to be confused with
drjit.max()
, which reduces the input along the specified axes to determine the maximum)- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
Maximum of the input(s)
- Return type:
int | float | drjit.ArrayBase
maximum(arg0: float, arg1: float, /) -> float
- drjit.sqrt(arg: ArrayT, /) ArrayT ¶
- drjit.sqrt(arg: float, /) float
Overloaded function.
sqrt(arg: ArrayT, /) -> ArrayT
Evaluate the square root of the provided input.
This function evaluates the component-wise square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
Negative inputs produce a NaN output value. Consider using the
safe_sqrt()
function to work around issues where the input might occasionally be negative due to prior round-off errors.Another noteworthy behavior of the square root function is that it has an infinite derivative at
arg=0
, which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. Thesafe_sqrt()
function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Square root of the input
- Return type:
float | drjit.ArrayBase
sqrt(arg: float, /) -> float
- drjit.cbrt(arg: ArrayT, /) ArrayT ¶
- drjit.cbrt(arg: float, /) float
Overloaded function.
cbrt(arg: ArrayT, /) -> ArrayT
Evaluate the cube root of the provided input.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Cube root of the input
- Return type:
float | drjit.ArrayBase
cbrt(arg: float, /) -> float
- drjit.rcp(arg: ArrayT, /) ArrayT ¶
- drjit.rcp(arg: float, /) float
Overloaded function.
rcp(arg: ArrayT, /) -> ArrayT
Evaluate the reciprocal (1 / arg) of the provided input.
When
arg
is a CUDA single precision array, the operation is implemented slightly approximately—see the documentation of the instructionrcp.approx.ftz.f32
in the NVIDIA PTX manual for details. For full IEEE-754 compatibility, unsetdrjit.JitFlag.FastMath
.When called with a matrix-, complex- or quaternion-valued array, this function uses the matrix, complex, or quaternion multiplicative inverse to evaluate the reciprocal.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Reciprocal of the input
- Return type:
float | drjit.ArrayBase
rcp(arg: float, /) -> float
- drjit.rsqrt(arg: ArrayT, /) ArrayT ¶
- drjit.rsqrt(arg: float, /) float
Overloaded function.
rsqrt(arg: ArrayT, /) -> ArrayT
Evaluate the reciprocal square root (1 / sqrt(arg)) of the provided input.
This function evaluates the component-wise reciprocal square root of the input scalar, array, or tensor. When called with a complex or quaternion-valued array, it uses a suitable generalization of the operation.
When
arg
is a single precision array, the operation is implemented slightly approximately—see the documentation of the instructionrsqrt.approx.ftz.f32
in the NVIDIA PTX manual for details on how this works on the CUDA backend. For full IEEE-754 compatibility, unsetdrjit.JitFlag.FastMath
.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Reciprocal square root of the input
- Return type:
float | drjit.ArrayBase
rsqrt(arg: float, /) -> float
- drjit.clip(value, min, max)¶
Clip the provided input to the given interval.
This function is equivalent to
dr.maximum(dr.minimum(value, max), min)
- Parameters:
value (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
min (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
max (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Clipped input
- Return type:
float | drjit.ArrayBase
- drjit.ceil(arg: ArrayT, /) ArrayT ¶
- drjit.ceil(arg: float, /) float
Overloaded function.
ceil(arg: ArrayT, /) -> ArrayT
Evaluate the ceiling, i.e. the smallest integer >= arg.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Ceiling of the input
- Return type:
float | drjit.ArrayBase
ceil(arg: float, /) -> float
- drjit.floor(arg: ArrayT, /) ArrayT ¶
- drjit.floor(arg: float, /) float
Overloaded function.
floor(arg: ArrayT, /) -> ArrayT
Evaluate the floor, i.e. the largest integer <= arg.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Floor of the input
- Return type:
float | drjit.ArrayBase
floor(arg: float, /) -> float
- drjit.trunc(arg: ArrayT, /) ArrayT ¶
- drjit.trunc(arg: float, /) float
Overloaded function.
trunc(arg: ArrayT, /) -> ArrayT
Truncates arg to the nearest integer by towards zero.
The function does not convert the type of the input array. A separate cast is necessary when integer output is desired.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Truncated result
- Return type:
float | drjit.ArrayBase
trunc(arg: float, /) -> float
- drjit.round(arg: ArrayT, /) ArrayT ¶
- drjit.round(arg: float, /) float
Overloaded function.
round(arg: ArrayT, /) -> ArrayT
Rounds the input to the nearest integer using Banker’s rounding for half-way values.
This function is equivalent to
std::rint
in C++. It does not convert the type of the input array. A separate cast is necessary when integer output is desired.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Rounded result
- Return type:
float | drjit.ArrayBase
round(arg: float, /) -> float
- drjit.sign(arg, /)¶
Return the element-wise sign of the provided array.
The function returns
\[\mathrm{sign}(\texttt{arg}) = \begin{cases} 1&\texttt{arg}>=0,\\ -1&\mathrm{otherwise}. \end{cases}\]- Parameters:
arg (int | float | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Sign of the input array
- Return type:
float | int | drjit.ArrayBase
- drjit.copysign(arg0, arg1, /)¶
Copy the sign of
arg1
toarg0
element-wise.- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to change the sign of
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to copy the sign from
- Returns:
The values of
arg0
with the sign ofarg1
- Return type:
float | int | drjit.ArrayBase
- drjit.mulsign(arg0, arg1, /)¶
Multiply
arg0
by the sign ofarg1
element-wise.This function is equivalent to
a * dr.sign(b)
- Parameters:
arg0 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to multiply the sign of
arg1 (int | float | drjit.ArrayBase) – A Python or Dr.Jit array to take the sign from
- Returns:
The values of
arg0
multiplied with the sign ofarg1
- Return type:
float | int | drjit.ArrayBase
Operations for vectors and matrices¶
- drjit.cross(arg0: ArrayT, arg1: ArrayT, /) ArrayT ¶
Returns the cross-product of the two input 3D arrays
- Parameters:
arg0 (drjit.ArrayBase) – A Dr.Jit 3D array
arg1 (drjit.ArrayBase) – A Dr.Jit 3D array
- Returns:
Cross-product of the two input 3D arrays
- Return type:
- drjit.det(arg, /)¶
Compute the determinant of the provided Dr.Jit matrix.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The determinant value of the input matrix
- Return type:
- drjit.diag(arg, /)¶
This function either returns the diagonal entries of the provided Dr.Jit matrix, or it constructs a new matrix from the diagonal entries.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The diagonal matrix of the input matrix
- Return type:
- drjit.trace(arg, /)¶
Returns the trace of the provided Dr.Jit matrix.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The trace of the input matrix
- Return type:
drjit.value_t(arg)
- drjit.matmul(arg0: object, arg1: object, /) object ¶
Compute a matrix-matrix, matrix-vector, vector-matrix, or inner product.
This function implements the semantics of the
@
operator introduced in Python’s PEP 465. There is no practical difference between usingdrjit.matul()
or@
in Dr.Jit-based code. Multiplication of matrix types (e.g.,drjit.scalar.Matrix2f
) using the standard multiplication operator (*
) is also based on matrix multiplication.This function takes two Dr.Jit arrays and picks one of the following 5 cases based on their leading fixed-size dimensions.
Matrix-matrix product: If both arrays have leading static dimensions
(n, n)
, they are multiplied like conventional matrices.Matrix-vector product: If
arg0
has leading static dimensions(n, n)
andarg1
has leading static dimension(n,)
, the operation conceptually appends a trailing 1-sized dimension toarg1
, multiplies, and then removes the extra dimension from the result.Vector-matrix product: If
arg0
has leading static dimensions(n,)
andarg1
has leading static dimension(n, n)
, the operation conceptually prepends a leading 1-sized dimension toarg0
, multiplies, and then removes the extra dimension from the result.Inner product: If
arg0
andarg1
have leading static dimensions(n,)
, the operation returns the sum of the elements ofarg0*arg1
.Scalar product: If
arg0
orarg1
is a scalar, the operation scales the elements of the other argument.
It is legal to combine vectorized and non-vectorized types, e.g.
dr.matmul(dr.scalar.Matrix4f(...), dr.cuda.Matrix4f(...))
Also, note that doesn’t matter whether an input is an instance of a matrix type or a similarly-shaped nested array—for example,
drjit.scalar.Matrix3f()
anddrjit.scalar.Array33f()
have the same shape and are treated identically.Note
This operation only handles fixed-sizes arrays. A different approach is needed for multiplications involving potentially large dynamic arrays/tensors. Other tools like PyTorch, JAX, or Tensorflow will be preferable in such situations (e.g., to train neural networks).
- Parameters:
arg0 (dr.ArrayBase) – Dr.Jit array type
arg1 (dr.ArrayBase) – Dr.Jit array type
- Returns:
The result of the operation as defined above
- Return type:
object
- drjit.hypot(a, b)¶
Computes \(\sqrt{x^2+y^2}\) while avoiding overflow and underflow.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit arithmetic type
- Returns:
The computed hypotenuse.
- Return type:
float | drjit.ArrayBase
- drjit.normalize(arg: T, /) T ¶
Normalize the input vector so that it has unit length and return the result.
This operation is equivalent to
arg * dr.rsqrt(dr.squared_norm(arg))
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit array type
- Returns:
Unit-norm version of the input
- Return type:
- drjit.lerp(a, b, t)¶
Linearly blend between two values.
This function computes
\[\mathrm{lerp}(a, b, t) = (1-t) a + t b\]In other words, it linearly blends between \(a\) and \(b\) based on the value \(t\) that is typically on the interval \([0, 1]\).
It does so using two fused multiply-additions (
drjit.fma()
) to improve performance and avoid numerical errors.- Parameters:
a (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
b (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
t (int | float | drjit.ArrayBase) – A Python or Dr.Jit type
- Returns:
Interpolated result
- Return type:
float | drjit.ArrayBase
- drjit.sh_eval(d: ArrayBase, order: int) list ¶
Evalute real spherical harmonics basis function up to a specified order.
The input
d
must be a normalized 3D Cartesian coordinate vector. The function returns a list containing all spherical harmonic basis functions evaluated with respect tod
up to the desired order, for a total of(order+1)**2
output values.The implementation relies on efficient pre-generated branch-free code with aggressive constant folding and common subexpression elimination. It admits scalar and Jit-compiled input arrays. Evaluation routines are included for orders
0
to10
. Requesting higher orders triggers a runtime exception.This automatically generated code is based on the paper Efficient Spherical Harmonic Evaluation, Journal of Computer Graphics Techniques (JCGT), vol. 2, no. 2, 84-90, 2013 by Peter-Pike Sloan.
The SciPy equivalent of this function is given by
def sh_eval(d, order: int): from scipy.special import sph_harm theta, phi = np.arccos(d.z), np.arctan2(d.y, d.x) r = [] for l in range(order + 1): for m in range(-l, l + 1): Y = sph_harm(abs(m), l, phi, theta) if m > 0: Y = np.sqrt(2) * Y.real elif m < 0: Y = np.sqrt(2) * Y.imag r.append(Y.real) return r
The Mathematica equivalent of a specific entry is given by:
SphericalHarmonicQ[l_, m_, d_] := Block[{θ, ϕ}, θ = ArcCos[d[[3]]]; ϕ = ArcTan[d[[1]], d[[2]]]; Piecewise[{ {SphericalHarmonicY[l, m, θ, ϕ], m == 0}, {Sqrt[2] * Re[SphericalHarmonicY[l, m, θ, ϕ]], m > 0}, {Sqrt[2] * Im[SphericalHarmonicY[l, -m, θ, ϕ]], m < 0} }] ]
- drjit.frob(arg, /)¶
Returns the squared Frobenius norm of the provided Dr.Jit matrix.
The squared Frobenius norm is defined as the sum of the squares of its elements:
\[\sum_{i=1}^m \sum_{j=1}^n a_{i j}^2\]- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The squared Frobenius norm of the input matrix
- Return type:
- drjit.rotate(dtype, axis, angle)¶
Constructs a rotation quaternion, which rotates by
angle
radians aroundaxis
.The function requires
axis
to be normalized.- Parameters:
dtype (type) – Desired Dr.Jit quaternion type.
axis (drjit.ArrayBase) – A 3-dimensional Dr.Jit array representing the rotation axis
angle (float | drjit.ArrayBase) – Rotation angle.
- Returns:
The rotation quaternion
- Return type:
- drjit.polar_decomp(arg, it=10)¶
Returns the polar decomposition of the provided Dr.Jit matrix.
The polar decomposition separates the matrix into a rotation followed by a scaling along each of its eigen vectors. This decomposition always exists for square matrices.
The implementation relies on an iterative algorithm, where the number of iterations can be controlled by the argument
it
(tradeoff between precision and computational cost).- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
it (int) – Number of iterations to be taken by the algorithm.
- Returns:
A tuple containing the rotation matrix and the scaling matrix resulting from the decomposition.
- Return type:
tuple
- drjit.matrix_to_quat(arg, /)¶
Converts a 3x3 or 4x4 homogeneous containing a pure rotation into a rotation quaternion.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
- Returns:
The Dr.Jit quaternion corresponding the to input matrix.
- Return type:
- drjit.quat_to_matrix(arg, size=4)¶
Converts a quaternion into its matrix representation.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit quaternion type
size (int) – Controls whether to construct a 3x3 or 4x4 matrix.
- Returns:
The Dr.Jit matrix corresponding the to input quaternion.
- Return type:
- drjit.transform_decompose(arg, it=10)¶
Performs a polar decomposition of a non-perspective 4x4 homogeneous coordinate matrix and returns a tuple of
A positive definite 3x3 matrix containing an inhomogeneous scaling operation
A rotation quaternion
A 3D translation vector
This representation is helpful when animating keyframe animations.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit matrix type
it (int) – Number of iterations to be taken by the polar decomposition algorithm.
- Returns:
The tuple containing the scaling matrix, rotation quaternion and 3D translation vector.
- Return type:
tuple
- drjit.transform_compose(S, Q, T, /)¶
This function composes a 4x4 homogeneous coordinate transformation from the given scale, rotation, and translation. It performs the reverse of
transform_decompose()
.- Parameters:
S (drjit.ArrayBase) – A Dr.Jit matrix type representing the scaling part
Q (drjit.ArrayBase) – A Dr.Jit quaternion type representing the rotation part
T (drjit.ArrayBase) – A 3D Dr.Jit array type representing the translation part
- Returns:
The Dr.Jit matrix resulting from the composition described above.
- Return type:
Operations for complex values and quaternions¶
- drjit.conj(arg, /)¶
Returns the conjugate of the provided complex or quaternion-valued array. For all other types, it returns the input unchanged.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit 3D array
- Returns:
Conjugate form of the input
- Return type:
- drjit.arg(z, /)¶
Return the argument of a complex Dr.Jit array.
The argument refers to the angle (in radians) between the positive real axis and a vector towards
z
in the complex plane. When the input isn’t complex-valued, the function returns \(0\) or \(\pi\) depending on the sign ofz
.- Parameters:
z (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Argument of the complex input array
- Return type:
float | drjit.ArrayBase
- drjit.real(arg, /)¶
Return the real part of a complex or quaternion-valued input.
When the input isn’t complex- or quaternion-valued, the function returns the input unchanged.
- Parameters:
arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Real part of the input array
- Return type:
float | drjit.ArrayBase
- drjit.imag(arg, /)¶
Return the imaginary part of a complex or quaternion-valued input.
When the input isn’t complex- or quaternion-valued, the function returns zero.
- Parameters:
arg (int | float | complex | drjit.ArrayBase) – A Python or Dr.Jit array
- Returns:
Imaginary part of the input array
- Return type:
float | drjit.ArrayBase
- drjit.quat_to_euler(arg, /)¶
Converts a quaternion into its Euler angles representation.
The order for Euler angles is XYZ.
- Parameters:
arg (drjit.ArrayBase) – A Dr.Jit quaternion type
- Returns:
A 3D Dr.Jit array containing the Euler angles.
- Return type:
- drjit.euler_to_quat(arg, /)¶
Converts Euler angles into a Dr.Jit quaternion.
The order for input Euler angles must be XYZ.
- Parameters:
arg (drjit.ArrayBase) – A 3D Dr.Jit array type
- Returns:
A Dr.Jit quaternion representing the input Euler angles.
- Return type:
Transcendental functions¶
Dr.Jit implements the most common transcendental functions using methods that are based on the CEPHES math library. The accuracy of these approximations is documented in a set of tables below.
Trigonometric functions¶
- drjit.sin(arg: ArrayT, /) ArrayT ¶
- drjit.sin(arg: float, /) float
Overloaded function.
sin(arg: ArrayT, /) -> ArrayT
Evaluate the sine function.
This function evaluates the component-wise sine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library and is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Sine of the input
- Return type:
float | drjit.ArrayBase
sin(arg: float, /) -> float
- drjit.cos(arg: ArrayT, /) ArrayT ¶
- drjit.cos(arg: float, /) float
Overloaded function.
cos(arg: ArrayT, /) -> ArrayT
Evaluate the cosine function.
This function evaluates the component-wise cosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Cosine of the input
- Return type:
float | drjit.ArrayBase
cos(arg: float, /) -> float
- drjit.sincos(arg: ArrayT, /) tuple[ArrayT, ArrayT] ¶
- drjit.sincos(arg: float, /) tuple[float, float]
Overloaded function.
sincos(arg: ArrayT, /) -> tuple[ArrayT, ArrayT]
Evaluate both sine and cosine functions at the same time.
This function simultaneously evaluates the component-wise sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to
drjit.sin()
anddrjit.cos()
when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the hardware’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Sine and cosine of the input
- Return type:
(float, float) | (drjit.ArrayBase, drjit.ArrayBase)
sincos(arg: float, /) -> tuple[float, float]
- drjit.tan(arg: ArrayT, /) ArrayT ¶
- drjit.tan(arg: float, /) float
Overloaded function.
tan(arg: ArrayT, /) -> ArrayT
Evaluate the tangent function.
This function evaluates the component-wise tangent function associated with each entry of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The default implementation of this function is based on the CEPHES library. It is designed to achieve low error on the domain \(|x| < 8192\) and will not perform as well beyond this range. See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation instead uses the GPU’s built-in multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Tangent of the input
- Return type:
float | drjit.ArrayBase
tan(arg: float, /) -> float
- drjit.asin(arg: ArrayT, /) ArrayT ¶
- drjit.asin(arg: float, /) float
Overloaded function.
asin(arg: ArrayT, /) -> ArrayT
Evaluate the arcsine function.
This function evaluates the component-wise arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when called with a complex-valued input.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the
safe_asin()
function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The
safe_asin()
function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arcsine of the input
- Return type:
float | drjit.ArrayBase
asin(arg: float, /) -> float
- drjit.acos(arg: ArrayT, /) ArrayT ¶
- drjit.acos(arg: float, /) float
Overloaded function.
acos(arg: ArrayT, /) -> ArrayT
Evaluate the arccosine function.
This function evaluates the component-wise arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
Real-valued inputs outside of the domain \((-1, 1)\) produce a NaN output value. Consider using the
safe_acos()
function to work around issues where the input might occasionally lie outside of this range due to prior round-off errors.Another noteworthy behavior of the arcsine function is that it has an infinite derivative at \(\texttt{arg}=\pm 1\), which can cause infinities/NaNs in gradients computed via forward/reverse-mode AD. The
safe_acos()
function contains a workaround to ensure a finite derivative in this case.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arccosine of the input
- Return type:
float | drjit.ArrayBase
acos(arg: float, /) -> float
- drjit.atan(arg: ArrayT, /) ArrayT ¶
- drjit.atan(arg: float, /) float
Overloaded function.
atan(arg: ArrayT, /) -> ArrayT
Evaluate the arctangent function.
This function evaluates the component-wise arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arctangent of the input
- Return type:
float | drjit.ArrayBase
atan(arg: float, /) -> float
- drjit.atan2(arg0: object, arg1: object, /) object ¶
- drjit.atan2(arg0: float, arg1: float, /) float
Overloaded function.
atan2(arg0: object, arg1: object, /) -> object
Evaluate the four-quadrant arctangent function.
This function is currently only implemented for real-valued inputs.
See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
y (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
x (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arctangent of
y
/x
, using the argument signs to determine the quadrant of the return value- Return type:
float | drjit.ArrayBase
atan2(arg0: float, arg1: float, /) -> float
Hyperbolic functions¶
- drjit.sinh(arg: ArrayT, /) ArrayT ¶
- drjit.sinh(arg: float, /) float
Overloaded function.
sinh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic sine function.
This function evaluates the component-wise hyperbolic sine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic sine of the input
- Return type:
float | drjit.ArrayBase
sinh(arg: float, /) -> float
- drjit.cosh(arg: ArrayT, /) ArrayT ¶
- drjit.cosh(arg: float, /) float
Overloaded function.
cosh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic cosine function.
This function evaluates the component-wise hyperbolic cosine of the input scalar, array, or tensor. The function uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic cosine of the input
- Return type:
float | drjit.ArrayBase
cosh(arg: float, /) -> float
- drjit.sincosh(arg: ArrayT, /) tuple[ArrayT, ArrayT] ¶
- drjit.sincosh(arg: float, /) tuple[float, float]
Overloaded function.
sincosh(arg: ArrayT, /) -> tuple[ArrayT, ArrayT]
Evaluate both hyperbolic sine and cosine functions at the same time.
This function simultaneously evaluates the component-wise hyperbolic sine and cosine of the input scalar, array, or tensor. This is more efficient than two separate calls to
drjit.sinh()
anddrjit.cosh()
when both are required. The function uses a suitable generalization of the operation when the input is complex-valued.The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic sine and cosine of the input
- Return type:
(float, float) | (drjit.ArrayBase, drjit.ArrayBase)
sincosh(arg: float, /) -> tuple[float, float]
- drjit.tanh(arg: ArrayT, /) ArrayT ¶
- drjit.tanh(arg: float, /) float
Overloaded function.
tanh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic tangent function.
This function evaluates the component-wise hyperbolic tangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic tangent of the input
- Return type:
float | drjit.ArrayBase
tanh(arg: float, /) -> float
- drjit.asinh(arg: ArrayT, /) ArrayT ¶
- drjit.asinh(arg: float, /) float
Overloaded function.
asinh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic arcsine function.
This function evaluates the component-wise hyperbolic arcsine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arcsine of the input
- Return type:
float | drjit.ArrayBase
asinh(arg: float, /) -> float
- drjit.acosh(arg: ArrayT, /) ArrayT ¶
- drjit.acosh(arg: float, /) float
Overloaded function.
acosh(arg: ArrayT, /) -> ArrayT
Hyperbolic arccosine approximation.
This function evaluates the component-wise hyperbolic arccosine of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arccosine of the input
- Return type:
float | drjit.ArrayBase
acosh(arg: float, /) -> float
- drjit.atanh(arg: ArrayT, /) ArrayT ¶
- drjit.atanh(arg: float, /) float
Overloaded function.
atanh(arg: ArrayT, /) -> ArrayT
Evaluate the hyperbolic arctangent function.
This function evaluates the component-wise hyperbolic arctangent of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex-valued.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Hyperbolic arctangent of the input
- Return type:
float | drjit.ArrayBase
atanh(arg: float, /) -> float
Exponentials, logarithms, power function¶
- drjit.log2(arg: ArrayT, /) ArrayT ¶
- drjit.log2(arg: float, /) float
Overloaded function.
log2(arg: ArrayT, /) -> ArrayT
Evaluate the base-2 logarithm.
This function evaluates the component-wise base-2 logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Base-2 logarithm of the input
- Return type:
float | drjit.ArrayBase
log2(arg: float, /) -> float
- drjit.log(arg: ArrayT, /) ArrayT ¶
- drjit.log(arg: float, /) float
Overloaded function.
log(arg: ArrayT, /) -> ArrayT
Evaluate the natural logarithm.
This function evaluates the component-wise natural logarithm of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Natural logarithm of the input
- Return type:
float | drjit.ArrayBase
log(arg: float, /) -> float
- drjit.exp2(arg: ArrayT, /) ArrayT ¶
- drjit.exp2(arg: float, /) float
Overloaded function.
exp2(arg: ArrayT, /) -> ArrayT
Evaluate
2
raised to a given power.This function evaluates the component-wise base-2 exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Base-2 exponential of the input
- Return type:
float | drjit.ArrayBase
exp2(arg: float, /) -> float
- drjit.exp(arg: ArrayT, /) ArrayT ¶
- drjit.exp(arg: float, /) float
Overloaded function.
exp(arg: ArrayT, /) -> ArrayT
Evaluate the natural exponential function.
This function evaluates the component-wise natural exponential function of the input scalar, array, or tensor. It uses a suitable generalization of the operation when the input is complex- or quaternion-valued.
See the section on transcendental function approximations for details regarding accuracy.
When
arg
is a CUDA single precision array, the operation is implemented using the native multi-function (“MUFU”) unit.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Natural exponential of the input
- Return type:
float | drjit.ArrayBase
exp(arg: float, /) -> float
- drjit.power(arg0: int, arg1: int, /) float ¶
- drjit.power(arg0: float, arg1: float, /) float
- drjit.power(arg0: object, arg1: object, /) object
Overloaded function.
power(arg0: int, arg1: int, /) -> float
Raise the first argument to a power specified via the second argument.
This function evaluates the component-wise power of the input scalar, array, or tensor arguments. When called with a complex or quaternion-valued inputs, it uses a suitable generalization of the operation.
When
arg1
is a Pythonint
or integralfloat
value, the function reduces operation to a sequence of multiplies and adds (potentially followed by a reciprocation operation whenarg1
is negative).The general case involves recursive use of the identity
pow(arg0, arg1) = exp2(log2(arg0) * arg1)
.There is no difference between using
drjit.power()
and the builtin Python**
operator.- Parameters:
arg (object) – A Python or Dr.Jit arithmetic type
- Returns:
The result of the operation
arg0**arg1
- Return type:
object
power(arg0: float, arg1: float, /) -> float
power(arg0: object, arg1: object, /) -> object
Other¶
- drjit.erf(arg: ArrayT, /) ArrayT ¶
- drjit.erf(arg: float, /) float
Overloaded function.
erf(arg: ArrayT, /) -> ArrayT
Evaluate the error function.
The error function <https://en.wikipedia.org/wiki/Error_function> is defined as
\[\operatorname{erf}(z) = \frac{2}{\sqrt\pi}\int_0^z e^{-t^2}\,\mathrm{d}t.\]See the section on transcendental function approximations for details regarding accuracy.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\mathrm{erf}(\textt{arg})\)
- Return type:
float | drjit.ArrayBase
erf(arg: float, /) -> float
- drjit.erfinv(arg: ArrayT, /) ArrayT ¶
- drjit.erfinv(arg: float, /) float
Overloaded function.
erfinv(arg: ArrayT, /) -> ArrayT
Evaluate the inverse error function.
This function evaluates the inverse of
drjit.erf()
. Its implementation is based on the paper Approximating the erfinv function by Mike Giles.This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\mathrm{erf}^{-1}(\textt{arg})\)
- Return type:
float | drjit.ArrayBase
erfinv(arg: float, /) -> float
- drjit.lgamma(arg: ArrayT, /) ArrayT ¶
- drjit.lgamma(arg: float, /) float
Overloaded function.
lgamma(arg: ArrayT, /) -> ArrayT
Evaluate the natural logarithm of the absolute value the gamma function.
The implementation of this function is based on the CEPHES library. See the section on transcendental function approximations for details regarding accuracy.
This function is currently only implemented for real-valued inputs.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
\(\log|\Gamma(\texttt{arg})|\)
- Return type:
float | drjit.ArrayBase
lgamma(arg: float, /) -> float
- drjit.rad2deg(arg: T, /) T ¶
Convert angles from radians to degrees.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
The equivalent angle in degrees.
- Return type:
float | drjit.ArrayBase
- drjit.deg2rad(arg: T, /) T ¶
Convert angles from degrees to radians.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
The equivalent angle in radians.
- Return type:
float | drjit.ArrayBase
Safe mathematical functions¶
Dr.Jit provides “safe” variants of a few standard mathematical operations that are prone to out-of-domain errors in calculations with floating point rounding errors. Such errors could, e.g., cause the argument of a square root to become negative, which would ordinarily require complex arithmetic. At zero, the derivative of the square root function is infinite. The following operations clamp the input to a safe range to avoid these extremes.
- drjit.safe_sqrt(arg: T, /) T ¶
Safely evaluate the square root of the provided input avoiding domain errors.
Negative inputs produce zero-valued output. When differentiated via AD, this function also avoids generating infinite derivatives at
x=0
.- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Square root of the input
- Return type:
float | drjit.ArrayBase
- drjit.safe_asin(arg: T, /) T ¶
Safe wrapper around
drjit.asin()
that avoids domain errors.Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arcsine approximation
- Return type:
float | drjit.ArrayBase
- drjit.safe_acos(arg: T, /) T ¶
Safe wrapper around
drjit.acos()
that avoids domain errors.Input values are clipped to the \((-1, 1)\) domain. When differentiated via AD, this function also avoids generating infinite derivatives at the boundaries of the domain.
- Parameters:
arg (float | drjit.ArrayBase) – A Python or Dr.Jit floating point type
- Returns:
Arccosine approximation
- Return type:
float | drjit.ArrayBase
Automatic differentiation¶
- enum drjit.ADMode(value)¶
Enumeration to distinguish different types of primal/derivative computation.
See also
drjit.enqueue()
,drjit.traverse()
.Valid values are as follows:
- Primal = ADMode.Primal¶
Primal/original computation without derivative tracking. Note that this is not a valid input to Dr.Jit AD routines, but it is sometimes useful to have this entry when to indicate to a computation that derivative propagation should not be performed.
- Forward = ADMode.Forward¶
Propagate derivatives in forward mode (from inputs to outputs)
- Backward = ADMode.Backward¶
Propagate derivatives in backward/reverse mode (from outputs to inputs
- enum drjit.ADFlag(value)¶
By default, Dr.Jit’s AD system destructs the enqueued input graph during forward/backward mode traversal. This frees up resources, which is useful when working with large wavefronts or very complex computation graphs. However, this also prevents repeated propagation of gradients through a shared subgraph that is being differentiated multiple times.
To support more fine-grained use cases that require this, the following flags can be used to control what should and should not be destructed.
Members of this enumeration can be combined using the
|
operator.- Member Type:
int
Valid values are as follows:
- ClearNone = ADFlag.ClearNone¶
Clear nothing.
- ClearEdges = ADFlag.ClearEdges¶
Delete all traversed edges from the computation graph
- ClearInput = ADFlag.ClearInput¶
Clear the gradients of processed input vertices (in-degree == 0)
- ClearInterior = ADFlag.ClearInterior¶
Clear the gradients of processed interior vertices (out-degree != 0)
- ClearVertices = ADFlag.ClearVertices¶
Clear gradients of processed vertices only, but leave edges intact. Equal to
ClearInput | ClearInterior
.
- AllowNoGrad = ADFlag.AllowNoGrad¶
Don’t fail when the input to a
drjit.forward
orbackward
operation is not a differentiable array.
- Default = ADFlag.Default¶
Default: clear everything (edges, gradients of processed vertices). Equal to
ClearEdges | ClearVertices
.
- drjit.detach(arg: T, preserve_type: bool = True) T ¶
Transforms the input variable into its non-differentiable version (detaches it from the AD computational graph).
This function supports arbitrary Dr.Jit arrays/tensors and PyTrees as input. In the latter case, it applies the transformation recursively. When the input variable is not a PyTree or Dr.Jit array, it is returned as it is.
While the type of the returned array is preserved by default, it is possible to set the
preserve_type
argument to false to force the returned type to be non-differentiable. For example, this will convert an array of typedrjit.llvm.ad.Float
into one of typedrjit.llvm.Float
.- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.
preserve_type (bool) – Defines whether the returned variable should preserve the type of the input variable.
- Returns:
The detached variable.
- Return type:
object
- drjit.enable_grad(arg: object, /) None ¶
- drjit.enable_grad(*args) None
Overloaded function.
enable_grad(arg: object, /) -> None
Enable gradient tracking for the provided variables.
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function enables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.
enable_grad(*args) -> None
- drjit.disable_grad(arg: object, /) None ¶
- drjit.disable_grad(*args) None
Overloaded function.
disable_grad(arg: object, /) -> None
Disable gradient tracking for the provided variables.
This function accepts a variable-length keyword argument and processes all input arguments. It recursively traverses PyTrees PyTrees (sequences, mappings, custom data structures, etc.).
During this recursive traversal, the function disables gradient tracking for all encountered Dr.Jit arrays. Variables of other types are ignored.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors or PyTrees.
disable_grad(*args) -> None
- drjit.set_grad_enabled(arg0: object, arg1: bool, /) None ¶
Enable or disable gradient tracking on the provided variables.
- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, PyTree, sequence, or mapping.
value (bool) – Defines whether gradient tracking should be enabled or disabled.
- drjit.grad_enabled(arg: object, /) bool ¶
- drjit.grad_enabled(*args) bool
Overloaded function.
grad_enabled(arg: object, /) -> bool
Return whether gradient tracking is enabled on any of the given variables.
- Parameters:
*args (tuple) – A variable-length list of Dr.Jit arrays/tensors instances or PyTrees. The function recursively traverses them to all differentiable variables.
- Returns:
True
if any of the input variables has gradient tracking enabled,False
otherwise.- Return type:
bool
grad_enabled(*args) -> bool
- drjit.grad(arg: T, preserve_type: bool = True) T ¶
Return the gradient value associated to a given variable.
When the variable doesn’t have gradient tracking enabled, this function returns
0
.- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor or PyTree.
preserve_type (bool) – Should the operation preserve the input type in the return value? (This is the default). Otherwise, Dr.Jit will, e.g., return a type of drjit.cuda.Float for an input of type drjit.cuda.ad.Float.
- Returns:
the gradient value associated to the input variable.
- Return type:
object
- drjit.set_grad(target: T, source: T) None ¶
Set the gradient associated with the provided variable.
This operation internally decomposes into two sub-steps:
dr.clear_grad(target) dr.accum_grad(target, source)
When
source
is not of the same type astarget
, Dr.Jit will try to broadcast its contents into the right shape.
- drjit.accum_grad(target: T, source: T) None ¶
Accumulate the contents of one variable into the gradient of another variable.
When
source
is not of the same type astarget
, Dr.Jit will try to broadcast its contents into the right shape.
- drjit.replace_grad(arg0: T, arg1: T, /) None ¶
Replace the gradient value of
arg0
with the one ofarg1
.This is a relatively specialized operation to be used with care when implementing advanced automatic differentiation-related features.
One example use would be to inform Dr.Jit that there is a better way to compute the gradient of a particular expression than what the normal AD traversal of the primal computation graph would yield.
The function promotes and broadcasts
arg0
andarg1
if they are not of the same type.
- drjit.clear_grad(arg: object, /) None ¶
Clear the gradient of the given variable.
- Parameters:
arg (object) – An arbitrary Dr.Jit array, tensor, or PyTree.
- drjit.traverse(mode: drjit.ADMode, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Propagate gradients along the enqueued set of AD graph edges.
Given prior use of :py:func`drjit.enqueue()` to enqueue AD nodes for gradient propagation, this functions now performs the actual gradient propagation into either the forward or reverse direction (as specified by the
mode
parameter)By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. The term leaf node is defined as follows: refers to
In forward AD, leaf nodes have no forward edges. They are outputs of a computation, and no other differentiable variable depends on them.
In backward AD, leaf nodes have no backward edges. They are inputs to a computation.
By default, the traversal also removes the edges of visited nodes to isolate them. These defaults are usually good ones: cleaning up the graph his frees up resources, which is useful when working with large wavefronts or very complex computation graphs. It also avoids potentially undesired derivative contributions that can arise when the AD graphs of two unrelated computations are connected by an edge and subsequently separately differentiated.
In advanced applications that require multiple AD traversals of the same graph, specify different combinations of the enumeration
drjit.ADFlag
via theflags
parameter.- Parameters:
mode (drjit.ADMode) – Specifies the direction in which gradients should be propgated.
drjit.ADMode.Forward
and:py:attr:drjit.ADMode.Backward refer to forward and backward traversal.flags (drjit.ADFlag | int) – Controls what parts of the AD graph are cleared during traversal. The default value is
drjit.ADFlag.Default
.
- drjit.enqueue(mode: drjit.ADMode, arg: object) None ¶
- drjit.enqueue(mode: drjit.ADMode, *args) None
Overloaded function.
enqueue(mode: drjit.ADMode, arg: object) -> None
Enqueues the input variable(s) for subsequent gradient propagation
Dr.Jit splits the process of automatic differentiation into three parts:
Initializing the gradient of one or more input or output variables. The most common initialization entails setting the gradient of an output (e.g., an optimization loss) to
1.0
.Enqueuing nodes that should partake in the gradient propagation pass. Dr.Jit will follow variable dependences (edges in the AD graph) to find variables that are reachable from the enqueued variable.
Finally propagating gradients to all of the enqueued variables.
This function is responsible for step 2 of the above list and works differently depending on the specified
mode
:- -
drjit.ADMode.Forward
: Dr.Jit will recursively enqueue all variables that are reachable along forward edges. That is, given a differentiable operation
a = b+c
, enqueuingc
will also enqueuea
for later traversal.- -
drjit.ADMode.Backward
: Dr.Jit will recursively enqueue all variables that are reachable along backward edges. That is, given a differentiable operation
a = b+c
, enqueuinga
will also enqueueb
andc
for later traversal.
For example, a typical chain of operations to forward propagate the gradients from
a
tob
might look as follow:a = dr.llvm.ad.Float(1.0) dr.enable_grad(a) b = f(a) # some computation involving 'a' # The below three operations can also be written more compactly as dr.forward(a) dr.set_gradient(a, 1.0) dr.enqueue(dr.ADMode.Forward, a) dr.traverse(dr.ADMode.Forward) grad = dr.grad(b)
One interesting aspect of this design is that enqueuing and traversal don’t necessarily need to follow the same direction.
For example, we may only be interested in forward gradients reaching a specific output node
c
, which can be expressed as follows:a = dr.llvm.ad.Float(1.0) dr.enable_grad(a) b, c, d, e = f(a) dr.set_gradient(a, 1.0) dr.enqueue(dr.ADMode.Backward, b) dr.traverse(dr.ADMode.Forward) grad = dr.grad(b)
The same naturally also works in the reverse direction. Dr.Jit provides a higher level API that encapsulate such logic in a few different flavors:
drjit.forward_from()
,drjit.forward_to()
, anddrjit.forward()
.drjit.backward_from()
,drjit.backward_to()
, anddrjit.backward()
.
- Parameters:
mode (drjit.ADMode) –
- Specifies the set edges which Dr.Jit should follow to
enqueue variables to be visited by a later gradient propagation phase.
drjit.ADMode.Forward
and:py:attr:drjit.ADMode.Backward refer to forward andbackward edges, respectively.
value (object) – An arbitrary Dr.Jit array, tensor or PyTree.
enqueue(mode: drjit.ADMode, *args) -> None
- drjit.forward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Forward-propagate gradients from the provided Dr.Jit array or tensor.
This function checks if a derivative has already been assigned to the provided Dr.Jit array or tensor
arg
. If not, it assigns the value1
.Following this, it forward-propagates derivatives through forward-connected components of the computation graph (i.e., reaching all variables that directly or indirectly depend on
arg
).The operation is equivalent to
if arg.grad has not been set yet: dr.set_grad(arg, 1) dr.enqueue(dr.ADMode.Forward, h) dr.traverse(dr.ADMode.Forward, flags=flags)
Refer to the documentation functions
drjit.set_grad()
,drjit.enqueue()
, anddrjit.traverse()
for further details on the nuances of forward derivative propagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.forward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT ¶
- drjit.forward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]
Overloaded function.
forward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> ArrayT
Forward-propagate gradients to the provided set of Dr.Jit arrays/tensors.
The operation is equivalent to
dr.enqueue(dr.ADMode.Backward, *args) dr.traverse(dr.ADMode.Forward, flags=flags) return dr.grad(*args)
Internally, the operation first traverses the computation graph backwards from
args
to find potential paths along which gradients can flow to the given set of arrays. Then, it performs a gradient propagation pass along the detected variables.For this to work, you must have previously enabled and specified input gradients for inputs of the computation. (see
drjit.enable_grad()
and viadrjit.set_grad()
).Refer to the documentation functions
drjit.enqueue()
anddrjit.traverse()
for further details on the nuances of forward derivative propagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- Returns:
the gradient value(s) associated with
*args
following the traversal.- Return type:
object
forward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> tuple[*Ts]
- drjit.forward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Forard-propagate gradients from the provided Dr.Jit array or tensor.
This operation is equivalent to
dr.set_grad(arg, 1) dr.forward_from(arg)
In other words, it assigns an initial gradient of
1
toarg
and then forward-propagates it through the rest of the computation.Please refer to the function
drjit.forward_from()
for further detail.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.backward_from(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Backpropagate gradients from the provided Dr.Jit array or tensor.
This function checks if a derivative has already been assigned to the provided Dr.Jit array or tensor
arg
. If not, it assigns the value1
.Following this, it backpropagates derivatives through backward-connected components of the computation graph (i.e., reaching differentiable variables that potentially influence the value of
arg
).The operation is conceptually equivalent to
if arg.grad has not been set yet: dr.set_grad(arg, 1) dr.enqueue(dr.ADMode.Backward, h) dr.traverse(dr.ADMode.Backward, flags=flags)
Refer to the documentation functions
drjit.set_grad()
,drjit.enqueue()
, anddrjit.traverse()
for further details on the nuances of derivative backpropagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.backward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) ArrayT ¶
- drjit.backward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) tuple[*Ts]
Overloaded function.
backward_to(arg: ArrayT, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> ArrayT
Backpropagate gradients to the provided set of Dr.Jit arrays/tensors.
The operation is equivalent to
dr.enqueue(dr.ADMode.Forward, *args) dr.traverse(dr.ADMode.Backwards, flags=flags) return dr.grad(*args)
Internally, the operation first traverses the computation graph forwards from
args
to find potential paths along which reverse-mode gradients can flow to the given set of input variables. Then, it performs a backpropagation pass along the detected variables.For this to work, you must have previously enabled and specified input gradients for outputs of the computation. (see
drjit.enable_grad()
and viadrjit.set_grad()
).Refer to the documentation functions
drjit.enqueue()
anddrjit.traverse()
for further details on the nuances of derivative backpropagation.By default, the operation is destructive: it clears the gradients of visited interior nodes and only retains gradients at leaf nodes. For details on this, refer to the documentation of
drjit.enqueue()
and the meaning of theflags
parameter.When
drjit.JitFlag.SymbolicCalls
is set, the implementation raises an exception when the provided array does not support gradient tracking, or when gradient tracking was not previously enabled viadrjit.enable_grad()
, as this generally indicates the presence of a bug. Specify thedrjit.ADFlag.AllowNoGrad
flag (e.g. by passingflags=dr.ADFlag.Default | dr.ADFlag.AllowNoGrad
) to the function.- Parameters:
*args (tuple) – A variable-length list of Dr.Jit differentiable array, tensors, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- Returns:
the gradient value(s) associated with
*args
following the traversal.- Return type:
object
backward_to(*args: *Ts, flags: drjit.ADFlag | int = drjit.ADFlag.Default) -> tuple[*Ts]
- drjit.backward(arg: drjit.AnyArray, flags: drjit.ADFlag | int = drjit.ADFlag.Default) None ¶
Backpropgate gradients from the provided Dr.Jit array or tensor.
This operation is equivalent to
dr.set_grad(arg, 1) dr.backward_from(arg)
In other words, it assigns an initial gradient of
1
toarg
and then backpropagates it through the rest of the computation.Please refer to the function
drjit.backward_from()
for further detail.- Parameters:
args (object) – A Dr.Jit array, tensor, or PyTree.
flags (drjit.ADFlag | int) – Controls what parts of the AD graph to clear during traversal, and whether or not to fail when the input is not differentiable. The default value is
drjit.ADFlag.Default
.
- drjit.suspend_grad(*args, when=True)¶
Python context manager to temporarily disable gradient tracking globally, or for a specific set of variables.
This context manager can be used as follows to completely disable all gradient tracking. Newly created variables will be detached from Dr.Jit’s AD graph.
with dr.suspend_grad(): # .. code coes here ..
You may also specify any number of Dr.Jit arrays, tensors, or PyTrees. In this case, the context manager behaves differently by disabling gradient tracking more selectively for the specified variables.
with dr.suspend_grad(x): z = x + y # 'z' will not track any gradients arising from 'x'
The
suspend_grad()
andresume_grad()
context manager can be arbitrarily nested and suitably update the set of tracked variables.A note about the interaction with
drjit.enable_grad()
: it is legal to register further AD variables within a scope that disables gradient tracking for specific variables.with dr.suspend_grad(x): y = Float(1) dr.enable_grad(y) # The following condition holds assert not dr.grad_enabled(x) and dr.grad_enabled(y)
In contrast, a
suspend_grad()
environment without arguments that completely disables AD does not allow further variables to be registered:with dr.suspend_grad(): y = Float(1) dr.enable_grad(y) # ignored # The following condition holds assert not dr.grad_enabled(x) and not dr.grad_enabled(y)
- Parameters:
*args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or PyTrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False
. The default value iswhen=True
.
- drjit.resume_grad(*args, when=True)¶
Python context manager to temporarily resume gradient tracking globally, or for a specific set of variables.
This context manager can be used as follows to fully re-enable all gradient tracking following a previous call to
drjit.suspend_grad()
. Newly created variables will then again be attached to Dr.Jit’s AD graph.with dr.suspend_grad(): # .. with dr.resume_grad(): # In this scope, the effect of the outer context # manager is effectively disabled
You may also specify any number of Dr.Jit arrays, tensors, or PyTrees. In this case, the context manager behaves differently by enabling gradient tracking more selectively for the specified variables.
with dr.suspend_grad(): with dr.resume_grad(x): z = x + y # 'z' will only track gradients arising from 'x'
The
suspend_grad()
andresume_grad()
context manager can be arbitrarily nested and suitably update the set of tracked variables.- Parameters:
*args (tuple) – Arbitrary list of Dr.Jit arrays, tuples, or PyTrees. Elements of data structures that could not possibly be attached to the AD graph (e.g., Python scalars) are ignored.
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False
. The default value iswhen=True
.
- drjit.isolate_grad(when=True)¶
Python context manager to isolate and partition AD traversals into multiple distinct phases.
Consider a sequence of steps being differentiated in reverse mode, like so:
x = .. dr.enable_grad(x) y = f(x) z = g(y) dr.backward(z)
The
drjit.backward()
call would automatically traverse the AD graph nodes created during the execution of the functionf()
andg()
.However, sometimes this is undesirable and more control is needed. For example, Dr.Jit may be in an execution context (a symbolic loop or call) that temporarily disallows differentiation of the
f()
part. Thedrjit.isolate_grad()
context manager addresses this need:dr.enable_grad(x) y = f(x) with dr.isolate_grad(): z = g(y) dr.backward(z)
Any reverse-mode AD traversal of an edge that crosses the isolation boundary is postponed until leaving the scope. This is mathematically equivalent but produces two smaller separate AD graph traversals.
Dr.Jit operations like symbolic loops and calls internally create such an isolation boundary, hence it is rare that you would need to do so yourself.
isolate_grad()
is not useful for forward mode AD.- Parameters:
when (bool) – Optional keyword argument that can be specified to turn the context manager into a no-op via
when=False
. The default value iswhen=True
.
- class drjit.CustomOp(*args, **kwargs)¶
Base class for implementing custom differentiable operations.
Dr.Jit can compute derivatives of builtin operations in both forward and reverse mode. In some cases, it may be useful or even necessary to control how a particular operation should be differentiated.
To do so, you may extend this class to provide three callback functions:
CustomOp.eval()
: Implements the primal evaluation of the function with detached inputs.CustomOp.forward()
: Implements the forward derivative that propagates derivatives from input arguments to the return valueCustomOp.backward()
: Implements the backward derivative that propagates derivatives from the return value to the input arguments.
An example for a hypothetical custom addition operation is shown below
class Addition(dr.CustomOp): def eval(self, x, y): # Primal calculation without derivative tracking return x + y def forward(self): # Compute forward derivatives self.set_grad_out(self.grad_in('x') + self.grad_in('y')) def backward(self): # .. compute backward derivatives .. self.set_grad_in('x', self.grad_out()) self.set_grad_in('y', self.grad_out()) def name(self): # Optional: a descriptive name shown in GraphViz visualizations return "Addition"
You should never need to call these functions yourself—Dr.Jit will do so when appropriate. To weave such a custom operation into the AD graph, use the
drjit.custom()
function, which expects a subclass ofdrjit.CustomOp
as first argument, followed by arguments to the actual operation that are directly forwarded to the.eval()
callback.# Add two numbers 'x' and 'y'. Calls our '.eval()' callback with detached arguments result = dr.custom(Addition, x, y)
Forward or backward derivatives are then automatically handled through the standard operations. For example,
dr.backward(result)
will invoke the
.backward()
callback from above.Many derivatives are more complex than the above examples and require access to inputs or intermediate steps of the primal evaluation routine. You can simply stash them in the instance (
self.field = ...
), which is shown below for a differentiable multiplication operation that implements the product rule:class Multiplication(dr.CustomOp): def eval(self, x, y): # Stash input arguments self.x = x self.y = y return x * y def forward(self): self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y')) def backward(self): self.set_grad_in('x', self.y * self.grad_out()) self.set_grad_in('y', self.x * self.grad_out()) def name(self): return "Multiplication"
- eval(self, *args, **kwargs) object ¶
Evaluate the custom operation in primal mode.
You must implement this method when subclassing
CustomOp
, since the default implementation raises an exception. It should realize the original (non-derivative-aware) form of a computation and may take an arbitrary sequence of positional, keyword, and variable-length positional/keyword arguments.You should not need to call this function yourself—Dr.Jit will automatically do so when performing custom operations through the
drjit.custom()
interface.Note that the input arguments passed to
.eval()
will be detached (i.e. they don’t have derivative tracking enabled). This is intentional, since derivative tracking is handled by the custom operation along with the other callbacksforward()
andbackward()
.
- forward(self) None ¶
Evaluate the forward derivative of the custom operation.
You must implement this method when subclassing
CustomOp
, since the default implementation raises an exception. It takes no arguments and has no return value.An implementation will generally perform repeated calls to
grad_in()
to query the gradients of all function followed by a single call toset_grad_out()
to set the gradient of the return value.For example, this is how one would implement the product rule of the primal calculation
x*y
, assuming that the.eval()
routine stashed the inputs in the custom operation object.def forward(self): self.set_grad_out(self.y * self.grad_in('x') + self.x * self.grad_in('y'))
- backward(self) None ¶
Evaluate the backward derivative of the custom operation.
You must implement this method when subclassing
CustomOp
, since the default implementation raises an exception. It takes no arguments and has no return value.An implementation will generally perform a single call to
grad_out()
to query the gradient of the function return value followed by a sequence of calls toset_grad_in()
to assign the gradients of the function inputs.For example, this is how one would implement the product rule of the primal calculation
x*y
, assuming that the.eval()
routine stashed the inputs in the custom operation object.def backward(self): self.set_grad_in('x', self.y * self.grad_out()) self.set_grad_in('y', self.x * self.grad_out())
- name(self) str ¶
Return a descriptive name of the
CustomOp
instance.Amongst other things, this name is used to document the presence of the custom operation in GraphViz debug output. (See
graphviz_ad()
.)
- grad_out(self) object ¶
Query the gradient of the return value.
Returns an object, whose type matches the original return value produced in
eval()
. This function should only be used within thebackward()
callback.
- set_grad_out(self, arg: object, /) None ¶
Accumulate a gradient into the return value.
This function should only be used within the
forward()
callback.
- grad_in(self, arg: object, /) object ¶
Query the gradient of a specified input parameter.
The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (
*args
) can be queried by providing an integer index instead.This function should only be used within the
forward()
callback.
- set_grad_in(self, arg0: object, arg1: object, /) None ¶
Accumulate a gradient into the specified input parameter.
The second argument specifies the parameter name as string. Gradients of variable-length positional arguments (
*args
) can be assigned by providing an integer index instead.This function should only be used within the
backward()
callback.
- add_input(self, arg: object, /) None ¶
Register an implicit input dependency of the operation on an AD variable.
This function should be called by the
eval()
implementation when an operation has a differentiable dependence on an input that is not a ordinary input argument of the function (e.g., a global program variable or a field of a class).
- add_output(self, arg: object, /) None ¶
Register an implicit output dependency of the operation on an AD variable.
This function should be called by the
eval()
implementation when an operation has a differentiable dependence on an output that is not part of the function return value (e.g., a global program variable or a field of a class).”
- drjit.custom(arg0: type[drjit.CustomOp], /, *args, **kwargs) object ¶
Evaluate a custom differentiable operation.
It can be useful or even necessary to control how a particular operation should be differentiated by Dr.Jit’s automatic differentiation (AD) layer. The
drjit.custom()
function enables such use cases by stitching an opque operation with user-defined primal and forward/backward derivative implementations into the AD graph.The function expects a subclass of the
CustomOp
interface as first argument. The remaining positional and keyword arguments are forwarded to theCustomOp.eval()
callback.See the documentation of
CustomOp
for examples on how to realize such a custom operation.
- drjit.wrap(source: str | ModuleType, target: str | ModuleType) Callable[[T], T] ¶
Differentiable bridge between Dr.Jit and other array programming frameworks.
This function wraps computation performed using one array programming framework to expose it in another. Currently, PyTorch, TensorFlow, and JAX are supported, though other frameworks may be added in the future.
Annotating a function with
@drjit.wrap
adds code that suitably converts arguments and return values. Furthermore, it stitches the operation into the automatic differentiation (AD) graph of the other framework to ensure correct gradient propagation.When exposing code written using another framework, the wrapped function can take and return any PyTree including flat or nested Dr.Jit arrays, tensors, and arbitrary nested lists/tuples, dictionaries, and custom data structures. The arguments don’t need to be differentiable—for example, integer/boolean arrays that don’t carry derivative information can be passed as well.
The wrapped function should be pure: in other words, it should read its input(s) and compute an associated output so that re-evaluating the function again produces the same answer. Multi-framework derivative tracking of impure computation will likely not behave as expected.
The following table lists the currently supported conversions:
Direction
Primal
Forward-mode AD
Reverse-mode AD
Remarks
drjit
→torch
✅
✅
✅
Everything just works.
torch
→drjit
✅
✅
✅
Limitation: The passed/returned PyTrees can contain arbitrary arrays or tensors, but other types (e.g., a custom Python object not understood by PyTorch) will raise errors when differentiating in forward mode (backward mode works fine).
An issue was filed on the PyTorch bugtracker.
drjit
→tf
✅
✅
✅
You may want to further annotate the wrapped function with
tf.function
to trace and just-in-time compile it in the Tensorflow environment, i.e.,@dr.wrap(source='drjit', target='tf') @tf.function(jit_compile=False) # Set to True for XLA mode
Limitation: There is an issue for tf.int32 tensors which are wrongly placed on CPU by DLPack. This can lead to inconsistent device placement of tensors.
An issue was filed on the TensorFlow bugtracker.
tf
→drjit
✅
❌
✅
TensorFlow has some limitiations with respect to custom gradients in foward-mode AD.
Limitation: TensorFlow does not allow for non-tensor input structures in fuctions with custom gradients.
TensorFlow has a bug for functions with custom gradients and keyword arguments.
An issue was filed on the TensorFlow bugtracker.
drjit
→jax
✅
✅
✅
You may want to further annotate the wrapped function with
jax.jit
to trace and just-in-time compile it in the JAX environment, i.e.,@dr.wrap(source='drjit', target='jax') @jax.jit
Limitation: The passed/returned PyTrees can contain arbitrary arrays or Python scalar types, but other types (e.g., a custom Python object not understood by JAX) will raise errors.
jax
→drjit
❌
❌
❌
This direction is currently unsupported. We plan to add it in the future.
Please also refer to the documentation sections on multi-framework differentiation associated caveats.
Note
Types that have no equivalent on the other side (e.g. a quaternion array) will convert to generic tensors.
Data exchange is limited to representations that exist on both sides. There are a few limitations:
PyTorch lacks support for most unsigned integer types (
uint16
,uint32
, oruint64
-typed arrays). Use signed integer types to work around this issue.TensorFlow has limitations with respect to forward-mode AD for functions with custom gradients.There is also an issue for functions with keyword arguments.
Dr.Jit currently lacks support for most 8- and 16-bit numeric types (besides half precision floats).
JAX refuses to exchange boolean-valued tensors with other frameworks.
- Parameters:
source (str | module) – The framework used outside of the wrapped function. The argument is currently limited to either
'drjit'
,'torch'
,'tf'
, orjax'
. For convenience, the associated Python module can be specified as well.target (str | module) – The framework used inside of the wrapped function. The argument is currently limited to either
'drjit'
,'torch'
,'tf'
, or'jax'
. For convenience, the associated Python module can be specified as well.
- Returns:
The decorated function.
Constants¶
- drjit.e¶
The exponential constant \(e\) represented as a Python
float
.
- drjit.log_two¶
The value \(\log(2)\) represented as a Python
float
.
- drjit.inv_log_two¶
The value \(\frac{1}{\log(2)}\) represented as a Python
float
.
- drjit.pi¶
The value \(\pi\) represented as a Python
float
.
- drjit.inv_pi¶
The value \(\frac{1}{\pi}\) represented as a Python
float
.
- drjit.sqrt_pi¶
The value \(\sqrt{\pi}\) represented as a Python
float
.
- drjit.inv_sqrt_pi¶
The value \(\frac{1}{\sqrt{\pi}}\) represented as a Python
float
.
- drjit.two_pi¶
The value \(2\pi\) represented as a Python
float
.
- drjit.inv_two_pi¶
The value \(\frac{1}{2\pi}\) represented as a Python
float
.
- drjit.sqrt_two_pi¶
The value \(\sqrt{2\pi}\) represented as a Python
float
.
- drjit.inv_sqrt_two_pi¶
The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python
float
.
- drjit.four_pi¶
The value \(4\pi\) represented as a Python
float
.
- drjit.inv_four_pi¶
The value \(\frac{1}{4\pi}\) represented as a Python
float
.
- drjit.sqrt_four_pi¶
The value \(\sqrt{4\pi}\) represented as a Python
float
.
- drjit.sqrt_two¶
The value \(\sqrt{2\pi}\) represented as a Python
float
.
- drjit.inv_sqrt_two¶
The value \(\frac{1}{\sqrt{2\pi}}\) represented as a Python
float
.
- drjit.inf¶
The value
float('inf')
represented as a Pythonfloat
.
- drjit.nan¶
The value
float('nan')
represented as a Pythonfloat
.
- drjit.epsilon(arg, /)¶
Returns the machine epsilon.
The machine epsilon gives an upper bound on the relative approximation error due to rounding in floating point arithmetic.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The machine epsilon.
- Return type:
float
- drjit.one_minus_epsilon(arg, /)¶
Returns one minus the machine epsilon value.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
One minus the machine epsilon.
- Return type:
float
- drjit.recip_overflow(arg, /)¶
Returns the reciprocal overflow threshold value.
Any numbers equal to this threshold or a smaller value will overflow to infinity when reciprocated.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The reciprocal overflow threshold value.
- Return type:
float
- drjit.smallest(arg, /)¶
Returns the smallest representable normalized floating point value.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The smallest representable normalized floating point value.
- Return type:
float
- drjit.largest(arg, /)¶
Returns the largest representable finite floating point value for t.
- Parameters:
arg (object) – Dr.Jit array or array type used to choose between an appropriate constant for half, single, or double precision.
- Returns:
The largest representable finite floating point value.
- Return type:
float
Array base class¶
- class drjit.ArrayBase(*args, **kwargs)¶
This is the base class of all Dr.Jit arrays and tensors. It provides an abstract version of the array API that becomes usable when the type is extended by a concrete specialization.
ArrayBase
itself cannot be instantiated.See the section on Dr.Jit type signatures <type_signatures> to learn about the type parameters of
ArrayBase
.- property array¶
This member plays multiple roles:
When
self
is a tensor, this property returns the storage representation of the tensor in the form of a linearized dynamic 1D array.When
self
is a special arithmetic object (matrix, quaternion, or complex number),array
provides an copy of the same data with ordinary array semantics.In all other cases,
array
is simply a reference toself
.
- Type:
- property ndim¶
This property represents the dimension of the provided Dr.Jit array or tensor.
- Type:
int
- property shape¶
This property provides a tuple describing dimension and shape of the provided Dr.Jit array or tensor.
When the input array is ragged the function raises a
RuntimeError
. The term ragged refers to an array, whose components have mismatched sizes, such as[[1, 2], [3, 4, 5]]
. Note that scalar entries (e.g.[[1, 2], [3]]
) are acceptable, since broadcasting can effectively convert them to any size.The expressions
drjit.shape(arg)
andarg.shape
are equivalent.- Type:
tuple[int, …]
- property state¶
This read-only property returns an enumeration value describing the evaluation state of this Dr.Jit array.
- Type:
- property x¶
If
self
is a static Dr.Jit array of size 1 (or larger), the propertyself.x
can be used synonymously withself[0]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property y¶
If
self
is a static Dr.Jit array of size 2 (or larger), the propertyself.y
can be used synonymously withself[1]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property z¶
If
self
is a static Dr.Jit array of size 3 (or larger), the propertyself.z
can be used synonymously withself[2]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property w¶
If
self
is a static Dr.Jit array of size 4 (or larger), the propertyself.w
can be used synonymously withself[3]
. Otherwise, accessing this field will generate aRuntimeError
.- Type:
- property T¶
This property returns the transpose of
self
. When the underlying array is not a matrix type, it raises aTypeError
.
- property index¶
If
self
is a leaf Dr.Jit array managed by a just-in-time compiled backend (i.e, CUDA or LLVM), this property contains the associated variable index in the graph data structure storing the computation trace. This graph can be visualized usingdrjit.graphviz()
. Otherwise, the value of this property equals zero. A non-leaf array (e.g.drjit.cuda.Array2i
) consists of several JIT variables, whose indices must be queried separately.Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index
index_ad
keeps track of the variable index within the AD computation graph, if applicable.- Type:
int
- property index_ad¶
If
self
is a leaf Dr.Jit array represented by an AD backend, this property contains the variable index in the graph data structure storing the computation trace for later differentiation (this graph can be visualized usingdrjit.graphviz_ad()
). A non-leaf array (e.g.drjit.cuda.ad.Array2f
) consists of several AD variables, whose indices must be queried separately.Note that Dr.Jit maintains two computation traces at the same time: one capturing the raw computation, and a higher-level graph for automatic differentiation (AD). The index
index
keeps track of the variable index within the raw computation graph, if applicable.- Type:
int
- property grad¶
This property can be used to retrieve or set the gradient associated with the Dr.Jit array or tensor.
The expressions
drjit.grad(arg)
andarg.grad
are equivalent whenarg
is a Dr.Jit array/tensor.- Type:
- __len__()¶
Return len(self).
- __iter__()¶
Implement iter(self).
- __repr__()¶
Return repr(self).
- __bool__()¶
True if self else False
Casts the array to a Python
bool
type. This is only permissible whenself
represents an boolean array of both depth and size 1.
- __add__(value, /)¶
Return self+value.
- __radd__(value, /)¶
Return value+self.
- __iadd__(value, /)¶
Return self+=value.
- __sub__(value, /)¶
Return self-value.
- __rsub__(value, /)¶
Return value-self.
- __isub__(value, /)¶
Return self-=value.
- __mul__(value, /)¶
Return self*value.
- __rmul__(value, /)¶
Return value*self.
- __imul__(value, /)¶
Return self*=value.
- __matmul__(value, /)¶
Return self@value.
- __rmatmul__(value, /)¶
Return value@self.
- __imatmul__(value, /)¶
Return self@=value.
- __truediv__(value, /)¶
Return self/value.
- __rtruediv__(value, /)¶
Return value/self.
- __itruediv__(value, /)¶
Return self/=value.
- __floordiv__(value, /)¶
Return self//value.
- __rfloordiv__(value, /)¶
Return value//self.
- __ifloordiv__(value, /)¶
Return self//=value.
- __mod__(value, /)¶
Return self%value.
- __rmod__(value, /)¶
Return value%self.
- __imod__(value, /)¶
Return self%=value.
- __rshift__(value, /)¶
Return self>>value.
- __rrshift__(value, /)¶
Return value>>self.
- __irshift__(value, /)¶
Return self>>=value.
- __lshift__(value, /)¶
Return self<<value.
- __rlshift__(value, /)¶
Return value<<self.
- __ilshift__(value, /)¶
Return self<<=value.
- __and__(value, /)¶
Return self&value.
- __rand__(value, /)¶
Return value&self.
- __iand__(value, /)¶
Return self&=value.
- __or__(value, /)¶
Return self|value.
- __ror__(value, /)¶
Return value|self.
- __ior__(value, /)¶
Return self|=value.
- __xor__(value, /)¶
Return self^value.
- __rxor__(value, /)¶
Return value^self.
- __ixor__(value, /)¶
Return self^=value.
- __abs__()¶
abs(self)
- __le__(value, /)¶
Return self<=value.
- __lt__(value, /)¶
Return self<value.
- __ge__(value, /)¶
Return self>=value.
- __gt__(value, /)¶
Return self>value.
- __ne__(value, /)¶
Return self!=value.
- __eq__(value, /)¶
Return self==value.
- __dlpack__(self, stream: object | None = None) ndarray[] ¶
Returns a DLPack capsule representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3f
ordrjit.cuda.Matrix4f
need to be rearranged into a contiguous memory representation before they can be exposed.In other case, e.g. for
drjit.llvm.Float
,drjit.scalar.Array3f
, ordrjit.scalar.ArrayXf
, the data is already contiguous and a zero-copy approach is used instead.
- __array__(self, dtype: object | None = None, copy: object | None = None) object ¶
Returns a NumPy array representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3f
ordrjit.cuda.Matrix4f
need to be rearranged into a contiguous memory representation before they can be wrapped.In other case, e.g. for
drjit.llvm.Float
,drjit.scalar.Array3f
, ordrjit.scalar.ArrayXf
, the data is already contiguous and a zero-copy approach is used instead.
- numpy(self) numpy.ndarray[] ¶
Returns a NumPy array representing the data in this array.
This operation may potentially perform a copy. For example, nested arrays like
drjit.llvm.Array3f
ordrjit.cuda.Matrix4f
need to be rearranged into a contiguous memory representation before they can be wrapped.In other case, e.g. for
drjit.llvm.Float
,drjit.scalar.Array3f
, ordrjit.scalar.ArrayXf
, the data is already contiguous and a zero-copy approach is used instead.
- torch(self) object ¶
Returns a PyTorch tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()
for further information on how to accomplish this.
- jax(self) object ¶
Returns a JAX tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()
for further information on how to accomplish this.
- tf(self) object ¶
Returns a TensorFlow tensor representing the data in this array.
For flat arrays and tensors, Dr.Jit performs a zero-copy conversion, which means that the created tensor provides a view of the same data that will reflect later modifications to the Dr.Jit array. Nested arrays require a temporary copy to rearrange data into a compatible form.
Warning
This operation converts the numerical representation but does not embed the resulting tensor into the automatic differentiation graph of the other framework. This means that gradients won’t correctly propagate through programs combining multiple frameworks. Take a look at the function
drjit.wrap()
for further information on how to accomplish this.
Computation graph analysis¶
The following operations visualize the contents of Dr.Jit’s computation graphs (of which there are two: one for Jit compilation, and one for automatic differentiation).
- drjit.graphviz(as_string: bool = False) object ¶
Return a GraphViz diagram describing registered JIT variables and their connectivity.
This function returns a representation of the computation graph underlying the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the
graphviz_ad()
function to visualize the computation graph of the latter.Run
dr.graphviz().view()
to open up a PDF viewer that shows the resulting output in a separate window.The function depends on the
graphviz
Python package whenas_string=False
(the default).- Parameters:
as_string (bool) – if set to
True
, the function will return raw GraphViz markup as a string. (Default:False
)- Returns:
GraphViz object or raw markup.
- Return type:
object
- drjit.graphviz_ad(as_string: bool = False) object ¶
Return a GraphViz diagram describing variables registered with the automatic differentiation layer, as well as their connectivity.
This function returns a representation of the computation graph underlying the Dr.Jit AD layer, which one architectural layer above the just-in-time compiler. See the
graphviz()
function to visualize the computation graph of the latter.Run
dr.graphviz_ad().view()
to open up a PDF viewer that shows the resulting output in a separate window.The function depends on the
graphviz
Python package whenas_string=False
(the default).- Parameters:
as_string (bool) – if set to
True
, the function will return raw GraphViz markup as a string. (Default:False
)- Returns:
GraphViz object or raw markup.
- Return type:
object
- drjit.whos(as_string: bool) str | None ¶
Return/print a list of live JIT variables.
This function provides information about the set of variables that are currently registered with the Dr.Jit just-in-time compiler, which is separate from the automatic differentiation layer. See the
whos_ad()
function for the latter.- Parameters:
as_string (bool) – if set to
True
, the function will return the list in string form. Otherwise, it will print directly onto the console and returnNone
. (Default:False
)- Returns:
a human-readable list (if requested).
- Return type:
None | str
- drjit.whos_ad(as_string: bool) str | None ¶
Return/print a list of live variables registered with the automatic differentiation layer.
This function provides information about the set of variables that are currently registered with the Dr.Jit automatic differentiation layer, which one architectural layer above the just-in-time compiler. See the
whos()
function to obtain information about the latter.- Parameters:
as_string (bool) – if set to
True
, the function will return the list in string form. Otherwise, it will print directly onto the console and returnNone
. (Default:False
)- Returns:
a human-readable list (if requested).
- Return type:
None | str
- drjit.set_label(arg0: object, arg1: str, /) None ¶
- drjit.set_label(**kwargs) None
Overloaded function.
set_label(arg0: object, arg1: str, /) -> None
Assign a label to the provided Dr.Jit array.
This can be helpful to identify computation in GraphViz output (see
drjit.graphviz()
,graphviz_ad()
).The operations assumes that the array is tracked by the just-in-time compiler. It has no effect on unsupported inputs (e.g., arrays from the
drjit.scalar
package). It recurses through PyTrees (tuples, lists, dictionaries, custom data structures) and appends names (indices, dictionary keys, field names) separated by underscores to uniquely identify each element.The following
**kwargs
-based shorthand notation can be used to assign multiple labels at once:set_label(x=x, y=y)
- Parameters:
*arg (tuple) – a Dr.Jit array instance and its corresponding label
str
value.**kwarg (dict) – A set of (keyword, object) pairs.
set_label(**kwargs) -> None
Debugging¶
- enum drjit.LogLevel(value)¶
Valid values are as follows:
- Disable = LogLevel.Disable¶
- Error = LogLevel.Error¶
- Warn = LogLevel.Warn¶
- Info = LogLevel.Info¶
- InfoSym = LogLevel.InfoSym¶
- Debug = LogLevel.Debug¶
- Trace = LogLevel.Trace¶
- drjit.assert_true(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)¶
Generate an assertion failure message when any of the entries in
cond
areFalse
.This function resembles the built-in
assert
keyword in that it raises anAssertionError
when the conditioncond
isFalse
.In contrast to the built-in keyword, it also works when
cond
is an array of boolean values. In this case, the function raises an exception when any entry ofcond
isFalse
.The function accepts an optional format string along with positional and keyword arguments, and it processes them like
drjit.print()
. When only a subset of the entries ofcond
isFalse
, the function reduces the generated output to only include the associated entries.>>> x = Float(1, -4, -2, 3) >>> dr.assert_true(x >= 0, 'Found negative values: {}', x) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "drjit/__init__.py", line 1327, in assert_true raise AssertionError(msg) AssertionError: Assertion failure: Found negative values: [-4, -2]
This function also works when some of the function inputs are symbolic. In this case, the check is delayed and potential failures will be reported asynchronously. In this case,
drjit.assert_true()
generates output onsys.stderr
instead of raising an exception, as the original execution context no longer exists at that point.Assertion checks carry a performance cost, hence they are disabled by default. To enable them, set the JIT flag
dr.JitFlag.Debug
.- Parameters:
cond (bool | drjit.ArrayBase) – The condition used to trigger the assertion. This should be a scalar Python boolean or a 1D boolean array.
fmt (str) – An optional format string that will be appended to the error message. It can reference positional or keyword arguments specified via
*args
and**kwargs
.*args (tuple) – Optional variable-length positional arguments referenced by
fmt
, seedrjit.print()
for details on this.tb_depth (int) – Depth of the backtrace that should be appended to the assertion message. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed.
tb_skip (int) – The first
tb_skip
entries of the backtrace will be removed. This only applies to cases some of the inputs are symbolic, and printing of the error message must be delayed. This is helpful when the assertion check is called from a helper function that should not be shown.**kwargs (dict) – Optional variable-length keyword arguments referenced by
fmt
, seedrjit.print()
for details on this.
- drjit.assert_false(cond, fmt: str | None = None, *args, tb_depth: int = 3, tb_skip: int = 0, **kwargs)¶
Equivalent to
assert_true()
with a flipped conditioncond
. Please refer to the documentation of this function for further details.
- drjit.assert_equal(arg0, arg1, fmt: str | None = None, *args, limit: int = 3, tb_skip: int = 0, **kwargs)¶
Equivalent to
assert_true()
with the conditionarg0==arg1
. Please refer to the documentation of this function for further details.
- drjit.print(fmt: str, *args, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None ¶
- drjit.print(value: object, /, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) None
Overloaded function.
print(fmt: str, *args, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) -> None
Generate a formatted string representation and print it immediately or in a delayed fashion (if any of the inputs are symbolic).
This function combines the behavior of the built-in Python
format()
andprint()
functions: it generates a formatted string representation as specified by a format stringfmt
and then outputs it on the console. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.>>> from drjit.cuda import Array3f >>> dr.print("{}:\n{foo}", ... "A PyTree containing an array", ... foo={ 'a' : Array3f(1, 2, 3) }) A PyTree containing an array: { 'a': [[1, 2, 3]] }
The key advance of
drjit.print()
compared to the built-in Pythonprint()
statement is that it can run asynchronously, which allows it to print symbolic variables without requiring their evaluation. Dr.Jit uses symbolic variables to trace loops (drjit.while_loop()
), conditionals (drjit.if_stmt()
), and calls (drjit.switch()
,drjit.dispatch()
). Such symbolic variables represent values that are unknown at trace time, and which cannot be printed using the built-in Pythonprint()
function (attempting to do so will raise an exception).When the print statement does not reference any symbolic arrays or tensors, it will execute immediately. Otherwise, the output will appear after the next
drjit.eval()
statement, or whenever any subsequent computation is evaluated. Here is an example from an interactive Python session demonstrating printing from a symbolic call performed viadrjit.switch()
:>>> from drjit.llvm import UInt, Float >>> def f1(x): ... dr.print("in f1: {x=}", x=x) ... >>> def f2(x): ... dr.print("in f2: {x=}", x=x) ... >>> dr.switch(index=UInt(0, 0, 0, 1, 1, 1), ... targets=[f1, f2], ... x=Float(1, 2, 3, 4, 5, 6)) >>> # No output (yet) >>> dr.eval() in f1: x=[1, 2, 3] in f2: x=[4, 5, 6]
Dynamic arrays with more than 20 entries will be abbreviated. Specify the
limit=..
argument to reveal the contents of larger arrays.>>> dr.print(dr.arange(dr.llvm.Int, 30)) [0, 1, 2, .. 24 skipped .., 27, 28, 29] >>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵ 23, 24, 25, 26, 27, 28, 29]
This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:
Positional arguments (in
*args
) can be referenced implicitly ({}
), or using indices ({0}
,{1}
, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.Keyword arguments (in
**kwargs
) can be referenced via their keyword name ({foo}
). Unreferenced keywords will be silently ignored.A trailing
=
in a brace expression repeats the string within the braces followed by the output:>>> dr.print('{foo=}', foo=1) foo=1
When the format string
fmt
is omitted, it is implicitly set to{}
, and the function formats a single positional argument.The function implicitly appends
end
to the format string, which is set to a newline by default. The final result is sent tosys.stdout
(by default) orfile
. When afile
argument is given, it must implement the methodwrite(arg: str)
.A related operation
drjit.format()
admits the same format string syntax but returns a Pythonstr
instead of printing to the console. This operation, however, does not support symbolic inputs—usedrjit.print()
with a customfile
argument to stringify symbolic inputs asynchronously.Note
This operation is not suitable for extracting large amounts of data from Dr.Jit kernels, as the conversion to a string representation incurs a nontrivial runtime cost.
Note
Technical details on symbolic printing
When Dr.Jit compiles and executes queued computation on the target device, it includes additional code for symbolic print operations that captures referenced arguments and copies them back to the host (CPU). The information is then printed following the end of that process.
Only a limited amount of memory is set aside to capture the output of symbolic print operations. This is because the amount of data produced within long-running symbolic loops can often exceed the total device memory. Also, printing gigabytes of ASCII text into a Python console or Jupyter notebook is likely not a good idea.
For the electronically inclined, the operation is best thought of as hooking up an oscilloscope to a high-frequency circuit. The oscilloscope provides a limited view into a vast torrent of data to assist the user, who would be overwhelmed if the oscilloscope worked by capturing and showing everything.
The operation warns when the size of the buffers was insufficient. In this case, the output is still printed in the correct order, but chunks of the data are missing. The position of the resulting holes is unspecified and non-deterministic.
>>> dr.print(dr.arange(Float, 10000000), method='symbolic') >>> dr.eval() [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] RuntimeWarning: dr.print(): symbolic print statement only captured 20 of 10000000 available outputs. The above is a non-deterministic sample, in which entries are in the right order but not necessarily contiguous. Specify `limit=..` to capture more information and/or add the special format field `{thread_id}` show the thread ID/array index associated with each entry of the captured output.
This is because the (many) parallel threads of the program all try to append their state to the output buffer, but only the first
limit
(20 by default) succeed. The host subsequently re-sorts the captured data by thread ID. This means that the output[5, 6, 102, 188, 1026, ..]
would also be a valid result of the prior command. When a print statement references multiple arrays, then the operations either shows all array entries associated with a particular execution thread, or none of them.To refine what is captured, you can specify the
active
argument to disable the print statement for a subset of the entries (a “trigger” in the oscilloscope analogy). Printing from an inactive thread within a symbolic loop (drjit.while_loop()
), conditional (drjit.if_stmt()
), or call (drjit.switch()
,drjit.dispatch()
) will likewise not generate any output.A potential gotcha of the current design is that a symbolic print within a symbolic loop counts as one print statement and will only generate a single combined output string. The output of each thread is arranged in one contiguous block. You can add the special format string keyword
{thread_id}
to reveal the mapping between output values and the execution thread that generated them:>>> from drjit.llvm import Int >>> @dr.syntax >>> def f(j: Int): ... i = Int(0) ... while i < j: ... dr.print('{thread_id=} {i=}', i=i) ... i += 1 ... >>> f(Int(2, 3)) >>> dr.eval(); thread_id=[0, 0, 1, 1, 1], i=[0, 1, 0, 1, 2]
The example above runs a symbolic loop twice in parallel: the first thread runs for 2 iterations, and the second runs for 3 iterations. The loop prints the iteration counter
i
, which then leads to the output[0, 1, 0, 1, 2]
where the first two entries are produced by the first thread, and the trailing three belong to the second thread. Thethread_id
output clarifies this mapping.- Parameters:
fmt (str) – A format string that potentially references input arguments from
*args
and**kwargs
.active (drjit.ArrayBase | bool) – A mask argument that can be used to disable a subset of the entries. The print statement will be completely suppressed when there is no output. (default:
True
).end (str) – This string will be appended to the format string. It is set to a newline character (
"\n"
) by default.file (object) – The print operation will eventually invoke
file.write(arg:str)
to print the formatted string. Specify this argument to route the output somewhere other than the default output streamsys.stdout
.mode (str) – Specify this parameter to override the evaluation mode. Possible values are:
"symbolic"
,"evaluated"
, or"auto"
. The default value of"auto"
causes the function to use evaluated mode (which prints immediately) unless a symbolic input is detected, in which case printing takes place symbolically (i.e., in a delayed fashion).limit (int) – The operation will abbreviate dynamic arrays with more than
limit
(default: 20) entries.
print(value: object, /, active: drjit.ArrayBase | bool = True, end: str = '\n', file: object = None, limit: int = 20, mode='auto', **kwargs) -> None
- drjit.format(fmt: str, *args, limit: int = 20, **kwargs)¶
- drjit.format(value: object, *, limit: int = 20, **kwargs) None
Overloaded function.
format(fmt: str, *args, limit: int = 20, **kwargs)
Return a formatted string representation.
This function generates a formatted string representation as specified by a format string
fmt
and then returns it as a Pythonstr
object. The operation fetches referenced positional and keyword arguments and pretty-prints Dr.Jit arrays, tensors, and PyTrees with indentation, field names, etc.>>> from drjit.cuda import Array3f >>> s = dr.format("{}:\n{foo}", ... "A PyTree containing an array", ... foo={ 'a' : Array3f(1, 2, 3) }) >>> print(s) A PyTree containing an array: { 'a': [[1, 2, 3]] }
Dynamic arrays with more than 20 entries will be abbreviated. Specify the
limit=..
argument to reveal the contents of larger arrays.>>> dr.format(dr.arange(dr.llvm.Int, 30)) [0, 1, 2, .. 24 skipped .., 27, 28, 29] >>> dr.format(dr.arange(dr.llvm.Int, 30), limit=30) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,↵ 23, 24, 25, 26, 27, 28, 29]
This function lacks many features of Python’s (rather intricate) format string mini language and f-string interpolation. However, a subset of the functionality is supported:
Positional arguments (in
*args
) can be referenced implicitly ({}
), or using indices ({0}
,{1}
, etc.). Those conventions should not be mixed. Unreferenced positional arguments will be silently ignored.Keyword arguments (in
**kwargs
) can be referenced via their keyword name ({foo}
). Unreferenced keywords will be silently ignored.A trailing
=
in a brace expression repeats the string within the braces followed by the output:>>> dr.format('{foo=}', foo=1) foo=1
When the format string
fmt
is omitted, it is implicitly set to{}
, and the function formats a single positional argument.In contrast to the related
drjit.print()
, this function does not output the result on the console, and it cannot support symbolic inputs. This is because returning a string right away is incompatible with the requirement of evaluating/formatting symbolic inputs in a delayed fashion. If you wish to format symbolic arrays, you must calldrjit.print()
with a customfile
object that implements the.write()
function. Dr.Jit will call this function with the generated string when it is ready.- Parameters:
fmt (str) – A format string that potentially references input arguments from
*args
and**kwargs
.limit (int) – The operation will abbreviate dynamic arrays with more than
limit
(default: 20) entries.
- Returns:
The formatted string representation created as specified above.
- Return type:
str
format(value: object, *, limit: int = 20, **kwargs)
- drjit.log_level() drjit.LogLevel ¶
- drjit.set_log_level(arg: drjit.LogLevel, /) None ¶
- drjit.set_log_level(arg: int, /) None
Profiling¶
- class drjit.profile_enable(*args, **kwargs)¶
Context manager to selectively activate profiling for a region of a program.
Some profiling tools (e.g., NSight Compute) support targeted profiling of smaller parts of a program. Use this context manager to locally enable profiling.
Note the difference between this context manager and dr.profile_range(), which annotates a profiled region with a label.
- drjit.profile_mark(arg: str, /) None ¶
Mark an event on the timeline of profiling tools.
Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.
Note that this event will be recorded on the CPU timeline.
- class drjit.profile_range(*args, **kwargs)¶
Context manager to mark a region (e.g. a function call) on the timeline of profiling tools.
You can use this context manager to wrap parts of your code and track when and for how long it runs. Regions can be arbitrarily nested, which profiling tools visualize as a stack.
Note that this function is intended to track activity on the CPU timeline. If the wrapped region launches asynchronous GPU kernels, then those won’t generally be included in the length of the range unless
drjit.sync_thread()
or some other type of synchronization operation waits for their completion (which is generally not advisable, since keeping CPU and GPU asynchronous with respect to each other improves performance).Currently, this function uses NVTX to report events that can be captured using NVIDIA Nsight Systems. The operation is a no-op when no profile collection tool is attached.
Note the difference between this context manager and dr.profile_enable(), which enables targeted profiling of a smaller region of code (as opposed to profiling the entire program).
Textures¶
The texture implementations are defined in the various backends.
(e.g. drjit.llvm.ad.Texture3f16
). However, they reference
enumerations provided here
Low-level bits¶
- drjit.set_backend(arg: Literal['cuda', 'llvm', 'scalar'], /)¶
- drjit.set_backend(arg: drjit.JitBackend, /) None
Overloaded function.
set_backend(arg: Literal['cuda', 'llvm', 'scalar'], /)
Adjust the
drjit.auto.*
module so that it refers to types from the specified backend.set_backend(arg: drjit.JitBackend, /) -> None
- drjit.thread_count() int ¶
Return the number of threads that Dr.Jit uses to parallelize computation on the CPU
- drjit.set_thread_count(arg: int, /) None ¶
Adjust the number of threads that Dr.Jit uses to parallelize computation on the CPU.
The thread pool is primarily used by Dr.Jit’s LLVM backend. Other projects using underlying nanothread thread pool library will also be affected by changes performed using by this function. It is legal to call it even while parallel computation is currently ongoing.
- drjit.sync_thread() None ¶
Wait for all currently running computation to finish.
This function synchronizes the device (e.g. the GPU) with the host (CPU) by waiting for the termination of all computation enqueued by the current host thread.
One potential use of this function is to measure the runtime of a kernel launched by Dr.Jit. We instead recommend the use of the
drjit.kernel_history()
, which exposes more accurate device timers.In general, calling this function in user code is considered bad practice. Dr.Jit programs “run ahead” of the device to keep it fed with work. This is important for performance, and
drjit.sync_thread()
breaks this optimization.All operations sent to a device (including reads) are strictly ordered, so there is generally no reason to wait for this queue to empty. If you find that
drjit.sync_thread()
is needed for your program to run correctly, then you have found a bug. Please report it on the project’s GitHub issue tracker.
- drjit.flush_kernel_cache() None ¶
Release all currently cached kernels.
When Dr.Jit evaluates a previously unseen computation, it compiles a kernel and then maps it into the memory of the CPU or GPU. The kernel stays resident so that it can be immediately reused when that same computation reoccurs at a later point.
In long development sessions (e.g. a Jupyter notebook-based prototyping), this cache may eventually become unreasonably large, and calling
flush_kernel_cache()
to free it may be advisable.Note that this does not free the disk cache that also exists to preserve compiled programs across sessions. To clear this cache as well, delete the directory
$HOME/.drjit
on Linux/macOS, and%AppData%\Local\Temp\drjit
on Windows. (TheAppData
folder is typically found inC:\Users\<your username>
).
- drjit.flush_malloc_cache() None ¶
Free the memory allocation cache maintained by Dr.Jit.
Allocating and releasing large chunks of memory tends to be relatively expensive, and Dr.Jit programs often need to do so at high rates.
Like most other array programming frameworks, Dr.Jit implements an internal cache to reduce such allocation-related costs. This cache starts out empty and grows on demand. Allocated memory is never released by default, which can be problematic when using multiple array programming frameworks within the same Python session, or when running multiple processes in parallel.
The
drjit.flush_malloc_cache()
function releases all currently unused memory back to the operating system. This is a relatively expensive step: you likely don’t want to use it within a performance-sensitive program region (e.g. an optimization loop).
- drjit.expand_threshold() int ¶
Query the threshold for performing scatter-reductions via expansion.
Getter for the quantity set in
drjit.set_expand_threshold()
- drjit.set_expand_threshold(arg: int, /) None ¶
Set the threshold for performing scatter-reductions via expansion.
The documentation of
drjit.ReduceOp
explains the cost of atomic scatter-reductions and introduces various optimization strategies.One particularly effective optimization (the section on optimizations for plots) named
drjit.ReduceOp.Expand
is specific to the LLVM backend. It replicates the target array to avoid write conflicts altogether, which enables the use of non-atomic memory operations. This is significantly faster but also very memory-intensive. The storage cost of an 1MB array targeted by adrjit.scatter_reduce()
operation now grows toN
megabytes, whereN
is the number of cores.For this reason, Dr.Jit implements a user-controllable threshold exposed via the functions
drjit.expand_threshold()
anddrjit.set_expand_threshold()
. When the array has more entries than the value specified here, thedrjit.ReduceOp.Expand
strategy will not be used unless specifically requested via themode=
parameter of operations likedrjit.scatter_reduce()
,drjit.scatter_add()
, anddrjit.gather()
.The default value of this parameter is 1000000 (1 million entries).
- drjit.kernel_history(types: collections.abc.Sequence[drjit.KernelType] = []) list ¶
Return the history of captured kernel launches.
Dr.Jit can optionally capture performance-related metadata. To do so, set the
drjit.JitFlag.KernelHistory
flag as follows:with dr.scoped_set_flag(dr.JitFlag.KernelHistory): # .. computation to be analyzed .. hist = dr.kernel_history()
The
drjit.kernel_history()
function returns a list of dictionaries characterizing each major operation performed by the analyzed region. This dictionary has the following entriesbackend
: The used JIT backend.execution_time
: The time (in milliseconds) used by this operation.On the CUDA backend, this value is captured via CUDA events. On the LLVM backend, this involves querying
CLOCK_MONOTONIC
(Linux/macOS) orQueryPerformanceCounter
(Windows).type
: The type of computation expressed by an enumeration value of typedrjit.KernelType
. The most interesting workload generated by Dr.Jit are just-in-time compiled kernels, which are identified bydrjit.KernelType.JIT
.These have several additional entries:
hash
: The hash code identifying the kernel. (This is the same hash code is also shown when increasing the log level viadrjit.set_log_level()
).ir
: A capture of the intermediate representation used in this kernel.operation_count
: The number of low-level IR operations. (A rough proxy for the complexity of the operation.)cache_hit
: Was this kernel present in Dr.Jit’s in-memory cache? Otherwise, it as either loaded from memory or had to be recompiled from scratch.cache_disk
: Was this kernel present in Dr.Jit’s on-disk cache? Otherwise, it had to be recompiled from scratch.codegen_time
: The time (in milliseconds) which Dr.Jit needed to generate the textual low-level IR representation of the kernel. This step is always needed even if the resulting kernel is already cached.backend_time
: The time (in milliseconds) which the backend (either the LLVM compiler framework or the CUDA PTX just-in-time compiler) required to compile and link the low-level IR into machine code. This step is only needed when the kernel did not already exist in the in-memory or on-disk cache.uses_optix
: Was this kernel compiled by the NVIDIA OptiX ray tracing engine?
Note that
drjit.kernel_history()
clears the history while extracting this information. A related operationdrjit.kernel_history_clear()
only clears the history without returning any information.
- drjit.kernel_history_clear() None ¶
Clear the kernel history.
This operation clears the kernel history without returning any information about it. See
drjit.kernel_history()
for details.
- drjit.detail.set_leak_warnings(arg: bool, /) None ¶
Dr.Jit tracks and can report leaks of various types (Python instance leaks, Dr.Jit-Core variable leaks, AD variable leaks). Since benign warnings can sometimes occur, they are disabled by default for PyPI release builds. Use this function to enable/disable them explicitly.
- drjit.detail.leak_warnings() bool ¶
Query whether leak warnings are enabled. See
drjit.detail.set_leak_warnings()
.
Typing¶
Local memory¶
- drjit.alloc_local(dtype: type[T], size: int, value: T | None = None) Local[T] ¶
Allocate a local memory buffer with type
dtype
and sizesize
.See the separate documentation section on local memory for details on the role of local memory and situations where it is useful.
- Parameters:
dtype (type) – Desired Dr.Jit array type or PyTree.
size (int) – Number of buffer elements. This value must be statically known.
value – If desired, an instance of type
dtype
/T
can be provided here to default-initialize all entries of the buffer. Otherwise, it is left uninitialized.
- Returns:
The allocated local memory buffer
- Return type:
Local[T]
- class drjit.Local(*args, **kwargs)¶
This generic class (parameterized by an extra type
T
) represents a local memory buffer—that is, a temporary scratch space with support for indexed reads and writes.See the separate documentation section on local memory for details on the role of local memory and situations where it is useful.
- __init__(self, arg: drjit.Local) None ¶
Copy-constructor, creates a copy of a given local memory buffer.
- read(self, index: int | AnyArray, active: bool | AnyArray = True) T ¶
Read the local memory buffer at index
index
and return a result of typeT
. An optional mask can be provided as well. Masked reads evaluate to zero.Danger
The indices provided to this operation are unchecked by default. Attempting to read beyond the end of the buffer is undefined behavior and may crash the application, unless such reads are explicitly disabled via the
active
parameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound reads and furthermore report warnings to identify problematic source locations.
- write(self, value: T, index: int | AnyArray, active: bool | AnyArray = True) None ¶
Store the value
value
at indexindex
. An optional mask can be provided as well. Masked writes are no-ops.Danger
The indices provided to this operation are unchecked by default. Attempting to write beyond the end of the buffer is undefined behavior and may crash the application, unless such writes are explicitly disabled via the
active
parameter. Negative indices are not permitted.If debug mode is enabled via the
drjit.JitFlag.Debug
flag, Dr.Jit will insert range checks into the program. These checks disable out-of-bound writes and furthermore report warnings to identify problematic source locations.
- __getitem__(self, arg: int | AnyArray, /) T ¶
Perform an normal read at the given index. This is equivalent to
read(.., active=True)
.
- __setitem__(self, arg0: int | AnyArray, arg1: T, /) None ¶
Perform an normal write at the given index. This is equivalent to
write(.., active=True)
.
- __len__(self) int ¶
Return the length (number of entries) of the local memory buffer. This corresponds to the
size
value passed todrjit.alloc_local()
.
Digital Differential Analyzer¶
The drjit.dda
module provides a general implementation of a digital
differential analyzer (DDA) that steps through the intersection of a ray
segment and a N-dimensional grid, performing a custom computation at every
cell.
The drjit.integrate()
function builds on this functionality to compute
differentiable line integrals of bi- or trilinearly interpolants stored on a
grid.
- drjit.dda.dda(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: object, grid_res: ArrayNuT, grid_min: ArrayNfT, grid_max: ArrayNfT, func: Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], Tuple[StateT, BoolT]], state: StateT, active: BoolT, mode: Literal['scalar', 'symbolic', 'evaluated', None] = None, max_iterations: int | None = None) StateT ¶
N-dimensional digital differential analyzer (DDA).
This function traverses the intersection of a Cartesian coordinate grid and a specified ray or ray segment. The following snippet shows how to use it to enumerate the intersection of a grid with a single ray.
from drjit.scalar import Array3f, Array3u, Float, Bool def dda_fun(state: list, index: Array3u, pt_in: Array3f, pt_out: Array3f, active: Bool) -> tuple[list, bool]: # Entered a grid cell, stash it in the 'state' variable state.append(Array3f(index)) return state, Bool(True) result = dda( ray_o = Array3f(-.1), ray_d = Array3f(.1, .2, .3), ray_max = Float(float('inf')), grid_res = Array3u(10), grid_min = Array3f(0), grid_max = Array3f(1), func = dda_fun, state = [], active = Bool(True) ) print(result)
Since all input elements are Dr.Jit arrays, everything works analogously when processing
N
rays andN
potentially different grid configurations. The entire process can be captured symbolically.The function takes the following arguments. Note that many of them are generic type variables (signaled by ending with a capital
T
). To support different dimensions and precisions, the implementation must be able to deal with various input types, which is communicated by these type variables.- Parameters:
ray_o (ArrayNfT) – the ray origin, where the
ArrayNfT
type variable refers to an n-dimensional scalar or Jit-compiled floating point array.ray_d (ArrayNfT) – the ray direction. Does not need to be normalized.
ray_max (object) – the maximum extent along the ray, which is permitted to be infinite. The value is specfied as a multiple of the norm of
ray_d
, which is not necessarily unit-length. Must be of typedr.value_t(ArrayNfT)
.grid_res (ArrayNuT) – the grid resolution, where the
ArrayNuT
type variable refers to a matched 32-bit integer array (i.e.,ArrayNuT = dr.int32_array_t(ArrayNfT)
).grid_min (ArrayNfT) – the minimum position of the grid bounds.
grid_max (ArrayNfT) – the maximum position of the grid bounds.
func (Callable[[StateT, ArrayNuT, ArrayNfT, ArrayNfT, BoolT], tuple[StateT, BoolT]]) –
a callback that will be invoked when the DDA traverses a grid cell. It must take the following five positional arguments:
arg0: StateT
: An arbitrary state value.arg1: ArrayNuT
: An integer array specifying the cell index along each dimension.arg2: ArrayNfT
: The fractional position (\(\in [0, 1]^n\)) where the ray enters the current cell.arg3: ArrayNfT
: The fractional position (\(\in [0, 1]^n\)) where the ray leaves the current cell.arg4: BoolT
: A boolean array specifying which elements are active.
The callback should then return a tuple of type
tuple[StateT, BoolT]
containingAn updated state value.
A boolean array that can be used to exit the loop prematurely for some or all rays. The iteration stops if the associated entry of the return value equals
False
.
state (StateT) – an arbitrary initial state that will be passed to the callback.
active (BoolT) – an array specifying which elements of the input are active, where the
BoolT
type variable refers to a matched boolean array (i.e.,BoolT = dr.mask_t(ray_o.x)
).mode – (str | None): The operation can operate in scalar, symbolic, or evaluated modes—see the
mode
argument and the documentation ofdrjit.while_loop()
for details.max_iterations – int | None: Bound on the iteration count that is needed for reverse-mode differentiation. Forwarded to the
max_iterations
parameter ofdrjit.while_loop()
.
- Returns:
The function returns the final state value of the callback upon termination.
- Return type:
StateT
Note
Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g.
(Z, Y, X)
), while regular 3D positions use the opposite(X, Y, Z)
order.In particular, all
ArrayNuT
-typed parameters of the function and the callback use the ZYX convention, whileArrayNfT
-typed parameters use theXYZ
convention.
- drjit.dda.integrate(ray_o: ArrayNfT, ray_d: ArrayNfT, ray_max: FloatT, grid_min: ArrayNfT, grid_max: ArrayNfT, vol: ArrayBase[Any, Any, Any, Any, Any, Any, Any], active: object = None, mode: Literal['scalar', 'symbolic', 'evaluated', None] = None) FloatT ¶
Compute an analytic definite integral of a bi- or trilinear interpolant.
This function uses DDA (
drjit.dda.dda()
) to step along the voxels of a 2D/3D volume traversed by a finite segment or a infinite-length ray. It analytically computes and accumulates the definite integral of the interpolant in each voxel.The input 2D/3D volume is provided using a tensor
vol
(e.g., of typedrjit.cuda.ad.TensorXf
) with an implicitly specified grid resolutionvol.shape
. This data volume is placed into an axis-aligned region with bounds (grid_min
,grid_max
).The operation provides an efficient forward and backward derivative.
Note
Just like the Dr.Jit texture interface, the implementation uses the convention that voxel sizes and positions are specified from last to first component (e.g.
(Z, Y, X)
), while regular 3D positions use the opposite(X, Y, Z)
order.In particular,
vol.shape
uses the ZYX convention, while theArrayNfT
-typed parameters use theXYZ
convention.One important difference to the texture classes is that the interpolant is sampled at integer grid positions, whereas the Dr.Jit texture classes places values at cell centers, i.e. with a
.5
fractional offset.
Optimizers¶
The drjit.opt
module implements basic infrastructure for
gradient-based optimization and adaptive mixed-precision training.
- class drjit.opt.Optimizer(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, mask_updates: bool = False)¶
Gradient-based optimizer base class
This class resembles a Python dictionary enabling retrieval and assignment of parameter values by name. It furthermore implements common functionality used by gradient-based optimizers, such as stochastic gradient descent (
SGD
) and Adam (Adam
).Typically, optimizers are used as follows:
# Create an optimizer object opt = Adam(lr=1e-3) # Register one or more parameters opt["x"] = Float(1, 2, 3) # Alternative syntax # opt = Adam( # lr=1e-3, # params={'x': Float(1, 2, 3)} # ) # For some number of iterations.. for i in range(1000): # Fetch the current parameter value x = opt["x"] # Compute a scalar loss value and backpropagate loss = my_optimization_task(x) dr.backward(loss) # Take a gradient step (details depend on the choice of optimizer) opt.step()
Following
opt.step()
, some applications may need to project the parameter values onto a valid subset, e.g.:# Values outside of the range [0, 1] are not permitted opt['x'] = clip(opt['x'], 0, 1)
You may register additional parameters or delete existing ones during an optimization.
Note
There are several notable differences compared to optimizers in PyTorch:
PyTorch keeps references to the original parameters and manipulates them in-place. The user must call
opt.zero_grad()
to clear out remaining gradients from the last iteration.Dr.Jit optimizers own the parameters being optimized. The function
opt.step()
only updates this internal set of parameters without causing changes elsewhere.Compared to PyTorch, an optimization loop therefore involves some boilerplate code (e.g.,
my_param = opt["my_param"]
) to fetch parameter values and use them to compute a differentiable objective.
In general, it is recommended that you optimize Dr.Jit code using Dr.Jit optimizers, rather than combining frameworks with differentiable bridges (e.g., via the
@dr.wrap
decorator), which can add significant inter-framework communication and bookkeeping overheads.Note
It is tempting to print or plot the loss decay during the optimization. However, doing so forces the CPU to wait for asynchronous execution to finish on the device (e.g., GPU), impeding the system’s ability to schedule new work during this time. In other words, such a seemingly minor detail can actually have a detrimental impact on device utilization. As a workaround, consider printing asynchronously using
dr.print(loss, symbolic=True)
.- step(*, eval: bool = True, grad_scale: float | ArrayBase | None = None, active: ArrayBase | None = None) None ¶
Take a gradient step.
This function visits each registered parameter in turn and
extracts the associated gradient (
.grad
),takes a step using an optimizer-dependent update rule, and
reattaches the resulting new state with the AD graph.
- Parameters:
eval (bool) – If
eval=True
(the default), the system will explicitly evaluate the resulting parameter state, causing a kernel lauch. Seteval=False
if you wish to perform further computation that should be fused into the optimizer step.grad_scale (float | drjit.ArrayBase) – Use this parameter to scale all gradients by a custom amount. Dr.Jit uses this parameter for automatic mixed-precision training.
active (drjit.ArrayBase | None) – This parameter can be used to pass a 1-element boolean mask. A value of
Bool(False)
disables the optimizer state update. Dr.Jit uses this parameter for automatic mixed-precision training.
- reset(key: str | None = None) None ¶
Reset the internal state (e.g., momentum, adaptive learning rate, etc.) associated with the parameter
key
. Whenkey=None
, the implementation resets the state of all parameters.
- update(params: Mapping[str, ArrayBase] | None = None, **args: ArrayBase) None ¶
Overwrite multiple parameter values at once.
This function simply calls
__setitem__()
multiple times. Likedict.update()
,update()
supports two calling conventions:# Update using a dictionary opt.update({'key_1': value_1, 'key_2': value_2}) # Update using a variable keyword arguments opt.update(key_1=value_1, key_2=value_2)
- learning_rate(key: str | None = None) float | ArrayBase | None ¶
Return the learning rate (globally, or of a specific parameter).
When
key
is provided, the function returns the associated parameter-specific learning rate (orNone
, if no learning rate was set for this parameter).When
key
is not provided, the function returns the default learning rate.
- set_learning_rate(value: float | ArrayBase | Mapping[str, float | ArrayBase | None] | None = None, /, **kwargs: float | ArrayBase | None) None ¶
Set the learning rate (globally, or of a specific parameter).
This function can be used as follows:
To modify the default learning rate of the optimizer:
opt = Adam(lr=1e-3) # ... some time later: opt.set_learning_rate(1e-4)
To modify the learning rate of a specific parameter:
opt = Adam(lr=1e-3, params={'x': x, 'y': y}) opt.set_learning_rate({'y': 1e-4}) # Alternative calling convention opt.set_learning_rate(y=1e-4)
Note that once the learning rate of a specific parameter is set, it always takes precedence over the global setting. You must remove the parameter-specific setting to return to the global default:
opt.set_learning_rate(y=None)
- __setitem__(key: str, value: ArrayBase, /)¶
Overwrite a parameter value or register a new parameter.
Supported parameter types includes:
Differentiable Dr.Jit arrays and nested arrays
Differentiable Dr.Jit tensors
Special array types (matrices, quaternions, complex numbers). These will be optimized component-wise.
In contrast to assignment in a regular dictionary, this function conceptually creates a copy of the input parameter (conceptual because the Copy-On-Write optimization avoids an actual on-device copy).
When
key
refers to a known parameter, the optimizer will overwrite it withvalue
. In doing so, it will preserve any associated optimizer state, such as momentum, adaptive step size, etc. When the new parameter value is substantially different (e.g., as part of a different optimization run), the previous momentum value may be meaningless, in which case a call toreset()
is advisable.When the new parameter value’s
.shape
differs from the current setting, the implementation automatically callsreset()
to discard the associated optimizer state.When
key
does not refer to a known parameter, the optimizer will register it. Note that only differentiable parameters are supported—incompatible types will raise aTypeError
.
- __delitem__(key: str, /) None ¶
Remove a parameter from the optimizer.
- __contains__(key: object, /) bool ¶
Check whether the optimizer contains a parameter with the name
key
.
- __len__() int ¶
Return the number of registered parameters.
- update(params: Mapping[str, ArrayBase] | None = None, **args: ArrayBase) None ¶
Overwrite multiple parameter values at once.
This function simply calls
__setitem__()
multiple times. Likedict.update()
,update()
supports two calling conventions:# Update using a dictionary opt.update({'key_1': value_1, 'key_2': value_2}) # Update using a variable keyword arguments opt.update(key_1=value_1, key_2=value_2)
- keys() Iterator[str] ¶
Return an iterator traversing the names of registered parameters.
- class drjit.opt.SGD(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, momentum: float = 0.0, nesterov: bool = False, mask_updates: bool = False)¶
Implements basic stochastic gradient descent (SGD) with a fixed learning rate and, optionally, momentum (0.9 is a typical parameter value for the
momentum
parameter).The default initailization (
momentum=0
) uses the following update equation:\[\begin{align*} \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta\cdot\mathbf{g}_{i+1}, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, and \(\eta\) is the learning rate.
Momentum-based SGD (with
momentum>0
,nesterov=False
) uses the update equation:\[\begin{align*} \mathbf{v}_{i+1} &= \mu\cdot\mathbf{v}_i + \mathbf{g}_{i+1}\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \cdot \mathbf{v}_{i+1}, \end{align*}\]where \(\mathbf{v}\) is the velocity and \(\mu\) is the momentum parameter. Nesterov-style SGD (
nesterov=True
) switches to the following update rule:\[\begin{align*} \mathbf{v}_{i+1} &= \mu \cdot \mathbf{v}_i + \mathbf{g}_{i+1}\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \cdot (\mathbf{g}_{i+1} + \mu \mathbf{v}_{i+1}). \end{align*}\]Some frameworks implement variations of the above quations. The code in Dr.Jit was designed to reproduce the behavior of torch.optim.SGD.
- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, momentum: float = 0.0, nesterov: bool = False, mask_updates: bool = False)¶
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()
to later adjust this value globally, or for specific parameters.momentum (float) – The momentum factor as described above. Larger values will cause past gradients to persist for a longer amount of time.
mask_updates (bool) – See
Optimizer.__init__()
for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.Adam(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-08, mask_updates: bool = False, uniform: bool = False)¶
This class implements the Adam optimizer as presented in the paper Adam: A Method for Stochastic Optimization by Kingman and Ba, ICLR 2015.
Adam effectively combines momentum (as in
SGD
withmomentum>0
) with the adaptive magnitude-based scale factor fromRMSProp
. To do so, it maintains two exponential moving averages (EMAs) per parameter: \(\mathbf{m}_i\) for the first moment, and \(\mathbf{v}_i\) for the second moment. This triples the memory usage and should be considered when optimizing very large representations.The method uses the following update equation:
\[\begin{align*} \mathbf{m}_{i+1} &= \beta_1 \cdot \mathbf{m}_i + (1-\beta_1)\cdot\mathbf{g}_{i+1}\\ \mathbf{v}_{i+1} &= \beta_2 \cdot \mathbf{v}_i + (1-\beta_2)\cdot\mathbf{g}_{i+1}^2\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \eta \frac{1-\beta_2^{i+1}}{1-\beta_1^{i+1}} \frac{\mathbf{v}_{i+1}}{\sqrt{\mathbf{m}_{i+1}+\varepsilon}}, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, \(\eta\) is the learning rate, and \(\varepsilon\) is a small number to avoid division by zero.
The scale factor \(\frac{1-\beta_2^{i+1}}{1-\beta_1^{i+1}}\) corrects for the zero-valued initialization of the moment accumulators \(\mathbf{m}_i\) and \(\mathbf{v}_i\) at \(i=0\).
This class also implements two extensions that are turned off by default. See the descriptions of the
mask_updates
anduniform
parameters below.- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-08, mask_updates: bool = False, uniform: bool = False)¶
Construct a new Adam optimizer object. The default parameters replicate the behavior of the original method.
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()
to later adjust this value globally, or for specific parameters.beta_1 (float) – Weight of the first-order moment exponential moving average (EMA). Values approaching
1
will cause past gradients to persist for a longer amount of time.beta_2 (float) – Weight of the second-order moment EMA. Values approaching
1
will cause past gradients to persist for a longer amount of time.uniform (bool) – If enabled, the optimizer will use the UniformAdam variant of Adam [Nicolet et al. 2021], where the update rule uses the maximum of the second moment estimates at the current step instead of the per-element second moments.
mask_updates (bool) – See
Optimizer.__init__()
for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.RMSProp(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, alpha: float = 0.99, epsilon: float = 1e-08, mask_updates: bool = False)¶
Implements the RMSProp optimizer explained in lecture notes by G. Hinton.
RMSProp scales the learning rate by the reciprocal of a running average of the magnitude of past gradients:
\[\begin{align*} \mathbf{m}_{i+1} &= \alpha \cdot \mathbf{m}_i + (1-\alpha)\cdot\mathbf{g}_{i+1}^2\\ \mathbf{p}_{i+1} &= \mathbf{p}_i - \frac{\eta}{\sqrt{\mathbf{m}_{i+1}+\varepsilon}}\, \mathbf{g}_{i+1}, \end{align*}\]where \(\mathbf{p}_i\) is the parameter value at iteration \(i\), \(\mathbf{g}_i\) denotes the associated gradient, \(\mathbf{m}_i\) accumulates the second moment, \(\eta\) is the learning rate, and \(\varepsilon\) is a small number to avoid division by zero.
The implementation reproduces the behavior of torch.optim.RMSprop.
- __init__(lr: float | ArrayBase, params: Mapping[str, ArrayBase] | None = None, *, alpha: float = 0.99, epsilon: float = 1e-08, mask_updates: bool = False)¶
Construct a RMSProp optimizer instance.
- Parameters:
lr (float | drjit.ArrayBase) – Learning rate parameter. You may want to try different values (e.g. powers of two) to find the best setting for a specific problem. Use
Optimizer.set_learning_rate()
to later adjust this value globally, or for specific parameters.alpha (float) – Weight of the second-order moment exponential moving average (EMA). Values approaching
1
will cause past gradients to persist for a longer amount of time.mask_updates (bool) – See
Optimizer.__init__()
for details on this parameter.params (Mapping[str, drjit.ArrayBase] | None) – Optional dictionary-like object containing an initial set of parameters.
- class drjit.opt.GradScaler(init_scale: float = 65536.0, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, debug: bool = False)¶
Gradient scaler for automatic mixed-precision training.
It is sometimes necessary to perform some part of a computation using lower precision (e.g.,
drjit.auto.Float16
) to improve storage and runtime efficiency. One issue with such lower-precision arithmetic is that gradients tend to underflow to zero, which can break the optimization.The
GradientScaler
class implements a strategy for automatic mixed precision (AMP) training to prevent such numerical issues. A comprehensive overview of AMP can be found hereAMP in Dr.Jit works as follows:
Construct a
GradientScaler
instance prior to the optimization loop. Suppose it is calledscaler
.
Invoke
scaler.scale(loss)()
function to scale the optimization loss by a suitable value to prevent gradient underflow, and then propagate derivatives usingdrjit.backward()
or a similar AD function.Replace the call to
opt.step()
withscale.step(opt)()
function, which removes the scaling prior to the gradient step.
Concretely, this might look as follows:
opt = Adam(lr=1e-3) opt['my_param'] = 0 scaler = GradScaler() for i in range(1000): my_param = opt['my_param'] loss = my_func(my_param) dr.backward(scaler.scale(loss)) scaler.step(opt)
A large scale factor can also cause the opposite problem: gradient overflow, which manifests in the form of infinity and NaN-valued gradient components.
GradientScaler
automatically detects this, skips the optimizer step, and decreases the step size.The implementation starts with a relatively aggressive scale factor that is likely to cause overflows, hence it may appear that the optimizer is initially stagnant for a few iterations. This is expected.
- __init__(init_scale: float = 65536.0, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, debug: bool = False)¶
The Dr.Jit
GradScaler
class follows the high-level API of pytorch.amp.GradScaler <https://pytorch.org/docs/stable/notes/amp_examples.html>__.- Parameters:
init_scale (float) – The initial scale factor.
growth_factor (float) – When
growth_interval
optimization steps have taken place without overflows,GradScaler
will begin to progressively increase the scale by multiplying it withgrowth_factor
at every iteration until an overflow is again detected.backoff_factor (float) – When an overflow issue is detected,
GradScaler
will decrease the scale factor by multiplying it withbackoff_factor
.growth_interval (int) – A large iteration count, following which it can be helpful to begin exploring larger scale factors.
debug (bool) – Print a debug message whenever the scale changes. This synchronizes with the device after each step, which will have a negative effect on optimization performance.