This commit is contained in:
rishikanthc
2025-11-24 19:42:08 -08:00
parent 747f2c04f6
commit 5d96ef56fe
5 changed files with 191 additions and 109 deletions

View File

@@ -11,6 +11,7 @@ import (
"time"
"scriberr/internal/transcription/interfaces"
"scriberr/pkg/downloader"
"scriberr/pkg/logger"
)
@@ -32,9 +33,9 @@ func NewCanaryAdapter(envPath string) *CanaryAdapter {
"en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh",
// Canary supports many more languages
},
SupportedFormats: []string{"wav", "flac"},
RequiresGPU: false, // Can run on CPU but GPU strongly recommended
MemoryRequirement: 8192, // 8GB+ recommended for Canary
SupportedFormats: []string{"wav", "flac"},
RequiresGPU: false, // Can run on CPU but GPU strongly recommended
MemoryRequirement: 8192, // 8GB+ recommended for Canary
Features: map[string]bool{
"timestamps": true,
"word_level": true,
@@ -67,7 +68,7 @@ func NewCanaryAdapter(envPath string) *CanaryAdapter {
},
{
Name: "target_lang",
Type: "string",
Type: "string",
Required: false,
Default: "en",
Options: []string{"en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"},
@@ -145,7 +146,7 @@ func NewCanaryAdapter(envPath string) *CanaryAdapter {
}
baseAdapter := NewBaseAdapter("canary", envPath, capabilities, schema)
adapter := &CanaryAdapter{
BaseAdapter: baseAdapter,
envPath: envPath,
@@ -253,28 +254,14 @@ func (c *CanaryAdapter) downloadCanaryModel() error {
}
logger.Info("Downloading Canary model", "path", modelPath)
modelURL := "https://huggingface.co/nvidia/canary-1b-v2/resolve/main/canary-1b-v2.nemo?download=true"
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
tempPath := modelPath + ".tmp"
os.Remove(tempPath)
cmd := exec.CommandContext(ctx, "curl",
"-L", "--progress-bar", "--create-dirs",
"-o", tempPath, modelURL)
out, err := cmd.CombinedOutput()
if err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to download Canary model: %w: %s", err, strings.TrimSpace(string(out)))
}
if err := os.Rename(tempPath, modelPath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to move downloaded model: %w", err)
if err := downloader.DownloadFile(ctx, modelURL, modelPath); err != nil {
return fmt.Errorf("failed to download Canary model: %w", err)
}
stat, err := os.Stat(modelPath)
@@ -292,7 +279,7 @@ func (c *CanaryAdapter) downloadCanaryModel() error {
// createTranscriptionScript creates the Python script for Canary transcription
func (c *CanaryAdapter) createTranscriptionScript() error {
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
// Check if script already exists
if _, err := os.Stat(scriptPath); err == nil {
return nil
@@ -551,7 +538,7 @@ func (c *CanaryAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
"PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
logger.Info("Executing Canary command", "args", strings.Join(args, " "))
output, err := cmd.CombinedOutput()
if ctx.Err() == context.Canceled {
return nil, fmt.Errorf("transcription was cancelled")
@@ -571,7 +558,7 @@ func (c *CanaryAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
result.ModelUsed = "canary-1b-v2"
result.Metadata = c.CreateDefaultMetadata(params)
logger.Info("Canary transcription completed",
logger.Info("Canary transcription completed",
"segments", len(result.Segments),
"words", len(result.WordSegments),
"processing_time", result.ProcessingTime,
@@ -583,7 +570,7 @@ func (c *CanaryAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
// buildCanaryArgs builds the command arguments for Canary
func (c *CanaryAdapter) buildCanaryArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) {
outputFile := filepath.Join(tempDir, "result.json")
scriptPath := filepath.Join(c.envPath, "canary_transcribe.py")
args := []string{
"run", "--native-tls", "--project", c.envPath, "python", scriptPath,
@@ -621,18 +608,18 @@ func (c *CanaryAdapter) buildCanaryArgs(input interfaces.AudioInput, params map[
// parseResult parses the Canary output
func (c *CanaryAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.TranscriptResult, error) {
resultFile := filepath.Join(tempDir, "result.json")
data, err := os.ReadFile(resultFile)
if err != nil {
return nil, fmt.Errorf("failed to read result file: %w", err)
}
var canaryResult struct {
Transcription string `json:"transcription"`
SourceLanguage string `json:"source_language"`
TargetLanguage string `json:"target_language"`
Task string `json:"task"`
WordTimestamps []struct {
Transcription string `json:"transcription"`
SourceLanguage string `json:"source_language"`
TargetLanguage string `json:"target_language"`
Task string `json:"task"`
WordTimestamps []struct {
Word string `json:"word"`
StartOffset int `json:"start_offset"`
EndOffset int `json:"end_offset"`
@@ -661,11 +648,11 @@ func (c *CanaryAdapter) parseResult(tempDir string, input interfaces.AudioInput,
// Convert to standard format
result := &interfaces.TranscriptResult{
Text: canaryResult.Transcription,
Language: resultLanguage,
Segments: make([]interfaces.TranscriptSegment, len(canaryResult.SegmentTimestamps)),
Text: canaryResult.Transcription,
Language: resultLanguage,
Segments: make([]interfaces.TranscriptSegment, len(canaryResult.SegmentTimestamps)),
WordSegments: make([]interfaces.TranscriptWord, len(canaryResult.WordTimestamps)),
Confidence: 0.0, // Default confidence
Confidence: 0.0, // Default confidence
}
// Convert segments
@@ -695,7 +682,7 @@ func (c *CanaryAdapter) parseResult(tempDir string, input interfaces.AudioInput,
func (c *CanaryAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration {
// Canary is typically slower than Parakeet due to its multilingual capabilities
baseTime := c.BaseAdapter.GetEstimatedProcessingTime(input)
// Canary typically processes at about 40-50% of audio duration
return time.Duration(float64(baseTime) * 2.0)
}
}

View File

@@ -12,6 +12,7 @@ import (
"time"
"scriberr/internal/transcription/interfaces"
"scriberr/pkg/downloader"
"scriberr/pkg/logger"
)
@@ -233,22 +234,8 @@ func (p *ParakeetAdapter) downloadParakeetModel() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
tempPath := modelPath + ".tmp"
os.Remove(tempPath)
cmd := exec.CommandContext(ctx, "curl",
"-L", "--progress-bar", "--create-dirs",
"-o", tempPath, modelURL)
out, err := cmd.CombinedOutput()
if err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to download Parakeet model: %w: %s", err, strings.TrimSpace(string(out)))
}
if err := os.Rename(tempPath, modelPath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to move downloaded model: %w", err)
if err := downloader.DownloadFile(ctx, modelURL, modelPath); err != nil {
return fmt.Errorf("failed to download Parakeet model: %w", err)
}
stat, err := os.Stat(modelPath)

View File

@@ -12,6 +12,7 @@ import (
"time"
"scriberr/internal/transcription/interfaces"
"scriberr/pkg/downloader"
"scriberr/pkg/logger"
)
@@ -41,13 +42,13 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter {
"no_token_required": true,
},
Metadata: map[string]string{
"engine": "nvidia_nemo",
"framework": "nemo_toolkit",
"license": "CC-BY-4.0",
"optimization": "4_speakers",
"sample_rate": "16000",
"format": "16khz_mono_wav",
"no_auth": "true",
"engine": "nvidia_nemo",
"framework": "nemo_toolkit",
"license": "CC-BY-4.0",
"optimization": "4_speakers",
"sample_rate": "16000",
"format": "16khz_mono_wav",
"no_auth": "true",
},
}
@@ -128,7 +129,7 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter {
}
baseAdapter := NewBaseAdapter("sortformer", envPath, capabilities, schema)
adapter := &SortformerAdapter{
BaseAdapter: baseAdapter,
envPath: envPath,
@@ -243,28 +244,14 @@ func (s *SortformerAdapter) downloadSortformerModel() error {
}
logger.Info("Downloading Sortformer model", "path", modelPath)
modelURL := "https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2/resolve/main/diar_streaming_sortformer_4spk-v2.nemo?download=true"
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
tempPath := modelPath + ".tmp"
os.Remove(tempPath)
cmd := exec.CommandContext(ctx, "curl",
"-L", "-#", "--max-time", "1800",
"-o", tempPath, modelURL)
out, err := cmd.CombinedOutput()
if err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to download Sortformer model: %w: %s", err, strings.TrimSpace(string(out)))
}
if err := os.Rename(tempPath, modelPath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to move downloaded model: %w", err)
if err := downloader.DownloadFile(ctx, modelURL, modelPath); err != nil {
return fmt.Errorf("failed to download Sortformer model: %w", err)
}
stat, err := os.Stat(modelPath)
@@ -282,7 +269,7 @@ func (s *SortformerAdapter) downloadSortformerModel() error {
// createDiarizationScript creates the Python script for Sortformer diarization
func (s *SortformerAdapter) createDiarizationScript() error {
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
// Check if script already exists
if _, err := os.Stat(scriptPath); err == nil {
return nil
@@ -672,7 +659,7 @@ func (s *SortformerAdapter) Diarize(ctx context.Context, input interfaces.AudioI
cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1")
logger.Info("Executing Sortformer command", "args", strings.Join(args, " "))
output, err := cmd.CombinedOutput()
if ctx.Err() == context.Canceled {
return nil, fmt.Errorf("diarization was cancelled")
@@ -692,7 +679,7 @@ func (s *SortformerAdapter) Diarize(ctx context.Context, input interfaces.AudioI
result.ModelUsed = "diar_streaming_sortformer_4spk-v2"
result.Metadata = s.CreateDefaultMetadata(params)
logger.Info("Sortformer diarization completed",
logger.Info("Sortformer diarization completed",
"segments", len(result.Segments),
"speakers", result.SpeakerCount,
"processing_time", result.ProcessingTime)
@@ -709,7 +696,7 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par
} else {
outputFile = filepath.Join(tempDir, "result.rttm")
}
scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py")
args := []string{
"run", "--native-tls", "--project", s.envPath, "python", scriptPath,
@@ -749,7 +736,7 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par
// parseResult parses the Sortformer output
func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.DiarizationResult, error) {
outputFormat := s.GetStringParameter(params, "output_format")
if outputFormat == "json" {
return s.parseJSONResult(tempDir)
} else {
@@ -760,16 +747,16 @@ func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioIn
// parseJSONResult parses JSON format output
func (s *SortformerAdapter) parseJSONResult(tempDir string) (*interfaces.DiarizationResult, error) {
resultFile := filepath.Join(tempDir, "result.json")
data, err := os.ReadFile(resultFile)
if err != nil {
return nil, fmt.Errorf("failed to read result file: %w", err)
}
var sortformerResult struct {
AudioFile string `json:"audio_file"`
Model string `json:"model"`
Segments []struct {
AudioFile string `json:"audio_file"`
Model string `json:"model"`
Segments []struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Speaker string `json:"speaker"`
@@ -808,7 +795,7 @@ func (s *SortformerAdapter) parseJSONResult(tempDir string) (*interfaces.Diariza
// parseRTTMResult parses RTTM format output
func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.AudioInput) (*interfaces.DiarizationResult, error) {
resultFile := filepath.Join(tempDir, "result.rttm")
data, err := os.ReadFile(resultFile)
if err != nil {
return nil, fmt.Errorf("failed to read result file: %w", err)
@@ -870,7 +857,7 @@ func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.Aud
func (s *SortformerAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration {
// Sortformer is typically very fast, often faster than real-time
baseTime := s.BaseAdapter.GetEstimatedProcessingTime(input)
// Sortformer typically processes at about 5-10% of audio duration
return time.Duration(float64(baseTime) * 0.3)
}
}

View File

@@ -0,0 +1,119 @@
package downloader
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
)
// DownloadFile downloads a file from a URL to a destination path with progress tracking
func DownloadFile(ctx context.Context, url, dest string) error {
// Create parent directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Create temporary file
tempDest := dest + ".tmp"
out, err := os.Create(tempDest)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
// Create request with context
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
// Execute request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
}
// Create progress tracker
size := resp.ContentLength
tracker := &progressTracker{
Total: size,
Filename: filepath.Base(dest),
LastLog: time.Now(),
}
// Copy with progress
_, err = io.Copy(out, io.TeeReader(resp.Body, tracker))
if err != nil {
return fmt.Errorf("failed to save file: %w", err)
}
// Close file before renaming
out.Close()
// Rename temp file to final destination
if err := os.Rename(tempDest, dest); err != nil {
return fmt.Errorf("failed to rename file: %w", err)
}
// Print final newline
fmt.Println()
return nil
}
type progressTracker struct {
Total int64
Current int64
Filename string
LastLog time.Time
LastPercent int
}
func (pt *progressTracker) Write(p []byte) (int, error) {
n := len(p)
pt.Current += int64(n)
pt.printProgress()
return n, nil
}
func (pt *progressTracker) printProgress() {
// Calculate percentage
percent := int(float64(pt.Current) / float64(pt.Total) * 100)
// Update only if percentage changed significantly or enough time passed
if percent != pt.LastPercent && (percent%5 == 0 || time.Since(pt.LastLog) > 1*time.Second) {
pt.LastPercent = percent
pt.LastLog = time.Now()
// Clear line and print progress
// \r moves cursor to start of line
// \033[K clears the line
fmt.Printf("\r\033[KDownloading %s: %d%% (%s / %s)",
pt.Filename,
percent,
formatBytes(pt.Current),
formatBytes(pt.Total))
}
}
func formatBytes(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
}

View File

@@ -145,7 +145,9 @@ func WithContext(key string, value any) *Logger {
func Startup(step, message string, args ...any) {
// Simple message at INFO level, technical details at DEBUG
if currentLevel <= LevelInfo {
Info(message)
// Clean, user-friendly startup message
// \033[36m is Cyan color for the [+] prefix
fmt.Printf("\033[36m[+]\033[0m %s\n", message)
}
if currentLevel <= LevelDebug {
Debug("Startup step", append([]any{"step", step, "message", message}, args...)...)
@@ -156,25 +158,25 @@ func Startup(step, message string, args ...any) {
func JobStarted(jobID, filename, model string, params map[string]any) {
// Simple message at INFO, details at DEBUG
Info("Transcription started", "file", filename)
Debug("Job started with details",
"job_id", jobID,
"file", filename,
Debug("Job started with details",
"job_id", jobID,
"file", filename,
"model", model,
"params", params)
}
func JobCompleted(jobID string, duration time.Duration, result any) {
Info("Transcription completed", "duration", duration.String())
Debug("Job completed with details",
"job_id", jobID,
Debug("Job completed with details",
"job_id", jobID,
"duration", duration.String(),
"result", result)
}
func JobFailed(jobID string, duration time.Duration, err error) {
Error("Transcription failed", "error", err.Error())
Debug("Job failed with details",
"job_id", jobID,
Debug("Job failed with details",
"job_id", jobID,
"duration", duration.String(),
"error", err.Error())
}
@@ -193,10 +195,10 @@ func HTTPRequest(method, path string, status int, duration time.Duration, userAg
return
}
}
// Log all requests at DEBUG level
if currentLevel <= LevelDebug {
Debug("API request",
Debug("API request",
"method", method,
"path", path,
"status", status,
@@ -209,24 +211,24 @@ func HTTPRequest(method, path string, status int, duration time.Duration, userAg
func AuthEvent(event, username, ip string, success bool, details ...any) {
if success {
Info("User login successful", "username", username)
Debug("Auth event details",
Debug("Auth event details",
append([]any{"event", event, "username", username, "ip", ip, "success", success}, details...)...)
} else {
Info("User login failed", "username", username, "reason", "invalid_credentials")
Debug("Auth event details",
Debug("Auth event details",
append([]any{"event", event, "username", username, "ip", ip, "success", success}, details...)...)
}
}
// Worker operation logger
func WorkerOperation(workerID int, jobID string, operation string, args ...any) {
Debug("Worker operation",
Debug("Worker operation",
append([]any{"worker_id", workerID, "job_id", jobID, "operation", operation}, args...)...)
}
// Performance logging for debugging
func Performance(operation string, duration time.Duration, details ...any) {
Debug("Performance",
Debug("Performance",
append([]any{"operation", operation, "duration", duration.String()}, details...)...)
}
@@ -262,7 +264,7 @@ func GinLogger() gin.HandlerFunc {
// Log request
status := c.Writer.Status()
statusColor := getStatusColor(status)
if currentLevel <= LevelDebug {
// Detailed logging for DEBUG
Debug("API request",
@@ -306,4 +308,4 @@ func getStatusColor(status int) string {
func SetGinOutput() {
// Set GIN to use a discard writer to suppress default logging
gin.DefaultWriter = io.Discard
}
}