mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-28 14:55:46 +00:00
Ignore local sanity files and enable Sortformer GPU
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -75,3 +75,5 @@ __pycache__/
|
||||
pr
|
||||
linus_60s.wav
|
||||
ref/
|
||||
sanity-checks/
|
||||
sample.wav
|
||||
|
||||
@@ -119,6 +119,8 @@ def _load_sortformer(spec: ModelSpec) -> LoadedModel:
|
||||
|
||||
model_path = _resolve_sortformer_model_path(spec)
|
||||
device = _resolve_device(spec.providers)
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
map_location = torch.device("cuda")
|
||||
else:
|
||||
@@ -130,6 +132,8 @@ def _load_sortformer(spec: ModelSpec) -> LoadedModel:
|
||||
strict=False,
|
||||
)
|
||||
diar_model.eval()
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
diar_model = diar_model.to(torch.device("cuda"))
|
||||
return LoadedModel(spec=spec, kind="sortformer", model=diar_model, loaded_at=time.time())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user