From 4dcb23cc41a5f07fc9d95c1943ec10c8563e7e0c Mon Sep 17 00:00:00 2001 From: rishikanthc Date: Wed, 6 May 2026 12:23:50 -0700 Subject: [PATCH] Consume descriptor-first engine inventory --- .../engineprovider/local_provider.go | 149 ++++++------------ .../engineprovider/local_provider_test.go | 133 +++++----------- 2 files changed, 89 insertions(+), 193 deletions(-) diff --git a/internal/transcription/engineprovider/local_provider.go b/internal/transcription/engineprovider/local_provider.go index de7337c1..8b75639a 100644 --- a/internal/transcription/engineprovider/local_provider.go +++ b/internal/transcription/engineprovider/local_provider.go @@ -25,7 +25,7 @@ type LocalConfig struct { type speechEngine interface { Inspect(ctx context.Context) (*speechengine.ProviderInfo, error) - Models(ctx context.Context) ([]speechengine.ModelCard, error) + Models(ctx context.Context) ([]speechproviders.ModelDescriptor, error) Status(ctx context.Context) (*speechengine.ProviderStatus, error) LoadedModels() []speechengine.LoadedModel Transcribe(ctx context.Context, req speechengine.TranscriptionRequest) (*speechengine.TranscriptionResult, error) @@ -115,7 +115,7 @@ func (p *LocalProvider) Models(ctx context.Context) ([]asrcontract.ModelCard, er } out := make([]asrcontract.ModelCard, 0, len(models)) for _, model := range models { - out = append(out, modelCardFromEngine(model, p.id, p.cfg, p.provider)) + out = append(out, modelCardFromEngine(model, p.id)) } return out, nil } @@ -330,26 +330,22 @@ func providerInfoFromEngine(info *speechengine.ProviderInfo, providerID string) } } -func modelCardFromEngine(model speechengine.ModelCard, providerID string, cfg LocalConfig, provider runtime.Provider) asrcontract.ModelCard { - capabilities := capabilitiesFromEngine(model.Capabilities) - descriptor := model.Descriptor - if strings.TrimSpace(descriptor.ID) == "" { - return fallbackModelCardFromEngine(model, providerID, capabilities) - } +func modelCardFromEngine(descriptor speechproviders.ModelDescriptor, providerID string) asrcontract.ModelCard { + capabilities := capabilitiesFromDescriptor(descriptor) return asrcontract.ModelCard{ - ID: firstNonEmpty(descriptor.ID, model.ID), - DisplayName: firstNonEmpty(descriptor.DisplayName, model.DisplayName), + ID: descriptor.ID, + DisplayName: descriptor.DisplayName, Provider: providerID, - Family: firstNonEmpty(descriptor.Family, model.Family), - Version: firstNonEmpty(descriptor.Version, model.Version), - Installed: model.Installed, - Loaded: model.Loaded, - Default: model.Default, - Tasks: tasksFromDescriptor(descriptor.Tasks, model.Tasks), - Languages: languageIDsFromDescriptor(descriptor.Languages, model.Languages), - LanguageSupport: languageSupportFromDescriptor(descriptor.Languages, model.Languages), + Family: descriptor.Family, + Version: descriptor.Version, + Installed: descriptor.Installed, + Loaded: descriptor.Loaded, + Default: descriptor.Default, + Tasks: tasksFromDescriptor(descriptor.Tasks), + Languages: languageIDsFromDescriptor(descriptor.Languages), + LanguageSupport: languageSupportFromDescriptor(descriptor.Languages), Capabilities: capabilities, - ResourceRequirements: resourceRequirementsFromDescriptor(descriptor.Runtime, model.ResourceRequirements), + ResourceRequirements: resourceRequirementsFromDescriptor(descriptor.Runtime), Chunking: chunkingCapabilitiesFromDescriptor(descriptor.Chunking, descriptor.Runtime, capabilities), ParameterSchema: parameterSchemaFromDescriptor(descriptor.Parameters), RecommendedDefaults: copyRecommendedDefaults(descriptor.RecommendedDefaults), @@ -357,39 +353,7 @@ func modelCardFromEngine(model speechengine.ModelCard, providerID string, cfg Lo } } -func fallbackModelCardFromEngine(model speechengine.ModelCard, providerID string, capabilities asrcontract.Capabilities) asrcontract.ModelCard { - return asrcontract.ModelCard{ - ID: model.ID, - DisplayName: model.DisplayName, - Provider: providerID, - Family: model.Family, - Version: model.Version, - Installed: model.Installed, - Loaded: model.Loaded, - Default: model.Default, - Tasks: tasksFromEngine(model.Tasks), - Languages: append([]string(nil), model.Languages...), - LanguageSupport: fallbackLanguageSupport(model.Languages), - Capabilities: capabilities, - ResourceRequirements: resourceRequirementsFromEngine(model.ResourceRequirements), - } -} - -func fallbackLanguageSupport(languages []string) *asrcontract.LanguageSupport { - if len(languages) == 0 { - return nil - } - mode := "fixed" - if len(languages) > 1 { - mode = "configurable" - } - return &asrcontract.LanguageSupport{Languages: append([]string(nil), languages...), Mode: mode} -} - -func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport, fallback []string) *asrcontract.LanguageSupport { - if len(languages) == 0 { - return fallbackLanguageSupport(fallback) - } +func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport) *asrcontract.LanguageSupport { ids := make([]string, 0, len(languages)) mode := "" for _, language := range languages { @@ -409,10 +373,7 @@ func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport, return &asrcontract.LanguageSupport{Languages: ids, Mode: mode} } -func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport, fallback []string) []string { - if len(languages) == 0 { - return append([]string(nil), fallback...) - } +func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport) []string { out := make([]string, 0, len(languages)) for _, language := range languages { if strings.TrimSpace(language.ID) != "" { @@ -422,10 +383,7 @@ func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport, fall return out } -func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor, fallback []speechengine.Task) []asrcontract.Task { - if len(tasks) == 0 { - return tasksFromEngine(fallback) - } +func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor) []asrcontract.Task { out := make([]asrcontract.Task, 0, len(tasks)) for _, task := range tasks { switch task.Kind { @@ -440,10 +398,7 @@ func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor, fallback []spee return out } -func resourceRequirementsFromDescriptor(runtime speechproviders.RuntimeCapabilities, fallback speechengine.ResourceRequirements) asrcontract.ResourceRequirements { - if len(runtime.Backends) == 0 { - return resourceRequirementsFromEngine(fallback) - } +func resourceRequirementsFromDescriptor(runtime speechproviders.RuntimeCapabilities) asrcontract.ResourceRequirements { backends := make([]string, 0, len(runtime.Backends)) for _, backend := range runtime.Backends { backends = append(backends, string(backend)) @@ -547,15 +502,6 @@ func descriptorExtensionsFromEngine(descriptor speechproviders.ModelDescriptor) return extensions } -func firstNonEmpty(values ...string) string { - for _, value := range values { - if strings.TrimSpace(value) != "" { - return strings.TrimSpace(value) - } - } - return "" -} - func cloneFloat64(value *float64) *float64 { if value == nil { return nil @@ -607,36 +553,45 @@ func loadedModelFromEngine(model speechengine.LoadedModel) asrcontract.LoadedMod } } -func providerCapabilitiesFromEngine(capabilities []speechengine.Capability) []asrcontract.Capability { +func providerCapabilitiesFromEngine(capabilities []speechproviders.TaskKind) []asrcontract.Capability { out := make([]asrcontract.Capability, 0, len(capabilities)) for _, capability := range capabilities { - out = append(out, asrcontract.Capability(capability)) + switch capability { + case speechproviders.TaskTranscription: + out = append(out, asrcontract.CapabilityTranscription) + case speechproviders.TaskDiarization: + out = append(out, asrcontract.CapabilityDiarization) + case speechproviders.TaskSpeakerIdentification: + out = append(out, asrcontract.CapabilitySpeakerIdentification) + case speechproviders.TaskTranslation: + out = append(out, asrcontract.CapabilityTranslation) + default: + out = append(out, asrcontract.Capability(capability)) + } } return out } -func tasksFromEngine(tasks []speechengine.Task) []asrcontract.Task { - out := make([]asrcontract.Task, 0, len(tasks)) - for _, task := range tasks { - out = append(out, asrcontract.Task(task)) +func capabilitiesFromDescriptor(descriptor speechproviders.ModelDescriptor) asrcontract.Capabilities { + capabilities := asrcontract.Capabilities{ + WordTimestamps: descriptor.Output.WordTimestamps, + SegmentTimestamps: descriptor.Output.SegmentTimestamps, + TokenTimestamps: descriptor.Output.TokenTimestamps, + LanguageDetection: descriptor.Output.LanguageSpans, + SpeakerEmbeddings: descriptor.Output.SpeakerLabels, + Translation: descriptor.Output.Translation, } - return out -} - -func capabilitiesFromEngine(capabilities speechengine.Capabilities) asrcontract.Capabilities { - return asrcontract.Capabilities{ - Transcription: capabilities.Transcription, - Diarization: capabilities.Diarization, - WordTimestamps: capabilities.WordTimestamps, - SegmentTimestamps: capabilities.SegmentTimestamps, - TokenTimestamps: capabilities.TokenTimestamps, - LanguageDetection: capabilities.LanguageDetection, - SpeakerEmbeddings: capabilities.SpeakerEmbeddings, - } -} - -func resourceRequirementsFromEngine(requirements speechengine.ResourceRequirements) asrcontract.ResourceRequirements { - return asrcontract.ResourceRequirements{ - Backends: append([]string(nil), requirements.Backends...), + for _, task := range descriptor.Tasks { + switch task.Kind { + case speechproviders.TaskTranscription: + capabilities.Transcription = true + case speechproviders.TaskDiarization: + capabilities.Diarization = true + case speechproviders.TaskSpeakerIdentification: + capabilities.SpeakerIdentification = true + case speechproviders.TaskTranslation: + capabilities.Translation = true + } } + return capabilities } diff --git a/internal/transcription/engineprovider/local_provider_test.go b/internal/transcription/engineprovider/local_provider_test.go index fb770101..2634fba7 100644 --- a/internal/transcription/engineprovider/local_provider_test.go +++ b/internal/transcription/engineprovider/local_provider_test.go @@ -13,6 +13,7 @@ import ( speechengine "scriberr-engine/speech/engine" speechmodels "scriberr-engine/speech/models" + speechproviders "scriberr-engine/speech/providers" engresults "scriberr-engine/speech/results" "scriberr-engine/speech/runtime" ) @@ -24,7 +25,7 @@ type fakeSpeechEngine struct { diarizationOut *speechengine.DiarizationResult err error info *speechengine.ProviderInfo - models []speechengine.ModelCard + models []speechproviders.ModelDescriptor status *speechengine.ProviderStatus loaded []speechengine.LoadedModel loadedRequested string @@ -59,7 +60,7 @@ func (e *fakeSpeechEngine) Inspect(ctx context.Context) (*speechengine.ProviderI ActiveBackend: "cpu", SupportsConcurrent: false, MaxConcurrentJobs: 1, - ProviderCapabilities: []speechengine.Capability{speechengine.CapabilityTranscription, speechengine.CapabilityDiarization}, + ProviderCapabilities: []speechproviders.TaskKind{speechproviders.TaskTranscription, speechproviders.TaskDiarization}, }, AudioInput: speechengine.AudioInputSpec{ RequiredSampleRate: 16000, @@ -70,7 +71,7 @@ func (e *fakeSpeechEngine) Inspect(ctx context.Context) (*speechengine.ProviderI }, nil } -func (e *fakeSpeechEngine) Models(ctx context.Context) ([]speechengine.ModelCard, error) { +func (e *fakeSpeechEngine) Models(ctx context.Context) ([]speechproviders.ModelDescriptor, error) { if e.err != nil { return nil, e.err } @@ -367,31 +368,15 @@ func TestLocalProviderDiarizeRejectsNilEngineResult(t *testing.T) { } } -func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) { - fake := &fakeSpeechEngine{models: []speechengine.ModelCard{ - { - ID: "whisper-base", - DisplayName: "Whisper Base", - Provider: "local", - Family: "whisper", - Installed: true, - Default: true, - Tasks: []speechengine.Task{speechengine.TaskTranscribe}, - Capabilities: speechengine.Capabilities{ - Transcription: true, - WordTimestamps: true, - }, - }, - { - ID: "diarization-default", - DisplayName: "Diarization", - Provider: "local", - Family: "pyannote", - Default: true, - Capabilities: speechengine.Capabilities{ - Diarization: true, - }, - }, +func TestLocalProviderCapabilitiesUseEngineDescriptors(t *testing.T) { + fake := &fakeSpeechEngine{models: []speechproviders.ModelDescriptor{ + descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) { + desc.Installed = true + desc.Default = true + }), + descriptorForModelWith(t, speechmodels.ModelDiarizationDefault, func(desc *speechproviders.ModelDescriptor) { + desc.Default = true + }), }} provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake) registry, err := NewRegistry("local", provider) @@ -409,7 +394,7 @@ func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) { if !capabilities[0].Installed || !capabilities[0].Default { t.Fatalf("whisper-base capability missing installed/default: %#v", capabilities[0]) } - if strings.Join(capabilities[0].Capabilities, ",") != "transcription,word_timestamps" { + if strings.Join(capabilities[0].Capabilities, ",") != "transcription,translation,word_timestamps,segment_timestamps,token_timestamps" { t.Fatalf("whisper capabilities = %#v", capabilities[0].Capabilities) } if capabilities[1].Installed || !capabilities[1].Default { @@ -423,21 +408,12 @@ func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) { func TestLocalProviderModelsStatusAndLifecycle(t *testing.T) { fake := &fakeSpeechEngine{ loaded: []speechengine.LoadedModel{{ID: "whisper-base"}}, - models: []speechengine.ModelCard{ - { - ID: "whisper-base", - DisplayName: "Whisper Base", - Provider: "local", - Family: "whisper", - Version: "whisper", - Installed: true, - Loaded: true, - Default: true, - Tasks: []speechengine.Task{speechengine.TaskTranscribe}, - Capabilities: speechengine.Capabilities{ - Transcription: true, - }, - }, + models: []speechproviders.ModelDescriptor{ + descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) { + desc.Installed = true + desc.Loaded = true + desc.Default = true + }), }, } provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake) @@ -482,41 +458,13 @@ func TestLocalProviderModelsStatusAndLifecycle(t *testing.T) { func TestLocalProviderModelDescriptorsDistinguishWhisperAndParakeet(t *testing.T) { fake := &fakeSpeechEngine{ - models: []speechengine.ModelCard{ - { - ID: "whisper-base", - DisplayName: "Whisper Base", - Provider: "local", - Family: "whisper", - Version: "base", - Installed: true, - Tasks: []speechengine.Task{speechengine.TaskTranscribe, speechengine.Task("translate")}, - Languages: []string{"auto", "en", "es"}, - Descriptor: descriptorForModel(t, speechmodels.ModelWhisperBase), - Capabilities: speechengine.Capabilities{ - Transcription: true, - WordTimestamps: true, - SegmentTimestamps: true, - TokenTimestamps: true, - LanguageDetection: true, - }, - }, - { - ID: "parakeet-v3", - DisplayName: "Parakeet V3", - Provider: "local", - Family: "nemo_transducer", - Version: "v3", - Installed: true, - Tasks: []speechengine.Task{speechengine.TaskTranscribe}, - Languages: []string{"en"}, - Descriptor: descriptorForModel(t, speechmodels.ModelParakeetV3), - Capabilities: speechengine.Capabilities{ - Transcription: true, - WordTimestamps: true, - SegmentTimestamps: true, - }, - }, + models: []speechproviders.ModelDescriptor{ + descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) { + desc.Installed = true + }), + descriptorForModelWith(t, speechmodels.ModelParakeetV3, func(desc *speechproviders.ModelDescriptor) { + desc.Installed = true + }), }, } provider := newLocalProviderWithEngine("local", LocalConfig{Provider: "cpu", Threads: 4, CacheDir: "/Users/zade/private/cache"}, runtime.ProviderCPU, fake) @@ -568,23 +516,9 @@ func TestLocalProviderModelDescriptorsDistinguishWhisperAndParakeet(t *testing.T func TestLocalProviderModelDescriptorParameterSchemasValidate(t *testing.T) { fake := &fakeSpeechEngine{ - models: []speechengine.ModelCard{ - { - ID: "whisper-base", - Family: "whisper", - Descriptor: descriptorForModel(t, speechmodels.ModelWhisperBase), - Capabilities: speechengine.Capabilities{ - Transcription: true, - }, - }, - { - ID: "parakeet-v3", - Family: "nemo_transducer", - Descriptor: descriptorForModel(t, speechmodels.ModelParakeetV3), - Capabilities: speechengine.Capabilities{ - Transcription: true, - }, - }, + models: []speechproviders.ModelDescriptor{ + descriptorForModel(t, speechmodels.ModelWhisperBase), + descriptorForModel(t, speechmodels.ModelParakeetV3), }, } provider := newLocalProviderWithEngine("local", LocalConfig{Provider: "cpu", Threads: 4}, runtime.ProviderCPU, fake) @@ -691,6 +625,13 @@ func descriptorForModel(t *testing.T, id speechmodels.ModelID) speechmodels.Desc return descriptor } +func descriptorForModelWith(t *testing.T, id speechmodels.ModelID, edit func(*speechproviders.ModelDescriptor)) speechmodels.Descriptor { + t.Helper() + descriptor := descriptorForModel(t, id) + edit(&descriptor) + return descriptor +} + func hasParameter(schema asrcontract.ParameterSchema, key string) bool { for _, parameter := range schema { if parameter.Key == key {