Control flow¶
Dr.Jit can trace control flow statements like loops, conditionals, and indirect jumps if they are expressed in a compatible manner.
First, let’s see what can go wrong when doing this naively. The Python snippet below is meant to compute the population count (i.e., the number of bits set to 1) per element of an integer sequence:
from drjit.auto import Int
def popcnt(i: Int):
'''Count the number of active bits in ``i``'''
j = Int(0)
while i != 0: # While there are remaining active bits
j += i & 1 # Increment counter 'j' if current bit active
i = i // 2 # Shift bits of 'i' to the right
return j
print(popcnt(dr.arange(Int, 1024)))
However, running it fails with an error message:
Traceback (most recent call last):
File "popcnt.py", line 12, in <module>
print(popcnt(dr.arange(Int, 1024)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "popcnt.py", line 7, in popcnt
while i != 0:
^^^^^^
RuntimeError: drjit.llvm.Bool.__bool__(): implicit conversion to 'bool' ↵
requires an array with at most 1 element (this one has 1024 elements).
Why does this happen? In the example above, i
is an array of integers,
hence i != 0
produces a Boolean array of component-wise comparisons.
If these values aren’t all identical, it implies that elements of i
require
different numbers of loop iterations. Unfortunately, this is something that
regular Python simply does not support. Dr.Jit raises the alarm when it
notices the user’s attempt to interpret the condition i != 0
as a Python
bool
.
Annotating the function with the @dr.syntax
decorator
(change highlighted) fixes this problem:
@dr.syntax
def popcnt(i: Int):
...
The script now terminates and prints the following correct output:
[0, 1, 1, .. 1018 skipped .., 9, 9, 10]
How does this work?¶
The @dr.syntax
decorator provides syntax sugar: it
takes a function as input and returns a slightly modified version of it. Most
code passes through unchanged, with two exceptions: the decorator rewrites the
declarations of loops and conditionals to make them compatible with tracing.
In the example above, it does this by
identifying the variables modified by the loop (
i
andj
)encapsulating the loop condition and body in separate functions, and
performing a call to
dr.while_loop()
with all of this information.
This produces code equivalent to:
def popcnt(i: Int):
j = Int(0)
i, j = dr.while_loop(
state=(i, j),
cond=lambda i, j: i != 0,
body=lambda i, j: (i // 2, j + (i & 1))
)
return j
The function dr.while_loop()
generalizes the
built-in Python while
loop: when the condition is a Python bool
, it
doesn’t do anything special and just reproduces the normal behavior. When the
loop condition is an array, it runs the loop separately for each element,
potentially for different numbers of iterations.
In the same manner, if
statements will be turned into calls to
dr.if_stmt()
that serves the same purpose for
conditionals.
The main feature of @dr.syntax
is to free users from
having to perform this transformation themselves.
Symbolic mode¶
The default way in which Dr.Jit handles control flow is called symbolic mode, which has certain limitations. Let’s make a small change to the code from before to illustrate one of them.
@dr.syntax
def popcnt(i: Int):
'''Count the number of active bits'''
j = Int(0)
while i != 0:
print(f"{i=}")
j += i & 1
i = i // 2
return j
(the added print()
statement is meant to show the state of variables at
intermediate steps.)
Running this modified code produces a long error message:
Traceback (most recent call last):
File "popcnt.py", line 9, in _loop_body
print(f"{i=}")
RuntimeError: You performed an operation that tried to evalute a *symbolic*↵
variable, which is not permitted.
[lots of explanation text omitted here]
The message explains that i
and j
are considered symbolic while
inside the loop. Certain operations are not allowed in this context, and printing
their contents is one of them.
To understand why this is forbidden, recall that Dr.Jit embraces the idea of
tracing, i.e., postponing computation for later evaluation. In the case of
popcnt()
, this means that Dr.Jit will execute the loop body only once to
understand how it modifies the variables i
and j
, but without doing any
actual computation. Even the number of loop iterations is unknown at this
point. All of these details are postponed to when the traced computation
actually runs on the target device (e.g., the GPU).
The implication of this design is that i
and j
are symbols that don’t
have explicit values within the loop body, which is why the print()
operation failed.
This way of capturing control flow is the default behavior of Dr.Jit and called symbolic mode. Dr.Jit also supports a second approach called evaluated mode that we will examine next.
Evaluated mode¶
The inability to access the contents of symbolic variables can be inconvenient. We might need to print or plot intermediate steps, or to step through a program using a visual debugger.
To do so, let’s switch the loop to evaluated mode. We can do so at a
statement level by annotating the loop condition with dr.hint(...,
mode='evaluated')
.
@dr.syntax
def popcnt(i: Int):
'''Count the number of active bits'''
j = Int(0)
while dr.hint(i != 0, mode='evaluated'):
print(f"{i=}")
j += i & 1
i = i // 2
return j
popcnt(dr.arange(Int, 1024))
With this change, Dr.Jit now executes all loop iterations explicitly. Accessing
the contents of i
also works without problems, and the script produces
the following output:
i=[0, 1, 2, .. 1018 skipped .., 1021, 1022, 1023]
i=[0, 0, 1, .. 1018 skipped .., 510, 511, 511]
i=[0, 0, 0, .. 1018 skipped .., 255, 255, 255]
i=[0, 0, 0, .. 1018 skipped .., 127, 127, 127]
i=[0, 0, 0, .. 1018 skipped .., 63, 63, 63]
i=[0, 0, 0, .. 1018 skipped .., 31, 31, 31]
i=[0, 0, 0, .. 1018 skipped .., 15, 15, 15]
i=[0, 0, 0, .. 1018 skipped .., 7, 7, 7]
i=[0, 0, 0, .. 1018 skipped .., 3, 3, 3]
i=[0, 0, 0, .. 1018 skipped .., 1, 1, 1]
[0, 1, 1, .. 1018 skipped .., 9, 9, 10]
Evaluated mode can also be enabled globally by disabling the flags
dr.JitFlag.SymbolicLoops
and
dr.JitFlag.SymbolicConditionals
via dr.set_flag()
or dr.scoped_set_flag()
.
Discussion¶
Let’s take a step back and compare the properties of these two different modes.
Evaluated mode¶
As the name suggests, this mode evaluates loop variables to store them in memory. Each loop iteration then loads variable state and writes out new state at the end. The host (i.e., the CPU) is in charge of all control flow, which makes this mode simple to understand:
Debugging programs is straightforward. The user can step through program line by line and examine variable contents via Python’s built-in
print()
statement or more advanced graphical plotting tools to construct visualizations from within loops, conditionals, and calls (tracing calls is described at the bottom of this section).The program can freely mix Dr.Jit computation with other array programming frameworks like PyTorch, Tensorflow, JAX, etc.
The main disadvantage of evaluated mode are overheads from constantly reading and writing from/to device memory. The resulting memory bandwidth and storage costs can be prohibitive.
Symbolic mode¶
Symbolic mode moves the control flow onto the target device. This is a natural choice: Dr.Jit already traces computation to generate fused kernels, and this simply extends that idea to include control flow as well. For this, Dr.Jit must trace loops that run for an unknown number of iterations, which it does by introducing symbolic variables to capture the change from one iteration to the next. Symbolic variables represent unknown information that will only become available later when the generated code runs on the device.
The advantage of symbolic mode is that it can keep variable state in fast CPU/GPU registers, which improves performance and reduces storage costs.
The main disadvantage is that symbolic variables cannot be evaluated while tracing. Likewise, they cannot be passed to other frameworks like PyTorch or Tensorflow. Indeed, any attempt to reveal the content of symbolic variables is doomed to fail since it literally does not exist (yet). The upcoming section on variable evaluation clarifies what operations require evaluation. Symbolic mode is the default, since the performance benefits usually outweigh these disadvantages.
Note
Here are a few more detailed notes about symbolic and evaluated loops for advanced users. Feel free to skip these if you are new to Dr.Jit.
Loops (
drjit.while_loop()
), conditionals (drjit.if_stmt()
), and dynamic dispatch (drjit.switch()
,drjit.dispatch()
) may be arbitrarily nested. However, it is not legal to nest evaluated operations within symbolic ones, as this would require the evaluation of symbolic variables.Printing array contents is not permitted in symbolic mode, but Dr.Jit also provides a requires a symbolic print statement implemented by
dr.print()
that prints in a delayed manner (i.e., asynchronously from the device) to avoid this problem.Symbolic mode tends to create much larger kernels. Indeed, the idea is to preserve the entire program and generate one giant output kernel (a megakernel). Such large kernels can be costly to compile, though this cost is usually offset by kernel caching discussed in the next section.
Large kernels produced by symbolic mode also tend to use a large number of registers, and this may impede the latency-hiding capabilities of GPUs. Simlarly, Dr.Jit always vectorizes computation (SIMD-style). Divergence in highly branching code produced by symbolic tracing may reduce performance.
Indirect calls¶
Dr.Jit provides the functions dr.switch()
and
dr.dispatch()
to capture indirect function calls
that target multiple possible targets. Here is an example:
# A sequence of fucntions with the same argument and return value signature
def f1(a, b, c):
# ...
return x, y
def f2(a, b, c):
# ...
return x, y
x, y = dr.switch(
targets=[f1, f2], # <-- call functions from the provided list ('f1' or 'f2')
index=index, # <-- choose based on the integer array 'index' (indices must be < 2 in this example)
a, b, c # <-- function parameters to forward to 'f1' and 'f2'
)
The reference of dr.switch()
and
dr.dispatch()
explains these two operations in full
detail. As with the previous control flow operations, they support compilation
in either symbolic or evaluated modes.
Pitfalls¶
Please be aware of the following potential issues involving tracing of control flow.
Unrolling loops. Consider a function
f(x)
, which calls another expensive functiong(x)
many times in a loop.@dr.syntax def f(x): for i in range(1000): x = g(x) return x
This will likely not yield the expected behavior: first, Dr.Jit’s
@dr.syntax
decorator ignoresfor
loops and only considerswhile
loops. Furthermore, it only processes loops with array-valued loop stopping conditions, which is not the case here. Therefore, this function actually unrolls the computation graph ofg
1000 times and is equivalent todef f(x): x = g(x) x = g(x) # .. (998 repetitions) .. return x
Compiling the resulting giant kernel can be very inefficient. Instead, consider rewriting the function as follows so that the loop can be traced:
from drjit.auto import Int @dr.syntax def f(x): i = Int(0) while i < 1000: x = g(x) i += 1 return x
Type constancy. Tracing control flow requires the type of state variables to remain consistent. For example, the following fails with an error message because the body of the
if
statement changesx
fromdrjit.*.Int
(a traced Dr.Jit type) to a lower caseint
(a built-in Python type).@dr.syntax def f(x: Int): if x < 0: x = 0 # ...
The problem is easily fixed by casting the assigned value to the expected type:
@dr.syntax def f(x: Int): if x < 0: x = Int(0) # ...
Traversal of nested objects. The
@dr.syntax
decorator transforms loops and conditionals into calls todr.while_loop()
anddr.if_stmt()
.This involves traversing local variables to detect potential changes during the loop or conditional statement. In the
Accum.add_positive()
example function below, bothy
andself
are automatically identified as such local variables.from drjit.auto import Int class Accum: def __init__(self): """Create a zero-initialized accumulator""" self.value = Int(0) @dr.syntax def add_positive(self, x: Int): """Accumulate 'x', but only if it is positive""" if x > 0: self.value += x a = Accum() a.add_positive(Int(1, -1)) print(a.value) # Prints: [1, -1] :-(
Unfortunately, there is a subtle bug in the above code: symbolic control flow operations only traverse PyTrees, and
self
(which is of typeAccum
) is not a PyTree. The implementation therefore misses the conditional nature of the change ofself.value
and produces the incorrect output[1, -1]
instead of the expected[1, 0]
.So what is a PyTree? Besides Dr.Jit arrays, they can consist of arbitrarily nested Python containers (
list
,tuple
,dict
), data classes, and custom classes with aDRJIT_STRUCT
annotation. To fix the problem, we can, e.g., add aDRJIT_STRUCT
annotation toAccum
to explain its sub-elements:class Accum: DRJIT_STRUCT = { 'value' : Int }
Alternatively, we can switch the implementation of
Accum
to a data class:from dataclasses import dataclass @dataclass class Accum: value: Int = Int(0) @dr.syntax def add_positive(self, y: Int): ...