feat: add Voxtral-mini transcription support

- Add VoxtralAdapter using transformers library with direct model loading
- Add Python transcription script with apply_transcription_request() method
- Register Voxtral adapter in main.go with dedicated environment
- Add UI configuration in TranscriptionConfigDialog with warning banner
- Support multilingual transcription without word-level timestamps
- Auto GPU/CPU detection, no device parameter needed
- Graceful degradation for missing timestamp features

Voxtral provides high-quality text-only transcription but does not
support word-level timestamps. UI warns users that synchronized
playback and seek features won't be available.
This commit is contained in:
rishikanthc
2025-12-30 12:03:45 -08:00
committed by Rishikanth Chandrasekaran
parent 923b39e415
commit 1ae7b2bf71
5 changed files with 610 additions and 0 deletions

View File

@@ -227,6 +227,9 @@ func registerAdapters(cfg *config.Config) {
// Dedicated environment path for PyAnnote (to avoid dependency conflicts)
pyannoteEnvPath := filepath.Join(cfg.WhisperXEnv, "pyannote")
// Dedicated environment path for Voxtral (Mistral AI model)
voxtralEnvPath := filepath.Join(cfg.WhisperXEnv, "voxtral")
// Register transcription adapters
registry.RegisterTranscriptionAdapter("whisperx",
adapters.NewWhisperXAdapter(cfg.WhisperXEnv))
@@ -234,6 +237,8 @@ func registerAdapters(cfg *config.Config) {
adapters.NewParakeetAdapter(nvidiaEnvPath))
registry.RegisterTranscriptionAdapter("canary",
adapters.NewCanaryAdapter(nvidiaEnvPath)) // Shares with Parakeet
registry.RegisterTranscriptionAdapter("voxtral",
adapters.NewVoxtralAdapter(voxtralEnvPath))
registry.RegisterTranscriptionAdapter("openai_whisper",
adapters.NewOpenAIAdapter(cfg.OpenAIAPIKey))

View File

@@ -0,0 +1,35 @@
[project]
name = "voxtral-transcription"
version = "0.1.0"
description = "Audio transcription using Mistral Voxtral-mini model"
requires-python = ">=3.11"
dependencies = [
"transformers>=4.45.0",
"torch",
"torchaudio",
"accelerate",
"librosa",
"soundfile",
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
torchaudio = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
{ index = "pytorch-cpu", marker = "platform_machine != 'x86_64' and sys_platform != 'darwin'" },
{ index = "pytorch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python3
"""
Voxtral-mini transcription script for Scriberr
Transcribes audio using Mistral's Voxtral-mini model
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import VoxtralForConditionalGeneration, AutoProcessor
def transcribe_audio(
audio_path: str,
output_path: str,
language: str = "en",
model_id: str = "mistralai/Voxtral-mini",
device: str = "auto",
max_new_tokens: int = 500,
) -> dict:
"""
Transcribe audio using Voxtral-mini model.
Args:
audio_path: Path to input audio file
output_path: Path to output JSON file
language: Language code (e.g., 'en', 'es', 'fr')
model_id: HuggingFace model ID
device: Device to use ('cpu', 'cuda', or 'auto')
max_new_tokens: Maximum number of tokens to generate
Returns:
Dictionary containing transcription results
"""
# Determine device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
# Load processor and model
processor = AutoProcessor.from_pretrained(model_id)
# Use appropriate dtype based on device
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = VoxtralForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
)
print(f"Model loaded successfully", file=sys.stderr)
print(f"Processing audio: {audio_path}", file=sys.stderr)
# Prepare transcription request using the proper method
inputs = processor.apply_transcription_request(
language=language,
audio=audio_path,
model_id=model_id
)
# Move inputs to device with correct dtype
inputs = inputs.to(device, dtype=dtype)
print(f"Generating transcription...", file=sys.stderr)
# Generate transcription
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
)
# Decode only the newly generated tokens (skip the input prompt)
decoded_outputs = processor.batch_decode(
outputs[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
transcription_text = decoded_outputs[0]
print(f"Transcription completed ({len(transcription_text)} chars)", file=sys.stderr)
# Prepare output in Scriberr format
# Note: Voxtral doesn't provide word-level timestamps, so we create a single segment
result = {
"text": transcription_text,
"segments": [
{
"id": 0,
"start": 0.0,
"end": 0.0, # Duration unknown without audio analysis
"text": transcription_text,
"words": [] # Voxtral doesn't provide word-level timestamps
}
],
"language": language,
"model": model_id,
"has_word_timestamps": False, # Important: Voxtral doesn't support timestamps
}
# Write output
output_file = Path(output_path)
with output_file.open('w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Results written to {output_path}", file=sys.stderr)
return result
def main():
parser = argparse.ArgumentParser(
description="Transcribe audio using Voxtral-mini model"
)
parser.add_argument(
"audio_path",
type=str,
help="Path to input audio file"
)
parser.add_argument(
"output_path",
type=str,
help="Path to output JSON file"
)
parser.add_argument(
"--language",
type=str,
default="en",
help="Language code (default: en)"
)
parser.add_argument(
"--model-id",
type=str,
default="mistralai/Voxtral-mini",
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["cpu", "cuda", "auto"],
help="Device to use (default: auto)"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=500,
help="Maximum number of tokens to generate (default: 500)"
)
args = parser.parse_args()
try:
transcribe_audio(
audio_path=args.audio_path,
output_path=args.output_path,
language=args.language,
model_id=args.model_id,
device=args.device,
max_new_tokens=args.max_new_tokens,
)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,339 @@
package adapters
import (
"context"
"embed"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"scriberr/internal/transcription/interfaces"
"scriberr/pkg/logger"
)
//go:embed py/voxtral/*
var voxtralScripts embed.FS
// VoxtralAdapter implements the TranscriptionAdapter interface for Mistral Voxtral-mini
type VoxtralAdapter struct {
*BaseAdapter
envPath string
}
// NewVoxtralAdapter creates a new Voxtral adapter
func NewVoxtralAdapter(envPath string) *VoxtralAdapter {
capabilities := interfaces.ModelCapabilities{
ModelID: "voxtral",
ModelFamily: "mistral_voxtral",
DisplayName: "Mistral Voxtral-mini",
Description: "Mistral's multilingual audio transcription model",
Version: "1.0.0",
SupportedLanguages: []string{
"en", "es", "fr", "de", "it", "pt", "nl", "pl", "ru", "zh", "ja", "ko",
// Voxtral supports many languages
},
SupportedFormats: []string{"wav", "mp3", "flac", "m4a", "ogg"},
RequiresGPU: false, // Can run on CPU but GPU recommended
MemoryRequirement: 4096, // 4GB recommended
Features: map[string]bool{
"timestamps": false, // Voxtral doesn't provide word-level timestamps
"word_level": false,
"multilingual": true,
"high_quality": true,
"fast_inference": true,
"transformers_based": true,
},
Metadata: map[string]string{
"engine": "mistral_ai",
"framework": "transformers",
"license": "Apache-2.0",
"model_id": "mistralai/Voxtral-mini",
"no_word_timestamps": "true", // Important metadata for frontend
},
}
schema := []interfaces.ParameterSchema{
// Language selection
{
Name: "language",
Type: "string",
Required: false,
Default: "en",
Options: []string{"en", "es", "fr", "de", "it", "pt", "nl", "pl", "ru", "zh", "ja", "ko"},
Description: "Language of the audio",
Group: "basic",
},
// Generation settings
{
Name: "max_new_tokens",
Type: "int",
Required: false,
Default: 500,
Min: &[]float64{100}[0],
Max: &[]float64{2000}[0],
Description: "Maximum number of tokens to generate",
Group: "advanced",
},
}
baseAdapter := NewBaseAdapter("voxtral", envPath, capabilities, schema)
adapter := &VoxtralAdapter{
BaseAdapter: baseAdapter,
envPath: envPath,
}
return adapter
}
// GetSupportedModels returns the available Voxtral models
func (v *VoxtralAdapter) GetSupportedModels() []string {
return []string{"mistralai/Voxtral-mini"}
}
// PrepareEnvironment sets up the Voxtral environment
func (v *VoxtralAdapter) PrepareEnvironment(ctx context.Context) error {
logger.Info("Preparing Voxtral environment", "env_path", v.envPath)
// Copy transcription script
if err := v.copyTranscriptionScript(); err != nil {
return fmt.Errorf("failed to copy transcription script: %w", err)
}
// Check if environment is already ready
if CheckEnvironmentReady(v.envPath, "from transformers import VoxtralForConditionalGeneration") {
logger.Info("Voxtral environment already ready")
v.initialized = true
return nil
}
// Setup environment
if err := v.setupVoxtralEnvironment(); err != nil {
return fmt.Errorf("failed to setup Voxtral environment: %w", err)
}
v.initialized = true
logger.Info("Voxtral environment prepared successfully")
return nil
}
// setupVoxtralEnvironment creates the Python environment for Voxtral
func (v *VoxtralAdapter) setupVoxtralEnvironment() error {
if err := os.MkdirAll(v.envPath, 0755); err != nil {
return fmt.Errorf("failed to create voxtral directory: %w", err)
}
// Read pyproject.toml
pyprojectContent, err := voxtralScripts.ReadFile("py/voxtral/pyproject.toml")
if err != nil {
return fmt.Errorf("failed to read embedded pyproject.toml: %w", err)
}
// Replace the hardcoded PyTorch URL with the dynamic one based on environment
contentStr := strings.Replace(
string(pyprojectContent),
"https://download.pytorch.org/whl/cu126",
GetPyTorchWheelURL(),
1,
)
pyprojectPath := filepath.Join(v.envPath, "pyproject.toml")
if err := os.WriteFile(pyprojectPath, []byte(contentStr), 0644); err != nil {
return fmt.Errorf("failed to write pyproject.toml: %w", err)
}
// Run uv sync
logger.Info("Installing Voxtral dependencies")
cmd := exec.Command("uv", "sync", "--native-tls")
cmd.Dir = v.envPath
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("uv sync failed: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
// copyTranscriptionScript creates the Python script for Voxtral transcription
func (v *VoxtralAdapter) copyTranscriptionScript() error {
// Ensure directory exists before writing script
if err := os.MkdirAll(v.envPath, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
scriptContent, err := voxtralScripts.ReadFile("py/voxtral/voxtral_transcribe.py")
if err != nil {
return fmt.Errorf("failed to read embedded voxtral_transcribe.py: %w", err)
}
scriptPath := filepath.Join(v.envPath, "voxtral_transcribe.py")
if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil {
return fmt.Errorf("failed to write transcription script: %w", err)
}
return nil
}
// Transcribe processes audio using Voxtral
func (v *VoxtralAdapter) Transcribe(ctx context.Context, input interfaces.AudioInput, params map[string]interface{}, procCtx interfaces.ProcessingContext) (*interfaces.TranscriptResult, error) {
startTime := time.Now()
v.LogProcessingStart(input, procCtx)
defer func() {
v.LogProcessingEnd(procCtx, time.Since(startTime), nil)
}()
// Validate input
if err := v.ValidateAudioInput(input); err != nil {
return nil, fmt.Errorf("invalid audio input: %w", err)
}
// Validate parameters
if err := v.ValidateParameters(params); err != nil {
return nil, fmt.Errorf("invalid parameters: %w", err)
}
// Create temporary directory
tempDir, err := v.CreateTempDirectory(procCtx)
if err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
defer v.CleanupTempDirectory(tempDir)
// Build command arguments
args, err := v.buildVoxtralArgs(input, params, tempDir)
if err != nil {
return nil, fmt.Errorf("failed to build command: %w", err)
}
// Execute Voxtral
cmd := exec.CommandContext(ctx, "uv", args...)
cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1")
// Setup log file
logFile, err := os.OpenFile(filepath.Join(procCtx.OutputDirectory, "transcription.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
logger.Warn("Failed to create log file", "error", err)
} else {
defer logFile.Close()
cmd.Stdout = logFile
cmd.Stderr = logFile
}
logger.Info("Executing Voxtral command", "args", strings.Join(args, " "))
if err := cmd.Run(); err != nil {
if ctx.Err() == context.Canceled {
return nil, fmt.Errorf("transcription was cancelled")
}
// Read tail of log file for context
logPath := filepath.Join(procCtx.OutputDirectory, "transcription.log")
logTail, readErr := v.ReadLogTail(logPath, 2048)
if readErr != nil {
logger.Warn("Failed to read log tail", "error", readErr)
}
logger.Error("Voxtral execution failed", "error", err)
return nil, fmt.Errorf("Voxtral execution failed: %w\nLogs:\n%s", err, logTail)
}
// Parse result
result, err := v.parseResult(tempDir)
if err != nil {
return nil, fmt.Errorf("failed to parse result: %w", err)
}
result.ProcessingTime = time.Since(startTime)
result.ModelUsed = "mistralai/Voxtral-mini"
logger.Info("Voxtral transcription completed",
"text_length", len(result.Text),
"processing_time", result.ProcessingTime)
return result, nil
}
// buildVoxtralArgs builds the command arguments for Voxtral
func (v *VoxtralAdapter) buildVoxtralArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) {
outputFile := filepath.Join(tempDir, "result.json")
scriptPath := filepath.Join(v.envPath, "voxtral_transcribe.py")
args := []string{
"run", "--native-tls", "--project", v.envPath, "python", scriptPath,
input.FilePath,
outputFile,
}
// Add language
if language := v.GetStringParameter(params, "language"); language != "" {
args = append(args, "--language", language)
}
// Device auto-detection (like Parakeet/Canary) - no device parameter needed
// Python script will auto-detect and use GPU if available
// Add max tokens
if maxTokens := v.GetIntParameter(params, "max_new_tokens"); maxTokens > 0 {
args = append(args, "--max-new-tokens", fmt.Sprintf("%d", maxTokens))
}
return args, nil
}
// parseResult parses the Voxtral output
func (v *VoxtralAdapter) parseResult(tempDir string) (*interfaces.TranscriptResult, error) {
resultFile := filepath.Join(tempDir, "result.json")
data, err := os.ReadFile(resultFile)
if err != nil {
return nil, fmt.Errorf("failed to read result file: %w", err)
}
var voxtralResult struct {
Text string `json:"text"`
Language string `json:"language"`
Model string `json:"model"`
HasWordTimestamps bool `json:"has_word_timestamps"`
Segments []struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
} `json:"segments"`
}
if err := json.Unmarshal(data, &voxtralResult); err != nil {
return nil, fmt.Errorf("failed to parse JSON result: %w", err)
}
// Convert to standard format
// Note: Voxtral doesn't provide word-level timestamps, so we create segments without words
result := &interfaces.TranscriptResult{
Text: voxtralResult.Text,
Language: voxtralResult.Language,
Segments: make([]interfaces.TranscriptSegment, len(voxtralResult.Segments)),
}
for i, seg := range voxtralResult.Segments {
result.Segments[i] = interfaces.TranscriptSegment{
Start: seg.Start,
End: seg.End,
Text: seg.Text,
}
}
return result, nil
}
// GetEstimatedProcessingTime provides Voxtral-specific time estimation
func (v *VoxtralAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration {
// Voxtral is relatively fast
baseTime := v.BaseAdapter.GetEstimatedProcessingTime(input)
// Voxtral typically processes at about 10-20% of audio duration
return time.Duration(float64(baseTime) * 0.15)
}

View File

@@ -383,6 +383,9 @@ export const TranscriptionConfigDialog = memo(function TranscriptionConfigDialog
<SelectItem value="nvidia_canary" className={selectItemClassName}>
NVIDIA Canary
</SelectItem>
<SelectItem value="mistral_voxtral" className={selectItemClassName}>
Mistral Voxtral
</SelectItem>
<SelectItem value="openai" className={selectItemClassName}>
OpenAI
</SelectItem>
@@ -433,6 +436,13 @@ export const TranscriptionConfigDialog = memo(function TranscriptionConfigDialog
onValidate={validateAPIKey}
/>
)}
{params.model_family === "mistral_voxtral" && (
<VoxtralConfig
params={params}
updateParam={updateParam}
/>
)}
</div>
{/* Footer */}
@@ -993,3 +1003,50 @@ function OpenAIConfig({
</div>
);
}
function VoxtralConfig({ params, updateParam }: ConfigProps) {
return (
<div className="space-y-6">
{/* Voxtral Warning Banner */}
<InfoBanner variant="warning" title="Limited Features">
Voxtral does not support word-level timestamps. Synchronized playback, audio seeking, and timestamp-based features won't be available.
</InfoBanner>
<Section title="Language Settings">
<FormField label="Language" description="Source language for transcription">
<Select value={params.language || "en"} onValueChange={(v) => updateParam('language', v)}>
<SelectTrigger className={selectTriggerClassName}>
<SelectValue />
</SelectTrigger>
<SelectContent className={selectContentClassName}>
{LANGUAGES.map((l) => (
<SelectItem key={l.value} value={l.value} className={selectItemClassName}>{l.label}</SelectItem>
))}
</SelectContent>
</Select>
</FormField>
</Section>
{/* Advanced Settings */}
<Accordion type="single" collapsible className="w-full">
<AccordionItem value="advanced" className="border border-[var(--border-subtle)] rounded-xl px-4">
<AccordionTrigger className="text-sm font-medium text-[var(--text-primary)] hover:no-underline py-4">
Advanced Settings
</AccordionTrigger>
<AccordionContent className="pb-4 space-y-4">
<FormField label="Max Tokens" description="Maximum number of tokens to generate. Higher values allow longer transcriptions.">
<Input
type="number"
min={100}
max={2000}
value={params.max_line_width || 500}
onChange={(e) => updateParam('max_line_width', parseInt(e.target.value) || 500)}
className={inputClassName}
/>
</FormField>
</AccordionContent>
</AccordionItem>
</Accordion>
</div>
);
}