mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-03-03 02:57:01 +00:00
Extract python adapter scripts to proper files
This commit is contained in:
committed by
Rishikanth Chandrasekaran
parent
edb65339b8
commit
50dd4130ff
1
.gitignore
vendored
1
.gitignore
vendored
@@ -70,3 +70,4 @@ tmp/
|
|||||||
# *.svg
|
# *.svg
|
||||||
# *.png
|
# *.png
|
||||||
dhl.txt
|
dhl.txt
|
||||||
|
__pycache__/
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package adapters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,6 +16,9 @@ import (
|
|||||||
"scriberr/pkg/logger"
|
"scriberr/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:embed py/nvidia/*
|
||||||
|
var nvidiaScripts embed.FS
|
||||||
|
|
||||||
// CanaryAdapter implements the TranscriptionAdapter interface for NVIDIA Canary
|
// CanaryAdapter implements the TranscriptionAdapter interface for NVIDIA Canary
|
||||||
type CanaryAdapter struct {
|
type CanaryAdapter struct {
|
||||||
*BaseAdapter
|
*BaseAdapter
|
||||||
@@ -185,7 +189,7 @@ func (c *CanaryAdapter) PrepareEnvironment(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create transcription script
|
// Create transcription script
|
||||||
if err := c.createTranscriptionScript(); err != nil {
|
if err := c.copyTranscriptionScript(); err != nil {
|
||||||
return fmt.Errorf("failed to create transcription script: %w", err)
|
return fmt.Errorf("failed to create transcription script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,49 +211,23 @@ func (c *CanaryAdapter) setupCanaryEnvironment() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
// Read pyproject.toml
|
||||||
pyprojectContent := fmt.Sprintf(`[project]
|
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
|
||||||
name = "parakeet-transcription"
|
if err != nil {
|
||||||
version = "0.1.0"
|
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||||
description = "Audio transcription using NVIDIA Parakeet models"
|
}
|
||||||
requires-python = ">=3.11"
|
|
||||||
dependencies = [
|
|
||||||
"nemo-toolkit[asr]",
|
|
||||||
"torch",
|
|
||||||
"torchaudio",
|
|
||||||
"librosa",
|
|
||||||
"soundfile",
|
|
||||||
"ml-dtypes>=0.3.1,<0.5.0",
|
|
||||||
"onnx>=1.15.0,<1.18.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
// The static file contains the default cu126 URL
|
||||||
torch = [
|
contentStr := strings.Replace(
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
string(pyprojectContent),
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
"https://download.pytorch.org/whl/cu126",
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
GetPyTorchWheelURL(),
|
||||||
]
|
1,
|
||||||
torchaudio = [
|
)
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
|
||||||
]
|
|
||||||
triton = [
|
|
||||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
pyprojectPath = filepath.Join(c.envPath, "pyproject.toml")
|
||||||
name = "pytorch"
|
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||||
url = "%s"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
||||||
`, GetPyTorchWheelURL())
|
|
||||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
|
||||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,213 +277,15 @@ func (c *CanaryAdapter) downloadCanaryModel() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTranscriptionScript creates the Python script for Canary transcription
|
// copyTranscriptionScript creates the Python script for Canary transcription
|
||||||
func (c *CanaryAdapter) createTranscriptionScript() error {
|
func (c *CanaryAdapter) copyTranscriptionScript() error {
|
||||||
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
|
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/canary_transcribe.py")
|
||||||
|
if err != nil {
|
||||||
// Check if script already exists
|
return fmt.Errorf("failed to read embedded canary_transcribe.py: %w", err)
|
||||||
if _, err := os.Stat(scriptPath); err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scriptContent := `#!/usr/bin/env python3
|
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
|
||||||
"""
|
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||||
NVIDIA Canary multilingual transcription and translation script.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_audio(
|
|
||||||
audio_path: str,
|
|
||||||
source_lang: str = "en",
|
|
||||||
target_lang: str = "en",
|
|
||||||
task: str = "transcribe",
|
|
||||||
timestamps: bool = True,
|
|
||||||
output_file: str = None,
|
|
||||||
include_confidence: bool = True,
|
|
||||||
preserve_formatting: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribe or translate audio using NVIDIA Canary model.
|
|
||||||
"""
|
|
||||||
# Get the directory where this script is located
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
model_path = os.path.join(script_dir, "canary-1b-v2.nemo")
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"Error: Model file not found: {model_path}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Loading NVIDIA Canary model from: {model_path}")
|
|
||||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
|
||||||
|
|
||||||
print(f"Processing: {audio_path}")
|
|
||||||
print(f"Task: {task}")
|
|
||||||
print(f"Source language: {source_lang}")
|
|
||||||
print(f"Target language: {target_lang}")
|
|
||||||
|
|
||||||
if timestamps:
|
|
||||||
if task == "translate" and source_lang != target_lang:
|
|
||||||
# Translation with timestamps
|
|
||||||
output = asr_model.transcribe(
|
|
||||||
[audio_path],
|
|
||||||
source_lang=source_lang,
|
|
||||||
target_lang=target_lang,
|
|
||||||
timestamps=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Transcription with timestamps
|
|
||||||
output = asr_model.transcribe(
|
|
||||||
[audio_path],
|
|
||||||
source_lang=source_lang,
|
|
||||||
target_lang=target_lang,
|
|
||||||
timestamps=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text and timestamps
|
|
||||||
result_data = output[0]
|
|
||||||
text = result_data.text
|
|
||||||
word_timestamps = result_data.timestamp.get("word", [])
|
|
||||||
segment_timestamps = result_data.timestamp.get("segment", [])
|
|
||||||
|
|
||||||
print(f"Result: {text}")
|
|
||||||
|
|
||||||
# Prepare output data
|
|
||||||
output_data = {
|
|
||||||
"transcription": text,
|
|
||||||
"source_language": source_lang,
|
|
||||||
"target_language": target_lang,
|
|
||||||
"task": task,
|
|
||||||
"word_timestamps": word_timestamps,
|
|
||||||
"segment_timestamps": segment_timestamps,
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "canary-1b-v2"
|
|
||||||
}
|
|
||||||
|
|
||||||
if include_confidence:
|
|
||||||
# Add confidence scores if available
|
|
||||||
if hasattr(result_data, 'confidence') and result_data.confidence:
|
|
||||||
output_data["confidence"] = result_data.confidence
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
if output_file:
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
|
||||||
print(f"Results saved to: {output_file}")
|
|
||||||
else:
|
|
||||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Simple transcription/translation without timestamps
|
|
||||||
if task == "translate" and source_lang != target_lang:
|
|
||||||
output = asr_model.transcribe(
|
|
||||||
[audio_path],
|
|
||||||
source_lang=source_lang,
|
|
||||||
target_lang=target_lang
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = asr_model.transcribe(
|
|
||||||
[audio_path],
|
|
||||||
source_lang=source_lang,
|
|
||||||
target_lang=target_lang
|
|
||||||
)
|
|
||||||
|
|
||||||
text = output[0].text
|
|
||||||
|
|
||||||
output_data = {
|
|
||||||
"transcription": text,
|
|
||||||
"source_language": source_lang,
|
|
||||||
"target_language": target_lang,
|
|
||||||
"task": task,
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "canary-1b-v2"
|
|
||||||
}
|
|
||||||
|
|
||||||
if output_file:
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
|
||||||
print(f"Results saved to: {output_file}")
|
|
||||||
else:
|
|
||||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Transcribe or translate audio using NVIDIA Canary model"
|
|
||||||
)
|
|
||||||
parser.add_argument("audio_file", help="Path to audio file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--source-lang", default="en",
|
|
||||||
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
|
||||||
help="Source language (default: en)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target-lang", default="en",
|
|
||||||
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
|
||||||
help="Target language (default: en)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--task", choices=["transcribe", "translate"], default="transcribe",
|
|
||||||
help="Task to perform (default: transcribe)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--timestamps", action="store_true", default=True,
|
|
||||||
help="Include word and segment level timestamps"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-timestamps", dest="timestamps", action="store_false",
|
|
||||||
help="Disable timestamps"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output", "-o", help="Output file path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--include-confidence", action="store_true", default=True,
|
|
||||||
help="Include confidence scores"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-confidence", dest="include_confidence", action="store_false",
|
|
||||||
help="Exclude confidence scores"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--preserve-formatting", action="store_true", default=True,
|
|
||||||
help="Preserve punctuation and capitalization"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate input file
|
|
||||||
if not os.path.exists(args.audio_file):
|
|
||||||
print(f"Error: Audio file not found: {args.audio_file}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
transcribe_audio(
|
|
||||||
audio_path=args.audio_file,
|
|
||||||
source_lang=args.source_lang,
|
|
||||||
target_lang=args.target_lang,
|
|
||||||
task=args.task,
|
|
||||||
timestamps=args.timestamps,
|
|
||||||
output_file=args.output,
|
|
||||||
include_confidence=args.include_confidence,
|
|
||||||
preserve_formatting=args.preserve_formatting,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during transcription: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
`
|
|
||||||
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to write transcription script: %w", err)
|
return fmt.Errorf("failed to write transcription script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -128,8 +128,8 @@ func (p *ParakeetAdapter) PrepareEnvironment(ctx context.Context) error {
|
|||||||
// Check if environment is already ready (using cache to speed up repeated checks)
|
// Check if environment is already ready (using cache to speed up repeated checks)
|
||||||
if CheckEnvironmentReady(p.envPath, "import nemo.collections.asr") {
|
if CheckEnvironmentReady(p.envPath, "import nemo.collections.asr") {
|
||||||
modelPath := filepath.Join(p.envPath, "parakeet-tdt-0.6b-v3.nemo")
|
modelPath := filepath.Join(p.envPath, "parakeet-tdt-0.6b-v3.nemo")
|
||||||
scriptPath := filepath.Join(p.envPath, "transcribe.py")
|
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
|
||||||
bufferedScriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
|
bufferedScriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
|
||||||
|
|
||||||
// Check model, standard script, and buffered script all exist
|
// Check model, standard script, and buffered script all exist
|
||||||
if stat, err := os.Stat(modelPath); err == nil && stat.Size() > 1024*1024 {
|
if stat, err := os.Stat(modelPath); err == nil && stat.Size() > 1024*1024 {
|
||||||
@@ -160,7 +160,7 @@ func (p *ParakeetAdapter) PrepareEnvironment(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create transcription scripts (standard and buffered)
|
// Create transcription scripts (standard and buffered)
|
||||||
if err := p.createTranscriptionScript(); err != nil {
|
if err := p.copyTranscriptionScript(); err != nil {
|
||||||
return fmt.Errorf("failed to create transcription script: %w", err)
|
return fmt.Errorf("failed to create transcription script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,50 +179,23 @@ func (p *ParakeetAdapter) setupParakeetEnvironment() error {
|
|||||||
return fmt.Errorf("failed to create parakeet directory: %w", err)
|
return fmt.Errorf("failed to create parakeet directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
// Read pyproject.toml
|
||||||
pyprojectContent := fmt.Sprintf(`[project]
|
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
|
||||||
name = "parakeet-transcription"
|
if err != nil {
|
||||||
version = "0.1.0"
|
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||||
description = "Audio transcription using NVIDIA Parakeet models"
|
}
|
||||||
requires-python = ">=3.11"
|
|
||||||
dependencies = [
|
|
||||||
"nemo-toolkit[asr]",
|
|
||||||
"torch",
|
|
||||||
"torchaudio",
|
|
||||||
"librosa",
|
|
||||||
"soundfile",
|
|
||||||
"ml-dtypes>=0.3.1,<0.5.0",
|
|
||||||
"onnx>=1.15.0,<1.18.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
// The static file contains the default cu126 URL
|
||||||
torch = [
|
contentStr := strings.Replace(
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
string(pyprojectContent),
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
"https://download.pytorch.org/whl/cu126",
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
GetPyTorchWheelURL(),
|
||||||
]
|
1,
|
||||||
torchaudio = [
|
)
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
|
||||||
]
|
|
||||||
triton = [
|
|
||||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch"
|
|
||||||
url = "%s"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
||||||
`, GetPyTorchWheelURL())
|
|
||||||
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
|
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
|
||||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,200 +245,15 @@ func (p *ParakeetAdapter) downloadParakeetModel() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTranscriptionScript creates the Python script for Parakeet transcription
|
// copyTranscriptionScript creates the Python script for Parakeet transcription
|
||||||
func (p *ParakeetAdapter) createTranscriptionScript() error {
|
func (p *ParakeetAdapter) copyTranscriptionScript() error {
|
||||||
scriptContent := `#!/usr/bin/env python3
|
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/parakeet_transcribe.py")
|
||||||
"""
|
if err != nil {
|
||||||
NVIDIA Parakeet transcription script with timestamp support.
|
return fmt.Errorf("failed to read embedded transcribe.py: %w", err)
|
||||||
"""
|
}
|
||||||
|
|
||||||
import argparse
|
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
|
||||||
import json
|
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_audio(
|
|
||||||
audio_path: str,
|
|
||||||
timestamps: bool = True,
|
|
||||||
output_file: str = None,
|
|
||||||
context_left: int = 256,
|
|
||||||
context_right: int = 256,
|
|
||||||
include_confidence: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribe audio using NVIDIA Parakeet model.
|
|
||||||
"""
|
|
||||||
# Get the directory where this script is located
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
model_path = os.path.join(script_dir, "parakeet-tdt-0.6b-v3.nemo")
|
|
||||||
|
|
||||||
print(f"Script directory: {script_dir}")
|
|
||||||
print(f"Looking for model at: {model_path}")
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"Error during transcription: Can't find {model_path}")
|
|
||||||
# List files in the directory to help debug
|
|
||||||
try:
|
|
||||||
files = os.listdir(script_dir)
|
|
||||||
print(f"Files in {script_dir}: {files}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Could not list directory: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
|
||||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
|
||||||
|
|
||||||
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
|
||||||
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
|
||||||
from omegaconf import OmegaConf, open_dict
|
|
||||||
|
|
||||||
print("Disabling CUDA graphs in TDT decoder...")
|
|
||||||
dec_cfg = asr_model.cfg.decoding
|
|
||||||
|
|
||||||
# Add use_cuda_graph_decoder parameter to greedy config
|
|
||||||
with open_dict(dec_cfg.greedy):
|
|
||||||
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
|
||||||
|
|
||||||
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
|
||||||
asr_model.change_decoding_strategy(dec_cfg)
|
|
||||||
print("✓ CUDA graphs disabled successfully")
|
|
||||||
|
|
||||||
# Configure for long-form audio if context sizes are not default
|
|
||||||
if context_left != 256 or context_right != 256:
|
|
||||||
print(f"Configuring attention context: left={context_left}, right={context_right}")
|
|
||||||
try:
|
|
||||||
asr_model.change_attention_model(
|
|
||||||
self_attention_model="rel_pos_local_attn",
|
|
||||||
att_context_size=[context_left, context_right]
|
|
||||||
)
|
|
||||||
print("Long-form audio mode enabled")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to configure attention model: {e}")
|
|
||||||
print("Continuing with default attention settings")
|
|
||||||
|
|
||||||
print(f"Transcribing: {audio_path}")
|
|
||||||
|
|
||||||
if timestamps:
|
|
||||||
output = asr_model.transcribe([audio_path], timestamps=True)
|
|
||||||
|
|
||||||
# Extract text and timestamps
|
|
||||||
result_data = output[0]
|
|
||||||
text = result_data.text
|
|
||||||
word_timestamps = result_data.timestamp.get("word", [])
|
|
||||||
segment_timestamps = result_data.timestamp.get("segment", [])
|
|
||||||
|
|
||||||
print(f"Transcription: {text}")
|
|
||||||
|
|
||||||
# Prepare output data
|
|
||||||
output_data = {
|
|
||||||
"transcription": text,
|
|
||||||
"language": "en",
|
|
||||||
"word_timestamps": word_timestamps,
|
|
||||||
"segment_timestamps": segment_timestamps,
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "parakeet-tdt-0.6b-v3",
|
|
||||||
"context": {
|
|
||||||
"left": context_left,
|
|
||||||
"right": context_right
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if include_confidence:
|
|
||||||
# Add confidence scores if available
|
|
||||||
if hasattr(result_data, 'confidence') and result_data.confidence:
|
|
||||||
output_data["confidence"] = result_data.confidence
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
if output_file:
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
|
||||||
print(f"Results saved to: {output_file}")
|
|
||||||
else:
|
|
||||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Simple transcription without timestamps
|
|
||||||
output = asr_model.transcribe([audio_path])
|
|
||||||
text = output[0].text
|
|
||||||
|
|
||||||
output_data = {
|
|
||||||
"transcription": text,
|
|
||||||
"language": "en",
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "parakeet-tdt-0.6b-v3"
|
|
||||||
}
|
|
||||||
|
|
||||||
if output_file:
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
|
||||||
print(f"Results saved to: {output_file}")
|
|
||||||
else:
|
|
||||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Transcribe audio using NVIDIA Parakeet model"
|
|
||||||
)
|
|
||||||
parser.add_argument("audio_file", help="Path to audio file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--timestamps", action="store_true", default=True,
|
|
||||||
help="Include word and segment level timestamps"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-timestamps", dest="timestamps", action="store_false",
|
|
||||||
help="Disable timestamps"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output", "-o", help="Output file path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-left", type=int, default=256,
|
|
||||||
help="Left attention context size (default: 256)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-right", type=int, default=256,
|
|
||||||
help="Right attention context size (default: 256)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--include-confidence", action="store_true", default=True,
|
|
||||||
help="Include confidence scores"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-confidence", dest="include_confidence", action="store_false",
|
|
||||||
help="Exclude confidence scores"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate input file
|
|
||||||
if not os.path.exists(args.audio_file):
|
|
||||||
print(f"Error: Audio file not found: {args.audio_file}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
transcribe_audio(
|
|
||||||
audio_path=args.audio_file,
|
|
||||||
timestamps=args.timestamps,
|
|
||||||
output_file=args.output,
|
|
||||||
context_left=args.context_left,
|
|
||||||
context_right=args.context_right,
|
|
||||||
include_confidence=args.include_confidence,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during transcription: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
`
|
|
||||||
|
|
||||||
scriptPath := filepath.Join(p.envPath, "transcribe.py")
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to write transcription script: %w", err)
|
return fmt.Errorf("failed to write transcription script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -684,7 +472,7 @@ func (p *ParakeetAdapter) transcribeBuffered(ctx context.Context, input interfac
|
|||||||
func (p *ParakeetAdapter) buildParakeetArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) {
|
func (p *ParakeetAdapter) buildParakeetArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) {
|
||||||
outputFile := filepath.Join(tempDir, "result.json")
|
outputFile := filepath.Join(tempDir, "result.json")
|
||||||
|
|
||||||
scriptPath := filepath.Join(p.envPath, "transcribe.py")
|
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
|
||||||
args := []string{
|
args := []string{
|
||||||
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
|
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
|
||||||
input.FilePath,
|
input.FilePath,
|
||||||
@@ -771,181 +559,13 @@ func (p *ParakeetAdapter) parseResult(tempDir string, input interfaces.AudioInpu
|
|||||||
|
|
||||||
// createBufferedScript creates the Python script for NeMo buffered inference
|
// createBufferedScript creates the Python script for NeMo buffered inference
|
||||||
func (p *ParakeetAdapter) createBufferedScript() error {
|
func (p *ParakeetAdapter) createBufferedScript() error {
|
||||||
scriptContent := `#!/usr/bin/env python3
|
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/parakeet_transcribe_buffered.py")
|
||||||
"""
|
if err != nil {
|
||||||
NVIDIA Parakeet buffered inference for long audio files.
|
return fmt.Errorf("failed to read embedded transcribe_buffered.py: %w", err)
|
||||||
Splits audio into chunks to avoid GPU memory issues.
|
}
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
|
||||||
import json
|
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import librosa
|
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
|
|
||||||
|
|
||||||
def split_audio_file(audio_path, chunk_duration_secs=300):
|
|
||||||
"""Split audio file into chunks of specified duration."""
|
|
||||||
audio, sr = librosa.load(audio_path, sr=None, mono=True)
|
|
||||||
total_duration = len(audio) / sr
|
|
||||||
chunk_samples = int(chunk_duration_secs * sr)
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for start_sample in range(0, len(audio), chunk_samples):
|
|
||||||
end_sample = min(start_sample + chunk_samples, len(audio))
|
|
||||||
chunk_audio = audio[start_sample:end_sample]
|
|
||||||
start_time = start_sample / sr
|
|
||||||
chunks.append({
|
|
||||||
'audio': chunk_audio,
|
|
||||||
'start_time': start_time,
|
|
||||||
'duration': len(chunk_audio) / sr
|
|
||||||
})
|
|
||||||
|
|
||||||
return chunks, sr
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_buffered(
|
|
||||||
audio_path: str,
|
|
||||||
output_file: str = None,
|
|
||||||
chunk_duration_secs: float = 300, # 5 minutes default
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribe long audio by splitting into chunks and merging results.
|
|
||||||
"""
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
model_path = os.path.join(script_dir, "parakeet-tdt-0.6b-v3.nemo")
|
|
||||||
|
|
||||||
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"Error: Model not found at {model_path}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
|
||||||
|
|
||||||
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
|
||||||
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
|
||||||
from omegaconf import OmegaConf, open_dict
|
|
||||||
|
|
||||||
print("Disabling CUDA graphs in TDT decoder...")
|
|
||||||
dec_cfg = asr_model.cfg.decoding
|
|
||||||
|
|
||||||
# Add use_cuda_graph_decoder parameter to greedy config
|
|
||||||
with open_dict(dec_cfg.greedy):
|
|
||||||
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
|
||||||
|
|
||||||
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
|
||||||
asr_model.change_decoding_strategy(dec_cfg)
|
|
||||||
print("✓ CUDA graphs disabled successfully")
|
|
||||||
|
|
||||||
print(f"Splitting audio into {chunk_duration_secs}s chunks...")
|
|
||||||
chunks, sr = split_audio_file(audio_path, chunk_duration_secs)
|
|
||||||
print(f"Created {len(chunks)} chunks")
|
|
||||||
|
|
||||||
all_words = []
|
|
||||||
all_segments = []
|
|
||||||
full_text = []
|
|
||||||
|
|
||||||
for i, chunk_info in enumerate(chunks):
|
|
||||||
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...")
|
|
||||||
|
|
||||||
# Save chunk to temporary file
|
|
||||||
chunk_path = f"/tmp/chunk_{i}.wav"
|
|
||||||
sf.write(chunk_path, chunk_info['audio'], sr)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Transcribe chunk
|
|
||||||
output = asr_model.transcribe(
|
|
||||||
[chunk_path],
|
|
||||||
batch_size=1,
|
|
||||||
timestamps=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_data = output[0]
|
|
||||||
chunk_text = result_data.text
|
|
||||||
full_text.append(chunk_text)
|
|
||||||
|
|
||||||
# Extract and adjust timestamps
|
|
||||||
if hasattr(result_data, 'timestamp') and result_data.timestamp:
|
|
||||||
chunk_words = result_data.timestamp.get("word", [])
|
|
||||||
chunk_segments = result_data.timestamp.get("segment", [])
|
|
||||||
|
|
||||||
# Adjust timestamps by chunk start time
|
|
||||||
for word in chunk_words:
|
|
||||||
word_copy = dict(word)
|
|
||||||
word_copy['start'] += chunk_info['start_time']
|
|
||||||
word_copy['end'] += chunk_info['start_time']
|
|
||||||
all_words.append(word_copy)
|
|
||||||
|
|
||||||
for segment in chunk_segments:
|
|
||||||
seg_copy = dict(segment)
|
|
||||||
seg_copy['start'] += chunk_info['start_time']
|
|
||||||
seg_copy['end'] += chunk_info['start_time']
|
|
||||||
all_segments.append(seg_copy)
|
|
||||||
|
|
||||||
print(f"Chunk {i+1} complete: {len(chunk_text)} characters")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up temp file
|
|
||||||
if os.path.exists(chunk_path):
|
|
||||||
os.remove(chunk_path)
|
|
||||||
|
|
||||||
final_text = " ".join(full_text)
|
|
||||||
print(f"Transcription complete: {len(final_text)} characters total")
|
|
||||||
|
|
||||||
output_data = {
|
|
||||||
"transcription": final_text,
|
|
||||||
"language": "en",
|
|
||||||
"word_timestamps": all_words,
|
|
||||||
"segment_timestamps": all_segments,
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "parakeet-tdt-0.6b-v3",
|
|
||||||
"buffered": True,
|
|
||||||
"chunk_duration_secs": chunk_duration_secs,
|
|
||||||
"num_chunks": len(chunks),
|
|
||||||
}
|
|
||||||
|
|
||||||
if output_file:
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
|
||||||
print(f"Results saved to: {output_file}")
|
|
||||||
else:
|
|
||||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Transcribe long audio using NVIDIA Parakeet with chunking"
|
|
||||||
)
|
|
||||||
parser.add_argument("audio_file", help="Path to audio file")
|
|
||||||
parser.add_argument("--output", "-o", help="Output file path", required=True)
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk-len", type=float, default=300,
|
|
||||||
help="Chunk duration in seconds (default: 300 = 5 minutes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not os.path.exists(args.audio_file):
|
|
||||||
print(f"Error: Audio file not found: {args.audio_file}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
transcribe_buffered(
|
|
||||||
audio_path=args.audio_file,
|
|
||||||
output_file=args.output,
|
|
||||||
chunk_duration_secs=args.chunk_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
`
|
|
||||||
|
|
||||||
scriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to write buffered script: %w", err)
|
return fmt.Errorf("failed to write buffered script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -963,7 +583,7 @@ func (p *ParakeetAdapter) buildBufferedArgs(input interfaces.AudioInput, params
|
|||||||
chunkDuration = thresholdStr
|
chunkDuration = thresholdStr
|
||||||
}
|
}
|
||||||
|
|
||||||
scriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
|
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
|
||||||
args := []string{
|
args := []string{
|
||||||
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
|
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
|
||||||
input.FilePath,
|
input.FilePath,
|
||||||
|
|||||||
204
internal/transcription/adapters/py/nvidia/canary_transcribe.py
Normal file
204
internal/transcription/adapters/py/nvidia/canary_transcribe.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
NVIDIA Canary multilingual transcription and translation script.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_audio(
|
||||||
|
audio_path: str,
|
||||||
|
source_lang: str = "en",
|
||||||
|
target_lang: str = "en",
|
||||||
|
task: str = "transcribe",
|
||||||
|
timestamps: bool = True,
|
||||||
|
output_file: str = None,
|
||||||
|
include_confidence: bool = True,
|
||||||
|
preserve_formatting: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transcribe or translate audio using NVIDIA Canary model.
|
||||||
|
"""
|
||||||
|
# Determine model path
|
||||||
|
model_filename = "canary-1b-v2.nemo"
|
||||||
|
model_path = None
|
||||||
|
|
||||||
|
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||||
|
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||||
|
if not virtual_env:
|
||||||
|
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
project_root = os.path.dirname(virtual_env)
|
||||||
|
model_path = os.path.join(project_root, model_filename)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loading NVIDIA Canary model from: {model_path}")
|
||||||
|
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||||
|
|
||||||
|
print(f"Processing: {audio_path}")
|
||||||
|
print(f"Task: {task}")
|
||||||
|
print(f"Source language: {source_lang}")
|
||||||
|
print(f"Target language: {target_lang}")
|
||||||
|
|
||||||
|
if timestamps:
|
||||||
|
if task == "translate" and source_lang != target_lang:
|
||||||
|
# Translation with timestamps
|
||||||
|
output = asr_model.transcribe(
|
||||||
|
[audio_path],
|
||||||
|
source_lang=source_lang,
|
||||||
|
target_lang=target_lang,
|
||||||
|
timestamps=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Transcription with timestamps
|
||||||
|
output = asr_model.transcribe(
|
||||||
|
[audio_path],
|
||||||
|
source_lang=source_lang,
|
||||||
|
target_lang=target_lang,
|
||||||
|
timestamps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract text and timestamps
|
||||||
|
result_data = output[0]
|
||||||
|
text = result_data.text
|
||||||
|
word_timestamps = result_data.timestamp.get("word", [])
|
||||||
|
segment_timestamps = result_data.timestamp.get("segment", [])
|
||||||
|
|
||||||
|
print(f"Result: {text}")
|
||||||
|
|
||||||
|
# Prepare output data
|
||||||
|
output_data = {
|
||||||
|
"transcription": text,
|
||||||
|
"source_language": source_lang,
|
||||||
|
"target_language": target_lang,
|
||||||
|
"task": task,
|
||||||
|
"word_timestamps": word_timestamps,
|
||||||
|
"segment_timestamps": segment_timestamps,
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "canary-1b-v2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_confidence:
|
||||||
|
# Add confidence scores if available
|
||||||
|
if hasattr(result_data, 'confidence') and result_data.confidence:
|
||||||
|
output_data["confidence"] = result_data.confidence
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
if output_file:
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to: {output_file}")
|
||||||
|
else:
|
||||||
|
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Simple transcription/translation without timestamps
|
||||||
|
if task == "translate" and source_lang != target_lang:
|
||||||
|
output = asr_model.transcribe(
|
||||||
|
[audio_path],
|
||||||
|
source_lang=source_lang,
|
||||||
|
target_lang=target_lang
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = asr_model.transcribe(
|
||||||
|
[audio_path],
|
||||||
|
source_lang=source_lang,
|
||||||
|
target_lang=target_lang
|
||||||
|
)
|
||||||
|
|
||||||
|
text = output[0].text
|
||||||
|
|
||||||
|
output_data = {
|
||||||
|
"transcription": text,
|
||||||
|
"source_language": source_lang,
|
||||||
|
"target_language": target_lang,
|
||||||
|
"task": task,
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "canary-1b-v2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to: {output_file}")
|
||||||
|
else:
|
||||||
|
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Transcribe or translate audio using NVIDIA Canary model"
|
||||||
|
)
|
||||||
|
parser.add_argument("audio_file", help="Path to audio file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--source-lang", default="en",
|
||||||
|
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
||||||
|
help="Source language (default: en)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-lang", default="en",
|
||||||
|
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
||||||
|
help="Target language (default: en)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task", choices=["transcribe", "translate"], default="transcribe",
|
||||||
|
help="Task to perform (default: transcribe)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timestamps", action="store_true", default=True,
|
||||||
|
help="Include word and segment level timestamps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-timestamps", dest="timestamps", action="store_false",
|
||||||
|
help="Disable timestamps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", help="Output file path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-confidence", action="store_true", default=True,
|
||||||
|
help="Include confidence scores"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-confidence", dest="include_confidence", action="store_false",
|
||||||
|
help="Exclude confidence scores"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preserve-formatting", action="store_true", default=True,
|
||||||
|
help="Preserve punctuation and capitalization"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate input file
|
||||||
|
if not os.path.exists(args.audio_file):
|
||||||
|
print(f"Error: Audio file not found: {args.audio_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcribe_audio(
|
||||||
|
audio_path=args.audio_file,
|
||||||
|
source_lang=args.source_lang,
|
||||||
|
target_lang=args.target_lang,
|
||||||
|
task=args.task,
|
||||||
|
timestamps=args.timestamps,
|
||||||
|
output_file=args.output,
|
||||||
|
include_confidence=args.include_confidence,
|
||||||
|
preserve_formatting=args.preserve_formatting,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during transcription: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
188
internal/transcription/adapters/py/nvidia/parakeet_transcribe.py
Normal file
188
internal/transcription/adapters/py/nvidia/parakeet_transcribe.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
NVIDIA Parakeet transcription script with timestamp support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_audio(
|
||||||
|
audio_path: str,
|
||||||
|
timestamps: bool = True,
|
||||||
|
output_file: str = None,
|
||||||
|
context_left: int = 256,
|
||||||
|
context_right: int = 256,
|
||||||
|
include_confidence: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transcribe audio using NVIDIA Parakeet model.
|
||||||
|
"""
|
||||||
|
# Determine model path
|
||||||
|
model_filename = "parakeet-tdt-0.6b-v3.nemo"
|
||||||
|
model_path = None
|
||||||
|
|
||||||
|
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||||
|
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||||
|
if not virtual_env:
|
||||||
|
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
project_root = os.path.dirname(virtual_env)
|
||||||
|
model_path = os.path.join(project_root, model_filename)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
||||||
|
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||||
|
|
||||||
|
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
||||||
|
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
||||||
|
from omegaconf import OmegaConf, open_dict
|
||||||
|
|
||||||
|
print("Disabling CUDA graphs in TDT decoder...")
|
||||||
|
dec_cfg = asr_model.cfg.decoding
|
||||||
|
|
||||||
|
# Add use_cuda_graph_decoder parameter to greedy config
|
||||||
|
with open_dict(dec_cfg.greedy):
|
||||||
|
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
||||||
|
|
||||||
|
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
||||||
|
asr_model.change_decoding_strategy(dec_cfg)
|
||||||
|
print("✓ CUDA graphs disabled successfully")
|
||||||
|
|
||||||
|
# Configure for long-form audio if context sizes are not default
|
||||||
|
if context_left != 256 or context_right != 256:
|
||||||
|
print(f"Configuring attention context: left={context_left}, right={context_right}")
|
||||||
|
try:
|
||||||
|
asr_model.change_attention_model(
|
||||||
|
self_attention_model="rel_pos_local_attn",
|
||||||
|
att_context_size=[context_left, context_right]
|
||||||
|
)
|
||||||
|
print("Long-form audio mode enabled")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to configure attention model: {e}")
|
||||||
|
print("Continuing with default attention settings")
|
||||||
|
|
||||||
|
print(f"Transcribing: {audio_path}")
|
||||||
|
|
||||||
|
if timestamps:
|
||||||
|
output = asr_model.transcribe([audio_path], timestamps=True)
|
||||||
|
|
||||||
|
# Extract text and timestamps
|
||||||
|
result_data = output[0]
|
||||||
|
text = result_data.text
|
||||||
|
word_timestamps = result_data.timestamp.get("word", [])
|
||||||
|
segment_timestamps = result_data.timestamp.get("segment", [])
|
||||||
|
|
||||||
|
print(f"Transcription: {text}")
|
||||||
|
|
||||||
|
# Prepare output data
|
||||||
|
output_data = {
|
||||||
|
"transcription": text,
|
||||||
|
"language": "en",
|
||||||
|
"word_timestamps": word_timestamps,
|
||||||
|
"segment_timestamps": segment_timestamps,
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "parakeet-tdt-0.6b-v3",
|
||||||
|
"context": {
|
||||||
|
"left": context_left,
|
||||||
|
"right": context_right
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_confidence:
|
||||||
|
# Add confidence scores if available
|
||||||
|
if hasattr(result_data, 'confidence') and result_data.confidence:
|
||||||
|
output_data["confidence"] = result_data.confidence
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
if output_file:
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to: {output_file}")
|
||||||
|
else:
|
||||||
|
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Simple transcription without timestamps
|
||||||
|
output = asr_model.transcribe([audio_path])
|
||||||
|
text = output[0].text
|
||||||
|
|
||||||
|
output_data = {
|
||||||
|
"transcription": text,
|
||||||
|
"language": "en",
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "parakeet-tdt-0.6b-v3"
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to: {output_file}")
|
||||||
|
else:
|
||||||
|
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Transcribe audio using NVIDIA Parakeet model"
|
||||||
|
)
|
||||||
|
parser.add_argument("audio_file", help="Path to audio file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--timestamps", action="store_true", default=True,
|
||||||
|
help="Include word and segment level timestamps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-timestamps", dest="timestamps", action="store_false",
|
||||||
|
help="Disable timestamps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", help="Output file path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-left", type=int, default=256,
|
||||||
|
help="Left attention context size (default: 256)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-right", type=int, default=256,
|
||||||
|
help="Right attention context size (default: 256)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-confidence", action="store_true", default=True,
|
||||||
|
help="Include confidence scores"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-confidence", dest="include_confidence", action="store_false",
|
||||||
|
help="Exclude confidence scores"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate input file
|
||||||
|
if not os.path.exists(args.audio_file):
|
||||||
|
print(f"Error: Audio file not found: {args.audio_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcribe_audio(
|
||||||
|
audio_path=args.audio_file,
|
||||||
|
timestamps=args.timestamps,
|
||||||
|
output_file=args.output,
|
||||||
|
context_left=args.context_left,
|
||||||
|
context_right=args.context_right,
|
||||||
|
include_confidence=args.include_confidence,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during transcription: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
NVIDIA Parakeet buffered inference for long audio files.
|
||||||
|
Splits audio into chunks to avoid GPU memory issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import librosa
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
|
|
||||||
|
def split_audio_file(audio_path, chunk_duration_secs=300):
|
||||||
|
"""Split audio file into chunks of specified duration."""
|
||||||
|
audio, sr = librosa.load(audio_path, sr=None, mono=True)
|
||||||
|
total_duration = len(audio) / sr
|
||||||
|
chunk_samples = int(chunk_duration_secs * sr)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for start_sample in range(0, len(audio), chunk_samples):
|
||||||
|
end_sample = min(start_sample + chunk_samples, len(audio))
|
||||||
|
chunk_audio = audio[start_sample:end_sample]
|
||||||
|
start_time = start_sample / sr
|
||||||
|
chunks.append({
|
||||||
|
'audio': chunk_audio,
|
||||||
|
'start_time': start_time,
|
||||||
|
'duration': len(chunk_audio) / sr
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks, sr
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_buffered(
|
||||||
|
audio_path: str,
|
||||||
|
output_file: str = None,
|
||||||
|
chunk_duration_secs: float = 300, # 5 minutes default
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transcribe long audio by splitting into chunks and merging results.
|
||||||
|
"""
|
||||||
|
# Determine model path
|
||||||
|
model_filename = "parakeet-tdt-0.6b-v3.nemo"
|
||||||
|
model_path = None
|
||||||
|
|
||||||
|
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||||
|
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||||
|
if not virtual_env:
|
||||||
|
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
project_root = os.path.dirname(virtual_env)
|
||||||
|
model_path = os.path.join(project_root, model_filename)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
||||||
|
|
||||||
|
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||||
|
|
||||||
|
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
||||||
|
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
||||||
|
from omegaconf import OmegaConf, open_dict
|
||||||
|
|
||||||
|
print("Disabling CUDA graphs in TDT decoder...")
|
||||||
|
dec_cfg = asr_model.cfg.decoding
|
||||||
|
|
||||||
|
# Add use_cuda_graph_decoder parameter to greedy config
|
||||||
|
with open_dict(dec_cfg.greedy):
|
||||||
|
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
||||||
|
|
||||||
|
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
||||||
|
asr_model.change_decoding_strategy(dec_cfg)
|
||||||
|
print("✓ CUDA graphs disabled successfully")
|
||||||
|
|
||||||
|
print(f"Splitting audio into {chunk_duration_secs}s chunks...")
|
||||||
|
chunks, sr = split_audio_file(audio_path, chunk_duration_secs)
|
||||||
|
print(f"Created {len(chunks)} chunks")
|
||||||
|
|
||||||
|
all_words = []
|
||||||
|
all_segments = []
|
||||||
|
full_text = []
|
||||||
|
|
||||||
|
for i, chunk_info in enumerate(chunks):
|
||||||
|
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...")
|
||||||
|
|
||||||
|
# Save chunk to temporary file
|
||||||
|
chunk_path = f"/tmp/chunk_{i}.wav"
|
||||||
|
sf.write(chunk_path, chunk_info['audio'], sr)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Transcribe chunk
|
||||||
|
output = asr_model.transcribe(
|
||||||
|
[chunk_path],
|
||||||
|
batch_size=1,
|
||||||
|
timestamps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_data = output[0]
|
||||||
|
chunk_text = result_data.text
|
||||||
|
full_text.append(chunk_text)
|
||||||
|
|
||||||
|
# Extract and adjust timestamps
|
||||||
|
if hasattr(result_data, 'timestamp') and result_data.timestamp:
|
||||||
|
chunk_words = result_data.timestamp.get("word", [])
|
||||||
|
chunk_segments = result_data.timestamp.get("segment", [])
|
||||||
|
|
||||||
|
# Adjust timestamps by chunk start time
|
||||||
|
for word in chunk_words:
|
||||||
|
word_copy = dict(word)
|
||||||
|
word_copy['start'] += chunk_info['start_time']
|
||||||
|
word_copy['end'] += chunk_info['start_time']
|
||||||
|
all_words.append(word_copy)
|
||||||
|
|
||||||
|
for segment in chunk_segments:
|
||||||
|
seg_copy = dict(segment)
|
||||||
|
seg_copy['start'] += chunk_info['start_time']
|
||||||
|
seg_copy['end'] += chunk_info['start_time']
|
||||||
|
all_segments.append(seg_copy)
|
||||||
|
|
||||||
|
print(f"Chunk {i+1} complete: {len(chunk_text)} characters")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temp file
|
||||||
|
if os.path.exists(chunk_path):
|
||||||
|
os.remove(chunk_path)
|
||||||
|
|
||||||
|
final_text = " ".join(full_text)
|
||||||
|
print(f"Transcription complete: {len(final_text)} characters total")
|
||||||
|
|
||||||
|
output_data = {
|
||||||
|
"transcription": final_text,
|
||||||
|
"language": "en",
|
||||||
|
"word_timestamps": all_words,
|
||||||
|
"segment_timestamps": all_segments,
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "parakeet-tdt-0.6b-v3",
|
||||||
|
"buffered": True,
|
||||||
|
"chunk_duration_secs": chunk_duration_secs,
|
||||||
|
"num_chunks": len(chunks),
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to: {output_file}")
|
||||||
|
else:
|
||||||
|
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Transcribe long audio using NVIDIA Parakeet with chunking"
|
||||||
|
)
|
||||||
|
parser.add_argument("audio_file", help="Path to audio file")
|
||||||
|
parser.add_argument("--output", "-o", help="Output file path", required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-len", type=float, default=300,
|
||||||
|
help="Chunk duration in seconds (default: 300 = 5 minutes)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.audio_file):
|
||||||
|
print(f"Error: Audio file not found: {args.audio_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
transcribe_buffered(
|
||||||
|
audio_path=args.audio_file,
|
||||||
|
output_file=args.output,
|
||||||
|
chunk_duration_secs=args.chunk_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
47
internal/transcription/adapters/py/nvidia/pyproject.toml
Normal file
47
internal/transcription/adapters/py/nvidia/pyproject.toml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
[project]
|
||||||
|
name = "nvidia-transcription"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Audio transcription and diarization using NVIDIA models (Canary, Parakeet, Sortformer)"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"nemo-toolkit[asr]",
|
||||||
|
"torch",
|
||||||
|
"torchaudio",
|
||||||
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"ml-dtypes>=0.3.1,<0.5.0",
|
||||||
|
"onnx>=1.15.0,<1.18.0",
|
||||||
|
# "pyannote.audio" # needed for sortformer or no?
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-mock>=3.12.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||||
|
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||||
|
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
torchaudio = [
|
||||||
|
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||||
|
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||||
|
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
triton = [
|
||||||
|
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch"
|
||||||
|
url = "https://download.pytorch.org/whl/cu126"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
335
internal/transcription/adapters/py/nvidia/sortformer_diarize.py
Normal file
335
internal/transcription/adapters/py/nvidia/sortformer_diarize.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
NVIDIA Sortformer speaker diarization script.
|
||||||
|
Uses diar_streaming_sortformer_4spk-v2 for optimized 4-speaker diarization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||||
|
except ImportError:
|
||||||
|
print("Error: NeMo not found. Please install nemo_toolkit[asr]")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def diarize_audio(
|
||||||
|
audio_path: str,
|
||||||
|
output_file: str,
|
||||||
|
batch_size: int = 1,
|
||||||
|
device: str = None,
|
||||||
|
max_speakers: int = 4,
|
||||||
|
output_format: str = "rttm",
|
||||||
|
streaming_mode: bool = False,
|
||||||
|
chunk_length_s: float = 30.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform speaker diarization using NVIDIA's Sortformer model.
|
||||||
|
"""
|
||||||
|
if device is None or device == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
print(f"Loading NVIDIA Sortformer diarization model...")
|
||||||
|
|
||||||
|
# Determine model path
|
||||||
|
model_filename = "diar_streaming_sortformer_4spk-v2.nemo"
|
||||||
|
model_path = None
|
||||||
|
|
||||||
|
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||||
|
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||||
|
if not virtual_env:
|
||||||
|
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
project_root = os.path.dirname(virtual_env)
|
||||||
|
model_path = os.path.join(project_root, model_filename)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Error: Model file not found: {model_filename} in project root: {project_root}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Load from local file
|
||||||
|
print(f"Loading model from path: {model_path}")
|
||||||
|
diar_model = SortformerEncLabelModel.restore_from(
|
||||||
|
restore_path=model_path,
|
||||||
|
map_location=device,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Switch to inference mode
|
||||||
|
diar_model.eval()
|
||||||
|
print("Model loaded successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading model: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Processing audio file: {audio_path}")
|
||||||
|
|
||||||
|
# Verify audio file exists
|
||||||
|
if not os.path.exists(audio_path):
|
||||||
|
print(f"Error: Audio file not found: {audio_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run diarization
|
||||||
|
print(f"Running diarization with batch_size={batch_size}, max_speakers={max_speakers}")
|
||||||
|
|
||||||
|
if streaming_mode:
|
||||||
|
print(f"Using streaming mode with chunk_length_s={chunk_length_s}")
|
||||||
|
# Note: Streaming mode implementation would go here
|
||||||
|
# For now, use standard diarization
|
||||||
|
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
||||||
|
else:
|
||||||
|
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
||||||
|
|
||||||
|
print(f"Diarization completed. Found segments: {len(predicted_segments)}")
|
||||||
|
|
||||||
|
# Process and save results
|
||||||
|
save_results(predicted_segments, output_file, audio_path, output_format)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during diarization: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(segments, output_file: str, audio_path: str, output_format: str):
|
||||||
|
"""
|
||||||
|
Save diarization results to output file.
|
||||||
|
Supports both JSON and RTTM formats based on output_format parameter.
|
||||||
|
"""
|
||||||
|
output_path = Path(output_file)
|
||||||
|
|
||||||
|
if output_format == "rttm":
|
||||||
|
save_rttm_format(segments, output_file, audio_path)
|
||||||
|
else:
|
||||||
|
save_json_format(segments, output_file, audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_json_format(segments, output_file: str, audio_path: str):
|
||||||
|
"""Save results in JSON format."""
|
||||||
|
results = {
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "nvidia/diar_streaming_sortformer_4spk-v2",
|
||||||
|
"segments": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle the case where segments is a list containing a single list of string entries
|
||||||
|
if len(segments) == 1 and isinstance(segments[0], list):
|
||||||
|
segments = segments[0]
|
||||||
|
|
||||||
|
# Convert segments to JSON format
|
||||||
|
speakers = set()
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
try:
|
||||||
|
# Handle different possible segment formats
|
||||||
|
if isinstance(segment, str):
|
||||||
|
# String format: "start end speaker_id"
|
||||||
|
parts = segment.strip().split()
|
||||||
|
if len(parts) >= 3:
|
||||||
|
segment_data = {
|
||||||
|
"start": float(parts[0]),
|
||||||
|
"end": float(parts[1]),
|
||||||
|
"speaker": str(parts[2]),
|
||||||
|
"duration": float(parts[1]) - float(parts[0]),
|
||||||
|
"confidence": 1.0,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print(f"Warning: Invalid string segment format: {segment}")
|
||||||
|
continue
|
||||||
|
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
||||||
|
# Standard pyannote-like format
|
||||||
|
segment_data = {
|
||||||
|
"start": float(segment.start),
|
||||||
|
"end": float(segment.end),
|
||||||
|
"speaker": str(segment.label),
|
||||||
|
"duration": float(segment.end - segment.start),
|
||||||
|
"confidence": getattr(segment, 'confidence', 1.0),
|
||||||
|
}
|
||||||
|
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
||||||
|
# List/tuple format: [start, end, speaker]
|
||||||
|
segment_data = {
|
||||||
|
"start": float(segment[0]),
|
||||||
|
"end": float(segment[1]),
|
||||||
|
"speaker": str(segment[2]),
|
||||||
|
"duration": float(segment[1] - segment[0]),
|
||||||
|
"confidence": 1.0,
|
||||||
|
}
|
||||||
|
elif isinstance(segment, dict):
|
||||||
|
# Dictionary format
|
||||||
|
segment_data = {
|
||||||
|
"start": float(segment.get('start', 0)),
|
||||||
|
"end": float(segment.get('end', 0)),
|
||||||
|
"speaker": str(segment.get('speaker', segment.get('label', f'speaker_{i}'))),
|
||||||
|
"duration": float(segment.get('end', 0) - segment.get('start', 0)),
|
||||||
|
"confidence": float(segment.get('confidence', 1.0)),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Fallback: try to extract attributes dynamically
|
||||||
|
segment_data = {
|
||||||
|
"start": float(getattr(segment, 'start', 0)),
|
||||||
|
"end": float(getattr(segment, 'end', 0)),
|
||||||
|
"speaker": str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}'))),
|
||||||
|
"duration": float(getattr(segment, 'end', 0) - getattr(segment, 'start', 0)),
|
||||||
|
"confidence": float(getattr(segment, 'confidence', 1.0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
results["segments"].append(segment_data)
|
||||||
|
speakers.add(segment_data["speaker"])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not process segment {i}: {e}")
|
||||||
|
print(f"Segment: {segment}")
|
||||||
|
|
||||||
|
# Sort by start time
|
||||||
|
if results["segments"]:
|
||||||
|
results["segments"].sort(key=lambda x: x["start"])
|
||||||
|
|
||||||
|
# Add summary statistics
|
||||||
|
results["speakers"] = sorted(speakers)
|
||||||
|
results["speaker_count"] = len(speakers)
|
||||||
|
results["total_segments"] = len(results["segments"])
|
||||||
|
results["total_duration"] = max(seg["end"] for seg in results["segments"]) if results["segments"] else 0
|
||||||
|
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Results saved to: {output_file}")
|
||||||
|
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_rttm_format(segments, output_file: str, audio_path: str):
|
||||||
|
"""Save results in RTTM (Rich Transcription Time Marked) format."""
|
||||||
|
audio_filename = Path(audio_path).stem
|
||||||
|
speakers = set()
|
||||||
|
|
||||||
|
# Handle the case where segments is a list containing a single list of string entries
|
||||||
|
if len(segments) == 1 and isinstance(segments[0], list):
|
||||||
|
segments = segments[0]
|
||||||
|
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
try:
|
||||||
|
# Handle different possible segment formats
|
||||||
|
if isinstance(segment, str):
|
||||||
|
# String format: "start end speaker_id"
|
||||||
|
parts = segment.strip().split()
|
||||||
|
if len(parts) >= 3:
|
||||||
|
start = float(parts[0])
|
||||||
|
end = float(parts[1])
|
||||||
|
speaker = str(parts[2])
|
||||||
|
else:
|
||||||
|
print(f"Warning: Invalid string segment format: {segment}")
|
||||||
|
continue
|
||||||
|
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
||||||
|
# Standard pyannote-like format
|
||||||
|
start = float(segment.start)
|
||||||
|
end = float(segment.end)
|
||||||
|
speaker = str(segment.label)
|
||||||
|
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
||||||
|
# List/tuple format: [start, end, speaker]
|
||||||
|
start = float(segment[0])
|
||||||
|
end = float(segment[1])
|
||||||
|
speaker = str(segment[2])
|
||||||
|
elif isinstance(segment, dict):
|
||||||
|
# Dictionary format
|
||||||
|
start = float(segment.get('start', 0))
|
||||||
|
end = float(segment.get('end', 0))
|
||||||
|
speaker = str(segment.get('speaker', segment.get('label', f'speaker_{i}')))
|
||||||
|
else:
|
||||||
|
# Fallback: try to extract attributes dynamically
|
||||||
|
start = float(getattr(segment, 'start', 0))
|
||||||
|
end = float(getattr(segment, 'end', 0))
|
||||||
|
speaker = str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}')))
|
||||||
|
|
||||||
|
duration = end - start
|
||||||
|
speakers.add(speaker)
|
||||||
|
|
||||||
|
# RTTM format: SPEAKER <filename> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
|
||||||
|
line = f"SPEAKER {audio_filename} 1 {start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not process segment {i} for RTTM: {e}")
|
||||||
|
print(f"Segment: {segment}")
|
||||||
|
|
||||||
|
print(f"RTTM results saved to: {output_file}")
|
||||||
|
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Speaker diarization using NVIDIA Sortformer model (local model only)",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
# Basic diarization with JSON output
|
||||||
|
python sortformer_diarize.py samples/sample.wav output.json
|
||||||
|
|
||||||
|
# Generate RTTM format output
|
||||||
|
python sortformer_diarize.py samples/sample.wav output.rttm
|
||||||
|
|
||||||
|
# Specify device and batch size
|
||||||
|
python sortformer_diarize.py --device cuda --batch-size 2 samples/sample.wav output.json
|
||||||
|
|
||||||
|
Note: This script requires diar_streaming_sortformer_4spk-v2.nemo to be in the same directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("audio_file", help="Path to input audio file (WAV, FLAC, etc.)")
|
||||||
|
parser.add_argument("output_file", help="Path to output file (.json for JSON format, .rttm for RTTM format)")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for processing (default: 1)")
|
||||||
|
parser.add_argument("--device", choices=["cuda", "cpu", "auto"], default="auto", help="Device to use for inference (default: auto-detect)")
|
||||||
|
parser.add_argument("--max-speakers", type=int, default=4, help="Maximum number of speakers (default: 4, optimized for this model)")
|
||||||
|
parser.add_argument("--output-format", choices=["json", "rttm"], help="Output format (auto-detected from file extension if not specified)")
|
||||||
|
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode")
|
||||||
|
parser.add_argument("--chunk-length-s", type=float, default=30.0, help="Chunk length in seconds for streaming mode (default: 30.0)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if not os.path.exists(args.audio_file):
|
||||||
|
print(f"Error: Audio file not found: {args.audio_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Auto-detect output format from file extension if not specified
|
||||||
|
if args.output_format is None:
|
||||||
|
if args.output_file.lower().endswith('.rttm'):
|
||||||
|
output_format = "rttm"
|
||||||
|
else:
|
||||||
|
output_format = "json"
|
||||||
|
else:
|
||||||
|
output_format = args.output_format
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
output_dir = Path(args.output_file).parent
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
device = None if args.device == "auto" else args.device
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
diarize_audio(
|
||||||
|
audio_path=args.audio_file,
|
||||||
|
output_file=args.output_file,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
device=device,
|
||||||
|
max_speakers=args.max_speakers,
|
||||||
|
output_format=output_format,
|
||||||
|
streaming_mode=args.streaming,
|
||||||
|
chunk_length_s=args.chunk_length_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
266
internal/transcription/adapters/py/pyannote/pyannote_diarize.py
Normal file
266
internal/transcription/adapters/py/pyannote/pyannote_diarize.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
PyAnnote speaker diarization script.
|
||||||
|
Processes audio files to identify and separate different speakers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Fix for PyTorch 2.6+ which defaults weights_only=True
|
||||||
|
# We need to allowlist PyAnnote's custom classes
|
||||||
|
try:
|
||||||
|
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
||||||
|
if hasattr(torch.serialization, "add_safe_globals"):
|
||||||
|
torch.serialization.add_safe_globals([Specifications, Problem, Resolution])
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not add safe globals: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def diarize_audio(
|
||||||
|
audio_path: str,
|
||||||
|
output_file: str,
|
||||||
|
hf_token: str,
|
||||||
|
model: str = "pyannote/speaker-diarization-community-1",
|
||||||
|
min_speakers: int = None,
|
||||||
|
max_speakers: int = None,
|
||||||
|
output_format: str = "rttm",
|
||||||
|
|
||||||
|
device: str = "auto"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform speaker diarization on audio file using PyAnnote.
|
||||||
|
"""
|
||||||
|
print(f"Loading PyAnnote speaker diarization pipeline: {model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the diarization pipeline
|
||||||
|
pipeline = Pipeline.from_pretrained(
|
||||||
|
model,
|
||||||
|
token=hf_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to specified device
|
||||||
|
# if device == "auto" or device == "cuda":
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
pipeline = pipeline.to(torch.device("cuda"))
|
||||||
|
print("Using CUDA for diarization")
|
||||||
|
elif device == "cuda":
|
||||||
|
print("CUDA requested but not available, falling back to CPU")
|
||||||
|
else:
|
||||||
|
print("CUDA not available, using CPU")
|
||||||
|
except ImportError:
|
||||||
|
print("PyTorch not available for CUDA, using CPU")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error moving to device: {e}, using CPU")
|
||||||
|
|
||||||
|
print("Pipeline loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading pipeline: {e}")
|
||||||
|
print("Make sure you have a valid Hugging Face token and have accepted the model's license")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Processing audio file: {audio_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run diarization
|
||||||
|
diarization_params = {}
|
||||||
|
if min_speakers is not None:
|
||||||
|
diarization_params["min_speakers"] = min_speakers
|
||||||
|
if max_speakers is not None:
|
||||||
|
diarization_params["max_speakers"] = max_speakers
|
||||||
|
|
||||||
|
if diarization_params:
|
||||||
|
print(f"Using speaker constraints: {diarization_params}")
|
||||||
|
diarization = pipeline(audio_path, **diarization_params)
|
||||||
|
else:
|
||||||
|
print("Using automatic speaker detection")
|
||||||
|
diarization = pipeline(audio_path)
|
||||||
|
|
||||||
|
print(f"Diarization completed. Saving results to: {output_file}")
|
||||||
|
|
||||||
|
if output_format == "rttm":
|
||||||
|
# Save the diarization output to RTTM format
|
||||||
|
with open(output_file, "w") as rttm:
|
||||||
|
diarization.write_rttm(rttm)
|
||||||
|
else:
|
||||||
|
# Save as JSON format
|
||||||
|
save_json_format(diarization, output_file, audio_path)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
speakers = set()
|
||||||
|
total_speech_time = 0.0
|
||||||
|
|
||||||
|
# Iterate over speaker diarization
|
||||||
|
# PyAnnote 4.x returns a DiarizeOutput object with a speaker_diarization attribute
|
||||||
|
if hasattr(diarization, "speaker_diarization"):
|
||||||
|
for turn, speaker in diarization.speaker_diarization:
|
||||||
|
speakers.add(speaker)
|
||||||
|
total_speech_time += turn.duration
|
||||||
|
elif hasattr(diarization, "itertracks"):
|
||||||
|
# Fallback for older versions
|
||||||
|
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
speakers.add(speaker)
|
||||||
|
total_speech_time += segment.duration
|
||||||
|
else:
|
||||||
|
# Try iterating directly (some versions return Annotation directly)
|
||||||
|
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
speakers.add(speaker)
|
||||||
|
total_speech_time += segment.duration
|
||||||
|
|
||||||
|
print(f"\nDiarization Summary:")
|
||||||
|
print(f" Speakers detected: {len(speakers)}")
|
||||||
|
print(f" Speaker labels: {sorted(speakers)}")
|
||||||
|
print(f" Total speech time: {total_speech_time:.2f} seconds")
|
||||||
|
print(f" Output file saved: {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during diarization: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def save_json_format(diarization, output_file: str, audio_path: str):
|
||||||
|
"""Save diarization results in JSON format."""
|
||||||
|
segments = []
|
||||||
|
speakers = set()
|
||||||
|
|
||||||
|
# PyAnnote 4.x
|
||||||
|
if hasattr(diarization, "speaker_diarization"):
|
||||||
|
for turn, speaker in diarization.speaker_diarization:
|
||||||
|
segments.append({
|
||||||
|
"start": turn.start,
|
||||||
|
"end": turn.end,
|
||||||
|
"speaker": speaker,
|
||||||
|
"confidence": 1.0,
|
||||||
|
"duration": turn.duration
|
||||||
|
})
|
||||||
|
speakers.add(speaker)
|
||||||
|
# Older versions
|
||||||
|
elif hasattr(diarization, "itertracks"):
|
||||||
|
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
segments.append({
|
||||||
|
"start": segment.start,
|
||||||
|
"end": segment.end,
|
||||||
|
"speaker": speaker,
|
||||||
|
"confidence": 1.0,
|
||||||
|
"duration": segment.duration
|
||||||
|
})
|
||||||
|
speakers.add(speaker)
|
||||||
|
|
||||||
|
# Sort segments by start time
|
||||||
|
segments.sort(key=lambda x: x["start"])
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"audio_file": audio_path,
|
||||||
|
"model": "pyannote/speaker-diarization-community-1",
|
||||||
|
"segments": segments,
|
||||||
|
"speakers": sorted(speakers),
|
||||||
|
"speaker_count": len(speakers),
|
||||||
|
"total_duration": max(seg["end"] for seg in segments) if segments else 0,
|
||||||
|
"processing_info": {
|
||||||
|
"total_segments": len(segments),
|
||||||
|
"total_speech_time": sum(seg["duration"] for seg in segments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Perform speaker diarization using PyAnnote.audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"audio_file",
|
||||||
|
help="Path to audio file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o",
|
||||||
|
required=True,
|
||||||
|
help="Output file path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-token",
|
||||||
|
required=True,
|
||||||
|
help="Hugging Face access token"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="pyannote/speaker-diarization-community-1",
|
||||||
|
help="PyAnnote model to use"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-speakers",
|
||||||
|
type=int,
|
||||||
|
help="Minimum number of speakers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-speakers",
|
||||||
|
type=int,
|
||||||
|
help="Maximum number of speakers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-format",
|
||||||
|
choices=["rttm", "json"],
|
||||||
|
default="rttm",
|
||||||
|
help="Output format"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
choices=["cpu", "cuda", "auto"],
|
||||||
|
default="auto",
|
||||||
|
help="Device to use for computation"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate input file
|
||||||
|
if not os.path.exists(args.audio_file):
|
||||||
|
print(f"Error: Audio file not found: {args.audio_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Validate speaker constraints
|
||||||
|
if args.min_speakers is not None and args.min_speakers < 1:
|
||||||
|
print("Error: min_speakers must be at least 1")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.max_speakers is not None and args.max_speakers < 1:
|
||||||
|
print("Error: max_speakers must be at least 1")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if (args.min_speakers is not None and args.max_speakers is not None and
|
||||||
|
args.min_speakers > args.max_speakers):
|
||||||
|
print("Error: min_speakers cannot be greater than max_speakers")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
output_path = Path(args.output)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
diarize_audio(
|
||||||
|
audio_path=args.audio_file,
|
||||||
|
output_file=args.output,
|
||||||
|
hf_token=args.hf_token,
|
||||||
|
model=args.model,
|
||||||
|
min_speakers=args.min_speakers,
|
||||||
|
max_speakers=args.max_speakers,
|
||||||
|
output_format=args.output_format,
|
||||||
|
device=args.device
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during diarization: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
39
internal/transcription/adapters/py/pyannote/pyproject.toml
Normal file
39
internal/transcription/adapters/py/pyannote/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
[project]
|
||||||
|
name = "pyannote-diarization"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Audio diarization using PyAnnote"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"torch>=2.5.0",
|
||||||
|
"torchaudio>=2.5.0",
|
||||||
|
"huggingface-hub>=0.28.1",
|
||||||
|
"pyannote.audio==4.0.2"
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-mock>=3.12.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||||
|
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||||
|
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
torchaudio = [
|
||||||
|
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||||
|
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||||
|
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch"
|
||||||
|
url = "https://download.pytorch.org/whl/cu126"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
@@ -2,6 +2,7 @@ package adapters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,6 +16,9 @@ import (
|
|||||||
"scriberr/pkg/logger"
|
"scriberr/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:embed py/pyannote/*
|
||||||
|
var pyannoteScripts embed.FS
|
||||||
|
|
||||||
const OutputFormatJSON = "json"
|
const OutputFormatJSON = "json"
|
||||||
|
|
||||||
// PyAnnoteAdapter implements the DiarizationAdapter interface for PyAnnote
|
// PyAnnoteAdapter implements the DiarizationAdapter interface for PyAnnote
|
||||||
@@ -182,7 +186,7 @@ func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error {
|
|||||||
if CheckEnvironmentReady(p.envPath, "from pyannote.audio import Pipeline") {
|
if CheckEnvironmentReady(p.envPath, "from pyannote.audio import Pipeline") {
|
||||||
logger.Info("PyAnnote already available in environment")
|
logger.Info("PyAnnote already available in environment")
|
||||||
// Still ensure script exists
|
// Still ensure script exists
|
||||||
if err := p.createDiarizationScript(); err != nil {
|
if err := p.copyDiarizationScript(); err != nil {
|
||||||
return fmt.Errorf("failed to create diarization script: %w", err)
|
return fmt.Errorf("failed to create diarization script: %w", err)
|
||||||
}
|
}
|
||||||
p.initialized = true
|
p.initialized = true
|
||||||
@@ -195,7 +199,7 @@ func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Always ensure diarization script exists
|
// Always ensure diarization script exists
|
||||||
if err := p.createDiarizationScript(); err != nil {
|
if err := p.copyDiarizationScript(); err != nil {
|
||||||
return fmt.Errorf("failed to create diarization script: %w", err)
|
return fmt.Errorf("failed to create diarization script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,45 +220,23 @@ func (p *PyAnnoteAdapter) setupPyAnnoteEnvironment() error {
|
|||||||
return fmt.Errorf("failed to create pyannote directory: %w", err)
|
return fmt.Errorf("failed to create pyannote directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
// Read pyproject.toml for PyAnnote
|
||||||
// Note: We explicitly pin torch and torchaudio to 2.1.2 to ensure compatibility with pyannote.audio 3.1
|
pyprojectContent, err := pyannoteScripts.ReadFile("py/pyannote/pyproject.toml")
|
||||||
// Newer versions of torchaudio (2.2+) removed AudioMetaData which causes crashes
|
if err != nil {
|
||||||
pyprojectContent := fmt.Sprintf(`[project]
|
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||||
name = "pyannote-diarization"
|
}
|
||||||
version = "0.1.0"
|
|
||||||
description = "Audio diarization using PyAnnote"
|
|
||||||
requires-python = ">=3.10"
|
|
||||||
dependencies = [
|
|
||||||
"torch>=2.5.0",
|
|
||||||
"torchaudio>=2.5.0",
|
|
||||||
"huggingface-hub>=0.28.1",
|
|
||||||
"pyannote.audio==4.0.2"
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||||
torch = [
|
// The static file contains the default cu126 URL
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
contentStr := strings.Replace(
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
string(pyprojectContent),
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
"https://download.pytorch.org/whl/cu126",
|
||||||
]
|
GetPyTorchWheelURL(),
|
||||||
torchaudio = [
|
1,
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
)
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch"
|
|
||||||
url = "%s"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
||||||
`, GetPyTorchWheelURL())
|
|
||||||
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
|
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
|
||||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,289 +252,20 @@ explicit = true
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createDiarizationScript creates the Python script for PyAnnote diarization
|
// copyDiarizationScript creates the Python script for PyAnnote diarization
|
||||||
func (p *PyAnnoteAdapter) createDiarizationScript() error {
|
func (p *PyAnnoteAdapter) copyDiarizationScript() error {
|
||||||
// Ensure the directory exists first
|
// Ensure the directory exists first
|
||||||
if err := os.MkdirAll(p.envPath, 0755); err != nil {
|
if err := os.MkdirAll(p.envPath, 0755); err != nil {
|
||||||
return fmt.Errorf("failed to create pyannote directory: %w", err)
|
return fmt.Errorf("failed to create pyannote directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scriptContent, err := pyannoteScripts.ReadFile("py/pyannote/pyannote_diarize.py")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read embedded pyannote_diarize.py: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
scriptPath := filepath.Join(p.envPath, "pyannote_diarize.py")
|
scriptPath := filepath.Join(p.envPath, "pyannote_diarize.py")
|
||||||
|
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||||
// Always recreate the script to ensure it's up to date with the adapter code
|
|
||||||
// if _, err := os.Stat(scriptPath); err == nil {
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
scriptContent := `#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
PyAnnote speaker diarization script.
|
|
||||||
Processes audio files to identify and separate different speakers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from pyannote.audio import Pipeline
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Fix for PyTorch 2.6+ which defaults weights_only=True
|
|
||||||
# We need to allowlist PyAnnote's custom classes
|
|
||||||
try:
|
|
||||||
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
|
||||||
if hasattr(torch.serialization, "add_safe_globals"):
|
|
||||||
torch.serialization.add_safe_globals([Specifications, Problem, Resolution])
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not add safe globals: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def diarize_audio(
|
|
||||||
audio_path: str,
|
|
||||||
output_file: str,
|
|
||||||
hf_token: str,
|
|
||||||
model: str = "pyannote/speaker-diarization-community-1",
|
|
||||||
min_speakers: int = None,
|
|
||||||
max_speakers: int = None,
|
|
||||||
output_format: str = "rttm",
|
|
||||||
|
|
||||||
device: str = "auto"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Perform speaker diarization on audio file using PyAnnote.
|
|
||||||
"""
|
|
||||||
print(f"Loading PyAnnote speaker diarization pipeline: {model}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize the diarization pipeline
|
|
||||||
pipeline = Pipeline.from_pretrained(
|
|
||||||
model,
|
|
||||||
token=hf_token
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move to specified device
|
|
||||||
# if device == "auto" or device == "cuda":
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
pipeline = pipeline.to(torch.device("cuda"))
|
|
||||||
print("Using CUDA for diarization")
|
|
||||||
elif device == "cuda":
|
|
||||||
print("CUDA requested but not available, falling back to CPU")
|
|
||||||
else:
|
|
||||||
print("CUDA not available, using CPU")
|
|
||||||
except ImportError:
|
|
||||||
print("PyTorch not available for CUDA, using CPU")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error moving to device: {e}, using CPU")
|
|
||||||
|
|
||||||
print("Pipeline loaded successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading pipeline: {e}")
|
|
||||||
print("Make sure you have a valid Hugging Face token and have accepted the model's license")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Processing audio file: {audio_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run diarization
|
|
||||||
diarization_params = {}
|
|
||||||
if min_speakers is not None:
|
|
||||||
diarization_params["min_speakers"] = min_speakers
|
|
||||||
if max_speakers is not None:
|
|
||||||
diarization_params["max_speakers"] = max_speakers
|
|
||||||
|
|
||||||
if diarization_params:
|
|
||||||
print(f"Using speaker constraints: {diarization_params}")
|
|
||||||
diarization = pipeline(audio_path, **diarization_params)
|
|
||||||
else:
|
|
||||||
print("Using automatic speaker detection")
|
|
||||||
diarization = pipeline(audio_path)
|
|
||||||
|
|
||||||
print(f"Diarization completed. Saving results to: {output_file}")
|
|
||||||
|
|
||||||
if output_format == "rttm":
|
|
||||||
# Save the diarization output to RTTM format
|
|
||||||
with open(output_file, "w") as rttm:
|
|
||||||
diarization.write_rttm(rttm)
|
|
||||||
else:
|
|
||||||
# Save as JSON format
|
|
||||||
save_json_format(diarization, output_file, audio_path)
|
|
||||||
|
|
||||||
# Print summary
|
|
||||||
speakers = set()
|
|
||||||
total_speech_time = 0.0
|
|
||||||
|
|
||||||
# Iterate over speaker diarization
|
|
||||||
# PyAnnote 4.x returns a DiarizeOutput object with a speaker_diarization attribute
|
|
||||||
if hasattr(diarization, "speaker_diarization"):
|
|
||||||
for turn, speaker in diarization.speaker_diarization:
|
|
||||||
speakers.add(speaker)
|
|
||||||
total_speech_time += turn.duration
|
|
||||||
elif hasattr(diarization, "itertracks"):
|
|
||||||
# Fallback for older versions
|
|
||||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
|
||||||
speakers.add(speaker)
|
|
||||||
total_speech_time += segment.duration
|
|
||||||
else:
|
|
||||||
# Try iterating directly (some versions return Annotation directly)
|
|
||||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
|
||||||
speakers.add(speaker)
|
|
||||||
total_speech_time += segment.duration
|
|
||||||
|
|
||||||
print(f"\nDiarization Summary:")
|
|
||||||
print(f" Speakers detected: {len(speakers)}")
|
|
||||||
print(f" Speaker labels: {sorted(speakers)}")
|
|
||||||
print(f" Total speech time: {total_speech_time:.2f} seconds")
|
|
||||||
print(f" Output file saved: {output_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during diarization: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def save_json_format(diarization, output_file: str, audio_path: str):
|
|
||||||
"""Save diarization results in JSON format."""
|
|
||||||
segments = []
|
|
||||||
speakers = set()
|
|
||||||
|
|
||||||
# PyAnnote 4.x
|
|
||||||
if hasattr(diarization, "speaker_diarization"):
|
|
||||||
for turn, speaker in diarization.speaker_diarization:
|
|
||||||
segments.append({
|
|
||||||
"start": turn.start,
|
|
||||||
"end": turn.end,
|
|
||||||
"speaker": speaker,
|
|
||||||
"confidence": 1.0,
|
|
||||||
"duration": turn.duration
|
|
||||||
})
|
|
||||||
speakers.add(speaker)
|
|
||||||
# Older versions
|
|
||||||
elif hasattr(diarization, "itertracks"):
|
|
||||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
|
||||||
segments.append({
|
|
||||||
"start": segment.start,
|
|
||||||
"end": segment.end,
|
|
||||||
"speaker": speaker,
|
|
||||||
"confidence": 1.0,
|
|
||||||
"duration": segment.duration
|
|
||||||
})
|
|
||||||
speakers.add(speaker)
|
|
||||||
|
|
||||||
# Sort segments by start time
|
|
||||||
segments.sort(key=lambda x: x["start"])
|
|
||||||
|
|
||||||
results = {
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "pyannote/speaker-diarization-community-1",
|
|
||||||
"segments": segments,
|
|
||||||
"speakers": sorted(speakers),
|
|
||||||
"speaker_count": len(speakers),
|
|
||||||
"total_duration": max(seg["end"] for seg in segments) if segments else 0,
|
|
||||||
"processing_info": {
|
|
||||||
"total_segments": len(segments),
|
|
||||||
"total_speech_time": sum(seg["duration"] for seg in segments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
json.dump(results, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Perform speaker diarization using PyAnnote.audio"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"audio_file",
|
|
||||||
help="Path to audio file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output", "-o",
|
|
||||||
required=True,
|
|
||||||
help="Output file path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hf-token",
|
|
||||||
required=True,
|
|
||||||
help="Hugging Face access token"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
default="pyannote/speaker-diarization-community-1",
|
|
||||||
help="PyAnnote model to use"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--min-speakers",
|
|
||||||
type=int,
|
|
||||||
help="Minimum number of speakers"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-speakers",
|
|
||||||
type=int,
|
|
||||||
help="Maximum number of speakers"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-format",
|
|
||||||
choices=["rttm", "json"],
|
|
||||||
default="rttm",
|
|
||||||
help="Output format"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
choices=["cpu", "cuda", "auto"],
|
|
||||||
default="auto",
|
|
||||||
help="Device to use for computation"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate input file
|
|
||||||
if not os.path.exists(args.audio_file):
|
|
||||||
print(f"Error: Audio file not found: {args.audio_file}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Validate speaker constraints
|
|
||||||
if args.min_speakers is not None and args.min_speakers < 1:
|
|
||||||
print("Error: min_speakers must be at least 1")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if args.max_speakers is not None and args.max_speakers < 1:
|
|
||||||
print("Error: max_speakers must be at least 1")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if (args.min_speakers is not None and args.max_speakers is not None and
|
|
||||||
args.min_speakers > args.max_speakers):
|
|
||||||
print("Error: min_speakers cannot be greater than max_speakers")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
output_path = Path(args.output)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
diarize_audio(
|
|
||||||
audio_path=args.audio_file,
|
|
||||||
output_file=args.output,
|
|
||||||
hf_token=args.hf_token,
|
|
||||||
model=args.model,
|
|
||||||
min_speakers=args.min_speakers,
|
|
||||||
max_speakers=args.max_speakers,
|
|
||||||
output_format=args.output_format,
|
|
||||||
device=args.device
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during diarization: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
`
|
|
||||||
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to write diarization script: %w", err)
|
return fmt.Errorf("failed to write diarization script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func (s *SortformerAdapter) PrepareEnvironment(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create diarization script
|
// Create diarization script
|
||||||
if err := s.createDiarizationScript(); err != nil {
|
if err := s.copyDiarizationScript(); err != nil {
|
||||||
return fmt.Errorf("failed to create diarization script: %w", err)
|
return fmt.Errorf("failed to create diarization script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,51 +195,23 @@ func (s *SortformerAdapter) setupSortformerEnvironment() error {
|
|||||||
return fmt.Errorf("failed to create sortformer directory: %w", err)
|
return fmt.Errorf("failed to create sortformer directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
// Read pyproject.toml
|
||||||
pyprojectContent := fmt.Sprintf(`[project]
|
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
|
||||||
name = "parakeet-transcription"
|
if err != nil {
|
||||||
version = "0.1.0"
|
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||||
description = "Audio transcription using NVIDIA Parakeet models"
|
}
|
||||||
requires-python = ">=3.11"
|
|
||||||
dependencies = [
|
|
||||||
"nemo-toolkit[asr]",
|
|
||||||
"torch",
|
|
||||||
"torchaudio",
|
|
||||||
"librosa",
|
|
||||||
"soundfile",
|
|
||||||
"ml-dtypes>=0.3.1,<0.5.0",
|
|
||||||
"onnx>=1.15.0,<1.18.0",
|
|
||||||
"pyannote.audio"
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
// The static file contains the default cu126 URL
|
||||||
torch = [
|
contentStr := strings.Replace(
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
string(pyprojectContent),
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
"https://download.pytorch.org/whl/cu126",
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
GetPyTorchWheelURL(),
|
||||||
]
|
1,
|
||||||
torchaudio = [
|
)
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
|
||||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
|
||||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
|
||||||
]
|
|
||||||
triton = [
|
|
||||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch"
|
|
||||||
url = "%s"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
||||||
`, GetPyTorchWheelURL())
|
|
||||||
pyprojectPath := filepath.Join(s.envPath, "pyproject.toml")
|
pyprojectPath := filepath.Join(s.envPath, "pyproject.toml")
|
||||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,345 +261,15 @@ func (s *SortformerAdapter) downloadSortformerModel() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createDiarizationScript creates the Python script for Sortformer diarization
|
// copyDiarizationScript creates the Python script for Sortformer diarization
|
||||||
func (s *SortformerAdapter) createDiarizationScript() error {
|
func (s *SortformerAdapter) copyDiarizationScript() error {
|
||||||
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
|
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/sortformer_diarize.py")
|
||||||
|
if err != nil {
|
||||||
// Check if script already exists
|
return fmt.Errorf("failed to read embedded sortformer_diarize.py: %w", err)
|
||||||
if _, err := os.Stat(scriptPath); err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scriptContent := `#!/usr/bin/env python3
|
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
|
||||||
"""
|
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||||
NVIDIA Sortformer speaker diarization script.
|
|
||||||
Uses diar_streaming_sortformer_4spk-v2 for optimized 4-speaker diarization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
|
||||||
except ImportError:
|
|
||||||
print("Error: NeMo not found. Please install nemo_toolkit[asr]")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def diarize_audio(
|
|
||||||
audio_path: str,
|
|
||||||
output_file: str,
|
|
||||||
batch_size: int = 1,
|
|
||||||
device: str = None,
|
|
||||||
max_speakers: int = 4,
|
|
||||||
output_format: str = "rttm",
|
|
||||||
streaming_mode: bool = False,
|
|
||||||
chunk_length_s: float = 30.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Perform speaker diarization using NVIDIA's Sortformer model.
|
|
||||||
"""
|
|
||||||
if device is None or device == "auto":
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
print(f"Loading NVIDIA Sortformer diarization model...")
|
|
||||||
|
|
||||||
# Get the directory where this script is located
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
model_path = os.path.join(script_dir, "diar_streaming_sortformer_4spk-v2.nemo")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"Error: Model file not found: {model_path}")
|
|
||||||
print("Please ensure diar_streaming_sortformer_4spk-v2.nemo is in the same directory as this script")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Load from local file
|
|
||||||
print(f"Loading model from local path: {model_path}")
|
|
||||||
diar_model = SortformerEncLabelModel.restore_from(
|
|
||||||
restore_path=model_path,
|
|
||||||
map_location=device,
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Switch to inference mode
|
|
||||||
diar_model.eval()
|
|
||||||
print("Model loaded successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading model: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Processing audio file: {audio_path}")
|
|
||||||
|
|
||||||
# Verify audio file exists
|
|
||||||
if not os.path.exists(audio_path):
|
|
||||||
print(f"Error: Audio file not found: {audio_path}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run diarization
|
|
||||||
print(f"Running diarization with batch_size={batch_size}, max_speakers={max_speakers}")
|
|
||||||
|
|
||||||
if streaming_mode:
|
|
||||||
print(f"Using streaming mode with chunk_length_s={chunk_length_s}")
|
|
||||||
# Note: Streaming mode implementation would go here
|
|
||||||
# For now, use standard diarization
|
|
||||||
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
|
||||||
else:
|
|
||||||
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
|
||||||
|
|
||||||
print(f"Diarization completed. Found segments: {len(predicted_segments)}")
|
|
||||||
|
|
||||||
# Process and save results
|
|
||||||
save_results(predicted_segments, output_file, audio_path, output_format)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during diarization: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(segments, output_file: str, audio_path: str, output_format: str):
|
|
||||||
"""
|
|
||||||
Save diarization results to output file.
|
|
||||||
Supports both JSON and RTTM formats based on output_format parameter.
|
|
||||||
"""
|
|
||||||
output_path = Path(output_file)
|
|
||||||
|
|
||||||
if output_format == "rttm":
|
|
||||||
save_rttm_format(segments, output_file, audio_path)
|
|
||||||
else:
|
|
||||||
save_json_format(segments, output_file, audio_path)
|
|
||||||
|
|
||||||
|
|
||||||
def save_json_format(segments, output_file: str, audio_path: str):
|
|
||||||
"""Save results in JSON format."""
|
|
||||||
results = {
|
|
||||||
"audio_file": audio_path,
|
|
||||||
"model": "nvidia/diar_streaming_sortformer_4spk-v2",
|
|
||||||
"segments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Handle the case where segments is a list containing a single list of string entries
|
|
||||||
if len(segments) == 1 and isinstance(segments[0], list):
|
|
||||||
segments = segments[0]
|
|
||||||
|
|
||||||
# Convert segments to JSON format
|
|
||||||
speakers = set()
|
|
||||||
for i, segment in enumerate(segments):
|
|
||||||
try:
|
|
||||||
# Handle different possible segment formats
|
|
||||||
if isinstance(segment, str):
|
|
||||||
# String format: "start end speaker_id"
|
|
||||||
parts = segment.strip().split()
|
|
||||||
if len(parts) >= 3:
|
|
||||||
segment_data = {
|
|
||||||
"start": float(parts[0]),
|
|
||||||
"end": float(parts[1]),
|
|
||||||
"speaker": str(parts[2]),
|
|
||||||
"duration": float(parts[1]) - float(parts[0]),
|
|
||||||
"confidence": 1.0,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
print(f"Warning: Invalid string segment format: {segment}")
|
|
||||||
continue
|
|
||||||
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
|
||||||
# Standard pyannote-like format
|
|
||||||
segment_data = {
|
|
||||||
"start": float(segment.start),
|
|
||||||
"end": float(segment.end),
|
|
||||||
"speaker": str(segment.label),
|
|
||||||
"duration": float(segment.end - segment.start),
|
|
||||||
"confidence": getattr(segment, 'confidence', 1.0),
|
|
||||||
}
|
|
||||||
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
|
||||||
# List/tuple format: [start, end, speaker]
|
|
||||||
segment_data = {
|
|
||||||
"start": float(segment[0]),
|
|
||||||
"end": float(segment[1]),
|
|
||||||
"speaker": str(segment[2]),
|
|
||||||
"duration": float(segment[1] - segment[0]),
|
|
||||||
"confidence": 1.0,
|
|
||||||
}
|
|
||||||
elif isinstance(segment, dict):
|
|
||||||
# Dictionary format
|
|
||||||
segment_data = {
|
|
||||||
"start": float(segment.get('start', 0)),
|
|
||||||
"end": float(segment.get('end', 0)),
|
|
||||||
"speaker": str(segment.get('speaker', segment.get('label', f'speaker_{i}'))),
|
|
||||||
"duration": float(segment.get('end', 0) - segment.get('start', 0)),
|
|
||||||
"confidence": float(segment.get('confidence', 1.0)),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Fallback: try to extract attributes dynamically
|
|
||||||
segment_data = {
|
|
||||||
"start": float(getattr(segment, 'start', 0)),
|
|
||||||
"end": float(getattr(segment, 'end', 0)),
|
|
||||||
"speaker": str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}'))),
|
|
||||||
"duration": float(getattr(segment, 'end', 0) - getattr(segment, 'start', 0)),
|
|
||||||
"confidence": float(getattr(segment, 'confidence', 1.0)),
|
|
||||||
}
|
|
||||||
|
|
||||||
results["segments"].append(segment_data)
|
|
||||||
speakers.add(segment_data["speaker"])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not process segment {i}: {e}")
|
|
||||||
print(f"Segment: {segment}")
|
|
||||||
|
|
||||||
# Sort by start time
|
|
||||||
if results["segments"]:
|
|
||||||
results["segments"].sort(key=lambda x: x["start"])
|
|
||||||
|
|
||||||
# Add summary statistics
|
|
||||||
results["speakers"] = sorted(speakers)
|
|
||||||
results["speaker_count"] = len(speakers)
|
|
||||||
results["total_segments"] = len(results["segments"])
|
|
||||||
results["total_duration"] = max(seg["end"] for seg in results["segments"]) if results["segments"] else 0
|
|
||||||
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
json.dump(results, f, indent=2)
|
|
||||||
|
|
||||||
print(f"Results saved to: {output_file}")
|
|
||||||
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
|
||||||
|
|
||||||
|
|
||||||
def save_rttm_format(segments, output_file: str, audio_path: str):
|
|
||||||
"""Save results in RTTM (Rich Transcription Time Marked) format."""
|
|
||||||
audio_filename = Path(audio_path).stem
|
|
||||||
speakers = set()
|
|
||||||
|
|
||||||
# Handle the case where segments is a list containing a single list of string entries
|
|
||||||
if len(segments) == 1 and isinstance(segments[0], list):
|
|
||||||
segments = segments[0]
|
|
||||||
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
for i, segment in enumerate(segments):
|
|
||||||
try:
|
|
||||||
# Handle different possible segment formats
|
|
||||||
if isinstance(segment, str):
|
|
||||||
# String format: "start end speaker_id"
|
|
||||||
parts = segment.strip().split()
|
|
||||||
if len(parts) >= 3:
|
|
||||||
start = float(parts[0])
|
|
||||||
end = float(parts[1])
|
|
||||||
speaker = str(parts[2])
|
|
||||||
else:
|
|
||||||
print(f"Warning: Invalid string segment format: {segment}")
|
|
||||||
continue
|
|
||||||
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
|
||||||
# Standard pyannote-like format
|
|
||||||
start = float(segment.start)
|
|
||||||
end = float(segment.end)
|
|
||||||
speaker = str(segment.label)
|
|
||||||
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
|
||||||
# List/tuple format: [start, end, speaker]
|
|
||||||
start = float(segment[0])
|
|
||||||
end = float(segment[1])
|
|
||||||
speaker = str(segment[2])
|
|
||||||
elif isinstance(segment, dict):
|
|
||||||
# Dictionary format
|
|
||||||
start = float(segment.get('start', 0))
|
|
||||||
end = float(segment.get('end', 0))
|
|
||||||
speaker = str(segment.get('speaker', segment.get('label', f'speaker_{i}')))
|
|
||||||
else:
|
|
||||||
# Fallback: try to extract attributes dynamically
|
|
||||||
start = float(getattr(segment, 'start', 0))
|
|
||||||
end = float(getattr(segment, 'end', 0))
|
|
||||||
speaker = str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}')))
|
|
||||||
|
|
||||||
duration = end - start
|
|
||||||
speakers.add(speaker)
|
|
||||||
|
|
||||||
# RTTM format: SPEAKER <filename> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
|
|
||||||
line = f"SPEAKER {audio_filename} 1 {start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
|
|
||||||
f.write(line)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not process segment {i} for RTTM: {e}")
|
|
||||||
print(f"Segment: {segment}")
|
|
||||||
|
|
||||||
print(f"RTTM results saved to: {output_file}")
|
|
||||||
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Speaker diarization using NVIDIA Sortformer model (local model only)",
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
epilog="""
|
|
||||||
Examples:
|
|
||||||
# Basic diarization with JSON output
|
|
||||||
python sortformer_diarize.py samples/sample.wav output.json
|
|
||||||
|
|
||||||
# Generate RTTM format output
|
|
||||||
python sortformer_diarize.py samples/sample.wav output.rttm
|
|
||||||
|
|
||||||
# Specify device and batch size
|
|
||||||
python sortformer_diarize.py --device cuda --batch-size 2 samples/sample.wav output.json
|
|
||||||
|
|
||||||
Note: This script requires diar_streaming_sortformer_4spk-v2.nemo to be in the same directory.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("audio_file", help="Path to input audio file (WAV, FLAC, etc.)")
|
|
||||||
parser.add_argument("output_file", help="Path to output file (.json for JSON format, .rttm for RTTM format)")
|
|
||||||
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for processing (default: 1)")
|
|
||||||
parser.add_argument("--device", choices=["cuda", "cpu", "auto"], default="auto", help="Device to use for inference (default: auto-detect)")
|
|
||||||
parser.add_argument("--max-speakers", type=int, default=4, help="Maximum number of speakers (default: 4, optimized for this model)")
|
|
||||||
parser.add_argument("--output-format", choices=["json", "rttm"], help="Output format (auto-detected from file extension if not specified)")
|
|
||||||
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode")
|
|
||||||
parser.add_argument("--chunk-length-s", type=float, default=30.0, help="Chunk length in seconds for streaming mode (default: 30.0)")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not os.path.exists(args.audio_file):
|
|
||||||
print(f"Error: Audio file not found: {args.audio_file}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Auto-detect output format from file extension if not specified
|
|
||||||
if args.output_format is None:
|
|
||||||
if args.output_file.lower().endswith('.rttm'):
|
|
||||||
output_format = "rttm"
|
|
||||||
else:
|
|
||||||
output_format = "json"
|
|
||||||
else:
|
|
||||||
output_format = args.output_format
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
output_dir = Path(args.output_file).parent
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = None if args.device == "auto" else args.device
|
|
||||||
|
|
||||||
# Run diarization
|
|
||||||
diarize_audio(
|
|
||||||
audio_path=args.audio_file,
|
|
||||||
output_file=args.output_file,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
device=device,
|
|
||||||
max_speakers=args.max_speakers,
|
|
||||||
output_format=output_format,
|
|
||||||
streaming_mode=args.streaming,
|
|
||||||
chunk_length_s=args.chunk_length_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
`
|
|
||||||
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to write diarization script: %w", err)
|
return fmt.Errorf("failed to write diarization script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user