Compilation#

MLX has a compile() function transformation which compiles computation graphs. Function compilation results in smaller graphs by merging common work and fusing certain operations. In many cases this can lead to big improvements in run-time and memory use.

Getting started with compile() is simple, but there are some edge cases that are good to be aware of for more complex graphs and advanced usage.

Basics of Compile#

Let’s start with a simple example:

def fun(x, y)
  mx.exp(mx.negative(x)) + y
end

x = mx.array(1.0)
y = mx.array(2.0)

# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
puts fun(x, y)

# Compile the function
compiled_fun = mx.compile(method(:fun))

# Prints: array(2.36788, dtype=float32)
puts compiled_fun.call(x, y)

The output of both the regular function and the compiled function is the same up to numerical precision.

The first time you call a compiled function, MLX will build the compute graph, optimize it, and generate and compile code. This can be relatively slow. However, MLX will cache compiled functions, so calling a compiled function multiple times will not initiate a new compilation. This means you should typically compile functions that you plan to use more than once.

def fun(x, y)
  mx.exp(mx.negative(x)) + y
end

x = mx.array(1.0)
y = mx.array(2.0)

compiled_fun = mx.compile(method(:fun))

# Compiled here
compiled_fun.call(x, y)

# Not compiled again
compiled_fun.call(x, y)

# Not compiled again
compiled_fun.call(x, y)

There are some important cases to be aware of that can cause a function to be recompiled:

  • Changing the shape or number of dimensions

  • Changing the type of any of the inputs

  • Changing the number of inputs to the function

In certain cases only some of the compilation stack will be rerun (for example when changing the shapes) and in other cases the full compilation stack will be rerun (for example when changing the types). In general you should avoid compiling functions too frequently.

Another idiom to watch out for is compiling functions which get created and destroyed frequently. This can happen, for example, when compiling an anonymous function in a loop:

a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
5.times do
  mx.compile(->(x) { mx.exp(mx.abs(x)) }).call(a)
end

Example Speedup#

The mlx.nn.gelu() is a nonlinear activation function commonly used with Transformer-based models. The implementation involves several unary and binary element-wise operations:

def gelu(x)
  x * (mx.erf(x / Math.sqrt(2.0)) + 1.0) / 2.0
end

If you use this function with small arrays, it will be overhead bound. If you use it with large arrays it will be memory bandwidth bound. However, all of the operations in the gelu are fusible into a single kernel with compile(). This can speedup both cases considerably.

Let’s compare the runtime of the regular function versus the compiled function. We’ll use the following timing helper which does a warm up and handles synchronization:

# Ruby timing helper
require "benchmark"

def timeit(fun, x)
  # warm up
  2.times { mx.eval(fun.call(x)) }

  tpi = Benchmark.realtime do
    10.times { mx.eval(fun.call(x)) }
  end
  1000.0 * tpi / 10.0
end

Now make an array, and benchmark both functions:

x = mx.random_uniform([8, 256, 256], 0.0, 1.0, mx.float32)
timeit(->(t) { gelu(t) }, x)
timeit(mx.compile(->(t) { gelu(t) }), x)

On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled gelu is five times faster.

Debugging#

When a compiled function is first called, it is traced with placeholder inputs. This means you can’t evaluate arrays (for example to print their contents) inside compiled functions.

fun = mx.compile(->(x) do
  z = mx.negative(x)
  puts z # Crash
  mx.exp(z)
end)

fun.call(mx.array(5.0))

For debugging, inspecting arrays can be helpful. One way to do that is to globally disable compilation using the disable_compile() function or MLX_DISABLE_COMPILE flag. For example the following is okay even though fun is compiled:

fun = mx.compile(->(x) do
  z = mx.negative(x)
  puts z # Okay
  mx.exp(z)
end)

mx.disable_compile()
fun.call(mx.array(5.0))

Pure Functions#

Compiled functions are intended to be pure; that is they should not have side effects. For example:

state = []

fun = mx.compile(->(x, y) do
  z = x + y
  state << z
  mx.exp(z)
end)

fun.call(mx.array(1.0), mx.array(2.0))
# Crash!
puts state

After the first call of fun, the state list will hold a placeholder array. The placeholder does not have any data; it is only used to build the computation graph. Printing such an array results in a crash.

You have two options to deal with this. The first option is to simply return state as an output:

state = []

fun = mx.compile(->(x, y) do
  z = x + y
  state << z
  [mx.exp(z), state]
end)

_, state = fun.call(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
puts state

In some cases returning updated state can be pretty inconvenient. Hence, compile() has a parameter to capture implicit outputs:

state = []

# Tell compile to capture state as an output
fun = mx.compile(
  ->(x, y) do
    z = x + y
    state << z
    mx.exp(z)
  end,
  outputs: state
)

fun.call(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
puts state

This is particularly useful for compiling a function which includes an update to a container of arrays, as is commonly done when training the parameters of a mlx.nn.Module.

Compiled functions will also treat any inputs not in the parameter list as constants. For example:

state = [mx.array(1.0)]

fun = mx.compile(->(x) { x + state[0] })

# Prints array(2, dtype=float32)
puts fun.call(mx.array(1.0))

# Update state
state[0] = mx.array(5.0)

# Still prints array(2, dtype=float32)
puts fun.call(mx.array(1.0))

In order to have the change of state reflected in the outputs of fun you again have two options. The first option is to simply pass state as input to the function.

state = [mx.array(1.0)]

fun = mx.compile(->(x, current_state) { x + current_state[0] })

# Prints array(2, dtype=float32)
puts fun.call(mx.array(1.0), state)

# Update state
state[0] = mx.array(5.0)

# Prints array(6, dtype=float32)
puts fun.call(mx.array(1.0), state)

In some cases this can be pretty inconvenient. Hence, compile() also has a parameter to capture implicit inputs:

state = [mx.array(1.0)]

# Tell compile to capture state as an input
fun = mx.compile(->(x) { x + state[0] }, inputs: state)

# Prints array(2, dtype=float32)
puts fun.call(mx.array(1.0))

# Update state
state[0] = mx.array(5.0)

# Prints array(6, dtype=float32)
puts fun.call(mx.array(1.0))

Compiling Training Graphs#

This section will step through how to use compile() with a simple example of a common setup: training a model with mlx.nn.Module using an mlx.optimizers.Optimizer with state. We will show how to compile the full forward, backward, and update with compile().

To start, here is the simple example without any compilation:

require "mlx"
mx = MLX::Core
nn = MLX::NN
optim = MLX::Optimizers
# 2 examples with 2 features each
x = mx.random_uniform([2, 2], 0.0, 1.0, mx.float32)

# 0, 1 targets
y = mx.array([0, 1])

# Simple linear model
model = nn::Linear.new(2, 1)

# SGD with momentum
optimizer = optim::SGD.new(learning_rate: 0.1, momentum: 0.8)

loss_fn = ->(model, x, y) do
  logits = model.call(x).squeeze
  nn.binary_cross_entropy(logits, y)
end

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Perform a few steps of gradient descent
3.times do
  loss, grads = loss_and_grad_fn.call(model, x, y)
  optimizer.update(model, grads)
  mx.eval(model.parameters, optimizer.state)
end

To compile the update we can put it all in a function and compile it with the appropriate input and output captures. Here’s the same example but compiled:

require "mlx"
mx = MLX::Core
nn = MLX::NN
optim = MLX::Optimizers

# 4 examples with 10 features each
x = mx.random_uniform([4, 10], 0.0, 1.0, mx.float32)

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn::Linear.new(10, 1)

# SGD with momentum
optimizer = optim::SGD.new(learning_rate: 0.1, momentum: 0.8)

loss_fn = ->(model, x, y) do
  logits = model.call(x).squeeze
  nn.binary_cross_entropy(logits, y)
end

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# The state that will be captured as input and output
state = [model.state, optimizer.state]

step = mx.compile(
  ->(x, y) do
    loss, grads = loss_and_grad_fn.call(model, x, y)
    optimizer.update(model, grads)
    loss
  end,
  inputs: state,
  outputs: state
)

# Perform one step of gradient descent
loss = step.call(x, y)
mx.eval(model.parameters, optimizer.state)
puts loss

Note

If you are using a module which performs random sampling such as mlx.nn.Dropout(), make sure you also include mx.random.state in the state captured by compile(), i.e. state = [model.state, optimizer.state, mx.random.state].

Note

For more examples of compiling full training graphs checkout the MLX Examples GitHub repo.

Transformations with Compile#

In MLX function transformations are composable. You can apply any function transformation to the output of any other function transformation. For more on this, see the documentation on function transforms.

Compiling transformed functions works just as expected:

grad_fn = mx.grad(->(x) { mx.exp(x) })

compiled_grad_fn = mx.compile(grad_fn)

# Prints: array(2.71828, dtype=float32)
puts grad_fn.call(mx.array(1.0))

# Also prints: array(2.71828, dtype=float32)
puts compiled_grad_fn.call(mx.array(1.0))

Note

In order to compile as much as possible, a transformation of a compiled function will not by default be compiled. To compile the transformed function simply pass it through compile().

You can also compile functions which themselves call compiled functions. A good practice is to compile the outer most function to give compile() the most opportunity to optimize the computation graph:

inner = mx.compile(->(x) { mx.exp(mx.negative(mx.abs(x))) })

outer = ->(x) do
  inner.call(inner.call(x))
end

# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

Shapeless Compilation#

When the shape of an input to a compiled function changes, the function is recompiled. You can compile a function once and run it on inputs with variable shapes by specifying shapeless: true to compile(). In this case changes to the shapes of the inputs do not cause the function to be recompiled.

def fun(x, y)
  mx.abs(x + y)
end

compiled_fun = mx.compile(method(:fun), shapeless: true)

x = mx.array(1.0)
y = mx.array(-2.0)

# Firt call compiles the function
puts compiled_fun.call(x, y)

# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
puts compiled_fun.call(x, y)

Use shapeless compilations carefully. Since compilation is not triggered when shapes change, any graphs which are conditional on the input shapes will not work as expected. Shape-dependent computations are common and sometimes subtle to detect. For example:

# docs-test: expect-error
def fun(x)
  x.reshape(x.shape[0] * x.shape[1], -1)
end

compiled_fun = mx.compile(method(:fun), shapeless: true)

x = mx.random_uniform([2, 3, 4], 0.0, 1.0, mx.float32)

out = compiled_fun.call(x)

x = mx.random_uniform([5, 5, 3], 0.0, 1.0, mx.float32)

# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun.call(x)

The second call to the compiled_fun fails because of the call to reshape() which uses the static shape of x in the first call. We can fix this by using flatten() to avoid hardcoding the shape of x:

def fun(x)
  x.flatten(0, 1)
end

compiled_fun = mx.compile(method(:fun), shapeless: true)

x = mx.random_uniform([2, 3, 4], 0.0, 1.0, mx.float32)

out = compiled_fun.call(x)

x = mx.random_uniform([5, 5, 3], 0.0, 1.0, mx.float32)

# Ok
out = compiled_fun.call(x)