Control flow

Dr.Jit can trace control flow statements like loops, conditionals, and indirect jumps if they are expressed in a compatible manner.

First, let’s see what can go wrong when doing this naively. The Python snippet below is meant to compute the population count (i.e., the number of bits set to 1) per element of an integer sequence:

from drjit.auto import Int

def popcnt(i: Int):
    '''Count the number of active bits in ``i``'''
    j = Int(0)
    while i != 0:  # While there are remaining active bits
        j += i & 1 # Increment counter 'j' if current bit active
        i = i // 2 # Shift bits of 'i' to the right
    return j

print(popcnt(dr.arange(Int, 1024)))

However, running it fails with an error message:

Traceback (most recent call last):
  File "popcnt.py", line 12, in <module>
    print(popcnt(dr.arange(Int, 1024)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "popcnt.py", line 7, in popcnt
    while i != 0:
          ^^^^^^
RuntimeError: drjit.llvm.Bool.__bool__(): implicit conversion to 'bool' ↵
requires an array with at most 1 element (this one has 1024 elements).

Why does this happen? In the example above, i is an array of integers, hence i != 0 produces a Boolean array of component-wise comparisons.

If these values aren’t all identical, it implies that elements of i require different numbers of loop iterations. Unfortunately, this is something that regular Python simply does not support. Dr.Jit raises the alarm when it notices the user’s attempt to interpret the condition i != 0 as a Python bool.

Annotating the function with the @dr.syntax decorator (change highlighted) fixes this problem:

@dr.syntax
def popcnt(i: Int):
    ...

The script now terminates and prints the following correct output:

[0, 1, 1, .. 1018 skipped .., 9, 9, 10]

How does this work?

The @dr.syntax decorator provides syntax sugar: it takes a function as input and returns a slightly modified version of it. Most code passes through unchanged, with two exceptions: the decorator rewrites the declarations of loops and conditionals to make them compatible with tracing.

In the example above, it does this by

  1. identifying the variables modified by the loop (i and j)

  2. encapsulating the loop condition and body in separate functions, and

  3. performing a call to dr.while_loop() with all of this information.

This produces code equivalent to:

def popcnt(i: Int):
    j = Int(0)
    i, j = dr.while_loop(
        state=(i, j),
        cond=lambda i, j: i != 0,
        body=lambda i, j: (i // 2, j + (i & 1))
    )
    return j

The function dr.while_loop() generalizes the built-in Python while loop: when the condition is a Python bool, it doesn’t do anything special and just reproduces the normal behavior. When the loop condition is an array, it runs the loop separately for each element, potentially for different numbers of iterations.

In the same manner, if statements will be turned into calls to dr.if_stmt() that serves the same purpose for conditionals.

The main feature of @dr.syntax is to free users from having to perform this transformation themselves.

Symbolic mode

The default way in which Dr.Jit handles control flow is called symbolic mode, which has certain limitations. Let’s make a small change to the code from before to illustrate one of them.

@dr.syntax
def popcnt(i: Int):
    '''Count the number of active bits'''
    j = Int(0)
    while i != 0:
        print(f"{i=}")
        j += i & 1
        i = i // 2
    return j

(the added print() statement is meant to show the state of variables at intermediate steps.)

Running this modified code produces a long error message:

Traceback (most recent call last):
  File "popcnt.py", line 9, in _loop_body
    print(f"{i=}")

RuntimeError: You performed an operation that tried to evalute a *symbolic*↵
variable, which is not permitted.

[lots of explanation text omitted here]

The message explains that i and j are considered symbolic while inside the loop. Certain operations are not allowed in this context, and printing their contents is one of them.

To understand why this is forbidden, recall that Dr.Jit embraces the idea of tracing, i.e., postponing computation for later evaluation. In the case of popcnt(), this means that Dr.Jit will execute the loop body only once to understand how it modifies the variables i and j, but without doing any actual computation. Even the number of loop iterations is unknown at this point. All of these details are postponed to when the traced computation actually runs on the target device (e.g., the GPU).

The implication of this design is that i and j are symbols that don’t have explicit values within the loop body, which is why the print() operation failed.

This way of capturing control flow is the default behavior of Dr.Jit and called symbolic mode. Dr.Jit also supports a second approach called evaluated mode that we will examine next.

Evaluated mode

The inability to access the contents of symbolic variables can be inconvenient. We might need to print or plot intermediate steps, or to step through a program using a visual debugger.

To do so, let’s switch the loop to evaluated mode. We can do so at a statement level by annotating the loop condition with dr.hint(..., mode='evaluated').

@dr.syntax
def popcnt(i: Int):
    '''Count the number of active bits'''
    j = Int(0)
    while dr.hint(i != 0, mode='evaluated'):
        print(f"{i=}")
        j += i & 1
        i = i // 2
    return j

popcnt(dr.arange(Int, 1024))

With this change, Dr.Jit now executes all loop iterations explicitly. Accessing the contents of i also works without problems, and the script produces the following output:

i=[0, 1, 2, .. 1018 skipped .., 1021, 1022, 1023]
i=[0, 0, 1, .. 1018 skipped .., 510, 511, 511]
i=[0, 0, 0, .. 1018 skipped .., 255, 255, 255]
i=[0, 0, 0, .. 1018 skipped .., 127, 127, 127]
i=[0, 0, 0, .. 1018 skipped .., 63, 63, 63]
i=[0, 0, 0, .. 1018 skipped .., 31, 31, 31]
i=[0, 0, 0, .. 1018 skipped .., 15, 15, 15]
i=[0, 0, 0, .. 1018 skipped .., 7, 7, 7]
i=[0, 0, 0, .. 1018 skipped .., 3, 3, 3]
i=[0, 0, 0, .. 1018 skipped .., 1, 1, 1]
[0, 1, 1, .. 1018 skipped .., 9, 9, 10]

Evaluated mode can also be enabled globally by disabling the flags dr.JitFlag.SymbolicLoops and dr.JitFlag.SymbolicConditionals via dr.set_flag() or dr.scoped_set_flag().

Discussion

Let’s take a step back and compare the properties of these two different modes.

Evaluated mode

As the name suggests, this mode evaluates loop variables to store them in memory. Each loop iteration then loads variable state and writes out new state at the end. The host (i.e., the CPU) is in charge of all control flow, which makes this mode simple to understand:

  • Debugging programs is straightforward. The user can step through program line by line and examine variable contents via Python’s built-in print() statement or more advanced graphical plotting tools to construct visualizations from within loops, conditionals, and calls (tracing calls is described at the bottom of this section).

  • The program can freely mix Dr.Jit computation with other array programming frameworks like PyTorch, Tensorflow, JAX, etc.

The main disadvantage of evaluated mode are overheads from constantly reading and writing from/to device memory. The resulting memory bandwidth and storage costs can be prohibitive.

Symbolic mode

Symbolic mode moves the control flow onto the target device. This is a natural choice: Dr.Jit already traces computation to generate fused kernels, and this simply extends that idea to include control flow as well. For this, Dr.Jit must trace loops that run for an unknown number of iterations, which it does by introducing symbolic variables to capture the change from one iteration to the next. Symbolic variables represent unknown information that will only become available later when the generated code runs on the device.

The advantage of symbolic mode is that it can keep variable state in fast CPU/GPU registers, which improves performance and reduces storage costs.

The main disadvantage is that symbolic variables cannot be evaluated while tracing. Likewise, they cannot be passed to other frameworks like PyTorch or Tensorflow. Indeed, any attempt to reveal the content of symbolic variables is doomed to fail since it literally does not exist (yet). The upcoming section on variable evaluation clarifies what operations require evaluation. Symbolic mode is the default, since the performance benefits usually outweigh these disadvantages.

Note

Here are a few more detailed notes about symbolic and evaluated loops for advanced users. Feel free to skip these if you are new to Dr.Jit.

  • Loops (drjit.while_loop()), conditionals (drjit.if_stmt()), and dynamic dispatch (drjit.switch(), drjit.dispatch()) may be arbitrarily nested. However, it is not legal to nest evaluated operations within symbolic ones, as this would require the evaluation of symbolic variables.

  • Printing array contents is not permitted in symbolic mode, but Dr.Jit also provides a requires a symbolic print statement implemented by dr.print() that prints in a delayed manner (i.e., asynchronously from the device) to avoid this problem.

  • Symbolic mode tends to create much larger kernels. Indeed, the idea is to preserve the entire program and generate one giant output kernel (a megakernel). Such large kernels can be costly to compile, though this cost is usually offset by kernel caching discussed in the next section.

  • Large kernels produced by symbolic mode also tend to use a large number of registers, and this may impede the latency-hiding capabilities of GPUs. Simlarly, Dr.Jit always vectorizes computation (SIMD-style). Divergence in highly branching code produced by symbolic tracing may reduce performance.

Indirect calls

Dr.Jit provides the functions dr.switch() and dr.dispatch() to capture indirect function calls that target multiple possible targets. Here is an example:

# A sequence of fucntions with the same argument and return value signature
def f1(a, b, c):
   # ...
   return x, y

def f2(a, b, c):
   # ...
   return x, y

x, y = dr.switch(
   targets=[f1, f2], # <-- call functions from the provided list ('f1' or 'f2')
   index=index,      # <-- choose based on the integer array 'index' (indices must be < 2 in this example)
   a, b, c           # <-- function parameters to forward to 'f1' and 'f2'
)

The reference of dr.switch() and dr.dispatch() explains these two operations in full detail. As with the previous control flow operations, they support compilation in either symbolic or evaluated modes.

Pitfalls

Please be aware of the following potential issues involving tracing of control flow.

  1. Unrolling loops. Consider a function f(x), which calls another expensive function g(x) many times in a loop.

    @dr.syntax
    def f(x):
        for i in range(1000):
            x = g(x)
        return x
    

    This will likely not yield the expected behavior: first, Dr.Jit’s @dr.syntax decorator ignores for loops and only considers while loops. Furthermore, it only processes loops with array-valued loop stopping conditions, which is not the case here. Therefore, this function actually unrolls the computation graph of g 1000 times and is equivalent to

    def f(x):
        x = g(x)
        x = g(x)
        # .. (998 repetitions) ..
        return x
    

    Compiling the resulting giant kernel can be very inefficient. Instead, consider rewriting the function as follows so that the loop can be traced:

    from drjit.auto import Int
    
    @dr.syntax
    def f(x):
        i = Int(0)
        while i < 1000:
            x = g(x)
            i += 1
        return x
    
  2. Type constancy. Tracing control flow requires the type of state variables to remain consistent. For example, the following fails with an error message because the body of the if statement changes x from drjit.*.Int (a traced Dr.Jit type) to a lower case int (a built-in Python type).

    @dr.syntax
    def f(x: Int):
        if x < 0:
            x = 0
        # ...
    

    The problem is easily fixed by casting the assigned value to the expected type:

    @dr.syntax
    def f(x: Int):
        if x < 0:
            x = Int(0)
        # ...
    
  3. Traversal of nested objects. The @dr.syntax decorator transforms loops and conditionals into calls to dr.while_loop() and dr.if_stmt().

    This involves traversing local variables to detect potential changes during the loop or conditional statement. In the Accum.add_positive() example function below, both y and self are automatically identified as such local variables.

    from drjit.auto import Int
    
    class Accum:
        def __init__(self):
            """Create a zero-initialized accumulator"""
            self.value = Int(0)
    
        @dr.syntax
        def add_positive(self, x: Int):
            """Accumulate 'x', but only if it is positive"""
            if x > 0:
                self.value += x
    
    a = Accum()
    a.add_positive(Int(1, -1))
    print(a.value) # Prints: [1, -1]    :-(
    

    Unfortunately, there is a subtle bug in the above code: symbolic control flow operations only traverse PyTrees, and self (which is of type Accum) is not a PyTree. The implementation therefore misses the conditional nature of the change of self.value and produces the incorrect output [1, -1] instead of the expected [1, 0].

    So what is a PyTree? Besides Dr.Jit arrays, they can consist of arbitrarily nested Python containers (list, tuple, dict), data classes, and custom classes with a DRJIT_STRUCT annotation. To fix the problem, we can, e.g., add a DRJIT_STRUCT annotation to Accum to explain its sub-elements:

    class Accum:
        DRJIT_STRUCT = { 'value' : Int }
    

    Alternatively, we can switch the implementation of Accum to a data class:

    from dataclasses import dataclass
    
    @dataclass
    class Accum:
        value: Int = Int(0)
    
        @dr.syntax
        def add_positive(self, y: Int):
            ...