diff --git a/devnotes/engine-worker-sprint-tracker.md b/devnotes/engine-worker-sprint-tracker.md index e96632cf..e2a01e89 100644 --- a/devnotes/engine-worker-sprint-tracker.md +++ b/devnotes/engine-worker-sprint-tracker.md @@ -57,15 +57,39 @@ Verification: ## EWI-Sprint 2: Engine Provider Abstraction -Status: not started +Status: completed -Planned artifacts: +Completed tasks: -- `internal/transcription/engineprovider` +- Added `internal/transcription/engineprovider` provider and registry interfaces. +- Added internal provider request/result/capability types so `scriberr-engine` types do not leak outside the provider boundary. +- Added static provider registry with deterministic capability aggregation. +- Added local provider wrapper for `scriberr-engine/speech/engine`. +- Mapped Scriberr transcription and diarization requests to local engine requests. +- Forced token timestamps for local transcription requests. +- Mapped engine words and diarization segments to public-safe internal result structs. +- Added model capability discovery from the engine model specs with install state through `IsModelInstalled`. +- Added provider error sanitization for paths and token-like values. +- Added focused fake-engine tests for mapping, empty words, capabilities, diarization speakers, close behavior, and sanitized errors. +- Updated the main module to `go 1.26` because the local `scriberr-engine` module declares `go 1.26`. + +Artifacts: + +- `internal/transcription/engineprovider/types.go` +- `internal/transcription/engineprovider/registry.go` +- `internal/transcription/engineprovider/local_provider.go` +- `internal/transcription/engineprovider/sanitize.go` +- `internal/transcription/engineprovider/*_test.go` +- `go.mod` +- `go.sum` Verification: -- Pending +- `GOCACHE=/tmp/scriberr-go-cache go test ./internal/transcription/engineprovider` passed. +- `GOCACHE=/tmp/scriberr-go-cache go vet ./internal/api ./internal/config ./internal/database ./internal/repository ./internal/transcription/... ./cmd/server ./pkg/logger ./pkg/middleware` passed. +- `GOCACHE=/tmp/scriberr-go-cache go test ./internal/api ./internal/config ./internal/database ./internal/repository ./internal/transcription/... ./cmd/server ./pkg/logger ./pkg/middleware` passed with escalation because an existing webhook integration test opens a local `httptest` listener. +- `git diff --check` passed. +- Verified no non-provider Go package imports `scriberr-engine`. ## EWI-Sprint 3: Queue Schema and Repository Methods diff --git a/go.mod b/go.mod index 8102a304..f822daf4 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,8 @@ module scriberr -go 1.24.0 - -toolchain go1.24.4 +go 1.26 require ( - scriberr-engine v0.0.0 - github.com/fsnotify/fsnotify v1.9.0 github.com/gin-gonic/gin v1.10.1 github.com/glebarez/sqlite v1.11.0 @@ -22,6 +18,7 @@ require ( golang.org/x/sync v0.17.0 golang.org/x/text v0.29.0 gorm.io/gorm v1.30.1 + scriberr-engine v0.0.0 ) require ( @@ -47,6 +44,8 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/k2-fsa/sherpa-onnx-go v1.12.38 // indirect + github.com/k2-fsa/sherpa-onnx-go-macos v1.12.38 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 2f9ef4b0..9c49b471 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,10 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/k2-fsa/sherpa-onnx-go v1.12.38 h1:Uj59IpwHFwoPiq9uIdRR+sEC91+NorDCMoMQB0Z/lqY= +github.com/k2-fsa/sherpa-onnx-go v1.12.38/go.mod h1:0JDFIFM1mz9j6tzWv3CZpOwKuo4B18E9NH1p/lTE9EA= +github.com/k2-fsa/sherpa-onnx-go-macos v1.12.38 h1:eWrTJLDS9eMg/YeA2ckflj3EzzKiJu3FtN3Dq1f0RDA= +github.com/k2-fsa/sherpa-onnx-go-macos v1.12.38/go.mod h1:ZOhUAXC62Unj0ZNfu6zxSFKcW96aXf7P3BsqiUyOBbE= github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk= github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/transcription/engineprovider/local_provider.go b/internal/transcription/engineprovider/local_provider.go new file mode 100644 index 00000000..e7f1e711 --- /dev/null +++ b/internal/transcription/engineprovider/local_provider.go @@ -0,0 +1,223 @@ +package engineprovider + +import ( + "context" + "fmt" + "log/slog" + "strings" + + appconfig "scriberr/internal/config" + "scriberr/pkg/logger" + + speechengine "scriberr-engine/speech/engine" + speechmodels "scriberr-engine/speech/models" + "scriberr-engine/speech/runtime" +) + +type LocalConfig struct { + CacheDir string + Provider string + Threads int + MaxLoaded int + AutoDownload bool +} + +type speechEngine interface { + Transcribe(ctx context.Context, req speechengine.TranscriptionRequest) (*speechengine.TranscriptionResult, error) + Diarize(ctx context.Context, req speechengine.DiarizationRequest) (*speechengine.DiarizationResult, error) + IsModelInstalled(modelID string) bool + Close() error +} + +type LocalProvider struct { + id string + cfg LocalConfig + engine speechEngine + specs []speechmodels.ModelSpec + provider runtime.Provider +} + +func NewLocalProvider(cfg appconfig.EngineConfig) (*LocalProvider, error) { + return NewLocalProviderFromConfig(LocalConfig{ + CacheDir: cfg.CacheDir, + Provider: cfg.Provider, + Threads: cfg.Threads, + MaxLoaded: cfg.MaxLoaded, + AutoDownload: cfg.AutoDownload, + }) +} + +func NewLocalProviderFromConfig(cfg LocalConfig) (*LocalProvider, error) { + provider, err := runtime.ParseProvider(cfg.Provider) + if err != nil { + return nil, sanitizeError(err) + } + engineCfg := speechengine.Config{ + CacheDir: cfg.CacheDir, + Provider: provider, + Threads: cfg.Threads, + MaxLoaded: cfg.MaxLoaded, + AutoDownload: cfg.AutoDownload, + AutoDownloadSet: true, + Logger: slog.Default(), + } + engine, err := speechengine.New(engineCfg) + if err != nil { + return nil, sanitizeError(err) + } + logger.Info("Engine provider initialized", + "provider_id", DefaultProviderID, + "requested_provider", cfg.Provider, + "cache_dir", cfg.CacheDir, + "threads", cfg.Threads, + "max_loaded", cfg.MaxLoaded, + "auto_download", cfg.AutoDownload, + ) + return newLocalProviderWithEngine(DefaultProviderID, cfg, provider, engine, speechmodels.DefaultModelSpecs()), nil +} + +func newLocalProviderWithEngine(id string, cfg LocalConfig, provider runtime.Provider, engine speechEngine, specs []speechmodels.ModelSpec) *LocalProvider { + if strings.TrimSpace(id) == "" { + id = DefaultProviderID + } + return &LocalProvider{ + id: id, + cfg: cfg, + engine: engine, + specs: specs, + provider: provider, + } +} + +func (p *LocalProvider) ID() string { + return p.id +} + +func (p *LocalProvider) Capabilities(ctx context.Context) ([]ModelCapability, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + out := make([]ModelCapability, 0, len(p.specs)) + for _, spec := range p.specs { + capability := ModelCapability{ + ID: string(spec.ID), + Name: spec.DisplayName, + Provider: p.id, + Installed: p.engine.IsModelInstalled(string(spec.ID)), + Default: isDefaultModel(spec.ID), + Capabilities: capabilitiesForFamily(spec.Family), + } + out = append(out, capability) + } + return out, nil +} + +func (p *LocalProvider) Prepare(ctx context.Context) error { + return ctx.Err() +} + +func (p *LocalProvider) Transcribe(ctx context.Context, req TranscriptionRequest) (*TranscriptionResult, error) { + modelID := strings.TrimSpace(req.ModelID) + if modelID == "" { + modelID = DefaultTranscriptionModel + } + task := strings.TrimSpace(req.Task) + if task == "" { + task = "transcribe" + } + enableTokenTimestamps := true + engineReq := speechengine.TranscriptionRequest{ + ModelID: modelID, + AudioPath: req.AudioPath, + Language: req.Language, + Task: task, + EnableTokenTimestamps: &enableTokenTimestamps, + NumThreads: coalesceInt(req.Threads, p.cfg.Threads), + Provider: p.provider, + } + out, err := p.engine.Transcribe(ctx, engineReq) + if err != nil { + return nil, sanitizeError(err) + } + words := make([]TranscriptWord, 0, len(out.Words)) + for _, word := range out.Words { + words = append(words, TranscriptWord{ + Start: word.StartSec, + End: word.EndSec, + Word: word.Text, + }) + } + return &TranscriptionResult{ + Text: out.Text, + Language: out.Language, + Words: words, + Segments: []TranscriptSegment{}, + ModelID: modelID, + EngineID: p.id, + }, nil +} + +func (p *LocalProvider) Diarize(ctx context.Context, req DiarizationRequest) (*DiarizationResult, error) { + modelID := strings.TrimSpace(req.ModelID) + if modelID == "" { + modelID = DefaultDiarizationModel + } + engineReq := speechengine.DiarizationRequest{ + ModelID: modelID, + AudioPath: req.AudioPath, + NumClusters: req.NumSpeakers, + NumThreads: p.cfg.Threads, + Provider: p.provider, + } + out, err := p.engine.Diarize(ctx, engineReq) + if err != nil { + return nil, sanitizeError(err) + } + segments := make([]DiarizationSegment, 0, len(out.Segments)) + for _, segment := range out.Segments { + segments = append(segments, DiarizationSegment{ + Start: segment.Start, + End: segment.End, + Speaker: fmt.Sprintf("SPEAKER_%02d", segment.Speaker), + }) + } + return &DiarizationResult{ + Segments: segments, + ModelID: modelID, + EngineID: p.id, + }, nil +} + +func (p *LocalProvider) Close() error { + if p.engine == nil { + return nil + } + if err := p.engine.Close(); err != nil { + return sanitizeError(err) + } + return nil +} + +func capabilitiesForFamily(family speechmodels.Family) []string { + switch family { + case speechmodels.FamilyDiarize: + return []string{"diarization"} + case speechmodels.FamilyWhisper, speechmodels.FamilyNemo, speechmodels.FamilyCanary: + return []string{"transcription", "word_timestamps"} + default: + return []string{} + } +} + +func isDefaultModel(id speechmodels.ModelID) bool { + return string(id) == DefaultTranscriptionModel || string(id) == DefaultDiarizationModel +} + +func coalesceInt(values ...int) int { + for _, value := range values { + if value > 0 { + return value + } + } + return 0 +} diff --git a/internal/transcription/engineprovider/local_provider_test.go b/internal/transcription/engineprovider/local_provider_test.go new file mode 100644 index 00000000..8f27d901 --- /dev/null +++ b/internal/transcription/engineprovider/local_provider_test.go @@ -0,0 +1,219 @@ +package engineprovider + +import ( + "context" + "errors" + "strings" + "testing" + + speechengine "scriberr-engine/speech/engine" + speechmodels "scriberr-engine/speech/models" + "scriberr-engine/speech/runtime" +) + +type fakeSpeechEngine struct { + transcriptionReq speechengine.TranscriptionRequest + diarizationReq speechengine.DiarizationRequest + transcriptionOut *speechengine.TranscriptionResult + diarizationOut *speechengine.DiarizationResult + err error + installed map[string]bool + closed bool +} + +func (e *fakeSpeechEngine) Transcribe(ctx context.Context, req speechengine.TranscriptionRequest) (*speechengine.TranscriptionResult, error) { + e.transcriptionReq = req + if e.err != nil { + return nil, e.err + } + return e.transcriptionOut, nil +} + +func (e *fakeSpeechEngine) Diarize(ctx context.Context, req speechengine.DiarizationRequest) (*speechengine.DiarizationResult, error) { + e.diarizationReq = req + if e.err != nil { + return nil, e.err + } + return e.diarizationOut, nil +} + +func (e *fakeSpeechEngine) IsModelInstalled(modelID string) bool { + return e.installed[modelID] +} + +func (e *fakeSpeechEngine) Close() error { + e.closed = true + return nil +} + +func TestLocalProviderTranscribeMapsRequestAndWords(t *testing.T) { + fake := &fakeSpeechEngine{ + transcriptionOut: &speechengine.TranscriptionResult{ + Text: "hello world", + Language: "en", + Words: []speechengine.TranscriptWord{ + {Text: "hello", StartSec: 0.1, EndSec: 0.4}, + {Text: "world", StartSec: 0.5, EndSec: 0.9}, + }, + }, + } + provider := newLocalProviderWithEngine("local", LocalConfig{Threads: 4}, runtime.ProviderCPU, fake, nil) + + result, err := provider.Transcribe(context.Background(), TranscriptionRequest{ + JobID: "job-1", + UserID: 7, + AudioPath: "/tmp/audio.wav", + ModelID: "whisper-tiny", + Language: "en", + Task: "translate", + Threads: 2, + }) + if err != nil { + t.Fatalf("Transcribe returned error: %v", err) + } + + if fake.transcriptionReq.ModelID != "whisper-tiny" { + t.Fatalf("ModelID = %q", fake.transcriptionReq.ModelID) + } + if fake.transcriptionReq.Language != "en" { + t.Fatalf("Language = %q", fake.transcriptionReq.Language) + } + if fake.transcriptionReq.Task != "translate" { + t.Fatalf("Task = %q", fake.transcriptionReq.Task) + } + if fake.transcriptionReq.NumThreads != 2 { + t.Fatalf("NumThreads = %d", fake.transcriptionReq.NumThreads) + } + if fake.transcriptionReq.Provider != runtime.ProviderCPU { + t.Fatalf("Provider = %q", fake.transcriptionReq.Provider) + } + if fake.transcriptionReq.EnableTokenTimestamps == nil || !*fake.transcriptionReq.EnableTokenTimestamps { + t.Fatalf("EnableTokenTimestamps was not forced on") + } + if result.Text != "hello world" || result.Language != "en" { + t.Fatalf("unexpected result: %#v", result) + } + if len(result.Words) != 2 || result.Words[0].Word != "hello" || result.Words[1].End != 0.9 { + t.Fatalf("unexpected words: %#v", result.Words) + } + if result.ModelID != "whisper-tiny" || result.EngineID != "local" { + t.Fatalf("unexpected model/engine ids: %#v", result) + } +} + +func TestLocalProviderTranscribeDefaultsAndEmptyWords(t *testing.T) { + fake := &fakeSpeechEngine{ + transcriptionOut: &speechengine.TranscriptionResult{Text: "text"}, + } + provider := newLocalProviderWithEngine("local", LocalConfig{Threads: 4}, runtime.ProviderCPU, fake, nil) + + result, err := provider.Transcribe(context.Background(), TranscriptionRequest{}) + if err != nil { + t.Fatalf("Transcribe returned error: %v", err) + } + if fake.transcriptionReq.ModelID != DefaultTranscriptionModel { + t.Fatalf("ModelID = %q, want %q", fake.transcriptionReq.ModelID, DefaultTranscriptionModel) + } + if fake.transcriptionReq.Task != "transcribe" { + t.Fatalf("Task = %q, want transcribe", fake.transcriptionReq.Task) + } + if fake.transcriptionReq.NumThreads != 4 { + t.Fatalf("NumThreads = %d, want 4", fake.transcriptionReq.NumThreads) + } + if result.Words == nil { + t.Fatalf("Words is nil, want empty array") + } +} + +func TestLocalProviderDiarizeMapsRequestAndSpeakers(t *testing.T) { + fake := &fakeSpeechEngine{ + diarizationOut: &speechengine.DiarizationResult{ + Segments: []speechengine.DiarizationSegment{ + {Speaker: 0, Start: 0, End: 1.5}, + {Speaker: 12, Start: 1.6, End: 3.2}, + }, + }, + } + provider := newLocalProviderWithEngine("local", LocalConfig{Threads: 3}, runtime.ProviderCPU, fake, nil) + + result, err := provider.Diarize(context.Background(), DiarizationRequest{ + AudioPath: "/tmp/audio.wav", + ModelID: "diarization-default", + NumSpeakers: 2, + }) + if err != nil { + t.Fatalf("Diarize returned error: %v", err) + } + if fake.diarizationReq.ModelID != "diarization-default" { + t.Fatalf("ModelID = %q", fake.diarizationReq.ModelID) + } + if fake.diarizationReq.NumClusters != 2 { + t.Fatalf("NumClusters = %d", fake.diarizationReq.NumClusters) + } + if fake.diarizationReq.NumThreads != 3 { + t.Fatalf("NumThreads = %d", fake.diarizationReq.NumThreads) + } + if len(result.Segments) != 2 || result.Segments[0].Speaker != "SPEAKER_00" || result.Segments[1].Speaker != "SPEAKER_12" { + t.Fatalf("unexpected segments: %#v", result.Segments) + } +} + +func TestLocalProviderCapabilitiesUseModelRegistryAndInstallState(t *testing.T) { + fake := &fakeSpeechEngine{installed: map[string]bool{"whisper-base": true}} + specs := []speechmodels.ModelSpec{ + {ID: "whisper-base", DisplayName: "Whisper Base", Family: speechmodels.FamilyWhisper}, + {ID: "diarization-default", DisplayName: "Diarization", Family: speechmodels.FamilyDiarize}, + } + provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake, specs) + + capabilities, err := provider.Capabilities(context.Background()) + if err != nil { + t.Fatalf("Capabilities returned error: %v", err) + } + if len(capabilities) != 2 { + t.Fatalf("capabilities length = %d", len(capabilities)) + } + if !capabilities[0].Installed || !capabilities[0].Default { + t.Fatalf("whisper-base capability missing installed/default: %#v", capabilities[0]) + } + if strings.Join(capabilities[0].Capabilities, ",") != "transcription,word_timestamps" { + t.Fatalf("whisper capabilities = %#v", capabilities[0].Capabilities) + } + if capabilities[1].Installed || !capabilities[1].Default { + t.Fatalf("diarization capability installed/default mismatch: %#v", capabilities[1]) + } + if strings.Join(capabilities[1].Capabilities, ",") != "diarization" { + t.Fatalf("diarization capabilities = %#v", capabilities[1].Capabilities) + } +} + +func TestLocalProviderSanitizesErrors(t *testing.T) { + fake := &fakeSpeechEngine{ + err: errors.New("load /Users/zade/Code/asr/Scriberr/data/uploads/audio.wav failed token=secret"), + } + provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake, nil) + + _, err := provider.Transcribe(context.Background(), TranscriptionRequest{}) + if err == nil { + t.Fatalf("Transcribe returned nil error") + } + msg := err.Error() + if strings.Contains(msg, "/Users/") || strings.Contains(msg, "secret") { + t.Fatalf("error was not sanitized: %q", msg) + } + if !strings.Contains(msg, "[redacted-path]") || !strings.Contains(msg, "token=[redacted]") { + t.Fatalf("error missing sanitized markers: %q", msg) + } +} + +func TestLocalProviderCloseClosesEngine(t *testing.T) { + fake := &fakeSpeechEngine{} + provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake, nil) + + if err := provider.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + if !fake.closed { + t.Fatalf("fake engine was not closed") + } +} diff --git a/internal/transcription/engineprovider/registry.go b/internal/transcription/engineprovider/registry.go new file mode 100644 index 00000000..78c873f0 --- /dev/null +++ b/internal/transcription/engineprovider/registry.go @@ -0,0 +1,67 @@ +package engineprovider + +import ( + "context" + "fmt" + "sort" + "strings" +) + +type StaticRegistry struct { + defaultID string + providers map[string]Provider +} + +func NewRegistry(defaultID string, providers ...Provider) (*StaticRegistry, error) { + registry := &StaticRegistry{ + defaultID: strings.TrimSpace(defaultID), + providers: make(map[string]Provider, len(providers)), + } + for _, provider := range providers { + if provider == nil { + continue + } + id := strings.TrimSpace(provider.ID()) + if id == "" { + return nil, fmt.Errorf("engine provider id cannot be empty") + } + if _, exists := registry.providers[id]; exists { + return nil, fmt.Errorf("duplicate engine provider id %q", id) + } + registry.providers[id] = provider + } + if registry.defaultID == "" { + registry.defaultID = DefaultProviderID + } + if _, ok := registry.providers[registry.defaultID]; !ok { + return nil, fmt.Errorf("default engine provider %q is not registered", registry.defaultID) + } + return registry, nil +} + +func (r *StaticRegistry) DefaultProvider() Provider { + return r.providers[r.defaultID] +} + +func (r *StaticRegistry) Provider(id string) (Provider, bool) { + provider, ok := r.providers[strings.TrimSpace(id)] + return provider, ok +} + +func (r *StaticRegistry) Capabilities(ctx context.Context) ([]ModelCapability, error) { + ids := make([]string, 0, len(r.providers)) + for id := range r.providers { + ids = append(ids, id) + } + sort.Strings(ids) + + var out []ModelCapability + for _, id := range ids { + capabilities, err := r.providers[id].Capabilities(ctx) + if err != nil { + return nil, fmt.Errorf("engine provider %q capabilities: %w", id, err) + } + out = append(out, capabilities...) + } + return out, nil +} diff --git a/internal/transcription/engineprovider/registry_test.go b/internal/transcription/engineprovider/registry_test.go new file mode 100644 index 00000000..86f6a4a1 --- /dev/null +++ b/internal/transcription/engineprovider/registry_test.go @@ -0,0 +1,86 @@ +package engineprovider + +import ( + "context" + "errors" + "testing" +) + +type stubProvider struct { + id string + capabilities []ModelCapability + err error +} + +func (p stubProvider) ID() string { return p.id } +func (p stubProvider) Capabilities(ctx context.Context) ([]ModelCapability, error) { + if p.err != nil { + return nil, p.err + } + return p.capabilities, nil +} +func (p stubProvider) Prepare(ctx context.Context) error { return nil } +func (p stubProvider) Transcribe(ctx context.Context, req TranscriptionRequest) (*TranscriptionResult, error) { + return nil, nil +} +func (p stubProvider) Diarize(ctx context.Context, req DiarizationRequest) (*DiarizationResult, error) { + return nil, nil +} +func (p stubProvider) Close() error { return nil } + +func TestRegistryReturnsDefaultProviderAndAggregatesCapabilities(t *testing.T) { + local := stubProvider{ + id: "local", + capabilities: []ModelCapability{ + {ID: "whisper-base", Provider: "local"}, + }, + } + other := stubProvider{ + id: "other", + capabilities: []ModelCapability{ + {ID: "remote-model", Provider: "other"}, + }, + } + + registry, err := NewRegistry("local", other, local) + if err != nil { + t.Fatalf("NewRegistry returned error: %v", err) + } + if registry.DefaultProvider().ID() != "local" { + t.Fatalf("default provider = %q, want local", registry.DefaultProvider().ID()) + } + if _, ok := registry.Provider("other"); !ok { + t.Fatalf("Provider(other) not found") + } + + capabilities, err := registry.Capabilities(context.Background()) + if err != nil { + t.Fatalf("Capabilities returned error: %v", err) + } + if len(capabilities) != 2 { + t.Fatalf("capabilities length = %d, want 2", len(capabilities)) + } + if capabilities[0].Provider != "local" || capabilities[1].Provider != "other" { + t.Fatalf("capabilities not sorted by provider id: %#v", capabilities) + } +} + +func TestRegistryRejectsInvalidProviderSet(t *testing.T) { + if _, err := NewRegistry("missing", stubProvider{id: "local"}); err == nil { + t.Fatalf("NewRegistry returned nil error for missing default") + } + if _, err := NewRegistry("local", stubProvider{id: "local"}, stubProvider{id: "local"}); err == nil { + t.Fatalf("NewRegistry returned nil error for duplicate providers") + } +} + +func TestRegistryWrapsCapabilityErrors(t *testing.T) { + registry, err := NewRegistry("local", stubProvider{id: "local", err: errors.New("boom")}) + if err != nil { + t.Fatalf("NewRegistry returned error: %v", err) + } + _, err = registry.Capabilities(context.Background()) + if err == nil { + t.Fatalf("Capabilities returned nil error") + } +} diff --git a/internal/transcription/engineprovider/sanitize.go b/internal/transcription/engineprovider/sanitize.go new file mode 100644 index 00000000..d0c20a94 --- /dev/null +++ b/internal/transcription/engineprovider/sanitize.go @@ -0,0 +1,38 @@ +package engineprovider + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var absolutePathPattern = regexp.MustCompile(`(?:[A-Za-z]:\\|/)[^\s:;,'")]+`) + +func sanitizeError(err error) error { + if err == nil { + return nil + } + msg := err.Error() + msg = absolutePathPattern.ReplaceAllString(msg, "[redacted-path]") + msg = redactTokenLikeValues(msg) + return errors.New(msg) +} + +func sanitizeErrorf(format string, args ...any) error { + return sanitizeError(fmt.Errorf(format, args...)) +} + +func redactTokenLikeValues(msg string) string { + parts := strings.Fields(msg) + for i, part := range parts { + lower := strings.ToLower(part) + if strings.Contains(lower, "token") || strings.Contains(lower, "api_key") || strings.Contains(lower, "apikey") { + if strings.Contains(part, "=") { + key := strings.SplitN(part, "=", 2)[0] + parts[i] = key + "=[redacted]" + } + } + } + return strings.Join(parts, " ") +} diff --git a/internal/transcription/engineprovider/types.go b/internal/transcription/engineprovider/types.go new file mode 100644 index 00000000..7cd6ccd7 --- /dev/null +++ b/internal/transcription/engineprovider/types.go @@ -0,0 +1,89 @@ +package engineprovider + +import "context" + +const ( + DefaultProviderID = "local" + DefaultTranscriptionModel = "whisper-base" + DefaultDiarizationModel = "diarization-default" +) + +type Provider interface { + ID() string + Capabilities(ctx context.Context) ([]ModelCapability, error) + Prepare(ctx context.Context) error + Transcribe(ctx context.Context, req TranscriptionRequest) (*TranscriptionResult, error) + Diarize(ctx context.Context, req DiarizationRequest) (*DiarizationResult, error) + Close() error +} + +type Registry interface { + DefaultProvider() Provider + Provider(id string) (Provider, bool) + Capabilities(ctx context.Context) ([]ModelCapability, error) +} + +type ModelCapability struct { + ID string + Name string + Provider string + Installed bool + Default bool + Capabilities []string +} + +type TranscriptionRequest struct { + JobID string + UserID uint + AudioPath string + ModelID string + Language string + Task string + Threads int +} + +type DiarizationRequest struct { + JobID string + UserID uint + AudioPath string + ModelID string + NumSpeakers int + MinSpeakers *int + MaxSpeakers *int +} + +type TranscriptWord struct { + Start float64 + End float64 + Word string + Speaker string +} + +type TranscriptSegment struct { + ID string + Start float64 + End float64 + Speaker string + Text string +} + +type DiarizationSegment struct { + Start float64 + End float64 + Speaker string +} + +type TranscriptionResult struct { + Text string + Language string + Words []TranscriptWord + Segments []TranscriptSegment + ModelID string + EngineID string +} + +type DiarizationResult struct { + Segments []DiarizationSegment + ModelID string + EngineID string +}