Index A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | Z A abs (in module mlx.core) (in module mlx.core.array) AdaDelta (class in mlx.optimizers) Adafactor (class in mlx.optimizers) Adagrad (class in mlx.optimizers) Adam (class in mlx.optimizers) Adamax (class in mlx.optimizers) AdamW (class in mlx.optimizers) add (in module mlx.core) addmm (in module mlx.core) ALiBi (class in mlx.nn) all (in module mlx.core) (in module mlx.core.array) all_gather (in module mlx.core.distributed) all_sum (in module mlx.core.distributed) allclose (in module mlx.core) AllToShardedLinear (class in mlx.nn) any (in module mlx.core) (in module mlx.core.array) apply (in module mlx.nn) apply_gradients (in module mlx.optimizers) apply_to_modules (in module mlx.nn) arange (in module mlx.core) arccos (in module mlx.core) arccosh (in module mlx.core) arcsin (in module mlx.core) arcsinh (in module mlx.core) arctan (in module mlx.core) arctan2 (in module mlx.core) arctanh (in module mlx.core) argmax (in module mlx.core) (in module mlx.core.array) argmin (in module mlx.core) (in module mlx.core.array) argpartition (in module mlx.core) argsort (in module mlx.core) array (in module mlx.core) array_equal (in module mlx.core) as_strided (in module mlx.core) astype (in module mlx.core.array) async_eval (in module mlx.core) at (in module mlx.core.array) atleast_1d (in module mlx.core) atleast_2d (in module mlx.core) atleast_3d (in module mlx.core) average_gradients (in module mlx.nn) AvgPool1d (class in mlx.nn) AvgPool2d (class in mlx.nn) AvgPool3d (class in mlx.nn) B BatchNorm (class in mlx.nn) binary_cross_entropy (class in mlx.nn.losses) bitwise_and (in module mlx.core) bitwise_invert (in module mlx.core) bitwise_or (in module mlx.core) bitwise_xor (in module mlx.core) block_masked_mm (in module mlx.core) broadcast_arrays (in module mlx.core) broadcast_to (in module mlx.core) C ceil (in module mlx.core) CELU (class in mlx.nn) celu (class in mlx.nn) checkpoint (in module mlx.core) children (in module mlx.nn) cholesky (in module mlx.core.linalg) cholesky_inv (in module mlx.core.linalg) clear_cache (in module mlx.core) clip (in module mlx.core) clip_grad_norm (in module mlx.optimizers) compile (in module mlx.core) concatenate (in module mlx.core) conj (in module mlx.core) (in module mlx.core.array) conjugate (in module mlx.core) constant (in module mlx.nn.init) contiguous (in module mlx.core) Conv1d (class in mlx.nn) conv1d (in module mlx.core) Conv2d (class in mlx.nn) conv2d (in module mlx.core) Conv3d (class in mlx.nn) conv3d (in module mlx.core) conv_general (in module mlx.core) conv_transpose1d (in module mlx.core) conv_transpose2d (in module mlx.core) conv_transpose3d (in module mlx.core) convolve (in module mlx.core) ConvTranspose1d (class in mlx.nn) ConvTranspose2d (class in mlx.nn) ConvTranspose3d (class in mlx.nn) cos (in module mlx.core) (in module mlx.core.array) cosh (in module mlx.core) cosine_decay (in module mlx.optimizers) cosine_similarity_loss (class in mlx.nn.losses) cross (in module mlx.core.linalg) cross_entropy (class in mlx.nn.losses) cuda_kernel (in module mlx.core.fast) cummax (in module mlx.core) (in module mlx.core.array) cummin (in module mlx.core) (in module mlx.core.array) cumprod (in module mlx.core) (in module mlx.core.array) cumsum (in module mlx.core) (in module mlx.core.array) custom_function (in module mlx.core) D default_device (in module mlx.core) default_stream (in module mlx.core) degrees (in module mlx.core) dequantize (in module mlx.core) Device (in module mlx.core) device_count (in module mlx.core) device_info (in module mlx.core) (in module mlx.core.metal) diag (in module mlx.core) (in module mlx.core.array) diagonal (in module mlx.core) (in module mlx.core.array) disable_compile (in module mlx.core) divide (in module mlx.core) divmod (in module mlx.core) Dropout (class in mlx.nn) Dropout2d (class in mlx.nn) Dropout3d (class in mlx.nn) Dtype (in module mlx.core) dtype (in module mlx.core.array) DtypeCategory (in module mlx.core) E eig (in module mlx.core.linalg) eigh (in module mlx.core.linalg) eigvals (in module mlx.core.linalg) eigvalsh (in module mlx.core.linalg) einsum (in module mlx.core) einsum_path (in module mlx.core) ELU (class in mlx.nn) elu (class in mlx.nn) Embedding (class in mlx.nn) enable_compile (in module mlx.core) equal (in module mlx.core) erf (in module mlx.core) erfinv (in module mlx.core) eval (in module mlx.core) (in module mlx.nn) exp (in module mlx.core) (in module mlx.core.array) expand_dims (in module mlx.core) expm1 (in module mlx.core) exponential_decay (in module mlx.optimizers) export_function (in module mlx.core) export_to_dot (in module mlx.core) exporter (in module mlx.core) eye (in module mlx.core) F fft (in module mlx.core.fft) fft2 (in module mlx.core.fft) fftn (in module mlx.core.fft) fftshift (in module mlx.core.fft) filter_and_map (in module mlx.nn) finfo (in module mlx.core) flatten (in module mlx.core) (in module mlx.core.array) floor (in module mlx.core) floor_divide (in module mlx.core) freeze (in module mlx.nn) full (in module mlx.core) G gather_mm (in module mlx.core) gather_qmm (in module mlx.core) gaussian_nll_loss (class in mlx.nn.losses) GELU (class in mlx.nn) gelu (class in mlx.nn) gelu_approx (class in mlx.nn) gelu_fast_approx (class in mlx.nn) get_active_memory (in module mlx.core) get_cache_memory (in module mlx.core) get_peak_memory (in module mlx.core) glorot_normal (in module mlx.nn.init) glorot_uniform (in module mlx.nn.init) GLU (class in mlx.nn) glu (class in mlx.nn) grad (in module mlx.core) greater (in module mlx.core) greater_equal (in module mlx.core) Group (in module mlx.core.distributed) GroupNorm (class in mlx.nn) GRU (class in mlx.nn) H hadamard_transform (in module mlx.core) hard_shrink (class in mlx.nn) hard_tanh (class in mlx.nn) HardShrink (class in mlx.nn) Hardswish (class in mlx.nn) hardswish (class in mlx.nn) HardTanh (class in mlx.nn) he_normal (in module mlx.nn.init) he_uniform (in module mlx.nn.init) hinge_loss (class in mlx.nn.losses) huber_loss (class in mlx.nn.losses) I identity (in module mlx.core) (in module mlx.nn.init) ifft (in module mlx.core.fft) ifft2 (in module mlx.core.fft) ifftn (in module mlx.core.fft) ifftshift (in module mlx.core.fft) imag (in module mlx.core) (in module mlx.core.array) import_function (in module mlx.core) init (in module mlx.core.distributed) (in module mlx.optimizers) inner (in module mlx.core) InstanceNorm (class in mlx.nn) inv (in module mlx.core.linalg) irfft (in module mlx.core.fft) irfft2 (in module mlx.core.fft) irfftn (in module mlx.core.fft) is_available (in module mlx.core.cuda) (in module mlx.core.distributed) (in module mlx.core.metal) isclose (in module mlx.core) isfinite (in module mlx.core) isinf (in module mlx.core) isnan (in module mlx.core) isneginf (in module mlx.core) isposinf (in module mlx.core) issubdtype (in module mlx.core) item (in module mlx.core.array) itemsize (in module mlx.core.array) J join_schedules (in module mlx.optimizers) jvp (in module mlx.core) K kl_div_loss (class in mlx.nn.losses) kron (in module mlx.core) L l1_loss (class in mlx.nn.losses) layer_norm (in module mlx.core.fast) LayerNorm (class in mlx.nn) leaf_modules (in module mlx.nn) leaky_relu (class in mlx.nn) LeakyReLU (class in mlx.nn) left_shift (in module mlx.core) less (in module mlx.core) less_equal (in module mlx.core) Linear (class in mlx.nn) linear_schedule (in module mlx.optimizers) linspace (in module mlx.core) Lion (class in mlx.optimizers) load (in module mlx.core) load_weights (in module mlx.nn) log (in module mlx.core) (in module mlx.core.array) log10 (in module mlx.core) (in module mlx.core.array) log1p (in module mlx.core) (in module mlx.core.array) log2 (in module mlx.core) (in module mlx.core.array) log_cosh_loss (class in mlx.nn.losses) log_sigmoid (class in mlx.nn) log_softmax (class in mlx.nn) logaddexp (in module mlx.core) logcumsumexp (in module mlx.core) (in module mlx.core.array) logical_and (in module mlx.core) logical_not (in module mlx.core) logical_or (in module mlx.core) LogSigmoid (class in mlx.nn) LogSoftmax (class in mlx.nn) logsumexp (in module mlx.core) (in module mlx.core.array) LSTM (class in mlx.nn) lu (in module mlx.core.linalg) lu_factor (in module mlx.core.linalg) M margin_ranking_loss (class in mlx.nn.losses) matmul (in module mlx.core) max (in module mlx.core) (in module mlx.core.array) maximum (in module mlx.core) MaxPool1d (class in mlx.nn) MaxPool2d (class in mlx.nn) MaxPool3d (class in mlx.nn) mean (in module mlx.core) (in module mlx.core.array) median (in module mlx.core) meshgrid (in module mlx.core) metal_kernel (in module mlx.core.fast) min (in module mlx.core) (in module mlx.core.array) minimum (in module mlx.core) Mish (class in mlx.nn) mish (class in mlx.nn) Module (class in mlx.nn) modules (in module mlx.nn) moveaxis (in module mlx.core) (in module mlx.core.array) mse_loss (class in mlx.nn.losses) MultiHeadAttention (class in mlx.nn) MultiOptimizer (class in mlx.optimizers) multiply (in module mlx.core) Muon (class in mlx.optimizers) N named_modules (in module mlx.nn) nan_to_num (in module mlx.core) nbytes (in module mlx.core.array) ndim (in module mlx.core.array) negative (in module mlx.core) new_stream (in module mlx.core) nll_loss (class in mlx.nn.losses) norm (in module mlx.core.linalg) normal (in module mlx.nn.init) not_equal (in module mlx.core) O ones (in module mlx.core) ones_like (in module mlx.core) Optimizer (class in mlx.optimizers) outer (in module mlx.core) P pad (in module mlx.core) parameters (in module mlx.nn) partition (in module mlx.core) pinv (in module mlx.core.linalg) power (in module mlx.core) PReLU (class in mlx.nn) prelu (class in mlx.nn) prod (in module mlx.core) (in module mlx.core.array) put_along_axis (in module mlx.core) Q qr (in module mlx.core.linalg) quantize (in module mlx.core) (in module mlx.nn) quantized_matmul (in module mlx.core) QuantizedAllToShardedLinear (class in mlx.nn) QuantizedEmbedding (class in mlx.nn) QuantizedLinear (class in mlx.nn) QuantizedShardedToAllLinear (class in mlx.nn) R radians (in module mlx.core) random_seed (in module mlx.core) random_split (in module mlx.core) random_uniform (in module mlx.core) real (in module mlx.core) (in module mlx.core.array) reciprocal (in module mlx.core) (in module mlx.core.array) recv (in module mlx.core.distributed) recv_like (in module mlx.core.distributed) ReLU (class in mlx.nn) relu (class in mlx.nn) ReLU2 (class in mlx.nn) relu2 (class in mlx.nn) ReLU6 (class in mlx.nn) relu6 (class in mlx.nn) remainder (in module mlx.core) repeat (in module mlx.core) reset_peak_memory (in module mlx.core) reshape (in module mlx.core) (in module mlx.core.array) rfft (in module mlx.core.fft) rfft2 (in module mlx.core.fft) rfftn (in module mlx.core.fft) right_shift (in module mlx.core) rms_norm (in module mlx.core.fast) RMSNorm (class in mlx.nn) RMSprop (class in mlx.optimizers) RNN (class in mlx.nn) roll (in module mlx.core) RoPE (class in mlx.nn) rope (in module mlx.core.fast) round (in module mlx.core) (in module mlx.core.array) rsqrt (in module mlx.core) (in module mlx.core.array) S save (in module mlx.core) save_gguf (in module mlx.core) save_safetensors (in module mlx.core) save_weights (in module mlx.nn) savez (in module mlx.core) savez_compressed (in module mlx.core) scaled_dot_product_attention (in module mlx.core.fast) SELU (class in mlx.nn) selu (class in mlx.nn) send (in module mlx.core.distributed) Sequential (class in mlx.nn) set_cache_limit (in module mlx.core) set_default_device (in module mlx.core) set_default_stream (in module mlx.core) set_dtype (in module mlx.nn) set_memory_limit (in module mlx.core) set_wired_limit (in module mlx.core) SGD (class in mlx.optimizers) shape (in module mlx.core.array) shard_inplace (in module mlx.nn.layers.distributed) shard_linear (in module mlx.nn.layers.distributed) ShardedToAllLinear (class in mlx.nn) Sigmoid (class in mlx.nn) sigmoid (class in mlx.nn) (in module mlx.core) sign (in module mlx.core) SiLU (class in mlx.nn) silu (class in mlx.nn) sin (in module mlx.core) (in module mlx.core.array) sinh (in module mlx.core) SinusoidalPositionalEncoding (class in mlx.nn) size (in module mlx.core.array) slice (in module mlx.core) slice_update (in module mlx.core) smooth_l1_loss (class in mlx.nn.losses) Softmax (class in mlx.nn) softmax (class in mlx.nn) (in module mlx.core) Softmin (class in mlx.nn) softmin (class in mlx.nn) Softplus (class in mlx.nn) softplus (class in mlx.nn) Softshrink (class in mlx.nn) softshrink (class in mlx.nn) Softsign (class in mlx.nn) solve (in module mlx.core.linalg) solve_triangular (in module mlx.core.linalg) sort (in module mlx.core) split (in module mlx.core) (in module mlx.core.array) sqrt (in module mlx.core) (in module mlx.core.array) square (in module mlx.core) (in module mlx.core.array) squeeze (in module mlx.core) (in module mlx.core.array) stack (in module mlx.core) start_capture (in module mlx.core.metal) state (in module mlx.nn) (in module mlx.optimizers) std (in module mlx.core) (in module mlx.core.array) Step (class in mlx.nn) step (class in mlx.nn) step_decay (in module mlx.optimizers) stop_capture (in module mlx.core.metal) stop_gradient (in module mlx.core) Stream (in module mlx.core) stream (in module mlx.core) subtract (in module mlx.core) sum (in module mlx.core) (in module mlx.core.array) svd (in module mlx.core.linalg) swapaxes (in module mlx.core) (in module mlx.core.array) synchronize (in module mlx.core) T T (in module mlx.core.array) take (in module mlx.core) take_along_axis (in module mlx.core) tan (in module mlx.core) Tanh (class in mlx.nn) tanh (class in mlx.nn) (in module mlx.core) tensordot (in module mlx.core) tile (in module mlx.core) tolist (in module mlx.core.array) topk (in module mlx.core) trace (in module mlx.core) train (in module mlx.nn) trainable_parameters (in module mlx.nn) training (in module mlx.nn) Transformer (class in mlx.nn) transpose (in module mlx.core) (in module mlx.core.array) tree_flatten (in module mlx.utils) tree_map (in module mlx.utils) tree_map_with_path (in module mlx.utils) tree_reduce (in module mlx.utils) tree_unflatten (in module mlx.utils) tri (in module mlx.core) tri_inv (in module mlx.core.linalg) tril (in module mlx.core) triplet_loss (class in mlx.nn.losses) triu (in module mlx.core) U unflatten (in module mlx.core) unfreeze (in module mlx.nn) uniform (in module mlx.nn.init) update (in module mlx.nn) (in module mlx.optimizers) update_modules (in module mlx.nn) Upsample (class in mlx.nn) V value_and_grad (in module mlx.core) (in module mlx.nn) var (in module mlx.core) (in module mlx.core.array) view (in module mlx.core) (in module mlx.core.array) vjp (in module mlx.core) vmap (in module mlx.core) W where (in module mlx.core) Z zeros (in module mlx.core) zeros_like (in module mlx.core)