Skip to content

Python API

Auto-generated from docstrings via mkdocstrings[python]. For the hand-curated quick reference, see the README on GitHub.


Inference

demucs_onnx.separate

separate(
    input: PathLike,
    output_dir: PathLike | None = None,
    *,
    model: str = DEFAULT_BAG_MODEL,
    stems: Iterable[str] | None = None,
    providers: str | Sequence[str] | None = "auto",
    precision: Precision = "fp32",
    cache_dir: PathLike | None = None,
    token: str | None = None,
    verbose: bool = False,
    progress: bool = True,
    output_format: Literal["wav", "mp3"] = "wav",
    bitrate_kbps: int = 192,
    mix_stems: Sequence[str] | None = None,
    mix_output_name: str = "mix"
) -> dict[str, ndarray]

Run htdemucs ONNX separation on an audio file.

Parameters:

Name Type Description Default
input PathLike

Path to an audio file. Any rate soundfile can decode (WAV/FLAC/OGG/MP3 via libsndfile 1.1+). Mono input is upmixed to stereo; non-44.1 kHz rates are transparently resampled in and back out.

required
output_dir PathLike | None

If given, write each stem under here. If None, results are returned in memory only.

None
model str

Which model to run. Supported (v0.3.0):

  • "htdemucs_ft" (default) — full 4-stem specialist bag.
  • "htdemucs_ft_<stem>" | "<stem>" — single FT specialist.
  • "htdemucs" — single-file 4-stem model. ~30% faster than the FT bag on the same hardware, slightly lower SDR.
  • "htdemucs_6s" — single-file 6-stem model. Same 4 stems plus "guitar" and "piano". The only ONNX export of the 6-stem variant on the Hub.
DEFAULT_BAG_MODEL
stems Iterable[str] | None

Subset of stems to materialize. For htdemucs_ft this also skips the specialists you don't need (saves time). For htdemucs and htdemucs_6s the model always computes all stems internally; this just filters the returned dict.

None
providers str | Sequence[str] | None

ONNX Runtime execution providers. "auto" (default) picks the best available EP for the host (CoreML on macOS arm64, CUDA on Linux+NVIDIA, DML on Windows, CPU otherwise). Pass a short alias ("cpu", "coreml", "cuda", "dml"), an explicit ORT provider name, or a list of either to override.

'auto'
precision Precision

"fp32" (default) or "fp16weights". The latter downloads a smaller variant (~1.9x smaller) but is otherwise identical at runtime — same RAM, same latency, max abs diff vs fp32 is ~6e-5.

'fp32'
cache_dir PathLike | None

Override the huggingface_hub model cache location.

None
token str | None

Hugging Face access token (only needed if you've made the model repos private).

None
verbose bool

Print chunk-by-chunk progress to stdout.

False
progress bool

Render a tqdm progress bar (only when verbose is False and stdout is a TTY).

True
output_format Literal['wav', 'mp3']

"wav" (default) or "mp3". MP3 requires pip install 'demucs-onnx[mp3]'.

'wav'
bitrate_kbps int

MP3 bitrate when output_format='mp3'. 32-320, default 192.

192
mix_stems Sequence[str] | None

Optional list of stem names to additionally sum into a single output file (alongside the individual stem files). For htdemucs_6s you can mix any of the 6 stems, e.g. ("drums","bass","guitar","piano") for a full backing track with guitar and piano kept in.

None
mix_output_name str

Filename stem for the mixed output (default "mix", becomes mix.wav / mix.mp3).

'mix'

Returns:

Type Description
dict[str, ndarray]

{stem_name: numpy.ndarray of shape (channels=2, samples)} in

dict[str, ndarray]

float32, at the input file's native sample rate (auto-resampled).

Source code in src/demucs_onnx/inference.py
def separate(input: PathLike,
             output_dir: PathLike | None = None,
             *,
             model: str = DEFAULT_BAG_MODEL,
             stems: Iterable[str] | None = None,
             providers: str | Sequence[str] | None = "auto",
             precision: Precision = "fp32",
             cache_dir: PathLike | None = None,
             token: str | None = None,
             verbose: bool = False,
             progress: bool = True,
             output_format: Literal["wav", "mp3"] = "wav",
             bitrate_kbps: int = 192,
             mix_stems: Sequence[str] | None = None,
             mix_output_name: str = "mix",
             ) -> dict[str, np.ndarray]:
    """Run htdemucs ONNX separation on an audio file.

    Args:
        input: Path to an audio file. Any rate ``soundfile`` can decode
            (WAV/FLAC/OGG/MP3 via libsndfile 1.1+). Mono input is
            upmixed to stereo; non-44.1 kHz rates are transparently
            resampled in *and* back out.
        output_dir: If given, write each stem under here. If ``None``,
            results are returned in memory only.
        model: Which model to run. Supported (v0.3.0):

            - ``"htdemucs_ft"`` (default) — full 4-stem specialist bag.
            - ``"htdemucs_ft_<stem>" | "<stem>"`` — single FT specialist.
            - ``"htdemucs"`` — single-file 4-stem model. ~30% faster
              than the FT bag on the same hardware, slightly lower SDR.
            - ``"htdemucs_6s"`` — single-file 6-stem model. Same 4 stems
              plus ``"guitar"`` and ``"piano"``. The only ONNX export of
              the 6-stem variant on the Hub.

        stems: Subset of stems to materialize. For ``htdemucs_ft`` this
            also skips the specialists you don't need (saves time). For
            ``htdemucs`` and ``htdemucs_6s`` the model always computes
            all stems internally; this just filters the returned dict.
        providers: ONNX Runtime execution providers. ``"auto"`` (default)
            picks the best available EP for the host (CoreML on macOS
            arm64, CUDA on Linux+NVIDIA, DML on Windows, CPU otherwise).
            Pass a short alias (``"cpu"``, ``"coreml"``, ``"cuda"``,
            ``"dml"``), an explicit ORT provider name, or a list of
            either to override.
        precision: ``"fp32"`` (default) or ``"fp16weights"``. The latter
            downloads a smaller variant (~1.9x smaller) but is otherwise
            identical at runtime — same RAM, same latency, max abs diff
            vs fp32 is ~6e-5.
        cache_dir: Override the huggingface_hub model cache location.
        token: Hugging Face access token (only needed if you've made the
            model repos private).
        verbose: Print chunk-by-chunk progress to stdout.
        progress: Render a ``tqdm`` progress bar (only when ``verbose``
            is False and stdout is a TTY).
        output_format: ``"wav"`` (default) or ``"mp3"``. MP3 requires
            ``pip install 'demucs-onnx[mp3]'``.
        bitrate_kbps: MP3 bitrate when ``output_format='mp3'``. 32-320,
            default 192.
        mix_stems: Optional list of stem names to additionally sum into
            a single output file (alongside the individual stem files).
            For ``htdemucs_6s`` you can mix any of the 6 stems, e.g.
            ``("drums","bass","guitar","piano")`` for a full backing
            track with guitar and piano kept in.
        mix_output_name: Filename stem for the mixed output (default
            ``"mix"``, becomes ``mix.wav`` / ``mix.mp3``).

    Returns:
        ``{stem_name: numpy.ndarray of shape (channels=2, samples)}`` in
        float32, at the **input file's native sample rate** (auto-resampled).
    """
    audio, native_sr = load_audio(input, target_sr=SAMPLE_RATE)
    onnx_providers = resolve_providers(providers)
    pool = session_pool()

    canonical = resolve_model_name(model)

    if canonical == "htdemucs_ft":
        wanted = list(stems) if stems is not None else list(BAG_STEMS)
        for s in wanted:
            if s not in BAG_STEMS:
                raise ValueError(
                    f"unknown stem {s!r}; htdemucs_ft predicts {BAG_STEMS}",
                )
        model_paths = {
            stem: download_stem_model(
                stem, cache_dir=cache_dir, token=token, precision=precision,
            )
            for stem in wanted
        }
        sessions = {
            stem: pool.get(path, onnx_providers)
            for stem, path in model_paths.items()
        }
        out_model_sr = _chunked_separate_specialists(
            sessions, audio, verbose=verbose, progress=progress,
        )

    elif canonical in MODEL_REGISTRY and MODEL_REGISTRY[canonical].kind == "single":
        info = MODEL_REGISTRY[canonical]
        path = download_single_model(
            canonical, cache_dir=cache_dir, token=token, precision=precision,
        )
        session = pool.get(path, onnx_providers)
        wanted = list(stems) if stems is not None else list(info.sources)
        for s in wanted:
            if s not in info.sources:
                raise ValueError(
                    f"unknown stem {s!r}; {canonical} predicts "
                    f"{tuple(info.sources)}",
                )
        out_model_sr = _chunked_separate_single(
            session, info.sources, audio, wanted=wanted,
            verbose=verbose, progress=progress,
        )

    elif canonical.startswith("htdemucs_ft_"):
        target_stem = canonical.replace("htdemucs_ft_", "")
        if target_stem not in BAG_STEMS:
            raise ValueError(
                f"unknown model {model!r}. Known specialists: "
                f"htdemucs_ft_{{drums,bass,other,vocals}}.",
            )
        if stems is not None and list(stems) != [target_stem]:
            raise ValueError(
                f"model {model!r} only predicts {target_stem!r}; pass "
                f"`model='htdemucs_ft'` or `model='htdemucs'` to get other stems.",
            )
        path = download_stem_model(
            target_stem, cache_dir=cache_dir, token=token, precision=precision,
        )
        sessions = {target_stem: pool.get(path, onnx_providers)}
        out_model_sr = _chunked_separate_specialists(
            sessions, audio, verbose=verbose, progress=progress,
        )

    else:
        raise ValueError(
            f"unknown model {model!r}. Known: {sorted(MODEL_REGISTRY)} "
            f"or any htdemucs_ft specialist (drums/bass/other/vocals).",
        )

    if native_sr != SAMPLE_RATE:
        log.info("demucs-onnx: resampling output %d Hz -> %d Hz", SAMPLE_RATE, native_sr)
        out_native = {
            stem: resample_to_native(arr, SAMPLE_RATE, native_sr)
            for stem, arr in out_model_sr.items()
        }
    else:
        out_native = out_model_sr

    if output_dir is not None:
        out_path = Path(output_dir)
        out_path.mkdir(parents=True, exist_ok=True)
        ext = "mp3" if output_format == "mp3" else "wav"
        for stem, audio_out in out_native.items():
            target = out_path / f"{stem}.{ext}"
            write_audio(target, audio_out, native_sr, bitrate_kbps=bitrate_kbps)
            if verbose:
                print(f"  wrote {target}")

        if mix_stems:
            missing = [s for s in mix_stems if s not in out_native]
            if missing:
                raise ValueError(
                    f"mix_stems references stems we did not produce: {missing}. "
                    f"Available: {sorted(out_native)}",
                )
            mixed = np.sum(
                np.stack([out_native[s] for s in mix_stems], axis=0), axis=0,
            ).astype(np.float32)
            target = out_path / f"{mix_output_name}.{ext}"
            write_audio(target, mixed, native_sr, bitrate_kbps=bitrate_kbps)
            if verbose:
                print(f"  wrote {target} (sum of {list(mix_stems)})")

    return out_native

demucs_onnx.separate_all

separate_all(
    input: PathLike,
    output_dir: PathLike | None = None,
    **kwargs: object
) -> dict[str, ndarray]

Shorthand for :func:separate with the full htdemucs_ft bag.

Equivalent to separate(input, output_dir, model="htdemucs_ft", ...). For the faster single-file htdemucs flavor pass model="htdemucs" directly to :func:separate.

Source code in src/demucs_onnx/inference.py
def separate_all(input: PathLike,
                 output_dir: PathLike | None = None,
                 **kwargs: object) -> dict[str, np.ndarray]:
    """Shorthand for :func:`separate` with the full htdemucs_ft bag.

    Equivalent to ``separate(input, output_dir, model="htdemucs_ft", ...)``.
    For the faster single-file ``htdemucs`` flavor pass ``model="htdemucs"``
    directly to :func:`separate`.
    """
    return separate(input, output_dir, model="htdemucs_ft", **kwargs)  # type: ignore[arg-type]

demucs_onnx.separate_stem

separate_stem(
    input: PathLike,
    stem: str,
    output_dir: PathLike | None = None,
    **kwargs: object
) -> ndarray

Run one specialist and return only that stem as a numpy array.

For stem ∈ {drums, bass, other, vocals} this picks the htdemucs_ft specialist (4x faster than the bag). For stem ∈ {guitar, piano} this transparently switches to htdemucs_6s and returns the requested row.

Source code in src/demucs_onnx/inference.py
def separate_stem(input: PathLike,
                  stem: str,
                  output_dir: PathLike | None = None,
                  **kwargs: object) -> np.ndarray:
    """Run one specialist and return only that stem as a numpy array.

    For ``stem`` ∈ {drums, bass, other, vocals} this picks the
    htdemucs_ft specialist (4x faster than the bag). For ``stem`` ∈
    {guitar, piano} this transparently switches to ``htdemucs_6s`` and
    returns the requested row.
    """
    if stem in BAG_STEMS:
        out = separate(input, output_dir, model=f"htdemucs_ft_{stem}", **kwargs)  # type: ignore[arg-type]
        return out[stem]
    if stem in ("guitar", "piano"):
        out = separate(
            input, output_dir, model="htdemucs_6s", stems=[stem], **kwargs,  # type: ignore[arg-type]
        )
        return out[stem]
    raise ValueError(
        f"stem must be one of {(*BAG_STEMS, 'guitar', 'piano')}, got {stem!r}",
    )

demucs_onnx.prewarm

prewarm(
    models: Iterable[str] | None = None,
    *,
    precision: Precision = "fp32",
    providers: str | Sequence[str] | None = "auto",
    cache_dir: PathLike | None = None,
    token: str | None = None
) -> dict[str, Path]

Pre-download + compile sessions for one or more models.

Useful at app startup or in a server context where the first :func:separate call would otherwise trigger a multi-minute CoreML graph-compile or a 300 MB download. models defaults to ["htdemucs_ft"] (the full 4-stem specialist bag).

Returns {model_or_stem: local_path}.

Source code in src/demucs_onnx/inference.py
def prewarm(models: Iterable[str] | None = None, *,
            precision: Precision = "fp32",
            providers: str | Sequence[str] | None = "auto",
            cache_dir: PathLike | None = None,
            token: str | None = None) -> dict[str, Path]:
    """Pre-download + compile sessions for one or more models.

    Useful at app startup or in a server context where the first
    :func:`separate` call would otherwise trigger a multi-minute CoreML
    graph-compile or a 300 MB download. ``models`` defaults to
    ``["htdemucs_ft"]`` (the full 4-stem specialist bag).

    Returns ``{model_or_stem: local_path}``.
    """
    if models is None:
        targets: list[str] = [DEFAULT_BAG_MODEL]
    else:
        targets = [resolve_model_name(m) for m in models]

    onnx_providers = resolve_providers(providers)
    pool = session_pool()
    paths: dict[str, Path] = {}
    for canonical in targets:
        if canonical == "htdemucs_ft":
            for stem in BAG_STEMS:
                p = download_stem_model(
                    stem, cache_dir=cache_dir, token=token, precision=precision,
                )
                pool.get(p, onnx_providers)
                paths[stem] = p
        elif canonical in MODEL_REGISTRY and MODEL_REGISTRY[canonical].kind == "single":
            p = download_single_model(
                canonical, cache_dir=cache_dir, token=token, precision=precision,
            )
            pool.get(p, onnx_providers)
            paths[canonical] = p
        elif canonical.startswith("htdemucs_ft_"):
            stem = canonical.replace("htdemucs_ft_", "")
            p = download_stem_model(
                stem, cache_dir=cache_dir, token=token, precision=precision,
            )
            pool.get(p, onnx_providers)
            paths[stem] = p
        else:
            raise ValueError(
                f"unknown model {canonical!r} for prewarm. "
                f"Known: {sorted(MODEL_REGISTRY)}",
            )
    return paths

demucs_onnx.session_pool

session_pool() -> SessionPool

Return the process-wide default :class:SessionPool.

Source code in src/demucs_onnx/inference.py
def session_pool() -> SessionPool:
    """Return the process-wide default :class:`SessionPool`."""
    return _DEFAULT_POOL

demucs_onnx.SessionPool

SessionPool()

Cache ort.InferenceSession objects keyed by (repo, precision, providers).

Inference sessions are expensive to create — particularly the first time on the CoreML EP, which compiles the graph (~5+ min for htdemucs on M-series macs). The pool keeps sessions alive across successive :func:separate calls so subsequent runs reuse the same compiled graph.

The pool is process-local and thread-safe under CPython's GIL for the common case (get-or-create). Eviction is manual via :meth:clear.

Source code in src/demucs_onnx/inference.py
def __init__(self) -> None:
    self._sessions: dict[tuple[str, str, tuple[str, ...]], ort.InferenceSession] = {}

get

get(
    onnx_path: PathLike, providers: Sequence[str]
) -> InferenceSession

Return a session for onnx_path, creating one if absent.

Source code in src/demucs_onnx/inference.py
def get(self, onnx_path: PathLike, providers: Sequence[str]) -> ort.InferenceSession:
    """Return a session for ``onnx_path``, creating one if absent."""
    key = (str(onnx_path), "", tuple(providers))
    sess = self._sessions.get(key)
    if sess is None:
        sess = _make_session(onnx_path, providers)
        self._sessions[key] = sess
    return sess

clear

clear() -> None

Drop every cached session. Frees the underlying ORT memory.

Source code in src/demucs_onnx/inference.py
def clear(self) -> None:
    """Drop every cached session. Frees the underlying ORT memory."""
    self._sessions.clear()

demucs_onnx.list_models

list_models() -> dict[str, dict[str, str]]

Public re-export so users don't need to dig into _hub.

Returns {alias: {"repo": url, "fp32": url, "fp16weights": url, "sources": comma_joined, "kind": "specialist_bag"|"single"|"specialist"}}.

Source code in src/demucs_onnx/inference.py
def list_models() -> dict[str, dict[str, str]]:
    """Public re-export so users don't need to dig into ``_hub``.

    Returns ``{alias: {"repo": url, "fp32": url, "fp16weights": url,
    "sources": comma_joined, "kind": "specialist_bag"|"single"|"specialist"}}``.
    """
    return _list_models()

Providers

demucs_onnx.auto_select_providers

auto_select_providers() -> list[str]

Return the best ORT provider list for this host.

Decision tree (first match wins):

  1. Browser (pyodide / js) — WASM. Browser support is not wired through the inference path yet; this is a forward-looking hook so callers can branch off it. Real browser support lands in v0.3.
  2. macOS arm64 with CoreML EP available — CoreMLExecutionProvider. If CoreML EP is missing we warn once with the upgrade hint.
  3. Linux with CUDA EP available — CUDAExecutionProvider. If we're on Linux x86_64 but the GPU EP is missing we don't warn (it's normal to run inference on a CPU-only Linux box).
  4. Windows with DML EP available — DmlExecutionProvider.
  5. FallbackCPUExecutionProvider.
Source code in src/demucs_onnx/providers.py
def auto_select_providers() -> list[str]:
    """Return the best ORT provider list for this host.

    Decision tree (first match wins):

    1. **Browser** (``pyodide`` / ``js``) — WASM. Browser support is not
       wired through the inference path yet; this is a forward-looking
       hook so callers can branch off it. Real browser support lands in
       v0.3.
    2. **macOS arm64** with CoreML EP available — ``CoreMLExecutionProvider``.
       If CoreML EP is missing we warn once with the upgrade hint.
    3. **Linux** with CUDA EP available — ``CUDAExecutionProvider``. If
       we're on Linux x86_64 but the GPU EP is missing we *don't* warn
       (it's normal to run inference on a CPU-only Linux box).
    4. **Windows** with DML EP available — ``DmlExecutionProvider``.
    5. **Fallback** — ``CPUExecutionProvider``.
    """
    if in_browser():
        return ["WasmExecutionProvider", "CPUExecutionProvider"]

    avail = _available_providers()
    system = platform.system()
    machine = platform.machine().lower()

    if system == "Darwin":
        if machine in ("arm64", "aarch64"):
            if "CoreMLExecutionProvider" in avail:
                return ["CoreMLExecutionProvider", "CPUExecutionProvider"]
            _warn_no_gpu_once(
                "running on Apple Silicon but CoreMLExecutionProvider is not "
                "exposed by onnxruntime.",
                "Upgrade with `pip install -U 'onnxruntime>=1.17'` — modern "
                "wheels include CoreML EP by default.",
            )
        return ["CPUExecutionProvider"]

    if system == "Linux":
        if "CUDAExecutionProvider" in avail:
            return ["CUDAExecutionProvider", "CPUExecutionProvider"]
        return ["CPUExecutionProvider"]

    if system == "Windows":
        if "DmlExecutionProvider" in avail:
            return ["DmlExecutionProvider", "CPUExecutionProvider"]
        if "CUDAExecutionProvider" in avail:
            return ["CUDAExecutionProvider", "CPUExecutionProvider"]
        return ["CPUExecutionProvider"]

    return ["CPUExecutionProvider"]

demucs_onnx.describe_runtime

describe_runtime() -> dict[str, object]

Return a flat dict describing the runtime environment.

Useful for debugging — print it from your code if auto selects something surprising.

Source code in src/demucs_onnx/providers.py
def describe_runtime() -> dict[str, object]:
    """Return a flat dict describing the runtime environment.

    Useful for debugging — print it from your code if ``auto`` selects
    something surprising.
    """
    return {
        "system":            platform.system(),
        "machine":           platform.machine(),
        "python":            platform.python_version(),
        "onnxruntime":       ort.__version__,
        "available_providers": list(_available_providers()),
        "in_browser":        in_browser(),
    }

Audio I/O

demucs_onnx.load_audio

load_audio(
    path: PathLike, target_sr: int = MODEL_SAMPLE_RATE
) -> tuple[ndarray, int]

Load audio as float32 stereo at target_sr, return (audio, native_sr).

Returns the audio at target_sr (the model's sample rate) for use by the inference loop, and the file's native sample rate alongside so the caller can resample the output back to the original rate before writing.

  • Mono inputs are duplicated to L/R.
  • 2-channel inputs are downmixed to the first two channels.

  • Sample rates < 8000 Hz log a quality warning but are still processed (the model output will sound rough).
  • Sample rates > 44.1 kHz are silently down-sampled.
Source code in src/demucs_onnx/_audio.py
def load_audio(path: PathLike, target_sr: int = MODEL_SAMPLE_RATE,
               ) -> tuple[np.ndarray, int]:
    """Load audio as float32 stereo at ``target_sr``, return ``(audio, native_sr)``.

    Returns the audio at ``target_sr`` (the model's sample rate) for use
    by the inference loop, and the file's *native* sample rate alongside
    so the caller can resample the output back to the original rate
    before writing.

    - Mono inputs are duplicated to L/R.
    - >2-channel inputs are downmixed to the first two channels.
    - Sample rates < 8000 Hz log a quality warning but are still
      processed (the model output will sound rough).
    - Sample rates > 44.1 kHz are silently down-sampled.
    """
    audio, sr = sf.read(str(path), dtype="float32", always_2d=True)
    audio = audio.T
    if audio.shape[0] == 1:
        audio = np.tile(audio, (2, 1))
        log.info("demucs-onnx: mono input upmixed to stereo")
    elif audio.shape[0] > 2:
        audio = audio[:2]
        log.info("demucs-onnx: >2-channel input downmixed to first 2 channels")

    if sr != target_sr:
        if sr < 8000:
            log.warning(
                "demucs-onnx: very low input sample rate (%d Hz) — "
                "output quality will be poor.", sr,
            )
        log.info("demucs-onnx: resampling %d Hz -> %d Hz for inference", sr, target_sr)
        audio_model = _resample_audio(audio, sr, target_sr)
    else:
        audio_model = audio
    return np.ascontiguousarray(audio_model, dtype=np.float32), sr

demucs_onnx.write_audio

write_audio(
    path: PathLike,
    audio: ndarray,
    sample_rate: int,
    *,
    bitrate_kbps: int = 192
) -> None

Write audio, dispatching by path suffix. Defaults to WAV.

Source code in src/demucs_onnx/_audio.py
def write_audio(path: PathLike, audio: np.ndarray, sample_rate: int, *,
                bitrate_kbps: int = 192) -> None:
    """Write audio, dispatching by ``path`` suffix. Defaults to WAV."""
    suffix = Path(path).suffix.lower()
    if suffix == ".mp3":
        write_mp3(path, audio, sample_rate, bitrate_kbps=bitrate_kbps)
    else:
        write_wav(path, audio, sample_rate)

demucs_onnx.write_wav

write_wav(
    path: PathLike, audio: ndarray, sample_rate: int
) -> None

Write a (channels, samples) float32 array as a 16-bit PCM WAV.

Source code in src/demucs_onnx/_audio.py
def write_wav(path: PathLike, audio: np.ndarray, sample_rate: int) -> None:
    """Write a ``(channels, samples)`` float32 array as a 16-bit PCM WAV."""
    if audio.ndim != 2:
        raise ValueError(f"expected (channels, samples), got {audio.shape}")
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    sf.write(str(path), audio.T, sample_rate, subtype="PCM_16")

demucs_onnx.write_mp3

write_mp3(
    path: PathLike,
    audio: ndarray,
    sample_rate: int,
    bitrate_kbps: int = 192,
) -> None

Write a (channels, samples) float32 array as a CBR MP3.

Requires the lameenc extra (pip install demucs-onnx[mp3]).

Source code in src/demucs_onnx/_audio.py
def write_mp3(path: PathLike, audio: np.ndarray, sample_rate: int,
              bitrate_kbps: int = 192) -> None:
    """Write a ``(channels, samples)`` float32 array as a CBR MP3.

    Requires the ``lameenc`` extra (``pip install demucs-onnx[mp3]``).
    """
    try:
        import lameenc
    except ImportError as exc:
        raise ImportError(
            "MP3 output requires the 'mp3' extra. "
            "Install with: pip install 'demucs-onnx[mp3]'",
        ) from exc

    if audio.ndim != 2 or audio.shape[0] not in (1, 2):
        raise ValueError(f"expected (channels=1|2, samples), got {audio.shape}")
    if not 32 <= bitrate_kbps <= 320:
        raise ValueError(f"bitrate must be in [32, 320] kbps, got {bitrate_kbps}")

    Path(path).parent.mkdir(parents=True, exist_ok=True)

    encoder = lameenc.Encoder()
    encoder.set_bit_rate(bitrate_kbps)
    encoder.set_in_sample_rate(sample_rate)
    encoder.set_channels(audio.shape[0])
    encoder.set_quality(2)

    pcm = (np.clip(audio.T, -1.0, 1.0) * 32767.0).astype(np.int16)
    pcm_bytes = pcm.tobytes()
    mp3_data = encoder.encode(pcm_bytes)
    mp3_data += encoder.flush()
    Path(path).write_bytes(mp3_data)

Browser helpers

demucs_onnx.browser.wasm_config

wasm_config(bundler: Bundler = 'vite') -> str

Return a copy-pasteable bundler config snippet for onnxruntime-web.

Supported bundlers: "vite" (default), "webpack", "esbuild", "next", "rollup".

Source code in src/demucs_onnx/browser.py
def wasm_config(bundler: Bundler = "vite") -> str:
    """Return a copy-pasteable bundler config snippet for ``onnxruntime-web``.

    Supported bundlers: ``"vite"`` (default), ``"webpack"``,
    ``"esbuild"``, ``"next"``, ``"rollup"``.
    """
    if bundler not in _SNIPPETS:
        raise ValueError(
            f"unknown bundler {bundler!r}; expected one of {list(_SNIPPETS)}",
        )
    return _SNIPPETS[bundler]()

demucs_onnx.browser.print_wasm_config

print_wasm_config(bundler: Bundler = 'vite') -> None

Print :func:wasm_config to stdout.

Source code in src/demucs_onnx/browser.py
def print_wasm_config(bundler: Bundler = "vite") -> None:
    """Print :func:`wasm_config` to stdout."""
    print(wasm_config(bundler))

demucs_onnx.browser.write_demo_dir

write_demo_dir(
    target: Path,
    *,
    model_url: str | None = None,
    react: bool = False
) -> list[Path]

Materialize the in-tree browser demo files under target.

Parameters:

Name Type Description Default
target Path

Directory to write into. Created if missing. Must be empty or non-existent (this helper refuses to overwrite existing files).

required
model_url str | None

URL the demo should fetch the ONNX model from. Defaults to DEFAULT_DEMO_URL (htdemucs-ft-vocals fp16weights from the StemSplitio HF org).

None
react bool

When True, emit a Vite + React + TS project. When False (default), emit the zero-build vanilla HTML/JS demo.

False

Returns:

Type Description
list[Path]

The list of files written, in the order they were created.

Source code in src/demucs_onnx/browser.py
def write_demo_dir(target: Path, *,
                   model_url: str | None = None,
                   react: bool = False) -> list[Path]:
    """Materialize the in-tree browser demo files under ``target``.

    Args:
        target: Directory to write into. Created if missing. Must be
            empty or non-existent (this helper refuses to overwrite
            existing files).
        model_url: URL the demo should fetch the ONNX model from.
            Defaults to ``DEFAULT_DEMO_URL`` (htdemucs-ft-vocals
            fp16weights from the StemSplitio HF org).
        react: When True, emit a Vite + React + TS project. When False
            (default), emit the zero-build vanilla HTML/JS demo.

    Returns:
        The list of files written, in the order they were created.
    """
    model_url = model_url or DEFAULT_DEMO_URL
    target = Path(target)
    target.mkdir(parents=True, exist_ok=True)
    if any(target.iterdir()):
        raise FileExistsError(
            f"refusing to write demo files into non-empty {target}",
        )
    written: list[Path] = []
    if react:
        files = {
            "index.html":      _REACT_INDEX_HTML,
            "vite.config.ts":  _REACT_VITE_CONFIG,
            "package.json":    _REACT_PACKAGE_JSON,
            "tsconfig.json":   _REACT_TSCONFIG,
            "README.md":       _REACT_README,
            "src/App.tsx":     _REACT_TS.replace("__MODEL_URL__", model_url),
            "src/main.tsx":    _REACT_MAIN_TSX,
        }
    else:
        files = {
            "index.html": _VANILLA_HTML,
            "demo.js":    _VANILLA_JS.replace("__MODEL_URL__", model_url),
            "README.md":  _VANILLA_README,
        }
    for rel, content in files.items():
        path = target / rel
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(content)
        written.append(path)
    return written

demucs_onnx.browser.model_browser_url

model_browser_url(
    model: str = DEFAULT_DEMO_MODEL,
    precision: Literal[
        "fp32", "fp16weights"
    ] = DEFAULT_DEMO_PRECISION,
) -> str

Return the direct download URL for model at precision.

Useful when emitting JS that wants to call ort.InferenceSession.create(url, ...) against the HF Hub.

Source code in src/demucs_onnx/browser.py
def model_browser_url(model: str = DEFAULT_DEMO_MODEL,
                      precision: Literal["fp32", "fp16weights"] = DEFAULT_DEMO_PRECISION,
                      ) -> str:
    """Return the direct download URL for ``model`` at ``precision``.

    Useful when emitting JS that wants to call
    ``ort.InferenceSession.create(url, ...)`` against the HF Hub.
    """
    if model == "htdemucs_ft_vocals":
        from ._hub import MODEL_REPOS, stem_model_filename
        repo = MODEL_REPOS["htdemucs_ft_vocals"]
        fname = stem_model_filename("vocals", precision)
        return f"https://huggingface.co/{repo}/resolve/main/{fname}"
    info = MODEL_REGISTRY.get(model)
    if info is None:
        raise ValueError(f"unknown model {model!r}")
    fname = model_filename(model, precision)
    return f"https://huggingface.co/{info.repo}/resolve/main/{fname}"

Export pipeline

Requires the [export] extra (pip install 'demucs-onnx[export]').

demucs_onnx.export.exporter.export_to_onnx

export_to_onnx(
    checkpoint: str | Path,
    output: str | Path,
    *,
    stem: str | None = None,
    stems: Iterable[str] | None = None,
    opset: int = 17,
    parity_check: bool = True,
    parity_tolerance: float = DEFAULT_PARITY_TOLERANCE,
    sample_rate: int = SAMPLE_RATE,
    segment_seconds: float = SEGMENT_S,
    verbose: bool = True
) -> dict[str, Path]

Export a demucs/htdemucs checkpoint to one or more ONNX files.

Parameters:

Name Type Description Default
checkpoint str | Path

Either a name accepted by demucs.pretrained.get_model (e.g. "htdemucs_ft", "htdemucs", "mdx_extra") or a path to a local .th checkpoint file.

Bag models (htdemucs_ft) are 4-model ensembles. By default we export every specialist; pass stem or stems to filter.

required
output str | Path

For a single model, a .onnx file path. For a bag, a directory that will be populated with one file per stem.

required
stem str | None

Optional single stem name (drums / bass / other / vocals). Only valid when checkpoint is a bag.

None
stems Iterable[str] | None

Optional iterable of stem names to export from a bag. Default: all 4.

None
opset int

ONNX opset to target. 17 is the lowest with native STFT support; we don't actually use ONNX's STFT op (we replace it with conv1d) but staying ≥17 is a good baseline for ORT EPs.

17
parity_check bool

When True (default), run the patched and the original PyTorch models on the same dummy input and abort if their outputs differ by more than parity_tolerance. Set to False only if you know what you're doing — exporting an unverified model is the most common way to ship a silently broken pipeline.

True
parity_tolerance float

Max allowed abs diff between the patched PyTorch model and the original. Default 1e-3 matches the tolerance the StemSplit team uses for the published HF models.

DEFAULT_PARITY_TOLERANCE
sample_rate int

Sample rate the exported model expects. Changing this requires a matching change to the host inference code; demucs is trained at 44100 Hz.

SAMPLE_RATE
segment_seconds float

Segment length baked into the exported graph.

SEGMENT_S
verbose bool

Print progress to stdout.

True

Returns:

Type Description
dict[str, Path]

{stem_name: output_path}, or for a single non-bag model

dict[str, Path]

{"model": output_path}.

Source code in src/demucs_onnx/export/exporter.py
def export_to_onnx(checkpoint: str | Path,
                   output: str | Path,
                   *,
                   stem: str | None = None,
                   stems: Iterable[str] | None = None,
                   opset: int = 17,
                   parity_check: bool = True,
                   parity_tolerance: float = DEFAULT_PARITY_TOLERANCE,
                   sample_rate: int = SAMPLE_RATE,
                   segment_seconds: float = SEGMENT_S,
                   verbose: bool = True,
                   ) -> dict[str, Path]:
    """Export a demucs/htdemucs checkpoint to one or more ONNX files.

    Args:
        checkpoint: Either a name accepted by ``demucs.pretrained.get_model``
            (e.g. ``"htdemucs_ft"``, ``"htdemucs"``, ``"mdx_extra"``) or a
            path to a local ``.th`` checkpoint file.

            Bag models (``htdemucs_ft``) are 4-model ensembles. By default we
            export every specialist; pass ``stem`` or ``stems`` to filter.

        output: For a single model, a ``.onnx`` file path. For a bag, a
            directory that will be populated with one file per stem.
        stem: Optional single stem name (drums / bass / other / vocals).
            Only valid when ``checkpoint`` is a bag.
        stems: Optional iterable of stem names to export from a bag.
            Default: all 4.
        opset: ONNX opset to target. 17 is the lowest with native STFT
            support; we don't actually use ONNX's STFT op (we replace it
            with conv1d) but staying ≥17 is a good baseline for ORT EPs.
        parity_check: When True (default), run the patched and the original
            PyTorch models on the same dummy input and abort if their
            outputs differ by more than ``parity_tolerance``. Set to False
            only if you know what you're doing — exporting an unverified
            model is the most common way to ship a silently broken pipeline.
        parity_tolerance: Max allowed abs diff between the patched
            PyTorch model and the original. Default 1e-3 matches the
            tolerance the StemSplit team uses for the published HF models.
        sample_rate: Sample rate the exported model expects. Changing this
            requires a matching change to the host inference code; demucs
            is trained at 44100 Hz.
        segment_seconds: Segment length baked into the exported graph.
        verbose: Print progress to stdout.

    Returns:
        ``{stem_name: output_path}``, or for a single non-bag model
        ``{"model": output_path}``.
    """
    bag, sub_models, sources = _load_checkpoint(checkpoint, verbose=verbose)

    # Resolve which sub-models to export.
    is_bag = sub_models is not None
    # The htdemucs_ft "specialist bag" has 4 sub-models (one per stem); each
    # sub-model nominally predicts all 4 stems but only one row is the real
    # prediction. By contrast, the plain `htdemucs` and `htdemucs_6s` models
    # arrive wrapped in a BagOfModels with a single sub-model that predicts
    # every stem row meaningfully. We treat that single-submodel case as a
    # single-file export, not as a 4-specialist export.
    is_specialist_bag = is_bag and len(sub_models) > 1
    if not is_bag or not is_specialist_bag:
        if stem is not None or stems is not None:
            raise ValueError(
                f"checkpoint {checkpoint!r} is a single model "
                f"(sources={sources!r}); cannot pass `stem` or `stems`. "
                "Export the whole model and pick the row at inference time.",
            )
        # Treat as one model with multi-stem output. Use the checkpoint name
        # as the filename stem (e.g. "htdemucs.onnx", "htdemucs_6s.onnx").
        ckpt_name = (
            Path(checkpoint).stem
            if isinstance(checkpoint, Path) or "/" in str(checkpoint)
            else str(checkpoint)
        )
        targets: list[tuple[str, int]] = [(ckpt_name, 0)]
    else:
        wanted: list[str]
        if stem is not None and stems is not None:
            raise ValueError("pass either `stem` OR `stems`, not both.")
        if stem is not None:
            wanted = [stem]
        elif stems is not None:
            wanted = list(stems)
        else:
            wanted = list(STEM_TO_INDEX)
        for s in wanted:
            if s not in STEM_TO_INDEX:
                raise ValueError(
                    f"unknown stem {s!r}; expected one of {list(STEM_TO_INDEX)}",
                )
        targets = [(s, STEM_TO_INDEX[s]) for s in wanted]

    n_samples = int(segment_seconds * sample_rate)
    out_paths: dict[str, Path] = {}

    out_root = Path(output)
    if is_specialist_bag and len(targets) > 1:
        out_root.mkdir(parents=True, exist_ok=True)
    else:
        out_root.parent.mkdir(parents=True, exist_ok=True)

    for stem_name, idx in targets:
        # Pick the right output filename.
        if is_specialist_bag and len(targets) > 1:
            file_path = out_root / f"htdemucs_ft_{stem_name}.onnx"
        elif is_specialist_bag:
            # Single-stem export: use `output` directly if it ends in .onnx,
            # otherwise treat it as a directory.
            if out_root.suffix.lower() == ".onnx":
                file_path = out_root
            else:
                out_root.mkdir(parents=True, exist_ok=True)
                file_path = out_root / f"htdemucs_ft_{stem_name}.onnx"
        else:
            file_path = out_root if out_root.suffix.lower() == ".onnx" else out_root / f"{stem_name}.onnx"

        if verbose:
            print(f"\n=== Exporting {stem_name} (index {idx}) -> {file_path} ===")

        if is_specialist_bag:
            original = sub_models[idx].eval().to("cpu")
            parity_idx: int | None = idx
        elif is_bag:
            original = sub_models[0].eval().to("cpu")
            parity_idx = None
        else:
            original = bag.eval().to("cpu")
            parity_idx = None
        if parity_check:
            _verify_pytorch_parity(
                original, n_samples=n_samples, bag_index=parity_idx, stem=stem_name,
                tolerance=parity_tolerance, verbose=verbose,
            )
        patched = patch_htdemucs_for_onnx(copy.deepcopy(original))

        _export_one(patched, file_path, n_samples=n_samples,
                    opset=opset, verbose=verbose)
        _onnx_check(file_path, verbose=verbose)

        if parity_check:
            _verify_onnx_parity(original, file_path, n_samples=n_samples,
                                bag_index=parity_idx, stem=stem_name,
                                tolerance=parity_tolerance, verbose=verbose)

        out_paths[stem_name] = file_path

    return out_paths

demucs_onnx.export.exporter.verify_onnx_parity

verify_onnx_parity(
    checkpoint: str | Path,
    onnx_path: str | Path,
    *,
    stem: str | None = None,
    tolerance: float = DEFAULT_PARITY_TOLERANCE,
    sample_rate: int = SAMPLE_RATE,
    segment_seconds: float = SEGMENT_S,
    verbose: bool = True
) -> float

Compare an exported ONNX model against the original PyTorch checkpoint.

Returns the max abs diff. Raises AssertionError if it exceeds tolerance.

Source code in src/demucs_onnx/export/exporter.py
def verify_onnx_parity(checkpoint: str | Path, onnx_path: str | Path, *,
                       stem: str | None = None,
                       tolerance: float = DEFAULT_PARITY_TOLERANCE,
                       sample_rate: int = SAMPLE_RATE,
                       segment_seconds: float = SEGMENT_S,
                       verbose: bool = True) -> float:
    """Compare an exported ONNX model against the original PyTorch checkpoint.

    Returns the max abs diff. Raises ``AssertionError`` if it exceeds
    ``tolerance``.
    """
    bag, sub_models, _sources = _load_checkpoint(checkpoint, verbose=verbose)
    if sub_models is not None:
        if stem is None:
            raise ValueError(
                f"checkpoint {checkpoint!r} is a bag; pass `stem=` to pick a sub-model.",
            )
        idx = STEM_TO_INDEX[stem]
        original = sub_models[idx].eval().to("cpu")
        bag_index = idx
        stem_name = stem
    else:
        original = bag.eval().to("cpu")
        bag_index = STEM_TO_INDEX.get(stem or "", 0)
        stem_name = stem or "model"
    n_samples = int(segment_seconds * sample_rate)
    return _verify_onnx_parity(
        original, Path(onnx_path), n_samples=n_samples, bag_index=bag_index,
        stem=stem_name, tolerance=tolerance, verbose=verbose,
    )

demucs_onnx.export.patch.patch_htdemucs_for_onnx

patch_htdemucs_for_onnx(model: Module) -> Module

Mutate an htdemucs sub-model in place so it has no complex tensors and no Python-dynamic code paths.

Returns the same model so calls can be chained.

The four patches applied:

  1. model.segment: Fractionfloat (:func:demucs_onnx.export.coerce_segment_to_float)
  2. CrossTransformerEncoder._get_pos_embedding: drops random.randrange (:func:demucs_onnx.export.disable_random_pos_shift)
  3. nn.MultiheadAttention.forward: replaced with a primitive-only impl (:func:demucs_onnx.export.onnx_friendly_mha_forward)
  4. model._spec / _ispec / _magnitude / _mask: rewritten to thread a real (B, C, 2, F, T) tensor instead of complex tensors, backed by :class:~demucs_onnx.export.RealSTFT and :class:~demucs_onnx.export.RealISTFT.
Source code in src/demucs_onnx/export/patch.py
def patch_htdemucs_for_onnx(model: nn.Module) -> nn.Module:
    """Mutate an htdemucs sub-model in place so it has no complex tensors and
    no Python-dynamic code paths.

    Returns the same model so calls can be chained.

    The four patches applied:

    1. ``model.segment``: ``Fraction`` → ``float``
       (:func:`demucs_onnx.export.coerce_segment_to_float`)
    2. ``CrossTransformerEncoder._get_pos_embedding``: drops ``random.randrange``
       (:func:`demucs_onnx.export.disable_random_pos_shift`)
    3. ``nn.MultiheadAttention.forward``: replaced with a primitive-only impl
       (:func:`demucs_onnx.export.onnx_friendly_mha_forward`)
    4. ``model._spec`` / ``_ispec`` / ``_magnitude`` / ``_mask``: rewritten to
       thread a real ``(B, C, 2, F, T)`` tensor instead of complex tensors,
       backed by :class:`~demucs_onnx.export.RealSTFT` and
       :class:`~demucs_onnx.export.RealISTFT`.
    """
    coerce_segment_to_float(model)
    disable_random_pos_shift(model)

    # Install the manual MHA forward on every nn.MultiheadAttention instance.
    for m in model.modules():
        if isinstance(m, nn.MultiheadAttention):
            m.forward = types.MethodType(onnx_friendly_mha_forward, m)

    # Real STFT/iSTFT replacements. The originals are the methods _spec /
    # _ispec / _magnitude / _mask on the model itself; we swap all four.
    n_fft = 4096
    hop_length = n_fft // 4
    real_stft = RealSTFT(n_fft, hop_length)
    real_istft = RealISTFT(n_fft, hop_length)
    # Move kernels onto the model so they migrate cleanly with .to(device).
    model.real_stft = real_stft
    model.real_istft = real_istft

    def _spec_real(self_: Any, x: torch.Tensor) -> torch.Tensor:
        # Mirrors HTDemucs._spec but emits (B, C, 2, F, T) real tensors.
        hl = self_.hop_length
        nfft = self_.nfft
        if hl != nfft // 4:
            raise AssertionError(f"unexpected hop {hl} for nfft {nfft}")
        le = math.ceil(x.shape[-1] / hl)
        pad = hl // 2 * 3
        x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
        z = self_.real_stft(x)[..., :-1, :]  # drop the Nyquist bin
        if z.shape[-1] != le + 4:
            raise AssertionError((z.shape, x.shape, le))
        return z[..., 2: 2 + le]

    def _ispec_real(self_: Any, z: torch.Tensor, length: int = 0,
                    scale: int = 0) -> torch.Tensor:
        hl = self_.hop_length // (4 ** scale)
        z = F.pad(z, (0, 0, 0, 1))   # restore the Nyquist bin we dropped
        z = F.pad(z, (2, 2))         # symmetric pad on time axis
        pad = hl // 2 * 3
        le = hl * math.ceil(length / hl) + 2 * pad
        x = self_.real_istft(z, length=le)
        return x[..., pad: pad + length]

    def _magnitude_real(self_: Any, z: torch.Tensor) -> torch.Tensor:
        # cac=True path. Original cac=True flow does:
        #   m = view_as_real(z_complex).permute(0,1,4,2,3) -> (B, C, 2, F, T)
        #   m = m.reshape(B, C*2, F, T)
        # Our z is already (B, C, 2, F, T) real, so just reshape.
        B, C, two, Fr, T = z.shape
        if two != 2:
            raise AssertionError(f"expected 2 real channels, got {two}")
        return z.reshape(B, C * two, Fr, T)

    def _mask_real(self_: Any, z: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        # cac=True path. Original:
        #   B,S,C,Fr,T = m.shape
        #   out = m.view(B,S,-1,2,Fr,T).permute(0,1,2,4,5,3)  -> (..,Fr,T,2)
        #   out = view_as_complex(out)                         -> (..,Fr,T) complex
        # Our equivalent stays real: (B, S, C', 2, F, T).
        B, S, C, Fr, T = m.shape
        return m.view(B, S, C // 2, 2, Fr, T)

    model._spec = types.MethodType(_spec_real, model)
    model._ispec = types.MethodType(_ispec_real, model)
    model._magnitude = types.MethodType(_magnitude_real, model)
    model._mask = types.MethodType(_mask_real, model)

    # Force eval + cpu (avoids some MPS-roundtrip branches in demucs).
    model.eval()
    model.to("cpu")
    return model

Individual patches

demucs_onnx.export.segment.coerce_segment_to_float

coerce_segment_to_float(model: Module) -> Module

Convert model.segment from a :class:Fraction to a float, if needed.

Idempotent and safe to call on already-coerced models.

Source code in src/demucs_onnx/export/segment.py
def coerce_segment_to_float(model: nn.Module) -> nn.Module:
    """Convert ``model.segment`` from a :class:`Fraction` to a ``float``, if needed.

    Idempotent and safe to call on already-coerced models.
    """
    seg = getattr(model, "segment", None)
    if isinstance(seg, Fraction):
        model.segment = float(seg)
    return model

demucs_onnx.export.pos_embed.disable_random_pos_shift

disable_random_pos_shift(model: Module) -> Module

Replace CrossTransformerEncoder._get_pos_embedding with a deterministic version that hardcodes shift = 0.

Also sets sin_random_shift = 0 on every module that has the attr, for belt-and-braces when other code paths read it directly.

Source code in src/demucs_onnx/export/pos_embed.py
def disable_random_pos_shift(model: nn.Module) -> nn.Module:
    """Replace ``CrossTransformerEncoder._get_pos_embedding`` with a
    deterministic version that hardcodes ``shift = 0``.

    Also sets ``sin_random_shift = 0`` on every module that has the attr,
    for belt-and-braces when other code paths read it directly.
    """
    # Importing demucs is deferred to keep the inference path torch-free.
    import demucs.transformer as tr

    for m in model.modules():
        if hasattr(m, "sin_random_shift"):
            m.sin_random_shift = 0

    def _get_pos_embedding_no_random(self_: Any, T: int, B: int, C: int,
                                     device: torch.device) -> torch.Tensor:
        # Mirror tr.CrossTransformerEncoder._get_pos_embedding but never
        # call random.randrange. At inference sin_random_shift is 0 so
        # `shift` is always 0 — this branch is exactly equivalent.
        if self_.emb == "sin":
            return tr.create_sin_embedding(
                T, C, shift=0, device=device, max_period=self_.max_period,
            )
        if self_.emb == "cape":
            return tr.create_sin_embedding_cape(
                T, C, B, device=device, max_period=self_.max_period,
                mean_normalize=self_.cape_mean_normalize,
                augment=False,  # eval mode never augments
                max_global_shift=0.0, max_local_shift=0.0, max_scale=1.0,
            )
        if self_.emb == "scaled":
            pos = torch.arange(T, device=device)
            return self_.position_embeddings(pos)[:, None]
        raise RuntimeError(f"unknown emb {self_.emb!r}")

    for m in model.modules():
        if isinstance(m, tr.CrossTransformerEncoder):
            m._get_pos_embedding = types.MethodType(_get_pos_embedding_no_random, m)

    return model

demucs_onnx.export.mha.onnx_friendly_mha_forward

onnx_friendly_mha_forward(
    self_: MultiheadAttention,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    key_padding_mask: Tensor | None = None,
    need_weights: bool = True,
    attn_mask: Tensor | None = None,
    average_attn_weights: bool = True,
    is_causal: bool = False,
) -> tuple[Tensor, Tensor | None]

Drop-in replacement for nn.MultiheadAttention.forward that uses only ops with stable ONNX symbolics.

Supports the call shapes htdemucs actually uses:

  • batch_first ∈ {True, False}
  • need_weights = False (no weight tensor returned)
  • attn_mask = None
  • key_padding_mask = None
  • cross-attention (query != key == value, or all three differ)

The signature matches nn.MultiheadAttention.forward so we can install it via types.MethodType without further wrapping.

Source code in src/demucs_onnx/export/mha.py
def onnx_friendly_mha_forward(self_: nn.MultiheadAttention,
                              query: torch.Tensor,
                              key: torch.Tensor,
                              value: torch.Tensor,
                              key_padding_mask: torch.Tensor | None = None,
                              need_weights: bool = True,
                              attn_mask: torch.Tensor | None = None,
                              average_attn_weights: bool = True,
                              is_causal: bool = False,
                              ) -> tuple[torch.Tensor, torch.Tensor | None]:
    """Drop-in replacement for ``nn.MultiheadAttention.forward`` that uses
    only ops with stable ONNX symbolics.

    Supports the call shapes htdemucs actually uses:

    - ``batch_first ∈ {True, False}``
    - ``need_weights = False`` (no weight tensor returned)
    - ``attn_mask = None``
    - ``key_padding_mask = None``
    - cross-attention (query != key == value, or all three differ)

    The signature matches ``nn.MultiheadAttention.forward`` so we can
    install it via ``types.MethodType`` without further wrapping.
    """
    if self_.batch_first:
        query = query.transpose(0, 1)
        key = key.transpose(0, 1)
        value = value.transpose(0, 1)

    tgt_len, bsz, embed_dim = query.shape
    src_len = key.shape[0]
    num_heads = self_.num_heads
    head_dim = embed_dim // num_heads
    scaling = head_dim ** -0.5

    # Apply input projections. nn.MultiheadAttention uses a single fused
    # in_proj_weight when q/k/v have the same dim, OR separate proj_weight
    # tensors otherwise.
    if self_._qkv_same_embed_dim:
        w = self_.in_proj_weight
        b = self_.in_proj_bias
        if torch.equal(query, key) and torch.equal(key, value):
            qkv = F.linear(query, w, b)
            q, k, v = qkv.chunk(3, dim=-1)
        else:
            w_q, w_k, w_v = w.chunk(3, dim=0)
            if b is not None:
                b_q, b_k, b_v = b.chunk(3, dim=0)
            else:
                b_q = b_k = b_v = None
            q = F.linear(query, w_q, b_q)
            k = F.linear(key, w_k, b_k)
            v = F.linear(value, w_v, b_v)
    else:
        bias = self_.in_proj_bias
        q = F.linear(
            query, self_.q_proj_weight,
            bias[:embed_dim] if bias is not None else None,
        )
        k = F.linear(
            key, self_.k_proj_weight,
            bias[embed_dim:2 * embed_dim] if bias is not None else None,
        )
        v = F.linear(
            value, self_.v_proj_weight,
            bias[2 * embed_dim:] if bias is not None else None,
        )

    # Reshape into (B*H, T, head_dim) for batched matmul.
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    k = k.contiguous().view(src_len, bsz * num_heads, head_dim).transpose(0, 1)
    v = v.contiguous().view(src_len, bsz * num_heads, head_dim).transpose(0, 1)

    # Scaled dot-product attention, manually.
    q = q * scaling
    attn_weights = torch.bmm(q, k.transpose(1, 2))  # (B*H, T_q, T_k)
    if attn_mask is not None:
        attn_weights = attn_weights + attn_mask
    attn_weights = F.softmax(attn_weights, dim=-1)
    # No dropout at inference.
    attn_output = torch.bmm(attn_weights, v)  # (B*H, T_q, head_dim)
    attn_output = (
        attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    )
    attn_output = self_.out_proj(attn_output)

    if self_.batch_first:
        attn_output = attn_output.transpose(0, 1)

    if not need_weights:
        return attn_output, None
    attn_weights = attn_weights.view(bsz, num_heads, tgt_len, src_len)
    if average_attn_weights:
        attn_weights = attn_weights.mean(dim=1)
    return attn_output, attn_weights

demucs_onnx.export.stft.RealSTFT

RealSTFT(n_fft: int = 4096, hop_length: int | None = None)

Bases: Module

ONNX-exportable STFT that emits real/imag as two channel groups.

Matches the demucs spectro(x, n_fft, hop_length, pad=0) API but returns a real tensor of shape (..., 2, F, T) instead of a complex (..., F, T) tensor.

Layout matches what view_as_real would produce: channel 0 = real, channel 1 = imag.

Source code in src/demucs_onnx/export/stft.py
def __init__(self, n_fft: int = 4096, hop_length: int | None = None) -> None:
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length or n_fft // 4
    cos, sin = make_stft_kernels(n_fft)
    self.register_buffer("cos_kernel", cos, persistent=False)
    self.register_buffer("sin_kernel", sin, persistent=False)

demucs_onnx.export.stft.RealISTFT

RealISTFT(n_fft: int = 4096, hop_length: int | None = None)

Bases: Module

ONNX-exportable inverse STFT, the inverse of :class:RealSTFT.

Input shape: (..., 2, F, T). Output shape: (..., L).

Reconstruction formula (per sample n of frame OLA):

.. math::

x[n] = \frac{1}{\sqrt{N}} w[n] \Big( R_0 + (-1)^n R_{N/2}
      + 2 \sum_{k=1}^{N/2-1} R_k \cos\frac{2\pi k n}{N}
      - I_k \sin\frac{2\pi k n}{N} \Big)

The factor of 2 doubles the contribution from bins 1..N/2-1 to account for the dropped negative frequencies. DC and Nyquist appear once, so we halve their entries. After OLA we divide by the OLA(window^2) envelope to undo the analysis windowing.

Source code in src/demucs_onnx/export/stft.py
def __init__(self, n_fft: int = 4096, hop_length: int | None = None) -> None:
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length or n_fft // 4

    n = torch.arange(n_fft, dtype=torch.float64)
    window = torch.hann_window(n_fft, periodic=True, dtype=torch.float64)
    norm = 1.0 / math.sqrt(n_fft)
    n_bins = n_fft // 2 + 1
    k = torch.arange(n_bins, dtype=torch.float64).unsqueeze(1)
    angles = 2 * math.pi * k * n.unsqueeze(0) / n_fft

    inv_cos = 2.0 * (window * torch.cos(angles)) * norm
    inv_sin = -2.0 * (window * torch.sin(angles)) * norm
    inv_cos[0] *= 0.5    # DC: appears once, not twice
    inv_cos[-1] *= 0.5   # Nyquist: same
    inv_sin[0] *= 0.0    # imag part of DC and Nyquist is zero by construction
    inv_sin[-1] *= 0.0

    self.register_buffer("inv_cos", inv_cos.float().unsqueeze(1), persistent=False)
    self.register_buffer("inv_sin", inv_sin.float().unsqueeze(1), persistent=False)
    # The OLA(window^2) envelope is shape-dependent; cache by (frames, length).
    self._envelope_cache: dict[tuple[int, int], torch.Tensor] = {}

demucs_onnx.export.stft.make_stft_kernels

make_stft_kernels(n_fft: int) -> tuple[Tensor, Tensor]

Build (cos, sin) DFT kernels of shape (n_fft//2 + 1, 1, n_fft).

The kernels reproduce torch.stft(window=hann, win_length=n_fft, n_fft=n_fft, normalized=True, center=True) with hop = n_fft // 4.

Computation is done in float64 to keep the high-frequency bins precise, then cast to float32 for the actual conv weights.

Source code in src/demucs_onnx/export/stft.py
def make_stft_kernels(n_fft: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Build ``(cos, sin)`` DFT kernels of shape ``(n_fft//2 + 1, 1, n_fft)``.

    The kernels reproduce ``torch.stft(window=hann, win_length=n_fft,
    n_fft=n_fft, normalized=True, center=True)`` with hop = ``n_fft // 4``.

    Computation is done in float64 to keep the high-frequency bins precise,
    then cast to float32 for the actual conv weights.
    """
    n_bins = n_fft // 2 + 1
    n = torch.arange(n_fft, dtype=torch.float64)
    window = torch.hann_window(n_fft, periodic=True, dtype=torch.float64)
    norm = 1.0 / math.sqrt(n_fft)  # matches normalized=True

    k = torch.arange(n_bins, dtype=torch.float64).unsqueeze(1)  # (F, 1)
    angles = 2 * math.pi * k * n.unsqueeze(0) / n_fft  # (F, N)

    cos = (window * torch.cos(angles)) * norm
    sin = (window * -torch.sin(angles)) * norm  # negative sign for forward STFT
    return cos.float().unsqueeze(1), sin.float().unsqueeze(1)