From f7253df24bcb2c986b57249e30a6fe0d449d77cc Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 Jan 2026 16:24:13 -0800 Subject: [PATCH] Ignore local sanity files and enable Sortformer GPU --- .gitignore | 2 ++ .../scriberr-diariz-torch/src/diar_engine/model_manager.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 5107aada..47debfd4 100644 --- a/.gitignore +++ b/.gitignore @@ -75,3 +75,5 @@ __pycache__/ pr linus_60s.wav ref/ +sanity-checks/ +sample.wav diff --git a/asr-engines/scriberr-diariz-torch/src/diar_engine/model_manager.py b/asr-engines/scriberr-diariz-torch/src/diar_engine/model_manager.py index 917ddb40..7eb4bf98 100644 --- a/asr-engines/scriberr-diariz-torch/src/diar_engine/model_manager.py +++ b/asr-engines/scriberr-diariz-torch/src/diar_engine/model_manager.py @@ -119,6 +119,8 @@ def _load_sortformer(spec: ModelSpec) -> LoadedModel: model_path = _resolve_sortformer_model_path(spec) device = _resolve_device(spec.providers) + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda" and torch.cuda.is_available(): map_location = torch.device("cuda") else: @@ -130,6 +132,8 @@ def _load_sortformer(spec: ModelSpec) -> LoadedModel: strict=False, ) diar_model.eval() + if device == "cuda" and torch.cuda.is_available(): + diar_model = diar_model.to(torch.device("cuda")) return LoadedModel(spec=spec, kind="sortformer", model=diar_model, loaded_at=time.time())