Consume descriptor-first engine inventory

This commit is contained in:
rishikanthc
2026-05-06 12:23:50 -07:00
parent 77ba86040f
commit 4dcb23cc41
2 changed files with 89 additions and 193 deletions

View File

@@ -25,7 +25,7 @@ type LocalConfig struct {
type speechEngine interface {
Inspect(ctx context.Context) (*speechengine.ProviderInfo, error)
Models(ctx context.Context) ([]speechengine.ModelCard, error)
Models(ctx context.Context) ([]speechproviders.ModelDescriptor, error)
Status(ctx context.Context) (*speechengine.ProviderStatus, error)
LoadedModels() []speechengine.LoadedModel
Transcribe(ctx context.Context, req speechengine.TranscriptionRequest) (*speechengine.TranscriptionResult, error)
@@ -115,7 +115,7 @@ func (p *LocalProvider) Models(ctx context.Context) ([]asrcontract.ModelCard, er
}
out := make([]asrcontract.ModelCard, 0, len(models))
for _, model := range models {
out = append(out, modelCardFromEngine(model, p.id, p.cfg, p.provider))
out = append(out, modelCardFromEngine(model, p.id))
}
return out, nil
}
@@ -330,26 +330,22 @@ func providerInfoFromEngine(info *speechengine.ProviderInfo, providerID string)
}
}
func modelCardFromEngine(model speechengine.ModelCard, providerID string, cfg LocalConfig, provider runtime.Provider) asrcontract.ModelCard {
capabilities := capabilitiesFromEngine(model.Capabilities)
descriptor := model.Descriptor
if strings.TrimSpace(descriptor.ID) == "" {
return fallbackModelCardFromEngine(model, providerID, capabilities)
}
func modelCardFromEngine(descriptor speechproviders.ModelDescriptor, providerID string) asrcontract.ModelCard {
capabilities := capabilitiesFromDescriptor(descriptor)
return asrcontract.ModelCard{
ID: firstNonEmpty(descriptor.ID, model.ID),
DisplayName: firstNonEmpty(descriptor.DisplayName, model.DisplayName),
ID: descriptor.ID,
DisplayName: descriptor.DisplayName,
Provider: providerID,
Family: firstNonEmpty(descriptor.Family, model.Family),
Version: firstNonEmpty(descriptor.Version, model.Version),
Installed: model.Installed,
Loaded: model.Loaded,
Default: model.Default,
Tasks: tasksFromDescriptor(descriptor.Tasks, model.Tasks),
Languages: languageIDsFromDescriptor(descriptor.Languages, model.Languages),
LanguageSupport: languageSupportFromDescriptor(descriptor.Languages, model.Languages),
Family: descriptor.Family,
Version: descriptor.Version,
Installed: descriptor.Installed,
Loaded: descriptor.Loaded,
Default: descriptor.Default,
Tasks: tasksFromDescriptor(descriptor.Tasks),
Languages: languageIDsFromDescriptor(descriptor.Languages),
LanguageSupport: languageSupportFromDescriptor(descriptor.Languages),
Capabilities: capabilities,
ResourceRequirements: resourceRequirementsFromDescriptor(descriptor.Runtime, model.ResourceRequirements),
ResourceRequirements: resourceRequirementsFromDescriptor(descriptor.Runtime),
Chunking: chunkingCapabilitiesFromDescriptor(descriptor.Chunking, descriptor.Runtime, capabilities),
ParameterSchema: parameterSchemaFromDescriptor(descriptor.Parameters),
RecommendedDefaults: copyRecommendedDefaults(descriptor.RecommendedDefaults),
@@ -357,39 +353,7 @@ func modelCardFromEngine(model speechengine.ModelCard, providerID string, cfg Lo
}
}
func fallbackModelCardFromEngine(model speechengine.ModelCard, providerID string, capabilities asrcontract.Capabilities) asrcontract.ModelCard {
return asrcontract.ModelCard{
ID: model.ID,
DisplayName: model.DisplayName,
Provider: providerID,
Family: model.Family,
Version: model.Version,
Installed: model.Installed,
Loaded: model.Loaded,
Default: model.Default,
Tasks: tasksFromEngine(model.Tasks),
Languages: append([]string(nil), model.Languages...),
LanguageSupport: fallbackLanguageSupport(model.Languages),
Capabilities: capabilities,
ResourceRequirements: resourceRequirementsFromEngine(model.ResourceRequirements),
}
}
func fallbackLanguageSupport(languages []string) *asrcontract.LanguageSupport {
if len(languages) == 0 {
return nil
}
mode := "fixed"
if len(languages) > 1 {
mode = "configurable"
}
return &asrcontract.LanguageSupport{Languages: append([]string(nil), languages...), Mode: mode}
}
func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport, fallback []string) *asrcontract.LanguageSupport {
if len(languages) == 0 {
return fallbackLanguageSupport(fallback)
}
func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport) *asrcontract.LanguageSupport {
ids := make([]string, 0, len(languages))
mode := ""
for _, language := range languages {
@@ -409,10 +373,7 @@ func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport,
return &asrcontract.LanguageSupport{Languages: ids, Mode: mode}
}
func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport, fallback []string) []string {
if len(languages) == 0 {
return append([]string(nil), fallback...)
}
func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport) []string {
out := make([]string, 0, len(languages))
for _, language := range languages {
if strings.TrimSpace(language.ID) != "" {
@@ -422,10 +383,7 @@ func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport, fall
return out
}
func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor, fallback []speechengine.Task) []asrcontract.Task {
if len(tasks) == 0 {
return tasksFromEngine(fallback)
}
func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor) []asrcontract.Task {
out := make([]asrcontract.Task, 0, len(tasks))
for _, task := range tasks {
switch task.Kind {
@@ -440,10 +398,7 @@ func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor, fallback []spee
return out
}
func resourceRequirementsFromDescriptor(runtime speechproviders.RuntimeCapabilities, fallback speechengine.ResourceRequirements) asrcontract.ResourceRequirements {
if len(runtime.Backends) == 0 {
return resourceRequirementsFromEngine(fallback)
}
func resourceRequirementsFromDescriptor(runtime speechproviders.RuntimeCapabilities) asrcontract.ResourceRequirements {
backends := make([]string, 0, len(runtime.Backends))
for _, backend := range runtime.Backends {
backends = append(backends, string(backend))
@@ -547,15 +502,6 @@ func descriptorExtensionsFromEngine(descriptor speechproviders.ModelDescriptor)
return extensions
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func cloneFloat64(value *float64) *float64 {
if value == nil {
return nil
@@ -607,36 +553,45 @@ func loadedModelFromEngine(model speechengine.LoadedModel) asrcontract.LoadedMod
}
}
func providerCapabilitiesFromEngine(capabilities []speechengine.Capability) []asrcontract.Capability {
func providerCapabilitiesFromEngine(capabilities []speechproviders.TaskKind) []asrcontract.Capability {
out := make([]asrcontract.Capability, 0, len(capabilities))
for _, capability := range capabilities {
out = append(out, asrcontract.Capability(capability))
switch capability {
case speechproviders.TaskTranscription:
out = append(out, asrcontract.CapabilityTranscription)
case speechproviders.TaskDiarization:
out = append(out, asrcontract.CapabilityDiarization)
case speechproviders.TaskSpeakerIdentification:
out = append(out, asrcontract.CapabilitySpeakerIdentification)
case speechproviders.TaskTranslation:
out = append(out, asrcontract.CapabilityTranslation)
default:
out = append(out, asrcontract.Capability(capability))
}
}
return out
}
func tasksFromEngine(tasks []speechengine.Task) []asrcontract.Task {
out := make([]asrcontract.Task, 0, len(tasks))
for _, task := range tasks {
out = append(out, asrcontract.Task(task))
func capabilitiesFromDescriptor(descriptor speechproviders.ModelDescriptor) asrcontract.Capabilities {
capabilities := asrcontract.Capabilities{
WordTimestamps: descriptor.Output.WordTimestamps,
SegmentTimestamps: descriptor.Output.SegmentTimestamps,
TokenTimestamps: descriptor.Output.TokenTimestamps,
LanguageDetection: descriptor.Output.LanguageSpans,
SpeakerEmbeddings: descriptor.Output.SpeakerLabels,
Translation: descriptor.Output.Translation,
}
return out
}
func capabilitiesFromEngine(capabilities speechengine.Capabilities) asrcontract.Capabilities {
return asrcontract.Capabilities{
Transcription: capabilities.Transcription,
Diarization: capabilities.Diarization,
WordTimestamps: capabilities.WordTimestamps,
SegmentTimestamps: capabilities.SegmentTimestamps,
TokenTimestamps: capabilities.TokenTimestamps,
LanguageDetection: capabilities.LanguageDetection,
SpeakerEmbeddings: capabilities.SpeakerEmbeddings,
}
}
func resourceRequirementsFromEngine(requirements speechengine.ResourceRequirements) asrcontract.ResourceRequirements {
return asrcontract.ResourceRequirements{
Backends: append([]string(nil), requirements.Backends...),
for _, task := range descriptor.Tasks {
switch task.Kind {
case speechproviders.TaskTranscription:
capabilities.Transcription = true
case speechproviders.TaskDiarization:
capabilities.Diarization = true
case speechproviders.TaskSpeakerIdentification:
capabilities.SpeakerIdentification = true
case speechproviders.TaskTranslation:
capabilities.Translation = true
}
}
return capabilities
}

View File

@@ -13,6 +13,7 @@ import (
speechengine "scriberr-engine/speech/engine"
speechmodels "scriberr-engine/speech/models"
speechproviders "scriberr-engine/speech/providers"
engresults "scriberr-engine/speech/results"
"scriberr-engine/speech/runtime"
)
@@ -24,7 +25,7 @@ type fakeSpeechEngine struct {
diarizationOut *speechengine.DiarizationResult
err error
info *speechengine.ProviderInfo
models []speechengine.ModelCard
models []speechproviders.ModelDescriptor
status *speechengine.ProviderStatus
loaded []speechengine.LoadedModel
loadedRequested string
@@ -59,7 +60,7 @@ func (e *fakeSpeechEngine) Inspect(ctx context.Context) (*speechengine.ProviderI
ActiveBackend: "cpu",
SupportsConcurrent: false,
MaxConcurrentJobs: 1,
ProviderCapabilities: []speechengine.Capability{speechengine.CapabilityTranscription, speechengine.CapabilityDiarization},
ProviderCapabilities: []speechproviders.TaskKind{speechproviders.TaskTranscription, speechproviders.TaskDiarization},
},
AudioInput: speechengine.AudioInputSpec{
RequiredSampleRate: 16000,
@@ -70,7 +71,7 @@ func (e *fakeSpeechEngine) Inspect(ctx context.Context) (*speechengine.ProviderI
}, nil
}
func (e *fakeSpeechEngine) Models(ctx context.Context) ([]speechengine.ModelCard, error) {
func (e *fakeSpeechEngine) Models(ctx context.Context) ([]speechproviders.ModelDescriptor, error) {
if e.err != nil {
return nil, e.err
}
@@ -367,31 +368,15 @@ func TestLocalProviderDiarizeRejectsNilEngineResult(t *testing.T) {
}
}
func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) {
fake := &fakeSpeechEngine{models: []speechengine.ModelCard{
{
ID: "whisper-base",
DisplayName: "Whisper Base",
Provider: "local",
Family: "whisper",
Installed: true,
Default: true,
Tasks: []speechengine.Task{speechengine.TaskTranscribe},
Capabilities: speechengine.Capabilities{
Transcription: true,
WordTimestamps: true,
},
},
{
ID: "diarization-default",
DisplayName: "Diarization",
Provider: "local",
Family: "pyannote",
Default: true,
Capabilities: speechengine.Capabilities{
Diarization: true,
},
},
func TestLocalProviderCapabilitiesUseEngineDescriptors(t *testing.T) {
fake := &fakeSpeechEngine{models: []speechproviders.ModelDescriptor{
descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) {
desc.Installed = true
desc.Default = true
}),
descriptorForModelWith(t, speechmodels.ModelDiarizationDefault, func(desc *speechproviders.ModelDescriptor) {
desc.Default = true
}),
}}
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
registry, err := NewRegistry("local", provider)
@@ -409,7 +394,7 @@ func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) {
if !capabilities[0].Installed || !capabilities[0].Default {
t.Fatalf("whisper-base capability missing installed/default: %#v", capabilities[0])
}
if strings.Join(capabilities[0].Capabilities, ",") != "transcription,word_timestamps" {
if strings.Join(capabilities[0].Capabilities, ",") != "transcription,translation,word_timestamps,segment_timestamps,token_timestamps" {
t.Fatalf("whisper capabilities = %#v", capabilities[0].Capabilities)
}
if capabilities[1].Installed || !capabilities[1].Default {
@@ -423,21 +408,12 @@ func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) {
func TestLocalProviderModelsStatusAndLifecycle(t *testing.T) {
fake := &fakeSpeechEngine{
loaded: []speechengine.LoadedModel{{ID: "whisper-base"}},
models: []speechengine.ModelCard{
{
ID: "whisper-base",
DisplayName: "Whisper Base",
Provider: "local",
Family: "whisper",
Version: "whisper",
Installed: true,
Loaded: true,
Default: true,
Tasks: []speechengine.Task{speechengine.TaskTranscribe},
Capabilities: speechengine.Capabilities{
Transcription: true,
},
},
models: []speechproviders.ModelDescriptor{
descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) {
desc.Installed = true
desc.Loaded = true
desc.Default = true
}),
},
}
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
@@ -482,41 +458,13 @@ func TestLocalProviderModelsStatusAndLifecycle(t *testing.T) {
func TestLocalProviderModelDescriptorsDistinguishWhisperAndParakeet(t *testing.T) {
fake := &fakeSpeechEngine{
models: []speechengine.ModelCard{
{
ID: "whisper-base",
DisplayName: "Whisper Base",
Provider: "local",
Family: "whisper",
Version: "base",
Installed: true,
Tasks: []speechengine.Task{speechengine.TaskTranscribe, speechengine.Task("translate")},
Languages: []string{"auto", "en", "es"},
Descriptor: descriptorForModel(t, speechmodels.ModelWhisperBase),
Capabilities: speechengine.Capabilities{
Transcription: true,
WordTimestamps: true,
SegmentTimestamps: true,
TokenTimestamps: true,
LanguageDetection: true,
},
},
{
ID: "parakeet-v3",
DisplayName: "Parakeet V3",
Provider: "local",
Family: "nemo_transducer",
Version: "v3",
Installed: true,
Tasks: []speechengine.Task{speechengine.TaskTranscribe},
Languages: []string{"en"},
Descriptor: descriptorForModel(t, speechmodels.ModelParakeetV3),
Capabilities: speechengine.Capabilities{
Transcription: true,
WordTimestamps: true,
SegmentTimestamps: true,
},
},
models: []speechproviders.ModelDescriptor{
descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) {
desc.Installed = true
}),
descriptorForModelWith(t, speechmodels.ModelParakeetV3, func(desc *speechproviders.ModelDescriptor) {
desc.Installed = true
}),
},
}
provider := newLocalProviderWithEngine("local", LocalConfig{Provider: "cpu", Threads: 4, CacheDir: "/Users/zade/private/cache"}, runtime.ProviderCPU, fake)
@@ -568,23 +516,9 @@ func TestLocalProviderModelDescriptorsDistinguishWhisperAndParakeet(t *testing.T
func TestLocalProviderModelDescriptorParameterSchemasValidate(t *testing.T) {
fake := &fakeSpeechEngine{
models: []speechengine.ModelCard{
{
ID: "whisper-base",
Family: "whisper",
Descriptor: descriptorForModel(t, speechmodels.ModelWhisperBase),
Capabilities: speechengine.Capabilities{
Transcription: true,
},
},
{
ID: "parakeet-v3",
Family: "nemo_transducer",
Descriptor: descriptorForModel(t, speechmodels.ModelParakeetV3),
Capabilities: speechengine.Capabilities{
Transcription: true,
},
},
models: []speechproviders.ModelDescriptor{
descriptorForModel(t, speechmodels.ModelWhisperBase),
descriptorForModel(t, speechmodels.ModelParakeetV3),
},
}
provider := newLocalProviderWithEngine("local", LocalConfig{Provider: "cpu", Threads: 4}, runtime.ProviderCPU, fake)
@@ -691,6 +625,13 @@ func descriptorForModel(t *testing.T, id speechmodels.ModelID) speechmodels.Desc
return descriptor
}
func descriptorForModelWith(t *testing.T, id speechmodels.ModelID, edit func(*speechproviders.ModelDescriptor)) speechmodels.Descriptor {
t.Helper()
descriptor := descriptorForModel(t, id)
edit(&descriptor)
return descriptor
}
func hasParameter(schema asrcontract.ParameterSchema, key string) bool {
for _, parameter := range schema {
if parameter.Key == key {