mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-29 07:15:54 +00:00
Consume descriptor-first engine inventory
This commit is contained in:
@@ -25,7 +25,7 @@ type LocalConfig struct {
|
||||
|
||||
type speechEngine interface {
|
||||
Inspect(ctx context.Context) (*speechengine.ProviderInfo, error)
|
||||
Models(ctx context.Context) ([]speechengine.ModelCard, error)
|
||||
Models(ctx context.Context) ([]speechproviders.ModelDescriptor, error)
|
||||
Status(ctx context.Context) (*speechengine.ProviderStatus, error)
|
||||
LoadedModels() []speechengine.LoadedModel
|
||||
Transcribe(ctx context.Context, req speechengine.TranscriptionRequest) (*speechengine.TranscriptionResult, error)
|
||||
@@ -115,7 +115,7 @@ func (p *LocalProvider) Models(ctx context.Context) ([]asrcontract.ModelCard, er
|
||||
}
|
||||
out := make([]asrcontract.ModelCard, 0, len(models))
|
||||
for _, model := range models {
|
||||
out = append(out, modelCardFromEngine(model, p.id, p.cfg, p.provider))
|
||||
out = append(out, modelCardFromEngine(model, p.id))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -330,26 +330,22 @@ func providerInfoFromEngine(info *speechengine.ProviderInfo, providerID string)
|
||||
}
|
||||
}
|
||||
|
||||
func modelCardFromEngine(model speechengine.ModelCard, providerID string, cfg LocalConfig, provider runtime.Provider) asrcontract.ModelCard {
|
||||
capabilities := capabilitiesFromEngine(model.Capabilities)
|
||||
descriptor := model.Descriptor
|
||||
if strings.TrimSpace(descriptor.ID) == "" {
|
||||
return fallbackModelCardFromEngine(model, providerID, capabilities)
|
||||
}
|
||||
func modelCardFromEngine(descriptor speechproviders.ModelDescriptor, providerID string) asrcontract.ModelCard {
|
||||
capabilities := capabilitiesFromDescriptor(descriptor)
|
||||
return asrcontract.ModelCard{
|
||||
ID: firstNonEmpty(descriptor.ID, model.ID),
|
||||
DisplayName: firstNonEmpty(descriptor.DisplayName, model.DisplayName),
|
||||
ID: descriptor.ID,
|
||||
DisplayName: descriptor.DisplayName,
|
||||
Provider: providerID,
|
||||
Family: firstNonEmpty(descriptor.Family, model.Family),
|
||||
Version: firstNonEmpty(descriptor.Version, model.Version),
|
||||
Installed: model.Installed,
|
||||
Loaded: model.Loaded,
|
||||
Default: model.Default,
|
||||
Tasks: tasksFromDescriptor(descriptor.Tasks, model.Tasks),
|
||||
Languages: languageIDsFromDescriptor(descriptor.Languages, model.Languages),
|
||||
LanguageSupport: languageSupportFromDescriptor(descriptor.Languages, model.Languages),
|
||||
Family: descriptor.Family,
|
||||
Version: descriptor.Version,
|
||||
Installed: descriptor.Installed,
|
||||
Loaded: descriptor.Loaded,
|
||||
Default: descriptor.Default,
|
||||
Tasks: tasksFromDescriptor(descriptor.Tasks),
|
||||
Languages: languageIDsFromDescriptor(descriptor.Languages),
|
||||
LanguageSupport: languageSupportFromDescriptor(descriptor.Languages),
|
||||
Capabilities: capabilities,
|
||||
ResourceRequirements: resourceRequirementsFromDescriptor(descriptor.Runtime, model.ResourceRequirements),
|
||||
ResourceRequirements: resourceRequirementsFromDescriptor(descriptor.Runtime),
|
||||
Chunking: chunkingCapabilitiesFromDescriptor(descriptor.Chunking, descriptor.Runtime, capabilities),
|
||||
ParameterSchema: parameterSchemaFromDescriptor(descriptor.Parameters),
|
||||
RecommendedDefaults: copyRecommendedDefaults(descriptor.RecommendedDefaults),
|
||||
@@ -357,39 +353,7 @@ func modelCardFromEngine(model speechengine.ModelCard, providerID string, cfg Lo
|
||||
}
|
||||
}
|
||||
|
||||
func fallbackModelCardFromEngine(model speechengine.ModelCard, providerID string, capabilities asrcontract.Capabilities) asrcontract.ModelCard {
|
||||
return asrcontract.ModelCard{
|
||||
ID: model.ID,
|
||||
DisplayName: model.DisplayName,
|
||||
Provider: providerID,
|
||||
Family: model.Family,
|
||||
Version: model.Version,
|
||||
Installed: model.Installed,
|
||||
Loaded: model.Loaded,
|
||||
Default: model.Default,
|
||||
Tasks: tasksFromEngine(model.Tasks),
|
||||
Languages: append([]string(nil), model.Languages...),
|
||||
LanguageSupport: fallbackLanguageSupport(model.Languages),
|
||||
Capabilities: capabilities,
|
||||
ResourceRequirements: resourceRequirementsFromEngine(model.ResourceRequirements),
|
||||
}
|
||||
}
|
||||
|
||||
func fallbackLanguageSupport(languages []string) *asrcontract.LanguageSupport {
|
||||
if len(languages) == 0 {
|
||||
return nil
|
||||
}
|
||||
mode := "fixed"
|
||||
if len(languages) > 1 {
|
||||
mode = "configurable"
|
||||
}
|
||||
return &asrcontract.LanguageSupport{Languages: append([]string(nil), languages...), Mode: mode}
|
||||
}
|
||||
|
||||
func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport, fallback []string) *asrcontract.LanguageSupport {
|
||||
if len(languages) == 0 {
|
||||
return fallbackLanguageSupport(fallback)
|
||||
}
|
||||
func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport) *asrcontract.LanguageSupport {
|
||||
ids := make([]string, 0, len(languages))
|
||||
mode := ""
|
||||
for _, language := range languages {
|
||||
@@ -409,10 +373,7 @@ func languageSupportFromDescriptor(languages []speechproviders.LanguageSupport,
|
||||
return &asrcontract.LanguageSupport{Languages: ids, Mode: mode}
|
||||
}
|
||||
|
||||
func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport, fallback []string) []string {
|
||||
if len(languages) == 0 {
|
||||
return append([]string(nil), fallback...)
|
||||
}
|
||||
func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport) []string {
|
||||
out := make([]string, 0, len(languages))
|
||||
for _, language := range languages {
|
||||
if strings.TrimSpace(language.ID) != "" {
|
||||
@@ -422,10 +383,7 @@ func languageIDsFromDescriptor(languages []speechproviders.LanguageSupport, fall
|
||||
return out
|
||||
}
|
||||
|
||||
func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor, fallback []speechengine.Task) []asrcontract.Task {
|
||||
if len(tasks) == 0 {
|
||||
return tasksFromEngine(fallback)
|
||||
}
|
||||
func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor) []asrcontract.Task {
|
||||
out := make([]asrcontract.Task, 0, len(tasks))
|
||||
for _, task := range tasks {
|
||||
switch task.Kind {
|
||||
@@ -440,10 +398,7 @@ func tasksFromDescriptor(tasks []speechproviders.TaskDescriptor, fallback []spee
|
||||
return out
|
||||
}
|
||||
|
||||
func resourceRequirementsFromDescriptor(runtime speechproviders.RuntimeCapabilities, fallback speechengine.ResourceRequirements) asrcontract.ResourceRequirements {
|
||||
if len(runtime.Backends) == 0 {
|
||||
return resourceRequirementsFromEngine(fallback)
|
||||
}
|
||||
func resourceRequirementsFromDescriptor(runtime speechproviders.RuntimeCapabilities) asrcontract.ResourceRequirements {
|
||||
backends := make([]string, 0, len(runtime.Backends))
|
||||
for _, backend := range runtime.Backends {
|
||||
backends = append(backends, string(backend))
|
||||
@@ -547,15 +502,6 @@ func descriptorExtensionsFromEngine(descriptor speechproviders.ModelDescriptor)
|
||||
return extensions
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cloneFloat64(value *float64) *float64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
@@ -607,36 +553,45 @@ func loadedModelFromEngine(model speechengine.LoadedModel) asrcontract.LoadedMod
|
||||
}
|
||||
}
|
||||
|
||||
func providerCapabilitiesFromEngine(capabilities []speechengine.Capability) []asrcontract.Capability {
|
||||
func providerCapabilitiesFromEngine(capabilities []speechproviders.TaskKind) []asrcontract.Capability {
|
||||
out := make([]asrcontract.Capability, 0, len(capabilities))
|
||||
for _, capability := range capabilities {
|
||||
out = append(out, asrcontract.Capability(capability))
|
||||
switch capability {
|
||||
case speechproviders.TaskTranscription:
|
||||
out = append(out, asrcontract.CapabilityTranscription)
|
||||
case speechproviders.TaskDiarization:
|
||||
out = append(out, asrcontract.CapabilityDiarization)
|
||||
case speechproviders.TaskSpeakerIdentification:
|
||||
out = append(out, asrcontract.CapabilitySpeakerIdentification)
|
||||
case speechproviders.TaskTranslation:
|
||||
out = append(out, asrcontract.CapabilityTranslation)
|
||||
default:
|
||||
out = append(out, asrcontract.Capability(capability))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func tasksFromEngine(tasks []speechengine.Task) []asrcontract.Task {
|
||||
out := make([]asrcontract.Task, 0, len(tasks))
|
||||
for _, task := range tasks {
|
||||
out = append(out, asrcontract.Task(task))
|
||||
func capabilitiesFromDescriptor(descriptor speechproviders.ModelDescriptor) asrcontract.Capabilities {
|
||||
capabilities := asrcontract.Capabilities{
|
||||
WordTimestamps: descriptor.Output.WordTimestamps,
|
||||
SegmentTimestamps: descriptor.Output.SegmentTimestamps,
|
||||
TokenTimestamps: descriptor.Output.TokenTimestamps,
|
||||
LanguageDetection: descriptor.Output.LanguageSpans,
|
||||
SpeakerEmbeddings: descriptor.Output.SpeakerLabels,
|
||||
Translation: descriptor.Output.Translation,
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func capabilitiesFromEngine(capabilities speechengine.Capabilities) asrcontract.Capabilities {
|
||||
return asrcontract.Capabilities{
|
||||
Transcription: capabilities.Transcription,
|
||||
Diarization: capabilities.Diarization,
|
||||
WordTimestamps: capabilities.WordTimestamps,
|
||||
SegmentTimestamps: capabilities.SegmentTimestamps,
|
||||
TokenTimestamps: capabilities.TokenTimestamps,
|
||||
LanguageDetection: capabilities.LanguageDetection,
|
||||
SpeakerEmbeddings: capabilities.SpeakerEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func resourceRequirementsFromEngine(requirements speechengine.ResourceRequirements) asrcontract.ResourceRequirements {
|
||||
return asrcontract.ResourceRequirements{
|
||||
Backends: append([]string(nil), requirements.Backends...),
|
||||
for _, task := range descriptor.Tasks {
|
||||
switch task.Kind {
|
||||
case speechproviders.TaskTranscription:
|
||||
capabilities.Transcription = true
|
||||
case speechproviders.TaskDiarization:
|
||||
capabilities.Diarization = true
|
||||
case speechproviders.TaskSpeakerIdentification:
|
||||
capabilities.SpeakerIdentification = true
|
||||
case speechproviders.TaskTranslation:
|
||||
capabilities.Translation = true
|
||||
}
|
||||
}
|
||||
return capabilities
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
speechengine "scriberr-engine/speech/engine"
|
||||
speechmodels "scriberr-engine/speech/models"
|
||||
speechproviders "scriberr-engine/speech/providers"
|
||||
engresults "scriberr-engine/speech/results"
|
||||
"scriberr-engine/speech/runtime"
|
||||
)
|
||||
@@ -24,7 +25,7 @@ type fakeSpeechEngine struct {
|
||||
diarizationOut *speechengine.DiarizationResult
|
||||
err error
|
||||
info *speechengine.ProviderInfo
|
||||
models []speechengine.ModelCard
|
||||
models []speechproviders.ModelDescriptor
|
||||
status *speechengine.ProviderStatus
|
||||
loaded []speechengine.LoadedModel
|
||||
loadedRequested string
|
||||
@@ -59,7 +60,7 @@ func (e *fakeSpeechEngine) Inspect(ctx context.Context) (*speechengine.ProviderI
|
||||
ActiveBackend: "cpu",
|
||||
SupportsConcurrent: false,
|
||||
MaxConcurrentJobs: 1,
|
||||
ProviderCapabilities: []speechengine.Capability{speechengine.CapabilityTranscription, speechengine.CapabilityDiarization},
|
||||
ProviderCapabilities: []speechproviders.TaskKind{speechproviders.TaskTranscription, speechproviders.TaskDiarization},
|
||||
},
|
||||
AudioInput: speechengine.AudioInputSpec{
|
||||
RequiredSampleRate: 16000,
|
||||
@@ -70,7 +71,7 @@ func (e *fakeSpeechEngine) Inspect(ctx context.Context) (*speechengine.ProviderI
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *fakeSpeechEngine) Models(ctx context.Context) ([]speechengine.ModelCard, error) {
|
||||
func (e *fakeSpeechEngine) Models(ctx context.Context) ([]speechproviders.ModelDescriptor, error) {
|
||||
if e.err != nil {
|
||||
return nil, e.err
|
||||
}
|
||||
@@ -367,31 +368,15 @@ func TestLocalProviderDiarizeRejectsNilEngineResult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{models: []speechengine.ModelCard{
|
||||
{
|
||||
ID: "whisper-base",
|
||||
DisplayName: "Whisper Base",
|
||||
Provider: "local",
|
||||
Family: "whisper",
|
||||
Installed: true,
|
||||
Default: true,
|
||||
Tasks: []speechengine.Task{speechengine.TaskTranscribe},
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Transcription: true,
|
||||
WordTimestamps: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "diarization-default",
|
||||
DisplayName: "Diarization",
|
||||
Provider: "local",
|
||||
Family: "pyannote",
|
||||
Default: true,
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Diarization: true,
|
||||
},
|
||||
},
|
||||
func TestLocalProviderCapabilitiesUseEngineDescriptors(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{models: []speechproviders.ModelDescriptor{
|
||||
descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) {
|
||||
desc.Installed = true
|
||||
desc.Default = true
|
||||
}),
|
||||
descriptorForModelWith(t, speechmodels.ModelDiarizationDefault, func(desc *speechproviders.ModelDescriptor) {
|
||||
desc.Default = true
|
||||
}),
|
||||
}}
|
||||
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
|
||||
registry, err := NewRegistry("local", provider)
|
||||
@@ -409,7 +394,7 @@ func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) {
|
||||
if !capabilities[0].Installed || !capabilities[0].Default {
|
||||
t.Fatalf("whisper-base capability missing installed/default: %#v", capabilities[0])
|
||||
}
|
||||
if strings.Join(capabilities[0].Capabilities, ",") != "transcription,word_timestamps" {
|
||||
if strings.Join(capabilities[0].Capabilities, ",") != "transcription,translation,word_timestamps,segment_timestamps,token_timestamps" {
|
||||
t.Fatalf("whisper capabilities = %#v", capabilities[0].Capabilities)
|
||||
}
|
||||
if capabilities[1].Installed || !capabilities[1].Default {
|
||||
@@ -423,21 +408,12 @@ func TestLocalProviderCapabilitiesUseEngineModelCards(t *testing.T) {
|
||||
func TestLocalProviderModelsStatusAndLifecycle(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{
|
||||
loaded: []speechengine.LoadedModel{{ID: "whisper-base"}},
|
||||
models: []speechengine.ModelCard{
|
||||
{
|
||||
ID: "whisper-base",
|
||||
DisplayName: "Whisper Base",
|
||||
Provider: "local",
|
||||
Family: "whisper",
|
||||
Version: "whisper",
|
||||
Installed: true,
|
||||
Loaded: true,
|
||||
Default: true,
|
||||
Tasks: []speechengine.Task{speechengine.TaskTranscribe},
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Transcription: true,
|
||||
},
|
||||
},
|
||||
models: []speechproviders.ModelDescriptor{
|
||||
descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) {
|
||||
desc.Installed = true
|
||||
desc.Loaded = true
|
||||
desc.Default = true
|
||||
}),
|
||||
},
|
||||
}
|
||||
provider := newLocalProviderWithEngine("local", LocalConfig{}, runtime.ProviderCPU, fake)
|
||||
@@ -482,41 +458,13 @@ func TestLocalProviderModelsStatusAndLifecycle(t *testing.T) {
|
||||
|
||||
func TestLocalProviderModelDescriptorsDistinguishWhisperAndParakeet(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{
|
||||
models: []speechengine.ModelCard{
|
||||
{
|
||||
ID: "whisper-base",
|
||||
DisplayName: "Whisper Base",
|
||||
Provider: "local",
|
||||
Family: "whisper",
|
||||
Version: "base",
|
||||
Installed: true,
|
||||
Tasks: []speechengine.Task{speechengine.TaskTranscribe, speechengine.Task("translate")},
|
||||
Languages: []string{"auto", "en", "es"},
|
||||
Descriptor: descriptorForModel(t, speechmodels.ModelWhisperBase),
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Transcription: true,
|
||||
WordTimestamps: true,
|
||||
SegmentTimestamps: true,
|
||||
TokenTimestamps: true,
|
||||
LanguageDetection: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "parakeet-v3",
|
||||
DisplayName: "Parakeet V3",
|
||||
Provider: "local",
|
||||
Family: "nemo_transducer",
|
||||
Version: "v3",
|
||||
Installed: true,
|
||||
Tasks: []speechengine.Task{speechengine.TaskTranscribe},
|
||||
Languages: []string{"en"},
|
||||
Descriptor: descriptorForModel(t, speechmodels.ModelParakeetV3),
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Transcription: true,
|
||||
WordTimestamps: true,
|
||||
SegmentTimestamps: true,
|
||||
},
|
||||
},
|
||||
models: []speechproviders.ModelDescriptor{
|
||||
descriptorForModelWith(t, speechmodels.ModelWhisperBase, func(desc *speechproviders.ModelDescriptor) {
|
||||
desc.Installed = true
|
||||
}),
|
||||
descriptorForModelWith(t, speechmodels.ModelParakeetV3, func(desc *speechproviders.ModelDescriptor) {
|
||||
desc.Installed = true
|
||||
}),
|
||||
},
|
||||
}
|
||||
provider := newLocalProviderWithEngine("local", LocalConfig{Provider: "cpu", Threads: 4, CacheDir: "/Users/zade/private/cache"}, runtime.ProviderCPU, fake)
|
||||
@@ -568,23 +516,9 @@ func TestLocalProviderModelDescriptorsDistinguishWhisperAndParakeet(t *testing.T
|
||||
|
||||
func TestLocalProviderModelDescriptorParameterSchemasValidate(t *testing.T) {
|
||||
fake := &fakeSpeechEngine{
|
||||
models: []speechengine.ModelCard{
|
||||
{
|
||||
ID: "whisper-base",
|
||||
Family: "whisper",
|
||||
Descriptor: descriptorForModel(t, speechmodels.ModelWhisperBase),
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Transcription: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "parakeet-v3",
|
||||
Family: "nemo_transducer",
|
||||
Descriptor: descriptorForModel(t, speechmodels.ModelParakeetV3),
|
||||
Capabilities: speechengine.Capabilities{
|
||||
Transcription: true,
|
||||
},
|
||||
},
|
||||
models: []speechproviders.ModelDescriptor{
|
||||
descriptorForModel(t, speechmodels.ModelWhisperBase),
|
||||
descriptorForModel(t, speechmodels.ModelParakeetV3),
|
||||
},
|
||||
}
|
||||
provider := newLocalProviderWithEngine("local", LocalConfig{Provider: "cpu", Threads: 4}, runtime.ProviderCPU, fake)
|
||||
@@ -691,6 +625,13 @@ func descriptorForModel(t *testing.T, id speechmodels.ModelID) speechmodels.Desc
|
||||
return descriptor
|
||||
}
|
||||
|
||||
func descriptorForModelWith(t *testing.T, id speechmodels.ModelID, edit func(*speechproviders.ModelDescriptor)) speechmodels.Descriptor {
|
||||
t.Helper()
|
||||
descriptor := descriptorForModel(t, id)
|
||||
edit(&descriptor)
|
||||
return descriptor
|
||||
}
|
||||
|
||||
func hasParameter(schema asrcontract.ParameterSchema, key string) bool {
|
||||
for _, parameter := range schema {
|
||||
if parameter.Key == key {
|
||||
|
||||
Reference in New Issue
Block a user