Python API¶
Auto-generated from docstrings via mkdocstrings[python]. For the
hand-curated quick reference, see the README on GitHub.
Inference¶
demucs_onnx.separate ¶
separate(
input: PathLike,
output_dir: PathLike | None = None,
*,
model: str = DEFAULT_BAG_MODEL,
stems: Iterable[str] | None = None,
providers: str | Sequence[str] | None = "auto",
precision: Precision = "fp32",
cache_dir: PathLike | None = None,
token: str | None = None,
verbose: bool = False,
progress: bool = True,
output_format: Literal["wav", "mp3"] = "wav",
bitrate_kbps: int = 192,
mix_stems: Sequence[str] | None = None,
mix_output_name: str = "mix"
) -> dict[str, ndarray]
Run htdemucs ONNX separation on an audio file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
PathLike
|
Path to an audio file. Any rate |
required |
output_dir
|
PathLike | None
|
If given, write each stem under here. If |
None
|
model
|
str
|
Which model to run. Supported (v0.3.0):
|
DEFAULT_BAG_MODEL
|
stems
|
Iterable[str] | None
|
Subset of stems to materialize. For |
None
|
providers
|
str | Sequence[str] | None
|
ONNX Runtime execution providers. |
'auto'
|
precision
|
Precision
|
|
'fp32'
|
cache_dir
|
PathLike | None
|
Override the huggingface_hub model cache location. |
None
|
token
|
str | None
|
Hugging Face access token (only needed if you've made the model repos private). |
None
|
verbose
|
bool
|
Print chunk-by-chunk progress to stdout. |
False
|
progress
|
bool
|
Render a |
True
|
output_format
|
Literal['wav', 'mp3']
|
|
'wav'
|
bitrate_kbps
|
int
|
MP3 bitrate when |
192
|
mix_stems
|
Sequence[str] | None
|
Optional list of stem names to additionally sum into
a single output file (alongside the individual stem files).
For |
None
|
mix_output_name
|
str
|
Filename stem for the mixed output (default
|
'mix'
|
Returns:
| Type | Description |
|---|---|
dict[str, ndarray]
|
|
dict[str, ndarray]
|
float32, at the input file's native sample rate (auto-resampled). |
Source code in src/demucs_onnx/inference.py
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 | |
demucs_onnx.separate_all ¶
separate_all(
input: PathLike,
output_dir: PathLike | None = None,
**kwargs: object
) -> dict[str, ndarray]
Shorthand for :func:separate with the full htdemucs_ft bag.
Equivalent to separate(input, output_dir, model="htdemucs_ft", ...).
For the faster single-file htdemucs flavor pass model="htdemucs"
directly to :func:separate.
Source code in src/demucs_onnx/inference.py
demucs_onnx.separate_stem ¶
separate_stem(
input: PathLike,
stem: str,
output_dir: PathLike | None = None,
**kwargs: object
) -> ndarray
Run one specialist and return only that stem as a numpy array.
For stem ∈ {drums, bass, other, vocals} this picks the
htdemucs_ft specialist (4x faster than the bag). For stem ∈
{guitar, piano} this transparently switches to htdemucs_6s and
returns the requested row.
Source code in src/demucs_onnx/inference.py
demucs_onnx.prewarm ¶
prewarm(
models: Iterable[str] | None = None,
*,
precision: Precision = "fp32",
providers: str | Sequence[str] | None = "auto",
cache_dir: PathLike | None = None,
token: str | None = None
) -> dict[str, Path]
Pre-download + compile sessions for one or more models.
Useful at app startup or in a server context where the first
:func:separate call would otherwise trigger a multi-minute CoreML
graph-compile or a 300 MB download. models defaults to
["htdemucs_ft"] (the full 4-stem specialist bag).
Returns {model_or_stem: local_path}.
Source code in src/demucs_onnx/inference.py
demucs_onnx.session_pool ¶
session_pool() -> SessionPool
demucs_onnx.SessionPool ¶
Cache ort.InferenceSession objects keyed by (repo, precision, providers).
Inference sessions are expensive to create — particularly the first
time on the CoreML EP, which compiles the graph (~5+ min for
htdemucs on M-series macs). The pool keeps sessions alive across
successive :func:separate calls so subsequent runs reuse the same
compiled graph.
The pool is process-local and thread-safe under CPython's GIL for the
common case (get-or-create). Eviction is manual via :meth:clear.
Source code in src/demucs_onnx/inference.py
get ¶
Return a session for onnx_path, creating one if absent.
Source code in src/demucs_onnx/inference.py
demucs_onnx.list_models ¶
Public re-export so users don't need to dig into _hub.
Returns {alias: {"repo": url, "fp32": url, "fp16weights": url,
"sources": comma_joined, "kind": "specialist_bag"|"single"|"specialist"}}.
Source code in src/demucs_onnx/inference.py
Providers¶
demucs_onnx.auto_select_providers ¶
Return the best ORT provider list for this host.
Decision tree (first match wins):
- Browser (
pyodide/js) — WASM. Browser support is not wired through the inference path yet; this is a forward-looking hook so callers can branch off it. Real browser support lands in v0.3. - macOS arm64 with CoreML EP available —
CoreMLExecutionProvider. If CoreML EP is missing we warn once with the upgrade hint. - Linux with CUDA EP available —
CUDAExecutionProvider. If we're on Linux x86_64 but the GPU EP is missing we don't warn (it's normal to run inference on a CPU-only Linux box). - Windows with DML EP available —
DmlExecutionProvider. - Fallback —
CPUExecutionProvider.
Source code in src/demucs_onnx/providers.py
demucs_onnx.describe_runtime ¶
Return a flat dict describing the runtime environment.
Useful for debugging — print it from your code if auto selects
something surprising.
Source code in src/demucs_onnx/providers.py
Audio I/O¶
demucs_onnx.load_audio ¶
Load audio as float32 stereo at target_sr, return (audio, native_sr).
Returns the audio at target_sr (the model's sample rate) for use
by the inference loop, and the file's native sample rate alongside
so the caller can resample the output back to the original rate
before writing.
- Mono inputs are duplicated to L/R.
-
2-channel inputs are downmixed to the first two channels.
- Sample rates < 8000 Hz log a quality warning but are still processed (the model output will sound rough).
- Sample rates > 44.1 kHz are silently down-sampled.
Source code in src/demucs_onnx/_audio.py
demucs_onnx.write_audio ¶
Write audio, dispatching by path suffix. Defaults to WAV.
Source code in src/demucs_onnx/_audio.py
demucs_onnx.write_wav ¶
Write a (channels, samples) float32 array as a 16-bit PCM WAV.
Source code in src/demucs_onnx/_audio.py
demucs_onnx.write_mp3 ¶
Write a (channels, samples) float32 array as a CBR MP3.
Requires the lameenc extra (pip install demucs-onnx[mp3]).
Source code in src/demucs_onnx/_audio.py
Browser helpers¶
demucs_onnx.browser.wasm_config ¶
Return a copy-pasteable bundler config snippet for onnxruntime-web.
Supported bundlers: "vite" (default), "webpack",
"esbuild", "next", "rollup".
Source code in src/demucs_onnx/browser.py
demucs_onnx.browser.print_wasm_config ¶
demucs_onnx.browser.write_demo_dir ¶
Materialize the in-tree browser demo files under target.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target
|
Path
|
Directory to write into. Created if missing. Must be empty or non-existent (this helper refuses to overwrite existing files). |
required |
model_url
|
str | None
|
URL the demo should fetch the ONNX model from.
Defaults to |
None
|
react
|
bool
|
When True, emit a Vite + React + TS project. When False (default), emit the zero-build vanilla HTML/JS demo. |
False
|
Returns:
| Type | Description |
|---|---|
list[Path]
|
The list of files written, in the order they were created. |
Source code in src/demucs_onnx/browser.py
demucs_onnx.browser.model_browser_url ¶
model_browser_url(
model: str = DEFAULT_DEMO_MODEL,
precision: Literal[
"fp32", "fp16weights"
] = DEFAULT_DEMO_PRECISION,
) -> str
Return the direct download URL for model at precision.
Useful when emitting JS that wants to call
ort.InferenceSession.create(url, ...) against the HF Hub.
Source code in src/demucs_onnx/browser.py
Export pipeline¶
Requires the [export] extra (pip install 'demucs-onnx[export]').
demucs_onnx.export.exporter.export_to_onnx ¶
export_to_onnx(
checkpoint: str | Path,
output: str | Path,
*,
stem: str | None = None,
stems: Iterable[str] | None = None,
opset: int = 17,
parity_check: bool = True,
parity_tolerance: float = DEFAULT_PARITY_TOLERANCE,
sample_rate: int = SAMPLE_RATE,
segment_seconds: float = SEGMENT_S,
verbose: bool = True
) -> dict[str, Path]
Export a demucs/htdemucs checkpoint to one or more ONNX files.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint
|
str | Path
|
Either a name accepted by Bag models ( |
required |
output
|
str | Path
|
For a single model, a |
required |
stem
|
str | None
|
Optional single stem name (drums / bass / other / vocals).
Only valid when |
None
|
stems
|
Iterable[str] | None
|
Optional iterable of stem names to export from a bag. Default: all 4. |
None
|
opset
|
int
|
ONNX opset to target. 17 is the lowest with native STFT support; we don't actually use ONNX's STFT op (we replace it with conv1d) but staying ≥17 is a good baseline for ORT EPs. |
17
|
parity_check
|
bool
|
When True (default), run the patched and the original
PyTorch models on the same dummy input and abort if their
outputs differ by more than |
True
|
parity_tolerance
|
float
|
Max allowed abs diff between the patched PyTorch model and the original. Default 1e-3 matches the tolerance the StemSplit team uses for the published HF models. |
DEFAULT_PARITY_TOLERANCE
|
sample_rate
|
int
|
Sample rate the exported model expects. Changing this requires a matching change to the host inference code; demucs is trained at 44100 Hz. |
SAMPLE_RATE
|
segment_seconds
|
float
|
Segment length baked into the exported graph. |
SEGMENT_S
|
verbose
|
bool
|
Print progress to stdout. |
True
|
Returns:
| Type | Description |
|---|---|
dict[str, Path]
|
|
dict[str, Path]
|
|
Source code in src/demucs_onnx/export/exporter.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | |
demucs_onnx.export.exporter.verify_onnx_parity ¶
verify_onnx_parity(
checkpoint: str | Path,
onnx_path: str | Path,
*,
stem: str | None = None,
tolerance: float = DEFAULT_PARITY_TOLERANCE,
sample_rate: int = SAMPLE_RATE,
segment_seconds: float = SEGMENT_S,
verbose: bool = True
) -> float
Compare an exported ONNX model against the original PyTorch checkpoint.
Returns the max abs diff. Raises AssertionError if it exceeds
tolerance.
Source code in src/demucs_onnx/export/exporter.py
demucs_onnx.export.patch.patch_htdemucs_for_onnx ¶
Mutate an htdemucs sub-model in place so it has no complex tensors and no Python-dynamic code paths.
Returns the same model so calls can be chained.
The four patches applied:
model.segment:Fraction→float(:func:demucs_onnx.export.coerce_segment_to_float)CrossTransformerEncoder._get_pos_embedding: dropsrandom.randrange(:func:demucs_onnx.export.disable_random_pos_shift)nn.MultiheadAttention.forward: replaced with a primitive-only impl (:func:demucs_onnx.export.onnx_friendly_mha_forward)model._spec/_ispec/_magnitude/_mask: rewritten to thread a real(B, C, 2, F, T)tensor instead of complex tensors, backed by :class:~demucs_onnx.export.RealSTFTand :class:~demucs_onnx.export.RealISTFT.
Source code in src/demucs_onnx/export/patch.py
Individual patches¶
demucs_onnx.export.segment.coerce_segment_to_float ¶
Convert model.segment from a :class:Fraction to a float, if needed.
Idempotent and safe to call on already-coerced models.
Source code in src/demucs_onnx/export/segment.py
demucs_onnx.export.pos_embed.disable_random_pos_shift ¶
Replace CrossTransformerEncoder._get_pos_embedding with a
deterministic version that hardcodes shift = 0.
Also sets sin_random_shift = 0 on every module that has the attr,
for belt-and-braces when other code paths read it directly.
Source code in src/demucs_onnx/export/pos_embed.py
demucs_onnx.export.mha.onnx_friendly_mha_forward ¶
onnx_friendly_mha_forward(
self_: MultiheadAttention,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Tensor | None = None,
need_weights: bool = True,
attn_mask: Tensor | None = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> tuple[Tensor, Tensor | None]
Drop-in replacement for nn.MultiheadAttention.forward that uses
only ops with stable ONNX symbolics.
Supports the call shapes htdemucs actually uses:
batch_first ∈ {True, False}need_weights = False(no weight tensor returned)attn_mask = Nonekey_padding_mask = None- cross-attention (query != key == value, or all three differ)
The signature matches nn.MultiheadAttention.forward so we can
install it via types.MethodType without further wrapping.
Source code in src/demucs_onnx/export/mha.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | |
demucs_onnx.export.stft.RealSTFT ¶
Bases: Module
ONNX-exportable STFT that emits real/imag as two channel groups.
Matches the demucs spectro(x, n_fft, hop_length, pad=0) API but
returns a real tensor of shape (..., 2, F, T) instead of a complex
(..., F, T) tensor.
Layout matches what view_as_real would produce: channel 0 = real,
channel 1 = imag.
Source code in src/demucs_onnx/export/stft.py
demucs_onnx.export.stft.RealISTFT ¶
Bases: Module
ONNX-exportable inverse STFT, the inverse of :class:RealSTFT.
Input shape: (..., 2, F, T). Output shape: (..., L).
Reconstruction formula (per sample n of frame OLA):
.. math::
x[n] = \frac{1}{\sqrt{N}} w[n] \Big( R_0 + (-1)^n R_{N/2}
+ 2 \sum_{k=1}^{N/2-1} R_k \cos\frac{2\pi k n}{N}
- I_k \sin\frac{2\pi k n}{N} \Big)
The factor of 2 doubles the contribution from bins 1..N/2-1 to
account for the dropped negative frequencies. DC and Nyquist appear
once, so we halve their entries. After OLA we divide by the
OLA(window^2) envelope to undo the analysis windowing.
Source code in src/demucs_onnx/export/stft.py
demucs_onnx.export.stft.make_stft_kernels ¶
Build (cos, sin) DFT kernels of shape (n_fft//2 + 1, 1, n_fft).
The kernels reproduce torch.stft(window=hann, win_length=n_fft,
n_fft=n_fft, normalized=True, center=True) with hop = n_fft // 4.
Computation is done in float64 to keep the high-frequency bins precise, then cast to float32 for the actual conv weights.