diff --git a/internal/transcription/adapters/canary_adapter.go b/internal/transcription/adapters/canary_adapter.go index 17e2bfe8..42669ac8 100644 --- a/internal/transcription/adapters/canary_adapter.go +++ b/internal/transcription/adapters/canary_adapter.go @@ -11,6 +11,7 @@ import ( "time" "scriberr/internal/transcription/interfaces" + "scriberr/pkg/downloader" "scriberr/pkg/logger" ) @@ -32,9 +33,9 @@ func NewCanaryAdapter(envPath string) *CanaryAdapter { "en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh", // Canary supports many more languages }, - SupportedFormats: []string{"wav", "flac"}, - RequiresGPU: false, // Can run on CPU but GPU strongly recommended - MemoryRequirement: 8192, // 8GB+ recommended for Canary + SupportedFormats: []string{"wav", "flac"}, + RequiresGPU: false, // Can run on CPU but GPU strongly recommended + MemoryRequirement: 8192, // 8GB+ recommended for Canary Features: map[string]bool{ "timestamps": true, "word_level": true, @@ -67,7 +68,7 @@ func NewCanaryAdapter(envPath string) *CanaryAdapter { }, { Name: "target_lang", - Type: "string", + Type: "string", Required: false, Default: "en", Options: []string{"en", "de", "es", "fr", "hi", "it", "ja", "ko", "pl", "pt", "ru", "zh"}, @@ -145,7 +146,7 @@ func NewCanaryAdapter(envPath string) *CanaryAdapter { } baseAdapter := NewBaseAdapter("canary", envPath, capabilities, schema) - + adapter := &CanaryAdapter{ BaseAdapter: baseAdapter, envPath: envPath, @@ -253,28 +254,14 @@ func (c *CanaryAdapter) downloadCanaryModel() error { } logger.Info("Downloading Canary model", "path", modelPath) - + modelURL := "https://huggingface.co/nvidia/canary-1b-v2/resolve/main/canary-1b-v2.nemo?download=true" - + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) defer cancel() - tempPath := modelPath + ".tmp" - os.Remove(tempPath) - - cmd := exec.CommandContext(ctx, "curl", - "-L", "--progress-bar", "--create-dirs", - "-o", tempPath, modelURL) - - out, err := cmd.CombinedOutput() - if err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to download Canary model: %w: %s", err, strings.TrimSpace(string(out))) - } - - if err := os.Rename(tempPath, modelPath); err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to move downloaded model: %w", err) + if err := downloader.DownloadFile(ctx, modelURL, modelPath); err != nil { + return fmt.Errorf("failed to download Canary model: %w", err) } stat, err := os.Stat(modelPath) @@ -292,7 +279,7 @@ func (c *CanaryAdapter) downloadCanaryModel() error { // createTranscriptionScript creates the Python script for Canary transcription func (c *CanaryAdapter) createTranscriptionScript() error { scriptPath := filepath.Join(c.envPath, "canary_transcribe.py") - + // Check if script already exists if _, err := os.Stat(scriptPath); err == nil { return nil @@ -551,7 +538,7 @@ func (c *CanaryAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True") logger.Info("Executing Canary command", "args", strings.Join(args, " ")) - + output, err := cmd.CombinedOutput() if ctx.Err() == context.Canceled { return nil, fmt.Errorf("transcription was cancelled") @@ -571,7 +558,7 @@ func (c *CanaryAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn result.ModelUsed = "canary-1b-v2" result.Metadata = c.CreateDefaultMetadata(params) - logger.Info("Canary transcription completed", + logger.Info("Canary transcription completed", "segments", len(result.Segments), "words", len(result.WordSegments), "processing_time", result.ProcessingTime, @@ -583,7 +570,7 @@ func (c *CanaryAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn // buildCanaryArgs builds the command arguments for Canary func (c *CanaryAdapter) buildCanaryArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) { outputFile := filepath.Join(tempDir, "result.json") - + scriptPath := filepath.Join(c.envPath, "canary_transcribe.py") args := []string{ "run", "--native-tls", "--project", c.envPath, "python", scriptPath, @@ -621,18 +608,18 @@ func (c *CanaryAdapter) buildCanaryArgs(input interfaces.AudioInput, params map[ // parseResult parses the Canary output func (c *CanaryAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.TranscriptResult, error) { resultFile := filepath.Join(tempDir, "result.json") - + data, err := os.ReadFile(resultFile) if err != nil { return nil, fmt.Errorf("failed to read result file: %w", err) } var canaryResult struct { - Transcription string `json:"transcription"` - SourceLanguage string `json:"source_language"` - TargetLanguage string `json:"target_language"` - Task string `json:"task"` - WordTimestamps []struct { + Transcription string `json:"transcription"` + SourceLanguage string `json:"source_language"` + TargetLanguage string `json:"target_language"` + Task string `json:"task"` + WordTimestamps []struct { Word string `json:"word"` StartOffset int `json:"start_offset"` EndOffset int `json:"end_offset"` @@ -661,11 +648,11 @@ func (c *CanaryAdapter) parseResult(tempDir string, input interfaces.AudioInput, // Convert to standard format result := &interfaces.TranscriptResult{ - Text: canaryResult.Transcription, - Language: resultLanguage, - Segments: make([]interfaces.TranscriptSegment, len(canaryResult.SegmentTimestamps)), + Text: canaryResult.Transcription, + Language: resultLanguage, + Segments: make([]interfaces.TranscriptSegment, len(canaryResult.SegmentTimestamps)), WordSegments: make([]interfaces.TranscriptWord, len(canaryResult.WordTimestamps)), - Confidence: 0.0, // Default confidence + Confidence: 0.0, // Default confidence } // Convert segments @@ -695,7 +682,7 @@ func (c *CanaryAdapter) parseResult(tempDir string, input interfaces.AudioInput, func (c *CanaryAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration { // Canary is typically slower than Parakeet due to its multilingual capabilities baseTime := c.BaseAdapter.GetEstimatedProcessingTime(input) - + // Canary typically processes at about 40-50% of audio duration return time.Duration(float64(baseTime) * 2.0) -} \ No newline at end of file +} diff --git a/internal/transcription/adapters/parakeet_adapter.go b/internal/transcription/adapters/parakeet_adapter.go index aa8a2fb2..157d3547 100644 --- a/internal/transcription/adapters/parakeet_adapter.go +++ b/internal/transcription/adapters/parakeet_adapter.go @@ -12,6 +12,7 @@ import ( "time" "scriberr/internal/transcription/interfaces" + "scriberr/pkg/downloader" "scriberr/pkg/logger" ) @@ -233,22 +234,8 @@ func (p *ParakeetAdapter) downloadParakeetModel() error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) defer cancel() - tempPath := modelPath + ".tmp" - os.Remove(tempPath) - - cmd := exec.CommandContext(ctx, "curl", - "-L", "--progress-bar", "--create-dirs", - "-o", tempPath, modelURL) - - out, err := cmd.CombinedOutput() - if err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to download Parakeet model: %w: %s", err, strings.TrimSpace(string(out))) - } - - if err := os.Rename(tempPath, modelPath); err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to move downloaded model: %w", err) + if err := downloader.DownloadFile(ctx, modelURL, modelPath); err != nil { + return fmt.Errorf("failed to download Parakeet model: %w", err) } stat, err := os.Stat(modelPath) diff --git a/internal/transcription/adapters/sortformer_adapter.go b/internal/transcription/adapters/sortformer_adapter.go index d2e0a7f5..396fb210 100644 --- a/internal/transcription/adapters/sortformer_adapter.go +++ b/internal/transcription/adapters/sortformer_adapter.go @@ -12,6 +12,7 @@ import ( "time" "scriberr/internal/transcription/interfaces" + "scriberr/pkg/downloader" "scriberr/pkg/logger" ) @@ -41,13 +42,13 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter { "no_token_required": true, }, Metadata: map[string]string{ - "engine": "nvidia_nemo", - "framework": "nemo_toolkit", - "license": "CC-BY-4.0", - "optimization": "4_speakers", - "sample_rate": "16000", - "format": "16khz_mono_wav", - "no_auth": "true", + "engine": "nvidia_nemo", + "framework": "nemo_toolkit", + "license": "CC-BY-4.0", + "optimization": "4_speakers", + "sample_rate": "16000", + "format": "16khz_mono_wav", + "no_auth": "true", }, } @@ -128,7 +129,7 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter { } baseAdapter := NewBaseAdapter("sortformer", envPath, capabilities, schema) - + adapter := &SortformerAdapter{ BaseAdapter: baseAdapter, envPath: envPath, @@ -243,28 +244,14 @@ func (s *SortformerAdapter) downloadSortformerModel() error { } logger.Info("Downloading Sortformer model", "path", modelPath) - + modelURL := "https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2/resolve/main/diar_streaming_sortformer_4spk-v2.nemo?download=true" - + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) defer cancel() - tempPath := modelPath + ".tmp" - os.Remove(tempPath) - - cmd := exec.CommandContext(ctx, "curl", - "-L", "-#", "--max-time", "1800", - "-o", tempPath, modelURL) - - out, err := cmd.CombinedOutput() - if err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to download Sortformer model: %w: %s", err, strings.TrimSpace(string(out))) - } - - if err := os.Rename(tempPath, modelPath); err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to move downloaded model: %w", err) + if err := downloader.DownloadFile(ctx, modelURL, modelPath); err != nil { + return fmt.Errorf("failed to download Sortformer model: %w", err) } stat, err := os.Stat(modelPath) @@ -282,7 +269,7 @@ func (s *SortformerAdapter) downloadSortformerModel() error { // createDiarizationScript creates the Python script for Sortformer diarization func (s *SortformerAdapter) createDiarizationScript() error { scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py") - + // Check if script already exists if _, err := os.Stat(scriptPath); err == nil { return nil @@ -672,7 +659,7 @@ func (s *SortformerAdapter) Diarize(ctx context.Context, input interfaces.AudioI cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1") logger.Info("Executing Sortformer command", "args", strings.Join(args, " ")) - + output, err := cmd.CombinedOutput() if ctx.Err() == context.Canceled { return nil, fmt.Errorf("diarization was cancelled") @@ -692,7 +679,7 @@ func (s *SortformerAdapter) Diarize(ctx context.Context, input interfaces.AudioI result.ModelUsed = "diar_streaming_sortformer_4spk-v2" result.Metadata = s.CreateDefaultMetadata(params) - logger.Info("Sortformer diarization completed", + logger.Info("Sortformer diarization completed", "segments", len(result.Segments), "speakers", result.SpeakerCount, "processing_time", result.ProcessingTime) @@ -709,7 +696,7 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par } else { outputFile = filepath.Join(tempDir, "result.rttm") } - + scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py") args := []string{ "run", "--native-tls", "--project", s.envPath, "python", scriptPath, @@ -749,7 +736,7 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par // parseResult parses the Sortformer output func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.DiarizationResult, error) { outputFormat := s.GetStringParameter(params, "output_format") - + if outputFormat == "json" { return s.parseJSONResult(tempDir) } else { @@ -760,16 +747,16 @@ func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioIn // parseJSONResult parses JSON format output func (s *SortformerAdapter) parseJSONResult(tempDir string) (*interfaces.DiarizationResult, error) { resultFile := filepath.Join(tempDir, "result.json") - + data, err := os.ReadFile(resultFile) if err != nil { return nil, fmt.Errorf("failed to read result file: %w", err) } var sortformerResult struct { - AudioFile string `json:"audio_file"` - Model string `json:"model"` - Segments []struct { + AudioFile string `json:"audio_file"` + Model string `json:"model"` + Segments []struct { Start float64 `json:"start"` End float64 `json:"end"` Speaker string `json:"speaker"` @@ -808,7 +795,7 @@ func (s *SortformerAdapter) parseJSONResult(tempDir string) (*interfaces.Diariza // parseRTTMResult parses RTTM format output func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.AudioInput) (*interfaces.DiarizationResult, error) { resultFile := filepath.Join(tempDir, "result.rttm") - + data, err := os.ReadFile(resultFile) if err != nil { return nil, fmt.Errorf("failed to read result file: %w", err) @@ -870,7 +857,7 @@ func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.Aud func (s *SortformerAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration { // Sortformer is typically very fast, often faster than real-time baseTime := s.BaseAdapter.GetEstimatedProcessingTime(input) - + // Sortformer typically processes at about 5-10% of audio duration return time.Duration(float64(baseTime) * 0.3) -} \ No newline at end of file +} diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go new file mode 100644 index 00000000..d68c4b49 --- /dev/null +++ b/pkg/downloader/downloader.go @@ -0,0 +1,119 @@ +package downloader + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" +) + +// DownloadFile downloads a file from a URL to a destination path with progress tracking +func DownloadFile(ctx context.Context, url, dest string) error { + // Create parent directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Create temporary file + tempDest := dest + ".tmp" + out, err := os.Create(tempDest) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + + // Create request with context + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Execute request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status: %s", resp.Status) + } + + // Create progress tracker + size := resp.ContentLength + tracker := &progressTracker{ + Total: size, + Filename: filepath.Base(dest), + LastLog: time.Now(), + } + + // Copy with progress + _, err = io.Copy(out, io.TeeReader(resp.Body, tracker)) + if err != nil { + return fmt.Errorf("failed to save file: %w", err) + } + + // Close file before renaming + out.Close() + + // Rename temp file to final destination + if err := os.Rename(tempDest, dest); err != nil { + return fmt.Errorf("failed to rename file: %w", err) + } + + // Print final newline + fmt.Println() + + return nil +} + +type progressTracker struct { + Total int64 + Current int64 + Filename string + LastLog time.Time + LastPercent int +} + +func (pt *progressTracker) Write(p []byte) (int, error) { + n := len(p) + pt.Current += int64(n) + pt.printProgress() + return n, nil +} + +func (pt *progressTracker) printProgress() { + // Calculate percentage + percent := int(float64(pt.Current) / float64(pt.Total) * 100) + + // Update only if percentage changed significantly or enough time passed + if percent != pt.LastPercent && (percent%5 == 0 || time.Since(pt.LastLog) > 1*time.Second) { + pt.LastPercent = percent + pt.LastLog = time.Now() + + // Clear line and print progress + // \r moves cursor to start of line + // \033[K clears the line + fmt.Printf("\r\033[KDownloading %s: %d%% (%s / %s)", + pt.Filename, + percent, + formatBytes(pt.Current), + formatBytes(pt.Total)) + } +} + +func formatBytes(b int64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp]) +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 14650742..ff8397b6 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -145,7 +145,9 @@ func WithContext(key string, value any) *Logger { func Startup(step, message string, args ...any) { // Simple message at INFO level, technical details at DEBUG if currentLevel <= LevelInfo { - Info(message) + // Clean, user-friendly startup message + // \033[36m is Cyan color for the [+] prefix + fmt.Printf("\033[36m[+]\033[0m %s\n", message) } if currentLevel <= LevelDebug { Debug("Startup step", append([]any{"step", step, "message", message}, args...)...) @@ -156,25 +158,25 @@ func Startup(step, message string, args ...any) { func JobStarted(jobID, filename, model string, params map[string]any) { // Simple message at INFO, details at DEBUG Info("Transcription started", "file", filename) - Debug("Job started with details", - "job_id", jobID, - "file", filename, + Debug("Job started with details", + "job_id", jobID, + "file", filename, "model", model, "params", params) } func JobCompleted(jobID string, duration time.Duration, result any) { Info("Transcription completed", "duration", duration.String()) - Debug("Job completed with details", - "job_id", jobID, + Debug("Job completed with details", + "job_id", jobID, "duration", duration.String(), "result", result) } func JobFailed(jobID string, duration time.Duration, err error) { Error("Transcription failed", "error", err.Error()) - Debug("Job failed with details", - "job_id", jobID, + Debug("Job failed with details", + "job_id", jobID, "duration", duration.String(), "error", err.Error()) } @@ -193,10 +195,10 @@ func HTTPRequest(method, path string, status int, duration time.Duration, userAg return } } - + // Log all requests at DEBUG level if currentLevel <= LevelDebug { - Debug("API request", + Debug("API request", "method", method, "path", path, "status", status, @@ -209,24 +211,24 @@ func HTTPRequest(method, path string, status int, duration time.Duration, userAg func AuthEvent(event, username, ip string, success bool, details ...any) { if success { Info("User login successful", "username", username) - Debug("Auth event details", + Debug("Auth event details", append([]any{"event", event, "username", username, "ip", ip, "success", success}, details...)...) } else { Info("User login failed", "username", username, "reason", "invalid_credentials") - Debug("Auth event details", + Debug("Auth event details", append([]any{"event", event, "username", username, "ip", ip, "success", success}, details...)...) } } // Worker operation logger func WorkerOperation(workerID int, jobID string, operation string, args ...any) { - Debug("Worker operation", + Debug("Worker operation", append([]any{"worker_id", workerID, "job_id", jobID, "operation", operation}, args...)...) } // Performance logging for debugging func Performance(operation string, duration time.Duration, details ...any) { - Debug("Performance", + Debug("Performance", append([]any{"operation", operation, "duration", duration.String()}, details...)...) } @@ -262,7 +264,7 @@ func GinLogger() gin.HandlerFunc { // Log request status := c.Writer.Status() statusColor := getStatusColor(status) - + if currentLevel <= LevelDebug { // Detailed logging for DEBUG Debug("API request", @@ -306,4 +308,4 @@ func getStatusColor(status int) string { func SetGinOutput() { // Set GIN to use a discard writer to suppress default logging gin.DefaultWriter = io.Discard -} \ No newline at end of file +}