mlx.MlxEngine

Runs batched MLX inference on a dedicated worker thread.

Usage

Source

mlx.MlxEngine()

The constructor starts a daemon worker thread that loads the model from fused_dir, applies MLXPatches, and precomputes a prompt cache for prefix_messages. Await ensure_loaded to surface load failures before generating.

Parameters

fused_dir: Path

Directory of the MLX model to load, such as the path returned by AdapterFuser.ensure_fused.

logits_processor_factory: Callable[[Any], Callable[…, Any]]

Builds the logits processor from the loaded tokenizer; the processor is applied to every generation call.

prefix_messages: list[dict[str, str]]

Chat messages shared by every conversation; their prompt cache is computed once and reused across batches. Each conversation passed to generate must start with these messages.

batch_size: int

Number of conversations generated per batch.

worker_name: str = "mlx"
Name of the worker thread.

Raises

RuntimeError
When not running on macOS with Apple Silicon.

Methods

Name Description
close() Signal the worker thread to exit after completing already-queued jobs.
ensure_loaded() Wait for the worker thread to finish loading, re-raising the error if loading failed.
generate() Generate one single-token completion per conversation.
peak_memory_gb() Return the process’s peak resident set size in GiB.
submit() Run fn(*args) on the worker thread and await its result.

close()

Signal the worker thread to exit after completing already-queued jobs.

Usage

Source

close()

ensure_loaded()

Wait for the worker thread to finish loading, re-raising the error if loading failed.

Usage

Source

ensure_loaded()

generate()

Generate one single-token completion per conversation.

Usage

Source

generate(message_lists, on_progress)

Conversations are processed in chunks of batch_size, ordered by the length of each conversation’s final message so similar-length prompts batch together. Every conversation must start with the engine’s prefix_messages; the shared prefix is served from the precomputed prompt cache.

Parameters
message_lists: list[list[dict[str, str]]]

Conversations, each a list of chat messages with role and content keys.

on_progress: Callable[[int], None]
Called after each chunk with the number of conversations completed in that chunk.
Returns
list[str]
Generated texts, in the same order as message_lists.

peak_memory_gb()

Return the process’s peak resident set size in GiB.

Usage

Source

peak_memory_gb()

submit()

Run fn(*args) on the worker thread and await its result.

Usage

Source

submit(fn, *args)
Parameters
fn: Callable[…, R]

Callable to execute on the worker thread.

*args: Any
Positional arguments passed to fn.
Returns
R
The value returned by fn; an exception raised by fn propagates to the awaiter.