Data Parallelism

Data Parallelism#

MLX enables efficient data parallel distributed training through its distributed communication primitives.

Training Example#

In this section we will adapt an MLX training loop to support data parallel distributed training. Namely, we will average the gradients across a set of hosts before applying them to the model.

Our training loop looks like the following code snippet if we omit the model, dataset, and optimizer initialization.

model = Struct.new(:parameters).new({"w" => mx.array([0.0])})
optimizer = Struct.new(:calls) do
  def update(model, grads)
    model.parameters["w"] = model.parameters["w"] - grads["w"] * 0.1
  end
end.new(0)
dataset = [[MLX::Core.array([1.0]), MLX::Core.array([1.0])]]

loss_grad_fn = ->(model, x, y) do
  pred = model.parameters["w"] * x
  loss = mx.mean((pred - y).square)
  grads = {"w" => mx.mean((pred - y) * x * 2.0)}
  [loss, grads]
end

step = ->(model, x, y) do
  loss, grads = loss_grad_fn.call(model, x, y)
  optimizer.update(model, grads)
  loss
end

dataset.each do |x, y|
  loss = step.call(model, x, y)
  mx.eval(loss, model.parameters)
end

All we have to do to average the gradients across machines is perform an all_sum() and divide by the size of the Group. Namely we have to MLX::Utils.tree_map() the gradients with following function.

def all_avg(x)
  world = mx.init
  mx.all_sum(x, world) / world.size
end

Putting everything together our training loop step looks as follows with everything else remaining the same.

def all_reduce_grads(grads)
  world = mx.init
  world_size = world.size
  return grads if world_size == 1

  MLX::Utils.tree_map(
    ->(x) { mx.all_sum(x, world) / world_size },
    grads
  )
end

step = ->(model, x, y) do
  loss, grads = loss_grad_fn.call(model, x, y)
  grads = all_reduce_grads(grads) # <--- This line was added
  optimizer.update(model, grads)
  loss
end

Using nn.average_gradients#

Although the code example above works correctly; it performs one communication per gradient. It is significantly more efficient to aggregate several gradients together and perform fewer communication steps.

This is the purpose of mlx.nn.average_gradients(). The final code looks almost identical to the example above:

model = Struct.new(:parameters).new({"w" => mx.array([0.0])})
optimizer = Struct.new(:calls) do
  def update(model, grads)
    model.parameters["w"] = model.parameters["w"] - grads["w"] * 0.1
  end
end.new(0)
dataset = [[MLX::Core.array([1.0]), MLX::Core.array([1.0])]]

loss_grad_fn = ->(model, x, y) do
  pred = model.parameters["w"] * x
  loss = mx.mean((pred - y).square)
  grads = {"w" => mx.mean((pred - y) * x * 2.0)}
  [loss, grads]
end

step = ->(model, x, y) do
  world = mx.init
  loss, grads = loss_grad_fn.call(model, x, y)
  grads = MLX::NN.average_gradients(grads, world) # <---- This line was added
  optimizer.update(model, grads)
  loss
end

dataset.each do |x, y|
  loss = step.call(model, x, y)
  mx.eval(loss, model.parameters)
end