Trainer Core#
MLX::DSL::Trainer provides high-level training and reporting workflows.
trainer = model.trainer(optimizer: optimizer) do |x:, y:|
MLX::NN.cross_entropy(model.call(x), y, reduction: "mean")
end
Core methods#
fit(dataset, **kwargs)fit_report(dataset, **kwargs)
trainer.fit(train_data, epochs: 5)
report = trainer.fit_report(train_data, epochs: 5, validation_data: validation_data)
puts report[:epochs_ran]
Core fit options#
epochs,limit,reducevalidation_data,validation_limit,validation_reducemonitor,metric,monitor_modepatience,min_deltakeep_losses,strict_data_reusecompileandsync(via trainer construction)
trainer.fit_report(
train_data,
epochs: 20,
limit: 200,
reduce: :mean,
validation_data: validation_data,
validation_limit: 50,
monitor: :loss,
monitor_mode: :min,
patience: 3,
min_delta: 0.001
)
Lifecycle hooks#
before_fitbefore_epochafter_batchbefore_validationafter_validation_batchafter_validationafter_epochcheckpointafter_fit
All hook registration routes through on(event, ...) and supports
priority, every, once, and conditional if.
trainer.before_epoch { |ctx| puts "epoch=#{ctx[:epoch]}" }
trainer.after_batch(every: 50) { |ctx| puts "step=#{ctx[:step]}" }
trainer.after_fit { |ctx| puts "best=#{ctx[:best_monitor]}" }
Error diagnostics#
Train/validation batch failures include kind, epoch, and
batch_index context.
begin
trainer.fit(train_data, epochs: 1)
rescue => e
warn("#{e.class}: #{e.message}")
warn("kind=#{e.respond_to?(:kind) ? e.kind : :unknown}")
end
See implementation:
lib/mlx/dsl/trainer.rb