mlx-lm Model Bringup Process
This is a note — quick thoughts, possibly AI-assisted. Not a fully fleshed article.
mlxapple-siliconllminferencemodel-support
How new model architectures get added to mlx-lm.
Model Loading Flow
- Download model from HuggingFace (weights +
config.json) - Read
model_typefromconfig.json(e.g.,"llama","qwen3_5","gemma4") importlib.import_module(f"mlx_lm.models.{model_type}")to find architecture- Module must export
ModelandModelArgsclasses - No matching module ->
ValueError: Model type X not supported MODEL_REMAPPINGdict handles aliases
Required Exports
ModelArgs (dataclass)
- Subclass of
BaseModelArgs(providesfrom_dictfor parsingconfig.json) - All architecture hyperparameters: hidden size, layers, heads, vocab size, RoPE config
Model (nn.Module)
__call__(self, inputs, cache=None, input_embeddings=None) -> logitssanitize(self, weights)— clean up weight names, drop unused keysmake_cache()— return correct KV cache type per layer- Optional:
shard()for multi-device inference
Internal pattern:
Embedding -> [TransformerBlock x N] -> RMSNorm -> LM HeadEach block: Input -> LayerNorm -> Attention -> Residual -> LayerNorm -> MLP -> Residual
Complexity Range
| Architecture | Lines | Why |
|---|---|---|
| Llama | ~274 | Standard dense transformer, baseline |
| Qwen3.5 | ~524 | Hybrid attention, MoE routing, vision, gated delta updates |
| DeepSeek V3 | ~600+ | MoE with shared experts, multi-latent attention |
Llama-like architectures (Mistral, Yi) can reuse components or be thin wrappers. Novel architectures need full forward pass from scratch.
What Makes Bringup Non-Trivial
- Weight mapping — HF weight names don't always match MLX module structure.
sanitize()handles renames, drops, reshapes. Wrong mapping = silent correctness bugs. - Attention variants — GQA, MQA, sliding window, linear, sparse all need different implementations.
mx.fast.scaled_dot_product_attentioncovers standard SDPA only. - RoPE variants — standard, NTK-aware, YaRN, dynamic.
rope_utils.pyhandles common ones. - KV cache types — Standard vs RotatingKVCache (sliding window) vs ArraysCache (SSM). Hybrid models use different types per layer.
- Quantization — must work with MLX's quantization. Quantized SDPA has its own codepath requiring specific tensor layouts.
Shared Infrastructure
base.py—BaseModelArgs, causal mask, SDPA (standard + quantized)cache.py—KVCache,RotatingKVCache,ArraysCacherope_utils.py— RoPE initialization for common scaling schemesactivations.py— SwiGLU etc.- Models can import from each other (e.g., Qwen3.5 imports from
qwen3_next)
Example Bringup PRs
Straightforward (follows existing pattern):
Non-trivial (new concepts):
- #940 — Mamba — SSM, custom state management
- #1336 — Gemma3 — sliding window + global attention hybrid
- #1191 — DeepSeek V3 — MoE, multi-latent attention, pipeline parallelism
Follow-up fixes (bringup isn't done at merge):
Bottom Line
- Known architecture (Llama/Mistral/Qwen-family) -> likely already supported or trivial to add
- New mechanism (novel attention, novel MoE, hybrid SSM) -> 300-600 lines of new MLX code + weight mapping
- ~117 architectures currently supported — check
mlx_lm/models/before assuming unsupported