Use local audio artifacts for in-process ASR engine

This commit is contained in:
rishikanthc
2026-05-09 12:20:48 -07:00
parent 6a381548e8
commit 7c3ba4dc34
5 changed files with 86 additions and 16 deletions

View File

@@ -183,7 +183,7 @@ func (p *LocalProvider) executeTranscription(ctx context.Context, req TaskReques
RequestID: req.JobID,
Task: speechproviders.TaskTranscription,
ModelID: modelID,
AudioPath: req.AudioPath,
AudioPath: localEngineAudioPath(req),
Parameters: copyParameters(req.Parameters),
Progress: localProgressSink{downstream: req.Progress},
}
@@ -248,6 +248,13 @@ func localTranscriptionMetadata(modelID string, out *speechengine.TranscriptionR
return metadata
}
func localEngineAudioPath(req TaskRequest) string {
if strings.TrimSpace(req.LocalAudioPath) != "" {
return req.LocalAudioPath
}
return req.AudioPath
}
func (p *LocalProvider) defaultModelID(ctx context.Context, capability asrcontract.Capability) string {
models, err := p.Models(ctx)
if err != nil {
@@ -278,7 +285,7 @@ func (p *LocalProvider) executeDiarization(ctx context.Context, req TaskRequest)
RequestID: req.JobID,
Task: speechproviders.TaskDiarization,
ModelID: modelID,
AudioPath: req.AudioPath,
AudioPath: localEngineAudioPath(req),
Parameters: copyParameters(req.Parameters),
Progress: localProgressSink{downstream: req.Progress},
}

View File

@@ -0,0 +1,59 @@
package engineprovider
import (
"context"
"testing"
speechengine "scriberr-engine/speech/engine"
"scriberr-engine/speech/runtime"
)
func TestLocalProviderUsesLocalAudioPathForInProcessEngine(t *testing.T) {
fake := &fakeSpeechEngine{
transcriptionOut: &speechengine.TranscriptionResult{Text: "hello"},
diarizationOut: &speechengine.DiarizationResult{},
}
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
if _, err := transcribeForTest(context.Background(), provider, TaskRequest{
JobID: "job-local-path",
AudioPath: "/provider-input/audio/audio.wav",
LocalAudioPath: "/tmp/scriberr-normalized/audio.wav",
ModelID: "parakeet-v3",
}); err != nil {
t.Fatalf("Transcribe returned error: %v", err)
}
if fake.transcriptionReq.AudioPath != "/tmp/scriberr-normalized/audio.wav" {
t.Fatalf("transcription AudioPath = %q", fake.transcriptionReq.AudioPath)
}
if _, err := diarizeForTest(context.Background(), provider, TaskRequest{
JobID: "job-local-path",
AudioPath: "/provider-input/audio/audio.wav",
LocalAudioPath: "/tmp/scriberr-normalized/audio.wav",
ModelID: "diarization-default",
}); err != nil {
t.Fatalf("Diarize returned error: %v", err)
}
if fake.diarizationReq.AudioPath != "/tmp/scriberr-normalized/audio.wav" {
t.Fatalf("diarization AudioPath = %q", fake.diarizationReq.AudioPath)
}
}
func TestLocalProviderFallsBackToProviderAudioPath(t *testing.T) {
fake := &fakeSpeechEngine{
transcriptionOut: &speechengine.TranscriptionResult{Text: "hello"},
}
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
if _, err := transcribeForTest(context.Background(), provider, TaskRequest{
JobID: "job-provider-path",
AudioPath: "/provider-input/audio/audio.wav",
ModelID: "parakeet-v3",
}); err != nil {
t.Fatalf("Transcribe returned error: %v", err)
}
if fake.transcriptionReq.AudioPath != "/provider-input/audio/audio.wav" {
t.Fatalf("transcription AudioPath = %q", fake.transcriptionReq.AudioPath)
}
}

View File

@@ -42,13 +42,14 @@ type SelectionRequest struct {
}
type TaskRequest struct {
JobID string
UserID uint
Operation asrcontract.Operation
AudioPath string
Progress ProgressSink
ModelID string
Parameters map[string]any
JobID string
UserID uint
Operation asrcontract.Operation
AudioPath string
LocalAudioPath string
Progress ProgressSink
ModelID string
Parameters map[string]any
}
type TaskResult struct {

View File

@@ -148,13 +148,14 @@ func (p *Processor) Process(ctx context.Context, job *models.TranscriptionJob) (
var providerMetadata map[string]any
for _, step := range pipeline {
task, err := step.Provider.ExecuteTask(ctx, engineprovider.TaskRequest{
JobID: job.ID,
UserID: job.UserID,
Operation: step.Operation,
AudioPath: audio.ProviderPath,
Progress: progressSink,
ModelID: step.Model,
Parameters: providerParametersForStep(step),
JobID: job.ID,
UserID: job.UserID,
Operation: step.Operation,
AudioPath: audio.ProviderPath,
LocalAudioPath: audio.Path,
Progress: progressSink,
ModelID: step.Model,
Parameters: providerParametersForStep(step),
})
if err != nil {
return withExecution(p.errorResult(ctx, err))

View File

@@ -517,6 +517,8 @@ func TestProcessorPassesPreprocessedAudioToProvider(t *testing.T) {
require.Equal(t, models.StatusCompleted, result.Status)
assert.Equal(t, "/provider-input/audio/file-orchestrator.wav", provider.transReq.AudioPath)
assert.Equal(t, provider.transReq.AudioPath, provider.diarizeReq.AudioPath)
assert.NotEmpty(t, provider.transReq.LocalAudioPath)
assert.Equal(t, provider.transReq.LocalAudioPath, provider.diarizeReq.LocalAudioPath)
assert.NotEqual(t, sourcePath, provider.transReq.AudioPath)
}