mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-03-03 00:27:01 +00:00
Extract python adapter scripts to proper files
This commit is contained in:
committed by
Rishikanth Chandrasekaran
parent
edb65339b8
commit
50dd4130ff
1
.gitignore
vendored
1
.gitignore
vendored
@@ -70,3 +70,4 @@ tmp/
|
||||
# *.svg
|
||||
# *.png
|
||||
dhl.txt
|
||||
__pycache__/
|
||||
|
||||
@@ -2,6 +2,7 @@ package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -15,6 +16,9 @@ import (
|
||||
"scriberr/pkg/logger"
|
||||
)
|
||||
|
||||
//go:embed py/nvidia/*
|
||||
var nvidiaScripts embed.FS
|
||||
|
||||
// CanaryAdapter implements the TranscriptionAdapter interface for NVIDIA Canary
|
||||
type CanaryAdapter struct {
|
||||
*BaseAdapter
|
||||
@@ -185,7 +189,7 @@ func (c *CanaryAdapter) PrepareEnvironment(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create transcription script
|
||||
if err := c.createTranscriptionScript(); err != nil {
|
||||
if err := c.copyTranscriptionScript(); err != nil {
|
||||
return fmt.Errorf("failed to create transcription script: %w", err)
|
||||
}
|
||||
|
||||
@@ -207,49 +211,23 @@ func (c *CanaryAdapter) setupCanaryEnvironment() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
||||
pyprojectContent := fmt.Sprintf(`[project]
|
||||
name = "parakeet-transcription"
|
||||
version = "0.1.0"
|
||||
description = "Audio transcription using NVIDIA Parakeet models"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"nemo-toolkit[asr]",
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"ml-dtypes>=0.3.1,<0.5.0",
|
||||
"onnx>=1.15.0,<1.18.0",
|
||||
]
|
||||
// Read pyproject.toml
|
||||
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
[tool.uv.sources]
|
||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
triton = [
|
||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
||||
]
|
||||
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||
// The static file contains the default cu126 URL
|
||||
contentStr := strings.Replace(
|
||||
string(pyprojectContent),
|
||||
"https://download.pytorch.org/whl/cu126",
|
||||
GetPyTorchWheelURL(),
|
||||
1,
|
||||
)
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "%s"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
`, GetPyTorchWheelURL())
|
||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
||||
pyprojectPath = filepath.Join(c.envPath, "pyproject.toml")
|
||||
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
@@ -299,213 +277,15 @@ func (c *CanaryAdapter) downloadCanaryModel() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createTranscriptionScript creates the Python script for Canary transcription
|
||||
func (c *CanaryAdapter) createTranscriptionScript() error {
|
||||
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
|
||||
|
||||
// Check if script already exists
|
||||
if _, err := os.Stat(scriptPath); err == nil {
|
||||
return nil
|
||||
// copyTranscriptionScript creates the Python script for Canary transcription
|
||||
func (c *CanaryAdapter) copyTranscriptionScript() error {
|
||||
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/canary_transcribe.py")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded canary_transcribe.py: %w", err)
|
||||
}
|
||||
|
||||
scriptContent := `#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Canary multilingual transcription and translation script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
source_lang: str = "en",
|
||||
target_lang: str = "en",
|
||||
task: str = "transcribe",
|
||||
timestamps: bool = True,
|
||||
output_file: str = None,
|
||||
include_confidence: bool = True,
|
||||
preserve_formatting: bool = True,
|
||||
):
|
||||
"""
|
||||
Transcribe or translate audio using NVIDIA Canary model.
|
||||
"""
|
||||
# Get the directory where this script is located
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(script_dir, "canary-1b-v2.nemo")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading NVIDIA Canary model from: {model_path}")
|
||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||
|
||||
print(f"Processing: {audio_path}")
|
||||
print(f"Task: {task}")
|
||||
print(f"Source language: {source_lang}")
|
||||
print(f"Target language: {target_lang}")
|
||||
|
||||
if timestamps:
|
||||
if task == "translate" and source_lang != target_lang:
|
||||
# Translation with timestamps
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
timestamps=True
|
||||
)
|
||||
else:
|
||||
# Transcription with timestamps
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
timestamps=True
|
||||
)
|
||||
|
||||
# Extract text and timestamps
|
||||
result_data = output[0]
|
||||
text = result_data.text
|
||||
word_timestamps = result_data.timestamp.get("word", [])
|
||||
segment_timestamps = result_data.timestamp.get("segment", [])
|
||||
|
||||
print(f"Result: {text}")
|
||||
|
||||
# Prepare output data
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"source_language": source_lang,
|
||||
"target_language": target_lang,
|
||||
"task": task,
|
||||
"word_timestamps": word_timestamps,
|
||||
"segment_timestamps": segment_timestamps,
|
||||
"audio_file": audio_path,
|
||||
"model": "canary-1b-v2"
|
||||
}
|
||||
|
||||
if include_confidence:
|
||||
# Add confidence scores if available
|
||||
if hasattr(result_data, 'confidence') and result_data.confidence:
|
||||
output_data["confidence"] = result_data.confidence
|
||||
|
||||
# Save to file
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
else:
|
||||
# Simple transcription/translation without timestamps
|
||||
if task == "translate" and source_lang != target_lang:
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang
|
||||
)
|
||||
else:
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang
|
||||
)
|
||||
|
||||
text = output[0].text
|
||||
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"source_language": source_lang,
|
||||
"target_language": target_lang,
|
||||
"task": task,
|
||||
"audio_file": audio_path,
|
||||
"model": "canary-1b-v2"
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe or translate audio using NVIDIA Canary model"
|
||||
)
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument(
|
||||
"--source-lang", default="en",
|
||||
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
||||
help="Source language (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-lang", default="en",
|
||||
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
||||
help="Target language (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", choices=["transcribe", "translate"], default="transcribe",
|
||||
help="Task to perform (default: transcribe)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestamps", action="store_true", default=True,
|
||||
help="Include word and segment level timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-timestamps", dest="timestamps", action="store_false",
|
||||
help="Disable timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", help="Output file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-confidence", action="store_true", default=True,
|
||||
help="Include confidence scores"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-confidence", dest="include_confidence", action="store_false",
|
||||
help="Exclude confidence scores"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve-formatting", action="store_true", default=True,
|
||||
help="Preserve punctuation and capitalization"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
transcribe_audio(
|
||||
audio_path=args.audio_file,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
task=args.task,
|
||||
timestamps=args.timestamps,
|
||||
output_file=args.output,
|
||||
include_confidence=args.include_confidence,
|
||||
preserve_formatting=args.preserve_formatting,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during transcription: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
`
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||
return fmt.Errorf("failed to write transcription script: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -128,8 +128,8 @@ func (p *ParakeetAdapter) PrepareEnvironment(ctx context.Context) error {
|
||||
// Check if environment is already ready (using cache to speed up repeated checks)
|
||||
if CheckEnvironmentReady(p.envPath, "import nemo.collections.asr") {
|
||||
modelPath := filepath.Join(p.envPath, "parakeet-tdt-0.6b-v3.nemo")
|
||||
scriptPath := filepath.Join(p.envPath, "transcribe.py")
|
||||
bufferedScriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
|
||||
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
|
||||
bufferedScriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
|
||||
|
||||
// Check model, standard script, and buffered script all exist
|
||||
if stat, err := os.Stat(modelPath); err == nil && stat.Size() > 1024*1024 {
|
||||
@@ -160,7 +160,7 @@ func (p *ParakeetAdapter) PrepareEnvironment(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create transcription scripts (standard and buffered)
|
||||
if err := p.createTranscriptionScript(); err != nil {
|
||||
if err := p.copyTranscriptionScript(); err != nil {
|
||||
return fmt.Errorf("failed to create transcription script: %w", err)
|
||||
}
|
||||
|
||||
@@ -179,50 +179,23 @@ func (p *ParakeetAdapter) setupParakeetEnvironment() error {
|
||||
return fmt.Errorf("failed to create parakeet directory: %w", err)
|
||||
}
|
||||
|
||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
||||
pyprojectContent := fmt.Sprintf(`[project]
|
||||
name = "parakeet-transcription"
|
||||
version = "0.1.0"
|
||||
description = "Audio transcription using NVIDIA Parakeet models"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"nemo-toolkit[asr]",
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"ml-dtypes>=0.3.1,<0.5.0",
|
||||
"onnx>=1.15.0,<1.18.0",
|
||||
]
|
||||
// Read pyproject.toml
|
||||
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
[tool.uv.sources]
|
||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
triton = [
|
||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
||||
]
|
||||
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||
// The static file contains the default cu126 URL
|
||||
contentStr := strings.Replace(
|
||||
string(pyprojectContent),
|
||||
"https://download.pytorch.org/whl/cu126",
|
||||
GetPyTorchWheelURL(),
|
||||
1,
|
||||
)
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "%s"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
`, GetPyTorchWheelURL())
|
||||
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
|
||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
||||
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
@@ -272,200 +245,15 @@ func (p *ParakeetAdapter) downloadParakeetModel() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createTranscriptionScript creates the Python script for Parakeet transcription
|
||||
func (p *ParakeetAdapter) createTranscriptionScript() error {
|
||||
scriptContent := `#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Parakeet transcription script with timestamp support.
|
||||
"""
|
||||
// copyTranscriptionScript creates the Python script for Parakeet transcription
|
||||
func (p *ParakeetAdapter) copyTranscriptionScript() error {
|
||||
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/parakeet_transcribe.py")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded transcribe.py: %w", err)
|
||||
}
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
timestamps: bool = True,
|
||||
output_file: str = None,
|
||||
context_left: int = 256,
|
||||
context_right: int = 256,
|
||||
include_confidence: bool = True,
|
||||
):
|
||||
"""
|
||||
Transcribe audio using NVIDIA Parakeet model.
|
||||
"""
|
||||
# Get the directory where this script is located
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(script_dir, "parakeet-tdt-0.6b-v3.nemo")
|
||||
|
||||
print(f"Script directory: {script_dir}")
|
||||
print(f"Looking for model at: {model_path}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error during transcription: Can't find {model_path}")
|
||||
# List files in the directory to help debug
|
||||
try:
|
||||
files = os.listdir(script_dir)
|
||||
print(f"Files in {script_dir}: {files}")
|
||||
except Exception as e:
|
||||
print(f"Could not list directory: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||
|
||||
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
||||
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
|
||||
print("Disabling CUDA graphs in TDT decoder...")
|
||||
dec_cfg = asr_model.cfg.decoding
|
||||
|
||||
# Add use_cuda_graph_decoder parameter to greedy config
|
||||
with open_dict(dec_cfg.greedy):
|
||||
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
||||
|
||||
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
||||
asr_model.change_decoding_strategy(dec_cfg)
|
||||
print("✓ CUDA graphs disabled successfully")
|
||||
|
||||
# Configure for long-form audio if context sizes are not default
|
||||
if context_left != 256 or context_right != 256:
|
||||
print(f"Configuring attention context: left={context_left}, right={context_right}")
|
||||
try:
|
||||
asr_model.change_attention_model(
|
||||
self_attention_model="rel_pos_local_attn",
|
||||
att_context_size=[context_left, context_right]
|
||||
)
|
||||
print("Long-form audio mode enabled")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to configure attention model: {e}")
|
||||
print("Continuing with default attention settings")
|
||||
|
||||
print(f"Transcribing: {audio_path}")
|
||||
|
||||
if timestamps:
|
||||
output = asr_model.transcribe([audio_path], timestamps=True)
|
||||
|
||||
# Extract text and timestamps
|
||||
result_data = output[0]
|
||||
text = result_data.text
|
||||
word_timestamps = result_data.timestamp.get("word", [])
|
||||
segment_timestamps = result_data.timestamp.get("segment", [])
|
||||
|
||||
print(f"Transcription: {text}")
|
||||
|
||||
# Prepare output data
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"language": "en",
|
||||
"word_timestamps": word_timestamps,
|
||||
"segment_timestamps": segment_timestamps,
|
||||
"audio_file": audio_path,
|
||||
"model": "parakeet-tdt-0.6b-v3",
|
||||
"context": {
|
||||
"left": context_left,
|
||||
"right": context_right
|
||||
}
|
||||
}
|
||||
|
||||
if include_confidence:
|
||||
# Add confidence scores if available
|
||||
if hasattr(result_data, 'confidence') and result_data.confidence:
|
||||
output_data["confidence"] = result_data.confidence
|
||||
|
||||
# Save to file
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
else:
|
||||
# Simple transcription without timestamps
|
||||
output = asr_model.transcribe([audio_path])
|
||||
text = output[0].text
|
||||
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"language": "en",
|
||||
"audio_file": audio_path,
|
||||
"model": "parakeet-tdt-0.6b-v3"
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe audio using NVIDIA Parakeet model"
|
||||
)
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument(
|
||||
"--timestamps", action="store_true", default=True,
|
||||
help="Include word and segment level timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-timestamps", dest="timestamps", action="store_false",
|
||||
help="Disable timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", help="Output file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-left", type=int, default=256,
|
||||
help="Left attention context size (default: 256)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-right", type=int, default=256,
|
||||
help="Right attention context size (default: 256)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-confidence", action="store_true", default=True,
|
||||
help="Include confidence scores"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-confidence", dest="include_confidence", action="store_false",
|
||||
help="Exclude confidence scores"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
transcribe_audio(
|
||||
audio_path=args.audio_file,
|
||||
timestamps=args.timestamps,
|
||||
output_file=args.output,
|
||||
context_left=args.context_left,
|
||||
context_right=args.context_right,
|
||||
include_confidence=args.include_confidence,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during transcription: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
`
|
||||
|
||||
scriptPath := filepath.Join(p.envPath, "transcribe.py")
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||
return fmt.Errorf("failed to write transcription script: %w", err)
|
||||
}
|
||||
|
||||
@@ -684,7 +472,7 @@ func (p *ParakeetAdapter) transcribeBuffered(ctx context.Context, input interfac
|
||||
func (p *ParakeetAdapter) buildParakeetArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) {
|
||||
outputFile := filepath.Join(tempDir, "result.json")
|
||||
|
||||
scriptPath := filepath.Join(p.envPath, "transcribe.py")
|
||||
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe.py")
|
||||
args := []string{
|
||||
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
|
||||
input.FilePath,
|
||||
@@ -771,181 +559,13 @@ func (p *ParakeetAdapter) parseResult(tempDir string, input interfaces.AudioInpu
|
||||
|
||||
// createBufferedScript creates the Python script for NeMo buffered inference
|
||||
func (p *ParakeetAdapter) createBufferedScript() error {
|
||||
scriptContent := `#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Parakeet buffered inference for long audio files.
|
||||
Splits audio into chunks to avoid GPU memory issues.
|
||||
"""
|
||||
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/parakeet_transcribe_buffered.py")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded transcribe_buffered.py: %w", err)
|
||||
}
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
|
||||
def split_audio_file(audio_path, chunk_duration_secs=300):
|
||||
"""Split audio file into chunks of specified duration."""
|
||||
audio, sr = librosa.load(audio_path, sr=None, mono=True)
|
||||
total_duration = len(audio) / sr
|
||||
chunk_samples = int(chunk_duration_secs * sr)
|
||||
|
||||
chunks = []
|
||||
for start_sample in range(0, len(audio), chunk_samples):
|
||||
end_sample = min(start_sample + chunk_samples, len(audio))
|
||||
chunk_audio = audio[start_sample:end_sample]
|
||||
start_time = start_sample / sr
|
||||
chunks.append({
|
||||
'audio': chunk_audio,
|
||||
'start_time': start_time,
|
||||
'duration': len(chunk_audio) / sr
|
||||
})
|
||||
|
||||
return chunks, sr
|
||||
|
||||
|
||||
def transcribe_buffered(
|
||||
audio_path: str,
|
||||
output_file: str = None,
|
||||
chunk_duration_secs: float = 300, # 5 minutes default
|
||||
):
|
||||
"""
|
||||
Transcribe long audio by splitting into chunks and merging results.
|
||||
"""
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(script_dir, "parakeet-tdt-0.6b-v3.nemo")
|
||||
|
||||
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||
|
||||
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
||||
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
|
||||
print("Disabling CUDA graphs in TDT decoder...")
|
||||
dec_cfg = asr_model.cfg.decoding
|
||||
|
||||
# Add use_cuda_graph_decoder parameter to greedy config
|
||||
with open_dict(dec_cfg.greedy):
|
||||
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
||||
|
||||
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
||||
asr_model.change_decoding_strategy(dec_cfg)
|
||||
print("✓ CUDA graphs disabled successfully")
|
||||
|
||||
print(f"Splitting audio into {chunk_duration_secs}s chunks...")
|
||||
chunks, sr = split_audio_file(audio_path, chunk_duration_secs)
|
||||
print(f"Created {len(chunks)} chunks")
|
||||
|
||||
all_words = []
|
||||
all_segments = []
|
||||
full_text = []
|
||||
|
||||
for i, chunk_info in enumerate(chunks):
|
||||
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...")
|
||||
|
||||
# Save chunk to temporary file
|
||||
chunk_path = f"/tmp/chunk_{i}.wav"
|
||||
sf.write(chunk_path, chunk_info['audio'], sr)
|
||||
|
||||
try:
|
||||
# Transcribe chunk
|
||||
output = asr_model.transcribe(
|
||||
[chunk_path],
|
||||
batch_size=1,
|
||||
timestamps=True,
|
||||
)
|
||||
|
||||
result_data = output[0]
|
||||
chunk_text = result_data.text
|
||||
full_text.append(chunk_text)
|
||||
|
||||
# Extract and adjust timestamps
|
||||
if hasattr(result_data, 'timestamp') and result_data.timestamp:
|
||||
chunk_words = result_data.timestamp.get("word", [])
|
||||
chunk_segments = result_data.timestamp.get("segment", [])
|
||||
|
||||
# Adjust timestamps by chunk start time
|
||||
for word in chunk_words:
|
||||
word_copy = dict(word)
|
||||
word_copy['start'] += chunk_info['start_time']
|
||||
word_copy['end'] += chunk_info['start_time']
|
||||
all_words.append(word_copy)
|
||||
|
||||
for segment in chunk_segments:
|
||||
seg_copy = dict(segment)
|
||||
seg_copy['start'] += chunk_info['start_time']
|
||||
seg_copy['end'] += chunk_info['start_time']
|
||||
all_segments.append(seg_copy)
|
||||
|
||||
print(f"Chunk {i+1} complete: {len(chunk_text)} characters")
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(chunk_path):
|
||||
os.remove(chunk_path)
|
||||
|
||||
final_text = " ".join(full_text)
|
||||
print(f"Transcription complete: {len(final_text)} characters total")
|
||||
|
||||
output_data = {
|
||||
"transcription": final_text,
|
||||
"language": "en",
|
||||
"word_timestamps": all_words,
|
||||
"segment_timestamps": all_segments,
|
||||
"audio_file": audio_path,
|
||||
"model": "parakeet-tdt-0.6b-v3",
|
||||
"buffered": True,
|
||||
"chunk_duration_secs": chunk_duration_secs,
|
||||
"num_chunks": len(chunks),
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe long audio using NVIDIA Parakeet with chunking"
|
||||
)
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument("--output", "-o", help="Output file path", required=True)
|
||||
parser.add_argument(
|
||||
"--chunk-len", type=float, default=300,
|
||||
help="Chunk duration in seconds (default: 300 = 5 minutes)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
transcribe_buffered(
|
||||
audio_path=args.audio_file,
|
||||
output_file=args.output,
|
||||
chunk_duration_secs=args.chunk_len,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
`
|
||||
|
||||
scriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||
return fmt.Errorf("failed to write buffered script: %w", err)
|
||||
}
|
||||
|
||||
@@ -963,7 +583,7 @@ func (p *ParakeetAdapter) buildBufferedArgs(input interfaces.AudioInput, params
|
||||
chunkDuration = thresholdStr
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(p.envPath, "transcribe_buffered.py")
|
||||
scriptPath := filepath.Join(p.envPath, "parakeet_transcribe_buffered.py")
|
||||
args := []string{
|
||||
"run", "--native-tls", "--project", p.envPath, "python", scriptPath,
|
||||
input.FilePath,
|
||||
|
||||
204
internal/transcription/adapters/py/nvidia/canary_transcribe.py
Normal file
204
internal/transcription/adapters/py/nvidia/canary_transcribe.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Canary multilingual transcription and translation script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
source_lang: str = "en",
|
||||
target_lang: str = "en",
|
||||
task: str = "transcribe",
|
||||
timestamps: bool = True,
|
||||
output_file: str = None,
|
||||
include_confidence: bool = True,
|
||||
preserve_formatting: bool = True,
|
||||
):
|
||||
"""
|
||||
Transcribe or translate audio using NVIDIA Canary model.
|
||||
"""
|
||||
# Determine model path
|
||||
model_filename = "canary-1b-v2.nemo"
|
||||
model_path = None
|
||||
|
||||
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||
if not virtual_env:
|
||||
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||
sys.exit(1)
|
||||
|
||||
project_root = os.path.dirname(virtual_env)
|
||||
model_path = os.path.join(project_root, model_filename)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading NVIDIA Canary model from: {model_path}")
|
||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||
|
||||
print(f"Processing: {audio_path}")
|
||||
print(f"Task: {task}")
|
||||
print(f"Source language: {source_lang}")
|
||||
print(f"Target language: {target_lang}")
|
||||
|
||||
if timestamps:
|
||||
if task == "translate" and source_lang != target_lang:
|
||||
# Translation with timestamps
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
timestamps=True
|
||||
)
|
||||
else:
|
||||
# Transcription with timestamps
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
timestamps=True
|
||||
)
|
||||
|
||||
# Extract text and timestamps
|
||||
result_data = output[0]
|
||||
text = result_data.text
|
||||
word_timestamps = result_data.timestamp.get("word", [])
|
||||
segment_timestamps = result_data.timestamp.get("segment", [])
|
||||
|
||||
print(f"Result: {text}")
|
||||
|
||||
# Prepare output data
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"source_language": source_lang,
|
||||
"target_language": target_lang,
|
||||
"task": task,
|
||||
"word_timestamps": word_timestamps,
|
||||
"segment_timestamps": segment_timestamps,
|
||||
"audio_file": audio_path,
|
||||
"model": "canary-1b-v2"
|
||||
}
|
||||
|
||||
if include_confidence:
|
||||
# Add confidence scores if available
|
||||
if hasattr(result_data, 'confidence') and result_data.confidence:
|
||||
output_data["confidence"] = result_data.confidence
|
||||
|
||||
# Save to file
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
else:
|
||||
# Simple transcription/translation without timestamps
|
||||
if task == "translate" and source_lang != target_lang:
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang
|
||||
)
|
||||
else:
|
||||
output = asr_model.transcribe(
|
||||
[audio_path],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang
|
||||
)
|
||||
|
||||
text = output[0].text
|
||||
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"source_language": source_lang,
|
||||
"target_language": target_lang,
|
||||
"task": task,
|
||||
"audio_file": audio_path,
|
||||
"model": "canary-1b-v2"
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe or translate audio using NVIDIA Canary model"
|
||||
)
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument(
|
||||
"--source-lang", default="en",
|
||||
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
||||
help="Source language (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-lang", default="en",
|
||||
choices=["en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"],
|
||||
help="Target language (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", choices=["transcribe", "translate"], default="transcribe",
|
||||
help="Task to perform (default: transcribe)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestamps", action="store_true", default=True,
|
||||
help="Include word and segment level timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-timestamps", dest="timestamps", action="store_false",
|
||||
help="Disable timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", help="Output file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-confidence", action="store_true", default=True,
|
||||
help="Include confidence scores"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-confidence", dest="include_confidence", action="store_false",
|
||||
help="Exclude confidence scores"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve-formatting", action="store_true", default=True,
|
||||
help="Preserve punctuation and capitalization"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
transcribe_audio(
|
||||
audio_path=args.audio_file,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
task=args.task,
|
||||
timestamps=args.timestamps,
|
||||
output_file=args.output,
|
||||
include_confidence=args.include_confidence,
|
||||
preserve_formatting=args.preserve_formatting,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during transcription: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
188
internal/transcription/adapters/py/nvidia/parakeet_transcribe.py
Normal file
188
internal/transcription/adapters/py/nvidia/parakeet_transcribe.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Parakeet transcription script with timestamp support.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
timestamps: bool = True,
|
||||
output_file: str = None,
|
||||
context_left: int = 256,
|
||||
context_right: int = 256,
|
||||
include_confidence: bool = True,
|
||||
):
|
||||
"""
|
||||
Transcribe audio using NVIDIA Parakeet model.
|
||||
"""
|
||||
# Determine model path
|
||||
model_filename = "parakeet-tdt-0.6b-v3.nemo"
|
||||
model_path = None
|
||||
|
||||
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||
if not virtual_env:
|
||||
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||
sys.exit(1)
|
||||
|
||||
project_root = os.path.dirname(virtual_env)
|
||||
model_path = os.path.join(project_root, model_filename)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||
|
||||
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
||||
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
|
||||
print("Disabling CUDA graphs in TDT decoder...")
|
||||
dec_cfg = asr_model.cfg.decoding
|
||||
|
||||
# Add use_cuda_graph_decoder parameter to greedy config
|
||||
with open_dict(dec_cfg.greedy):
|
||||
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
||||
|
||||
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
||||
asr_model.change_decoding_strategy(dec_cfg)
|
||||
print("✓ CUDA graphs disabled successfully")
|
||||
|
||||
# Configure for long-form audio if context sizes are not default
|
||||
if context_left != 256 or context_right != 256:
|
||||
print(f"Configuring attention context: left={context_left}, right={context_right}")
|
||||
try:
|
||||
asr_model.change_attention_model(
|
||||
self_attention_model="rel_pos_local_attn",
|
||||
att_context_size=[context_left, context_right]
|
||||
)
|
||||
print("Long-form audio mode enabled")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to configure attention model: {e}")
|
||||
print("Continuing with default attention settings")
|
||||
|
||||
print(f"Transcribing: {audio_path}")
|
||||
|
||||
if timestamps:
|
||||
output = asr_model.transcribe([audio_path], timestamps=True)
|
||||
|
||||
# Extract text and timestamps
|
||||
result_data = output[0]
|
||||
text = result_data.text
|
||||
word_timestamps = result_data.timestamp.get("word", [])
|
||||
segment_timestamps = result_data.timestamp.get("segment", [])
|
||||
|
||||
print(f"Transcription: {text}")
|
||||
|
||||
# Prepare output data
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"language": "en",
|
||||
"word_timestamps": word_timestamps,
|
||||
"segment_timestamps": segment_timestamps,
|
||||
"audio_file": audio_path,
|
||||
"model": "parakeet-tdt-0.6b-v3",
|
||||
"context": {
|
||||
"left": context_left,
|
||||
"right": context_right
|
||||
}
|
||||
}
|
||||
|
||||
if include_confidence:
|
||||
# Add confidence scores if available
|
||||
if hasattr(result_data, 'confidence') and result_data.confidence:
|
||||
output_data["confidence"] = result_data.confidence
|
||||
|
||||
# Save to file
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
else:
|
||||
# Simple transcription without timestamps
|
||||
output = asr_model.transcribe([audio_path])
|
||||
text = output[0].text
|
||||
|
||||
output_data = {
|
||||
"transcription": text,
|
||||
"language": "en",
|
||||
"audio_file": audio_path,
|
||||
"model": "parakeet-tdt-0.6b-v3"
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe audio using NVIDIA Parakeet model"
|
||||
)
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument(
|
||||
"--timestamps", action="store_true", default=True,
|
||||
help="Include word and segment level timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-timestamps", dest="timestamps", action="store_false",
|
||||
help="Disable timestamps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", help="Output file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-left", type=int, default=256,
|
||||
help="Left attention context size (default: 256)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-right", type=int, default=256,
|
||||
help="Right attention context size (default: 256)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-confidence", action="store_true", default=True,
|
||||
help="Include confidence scores"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-confidence", dest="include_confidence", action="store_false",
|
||||
help="Exclude confidence scores"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
transcribe_audio(
|
||||
audio_path=args.audio_file,
|
||||
timestamps=args.timestamps,
|
||||
output_file=args.output,
|
||||
context_left=args.context_left,
|
||||
context_right=args.context_right,
|
||||
include_confidence=args.include_confidence,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during transcription: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Parakeet buffered inference for long audio files.
|
||||
Splits audio into chunks to avoid GPU memory issues.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
|
||||
def split_audio_file(audio_path, chunk_duration_secs=300):
|
||||
"""Split audio file into chunks of specified duration."""
|
||||
audio, sr = librosa.load(audio_path, sr=None, mono=True)
|
||||
total_duration = len(audio) / sr
|
||||
chunk_samples = int(chunk_duration_secs * sr)
|
||||
|
||||
chunks = []
|
||||
for start_sample in range(0, len(audio), chunk_samples):
|
||||
end_sample = min(start_sample + chunk_samples, len(audio))
|
||||
chunk_audio = audio[start_sample:end_sample]
|
||||
start_time = start_sample / sr
|
||||
chunks.append({
|
||||
'audio': chunk_audio,
|
||||
'start_time': start_time,
|
||||
'duration': len(chunk_audio) / sr
|
||||
})
|
||||
|
||||
return chunks, sr
|
||||
|
||||
|
||||
def transcribe_buffered(
|
||||
audio_path: str,
|
||||
output_file: str = None,
|
||||
chunk_duration_secs: float = 300, # 5 minutes default
|
||||
):
|
||||
"""
|
||||
Transcribe long audio by splitting into chunks and merging results.
|
||||
"""
|
||||
# Determine model path
|
||||
model_filename = "parakeet-tdt-0.6b-v3.nemo"
|
||||
model_path = None
|
||||
|
||||
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||
if not virtual_env:
|
||||
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||
sys.exit(1)
|
||||
|
||||
project_root = os.path.dirname(virtual_env)
|
||||
model_path = os.path.join(project_root, model_filename)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error during transcription: Can't find {model_filename} in project root: {project_root}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading NVIDIA Parakeet model from: {model_path}")
|
||||
|
||||
asr_model = nemo_asr.models.ASRModel.restore_from(model_path)
|
||||
|
||||
# Disable CUDA graphs to fix Error 35 on RTX 2000e Ada GPU
|
||||
# Uses change_decoding_strategy() to properly reconfigure the TDT decoder
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
|
||||
print("Disabling CUDA graphs in TDT decoder...")
|
||||
dec_cfg = asr_model.cfg.decoding
|
||||
|
||||
# Add use_cuda_graph_decoder parameter to greedy config
|
||||
with open_dict(dec_cfg.greedy):
|
||||
dec_cfg.greedy['use_cuda_graph_decoder'] = False
|
||||
|
||||
# Apply the new decoding strategy (this rebuilds the decoder with our config)
|
||||
asr_model.change_decoding_strategy(dec_cfg)
|
||||
print("✓ CUDA graphs disabled successfully")
|
||||
|
||||
print(f"Splitting audio into {chunk_duration_secs}s chunks...")
|
||||
chunks, sr = split_audio_file(audio_path, chunk_duration_secs)
|
||||
print(f"Created {len(chunks)} chunks")
|
||||
|
||||
all_words = []
|
||||
all_segments = []
|
||||
full_text = []
|
||||
|
||||
for i, chunk_info in enumerate(chunks):
|
||||
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...")
|
||||
|
||||
# Save chunk to temporary file
|
||||
chunk_path = f"/tmp/chunk_{i}.wav"
|
||||
sf.write(chunk_path, chunk_info['audio'], sr)
|
||||
|
||||
try:
|
||||
# Transcribe chunk
|
||||
output = asr_model.transcribe(
|
||||
[chunk_path],
|
||||
batch_size=1,
|
||||
timestamps=True,
|
||||
)
|
||||
|
||||
result_data = output[0]
|
||||
chunk_text = result_data.text
|
||||
full_text.append(chunk_text)
|
||||
|
||||
# Extract and adjust timestamps
|
||||
if hasattr(result_data, 'timestamp') and result_data.timestamp:
|
||||
chunk_words = result_data.timestamp.get("word", [])
|
||||
chunk_segments = result_data.timestamp.get("segment", [])
|
||||
|
||||
# Adjust timestamps by chunk start time
|
||||
for word in chunk_words:
|
||||
word_copy = dict(word)
|
||||
word_copy['start'] += chunk_info['start_time']
|
||||
word_copy['end'] += chunk_info['start_time']
|
||||
all_words.append(word_copy)
|
||||
|
||||
for segment in chunk_segments:
|
||||
seg_copy = dict(segment)
|
||||
seg_copy['start'] += chunk_info['start_time']
|
||||
seg_copy['end'] += chunk_info['start_time']
|
||||
all_segments.append(seg_copy)
|
||||
|
||||
print(f"Chunk {i+1} complete: {len(chunk_text)} characters")
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(chunk_path):
|
||||
os.remove(chunk_path)
|
||||
|
||||
final_text = " ".join(full_text)
|
||||
print(f"Transcription complete: {len(final_text)} characters total")
|
||||
|
||||
output_data = {
|
||||
"transcription": final_text,
|
||||
"language": "en",
|
||||
"word_timestamps": all_words,
|
||||
"segment_timestamps": all_segments,
|
||||
"audio_file": audio_path,
|
||||
"model": "parakeet-tdt-0.6b-v3",
|
||||
"buffered": True,
|
||||
"chunk_duration_secs": chunk_duration_secs,
|
||||
"num_chunks": len(chunks),
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to: {output_file}")
|
||||
else:
|
||||
print(json.dumps(output_data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe long audio using NVIDIA Parakeet with chunking"
|
||||
)
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument("--output", "-o", help="Output file path", required=True)
|
||||
parser.add_argument(
|
||||
"--chunk-len", type=float, default=300,
|
||||
help="Chunk duration in seconds (default: 300 = 5 minutes)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
transcribe_buffered(
|
||||
audio_path=args.audio_file,
|
||||
output_file=args.output,
|
||||
chunk_duration_secs=args.chunk_len,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
internal/transcription/adapters/py/nvidia/pyproject.toml
Normal file
47
internal/transcription/adapters/py/nvidia/pyproject.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[project]
|
||||
name = "nvidia-transcription"
|
||||
version = "0.1.0"
|
||||
description = "Audio transcription and diarization using NVIDIA models (Canary, Parakeet, Sortformer)"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"nemo-toolkit[asr]",
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"ml-dtypes>=0.3.1,<0.5.0",
|
||||
"onnx>=1.15.0,<1.18.0",
|
||||
# "pyannote.audio" # needed for sortformer or no?
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-mock>=3.12.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
triton = [
|
||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "https://download.pytorch.org/whl/cu126"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
335
internal/transcription/adapters/py/nvidia/sortformer_diarize.py
Normal file
335
internal/transcription/adapters/py/nvidia/sortformer_diarize.py
Normal file
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Sortformer speaker diarization script.
|
||||
Uses diar_streaming_sortformer_4spk-v2 for optimized 4-speaker diarization.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
except ImportError:
|
||||
print("Error: NeMo not found. Please install nemo_toolkit[asr]")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def diarize_audio(
|
||||
audio_path: str,
|
||||
output_file: str,
|
||||
batch_size: int = 1,
|
||||
device: str = None,
|
||||
max_speakers: int = 4,
|
||||
output_format: str = "rttm",
|
||||
streaming_mode: bool = False,
|
||||
chunk_length_s: float = 30.0,
|
||||
):
|
||||
"""
|
||||
Perform speaker diarization using NVIDIA's Sortformer model.
|
||||
"""
|
||||
if device is None or device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"Using device: {device}")
|
||||
print(f"Loading NVIDIA Sortformer diarization model...")
|
||||
|
||||
# Determine model path
|
||||
model_filename = "diar_streaming_sortformer_4spk-v2.nemo"
|
||||
model_path = None
|
||||
|
||||
# Locate project root: derived from VIRTUAL_ENV, which is set by `uv run` to path/.venv
|
||||
virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||
if not virtual_env:
|
||||
print("Error: VIRTUAL_ENV environment variable not set. Script must be run with 'uv run'.")
|
||||
sys.exit(1)
|
||||
|
||||
project_root = os.path.dirname(virtual_env)
|
||||
model_path = os.path.join(project_root, model_filename)
|
||||
|
||||
try:
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_filename} in project root: {project_root}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load from local file
|
||||
print(f"Loading model from path: {model_path}")
|
||||
diar_model = SortformerEncLabelModel.restore_from(
|
||||
restore_path=model_path,
|
||||
map_location=device,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
# Switch to inference mode
|
||||
diar_model.eval()
|
||||
print("Model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Processing audio file: {audio_path}")
|
||||
|
||||
# Verify audio file exists
|
||||
if not os.path.exists(audio_path):
|
||||
print(f"Error: Audio file not found: {audio_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Run diarization
|
||||
print(f"Running diarization with batch_size={batch_size}, max_speakers={max_speakers}")
|
||||
|
||||
if streaming_mode:
|
||||
print(f"Using streaming mode with chunk_length_s={chunk_length_s}")
|
||||
# Note: Streaming mode implementation would go here
|
||||
# For now, use standard diarization
|
||||
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
||||
else:
|
||||
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
||||
|
||||
print(f"Diarization completed. Found segments: {len(predicted_segments)}")
|
||||
|
||||
# Process and save results
|
||||
save_results(predicted_segments, output_file, audio_path, output_format)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during diarization: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def save_results(segments, output_file: str, audio_path: str, output_format: str):
|
||||
"""
|
||||
Save diarization results to output file.
|
||||
Supports both JSON and RTTM formats based on output_format parameter.
|
||||
"""
|
||||
output_path = Path(output_file)
|
||||
|
||||
if output_format == "rttm":
|
||||
save_rttm_format(segments, output_file, audio_path)
|
||||
else:
|
||||
save_json_format(segments, output_file, audio_path)
|
||||
|
||||
|
||||
def save_json_format(segments, output_file: str, audio_path: str):
|
||||
"""Save results in JSON format."""
|
||||
results = {
|
||||
"audio_file": audio_path,
|
||||
"model": "nvidia/diar_streaming_sortformer_4spk-v2",
|
||||
"segments": [],
|
||||
}
|
||||
|
||||
# Handle the case where segments is a list containing a single list of string entries
|
||||
if len(segments) == 1 and isinstance(segments[0], list):
|
||||
segments = segments[0]
|
||||
|
||||
# Convert segments to JSON format
|
||||
speakers = set()
|
||||
for i, segment in enumerate(segments):
|
||||
try:
|
||||
# Handle different possible segment formats
|
||||
if isinstance(segment, str):
|
||||
# String format: "start end speaker_id"
|
||||
parts = segment.strip().split()
|
||||
if len(parts) >= 3:
|
||||
segment_data = {
|
||||
"start": float(parts[0]),
|
||||
"end": float(parts[1]),
|
||||
"speaker": str(parts[2]),
|
||||
"duration": float(parts[1]) - float(parts[0]),
|
||||
"confidence": 1.0,
|
||||
}
|
||||
else:
|
||||
print(f"Warning: Invalid string segment format: {segment}")
|
||||
continue
|
||||
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
||||
# Standard pyannote-like format
|
||||
segment_data = {
|
||||
"start": float(segment.start),
|
||||
"end": float(segment.end),
|
||||
"speaker": str(segment.label),
|
||||
"duration": float(segment.end - segment.start),
|
||||
"confidence": getattr(segment, 'confidence', 1.0),
|
||||
}
|
||||
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
||||
# List/tuple format: [start, end, speaker]
|
||||
segment_data = {
|
||||
"start": float(segment[0]),
|
||||
"end": float(segment[1]),
|
||||
"speaker": str(segment[2]),
|
||||
"duration": float(segment[1] - segment[0]),
|
||||
"confidence": 1.0,
|
||||
}
|
||||
elif isinstance(segment, dict):
|
||||
# Dictionary format
|
||||
segment_data = {
|
||||
"start": float(segment.get('start', 0)),
|
||||
"end": float(segment.get('end', 0)),
|
||||
"speaker": str(segment.get('speaker', segment.get('label', f'speaker_{i}'))),
|
||||
"duration": float(segment.get('end', 0) - segment.get('start', 0)),
|
||||
"confidence": float(segment.get('confidence', 1.0)),
|
||||
}
|
||||
else:
|
||||
# Fallback: try to extract attributes dynamically
|
||||
segment_data = {
|
||||
"start": float(getattr(segment, 'start', 0)),
|
||||
"end": float(getattr(segment, 'end', 0)),
|
||||
"speaker": str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}'))),
|
||||
"duration": float(getattr(segment, 'end', 0) - getattr(segment, 'start', 0)),
|
||||
"confidence": float(getattr(segment, 'confidence', 1.0)),
|
||||
}
|
||||
|
||||
results["segments"].append(segment_data)
|
||||
speakers.add(segment_data["speaker"])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process segment {i}: {e}")
|
||||
print(f"Segment: {segment}")
|
||||
|
||||
# Sort by start time
|
||||
if results["segments"]:
|
||||
results["segments"].sort(key=lambda x: x["start"])
|
||||
|
||||
# Add summary statistics
|
||||
results["speakers"] = sorted(speakers)
|
||||
results["speaker_count"] = len(speakers)
|
||||
results["total_segments"] = len(results["segments"])
|
||||
results["total_duration"] = max(seg["end"] for seg in results["segments"]) if results["segments"] else 0
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"Results saved to: {output_file}")
|
||||
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
||||
|
||||
|
||||
def save_rttm_format(segments, output_file: str, audio_path: str):
|
||||
"""Save results in RTTM (Rich Transcription Time Marked) format."""
|
||||
audio_filename = Path(audio_path).stem
|
||||
speakers = set()
|
||||
|
||||
# Handle the case where segments is a list containing a single list of string entries
|
||||
if len(segments) == 1 and isinstance(segments[0], list):
|
||||
segments = segments[0]
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
for i, segment in enumerate(segments):
|
||||
try:
|
||||
# Handle different possible segment formats
|
||||
if isinstance(segment, str):
|
||||
# String format: "start end speaker_id"
|
||||
parts = segment.strip().split()
|
||||
if len(parts) >= 3:
|
||||
start = float(parts[0])
|
||||
end = float(parts[1])
|
||||
speaker = str(parts[2])
|
||||
else:
|
||||
print(f"Warning: Invalid string segment format: {segment}")
|
||||
continue
|
||||
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
||||
# Standard pyannote-like format
|
||||
start = float(segment.start)
|
||||
end = float(segment.end)
|
||||
speaker = str(segment.label)
|
||||
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
||||
# List/tuple format: [start, end, speaker]
|
||||
start = float(segment[0])
|
||||
end = float(segment[1])
|
||||
speaker = str(segment[2])
|
||||
elif isinstance(segment, dict):
|
||||
# Dictionary format
|
||||
start = float(segment.get('start', 0))
|
||||
end = float(segment.get('end', 0))
|
||||
speaker = str(segment.get('speaker', segment.get('label', f'speaker_{i}')))
|
||||
else:
|
||||
# Fallback: try to extract attributes dynamically
|
||||
start = float(getattr(segment, 'start', 0))
|
||||
end = float(getattr(segment, 'end', 0))
|
||||
speaker = str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}')))
|
||||
|
||||
duration = end - start
|
||||
speakers.add(speaker)
|
||||
|
||||
# RTTM format: SPEAKER <filename> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
|
||||
line = f"SPEAKER {audio_filename} 1 {start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
|
||||
f.write(line)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process segment {i} for RTTM: {e}")
|
||||
print(f"Segment: {segment}")
|
||||
|
||||
print(f"RTTM results saved to: {output_file}")
|
||||
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Speaker diarization using NVIDIA Sortformer model (local model only)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic diarization with JSON output
|
||||
python sortformer_diarize.py samples/sample.wav output.json
|
||||
|
||||
# Generate RTTM format output
|
||||
python sortformer_diarize.py samples/sample.wav output.rttm
|
||||
|
||||
# Specify device and batch size
|
||||
python sortformer_diarize.py --device cuda --batch-size 2 samples/sample.wav output.json
|
||||
|
||||
Note: This script requires diar_streaming_sortformer_4spk-v2.nemo to be in the same directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument("audio_file", help="Path to input audio file (WAV, FLAC, etc.)")
|
||||
parser.add_argument("output_file", help="Path to output file (.json for JSON format, .rttm for RTTM format)")
|
||||
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for processing (default: 1)")
|
||||
parser.add_argument("--device", choices=["cuda", "cpu", "auto"], default="auto", help="Device to use for inference (default: auto-detect)")
|
||||
parser.add_argument("--max-speakers", type=int, default=4, help="Maximum number of speakers (default: 4, optimized for this model)")
|
||||
parser.add_argument("--output-format", choices=["json", "rttm"], help="Output format (auto-detected from file extension if not specified)")
|
||||
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode")
|
||||
parser.add_argument("--chunk-length-s", type=float, default=30.0, help="Chunk length in seconds for streaming mode (default: 30.0)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Auto-detect output format from file extension if not specified
|
||||
if args.output_format is None:
|
||||
if args.output_file.lower().endswith('.rttm'):
|
||||
output_format = "rttm"
|
||||
else:
|
||||
output_format = "json"
|
||||
else:
|
||||
output_format = args.output_format
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = Path(args.output_file).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = None if args.device == "auto" else args.device
|
||||
|
||||
# Run diarization
|
||||
diarize_audio(
|
||||
audio_path=args.audio_file,
|
||||
output_file=args.output_file,
|
||||
batch_size=args.batch_size,
|
||||
device=device,
|
||||
max_speakers=args.max_speakers,
|
||||
output_format=output_format,
|
||||
streaming_mode=args.streaming,
|
||||
chunk_length_s=args.chunk_length_s,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
266
internal/transcription/adapters/py/pyannote/pyannote_diarize.py
Normal file
266
internal/transcription/adapters/py/pyannote/pyannote_diarize.py
Normal file
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
PyAnnote speaker diarization script.
|
||||
Processes audio files to identify and separate different speakers.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pyannote.audio import Pipeline
|
||||
import torch
|
||||
|
||||
# Fix for PyTorch 2.6+ which defaults weights_only=True
|
||||
# We need to allowlist PyAnnote's custom classes
|
||||
try:
|
||||
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
||||
if hasattr(torch.serialization, "add_safe_globals"):
|
||||
torch.serialization.add_safe_globals([Specifications, Problem, Resolution])
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not add safe globals: {e}")
|
||||
|
||||
|
||||
def diarize_audio(
|
||||
audio_path: str,
|
||||
output_file: str,
|
||||
hf_token: str,
|
||||
model: str = "pyannote/speaker-diarization-community-1",
|
||||
min_speakers: int = None,
|
||||
max_speakers: int = None,
|
||||
output_format: str = "rttm",
|
||||
|
||||
device: str = "auto"
|
||||
):
|
||||
"""
|
||||
Perform speaker diarization on audio file using PyAnnote.
|
||||
"""
|
||||
print(f"Loading PyAnnote speaker diarization pipeline: {model}")
|
||||
|
||||
try:
|
||||
# Initialize the diarization pipeline
|
||||
pipeline = Pipeline.from_pretrained(
|
||||
model,
|
||||
token=hf_token
|
||||
)
|
||||
|
||||
# Move to specified device
|
||||
# if device == "auto" or device == "cuda":
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
pipeline = pipeline.to(torch.device("cuda"))
|
||||
print("Using CUDA for diarization")
|
||||
elif device == "cuda":
|
||||
print("CUDA requested but not available, falling back to CPU")
|
||||
else:
|
||||
print("CUDA not available, using CPU")
|
||||
except ImportError:
|
||||
print("PyTorch not available for CUDA, using CPU")
|
||||
except Exception as e:
|
||||
print(f"Error moving to device: {e}, using CPU")
|
||||
|
||||
print("Pipeline loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Error loading pipeline: {e}")
|
||||
print("Make sure you have a valid Hugging Face token and have accepted the model's license")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Processing audio file: {audio_path}")
|
||||
|
||||
try:
|
||||
# Run diarization
|
||||
diarization_params = {}
|
||||
if min_speakers is not None:
|
||||
diarization_params["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
diarization_params["max_speakers"] = max_speakers
|
||||
|
||||
if diarization_params:
|
||||
print(f"Using speaker constraints: {diarization_params}")
|
||||
diarization = pipeline(audio_path, **diarization_params)
|
||||
else:
|
||||
print("Using automatic speaker detection")
|
||||
diarization = pipeline(audio_path)
|
||||
|
||||
print(f"Diarization completed. Saving results to: {output_file}")
|
||||
|
||||
if output_format == "rttm":
|
||||
# Save the diarization output to RTTM format
|
||||
with open(output_file, "w") as rttm:
|
||||
diarization.write_rttm(rttm)
|
||||
else:
|
||||
# Save as JSON format
|
||||
save_json_format(diarization, output_file, audio_path)
|
||||
|
||||
# Print summary
|
||||
speakers = set()
|
||||
total_speech_time = 0.0
|
||||
|
||||
# Iterate over speaker diarization
|
||||
# PyAnnote 4.x returns a DiarizeOutput object with a speaker_diarization attribute
|
||||
if hasattr(diarization, "speaker_diarization"):
|
||||
for turn, speaker in diarization.speaker_diarization:
|
||||
speakers.add(speaker)
|
||||
total_speech_time += turn.duration
|
||||
elif hasattr(diarization, "itertracks"):
|
||||
# Fallback for older versions
|
||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||
speakers.add(speaker)
|
||||
total_speech_time += segment.duration
|
||||
else:
|
||||
# Try iterating directly (some versions return Annotation directly)
|
||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||
speakers.add(speaker)
|
||||
total_speech_time += segment.duration
|
||||
|
||||
print(f"\nDiarization Summary:")
|
||||
print(f" Speakers detected: {len(speakers)}")
|
||||
print(f" Speaker labels: {sorted(speakers)}")
|
||||
print(f" Total speech time: {total_speech_time:.2f} seconds")
|
||||
print(f" Output file saved: {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during diarization: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def save_json_format(diarization, output_file: str, audio_path: str):
|
||||
"""Save diarization results in JSON format."""
|
||||
segments = []
|
||||
speakers = set()
|
||||
|
||||
# PyAnnote 4.x
|
||||
if hasattr(diarization, "speaker_diarization"):
|
||||
for turn, speaker in diarization.speaker_diarization:
|
||||
segments.append({
|
||||
"start": turn.start,
|
||||
"end": turn.end,
|
||||
"speaker": speaker,
|
||||
"confidence": 1.0,
|
||||
"duration": turn.duration
|
||||
})
|
||||
speakers.add(speaker)
|
||||
# Older versions
|
||||
elif hasattr(diarization, "itertracks"):
|
||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||
segments.append({
|
||||
"start": segment.start,
|
||||
"end": segment.end,
|
||||
"speaker": speaker,
|
||||
"confidence": 1.0,
|
||||
"duration": segment.duration
|
||||
})
|
||||
speakers.add(speaker)
|
||||
|
||||
# Sort segments by start time
|
||||
segments.sort(key=lambda x: x["start"])
|
||||
|
||||
results = {
|
||||
"audio_file": audio_path,
|
||||
"model": "pyannote/speaker-diarization-community-1",
|
||||
"segments": segments,
|
||||
"speakers": sorted(speakers),
|
||||
"speaker_count": len(speakers),
|
||||
"total_duration": max(seg["end"] for seg in segments) if segments else 0,
|
||||
"processing_info": {
|
||||
"total_segments": len(segments),
|
||||
"total_speech_time": sum(seg["duration"] for seg in segments)
|
||||
}
|
||||
}
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Perform speaker diarization using PyAnnote.audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_file",
|
||||
help="Path to audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
required=True,
|
||||
help="Output file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
required=True,
|
||||
help="Hugging Face access token"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="pyannote/speaker-diarization-community-1",
|
||||
help="PyAnnote model to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-speakers",
|
||||
type=int,
|
||||
help="Minimum number of speakers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-speakers",
|
||||
type=int,
|
||||
help="Maximum number of speakers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-format",
|
||||
choices=["rttm", "json"],
|
||||
default="rttm",
|
||||
help="Output format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
choices=["cpu", "cuda", "auto"],
|
||||
default="auto",
|
||||
help="Device to use for computation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate speaker constraints
|
||||
if args.min_speakers is not None and args.min_speakers < 1:
|
||||
print("Error: min_speakers must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
if args.max_speakers is not None and args.max_speakers < 1:
|
||||
print("Error: max_speakers must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
if (args.min_speakers is not None and args.max_speakers is not None and
|
||||
args.min_speakers > args.max_speakers):
|
||||
print("Error: min_speakers cannot be greater than max_speakers")
|
||||
sys.exit(1)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
diarize_audio(
|
||||
audio_path=args.audio_file,
|
||||
output_file=args.output,
|
||||
hf_token=args.hf_token,
|
||||
model=args.model,
|
||||
min_speakers=args.min_speakers,
|
||||
max_speakers=args.max_speakers,
|
||||
output_format=args.output_format,
|
||||
device=args.device
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during diarization: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
internal/transcription/adapters/py/pyannote/pyproject.toml
Normal file
39
internal/transcription/adapters/py/pyannote/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[project]
|
||||
name = "pyannote-diarization"
|
||||
version = "0.1.0"
|
||||
description = "Audio diarization using PyAnnote"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.5.0",
|
||||
"torchaudio>=2.5.0",
|
||||
"huggingface-hub>=0.28.1",
|
||||
"pyannote.audio==4.0.2"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-mock>=3.12.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "https://download.pytorch.org/whl/cu126"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
@@ -2,6 +2,7 @@ package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -15,6 +16,9 @@ import (
|
||||
"scriberr/pkg/logger"
|
||||
)
|
||||
|
||||
//go:embed py/pyannote/*
|
||||
var pyannoteScripts embed.FS
|
||||
|
||||
const OutputFormatJSON = "json"
|
||||
|
||||
// PyAnnoteAdapter implements the DiarizationAdapter interface for PyAnnote
|
||||
@@ -182,7 +186,7 @@ func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error {
|
||||
if CheckEnvironmentReady(p.envPath, "from pyannote.audio import Pipeline") {
|
||||
logger.Info("PyAnnote already available in environment")
|
||||
// Still ensure script exists
|
||||
if err := p.createDiarizationScript(); err != nil {
|
||||
if err := p.copyDiarizationScript(); err != nil {
|
||||
return fmt.Errorf("failed to create diarization script: %w", err)
|
||||
}
|
||||
p.initialized = true
|
||||
@@ -195,7 +199,7 @@ func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Always ensure diarization script exists
|
||||
if err := p.createDiarizationScript(); err != nil {
|
||||
if err := p.copyDiarizationScript(); err != nil {
|
||||
return fmt.Errorf("failed to create diarization script: %w", err)
|
||||
}
|
||||
|
||||
@@ -216,45 +220,23 @@ func (p *PyAnnoteAdapter) setupPyAnnoteEnvironment() error {
|
||||
return fmt.Errorf("failed to create pyannote directory: %w", err)
|
||||
}
|
||||
|
||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
||||
// Note: We explicitly pin torch and torchaudio to 2.1.2 to ensure compatibility with pyannote.audio 3.1
|
||||
// Newer versions of torchaudio (2.2+) removed AudioMetaData which causes crashes
|
||||
pyprojectContent := fmt.Sprintf(`[project]
|
||||
name = "pyannote-diarization"
|
||||
version = "0.1.0"
|
||||
description = "Audio diarization using PyAnnote"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.5.0",
|
||||
"torchaudio>=2.5.0",
|
||||
"huggingface-hub>=0.28.1",
|
||||
"pyannote.audio==4.0.2"
|
||||
]
|
||||
// Read pyproject.toml for PyAnnote
|
||||
pyprojectContent, err := pyannoteScripts.ReadFile("py/pyannote/pyproject.toml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||
// The static file contains the default cu126 URL
|
||||
contentStr := strings.Replace(
|
||||
string(pyprojectContent),
|
||||
"https://download.pytorch.org/whl/cu126",
|
||||
GetPyTorchWheelURL(),
|
||||
1,
|
||||
)
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "%s"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
`, GetPyTorchWheelURL())
|
||||
pyprojectPath := filepath.Join(p.envPath, "pyproject.toml")
|
||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
||||
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
@@ -270,289 +252,20 @@ explicit = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDiarizationScript creates the Python script for PyAnnote diarization
|
||||
func (p *PyAnnoteAdapter) createDiarizationScript() error {
|
||||
// copyDiarizationScript creates the Python script for PyAnnote diarization
|
||||
func (p *PyAnnoteAdapter) copyDiarizationScript() error {
|
||||
// Ensure the directory exists first
|
||||
if err := os.MkdirAll(p.envPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create pyannote directory: %w", err)
|
||||
}
|
||||
|
||||
scriptContent, err := pyannoteScripts.ReadFile("py/pyannote/pyannote_diarize.py")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded pyannote_diarize.py: %w", err)
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(p.envPath, "pyannote_diarize.py")
|
||||
|
||||
// Always recreate the script to ensure it's up to date with the adapter code
|
||||
// if _, err := os.Stat(scriptPath); err == nil {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
scriptContent := `#!/usr/bin/env python3
|
||||
"""
|
||||
PyAnnote speaker diarization script.
|
||||
Processes audio files to identify and separate different speakers.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pyannote.audio import Pipeline
|
||||
import torch
|
||||
|
||||
# Fix for PyTorch 2.6+ which defaults weights_only=True
|
||||
# We need to allowlist PyAnnote's custom classes
|
||||
try:
|
||||
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
||||
if hasattr(torch.serialization, "add_safe_globals"):
|
||||
torch.serialization.add_safe_globals([Specifications, Problem, Resolution])
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not add safe globals: {e}")
|
||||
|
||||
|
||||
def diarize_audio(
|
||||
audio_path: str,
|
||||
output_file: str,
|
||||
hf_token: str,
|
||||
model: str = "pyannote/speaker-diarization-community-1",
|
||||
min_speakers: int = None,
|
||||
max_speakers: int = None,
|
||||
output_format: str = "rttm",
|
||||
|
||||
device: str = "auto"
|
||||
):
|
||||
"""
|
||||
Perform speaker diarization on audio file using PyAnnote.
|
||||
"""
|
||||
print(f"Loading PyAnnote speaker diarization pipeline: {model}")
|
||||
|
||||
try:
|
||||
# Initialize the diarization pipeline
|
||||
pipeline = Pipeline.from_pretrained(
|
||||
model,
|
||||
token=hf_token
|
||||
)
|
||||
|
||||
# Move to specified device
|
||||
# if device == "auto" or device == "cuda":
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
pipeline = pipeline.to(torch.device("cuda"))
|
||||
print("Using CUDA for diarization")
|
||||
elif device == "cuda":
|
||||
print("CUDA requested but not available, falling back to CPU")
|
||||
else:
|
||||
print("CUDA not available, using CPU")
|
||||
except ImportError:
|
||||
print("PyTorch not available for CUDA, using CPU")
|
||||
except Exception as e:
|
||||
print(f"Error moving to device: {e}, using CPU")
|
||||
|
||||
print("Pipeline loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Error loading pipeline: {e}")
|
||||
print("Make sure you have a valid Hugging Face token and have accepted the model's license")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Processing audio file: {audio_path}")
|
||||
|
||||
try:
|
||||
# Run diarization
|
||||
diarization_params = {}
|
||||
if min_speakers is not None:
|
||||
diarization_params["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
diarization_params["max_speakers"] = max_speakers
|
||||
|
||||
if diarization_params:
|
||||
print(f"Using speaker constraints: {diarization_params}")
|
||||
diarization = pipeline(audio_path, **diarization_params)
|
||||
else:
|
||||
print("Using automatic speaker detection")
|
||||
diarization = pipeline(audio_path)
|
||||
|
||||
print(f"Diarization completed. Saving results to: {output_file}")
|
||||
|
||||
if output_format == "rttm":
|
||||
# Save the diarization output to RTTM format
|
||||
with open(output_file, "w") as rttm:
|
||||
diarization.write_rttm(rttm)
|
||||
else:
|
||||
# Save as JSON format
|
||||
save_json_format(diarization, output_file, audio_path)
|
||||
|
||||
# Print summary
|
||||
speakers = set()
|
||||
total_speech_time = 0.0
|
||||
|
||||
# Iterate over speaker diarization
|
||||
# PyAnnote 4.x returns a DiarizeOutput object with a speaker_diarization attribute
|
||||
if hasattr(diarization, "speaker_diarization"):
|
||||
for turn, speaker in diarization.speaker_diarization:
|
||||
speakers.add(speaker)
|
||||
total_speech_time += turn.duration
|
||||
elif hasattr(diarization, "itertracks"):
|
||||
# Fallback for older versions
|
||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||
speakers.add(speaker)
|
||||
total_speech_time += segment.duration
|
||||
else:
|
||||
# Try iterating directly (some versions return Annotation directly)
|
||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||
speakers.add(speaker)
|
||||
total_speech_time += segment.duration
|
||||
|
||||
print(f"\nDiarization Summary:")
|
||||
print(f" Speakers detected: {len(speakers)}")
|
||||
print(f" Speaker labels: {sorted(speakers)}")
|
||||
print(f" Total speech time: {total_speech_time:.2f} seconds")
|
||||
print(f" Output file saved: {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during diarization: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def save_json_format(diarization, output_file: str, audio_path: str):
|
||||
"""Save diarization results in JSON format."""
|
||||
segments = []
|
||||
speakers = set()
|
||||
|
||||
# PyAnnote 4.x
|
||||
if hasattr(diarization, "speaker_diarization"):
|
||||
for turn, speaker in diarization.speaker_diarization:
|
||||
segments.append({
|
||||
"start": turn.start,
|
||||
"end": turn.end,
|
||||
"speaker": speaker,
|
||||
"confidence": 1.0,
|
||||
"duration": turn.duration
|
||||
})
|
||||
speakers.add(speaker)
|
||||
# Older versions
|
||||
elif hasattr(diarization, "itertracks"):
|
||||
for segment, track, speaker in diarization.itertracks(yield_label=True):
|
||||
segments.append({
|
||||
"start": segment.start,
|
||||
"end": segment.end,
|
||||
"speaker": speaker,
|
||||
"confidence": 1.0,
|
||||
"duration": segment.duration
|
||||
})
|
||||
speakers.add(speaker)
|
||||
|
||||
# Sort segments by start time
|
||||
segments.sort(key=lambda x: x["start"])
|
||||
|
||||
results = {
|
||||
"audio_file": audio_path,
|
||||
"model": "pyannote/speaker-diarization-community-1",
|
||||
"segments": segments,
|
||||
"speakers": sorted(speakers),
|
||||
"speaker_count": len(speakers),
|
||||
"total_duration": max(seg["end"] for seg in segments) if segments else 0,
|
||||
"processing_info": {
|
||||
"total_segments": len(segments),
|
||||
"total_speech_time": sum(seg["duration"] for seg in segments)
|
||||
}
|
||||
}
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Perform speaker diarization using PyAnnote.audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_file",
|
||||
help="Path to audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
required=True,
|
||||
help="Output file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
required=True,
|
||||
help="Hugging Face access token"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="pyannote/speaker-diarization-community-1",
|
||||
help="PyAnnote model to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-speakers",
|
||||
type=int,
|
||||
help="Minimum number of speakers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-speakers",
|
||||
type=int,
|
||||
help="Maximum number of speakers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-format",
|
||||
choices=["rttm", "json"],
|
||||
default="rttm",
|
||||
help="Output format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
choices=["cpu", "cuda", "auto"],
|
||||
default="auto",
|
||||
help="Device to use for computation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate speaker constraints
|
||||
if args.min_speakers is not None and args.min_speakers < 1:
|
||||
print("Error: min_speakers must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
if args.max_speakers is not None and args.max_speakers < 1:
|
||||
print("Error: max_speakers must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
if (args.min_speakers is not None and args.max_speakers is not None and
|
||||
args.min_speakers > args.max_speakers):
|
||||
print("Error: min_speakers cannot be greater than max_speakers")
|
||||
sys.exit(1)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
diarize_audio(
|
||||
audio_path=args.audio_file,
|
||||
output_file=args.output,
|
||||
hf_token=args.hf_token,
|
||||
model=args.model,
|
||||
min_speakers=args.min_speakers,
|
||||
max_speakers=args.max_speakers,
|
||||
output_format=args.output_format,
|
||||
device=args.device
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during diarization: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
`
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||
return fmt.Errorf("failed to write diarization script: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ func (s *SortformerAdapter) PrepareEnvironment(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create diarization script
|
||||
if err := s.createDiarizationScript(); err != nil {
|
||||
if err := s.copyDiarizationScript(); err != nil {
|
||||
return fmt.Errorf("failed to create diarization script: %w", err)
|
||||
}
|
||||
|
||||
@@ -195,51 +195,23 @@ func (s *SortformerAdapter) setupSortformerEnvironment() error {
|
||||
return fmt.Errorf("failed to create sortformer directory: %w", err)
|
||||
}
|
||||
|
||||
// Create pyproject.toml with configurable PyTorch CUDA version
|
||||
pyprojectContent := fmt.Sprintf(`[project]
|
||||
name = "parakeet-transcription"
|
||||
version = "0.1.0"
|
||||
description = "Audio transcription using NVIDIA Parakeet models"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"nemo-toolkit[asr]",
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"ml-dtypes>=0.3.1,<0.5.0",
|
||||
"onnx>=1.15.0,<1.18.0",
|
||||
"pyannote.audio"
|
||||
]
|
||||
// Read pyproject.toml
|
||||
pyprojectContent, err := nvidiaScripts.ReadFile("py/nvidia/pyproject.toml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
[tool.uv.sources]
|
||||
nemo-toolkit = { git = "https://github.com/NVIDIA/NeMo.git", tag = "v2.5.3" }
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
|
||||
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
|
||||
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
triton = [
|
||||
{ index = "pytorch", marker = "sys_platform == 'linux'" }
|
||||
]
|
||||
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
|
||||
// The static file contains the default cu126 URL
|
||||
contentStr := strings.Replace(
|
||||
string(pyprojectContent),
|
||||
"https://download.pytorch.org/whl/cu126",
|
||||
GetPyTorchWheelURL(),
|
||||
1,
|
||||
)
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "%s"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
`, GetPyTorchWheelURL())
|
||||
pyprojectPath := filepath.Join(s.envPath, "pyproject.toml")
|
||||
if err := os.WriteFile(pyprojectPath, []byte(pyprojectContent), 0644); err != nil {
|
||||
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write pyproject.toml: %w", err)
|
||||
}
|
||||
|
||||
@@ -289,345 +261,15 @@ func (s *SortformerAdapter) downloadSortformerModel() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDiarizationScript creates the Python script for Sortformer diarization
|
||||
func (s *SortformerAdapter) createDiarizationScript() error {
|
||||
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
|
||||
|
||||
// Check if script already exists
|
||||
if _, err := os.Stat(scriptPath); err == nil {
|
||||
return nil
|
||||
// copyDiarizationScript creates the Python script for Sortformer diarization
|
||||
func (s *SortformerAdapter) copyDiarizationScript() error {
|
||||
scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/sortformer_diarize.py")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read embedded sortformer_diarize.py: %w", err)
|
||||
}
|
||||
|
||||
scriptContent := `#!/usr/bin/env python3
|
||||
"""
|
||||
NVIDIA Sortformer speaker diarization script.
|
||||
Uses diar_streaming_sortformer_4spk-v2 for optimized 4-speaker diarization.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
except ImportError:
|
||||
print("Error: NeMo not found. Please install nemo_toolkit[asr]")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def diarize_audio(
|
||||
audio_path: str,
|
||||
output_file: str,
|
||||
batch_size: int = 1,
|
||||
device: str = None,
|
||||
max_speakers: int = 4,
|
||||
output_format: str = "rttm",
|
||||
streaming_mode: bool = False,
|
||||
chunk_length_s: float = 30.0,
|
||||
):
|
||||
"""
|
||||
Perform speaker diarization using NVIDIA's Sortformer model.
|
||||
"""
|
||||
if device is None or device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"Using device: {device}")
|
||||
print(f"Loading NVIDIA Sortformer diarization model...")
|
||||
|
||||
# Get the directory where this script is located
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(script_dir, "diar_streaming_sortformer_4spk-v2.nemo")
|
||||
|
||||
try:
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
print("Please ensure diar_streaming_sortformer_4spk-v2.nemo is in the same directory as this script")
|
||||
sys.exit(1)
|
||||
|
||||
# Load from local file
|
||||
print(f"Loading model from local path: {model_path}")
|
||||
diar_model = SortformerEncLabelModel.restore_from(
|
||||
restore_path=model_path,
|
||||
map_location=device,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
# Switch to inference mode
|
||||
diar_model.eval()
|
||||
print("Model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Processing audio file: {audio_path}")
|
||||
|
||||
# Verify audio file exists
|
||||
if not os.path.exists(audio_path):
|
||||
print(f"Error: Audio file not found: {audio_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Run diarization
|
||||
print(f"Running diarization with batch_size={batch_size}, max_speakers={max_speakers}")
|
||||
|
||||
if streaming_mode:
|
||||
print(f"Using streaming mode with chunk_length_s={chunk_length_s}")
|
||||
# Note: Streaming mode implementation would go here
|
||||
# For now, use standard diarization
|
||||
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
||||
else:
|
||||
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=batch_size)
|
||||
|
||||
print(f"Diarization completed. Found segments: {len(predicted_segments)}")
|
||||
|
||||
# Process and save results
|
||||
save_results(predicted_segments, output_file, audio_path, output_format)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during diarization: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def save_results(segments, output_file: str, audio_path: str, output_format: str):
|
||||
"""
|
||||
Save diarization results to output file.
|
||||
Supports both JSON and RTTM formats based on output_format parameter.
|
||||
"""
|
||||
output_path = Path(output_file)
|
||||
|
||||
if output_format == "rttm":
|
||||
save_rttm_format(segments, output_file, audio_path)
|
||||
else:
|
||||
save_json_format(segments, output_file, audio_path)
|
||||
|
||||
|
||||
def save_json_format(segments, output_file: str, audio_path: str):
|
||||
"""Save results in JSON format."""
|
||||
results = {
|
||||
"audio_file": audio_path,
|
||||
"model": "nvidia/diar_streaming_sortformer_4spk-v2",
|
||||
"segments": [],
|
||||
}
|
||||
|
||||
# Handle the case where segments is a list containing a single list of string entries
|
||||
if len(segments) == 1 and isinstance(segments[0], list):
|
||||
segments = segments[0]
|
||||
|
||||
# Convert segments to JSON format
|
||||
speakers = set()
|
||||
for i, segment in enumerate(segments):
|
||||
try:
|
||||
# Handle different possible segment formats
|
||||
if isinstance(segment, str):
|
||||
# String format: "start end speaker_id"
|
||||
parts = segment.strip().split()
|
||||
if len(parts) >= 3:
|
||||
segment_data = {
|
||||
"start": float(parts[0]),
|
||||
"end": float(parts[1]),
|
||||
"speaker": str(parts[2]),
|
||||
"duration": float(parts[1]) - float(parts[0]),
|
||||
"confidence": 1.0,
|
||||
}
|
||||
else:
|
||||
print(f"Warning: Invalid string segment format: {segment}")
|
||||
continue
|
||||
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
||||
# Standard pyannote-like format
|
||||
segment_data = {
|
||||
"start": float(segment.start),
|
||||
"end": float(segment.end),
|
||||
"speaker": str(segment.label),
|
||||
"duration": float(segment.end - segment.start),
|
||||
"confidence": getattr(segment, 'confidence', 1.0),
|
||||
}
|
||||
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
||||
# List/tuple format: [start, end, speaker]
|
||||
segment_data = {
|
||||
"start": float(segment[0]),
|
||||
"end": float(segment[1]),
|
||||
"speaker": str(segment[2]),
|
||||
"duration": float(segment[1] - segment[0]),
|
||||
"confidence": 1.0,
|
||||
}
|
||||
elif isinstance(segment, dict):
|
||||
# Dictionary format
|
||||
segment_data = {
|
||||
"start": float(segment.get('start', 0)),
|
||||
"end": float(segment.get('end', 0)),
|
||||
"speaker": str(segment.get('speaker', segment.get('label', f'speaker_{i}'))),
|
||||
"duration": float(segment.get('end', 0) - segment.get('start', 0)),
|
||||
"confidence": float(segment.get('confidence', 1.0)),
|
||||
}
|
||||
else:
|
||||
# Fallback: try to extract attributes dynamically
|
||||
segment_data = {
|
||||
"start": float(getattr(segment, 'start', 0)),
|
||||
"end": float(getattr(segment, 'end', 0)),
|
||||
"speaker": str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}'))),
|
||||
"duration": float(getattr(segment, 'end', 0) - getattr(segment, 'start', 0)),
|
||||
"confidence": float(getattr(segment, 'confidence', 1.0)),
|
||||
}
|
||||
|
||||
results["segments"].append(segment_data)
|
||||
speakers.add(segment_data["speaker"])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process segment {i}: {e}")
|
||||
print(f"Segment: {segment}")
|
||||
|
||||
# Sort by start time
|
||||
if results["segments"]:
|
||||
results["segments"].sort(key=lambda x: x["start"])
|
||||
|
||||
# Add summary statistics
|
||||
results["speakers"] = sorted(speakers)
|
||||
results["speaker_count"] = len(speakers)
|
||||
results["total_segments"] = len(results["segments"])
|
||||
results["total_duration"] = max(seg["end"] for seg in results["segments"]) if results["segments"] else 0
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"Results saved to: {output_file}")
|
||||
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
||||
|
||||
|
||||
def save_rttm_format(segments, output_file: str, audio_path: str):
|
||||
"""Save results in RTTM (Rich Transcription Time Marked) format."""
|
||||
audio_filename = Path(audio_path).stem
|
||||
speakers = set()
|
||||
|
||||
# Handle the case where segments is a list containing a single list of string entries
|
||||
if len(segments) == 1 and isinstance(segments[0], list):
|
||||
segments = segments[0]
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
for i, segment in enumerate(segments):
|
||||
try:
|
||||
# Handle different possible segment formats
|
||||
if isinstance(segment, str):
|
||||
# String format: "start end speaker_id"
|
||||
parts = segment.strip().split()
|
||||
if len(parts) >= 3:
|
||||
start = float(parts[0])
|
||||
end = float(parts[1])
|
||||
speaker = str(parts[2])
|
||||
else:
|
||||
print(f"Warning: Invalid string segment format: {segment}")
|
||||
continue
|
||||
elif hasattr(segment, 'start') and hasattr(segment, 'end') and hasattr(segment, 'label'):
|
||||
# Standard pyannote-like format
|
||||
start = float(segment.start)
|
||||
end = float(segment.end)
|
||||
speaker = str(segment.label)
|
||||
elif isinstance(segment, (list, tuple)) and len(segment) >= 3:
|
||||
# List/tuple format: [start, end, speaker]
|
||||
start = float(segment[0])
|
||||
end = float(segment[1])
|
||||
speaker = str(segment[2])
|
||||
elif isinstance(segment, dict):
|
||||
# Dictionary format
|
||||
start = float(segment.get('start', 0))
|
||||
end = float(segment.get('end', 0))
|
||||
speaker = str(segment.get('speaker', segment.get('label', f'speaker_{i}')))
|
||||
else:
|
||||
# Fallback: try to extract attributes dynamically
|
||||
start = float(getattr(segment, 'start', 0))
|
||||
end = float(getattr(segment, 'end', 0))
|
||||
speaker = str(getattr(segment, 'label', getattr(segment, 'speaker', f'speaker_{i}')))
|
||||
|
||||
duration = end - start
|
||||
speakers.add(speaker)
|
||||
|
||||
# RTTM format: SPEAKER <filename> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
|
||||
line = f"SPEAKER {audio_filename} 1 {start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
|
||||
f.write(line)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process segment {i} for RTTM: {e}")
|
||||
print(f"Segment: {segment}")
|
||||
|
||||
print(f"RTTM results saved to: {output_file}")
|
||||
print(f"Found {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Speaker diarization using NVIDIA Sortformer model (local model only)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic diarization with JSON output
|
||||
python sortformer_diarize.py samples/sample.wav output.json
|
||||
|
||||
# Generate RTTM format output
|
||||
python sortformer_diarize.py samples/sample.wav output.rttm
|
||||
|
||||
# Specify device and batch size
|
||||
python sortformer_diarize.py --device cuda --batch-size 2 samples/sample.wav output.json
|
||||
|
||||
Note: This script requires diar_streaming_sortformer_4spk-v2.nemo to be in the same directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument("audio_file", help="Path to input audio file (WAV, FLAC, etc.)")
|
||||
parser.add_argument("output_file", help="Path to output file (.json for JSON format, .rttm for RTTM format)")
|
||||
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for processing (default: 1)")
|
||||
parser.add_argument("--device", choices=["cuda", "cpu", "auto"], default="auto", help="Device to use for inference (default: auto-detect)")
|
||||
parser.add_argument("--max-speakers", type=int, default=4, help="Maximum number of speakers (default: 4, optimized for this model)")
|
||||
parser.add_argument("--output-format", choices=["json", "rttm"], help="Output format (auto-detected from file extension if not specified)")
|
||||
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode")
|
||||
parser.add_argument("--chunk-length-s", type=float, default=30.0, help="Chunk length in seconds for streaming mode (default: 30.0)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.audio_file):
|
||||
print(f"Error: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Auto-detect output format from file extension if not specified
|
||||
if args.output_format is None:
|
||||
if args.output_file.lower().endswith('.rttm'):
|
||||
output_format = "rttm"
|
||||
else:
|
||||
output_format = "json"
|
||||
else:
|
||||
output_format = args.output_format
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = Path(args.output_file).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = None if args.device == "auto" else args.device
|
||||
|
||||
# Run diarization
|
||||
diarize_audio(
|
||||
audio_path=args.audio_file,
|
||||
output_file=args.output_file,
|
||||
batch_size=args.batch_size,
|
||||
device=device,
|
||||
max_speakers=args.max_speakers,
|
||||
output_format=output_format,
|
||||
streaming_mode=args.streaming,
|
||||
chunk_length_s=args.chunk_length_s,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
`
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
|
||||
return fmt.Errorf("failed to write diarization script: %w", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user