diff --git a/internal/transcription/engineprovider/local_provider.go b/internal/transcription/engineprovider/local_provider.go index b6d971cf..86b93d16 100644 --- a/internal/transcription/engineprovider/local_provider.go +++ b/internal/transcription/engineprovider/local_provider.go @@ -183,7 +183,7 @@ func (p *LocalProvider) executeTranscription(ctx context.Context, req TaskReques RequestID: req.JobID, Task: speechproviders.TaskTranscription, ModelID: modelID, - AudioPath: req.AudioPath, + AudioPath: localEngineAudioPath(req), Parameters: copyParameters(req.Parameters), Progress: localProgressSink{downstream: req.Progress}, } @@ -248,6 +248,13 @@ func localTranscriptionMetadata(modelID string, out *speechengine.TranscriptionR return metadata } +func localEngineAudioPath(req TaskRequest) string { + if strings.TrimSpace(req.LocalAudioPath) != "" { + return req.LocalAudioPath + } + return req.AudioPath +} + func (p *LocalProvider) defaultModelID(ctx context.Context, capability asrcontract.Capability) string { models, err := p.Models(ctx) if err != nil { @@ -278,7 +285,7 @@ func (p *LocalProvider) executeDiarization(ctx context.Context, req TaskRequest) RequestID: req.JobID, Task: speechproviders.TaskDiarization, ModelID: modelID, - AudioPath: req.AudioPath, + AudioPath: localEngineAudioPath(req), Parameters: copyParameters(req.Parameters), Progress: localProgressSink{downstream: req.Progress}, } diff --git a/internal/transcription/engineprovider/local_provider_audio_path_test.go b/internal/transcription/engineprovider/local_provider_audio_path_test.go new file mode 100644 index 00000000..27fcdeae --- /dev/null +++ b/internal/transcription/engineprovider/local_provider_audio_path_test.go @@ -0,0 +1,59 @@ +package engineprovider + +import ( + "context" + "testing" + + speechengine "scriberr-engine/speech/engine" + "scriberr-engine/speech/runtime" +) + +func TestLocalProviderUsesLocalAudioPathForInProcessEngine(t *testing.T) { + fake := &fakeSpeechEngine{ + transcriptionOut: &speechengine.TranscriptionResult{Text: "hello"}, + diarizationOut: &speechengine.DiarizationResult{}, + } + provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake) + + if _, err := transcribeForTest(context.Background(), provider, TaskRequest{ + JobID: "job-local-path", + AudioPath: "/provider-input/audio/audio.wav", + LocalAudioPath: "/tmp/scriberr-normalized/audio.wav", + ModelID: "parakeet-v3", + }); err != nil { + t.Fatalf("Transcribe returned error: %v", err) + } + if fake.transcriptionReq.AudioPath != "/tmp/scriberr-normalized/audio.wav" { + t.Fatalf("transcription AudioPath = %q", fake.transcriptionReq.AudioPath) + } + + if _, err := diarizeForTest(context.Background(), provider, TaskRequest{ + JobID: "job-local-path", + AudioPath: "/provider-input/audio/audio.wav", + LocalAudioPath: "/tmp/scriberr-normalized/audio.wav", + ModelID: "diarization-default", + }); err != nil { + t.Fatalf("Diarize returned error: %v", err) + } + if fake.diarizationReq.AudioPath != "/tmp/scriberr-normalized/audio.wav" { + t.Fatalf("diarization AudioPath = %q", fake.diarizationReq.AudioPath) + } +} + +func TestLocalProviderFallsBackToProviderAudioPath(t *testing.T) { + fake := &fakeSpeechEngine{ + transcriptionOut: &speechengine.TranscriptionResult{Text: "hello"}, + } + provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake) + + if _, err := transcribeForTest(context.Background(), provider, TaskRequest{ + JobID: "job-provider-path", + AudioPath: "/provider-input/audio/audio.wav", + ModelID: "parakeet-v3", + }); err != nil { + t.Fatalf("Transcribe returned error: %v", err) + } + if fake.transcriptionReq.AudioPath != "/provider-input/audio/audio.wav" { + t.Fatalf("transcription AudioPath = %q", fake.transcriptionReq.AudioPath) + } +} diff --git a/internal/transcription/engineprovider/types.go b/internal/transcription/engineprovider/types.go index 8288444d..e205a8d2 100644 --- a/internal/transcription/engineprovider/types.go +++ b/internal/transcription/engineprovider/types.go @@ -42,13 +42,14 @@ type SelectionRequest struct { } type TaskRequest struct { - JobID string - UserID uint - Operation asrcontract.Operation - AudioPath string - Progress ProgressSink - ModelID string - Parameters map[string]any + JobID string + UserID uint + Operation asrcontract.Operation + AudioPath string + LocalAudioPath string + Progress ProgressSink + ModelID string + Parameters map[string]any } type TaskResult struct { diff --git a/internal/transcription/orchestrator/processor.go b/internal/transcription/orchestrator/processor.go index 4e999fd5..b4603f21 100644 --- a/internal/transcription/orchestrator/processor.go +++ b/internal/transcription/orchestrator/processor.go @@ -148,13 +148,14 @@ func (p *Processor) Process(ctx context.Context, job *models.TranscriptionJob) ( var providerMetadata map[string]any for _, step := range pipeline { task, err := step.Provider.ExecuteTask(ctx, engineprovider.TaskRequest{ - JobID: job.ID, - UserID: job.UserID, - Operation: step.Operation, - AudioPath: audio.ProviderPath, - Progress: progressSink, - ModelID: step.Model, - Parameters: providerParametersForStep(step), + JobID: job.ID, + UserID: job.UserID, + Operation: step.Operation, + AudioPath: audio.ProviderPath, + LocalAudioPath: audio.Path, + Progress: progressSink, + ModelID: step.Model, + Parameters: providerParametersForStep(step), }) if err != nil { return withExecution(p.errorResult(ctx, err)) diff --git a/internal/transcription/orchestrator/processor_test.go b/internal/transcription/orchestrator/processor_test.go index 05e9f58d..076deb95 100644 --- a/internal/transcription/orchestrator/processor_test.go +++ b/internal/transcription/orchestrator/processor_test.go @@ -517,6 +517,8 @@ func TestProcessorPassesPreprocessedAudioToProvider(t *testing.T) { require.Equal(t, models.StatusCompleted, result.Status) assert.Equal(t, "/provider-input/audio/file-orchestrator.wav", provider.transReq.AudioPath) assert.Equal(t, provider.transReq.AudioPath, provider.diarizeReq.AudioPath) + assert.NotEmpty(t, provider.transReq.LocalAudioPath) + assert.Equal(t, provider.transReq.LocalAudioPath, provider.diarizeReq.LocalAudioPath) assert.NotEqual(t, sourcePath, provider.transReq.AudioPath) }