mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-29 07:15:54 +00:00
Use local audio artifacts for in-process ASR engine
This commit is contained in:
@@ -183,7 +183,7 @@ func (p *LocalProvider) executeTranscription(ctx context.Context, req TaskReques
|
||||
RequestID: req.JobID,
|
||||
Task: speechproviders.TaskTranscription,
|
||||
ModelID: modelID,
|
||||
AudioPath: req.AudioPath,
|
||||
AudioPath: localEngineAudioPath(req),
|
||||
Parameters: copyParameters(req.Parameters),
|
||||
Progress: localProgressSink{downstream: req.Progress},
|
||||
}
|
||||
@@ -248,6 +248,13 @@ func localTranscriptionMetadata(modelID string, out *speechengine.TranscriptionR
|
||||
return metadata
|
||||
}
|
||||
|
||||
func localEngineAudioPath(req TaskRequest) string {
|
||||
if strings.TrimSpace(req.LocalAudioPath) != "" {
|
||||
return req.LocalAudioPath
|
||||
}
|
||||
return req.AudioPath
|
||||
}
|
||||
|
||||
func (p *LocalProvider) defaultModelID(ctx context.Context, capability asrcontract.Capability) string {
|
||||
models, err := p.Models(ctx)
|
||||
if err != nil {
|
||||
@@ -278,7 +285,7 @@ func (p *LocalProvider) executeDiarization(ctx context.Context, req TaskRequest)
|
||||
RequestID: req.JobID,
|
||||
Task: speechproviders.TaskDiarization,
|
||||
ModelID: modelID,
|
||||
AudioPath: req.AudioPath,
|
||||
AudioPath: localEngineAudioPath(req),
|
||||
Parameters: copyParameters(req.Parameters),
|
||||
Progress: localProgressSink{downstream: req.Progress},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package engineprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
speechengine "scriberr-engine/speech/engine"
|
||||
"scriberr-engine/speech/runtime"
|
||||
)
|
||||
|
||||
func TestLocalProviderUsesLocalAudioPathForInProcessEngine(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{
|
||||
transcriptionOut: &speechengine.TranscriptionResult{Text: "hello"},
|
||||
diarizationOut: &speechengine.DiarizationResult{},
|
||||
}
|
||||
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
|
||||
|
||||
if _, err := transcribeForTest(context.Background(), provider, TaskRequest{
|
||||
JobID: "job-local-path",
|
||||
AudioPath: "/provider-input/audio/audio.wav",
|
||||
LocalAudioPath: "/tmp/scriberr-normalized/audio.wav",
|
||||
ModelID: "parakeet-v3",
|
||||
}); err != nil {
|
||||
t.Fatalf("Transcribe returned error: %v", err)
|
||||
}
|
||||
if fake.transcriptionReq.AudioPath != "/tmp/scriberr-normalized/audio.wav" {
|
||||
t.Fatalf("transcription AudioPath = %q", fake.transcriptionReq.AudioPath)
|
||||
}
|
||||
|
||||
if _, err := diarizeForTest(context.Background(), provider, TaskRequest{
|
||||
JobID: "job-local-path",
|
||||
AudioPath: "/provider-input/audio/audio.wav",
|
||||
LocalAudioPath: "/tmp/scriberr-normalized/audio.wav",
|
||||
ModelID: "diarization-default",
|
||||
}); err != nil {
|
||||
t.Fatalf("Diarize returned error: %v", err)
|
||||
}
|
||||
if fake.diarizationReq.AudioPath != "/tmp/scriberr-normalized/audio.wav" {
|
||||
t.Fatalf("diarization AudioPath = %q", fake.diarizationReq.AudioPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalProviderFallsBackToProviderAudioPath(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{
|
||||
transcriptionOut: &speechengine.TranscriptionResult{Text: "hello"},
|
||||
}
|
||||
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
|
||||
|
||||
if _, err := transcribeForTest(context.Background(), provider, TaskRequest{
|
||||
JobID: "job-provider-path",
|
||||
AudioPath: "/provider-input/audio/audio.wav",
|
||||
ModelID: "parakeet-v3",
|
||||
}); err != nil {
|
||||
t.Fatalf("Transcribe returned error: %v", err)
|
||||
}
|
||||
if fake.transcriptionReq.AudioPath != "/provider-input/audio/audio.wav" {
|
||||
t.Fatalf("transcription AudioPath = %q", fake.transcriptionReq.AudioPath)
|
||||
}
|
||||
}
|
||||
@@ -42,13 +42,14 @@ type SelectionRequest struct {
|
||||
}
|
||||
|
||||
type TaskRequest struct {
|
||||
JobID string
|
||||
UserID uint
|
||||
Operation asrcontract.Operation
|
||||
AudioPath string
|
||||
Progress ProgressSink
|
||||
ModelID string
|
||||
Parameters map[string]any
|
||||
JobID string
|
||||
UserID uint
|
||||
Operation asrcontract.Operation
|
||||
AudioPath string
|
||||
LocalAudioPath string
|
||||
Progress ProgressSink
|
||||
ModelID string
|
||||
Parameters map[string]any
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
|
||||
@@ -148,13 +148,14 @@ func (p *Processor) Process(ctx context.Context, job *models.TranscriptionJob) (
|
||||
var providerMetadata map[string]any
|
||||
for _, step := range pipeline {
|
||||
task, err := step.Provider.ExecuteTask(ctx, engineprovider.TaskRequest{
|
||||
JobID: job.ID,
|
||||
UserID: job.UserID,
|
||||
Operation: step.Operation,
|
||||
AudioPath: audio.ProviderPath,
|
||||
Progress: progressSink,
|
||||
ModelID: step.Model,
|
||||
Parameters: providerParametersForStep(step),
|
||||
JobID: job.ID,
|
||||
UserID: job.UserID,
|
||||
Operation: step.Operation,
|
||||
AudioPath: audio.ProviderPath,
|
||||
LocalAudioPath: audio.Path,
|
||||
Progress: progressSink,
|
||||
ModelID: step.Model,
|
||||
Parameters: providerParametersForStep(step),
|
||||
})
|
||||
if err != nil {
|
||||
return withExecution(p.errorResult(ctx, err))
|
||||
|
||||
@@ -517,6 +517,8 @@ func TestProcessorPassesPreprocessedAudioToProvider(t *testing.T) {
|
||||
require.Equal(t, models.StatusCompleted, result.Status)
|
||||
assert.Equal(t, "/provider-input/audio/file-orchestrator.wav", provider.transReq.AudioPath)
|
||||
assert.Equal(t, provider.transReq.AudioPath, provider.diarizeReq.AudioPath)
|
||||
assert.NotEmpty(t, provider.transReq.LocalAudioPath)
|
||||
assert.Equal(t, provider.transReq.LocalAudioPath, provider.diarizeReq.LocalAudioPath)
|
||||
assert.NotEqual(t, sourcePath, provider.transReq.AudioPath)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user