Extract python adapter scripts to proper files

This commit is contained in:
Paul Irish
2025-12-26 17:12:21 -08:00
committed by Rishikanth Chandrasekaran
parent edb65339b8
commit 50dd4130ff
12 changed files with 1372 additions and 1355 deletions

1
.gitignore vendored
View File

@@ -70,3 +70,4 @@ tmp/
# *.svg
# *.png
dhl.txt
__pycache__/

View File

@@ -2,6 +2,7 @@ package adapters
import (
"context"
"embed"
"encoding/json"
"fmt"
"os"
@@ -15,6 +16,9 @@ import (
"scriberr/pkg/logger"
)
//go:embed py/nvidia/*
var nvidiaScripts embed.FS
// CanaryAdapter implements the TranscriptionAdapter interface for NVIDIA Canary
type CanaryAdapter struct {
*BaseAdapter
@@ -185,7 +189,7 @@ func (c *CanaryAdapter) PrepareEnvironment(ctx context.Context) error {
}
// Create transcription script
if err := c.createTranscriptionScript(); err != nil {
if err := c.copyTranscriptionScript(); err != nil {
return fmt.Errorf("failed to create transcription script: %w", err)
}
@@ -207,49 +211,23 @@ func (c *CanaryAdapter) setupCanaryEnvironment() error {
return nil
}
// Create pyproject.toml with configurable PyTorch CUDA version
pyprojectContent := fmt.Sprintf(`[project]
name = "parakeet-transcription"
version = "0.1.0"
description = "Audio transcription using NVIDIA Parakeet models"
requires-python = ">=3.11"
dependencies = [
"nemo-toolkit[asr]",
"torch",
"torchaudio",
"librosa",
"soundfile",
"ml-dtypes>=0.3.1,<0.5.0",
"onnx>=1.15.0,<1.18.0",
]
// Read pyproject.toml
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
if err != nil {
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
}
[tool.uv.sources]
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
triton = [
{ index = "pytorch", marker = "sys_platform == 'linux'" }
]
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
// The static file contains the default cu126 URL
contentStr := strings.Replace(
string(pyprojectContent),
"https://download.pytorch.org/whl/cu126",
GetPyTorchWheelURL(),
1,
)
[[tool.uv.index]]
name = "pytorch"
url = "%s"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
`, GetPyTorchWheelURL())
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
pyprojectPath = filepath.Join(c.envPath, "pyproject.toml")
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
return fmt.Errorf("failed to write pyproject.toml: %w", err)
}
@@ -299,213 +277,15 @@ func (c *CanaryAdapter) downloadCanaryModel() error {
return nil
}
// createTranscriptionScript creates the Python script for Canary transcription
func (c *CanaryAdapter) createTranscriptionScript() error {
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
// Check if script already exists
if _, err := os.Stat(scriptPath); err == nil {
return nil
// copyTranscriptionScript creates the Python script for Canary transcription
func (c *CanaryAdapter) copyTranscriptionScript() error {
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/canary_transcribe.py")
if err != nil {
return fmt.Errorf("failed to read embedded canary_transcribe.py: %w", err)
}
scriptContent := `#!/usr/bin/env python3
"""
NVIDIA Canary multilingual transcription and translation script.
"""
import argparse
import json
import sys
import os
from pathlib import Path
import nemo.collections.asr as nemo_asr
def transcribe_audio(
audio_path: str,
source_lang: str = "en",
target_lang: str = "en",
task: str = "transcribe",
timestamps: bool = True,
output_file: str = None,
include_confidence: bool = True,
preserve_formatting: bool = True,
):
"""
Transcribe or translate audio using NVIDIA Canary model.
"""
# Get the directory where this script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "canary-1b-v2.nemo")
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
sys.exit(1)
print(f"Loading NVIDIA Canary model from: {model_path}")
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
print(f"Processing: {audio_path}")
print(f"Task: {task}")
print(f"Source language: {source_lang}")
print(f"Target language: {target_lang}")
if timestamps:
if task == "translate" and source_lang != target_lang:
# Translation with timestamps
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang,
timestamps=True
)
else:
# Transcription with timestamps
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang,
timestamps=True
)
# Extract text and timestamps
result_data = output[0]
text = result_data.text
word_timestamps = result_data.timestamp.get("word", [])
segment_timestamps = result_data.timestamp.get("segment", [])
print(f"Result: {text}")
# Prepare output data
output_data = {
"transcription": text,
"source_language": source_lang,
"target_language": target_lang,
"task": task,
"word_timestamps": word_timestamps,
"segment_timestamps": segment_timestamps,
"audio_file": audio_path,
"model": "canary-1b-v2"
}
if include_confidence:
# Add confidence scores if available
if hasattr(result_data, 'confidence') and result_data.confidence:
output_data["confidence"] = result_data.confidence
# Save to file
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
else:
# Simple transcription/translation without timestamps
if task == "translate" and source_lang != target_lang:
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang
)
else:
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang
)
text = output[0].text
output_data = {
"transcription": text,
"source_language": source_lang,
"target_language": target_lang,
"task": task,
"audio_file": audio_path,
"model": "canary-1b-v2"
}
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
def main():
parser = argparse.ArgumentParser(
description="Transcribe or translate audio using NVIDIA Canary model"
)
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument(
"--source-lang", default="en",
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
help="Source language (default: en)"
)
parser.add_argument(
"--target-lang", default="en",
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
help="Target language (default: en)"
)
parser.add_argument(
"--task", choices=["transcribe", "translate"], default="transcribe",
help="Task to perform (default: transcribe)"
)
parser.add_argument(
"--timestamps", action="store_true", default=True,
help="Include word and segment level timestamps"
)
parser.add_argument(
"--no-timestamps", dest="timestamps", action="store_false",
help="Disable timestamps"
)
parser.add_argument(
"--output", "-o", help="Output file path"
)
parser.add_argument(
"--include-confidence", action="store_true", default=True,
help="Include confidence scores"
)
parser.add_argument(
"--no-confidence", dest="include_confidence", action="store_false",
help="Exclude confidence scores"
)
parser.add_argument(
"--preserve-formatting", action="store_true", default=True,
help="Preserve punctuation and capitalization"
)
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
try:
transcribe_audio(
audio_path=args.audio_file,
source_lang=args.source_lang,
target_lang=args.target_lang,
task=args.task,
timestamps=args.timestamps,
output_file=args.output,
include_confidence=args.include_confidence,
preserve_formatting=args.preserve_formatting,
)
except Exception as e:
print(f"Error during transcription: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
`
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
return fmt.Errorf("failed to write transcription script: %w", err)
}

View File

@@ -128,8 +128,8 @@ func (p *ParakeetAdapter) PrepareEnvironment(ctx context.Context) error {
// Check if environment is already ready (using cache to speed up repeated checks)
if CheckEnvironmentReady(p.envPath, "import nemo.collections.asr") {
modelPath := filepath.Join(p.envPath, "parakeet-tdt-0.6b-v3.nemo")
scriptPath := filepath.Join(p.envPath, "transcribe.py")
bufferedScriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
bufferedScriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
// Check model, standard script, and buffered script all exist
if stat, err := os.Stat(modelPath); err == nil && stat.Size() > 1024*1024 {
@@ -160,7 +160,7 @@ func (p *ParakeetAdapter) PrepareEnvironment(ctx context.Context) error {
}
// Create transcription scripts (standard and buffered)
if err := p.createTranscriptionScript(); err != nil {
if err := p.copyTranscriptionScript(); err != nil {
return fmt.Errorf("failed to create transcription script: %w", err)
}
@@ -179,50 +179,23 @@ func (p *ParakeetAdapter) setupParakeetEnvironment() error {
return fmt.Errorf("failed to create parakeet directory: %w", err)
}
// Create pyproject.toml with configurable PyTorch CUDA version
pyprojectContent := fmt.Sprintf(`[project]
name = "parakeet-transcription"
version = "0.1.0"
description = "Audio transcription using NVIDIA Parakeet models"
requires-python = ">=3.11"
dependencies = [
"nemo-toolkit[asr]",
"torch",
"torchaudio",
"librosa",
"soundfile",
"ml-dtypes>=0.3.1,<0.5.0",
"onnx>=1.15.0,<1.18.0",
]
// Read pyproject.toml
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
if err != nil {
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
}
[tool.uv.sources]
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
triton = [
{ index = "pytorch", marker = "sys_platform == 'linux'" }
]
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
// The static file contains the default cu126 URL
contentStr := strings.Replace(
string(pyprojectContent),
"https://download.pytorch.org/whl/cu126",
GetPyTorchWheelURL(),
1,
)
[[tool.uv.index]]
name = "pytorch"
url = "%s"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
`, GetPyTorchWheelURL())
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
return fmt.Errorf("failed to write pyproject.toml: %w", err)
}
@@ -272,200 +245,15 @@ func (p *ParakeetAdapter) downloadParakeetModel() error {
return nil
}
// createTranscriptionScript creates the Python script for Parakeet transcription
func (p *ParakeetAdapter) createTranscriptionScript() error {
scriptContent := `#!/usr/bin/env python3
"""
NVIDIA Parakeet transcription script with timestamp support.
"""
// copyTranscriptionScript creates the Python script for Parakeet transcription
func (p *ParakeetAdapter) copyTranscriptionScript() error {
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/parakeet_transcribe.py")
if err != nil {
return fmt.Errorf("failed to read embedded transcribe.py: %w", err)
}
import argparse
import json
import sys
import os
from pathlib import Path
import nemo.collections.asr as nemo_asr
def transcribe_audio(
audio_path: str,
timestamps: bool = True,
output_file: str = None,
context_left: int = 256,
context_right: int = 256,
include_confidence: bool = True,
):
"""
Transcribe audio using NVIDIA Parakeet model.
"""
# Get the directory where this script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "parakeet-tdt-0.6b-v3.nemo")
print(f"Script directory: {script_dir}")
print(f"Looking for model at: {model_path}")
if not os.path.exists(model_path):
print(f"Error during transcription: Can't find {model_path}")
# List files in the directory to help debug
try:
files = os.listdir(script_dir)
print(f"Files in {script_dir}: {files}")
except Exception as e:
print(f"Could not list directory: {e}")
sys.exit(1)
print(f"Loading NVIDIA Parakeet model from: {model_path}")
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
from omegaconf import OmegaConf, open_dict
print("Disabling CUDA graphs in TDT decoder...")
dec_cfg = asr_model.cfg.decoding
# Add use_cuda_graph_decoder parameter to greedy config
with open_dict(dec_cfg.greedy):
dec_cfg.greedy['use_cuda_graph_decoder'] = False
# Apply the new decoding strategy (this rebuilds the decoder with our config)
asr_model.change_decoding_strategy(dec_cfg)
print("✓ CUDA graphs disabled successfully")
# Configure for long-form audio if context sizes are not default
if context_left != 256 or context_right != 256:
print(f"Configuring attention context: left={context_left}, right={context_right}")
try:
asr_model.change_attention_model(
self_attention_model="rel_pos_local_attn",
att_context_size=[context_left, context_right]
)
print("Long-form audio mode enabled")
except Exception as e:
print(f"Warning: Failed to configure attention model: {e}")
print("Continuing with default attention settings")
print(f"Transcribing: {audio_path}")
if timestamps:
output = asr_model.transcribe([audio_path], timestamps=True)
# Extract text and timestamps
result_data = output[0]
text = result_data.text
word_timestamps = result_data.timestamp.get("word", [])
segment_timestamps = result_data.timestamp.get("segment", [])
print(f"Transcription: {text}")
# Prepare output data
output_data = {
"transcription": text,
"language": "en",
"word_timestamps": word_timestamps,
"segment_timestamps": segment_timestamps,
"audio_file": audio_path,
"model": "parakeet-tdt-0.6b-v3",
"context": {
"left": context_left,
"right": context_right
}
}
if include_confidence:
# Add confidence scores if available
if hasattr(result_data, 'confidence') and result_data.confidence:
output_data["confidence"] = result_data.confidence
# Save to file
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
else:
# Simple transcription without timestamps
output = asr_model.transcribe([audio_path])
text = output[0].text
output_data = {
"transcription": text,
"language": "en",
"audio_file": audio_path,
"model": "parakeet-tdt-0.6b-v3"
}
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
def main():
parser = argparse.ArgumentParser(
description="Transcribe audio using NVIDIA Parakeet model"
)
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument(
"--timestamps", action="store_true", default=True,
help="Include word and segment level timestamps"
)
parser.add_argument(
"--no-timestamps", dest="timestamps", action="store_false",
help="Disable timestamps"
)
parser.add_argument(
"--output", "-o", help="Output file path"
)
parser.add_argument(
"--context-left", type=int, default=256,
help="Left attention context size (default: 256)"
)
parser.add_argument(
"--context-right", type=int, default=256,
help="Right attention context size (default: 256)"
)
parser.add_argument(
"--include-confidence", action="store_true", default=True,
help="Include confidence scores"
)
parser.add_argument(
"--no-confidence", dest="include_confidence", action="store_false",
help="Exclude confidence scores"
)
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
try:
transcribe_audio(
audio_path=args.audio_file,
timestamps=args.timestamps,
output_file=args.output,
context_left=args.context_left,
context_right=args.context_right,
include_confidence=args.include_confidence,
)
except Exception as e:
print(f"Error during transcription: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
`
scriptPath := filepath.Join(p.envPath, "transcribe.py")
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
return fmt.Errorf("failed to write transcription script: %w", err)
}
@@ -684,7 +472,7 @@ func (p *ParakeetAdapter) transcribeBuffered(ctx context.Context, input interfac
func (p *ParakeetAdapter) buildParakeetArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) {
outputFile := filepath.Join(tempDir, "result.json")
scriptPath := filepath.Join(p.envPath, "transcribe.py")
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
args := []string{
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
input.FilePath,
@@ -771,181 +559,13 @@ func (p *ParakeetAdapter) parseResult(tempDir string, input interfaces.AudioInpu
// createBufferedScript creates the Python script for NeMo buffered inference
func (p *ParakeetAdapter) createBufferedScript() error {
scriptContent := `#!/usr/bin/env python3
"""
NVIDIA Parakeet buffered inference for long audio files.
Splits audio into chunks to avoid GPU memory issues.
"""
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/parakeet_transcribe_buffered.py")
if err != nil {
return fmt.Errorf("failed to read embedded transcribe_buffered.py: %w", err)
}
import argparse
import json
import sys
import os
import librosa
import soundfile as sf
import numpy as np
from pathlib import Path
import nemo.collections.asr as nemo_asr
def split_audio_file(audio_path, chunk_duration_secs=300):
"""Split audio file into chunks of specified duration."""
audio, sr = librosa.load(audio_path, sr=None, mono=True)
total_duration = len(audio) / sr
chunk_samples = int(chunk_duration_secs * sr)
chunks = []
for start_sample in range(0, len(audio), chunk_samples):
end_sample = min(start_sample + chunk_samples, len(audio))
chunk_audio = audio[start_sample:end_sample]
start_time = start_sample / sr
chunks.append({
'audio': chunk_audio,
'start_time': start_time,
'duration': len(chunk_audio) / sr
})
return chunks, sr
def transcribe_buffered(
audio_path: str,
output_file: str = None,
chunk_duration_secs: float = 300, # 5 minutes default
):
"""
Transcribe long audio by splitting into chunks and merging results.
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "parakeet-tdt-0.6b-v3.nemo")
print(f"Loading NVIDIA Parakeet model from: {model_path}")
if not os.path.exists(model_path):
print(f"Error: Model not found at {model_path}")
sys.exit(1)
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
from omegaconf import OmegaConf, open_dict
print("Disabling CUDA graphs in TDT decoder...")
dec_cfg = asr_model.cfg.decoding
# Add use_cuda_graph_decoder parameter to greedy config
with open_dict(dec_cfg.greedy):
dec_cfg.greedy['use_cuda_graph_decoder'] = False
# Apply the new decoding strategy (this rebuilds the decoder with our config)
asr_model.change_decoding_strategy(dec_cfg)
print("✓ CUDA graphs disabled successfully")
print(f"Splitting audio into {chunk_duration_secs}s chunks...")
chunks, sr = split_audio_file(audio_path, chunk_duration_secs)
print(f"Created {len(chunks)} chunks")
all_words = []
all_segments = []
full_text = []
for i, chunk_info in enumerate(chunks):
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...")
# Save chunk to temporary file
chunk_path = f"/tmp/chunk_{i}.wav"
sf.write(chunk_path, chunk_info['audio'], sr)
try:
# Transcribe chunk
output = asr_model.transcribe(
[chunk_path],
batch_size=1,
timestamps=True,
)
result_data = output[0]
chunk_text = result_data.text
full_text.append(chunk_text)
# Extract and adjust timestamps
if hasattr(result_data, 'timestamp') and result_data.timestamp:
chunk_words = result_data.timestamp.get("word", [])
chunk_segments = result_data.timestamp.get("segment", [])
# Adjust timestamps by chunk start time
for word in chunk_words:
word_copy = dict(word)
word_copy['start'] += chunk_info['start_time']
word_copy['end'] += chunk_info['start_time']
all_words.append(word_copy)
for segment in chunk_segments:
seg_copy = dict(segment)
seg_copy['start'] += chunk_info['start_time']
seg_copy['end'] += chunk_info['start_time']
all_segments.append(seg_copy)
print(f"Chunk {i+1} complete: {len(chunk_text)} characters")
finally:
# Clean up temp file
if os.path.exists(chunk_path):
os.remove(chunk_path)
final_text = " ".join(full_text)
print(f"Transcription complete: {len(final_text)} characters total")
output_data = {
"transcription": final_text,
"language": "en",
"word_timestamps": all_words,
"segment_timestamps": all_segments,
"audio_file": audio_path,
"model": "parakeet-tdt-0.6b-v3",
"buffered": True,
"chunk_duration_secs": chunk_duration_secs,
"num_chunks": len(chunks),
}
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
def main():
parser = argparse.ArgumentParser(
description="Transcribe long audio using NVIDIA Parakeet with chunking"
)
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument("--output", "-o", help="Output file path", required=True)
parser.add_argument(
"--chunk-len", type=float, default=300,
help="Chunk duration in seconds (default: 300 = 5 minutes)"
)
args = parser.parse_args()
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
transcribe_buffered(
audio_path=args.audio_file,
output_file=args.output,
chunk_duration_secs=args.chunk_len,
)
if __name__ == "__main__":
main()
`
scriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
return fmt.Errorf("failed to write buffered script: %w", err)
}
@@ -963,7 +583,7 @@ func (p *ParakeetAdapter) buildBufferedArgs(input interfaces.AudioInput, params
chunkDuration = thresholdStr
}
scriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
args := []string{
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
input.FilePath,

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python3
"""
NVIDIA Canary multilingual transcription and translation script.
"""
import argparse
import json
import sys
import os
from pathlib import Path
import nemo.collections.asr as nemo_asr
def transcribe_audio(
audio_path: str,
source_lang: str = "en",
target_lang: str = "en",
task: str = "transcribe",
timestamps: bool = True,
output_file: str = None,
include_confidence: bool = True,
preserve_formatting: bool = True,
):
"""
Transcribe or translate audio using NVIDIA Canary model.
"""
# Determine model path
model_filename = "canary-1b-v2.nemo"
model_path = None
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
virtual_env = os.environ.get("VIRTUAL_ENV")
if not virtual_env:
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
sys.exit(1)
project_root = os.path.dirname(virtual_env)
model_path = os.path.join(project_root, model_filename)
if not os.path.exists(model_path):
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
sys.exit(1)
print(f"Loading NVIDIA Canary model from: {model_path}")
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
print(f"Processing: {audio_path}")
print(f"Task: {task}")
print(f"Source language: {source_lang}")
print(f"Target language: {target_lang}")
if timestamps:
if task == "translate" and source_lang != target_lang:
# Translation with timestamps
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang,
timestamps=True
)
else:
# Transcription with timestamps
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang,
timestamps=True
)
# Extract text and timestamps
result_data = output[0]
text = result_data.text
word_timestamps = result_data.timestamp.get("word", [])
segment_timestamps = result_data.timestamp.get("segment", [])
print(f"Result: {text}")
# Prepare output data
output_data = {
"transcription": text,
"source_language": source_lang,
"target_language": target_lang,
"task": task,
"word_timestamps": word_timestamps,
"segment_timestamps": segment_timestamps,
"audio_file": audio_path,
"model": "canary-1b-v2"
}
if include_confidence:
# Add confidence scores if available
if hasattr(result_data, 'confidence') and result_data.confidence:
output_data["confidence"] = result_data.confidence
# Save to file
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
else:
# Simple transcription/translation without timestamps
if task == "translate" and source_lang != target_lang:
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang
)
else:
output = asr_model.transcribe(
[audio_path],
source_lang=source_lang,
target_lang=target_lang
)
text = output[0].text
output_data = {
"transcription": text,
"source_language": source_lang,
"target_language": target_lang,
"task": task,
"audio_file": audio_path,
"model": "canary-1b-v2"
}
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
def main():
parser = argparse.ArgumentParser(
description="Transcribe or translate audio using NVIDIA Canary model"
)
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument(
"--source-lang", default="en",
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
help="Source language (default: en)"
)
parser.add_argument(
"--target-lang", default="en",
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
help="Target language (default: en)"
)
parser.add_argument(
"--task", choices=["transcribe", "translate"], default="transcribe",
help="Task to perform (default: transcribe)"
)
parser.add_argument(
"--timestamps", action="store_true", default=True,
help="Include word and segment level timestamps"
)
parser.add_argument(
"--no-timestamps", dest="timestamps", action="store_false",
help="Disable timestamps"
)
parser.add_argument(
"--output", "-o", help="Output file path"
)
parser.add_argument(
"--include-confidence", action="store_true", default=True,
help="Include confidence scores"
)
parser.add_argument(
"--no-confidence", dest="include_confidence", action="store_false",
help="Exclude confidence scores"
)
parser.add_argument(
"--preserve-formatting", action="store_true", default=True,
help="Preserve punctuation and capitalization"
)
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
try:
transcribe_audio(
audio_path=args.audio_file,
source_lang=args.source_lang,
target_lang=args.target_lang,
task=args.task,
timestamps=args.timestamps,
output_file=args.output,
include_confidence=args.include_confidence,
preserve_formatting=args.preserve_formatting,
)
except Exception as e:
print(f"Error during transcription: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,188 @@
#!/usr/bin/env python3
"""
NVIDIA Parakeet transcription script with timestamp support.
"""
import argparse
import json
import sys
import os
from pathlib import Path
import nemo.collections.asr as nemo_asr
def transcribe_audio(
audio_path: str,
timestamps: bool = True,
output_file: str = None,
context_left: int = 256,
context_right: int = 256,
include_confidence: bool = True,
):
"""
Transcribe audio using NVIDIA Parakeet model.
"""
# Determine model path
model_filename = "parakeet-tdt-0.6b-v3.nemo"
model_path = None
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
virtual_env = os.environ.get("VIRTUAL_ENV")
if not virtual_env:
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
sys.exit(1)
project_root = os.path.dirname(virtual_env)
model_path = os.path.join(project_root, model_filename)
if not os.path.exists(model_path):
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
sys.exit(1)
print(f"Loading NVIDIA Parakeet model from: {model_path}")
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
from omegaconf import OmegaConf, open_dict
print("Disabling CUDA graphs in TDT decoder...")
dec_cfg = asr_model.cfg.decoding
# Add use_cuda_graph_decoder parameter to greedy config
with open_dict(dec_cfg.greedy):
dec_cfg.greedy['use_cuda_graph_decoder'] = False
# Apply the new decoding strategy (this rebuilds the decoder with our config)
asr_model.change_decoding_strategy(dec_cfg)
print("✓ CUDA graphs disabled successfully")
# Configure for long-form audio if context sizes are not default
if context_left != 256 or context_right != 256:
print(f"Configuring attention context: left={context_left}, right={context_right}")
try:
asr_model.change_attention_model(
self_attention_model="rel_pos_local_attn",
att_context_size=[context_left, context_right]
)
print("Long-form audio mode enabled")
except Exception as e:
print(f"Warning: Failed to configure attention model: {e}")
print("Continuing with default attention settings")
print(f"Transcribing: {audio_path}")
if timestamps:
output = asr_model.transcribe([audio_path], timestamps=True)
# Extract text and timestamps
result_data = output[0]
text = result_data.text
word_timestamps = result_data.timestamp.get("word", [])
segment_timestamps = result_data.timestamp.get("segment", [])
print(f"Transcription: {text}")
# Prepare output data
output_data = {
"transcription": text,
"language": "en",
"word_timestamps": word_timestamps,
"segment_timestamps": segment_timestamps,
"audio_file": audio_path,
"model": "parakeet-tdt-0.6b-v3",
"context": {
"left": context_left,
"right": context_right
}
}
if include_confidence:
# Add confidence scores if available
if hasattr(result_data, 'confidence') and result_data.confidence:
output_data["confidence"] = result_data.confidence
# Save to file
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
else:
# Simple transcription without timestamps
output = asr_model.transcribe([audio_path])
text = output[0].text
output_data = {
"transcription": text,
"language": "en",
"audio_file": audio_path,
"model": "parakeet-tdt-0.6b-v3"
}
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
def main():
parser = argparse.ArgumentParser(
description="Transcribe audio using NVIDIA Parakeet model"
)
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument(
"--timestamps", action="store_true", default=True,
help="Include word and segment level timestamps"
)
parser.add_argument(
"--no-timestamps", dest="timestamps", action="store_false",
help="Disable timestamps"
)
parser.add_argument(
"--output", "-o", help="Output file path"
)
parser.add_argument(
"--context-left", type=int, default=256,
help="Left attention context size (default: 256)"
)
parser.add_argument(
"--context-right", type=int, default=256,
help="Right attention context size (default: 256)"
)
parser.add_argument(
"--include-confidence", action="store_true", default=True,
help="Include confidence scores"
)
parser.add_argument(
"--no-confidence", dest="include_confidence", action="store_false",
help="Exclude confidence scores"
)
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
try:
transcribe_audio(
audio_path=args.audio_file,
timestamps=args.timestamps,
output_file=args.output,
context_left=args.context_left,
context_right=args.context_right,
include_confidence=args.include_confidence,
)
except Exception as e:
print(f"Error during transcription: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""
NVIDIA Parakeet buffered inference for long audio files.
Splits audio into chunks to avoid GPU memory issues.
"""
import argparse
import json
import sys
import os
import librosa
import soundfile as sf
import numpy as np
from pathlib import Path
import nemo.collections.asr as nemo_asr
def split_audio_file(audio_path, chunk_duration_secs=300):
"""Split audio file into chunks of specified duration."""
audio, sr = librosa.load(audio_path, sr=None, mono=True)
total_duration = len(audio) / sr
chunk_samples = int(chunk_duration_secs * sr)
chunks = []
for start_sample in range(0, len(audio), chunk_samples):
end_sample = min(start_sample + chunk_samples, len(audio))
chunk_audio = audio[start_sample:end_sample]
start_time = start_sample / sr
chunks.append({
'audio': chunk_audio,
'start_time': start_time,
'duration': len(chunk_audio) / sr
})
return chunks, sr
def transcribe_buffered(
audio_path: str,
output_file: str = None,
chunk_duration_secs: float = 300, # 5 minutes default
):
"""
Transcribe long audio by splitting into chunks and merging results.
"""
# Determine model path
model_filename = "parakeet-tdt-0.6b-v3.nemo"
model_path = None
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
virtual_env = os.environ.get("VIRTUAL_ENV")
if not virtual_env:
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
sys.exit(1)
project_root = os.path.dirname(virtual_env)
model_path = os.path.join(project_root, model_filename)
if not os.path.exists(model_path):
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
sys.exit(1)
print(f"Loading NVIDIA Parakeet model from: {model_path}")
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
from omegaconf import OmegaConf, open_dict
print("Disabling CUDA graphs in TDT decoder...")
dec_cfg = asr_model.cfg.decoding
# Add use_cuda_graph_decoder parameter to greedy config
with open_dict(dec_cfg.greedy):
dec_cfg.greedy['use_cuda_graph_decoder'] = False
# Apply the new decoding strategy (this rebuilds the decoder with our config)
asr_model.change_decoding_strategy(dec_cfg)
print("✓ CUDA graphs disabled successfully")
print(f"Splitting audio into {chunk_duration_secs}s chunks...")
chunks, sr = split_audio_file(audio_path, chunk_duration_secs)
print(f"Created {len(chunks)} chunks")
all_words = []
all_segments = []
full_text = []
for i, chunk_info in enumerate(chunks):
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...")
# Save chunk to temporary file
chunk_path = f"/tmp/chunk_{i}.wav"
sf.write(chunk_path, chunk_info['audio'], sr)
try:
# Transcribe chunk
output = asr_model.transcribe(
[chunk_path],
batch_size=1,
timestamps=True,
)
result_data = output[0]
chunk_text = result_data.text
full_text.append(chunk_text)
# Extract and adjust timestamps
if hasattr(result_data, 'timestamp') and result_data.timestamp:
chunk_words = result_data.timestamp.get("word", [])
chunk_segments = result_data.timestamp.get("segment", [])
# Adjust timestamps by chunk start time
for word in chunk_words:
word_copy = dict(word)
word_copy['start'] += chunk_info['start_time']
word_copy['end'] += chunk_info['start_time']
all_words.append(word_copy)
for segment in chunk_segments:
seg_copy = dict(segment)
seg_copy['start'] += chunk_info['start_time']
seg_copy['end'] += chunk_info['start_time']
all_segments.append(seg_copy)
print(f"Chunk {i+1} complete: {len(chunk_text)} characters")
finally:
# Clean up temp file
if os.path.exists(chunk_path):
os.remove(chunk_path)
final_text = " ".join(full_text)
print(f"Transcription complete: {len(final_text)} characters total")
output_data = {
"transcription": final_text,
"language": "en",
"word_timestamps": all_words,
"segment_timestamps": all_segments,
"audio_file": audio_path,
"model": "parakeet-tdt-0.6b-v3",
"buffered": True,
"chunk_duration_secs": chunk_duration_secs,
"num_chunks": len(chunks),
}
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"Results saved to: {output_file}")
else:
print(json.dumps(output_data, indent=2, ensure_ascii=False))
def main():
parser = argparse.ArgumentParser(
description="Transcribe long audio using NVIDIA Parakeet with chunking"
)
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument("--output", "-o", help="Output file path", required=True)
parser.add_argument(
"--chunk-len", type=float, default=300,
help="Chunk duration in seconds (default: 300 = 5 minutes)"
)
args = parser.parse_args()
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
transcribe_buffered(
audio_path=args.audio_file,
output_file=args.output,
chunk_duration_secs=args.chunk_len,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
[project]
name = "nvidia-transcription"
version = "0.1.0"
description = "Audio transcription and diarization using NVIDIA models (Canary, Parakeet, Sortformer)"
requires-python = ">=3.11"
dependencies = [
"nemo-toolkit[asr]",
"torch",
"torchaudio",
"librosa",
"soundfile",
"ml-dtypes>=0.3.1,<0.5.0",
"onnx>=1.15.0,<1.18.0",
# "pyannote.audio" # needed for sortformer or no?
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-mock>=3.12.0",
]
[tool.uv.sources]
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
triton = [
{ index = "pytorch", marker = "sys_platform == 'linux'" }
]
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

View File

@@ -0,0 +1,335 @@
#!/usr/bin/env python3
"""
NVIDIA Sortformer speaker diarization script.
Uses diar_streaming_sortformer_4spk-v2 for optimized 4-speaker diarization.
"""
import argparse
import json
import sys
import os
from pathlib import Path
import torch
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except ImportError:
print("Error: NeMo not found. Please install nemo_toolkit[asr]")
sys.exit(1)
def diarize_audio(
audio_path: str,
output_file: str,
batch_size: int = 1,
device: str = None,
max_speakers: int = 4,
output_format: str = "rttm",
streaming_mode: bool = False,
chunk_length_s: float = 30.0,
):
"""
Perform speaker diarization using NVIDIA's Sortformer model.
"""
if device is None or device == "auto":
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
print(f"Loading NVIDIA Sortformer diarization model...")
# Determine model path
model_filename = "diar_streaming_sortformer_4spk-v2.nemo"
model_path = None
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
virtual_env = os.environ.get("VIRTUAL_ENV")
if not virtual_env:
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
sys.exit(1)
project_root = os.path.dirname(virtual_env)
model_path = os.path.join(project_root, model_filename)
try:
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_filename} in project root: {project_root}")
sys.exit(1)
# Load from local file
print(f"Loading model from path: {model_path}")
diar_model = SortformerEncLabelModel.restore_from(
restore_path=model_path,
map_location=device,
strict=False,
)
# Switch to inference mode
diar_model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
print(f"Processing audio file: {audio_path}")
# Verify audio file exists
if not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}")
sys.exit(1)
try:
# Run diarization
print(f"Running diarization with batch_size={batch_size}, max_speakers={max_speakers}")
if streaming_mode:
print(f"Using streaming mode with chunk_length_s={chunk_length_s}")
# Note: Streaming mode implementation would go here
# For now, use standard diarization
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
else:
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
print(f"Diarization completed. Found segments: {len(predicted_segments)}")
# Process and save results
save_results(predicted_segments, output_file, audio_path, output_format)
except Exception as e:
print(f"Error during diarization: {e}")
sys.exit(1)
def save_results(segments, output_file: str, audio_path: str, output_format: str):
"""
Save diarization results to output file.
Supports both JSON and RTTM formats based on output_format parameter.
"""
output_path = Path(output_file)
if output_format == "rttm":
save_rttm_format(segments, output_file, audio_path)
else:
save_json_format(segments, output_file, audio_path)
def save_json_format(segments, output_file: str, audio_path: str):
"""Save results in JSON format."""
results = {
"audio_file": audio_path,
"model": "nvidia/diar_streaming_sortformer_4spk-v2",
"segments": [],
}
# Handle the case where segments is a list containing a single list of string entries
if len(segments) == 1 and isinstance(segments[0], list):
segments = segments[0]
# Convert segments to JSON format
speakers = set()
for i, segment in enumerate(segments):
try:
# Handle different possible segment formats
if isinstance(segment, str):
# String format: "start end speaker_id"
parts = segment.strip().split()
if len(parts) >= 3:
segment_data = {
"start": float(parts[0]),
"end": float(parts[1]),
"speaker": str(parts[2]),
"duration": float(parts[1]) - float(parts[0]),
"confidence": 1.0,
}
else:
print(f"Warning: Invalid string segment format: {segment}")
continue
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
# Standard pyannote-like format
segment_data = {
"start": float(segment.start),
"end": float(segment.end),
"speaker": str(segment.label),
"duration": float(segment.end - segment.start),
"confidence": getattr(segment, 'confidence', 1.0),
}
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
# List/tuple format: [start, end, speaker]
segment_data = {
"start": float(segment[0]),
"end": float(segment[1]),
"speaker": str(segment[2]),
"duration": float(segment[1] - segment[0]),
"confidence": 1.0,
}
elif isinstance(segment, dict):
# Dictionary format
segment_data = {
"start": float(segment.get('start', 0)),
"end": float(segment.get('end', 0)),
"speaker": str(segment.get('speaker', segment.get('label', f'speaker_{i}'))),
"duration": float(segment.get('end', 0) - segment.get('start', 0)),
"confidence": float(segment.get('confidence', 1.0)),
}
else:
# Fallback: try to extract attributes dynamically
segment_data = {
"start": float(getattr(segment, 'start', 0)),
"end": float(getattr(segment, 'end', 0)),
"speaker": str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}'))),
"duration": float(getattr(segment, 'end', 0) - getattr(segment, 'start', 0)),
"confidence": float(getattr(segment, 'confidence', 1.0)),
}
results["segments"].append(segment_data)
speakers.add(segment_data["speaker"])
except Exception as e:
print(f"Warning: Could not process segment {i}: {e}")
print(f"Segment: {segment}")
# Sort by start time
if results["segments"]:
results["segments"].sort(key=lambda x: x["start"])
# Add summary statistics
results["speakers"] = sorted(speakers)
results["speaker_count"] = len(speakers)
results["total_segments"] = len(results["segments"])
results["total_duration"] = max(seg["end"] for seg in results["segments"]) if results["segments"] else 0
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to: {output_file}")
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
def save_rttm_format(segments, output_file: str, audio_path: str):
"""Save results in RTTM (Rich Transcription Time Marked) format."""
audio_filename = Path(audio_path).stem
speakers = set()
# Handle the case where segments is a list containing a single list of string entries
if len(segments) == 1 and isinstance(segments[0], list):
segments = segments[0]
with open(output_file, "w") as f:
for i, segment in enumerate(segments):
try:
# Handle different possible segment formats
if isinstance(segment, str):
# String format: "start end speaker_id"
parts = segment.strip().split()
if len(parts) >= 3:
start = float(parts[0])
end = float(parts[1])
speaker = str(parts[2])
else:
print(f"Warning: Invalid string segment format: {segment}")
continue
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
# Standard pyannote-like format
start = float(segment.start)
end = float(segment.end)
speaker = str(segment.label)
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
# List/tuple format: [start, end, speaker]
start = float(segment[0])
end = float(segment[1])
speaker = str(segment[2])
elif isinstance(segment, dict):
# Dictionary format
start = float(segment.get('start', 0))
end = float(segment.get('end', 0))
speaker = str(segment.get('speaker', segment.get('label', f'speaker_{i}')))
else:
# Fallback: try to extract attributes dynamically
start = float(getattr(segment, 'start', 0))
end = float(getattr(segment, 'end', 0))
speaker = str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}')))
duration = end - start
speakers.add(speaker)
# RTTM format: SPEAKER <filename> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
line = f"SPEAKER {audio_filename} 1 {start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
f.write(line)
except Exception as e:
print(f"Warning: Could not process segment {i} for RTTM: {e}")
print(f"Segment: {segment}")
print(f"RTTM results saved to: {output_file}")
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
def main():
parser = argparse.ArgumentParser(
description="Speaker diarization using NVIDIA Sortformer model (local model only)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic diarization with JSON output
python sortformer_diarize.py samples/sample.wav output.json
# Generate RTTM format output
python sortformer_diarize.py samples/sample.wav output.rttm
# Specify device and batch size
python sortformer_diarize.py --device cuda --batch-size 2 samples/sample.wav output.json
Note: This script requires diar_streaming_sortformer_4spk-v2.nemo to be in the same directory.
""",
)
parser.add_argument("audio_file", help="Path to input audio file (WAV, FLAC, etc.)")
parser.add_argument("output_file", help="Path to output file (.json for JSON format, .rttm for RTTM format)")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for processing (default: 1)")
parser.add_argument("--device", choices=["cuda", "cpu", "auto"], default="auto", help="Device to use for inference (default: auto-detect)")
parser.add_argument("--max-speakers", type=int, default=4, help="Maximum number of speakers (default: 4, optimized for this model)")
parser.add_argument("--output-format", choices=["json", "rttm"], help="Output format (auto-detected from file extension if not specified)")
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode")
parser.add_argument("--chunk-length-s", type=float, default=30.0, help="Chunk length in seconds for streaming mode (default: 30.0)")
args = parser.parse_args()
# Validate inputs
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
# Auto-detect output format from file extension if not specified
if args.output_format is None:
if args.output_file.lower().endswith('.rttm'):
output_format = "rttm"
else:
output_format = "json"
else:
output_format = args.output_format
# Create output directory if it doesn't exist
output_dir = Path(args.output_file).parent
output_dir.mkdir(parents=True, exist_ok=True)
device = None if args.device == "auto" else args.device
# Run diarization
diarize_audio(
audio_path=args.audio_file,
output_file=args.output_file,
batch_size=args.batch_size,
device=device,
max_speakers=args.max_speakers,
output_format=output_format,
streaming_mode=args.streaming,
chunk_length_s=args.chunk_length_s,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,266 @@
#!/usr/bin/env python3
"""
PyAnnote speaker diarization script.
Processes audio files to identify and separate different speakers.
"""
import argparse
import json
import sys
import os
from pathlib import Path
from pyannote.audio import Pipeline
import torch
# Fix for PyTorch 2.6+ which defaults weights_only=True
# We need to allowlist PyAnnote's custom classes
try:
from pyannote.audio.core.task import Specifications, Problem, Resolution
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([Specifications, Problem, Resolution])
except ImportError:
pass
except Exception as e:
print(f"Warning: Could not add safe globals: {e}")
def diarize_audio(
audio_path: str,
output_file: str,
hf_token: str,
model: str = "pyannote/speaker-diarization-community-1",
min_speakers: int = None,
max_speakers: int = None,
output_format: str = "rttm",
device: str = "auto"
):
"""
Perform speaker diarization on audio file using PyAnnote.
"""
print(f"Loading PyAnnote speaker diarization pipeline: {model}")
try:
# Initialize the diarization pipeline
pipeline = Pipeline.from_pretrained(
model,
token=hf_token
)
# Move to specified device
# if device == "auto" or device == "cuda":
try:
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
print("Using CUDA for diarization")
elif device == "cuda":
print("CUDA requested but not available, falling back to CPU")
else:
print("CUDA not available, using CPU")
except ImportError:
print("PyTorch not available for CUDA, using CPU")
except Exception as e:
print(f"Error moving to device: {e}, using CPU")
print("Pipeline loaded successfully")
except Exception as e:
print(f"Error loading pipeline: {e}")
print("Make sure you have a valid Hugging Face token and have accepted the model's license")
sys.exit(1)
print(f"Processing audio file: {audio_path}")
try:
# Run diarization
diarization_params = {}
if min_speakers is not None:
diarization_params["min_speakers"] = min_speakers
if max_speakers is not None:
diarization_params["max_speakers"] = max_speakers
if diarization_params:
print(f"Using speaker constraints: {diarization_params}")
diarization = pipeline(audio_path, **diarization_params)
else:
print("Using automatic speaker detection")
diarization = pipeline(audio_path)
print(f"Diarization completed. Saving results to: {output_file}")
if output_format == "rttm":
# Save the diarization output to RTTM format
with open(output_file, "w") as rttm:
diarization.write_rttm(rttm)
else:
# Save as JSON format
save_json_format(diarization, output_file, audio_path)
# Print summary
speakers = set()
total_speech_time = 0.0
# Iterate over speaker diarization
# PyAnnote 4.x returns a DiarizeOutput object with a speaker_diarization attribute
if hasattr(diarization, "speaker_diarization"):
for turn, speaker in diarization.speaker_diarization:
speakers.add(speaker)
total_speech_time += turn.duration
elif hasattr(diarization, "itertracks"):
# Fallback for older versions
for segment, track, speaker in diarization.itertracks(yield_label=True):
speakers.add(speaker)
total_speech_time += segment.duration
else:
# Try iterating directly (some versions return Annotation directly)
for segment, track, speaker in diarization.itertracks(yield_label=True):
speakers.add(speaker)
total_speech_time += segment.duration
print(f"\nDiarization Summary:")
print(f" Speakers detected: {len(speakers)}")
print(f" Speaker labels: {sorted(speakers)}")
print(f" Total speech time: {total_speech_time:.2f} seconds")
print(f" Output file saved: {output_file}")
except Exception as e:
print(f"Error during diarization: {e}")
sys.exit(1)
def save_json_format(diarization, output_file: str, audio_path: str):
"""Save diarization results in JSON format."""
segments = []
speakers = set()
# PyAnnote 4.x
if hasattr(diarization, "speaker_diarization"):
for turn, speaker in diarization.speaker_diarization:
segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker,
"confidence": 1.0,
"duration": turn.duration
})
speakers.add(speaker)
# Older versions
elif hasattr(diarization, "itertracks"):
for segment, track, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": segment.start,
"end": segment.end,
"speaker": speaker,
"confidence": 1.0,
"duration": segment.duration
})
speakers.add(speaker)
# Sort segments by start time
segments.sort(key=lambda x: x["start"])
results = {
"audio_file": audio_path,
"model": "pyannote/speaker-diarization-community-1",
"segments": segments,
"speakers": sorted(speakers),
"speaker_count": len(speakers),
"total_duration": max(seg["end"] for seg in segments) if segments else 0,
"processing_info": {
"total_segments": len(segments),
"total_speech_time": sum(seg["duration"] for seg in segments)
}
}
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
def main():
parser = argparse.ArgumentParser(
description="Perform speaker diarization using PyAnnote.audio"
)
parser.add_argument(
"audio_file",
help="Path to audio file"
)
parser.add_argument(
"--output", "-o",
required=True,
help="Output file path"
)
parser.add_argument(
"--hf-token",
required=True,
help="Hugging Face access token"
)
parser.add_argument(
"--model",
default="pyannote/speaker-diarization-community-1",
help="PyAnnote model to use"
)
parser.add_argument(
"--min-speakers",
type=int,
help="Minimum number of speakers"
)
parser.add_argument(
"--max-speakers",
type=int,
help="Maximum number of speakers"
)
parser.add_argument(
"--output-format",
choices=["rttm", "json"],
default="rttm",
help="Output format"
)
parser.add_argument(
"--device",
choices=["cpu", "cuda", "auto"],
default="auto",
help="Device to use for computation"
)
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
# Validate speaker constraints
if args.min_speakers is not None and args.min_speakers < 1:
print("Error: min_speakers must be at least 1")
sys.exit(1)
if args.max_speakers is not None and args.max_speakers < 1:
print("Error: max_speakers must be at least 1")
sys.exit(1)
if (args.min_speakers is not None and args.max_speakers is not None and
args.min_speakers > args.max_speakers):
print("Error: min_speakers cannot be greater than max_speakers")
sys.exit(1)
# Create output directory if it doesn't exist
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
diarize_audio(
audio_path=args.audio_file,
output_file=args.output,
hf_token=args.hf_token,
model=args.model,
min_speakers=args.min_speakers,
max_speakers=args.max_speakers,
output_format=args.output_format,
device=args.device
)
except Exception as e:
print(f"Error during diarization: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,39 @@
[project]
name = "pyannote-diarization"
version = "0.1.0"
description = "Audio diarization using PyAnnote"
requires-python = ">=3.10"
dependencies = [
"torch>=2.5.0",
"torchaudio>=2.5.0",
"huggingface-hub>=0.28.1",
"pyannote.audio==4.0.2"
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-mock>=3.12.0",
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

View File

@@ -2,6 +2,7 @@ package adapters
import (
"context"
"embed"
"encoding/json"
"fmt"
"os"
@@ -15,6 +16,9 @@ import (
"scriberr/pkg/logger"
)
//go:embed py/pyannote/*
var pyannoteScripts embed.FS
const OutputFormatJSON = "json"
// PyAnnoteAdapter implements the DiarizationAdapter interface for PyAnnote
@@ -182,7 +186,7 @@ func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error {
if CheckEnvironmentReady(p.envPath, "from pyannote.audio import Pipeline") {
logger.Info("PyAnnote already available in environment")
// Still ensure script exists
if err := p.createDiarizationScript(); err != nil {
if err := p.copyDiarizationScript(); err != nil {
return fmt.Errorf("failed to create diarization script: %w", err)
}
p.initialized = true
@@ -195,7 +199,7 @@ func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error {
}
// Always ensure diarization script exists
if err := p.createDiarizationScript(); err != nil {
if err := p.copyDiarizationScript(); err != nil {
return fmt.Errorf("failed to create diarization script: %w", err)
}
@@ -216,45 +220,23 @@ func (p *PyAnnoteAdapter) setupPyAnnoteEnvironment() error {
return fmt.Errorf("failed to create pyannote directory: %w", err)
}
// Create pyproject.toml with configurable PyTorch CUDA version
// Note: We explicitly pin torch and torchaudio to 2.1.2 to ensure compatibility with pyannote.audio 3.1
// Newer versions of torchaudio (2.2+) removed AudioMetaData which causes crashes
pyprojectContent := fmt.Sprintf(`[project]
name = "pyannote-diarization"
version = "0.1.0"
description = "Audio diarization using PyAnnote"
requires-python = ">=3.10"
dependencies = [
"torch>=2.5.0",
"torchaudio>=2.5.0",
"huggingface-hub>=0.28.1",
"pyannote.audio==4.0.2"
]
// Read pyproject.toml for PyAnnote
pyprojectContent, err := pyannoteScripts.ReadFile("py/pyannote/pyproject.toml")
if err != nil {
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
}
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
// The static file contains the default cu126 URL
contentStr := strings.Replace(
string(pyprojectContent),
"https://download.pytorch.org/whl/cu126",
GetPyTorchWheelURL(),
1,
)
[[tool.uv.index]]
name = "pytorch"
url = "%s"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
`, GetPyTorchWheelURL())
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
return fmt.Errorf("failed to write pyproject.toml: %w", err)
}
@@ -270,289 +252,20 @@ explicit = true
return nil
}
// createDiarizationScript creates the Python script for PyAnnote diarization
func (p *PyAnnoteAdapter) createDiarizationScript() error {
// copyDiarizationScript creates the Python script for PyAnnote diarization
func (p *PyAnnoteAdapter) copyDiarizationScript() error {
// Ensure the directory exists first
if err := os.MkdirAll(p.envPath, 0755); err != nil {
return fmt.Errorf("failed to create pyannote directory: %w", err)
}
scriptContent, err := pyannoteScripts.ReadFile("py/pyannote/pyannote_diarize.py")
if err != nil {
return fmt.Errorf("failed to read embedded pyannote_diarize.py: %w", err)
}
scriptPath := filepath.Join(p.envPath, "pyannote_diarize.py")
// Always recreate the script to ensure it's up to date with the adapter code
// if _, err := os.Stat(scriptPath); err == nil {
// return nil
// }
scriptContent := `#!/usr/bin/env python3
"""
PyAnnote speaker diarization script.
Processes audio files to identify and separate different speakers.
"""
import argparse
import json
import sys
import os
from pathlib import Path
from pyannote.audio import Pipeline
import torch
# Fix for PyTorch 2.6+ which defaults weights_only=True
# We need to allowlist PyAnnote's custom classes
try:
from pyannote.audio.core.task import Specifications, Problem, Resolution
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([Specifications, Problem, Resolution])
except ImportError:
pass
except Exception as e:
print(f"Warning: Could not add safe globals: {e}")
def diarize_audio(
audio_path: str,
output_file: str,
hf_token: str,
model: str = "pyannote/speaker-diarization-community-1",
min_speakers: int = None,
max_speakers: int = None,
output_format: str = "rttm",
device: str = "auto"
):
"""
Perform speaker diarization on audio file using PyAnnote.
"""
print(f"Loading PyAnnote speaker diarization pipeline: {model}")
try:
# Initialize the diarization pipeline
pipeline = Pipeline.from_pretrained(
model,
token=hf_token
)
# Move to specified device
# if device == "auto" or device == "cuda":
try:
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
print("Using CUDA for diarization")
elif device == "cuda":
print("CUDA requested but not available, falling back to CPU")
else:
print("CUDA not available, using CPU")
except ImportError:
print("PyTorch not available for CUDA, using CPU")
except Exception as e:
print(f"Error moving to device: {e}, using CPU")
print("Pipeline loaded successfully")
except Exception as e:
print(f"Error loading pipeline: {e}")
print("Make sure you have a valid Hugging Face token and have accepted the model's license")
sys.exit(1)
print(f"Processing audio file: {audio_path}")
try:
# Run diarization
diarization_params = {}
if min_speakers is not None:
diarization_params["min_speakers"] = min_speakers
if max_speakers is not None:
diarization_params["max_speakers"] = max_speakers
if diarization_params:
print(f"Using speaker constraints: {diarization_params}")
diarization = pipeline(audio_path, **diarization_params)
else:
print("Using automatic speaker detection")
diarization = pipeline(audio_path)
print(f"Diarization completed. Saving results to: {output_file}")
if output_format == "rttm":
# Save the diarization output to RTTM format
with open(output_file, "w") as rttm:
diarization.write_rttm(rttm)
else:
# Save as JSON format
save_json_format(diarization, output_file, audio_path)
# Print summary
speakers = set()
total_speech_time = 0.0
# Iterate over speaker diarization
# PyAnnote 4.x returns a DiarizeOutput object with a speaker_diarization attribute
if hasattr(diarization, "speaker_diarization"):
for turn, speaker in diarization.speaker_diarization:
speakers.add(speaker)
total_speech_time += turn.duration
elif hasattr(diarization, "itertracks"):
# Fallback for older versions
for segment, track, speaker in diarization.itertracks(yield_label=True):
speakers.add(speaker)
total_speech_time += segment.duration
else:
# Try iterating directly (some versions return Annotation directly)
for segment, track, speaker in diarization.itertracks(yield_label=True):
speakers.add(speaker)
total_speech_time += segment.duration
print(f"\nDiarization Summary:")
print(f" Speakers detected: {len(speakers)}")
print(f" Speaker labels: {sorted(speakers)}")
print(f" Total speech time: {total_speech_time:.2f} seconds")
print(f" Output file saved: {output_file}")
except Exception as e:
print(f"Error during diarization: {e}")
sys.exit(1)
def save_json_format(diarization, output_file: str, audio_path: str):
"""Save diarization results in JSON format."""
segments = []
speakers = set()
# PyAnnote 4.x
if hasattr(diarization, "speaker_diarization"):
for turn, speaker in diarization.speaker_diarization:
segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker,
"confidence": 1.0,
"duration": turn.duration
})
speakers.add(speaker)
# Older versions
elif hasattr(diarization, "itertracks"):
for segment, track, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": segment.start,
"end": segment.end,
"speaker": speaker,
"confidence": 1.0,
"duration": segment.duration
})
speakers.add(speaker)
# Sort segments by start time
segments.sort(key=lambda x: x["start"])
results = {
"audio_file": audio_path,
"model": "pyannote/speaker-diarization-community-1",
"segments": segments,
"speakers": sorted(speakers),
"speaker_count": len(speakers),
"total_duration": max(seg["end"] for seg in segments) if segments else 0,
"processing_info": {
"total_segments": len(segments),
"total_speech_time": sum(seg["duration"] for seg in segments)
}
}
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
def main():
parser = argparse.ArgumentParser(
description="Perform speaker diarization using PyAnnote.audio"
)
parser.add_argument(
"audio_file",
help="Path to audio file"
)
parser.add_argument(
"--output", "-o",
required=True,
help="Output file path"
)
parser.add_argument(
"--hf-token",
required=True,
help="Hugging Face access token"
)
parser.add_argument(
"--model",
default="pyannote/speaker-diarization-community-1",
help="PyAnnote model to use"
)
parser.add_argument(
"--min-speakers",
type=int,
help="Minimum number of speakers"
)
parser.add_argument(
"--max-speakers",
type=int,
help="Maximum number of speakers"
)
parser.add_argument(
"--output-format",
choices=["rttm", "json"],
default="rttm",
help="Output format"
)
parser.add_argument(
"--device",
choices=["cpu", "cuda", "auto"],
default="auto",
help="Device to use for computation"
)
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
# Validate speaker constraints
if args.min_speakers is not None and args.min_speakers < 1:
print("Error: min_speakers must be at least 1")
sys.exit(1)
if args.max_speakers is not None and args.max_speakers < 1:
print("Error: max_speakers must be at least 1")
sys.exit(1)
if (args.min_speakers is not None and args.max_speakers is not None and
args.min_speakers > args.max_speakers):
print("Error: min_speakers cannot be greater than max_speakers")
sys.exit(1)
# Create output directory if it doesn't exist
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
diarize_audio(
audio_path=args.audio_file,
output_file=args.output,
hf_token=args.hf_token,
model=args.model,
min_speakers=args.min_speakers,
max_speakers=args.max_speakers,
output_format=args.output_format,
device=args.device
)
except Exception as e:
print(f"Error during diarization: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
`
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
return fmt.Errorf("failed to write diarization script: %w", err)
}

View File

@@ -180,7 +180,7 @@ func (s *SortformerAdapter) PrepareEnvironment(ctx context.Context) error {
}
// Create diarization script
if err := s.createDiarizationScript(); err != nil {
if err := s.copyDiarizationScript(); err != nil {
return fmt.Errorf("failed to create diarization script: %w", err)
}
@@ -195,51 +195,23 @@ func (s *SortformerAdapter) setupSortformerEnvironment() error {
return fmt.Errorf("failed to create sortformer directory: %w", err)
}
// Create pyproject.toml with configurable PyTorch CUDA version
pyprojectContent := fmt.Sprintf(`[project]
name = "parakeet-transcription"
version = "0.1.0"
description = "Audio transcription using NVIDIA Parakeet models"
requires-python = ">=3.11"
dependencies = [
"nemo-toolkit[asr]",
"torch",
"torchaudio",
"librosa",
"soundfile",
"ml-dtypes>=0.3.1,<0.5.0",
"onnx>=1.15.0,<1.18.0",
"pyannote.audio"
]
// Read pyproject.toml
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
if err != nil {
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
}
[tool.uv.sources]
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
triton = [
{ index = "pytorch", marker = "sys_platform == 'linux'" }
]
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
// The static file contains the default cu126 URL
contentStr := strings.Replace(
string(pyprojectContent),
"https://download.pytorch.org/whl/cu126",
GetPyTorchWheelURL(),
1,
)
[[tool.uv.index]]
name = "pytorch"
url = "%s"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
`, GetPyTorchWheelURL())
pyprojectPath := filepath.Join(s.envPath, "pyproject.toml")
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
return fmt.Errorf("failed to write pyproject.toml: %w", err)
}
@@ -289,345 +261,15 @@ func (s *SortformerAdapter) downloadSortformerModel() error {
return nil
}
// createDiarizationScript creates the Python script for Sortformer diarization
func (s *SortformerAdapter) createDiarizationScript() error {
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
// Check if script already exists
if _, err := os.Stat(scriptPath); err == nil {
return nil
// copyDiarizationScript creates the Python script for Sortformer diarization
func (s *SortformerAdapter) copyDiarizationScript() error {
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/sortformer_diarize.py")
if err != nil {
return fmt.Errorf("failed to read embedded sortformer_diarize.py: %w", err)
}
scriptContent := `#!/usr/bin/env python3
"""
NVIDIA Sortformer speaker diarization script.
Uses diar_streaming_sortformer_4spk-v2 for optimized 4-speaker diarization.
"""
import argparse
import json
import sys
import os
from pathlib import Path
import torch
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except ImportError:
print("Error: NeMo not found. Please install nemo_toolkit[asr]")
sys.exit(1)
def diarize_audio(
audio_path: str,
output_file: str,
batch_size: int = 1,
device: str = None,
max_speakers: int = 4,
output_format: str = "rttm",
streaming_mode: bool = False,
chunk_length_s: float = 30.0,
):
"""
Perform speaker diarization using NVIDIA's Sortformer model.
"""
if device is None or device == "auto":
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
print(f"Loading NVIDIA Sortformer diarization model...")
# Get the directory where this script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "diar_streaming_sortformer_4spk-v2.nemo")
try:
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
print("Please ensure diar_streaming_sortformer_4spk-v2.nemo is in the same directory as this script")
sys.exit(1)
# Load from local file
print(f"Loading model from local path: {model_path}")
diar_model = SortformerEncLabelModel.restore_from(
restore_path=model_path,
map_location=device,
strict=False,
)
# Switch to inference mode
diar_model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
print(f"Processing audio file: {audio_path}")
# Verify audio file exists
if not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}")
sys.exit(1)
try:
# Run diarization
print(f"Running diarization with batch_size={batch_size}, max_speakers={max_speakers}")
if streaming_mode:
print(f"Using streaming mode with chunk_length_s={chunk_length_s}")
# Note: Streaming mode implementation would go here
# For now, use standard diarization
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
else:
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
print(f"Diarization completed. Found segments: {len(predicted_segments)}")
# Process and save results
save_results(predicted_segments, output_file, audio_path, output_format)
except Exception as e:
print(f"Error during diarization: {e}")
sys.exit(1)
def save_results(segments, output_file: str, audio_path: str, output_format: str):
"""
Save diarization results to output file.
Supports both JSON and RTTM formats based on output_format parameter.
"""
output_path = Path(output_file)
if output_format == "rttm":
save_rttm_format(segments, output_file, audio_path)
else:
save_json_format(segments, output_file, audio_path)
def save_json_format(segments, output_file: str, audio_path: str):
"""Save results in JSON format."""
results = {
"audio_file": audio_path,
"model": "nvidia/diar_streaming_sortformer_4spk-v2",
"segments": [],
}
# Handle the case where segments is a list containing a single list of string entries
if len(segments) == 1 and isinstance(segments[0], list):
segments = segments[0]
# Convert segments to JSON format
speakers = set()
for i, segment in enumerate(segments):
try:
# Handle different possible segment formats
if isinstance(segment, str):
# String format: "start end speaker_id"
parts = segment.strip().split()
if len(parts) >= 3:
segment_data = {
"start": float(parts[0]),
"end": float(parts[1]),
"speaker": str(parts[2]),
"duration": float(parts[1]) - float(parts[0]),
"confidence": 1.0,
}
else:
print(f"Warning: Invalid string segment format: {segment}")
continue
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
# Standard pyannote-like format
segment_data = {
"start": float(segment.start),
"end": float(segment.end),
"speaker": str(segment.label),
"duration": float(segment.end - segment.start),
"confidence": getattr(segment, 'confidence', 1.0),
}
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
# List/tuple format: [start, end, speaker]
segment_data = {
"start": float(segment[0]),
"end": float(segment[1]),
"speaker": str(segment[2]),
"duration": float(segment[1] - segment[0]),
"confidence": 1.0,
}
elif isinstance(segment, dict):
# Dictionary format
segment_data = {
"start": float(segment.get('start', 0)),
"end": float(segment.get('end', 0)),
"speaker": str(segment.get('speaker', segment.get('label', f'speaker_{i}'))),
"duration": float(segment.get('end', 0) - segment.get('start', 0)),
"confidence": float(segment.get('confidence', 1.0)),
}
else:
# Fallback: try to extract attributes dynamically
segment_data = {
"start": float(getattr(segment, 'start', 0)),
"end": float(getattr(segment, 'end', 0)),
"speaker": str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}'))),
"duration": float(getattr(segment, 'end', 0) - getattr(segment, 'start', 0)),
"confidence": float(getattr(segment, 'confidence', 1.0)),
}
results["segments"].append(segment_data)
speakers.add(segment_data["speaker"])
except Exception as e:
print(f"Warning: Could not process segment {i}: {e}")
print(f"Segment: {segment}")
# Sort by start time
if results["segments"]:
results["segments"].sort(key=lambda x: x["start"])
# Add summary statistics
results["speakers"] = sorted(speakers)
results["speaker_count"] = len(speakers)
results["total_segments"] = len(results["segments"])
results["total_duration"] = max(seg["end"] for seg in results["segments"]) if results["segments"] else 0
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to: {output_file}")
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
def save_rttm_format(segments, output_file: str, audio_path: str):
"""Save results in RTTM (Rich Transcription Time Marked) format."""
audio_filename = Path(audio_path).stem
speakers = set()
# Handle the case where segments is a list containing a single list of string entries
if len(segments) == 1 and isinstance(segments[0], list):
segments = segments[0]
with open(output_file, "w") as f:
for i, segment in enumerate(segments):
try:
# Handle different possible segment formats
if isinstance(segment, str):
# String format: "start end speaker_id"
parts = segment.strip().split()
if len(parts) >= 3:
start = float(parts[0])
end = float(parts[1])
speaker = str(parts[2])
else:
print(f"Warning: Invalid string segment format: {segment}")
continue
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
# Standard pyannote-like format
start = float(segment.start)
end = float(segment.end)
speaker = str(segment.label)
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
# List/tuple format: [start, end, speaker]
start = float(segment[0])
end = float(segment[1])
speaker = str(segment[2])
elif isinstance(segment, dict):
# Dictionary format
start = float(segment.get('start', 0))
end = float(segment.get('end', 0))
speaker = str(segment.get('speaker', segment.get('label', f'speaker_{i}')))
else:
# Fallback: try to extract attributes dynamically
start = float(getattr(segment, 'start', 0))
end = float(getattr(segment, 'end', 0))
speaker = str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}')))
duration = end - start
speakers.add(speaker)
# RTTM format: SPEAKER <filename> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
line = f"SPEAKER {audio_filename} 1 {start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
f.write(line)
except Exception as e:
print(f"Warning: Could not process segment {i} for RTTM: {e}")
print(f"Segment: {segment}")
print(f"RTTM results saved to: {output_file}")
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
def main():
parser = argparse.ArgumentParser(
description="Speaker diarization using NVIDIA Sortformer model (local model only)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic diarization with JSON output
python sortformer_diarize.py samples/sample.wav output.json
# Generate RTTM format output
python sortformer_diarize.py samples/sample.wav output.rttm
# Specify device and batch size
python sortformer_diarize.py --device cuda --batch-size 2 samples/sample.wav output.json
Note: This script requires diar_streaming_sortformer_4spk-v2.nemo to be in the same directory.
""",
)
parser.add_argument("audio_file", help="Path to input audio file (WAV, FLAC, etc.)")
parser.add_argument("output_file", help="Path to output file (.json for JSON format, .rttm for RTTM format)")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for processing (default: 1)")
parser.add_argument("--device", choices=["cuda", "cpu", "auto"], default="auto", help="Device to use for inference (default: auto-detect)")
parser.add_argument("--max-speakers", type=int, default=4, help="Maximum number of speakers (default: 4, optimized for this model)")
parser.add_argument("--output-format", choices=["json", "rttm"], help="Output format (auto-detected from file extension if not specified)")
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode")
parser.add_argument("--chunk-length-s", type=float, default=30.0, help="Chunk length in seconds for streaming mode (default: 30.0)")
args = parser.parse_args()
# Validate inputs
if not os.path.exists(args.audio_file):
print(f"Error: Audio file not found: {args.audio_file}")
sys.exit(1)
# Auto-detect output format from file extension if not specified
if args.output_format is None:
if args.output_file.lower().endswith('.rttm'):
output_format = "rttm"
else:
output_format = "json"
else:
output_format = args.output_format
# Create output directory if it doesn't exist
output_dir = Path(args.output_file).parent
output_dir.mkdir(parents=True, exist_ok=True)
device = None if args.device == "auto" else args.device
# Run diarization
diarize_audio(
audio_path=args.audio_file,
output_file=args.output_file,
batch_size=args.batch_size,
device=device,
max_speakers=args.max_speakers,
output_format=output_format,
streaming_mode=args.streaming,
chunk_length_s=args.chunk_length_s,
)
if __name__ == "__main__":
main()
`
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
return fmt.Errorf("failed to write diarization script: %w", err)
}