diff --git a/devnotes/v2.0.0/sprint-trackers/asr-provider-backend-sprint-tracker.md b/devnotes/v2.0.0/sprint-trackers/asr-provider-backend-sprint-tracker.md index a723966f..3d59c0c2 100644 --- a/devnotes/v2.0.0/sprint-trackers/asr-provider-backend-sprint-tracker.md +++ b/devnotes/v2.0.0/sprint-trackers/asr-provider-backend-sprint-tracker.md @@ -2,7 +2,7 @@ Run ID: `ASRP` -Status: completed through ASRP-Sprint 10. +Status: completed through ASRP-Sprint 11. This tracker belongs to `devnotes/v2.0.0/sprint-plans/asr-provider-backend-sprint-plan.md` and the design spec in `devnotes/v2.0.0/specs/asr-provider-backend-architecture.md`. @@ -492,30 +492,42 @@ Commit: ## ASRP-Sprint 11: Provider Admin/Diagnostics API -Status: pending +Status: completed -Planned tasks: +Completed tasks: -- [ ] Add service methods for provider list/status/model load/unload. -- [ ] Add authenticated/admin-gated diagnostics endpoints as appropriate. -- [ ] Keep `/api/v1/models/transcription` user-readable. -- [ ] Use bounded timeouts for load/unload commands. -- [ ] Sanitize provider status messages. +- [x] Added registry provider enumeration for diagnostics. +- [x] Added admin-only ASR provider list/detail endpoints. +- [x] Added admin-only provider model load/unload endpoints. +- [x] Kept `/api/v1/models/transcription` user-readable and unchanged. +- [x] Added bounded timeouts around provider diagnostics and model commands. +- [x] Sanitized provider diagnostics and provider error messages before API output. Acceptance checks: -- [ ] Users can still list selectable transcription models. -- [ ] Admin diagnostics show provider state, active operation, and loaded models without paths/secrets. -- [ ] Load/unload failures return typed safe errors. -- [ ] Route contract and security regression tests cover new endpoints. +- [x] Users can still list selectable transcription models. +- [x] Admin diagnostics show provider state, active operation, and loaded models without paths/secrets. +- [x] Load/unload failures return typed safe errors. +- [x] Route contract and security regression tests cover new endpoints. Verification: -- [ ] Not run yet. +- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/transcription/engineprovider ./internal/api -run 'TestAdminASRProvider|TestEndpointContractSmoke|TestRouteContract|TestListTranscriptionModels|TestTranscriptExecutionsLogsModelsAndStatsUseEngineServices'` +- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/transcription/... ./internal/profile ./internal/recording` +- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/api -run 'Test(AdminASRProvider|EndpointContractSmoke|RouteContract|ListTranscriptionModels|TranscriptExecutionsLogsModelsAndStatsUseEngineServices|ASREngineImportInventory|ProductionCodeDoesNotUseOldASRParameterIdentifiers|ASRProvidersDoNotDependOnAPIOrRepositories|BackendDependencyDirection)'` +- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/api` with localhost binding allowed for existing `httptest` LLM provider tests. +- [x] `GOCACHE=/tmp/scriberr-go-cache go vet ./internal/api ./internal/transcription/... ./internal/profile ./internal/recording` Artifacts: -- To be filled during implementation. +- `internal/api/asr_provider_admin_handlers.go` +- `internal/api/router.go` +- `internal/api/types.go` +- `internal/api/engine_worker_api_test.go` +- `internal/api/route_contract_test.go` +- `internal/transcription/engineprovider/types.go` +- `internal/transcription/engineprovider/registry.go` +- `devnotes/v2.0.0/sprint-trackers/asr-provider-backend-sprint-tracker.md` Commit: diff --git a/internal/api/asr_provider_admin_handlers.go b/internal/api/asr_provider_admin_handlers.go new file mode 100644 index 00000000..f92f1c4e --- /dev/null +++ b/internal/api/asr_provider_admin_handlers.go @@ -0,0 +1,268 @@ +package api + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + "scriberr/internal/transcription/asrcontract" + "scriberr/internal/transcription/engineprovider" + + "github.com/gin-gonic/gin" +) + +const asrProviderAdminTimeout = 15 * time.Second + +func (h *Handler) listASRProviders(c *gin.Context) { + if h.modelRegistry == nil { + c.JSON(http.StatusOK, gin.H{"items": []gin.H{}}) + return + } + providers := h.modelRegistry.Providers() + items := make([]gin.H, 0, len(providers)) + for _, provider := range providers { + summary, err := h.asrProviderSummary(c.Request.Context(), provider) + if err != nil { + writeASRProviderError(c, err) + return + } + items = append(items, summary) + } + c.JSON(http.StatusOK, gin.H{"items": items}) +} + +func (h *Handler) getASRProvider(c *gin.Context) { + provider, ok := h.asrProviderByID(c, c.Param("provider_id")) + if !ok { + return + } + ctx, cancel := context.WithTimeout(c.Request.Context(), asrProviderAdminTimeout) + defer cancel() + + info, err := provider.Inspect(ctx) + if err != nil { + writeASRProviderError(c, err) + return + } + status, err := provider.Status(ctx) + if err != nil { + writeASRProviderError(c, err) + return + } + models, err := provider.Models(ctx) + if err != nil { + writeASRProviderError(c, err) + return + } + loaded, err := provider.LoadedModels(ctx) + if err != nil { + writeASRProviderError(c, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "id": provider.ID(), + "info": sanitizeProviderInfo(info), + "status": sanitizeProviderStatus(status), + "models": sanitizeModelCards(models), + "loaded_models": sanitizeLoadedModels(loaded), + }) +} + +func (h *Handler) loadASRProviderModel(c *gin.Context) { + provider, ok := h.asrProviderByID(c, c.Param("provider_id")) + if !ok { + return + } + var req loadASRProviderModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + writeError(c, http.StatusBadRequest, "BAD_REQUEST", "invalid request body", nil) + return + } + model := strings.TrimSpace(req.Model) + if model == "" { + writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "model is required", stringPtr("model")) + return + } + ctx, cancel := context.WithTimeout(c.Request.Context(), asrProviderAdminTimeout) + defer cancel() + err := provider.LoadModel(ctx, asrcontract.LoadModelRequest{ + Model: model, + Operation: asrcontract.Operation(strings.TrimSpace(req.Operation)), + LoadPolicy: asrcontract.LoadPolicy(strings.TrimSpace(req.LoadPolicy)), + Options: req.Options, + }) + if err != nil { + writeASRProviderError(c, err) + return + } + c.JSON(http.StatusAccepted, gin.H{"provider": provider.ID(), "model": model, "status": "loading"}) +} + +func (h *Handler) unloadASRProviderModel(c *gin.Context) { + provider, ok := h.asrProviderByID(c, c.Param("provider_id")) + if !ok { + return + } + var req unloadASRProviderModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + writeError(c, http.StatusBadRequest, "BAD_REQUEST", "invalid request body", nil) + return + } + model := strings.TrimSpace(req.Model) + if model == "" { + writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "model is required", stringPtr("model")) + return + } + ctx, cancel := context.WithTimeout(c.Request.Context(), asrProviderAdminTimeout) + defer cancel() + err := provider.UnloadModel(ctx, asrcontract.UnloadModelRequest{ + Model: model, + Force: req.Force, + Options: req.Options, + }) + if err != nil { + writeASRProviderError(c, err) + return + } + c.JSON(http.StatusAccepted, gin.H{"provider": provider.ID(), "model": model, "status": "unloading"}) +} + +func (h *Handler) asrProviderByID(c *gin.Context, id string) (engineprovider.Provider, bool) { + if h.modelRegistry == nil { + writeError(c, http.StatusNotFound, "NOT_FOUND", "ASR provider not found", nil) + return nil, false + } + provider, ok := h.modelRegistry.Provider(strings.TrimSpace(id)) + if !ok { + writeError(c, http.StatusNotFound, "NOT_FOUND", "ASR provider not found", nil) + return nil, false + } + return provider, true +} + +func (h *Handler) asrProviderSummary(parent context.Context, provider engineprovider.Provider) (gin.H, error) { + ctx, cancel := context.WithTimeout(parent, asrProviderAdminTimeout) + defer cancel() + info, err := provider.Inspect(ctx) + if err != nil { + return nil, err + } + status, err := provider.Status(ctx) + if err != nil { + return nil, err + } + loaded, err := provider.LoadedModels(ctx) + if err != nil { + return nil, err + } + return gin.H{ + "id": provider.ID(), + "info": sanitizeProviderInfo(info), + "status": sanitizeProviderStatus(status), + "loaded_models": sanitizeLoadedModels(loaded), + }, nil +} + +func writeASRProviderError(c *gin.Context, err error) { + if errors.Is(err, context.DeadlineExceeded) { + writeError(c, http.StatusGatewayTimeout, "PROVIDER_TIMEOUT", "ASR provider timed out", nil) + return + } + var providerErr *asrcontract.ProviderError + if errors.As(err, &providerErr) { + status := http.StatusBadGateway + switch providerErr.Code { + case asrcontract.CodeInvalidRequest: + status = http.StatusUnprocessableEntity + case asrcontract.CodeUnsupportedModel, asrcontract.CodeModelNotInstalled: + status = http.StatusNotFound + case asrcontract.CodeProviderBusy: + status = http.StatusConflict + case asrcontract.CodeProviderUnhealthy, asrcontract.CodeInsufficientResources: + status = http.StatusServiceUnavailable + case asrcontract.CodeTimeout: + status = http.StatusGatewayTimeout + } + message := sanitizePublicText(providerErr.Message) + if strings.TrimSpace(message) == "" { + message = "ASR provider request failed" + } + writeError(c, status, string(providerErr.Code), message, nil) + return + } + writeError(c, http.StatusBadGateway, "PROVIDER_ERROR", "ASR provider request failed", nil) +} + +func sanitizeProviderInfo(info *asrcontract.ProviderInfo) any { + if info == nil { + return nil + } + return gin.H{ + "contract_version": info.ContractVersion, + "provider": info.Provider, + "runtime": info.Runtime, + "audio_input": info.AudioInput, + } +} + +func sanitizeProviderStatus(status *asrcontract.ProviderStatus) any { + if status == nil { + return nil + } + return gin.H{ + "state": status.State, + "active_job": sanitizeActiveJob(status.ActiveJob), + "loaded_models": sanitizeLoadedModels(status.LoadedModels), + "capacity": status.Capacity, + } +} + +func sanitizeActiveJob(job *asrcontract.ActiveJob) any { + if job == nil { + return nil + } + return gin.H{ + "id": sanitizePublicText(job.ID), + "operation": job.Operation, + "model": sanitizePublicText(job.Model), + "stage": job.Stage, + "progress": job.Progress, + } +} + +func sanitizeLoadedModels(models []asrcontract.LoadedModel) []gin.H { + out := make([]gin.H, 0, len(models)) + for _, model := range models { + out = append(out, gin.H{ + "id": sanitizePublicText(model.ID), + "loaded_at": model.LoadedAt, + "memory_mb": model.MemoryMB, + }) + } + return out +} + +func sanitizeModelCards(models []asrcontract.ModelCard) []gin.H { + out := make([]gin.H, 0, len(models)) + for _, model := range models { + out = append(out, gin.H{ + "id": sanitizePublicText(model.ID), + "display_name": sanitizePublicText(model.DisplayName), + "provider": sanitizePublicText(model.Provider), + "family": sanitizePublicText(model.Family), + "version": sanitizePublicText(model.Version), + "installed": model.Installed, + "loaded": model.Loaded, + "default": model.Default, + "tasks": model.Tasks, + "languages": model.Languages, + "capabilities": model.Capabilities, + "limits": model.Limits, + "resource_requirements": model.ResourceRequirements, + "license": sanitizePublicText(model.License), + }) + } + return out +} diff --git a/internal/api/engine_worker_api_test.go b/internal/api/engine_worker_api_test.go index e08b0d0c..70c22665 100644 --- a/internal/api/engine_worker_api_test.go +++ b/internal/api/engine_worker_api_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "scriberr/internal/auth" "scriberr/internal/database" "scriberr/internal/models" "scriberr/internal/transcription/asrcontract" @@ -97,6 +98,82 @@ func (p fakeCapabilityProvider) IdentifySpeakers(context.Context, asrcontract.Sp } func (p fakeCapabilityProvider) Close() error { return nil } +type fakeAdminASRProvider struct { + mu sync.Mutex + id string + status asrcontract.ProviderStatus + models []asrcontract.ModelCard + loaded []asrcontract.LoadedModel + loads []asrcontract.LoadModelRequest + unloads []asrcontract.UnloadModelRequest + loadErr error + unloadErr error + statusErr error + inspectErr error + modelsErr error + loadedErr error +} + +func (p *fakeAdminASRProvider) ID() string { return p.id } +func (p *fakeAdminASRProvider) Inspect(context.Context) (*asrcontract.ProviderInfo, error) { + if p.inspectErr != nil { + return nil, p.inspectErr + } + return &asrcontract.ProviderInfo{ + ContractVersion: asrcontract.ContractVersionV1, + Provider: asrcontract.ProviderIdentity{ID: p.id, Name: "Fake ASR"}, + Runtime: asrcontract.RuntimeInfo{DeviceBackends: []string{"cpu"}, MaxConcurrentJobs: 1}, + }, nil +} +func (p *fakeAdminASRProvider) Models(context.Context) ([]asrcontract.ModelCard, error) { + if p.modelsErr != nil { + return nil, p.modelsErr + } + return p.models, nil +} +func (p *fakeAdminASRProvider) Status(context.Context) (*asrcontract.ProviderStatus, error) { + if p.statusErr != nil { + return nil, p.statusErr + } + status := p.status + if status.State == "" { + status.State = asrcontract.ProviderStateIdle + } + return &status, nil +} +func (p *fakeAdminASRProvider) LoadModel(_ context.Context, req asrcontract.LoadModelRequest) error { + p.mu.Lock() + defer p.mu.Unlock() + p.loads = append(p.loads, req) + return p.loadErr +} +func (p *fakeAdminASRProvider) UnloadModel(_ context.Context, req asrcontract.UnloadModelRequest) error { + p.mu.Lock() + defer p.mu.Unlock() + p.unloads = append(p.unloads, req) + return p.unloadErr +} +func (p *fakeAdminASRProvider) LoadedModels(context.Context) ([]asrcontract.LoadedModel, error) { + if p.loadedErr != nil { + return nil, p.loadedErr + } + return p.loaded, nil +} +func (p *fakeAdminASRProvider) Capabilities(context.Context) ([]engineprovider.ModelCapability, error) { + return []engineprovider.ModelCapability{{ID: "whisper-base", Provider: p.id, Installed: true, Capabilities: []string{"transcription"}}}, nil +} +func (p *fakeAdminASRProvider) Prepare(context.Context) error { return nil } +func (p *fakeAdminASRProvider) Transcribe(context.Context, engineprovider.TranscriptionRequest) (*engineprovider.TranscriptionResult, error) { + return nil, nil +} +func (p *fakeAdminASRProvider) Diarize(context.Context, engineprovider.DiarizationRequest) (*engineprovider.DiarizationResult, error) { + return nil, nil +} +func (p *fakeAdminASRProvider) IdentifySpeakers(context.Context, asrcontract.SpeakerIDRequest) (*asrcontract.SpeakerIDResult, error) { + return nil, nil +} +func (p *fakeAdminASRProvider) Close() error { return nil } + func TestCreateSubmitRetryUseQueueService(t *testing.T) { s := newAuthTestServer(t) queue := &fakeQueueService{} @@ -282,3 +359,111 @@ func TestQueueServiceErrorDoesNotLeakInternals(t *testing.T) { require.NotContains(t, message, "/tmp/private") require.NotContains(t, message, "secret") } + +func TestAdminASRProviderDiagnosticsAndModelCommands(t *testing.T) { + s := newAuthTestServer(t) + adminToken := registerForFileTests(t, s) + loadedAt := time.Now().UTC().Truncate(time.Millisecond) + memory := 512 + progress := 0.42 + provider := &fakeAdminASRProvider{ + id: "local", + status: asrcontract.ProviderStatus{ + State: asrcontract.ProviderStateBusy, + ActiveJob: &asrcontract.ActiveJob{ + ID: "job-/tmp/private/audio.wav", + Operation: asrcontract.OperationTranscription, + Model: "whisper-base api_key=secret", + Stage: asrcontract.StageTranscribing, + Progress: &progress, + }, + Capacity: asrcontract.ProviderCapacity{MaxConcurrentJobs: 1}, + }, + models: []asrcontract.ModelCard{{ + ID: "whisper-base", + DisplayName: "Whisper Base", + Provider: "local", + Family: "whisper", + Installed: true, + Loaded: true, + Default: true, + SourceURL: "file:///tmp/private/model.bin?api_key=secret", + Capabilities: asrcontract.Capabilities{ + Transcription: true, + WordTimestamps: true, + }, + }}, + loaded: []asrcontract.LoadedModel{{ID: "whisper-base", LoadedAt: &loadedAt, MemoryMB: &memory}}, + } + registry, err := engineprovider.NewRegistry("local", provider) + require.NoError(t, err) + s.handler.modelRegistry = registry + + resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, adminToken, "") + require.Equal(t, http.StatusOK, resp.Code) + item := body["items"].([]any)[0].(map[string]any) + require.Equal(t, "local", item["id"]) + status := item["status"].(map[string]any) + activeJob := status["active_job"].(map[string]any) + require.NotContains(t, activeJob["id"], "/tmp/private") + require.NotContains(t, activeJob["model"], "secret") + + resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/local", nil, adminToken, "") + require.Equal(t, http.StatusOK, resp.Code) + model := body["models"].([]any)[0].(map[string]any) + require.Equal(t, "whisper-base", model["id"]) + require.NotContains(t, model, "source_url") + require.Len(t, body["loaded_models"].([]any), 1) + + resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{ + "model": "whisper-base", + "operation": "transcription", + "load_policy": "require", + }, adminToken, "") + require.Equal(t, http.StatusAccepted, resp.Code) + require.Equal(t, "loading", body["status"]) + require.Len(t, provider.loads, 1) + require.Equal(t, asrcontract.OperationTranscription, provider.loads[0].Operation) + require.Equal(t, asrcontract.LoadPolicyRequire, provider.loads[0].LoadPolicy) + + resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/unload", map[string]any{ + "model": "whisper-base", + "force": true, + }, adminToken, "") + require.Equal(t, http.StatusAccepted, resp.Code) + require.Equal(t, "unloading", body["status"]) + require.Len(t, provider.unloads, 1) + require.True(t, provider.unloads[0].Force) +} + +func TestAdminASRProviderRoutesAreAdminOnlyAndMapProviderErrors(t *testing.T) { + s := newAuthTestServer(t) + adminToken := registerForFileTests(t, s) + nonAdmin := models.User{Username: "asr-member", Password: "pw", Role: "user"} + require.NoError(t, database.DB.Create(&nonAdmin).Error) + nonAdminToken, err := auth.NewAuthService("test-secret").GenerateToken(&nonAdmin) + require.NoError(t, err) + provider := &fakeAdminASRProvider{ + id: "local", + loadErr: asrcontract.NewProviderError(asrcontract.CodeProviderBusy, "busy at /tmp/private/model.bin token=secret", true), + } + registry, err := engineprovider.NewRegistry("local", provider) + require.NoError(t, err) + s.handler.modelRegistry = registry + + resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, nonAdminToken, "") + require.Equal(t, http.StatusForbidden, resp.Code) + require.Equal(t, "FORBIDDEN", body["error"].(map[string]any)["code"]) + + resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/missing", nil, adminToken, "") + require.Equal(t, http.StatusNotFound, resp.Code) + + resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{ + "model": "whisper-base", + }, adminToken, "") + require.Equal(t, http.StatusConflict, resp.Code) + errBody := body["error"].(map[string]any) + require.Equal(t, string(asrcontract.CodeProviderBusy), errBody["code"]) + require.NotContains(t, errBody["message"], "/tmp/private") + require.NotContains(t, errBody["message"], "secret") +} diff --git a/internal/api/route_contract_test.go b/internal/api/route_contract_test.go index d0898d9c..d57d38b3 100644 --- a/internal/api/route_contract_test.go +++ b/internal/api/route_contract_test.go @@ -103,6 +103,10 @@ func TestCanonicalRouteRegistration(t *testing.T) { "GET /api/v1/admin/queue", "GET /api/v1/admin/queue/scheduler", "PUT /api/v1/admin/queue/scheduler", + "GET /api/v1/admin/asr-providers", + "GET /api/v1/admin/asr-providers/:provider_id", + "POST /api/v1/admin/asr-providers/:provider_id/models/load", + "POST /api/v1/admin/asr-providers/:provider_id/models/unload", "GET /api/v1/admin/users", "POST /api/v1/admin/users", "GET /api/v1/admin/users/:user_id", @@ -154,6 +158,8 @@ func TestEndpointContractSmoke(t *testing.T) { {name: "queue stats", method: http.MethodGet, path: "/api/v1/admin/queue", token: token, want: http.StatusOK}, {name: "queue scheduler get", method: http.MethodGet, path: "/api/v1/admin/queue/scheduler", token: token, want: http.StatusOK}, {name: "queue scheduler invalid update", method: http.MethodPut, path: "/api/v1/admin/queue/scheduler", body: map[string]any{"policy": "random"}, token: token, want: http.StatusUnprocessableEntity}, + {name: "asr provider diagnostics", method: http.MethodGet, path: "/api/v1/admin/asr-providers", token: token, want: http.StatusOK}, + {name: "asr provider missing", method: http.MethodGet, path: "/api/v1/admin/asr-providers/missing", token: token, want: http.StatusNotFound}, {name: "admin users list", method: http.MethodGet, path: "/api/v1/admin/users", token: token, want: http.StatusOK}, {name: "admin users invalid create", method: http.MethodPost, path: "/api/v1/admin/users", body: map[string]any{"username": "u"}, token: token, want: http.StatusUnprocessableEntity}, {name: "youtube import", method: http.MethodPost, path: "/api/v1/files:import-youtube", body: map[string]any{"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, token: token, want: http.StatusAccepted}, diff --git a/internal/api/router.go b/internal/api/router.go index 70f0b645..d3e9d9cf 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -376,6 +376,10 @@ func SetupRoutes(handler *Handler, _ *auth.AuthService) *gin.Engine { adminRoutes.GET("/queue", handler.queueStats) adminRoutes.GET("/queue/scheduler", handler.getAdminQueueScheduler) adminRoutes.PUT("/queue/scheduler", handler.updateAdminQueueScheduler) + adminRoutes.GET("/asr-providers", handler.listASRProviders) + adminRoutes.GET("/asr-providers/:provider_id", handler.getASRProvider) + adminRoutes.POST("/asr-providers/:provider_id/models/load", handler.loadASRProviderModel) + adminRoutes.POST("/asr-providers/:provider_id/models/unload", handler.unloadASRProviderModel) adminRoutes.GET("/users", handler.listAdminUsers) adminRoutes.POST("/users", handler.idempotencyMiddleware(), handler.createAdminUser) adminRoutes.GET("/users/:user_id", handler.getAdminUser) diff --git a/internal/api/types.go b/internal/api/types.go index ec064eaf..787f9630 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -106,6 +106,17 @@ type updateSettingsRequest struct { AutoRenameEnabled *bool `json:"auto_rename_enabled"` DefaultProfileID *string `json:"default_profile_id"` } +type loadASRProviderModelRequest struct { + Model string `json:"model"` + Operation string `json:"operation"` + LoadPolicy string `json:"load_policy"` + Options map[string]any `json:"options"` +} +type unloadASRProviderModelRequest struct { + Model string `json:"model"` + Force bool `json:"force"` + Options map[string]any `json:"options"` +} type adminCreateUserRequest struct { Username string `json:"username"` Email *string `json:"email"` diff --git a/internal/transcription/engineprovider/registry.go b/internal/transcription/engineprovider/registry.go index 4c02b330..bad24585 100644 --- a/internal/transcription/engineprovider/registry.go +++ b/internal/transcription/engineprovider/registry.go @@ -50,6 +50,19 @@ func (r *StaticRegistry) Provider(id string) (Provider, bool) { return provider, ok } +func (r *StaticRegistry) Providers() []Provider { + ids := make([]string, 0, len(r.providers)) + for id := range r.providers { + ids = append(ids, id) + } + sort.Strings(ids) + out := make([]Provider, 0, len(ids)) + for _, id := range ids { + out = append(out, r.providers[id]) + } + return out +} + func (r *StaticRegistry) Models(ctx context.Context) ([]asrcontract.ModelCard, error) { ids := make([]string, 0, len(r.providers)) for id := range r.providers { diff --git a/internal/transcription/engineprovider/types.go b/internal/transcription/engineprovider/types.go index acbf6707..306cd55f 100644 --- a/internal/transcription/engineprovider/types.go +++ b/internal/transcription/engineprovider/types.go @@ -35,6 +35,7 @@ type ProgressSink interface { type Registry interface { DefaultProvider() Provider Provider(id string) (Provider, bool) + Providers() []Provider Models(ctx context.Context) ([]asrcontract.ModelCard, error) Capabilities(ctx context.Context) ([]ModelCapability, error) Select(ctx context.Context, req SelectionRequest) (Provider, *ModelCapability, error)