Skip to content

Arolla Dynamic Evaluation Details

Alexander G. Pronchenkov edited this page May 16, 2024 · 2 revisions

Arolla Dynamic Evaluation Details

This document describes how Arolla Dynamic Evaluation works. The intended audience are the users who are interested in deeper understanding of the system, as well as Arolla team members.

Memory management

arolla::FrameLayout describes a data structure that occupies a consecutive piece of memory. It is similar to C++ struct definition, but is created in runtime instead of compile time. Each field in the FrameLayout is described as arolla::FrameLayout::Slot, which is essentially an offset from the beginning of the frame. (There is also a type-erased version of the Slot<T> called arolla::FrameLayout::TypedSlot, essentially a pair of an offset and a arolla::QType.) FrameLayouts can be nested, similarly to C++ structs nested into each other.

Once we have a FrameLayout we can allocate one or several "frames" corresponding to it (see arolla::FramePtr) and arolla::MemoryAllocation. The relation between the FrameLayout and the FramePtr is similar to the relation between a struct Foo definition and Foo* pointer to an instance.

Arolla Dynamic Evaluator creates a single FrameLayout for all the inputs, outputs, literal values, and temporaries during Expr compilation. Depending on the settings it may allocate a new frame for each evaluation, or keep a cache of pre-initialized frames and reuse them.

Example of FrameLayout and several frames

Evaluation

During compilation Arolla Dynamic Evaluator creates BoundExpr: two linear sequences of instructions (aka arolla::BoundOperators). One to initialize a new frame, and one to evaluate the expression. Each BoundOperator stores its input and output slots (aka offsets from the beginning of the frame) and can be evaluated by passing FramePtr.

Here is an example of a compiled expression evaluating math.add(core.to_float64(L.x), arolla.float64(1)):

Inputs:

x: FLOAT32 [0x00]

Output:

FLOAT64 [0x08]

Init operations:

FLOAT64 [0x18] = float64{1}

Eval operations:

FLOAT64 [0x10] = core.to_float64(FLOAT32 [0x00])
FLOAT64 [0x08] = math.add(FLOAT64 [0x10], FLOAT64 [0x18])

Example of compiled expression

End-to-end evaluation

The described above BoundExpr handles only evaluation on a frame. There are other steps to perform the evaluation end-to-end:

  1. Allocate frame(s) and manage their lifetime. For each new frame BoundExpr::InitializeLiterals must be called.
  2. Load input fields form a user-provided data source into a frame (using InputLoader).
  3. BoundExpr::Execute that performs the actual evaluation.
  4. (optionally) copy side output fields into the user-provided data sink (using SlotListener).
  5. Read the output value from the frame and return it to the user.

All these steps are performed by arolla::expr::ModelExecutor class. Note that this logic is shared between Dynamic Eval and Codegen.

Compilation overview

The process of transforming an Expr into a BoundExpr (aka "compilation") consists of two major stages:

  1. Expr preparation.
  2. Transforming the prepared Expr into a sequence of BoundOperators.

Expr preparation

In order to simplify the logic we are trying to stay in the "Expr" world as long as possible, i.e. to keep most of the compiler logic at the expr preparation stage. It consists of:

  1. If side outputs are requested, join the main Expr with all the requested side outputs using a single top-level InternalRootOperator.
  2. Annotate all the input leaves with their QTypes. For example, L.x becomes M.annotation.qtype(L.x, FLOAT32). Types are automatically propagated in expressions, so after this step type of each node will be known.
  3. Literal folding: if a subexpression does not depend on inputs, it gets replaced with its value. For example, 1 + 2 + L.x becomes 3 + L.x.
  4. Lowering all the operators. For example, M.core.cast(x, INT64) becomes either M.core.to_int64(x) or just x, depending on the type of x.
  5. Strip unnecessary annotations.
  6. Backend compatibility casting: if there is no exact match of the backend (QExpr) operator, implementation can insert casting operators. For example, optional_int64_value + int32{1} becomes optional_int64_value + M.core.to_optional(M.core.to_int64(int32{1})) (and then optional_int64_value + optional_int64{1} after another round of literal folding).
  7. Apply optimizations. For example M.core.get_second(M.core.make_tuple(x, y, z)) becomes just y.
  8. Apply custom preprocessing using compiler extensions (see the section below). For example, M.core.map(M.math.add, L.x, L.y) becomes packed_core_map[M.math.add](L.x, L.y) in order to hide M.math.add literal from the expression and avoid allocating a useless slot for it.
  9. Apply heuristics to the whole expression to enable short-circuit evaluation. For example, M.core.where(condition, x + y**2, x - y**2) becomes something like internal_packed_where[$1 + $2, $1 - $2](condition, x, y**2).
  10. Extract type annotations into a separate map, stripping them from the expression.

Prepared expr compilation

We expect the following invariants to hold for the prepared expression:

  1. All the node types are known and stored in a separate map.
  2. Expression does not contain any placeholders.
  3. Expression does not contain any annotations.
  4. All the operators are at the lowest level, they have either BackendExprOperatorTag or BuiltinExprOperatorTag.
  5. If any of the side outputs is requested, the root operator is InternalRootOperator, with the first argument being the main output and the rest being the side outputs in the specific order.

Once the expression is prepared, we perform a single post order traversal, collecting a set of literals to initialize and a building sequence of BoundOperators that will be evaluating the expression:

  • For most of the operators (with BackendExprOperatorTag), the compiler just searches for a QExpr operator with the same name and input / output QTypes and calls ::Bind on it.
  • Some special cases (with BuiltinExprOperatorTag) are handled separately. For example, for short circuit where operator we call compilation for each of the branches and insert jump operators between them. Some of them are handled using compiler extensions.

We may add slots for intermediate values to the FrameLayout as we go. If slot is not needed for further evaluation, it may be reused for another intermediate value.

Compiler extensions

To customize operator compilation one can register a compiler extension. Both of the compilation stages can be customized using CompilerExtensionRegistry:

  • RegisterNodeTransformationFn specifies a function applied on the preparation stage.
  • RegisterCompileOperatorFn is applied on the compilation stage.

For example, the extension for core.map operator consists of two parts: on the preparation stage (see arolla/qexpr/eval_extensions/prepare_core_map_operator.cc) it transforms M.core.map(op, *args) into packed_core_map[op](*args), where packed_core_map[op] is a stateful operator. During the compilation stage (see arolla/qexpr/eval_extensions/compile_core_map_operator.cc) it extracts the op value from packed_core_map and manually creates a BoundOperator that applies op to the arguments. Because the op literal was removed from the expression during the preparation stage, we do not need to allocate a useless slot for it in the FrameLayout and pay for its initialization.

Example

Here is bigger example of the compilation result:

from arolla import arolla
from arolla.experimental import debug

M, L = arolla.M, arolla.L

# The current version of debug.describe_dynamic_eval requires all the leaves to
# have a type annotation.
a = M.annotation.qtype(L.a, arolla.FLOAT32)
# Different type for the sake of example.
b = M.annotation.qtype(L.b, arolla.FLOAT64)
c = M.annotation.qtype(L.c, arolla.FLOAT64)

d = b * b - 4. * a * c
sqrt_d = M.math.pow(d, 0.5)
first_root = (-b + sqrt_d) / 2. / a
second_root = (-b - sqrt_d) / 2. / a
roots = M.core.make_tuple(
    M.core.where(d >= 0., first_root, M.core.empty_like(first_root)),
    M.core.where(d >= 0., second_root, M.core.empty_like(second_root)))

print(debug.describe_dynamic_eval(roots))

Prints:

Inputs:

a: FLOAT32 [0x00]
b: FLOAT64 [0x08]
c: FLOAT64 [0x10]

Output:

tuple<OPTIONAL_FLOAT64,OPTIONAL_FLOAT64> [0x18]

Init operations:

FLOAT64 [0x38] = float64{0}
FLOAT32 [0x48] = 4.
FLOAT64 [0x68] = float64{0.5}
FLOAT64 [0x70] = float64{2}
OPTIONAL_FLOAT64 [0x78] = optional_float64{NA}

Eval operations:

FLOAT64 [0x40] = math.multiply(FLOAT64 [0x08], FLOAT64 [0x08])
FLOAT32 [0x4C] = math.multiply(FLOAT32 [0x48], FLOAT32 [0x00])
FLOAT64 [0x50] = core.to_float64(FLOAT32 [0x4C])
FLOAT64 [0x58] = math.multiply(FLOAT64 [0x50], FLOAT64 [0x10])
FLOAT64 [0x50] = math.subtract(FLOAT64 [0x40], FLOAT64 [0x58])
OPTIONAL_UNIT [0x60] = core.less_equal(FLOAT64 [0x38], FLOAT64 [0x50])
FLOAT64 [0x58] = math.neg(FLOAT64 [0x08])
FLOAT64 [0x40] = math._pow(FLOAT64 [0x50], FLOAT64 [0x68])
FLOAT64 [0x50] = core.to_float64(FLOAT32 [0x00])
jump_if_not<+5>(OPTIONAL_UNIT [0x60])
FLOAT64 [0x98] = math.add(FLOAT64 [0x58], FLOAT64 [0x40])
FLOAT64 [0xA0] = math.divide(FLOAT64 [0x98], FLOAT64 [0x70])
FLOAT64 [0x98] = math.divide(FLOAT64 [0xA0], FLOAT64 [0x50])
OPTIONAL_FLOAT64 [0x88] = core.to_optional._scalar(FLOAT64 [0x98])
jump<+1>()
OPTIONAL_FLOAT64 [0x88] = core._copy(OPTIONAL_FLOAT64 [0x78])
jump_if_not<+5>(OPTIONAL_UNIT [0x60])
FLOAT64 [0xB8] = math.subtract(FLOAT64 [0x58], FLOAT64 [0x40])
FLOAT64 [0xC0] = math.divide(FLOAT64 [0xB8], FLOAT64 [0x70])
FLOAT64 [0xB8] = math.divide(FLOAT64 [0xC0], FLOAT64 [0x50])
OPTIONAL_FLOAT64 [0xA8] = core.to_optional._scalar(FLOAT64 [0xB8])
jump<+1>()
OPTIONAL_FLOAT64 [0xA8] = core._copy(OPTIONAL_FLOAT64 [0x78])
tuple<OPTIONAL_FLOAT64,OPTIONAL_FLOAT64> [0x18] = core.make_tuple(OPTIONAL_FLOAT64 [0x88], OPTIONAL_FLOAT64 [0xA8])

Note above that:

  1. Some operators got lowered, e.g. math.pow became math._pow.
  2. Some intermediate slots, e.g. FLOAT64 [0x50] are reused.
  3. Literal folding is applied, e.g. optional_float64{NA} literal comes from M.core.empty_like(first_root) expression.
  4. core.to_float64 operator inserted for type casting.
  5. Short-circuiting implemented using jump_if_not and jump instructions.