mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-29 23:36:18 +00:00
backend: add asr provider diagnostics api
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
Run ID: `ASRP`
|
||||
|
||||
Status: completed through ASRP-Sprint 10.
|
||||
Status: completed through ASRP-Sprint 11.
|
||||
|
||||
This tracker belongs to `devnotes/v2.0.0/sprint-plans/asr-provider-backend-sprint-plan.md` and the design spec in `devnotes/v2.0.0/specs/asr-provider-backend-architecture.md`.
|
||||
|
||||
@@ -492,30 +492,42 @@ Commit:
|
||||
|
||||
## ASRP-Sprint 11: Provider Admin/Diagnostics API
|
||||
|
||||
Status: pending
|
||||
Status: completed
|
||||
|
||||
Planned tasks:
|
||||
Completed tasks:
|
||||
|
||||
- [ ] Add service methods for provider list/status/model load/unload.
|
||||
- [ ] Add authenticated/admin-gated diagnostics endpoints as appropriate.
|
||||
- [ ] Keep `/api/v1/models/transcription` user-readable.
|
||||
- [ ] Use bounded timeouts for load/unload commands.
|
||||
- [ ] Sanitize provider status messages.
|
||||
- [x] Added registry provider enumeration for diagnostics.
|
||||
- [x] Added admin-only ASR provider list/detail endpoints.
|
||||
- [x] Added admin-only provider model load/unload endpoints.
|
||||
- [x] Kept `/api/v1/models/transcription` user-readable and unchanged.
|
||||
- [x] Added bounded timeouts around provider diagnostics and model commands.
|
||||
- [x] Sanitized provider diagnostics and provider error messages before API output.
|
||||
|
||||
Acceptance checks:
|
||||
|
||||
- [ ] Users can still list selectable transcription models.
|
||||
- [ ] Admin diagnostics show provider state, active operation, and loaded models without paths/secrets.
|
||||
- [ ] Load/unload failures return typed safe errors.
|
||||
- [ ] Route contract and security regression tests cover new endpoints.
|
||||
- [x] Users can still list selectable transcription models.
|
||||
- [x] Admin diagnostics show provider state, active operation, and loaded models without paths/secrets.
|
||||
- [x] Load/unload failures return typed safe errors.
|
||||
- [x] Route contract and security regression tests cover new endpoints.
|
||||
|
||||
Verification:
|
||||
|
||||
- [ ] Not run yet.
|
||||
- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/transcription/engineprovider ./internal/api -run 'TestAdminASRProvider|TestEndpointContractSmoke|TestRouteContract|TestListTranscriptionModels|TestTranscriptExecutionsLogsModelsAndStatsUseEngineServices'`
|
||||
- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/transcription/... ./internal/profile ./internal/recording`
|
||||
- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/api -run 'Test(AdminASRProvider|EndpointContractSmoke|RouteContract|ListTranscriptionModels|TranscriptExecutionsLogsModelsAndStatsUseEngineServices|ASREngineImportInventory|ProductionCodeDoesNotUseOldASRParameterIdentifiers|ASRProvidersDoNotDependOnAPIOrRepositories|BackendDependencyDirection)'`
|
||||
- [x] `GOCACHE=/tmp/scriberr-go-cache go test ./internal/api` with localhost binding allowed for existing `httptest` LLM provider tests.
|
||||
- [x] `GOCACHE=/tmp/scriberr-go-cache go vet ./internal/api ./internal/transcription/... ./internal/profile ./internal/recording`
|
||||
|
||||
Artifacts:
|
||||
|
||||
- To be filled during implementation.
|
||||
- `internal/api/asr_provider_admin_handlers.go`
|
||||
- `internal/api/router.go`
|
||||
- `internal/api/types.go`
|
||||
- `internal/api/engine_worker_api_test.go`
|
||||
- `internal/api/route_contract_test.go`
|
||||
- `internal/transcription/engineprovider/types.go`
|
||||
- `internal/transcription/engineprovider/registry.go`
|
||||
- `devnotes/v2.0.0/sprint-trackers/asr-provider-backend-sprint-tracker.md`
|
||||
|
||||
Commit:
|
||||
|
||||
|
||||
268
internal/api/asr_provider_admin_handlers.go
Normal file
268
internal/api/asr_provider_admin_handlers.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"scriberr/internal/transcription/asrcontract"
|
||||
"scriberr/internal/transcription/engineprovider"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const asrProviderAdminTimeout = 15 * time.Second
|
||||
|
||||
func (h *Handler) listASRProviders(c *gin.Context) {
|
||||
if h.modelRegistry == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"items": []gin.H{}})
|
||||
return
|
||||
}
|
||||
providers := h.modelRegistry.Providers()
|
||||
items := make([]gin.H, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
summary, err := h.asrProviderSummary(c.Request.Context(), provider)
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
items = append(items, summary)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *Handler) getASRProvider(c *gin.Context) {
|
||||
provider, ok := h.asrProviderByID(c, c.Param("provider_id"))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), asrProviderAdminTimeout)
|
||||
defer cancel()
|
||||
|
||||
info, err := provider.Inspect(ctx)
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
status, err := provider.Status(ctx)
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
models, err := provider.Models(ctx)
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
loaded, err := provider.LoadedModels(ctx)
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": provider.ID(),
|
||||
"info": sanitizeProviderInfo(info),
|
||||
"status": sanitizeProviderStatus(status),
|
||||
"models": sanitizeModelCards(models),
|
||||
"loaded_models": sanitizeLoadedModels(loaded),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) loadASRProviderModel(c *gin.Context) {
|
||||
provider, ok := h.asrProviderByID(c, c.Param("provider_id"))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req loadASRProviderModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
writeError(c, http.StatusBadRequest, "BAD_REQUEST", "invalid request body", nil)
|
||||
return
|
||||
}
|
||||
model := strings.TrimSpace(req.Model)
|
||||
if model == "" {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "model is required", stringPtr("model"))
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), asrProviderAdminTimeout)
|
||||
defer cancel()
|
||||
err := provider.LoadModel(ctx, asrcontract.LoadModelRequest{
|
||||
Model: model,
|
||||
Operation: asrcontract.Operation(strings.TrimSpace(req.Operation)),
|
||||
LoadPolicy: asrcontract.LoadPolicy(strings.TrimSpace(req.LoadPolicy)),
|
||||
Options: req.Options,
|
||||
})
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusAccepted, gin.H{"provider": provider.ID(), "model": model, "status": "loading"})
|
||||
}
|
||||
|
||||
func (h *Handler) unloadASRProviderModel(c *gin.Context) {
|
||||
provider, ok := h.asrProviderByID(c, c.Param("provider_id"))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req unloadASRProviderModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
writeError(c, http.StatusBadRequest, "BAD_REQUEST", "invalid request body", nil)
|
||||
return
|
||||
}
|
||||
model := strings.TrimSpace(req.Model)
|
||||
if model == "" {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "model is required", stringPtr("model"))
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), asrProviderAdminTimeout)
|
||||
defer cancel()
|
||||
err := provider.UnloadModel(ctx, asrcontract.UnloadModelRequest{
|
||||
Model: model,
|
||||
Force: req.Force,
|
||||
Options: req.Options,
|
||||
})
|
||||
if err != nil {
|
||||
writeASRProviderError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusAccepted, gin.H{"provider": provider.ID(), "model": model, "status": "unloading"})
|
||||
}
|
||||
|
||||
func (h *Handler) asrProviderByID(c *gin.Context, id string) (engineprovider.Provider, bool) {
|
||||
if h.modelRegistry == nil {
|
||||
writeError(c, http.StatusNotFound, "NOT_FOUND", "ASR provider not found", nil)
|
||||
return nil, false
|
||||
}
|
||||
provider, ok := h.modelRegistry.Provider(strings.TrimSpace(id))
|
||||
if !ok {
|
||||
writeError(c, http.StatusNotFound, "NOT_FOUND", "ASR provider not found", nil)
|
||||
return nil, false
|
||||
}
|
||||
return provider, true
|
||||
}
|
||||
|
||||
func (h *Handler) asrProviderSummary(parent context.Context, provider engineprovider.Provider) (gin.H, error) {
|
||||
ctx, cancel := context.WithTimeout(parent, asrProviderAdminTimeout)
|
||||
defer cancel()
|
||||
info, err := provider.Inspect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status, err := provider.Status(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loaded, err := provider.LoadedModels(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gin.H{
|
||||
"id": provider.ID(),
|
||||
"info": sanitizeProviderInfo(info),
|
||||
"status": sanitizeProviderStatus(status),
|
||||
"loaded_models": sanitizeLoadedModels(loaded),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func writeASRProviderError(c *gin.Context, err error) {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
writeError(c, http.StatusGatewayTimeout, "PROVIDER_TIMEOUT", "ASR provider timed out", nil)
|
||||
return
|
||||
}
|
||||
var providerErr *asrcontract.ProviderError
|
||||
if errors.As(err, &providerErr) {
|
||||
status := http.StatusBadGateway
|
||||
switch providerErr.Code {
|
||||
case asrcontract.CodeInvalidRequest:
|
||||
status = http.StatusUnprocessableEntity
|
||||
case asrcontract.CodeUnsupportedModel, asrcontract.CodeModelNotInstalled:
|
||||
status = http.StatusNotFound
|
||||
case asrcontract.CodeProviderBusy:
|
||||
status = http.StatusConflict
|
||||
case asrcontract.CodeProviderUnhealthy, asrcontract.CodeInsufficientResources:
|
||||
status = http.StatusServiceUnavailable
|
||||
case asrcontract.CodeTimeout:
|
||||
status = http.StatusGatewayTimeout
|
||||
}
|
||||
message := sanitizePublicText(providerErr.Message)
|
||||
if strings.TrimSpace(message) == "" {
|
||||
message = "ASR provider request failed"
|
||||
}
|
||||
writeError(c, status, string(providerErr.Code), message, nil)
|
||||
return
|
||||
}
|
||||
writeError(c, http.StatusBadGateway, "PROVIDER_ERROR", "ASR provider request failed", nil)
|
||||
}
|
||||
|
||||
func sanitizeProviderInfo(info *asrcontract.ProviderInfo) any {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
return gin.H{
|
||||
"contract_version": info.ContractVersion,
|
||||
"provider": info.Provider,
|
||||
"runtime": info.Runtime,
|
||||
"audio_input": info.AudioInput,
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeProviderStatus(status *asrcontract.ProviderStatus) any {
|
||||
if status == nil {
|
||||
return nil
|
||||
}
|
||||
return gin.H{
|
||||
"state": status.State,
|
||||
"active_job": sanitizeActiveJob(status.ActiveJob),
|
||||
"loaded_models": sanitizeLoadedModels(status.LoadedModels),
|
||||
"capacity": status.Capacity,
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeActiveJob(job *asrcontract.ActiveJob) any {
|
||||
if job == nil {
|
||||
return nil
|
||||
}
|
||||
return gin.H{
|
||||
"id": sanitizePublicText(job.ID),
|
||||
"operation": job.Operation,
|
||||
"model": sanitizePublicText(job.Model),
|
||||
"stage": job.Stage,
|
||||
"progress": job.Progress,
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeLoadedModels(models []asrcontract.LoadedModel) []gin.H {
|
||||
out := make([]gin.H, 0, len(models))
|
||||
for _, model := range models {
|
||||
out = append(out, gin.H{
|
||||
"id": sanitizePublicText(model.ID),
|
||||
"loaded_at": model.LoadedAt,
|
||||
"memory_mb": model.MemoryMB,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeModelCards(models []asrcontract.ModelCard) []gin.H {
|
||||
out := make([]gin.H, 0, len(models))
|
||||
for _, model := range models {
|
||||
out = append(out, gin.H{
|
||||
"id": sanitizePublicText(model.ID),
|
||||
"display_name": sanitizePublicText(model.DisplayName),
|
||||
"provider": sanitizePublicText(model.Provider),
|
||||
"family": sanitizePublicText(model.Family),
|
||||
"version": sanitizePublicText(model.Version),
|
||||
"installed": model.Installed,
|
||||
"loaded": model.Loaded,
|
||||
"default": model.Default,
|
||||
"tasks": model.Tasks,
|
||||
"languages": model.Languages,
|
||||
"capabilities": model.Capabilities,
|
||||
"limits": model.Limits,
|
||||
"resource_requirements": model.ResourceRequirements,
|
||||
"license": sanitizePublicText(model.License),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"scriberr/internal/auth"
|
||||
"scriberr/internal/database"
|
||||
"scriberr/internal/models"
|
||||
"scriberr/internal/transcription/asrcontract"
|
||||
@@ -97,6 +98,82 @@ func (p fakeCapabilityProvider) IdentifySpeakers(context.Context, asrcontract.Sp
|
||||
}
|
||||
func (p fakeCapabilityProvider) Close() error { return nil }
|
||||
|
||||
type fakeAdminASRProvider struct {
|
||||
mu sync.Mutex
|
||||
id string
|
||||
status asrcontract.ProviderStatus
|
||||
models []asrcontract.ModelCard
|
||||
loaded []asrcontract.LoadedModel
|
||||
loads []asrcontract.LoadModelRequest
|
||||
unloads []asrcontract.UnloadModelRequest
|
||||
loadErr error
|
||||
unloadErr error
|
||||
statusErr error
|
||||
inspectErr error
|
||||
modelsErr error
|
||||
loadedErr error
|
||||
}
|
||||
|
||||
func (p *fakeAdminASRProvider) ID() string { return p.id }
|
||||
func (p *fakeAdminASRProvider) Inspect(context.Context) (*asrcontract.ProviderInfo, error) {
|
||||
if p.inspectErr != nil {
|
||||
return nil, p.inspectErr
|
||||
}
|
||||
return &asrcontract.ProviderInfo{
|
||||
ContractVersion: asrcontract.ContractVersionV1,
|
||||
Provider: asrcontract.ProviderIdentity{ID: p.id, Name: "Fake ASR"},
|
||||
Runtime: asrcontract.RuntimeInfo{DeviceBackends: []string{"cpu"}, MaxConcurrentJobs: 1},
|
||||
}, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) Models(context.Context) ([]asrcontract.ModelCard, error) {
|
||||
if p.modelsErr != nil {
|
||||
return nil, p.modelsErr
|
||||
}
|
||||
return p.models, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) Status(context.Context) (*asrcontract.ProviderStatus, error) {
|
||||
if p.statusErr != nil {
|
||||
return nil, p.statusErr
|
||||
}
|
||||
status := p.status
|
||||
if status.State == "" {
|
||||
status.State = asrcontract.ProviderStateIdle
|
||||
}
|
||||
return &status, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) LoadModel(_ context.Context, req asrcontract.LoadModelRequest) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.loads = append(p.loads, req)
|
||||
return p.loadErr
|
||||
}
|
||||
func (p *fakeAdminASRProvider) UnloadModel(_ context.Context, req asrcontract.UnloadModelRequest) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.unloads = append(p.unloads, req)
|
||||
return p.unloadErr
|
||||
}
|
||||
func (p *fakeAdminASRProvider) LoadedModels(context.Context) ([]asrcontract.LoadedModel, error) {
|
||||
if p.loadedErr != nil {
|
||||
return nil, p.loadedErr
|
||||
}
|
||||
return p.loaded, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) Capabilities(context.Context) ([]engineprovider.ModelCapability, error) {
|
||||
return []engineprovider.ModelCapability{{ID: "whisper-base", Provider: p.id, Installed: true, Capabilities: []string{"transcription"}}}, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) Prepare(context.Context) error { return nil }
|
||||
func (p *fakeAdminASRProvider) Transcribe(context.Context, engineprovider.TranscriptionRequest) (*engineprovider.TranscriptionResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) Diarize(context.Context, engineprovider.DiarizationRequest) (*engineprovider.DiarizationResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) IdentifySpeakers(context.Context, asrcontract.SpeakerIDRequest) (*asrcontract.SpeakerIDResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (p *fakeAdminASRProvider) Close() error { return nil }
|
||||
|
||||
func TestCreateSubmitRetryUseQueueService(t *testing.T) {
|
||||
s := newAuthTestServer(t)
|
||||
queue := &fakeQueueService{}
|
||||
@@ -282,3 +359,111 @@ func TestQueueServiceErrorDoesNotLeakInternals(t *testing.T) {
|
||||
require.NotContains(t, message, "/tmp/private")
|
||||
require.NotContains(t, message, "secret")
|
||||
}
|
||||
|
||||
func TestAdminASRProviderDiagnosticsAndModelCommands(t *testing.T) {
|
||||
s := newAuthTestServer(t)
|
||||
adminToken := registerForFileTests(t, s)
|
||||
loadedAt := time.Now().UTC().Truncate(time.Millisecond)
|
||||
memory := 512
|
||||
progress := 0.42
|
||||
provider := &fakeAdminASRProvider{
|
||||
id: "local",
|
||||
status: asrcontract.ProviderStatus{
|
||||
State: asrcontract.ProviderStateBusy,
|
||||
ActiveJob: &asrcontract.ActiveJob{
|
||||
ID: "job-/tmp/private/audio.wav",
|
||||
Operation: asrcontract.OperationTranscription,
|
||||
Model: "whisper-base api_key=secret",
|
||||
Stage: asrcontract.StageTranscribing,
|
||||
Progress: &progress,
|
||||
},
|
||||
Capacity: asrcontract.ProviderCapacity{MaxConcurrentJobs: 1},
|
||||
},
|
||||
models: []asrcontract.ModelCard{{
|
||||
ID: "whisper-base",
|
||||
DisplayName: "Whisper Base",
|
||||
Provider: "local",
|
||||
Family: "whisper",
|
||||
Installed: true,
|
||||
Loaded: true,
|
||||
Default: true,
|
||||
SourceURL: "file:///tmp/private/model.bin?api_key=secret",
|
||||
Capabilities: asrcontract.Capabilities{
|
||||
Transcription: true,
|
||||
WordTimestamps: true,
|
||||
},
|
||||
}},
|
||||
loaded: []asrcontract.LoadedModel{{ID: "whisper-base", LoadedAt: &loadedAt, MemoryMB: &memory}},
|
||||
}
|
||||
registry, err := engineprovider.NewRegistry("local", provider)
|
||||
require.NoError(t, err)
|
||||
s.handler.modelRegistry = registry
|
||||
|
||||
resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, adminToken, "")
|
||||
require.Equal(t, http.StatusOK, resp.Code)
|
||||
item := body["items"].([]any)[0].(map[string]any)
|
||||
require.Equal(t, "local", item["id"])
|
||||
status := item["status"].(map[string]any)
|
||||
activeJob := status["active_job"].(map[string]any)
|
||||
require.NotContains(t, activeJob["id"], "/tmp/private")
|
||||
require.NotContains(t, activeJob["model"], "secret")
|
||||
|
||||
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/local", nil, adminToken, "")
|
||||
require.Equal(t, http.StatusOK, resp.Code)
|
||||
model := body["models"].([]any)[0].(map[string]any)
|
||||
require.Equal(t, "whisper-base", model["id"])
|
||||
require.NotContains(t, model, "source_url")
|
||||
require.Len(t, body["loaded_models"].([]any), 1)
|
||||
|
||||
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{
|
||||
"model": "whisper-base",
|
||||
"operation": "transcription",
|
||||
"load_policy": "require",
|
||||
}, adminToken, "")
|
||||
require.Equal(t, http.StatusAccepted, resp.Code)
|
||||
require.Equal(t, "loading", body["status"])
|
||||
require.Len(t, provider.loads, 1)
|
||||
require.Equal(t, asrcontract.OperationTranscription, provider.loads[0].Operation)
|
||||
require.Equal(t, asrcontract.LoadPolicyRequire, provider.loads[0].LoadPolicy)
|
||||
|
||||
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/unload", map[string]any{
|
||||
"model": "whisper-base",
|
||||
"force": true,
|
||||
}, adminToken, "")
|
||||
require.Equal(t, http.StatusAccepted, resp.Code)
|
||||
require.Equal(t, "unloading", body["status"])
|
||||
require.Len(t, provider.unloads, 1)
|
||||
require.True(t, provider.unloads[0].Force)
|
||||
}
|
||||
|
||||
func TestAdminASRProviderRoutesAreAdminOnlyAndMapProviderErrors(t *testing.T) {
|
||||
s := newAuthTestServer(t)
|
||||
adminToken := registerForFileTests(t, s)
|
||||
nonAdmin := models.User{Username: "asr-member", Password: "pw", Role: "user"}
|
||||
require.NoError(t, database.DB.Create(&nonAdmin).Error)
|
||||
nonAdminToken, err := auth.NewAuthService("test-secret").GenerateToken(&nonAdmin)
|
||||
require.NoError(t, err)
|
||||
provider := &fakeAdminASRProvider{
|
||||
id: "local",
|
||||
loadErr: asrcontract.NewProviderError(asrcontract.CodeProviderBusy, "busy at /tmp/private/model.bin token=secret", true),
|
||||
}
|
||||
registry, err := engineprovider.NewRegistry("local", provider)
|
||||
require.NoError(t, err)
|
||||
s.handler.modelRegistry = registry
|
||||
|
||||
resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, nonAdminToken, "")
|
||||
require.Equal(t, http.StatusForbidden, resp.Code)
|
||||
require.Equal(t, "FORBIDDEN", body["error"].(map[string]any)["code"])
|
||||
|
||||
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/missing", nil, adminToken, "")
|
||||
require.Equal(t, http.StatusNotFound, resp.Code)
|
||||
|
||||
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{
|
||||
"model": "whisper-base",
|
||||
}, adminToken, "")
|
||||
require.Equal(t, http.StatusConflict, resp.Code)
|
||||
errBody := body["error"].(map[string]any)
|
||||
require.Equal(t, string(asrcontract.CodeProviderBusy), errBody["code"])
|
||||
require.NotContains(t, errBody["message"], "/tmp/private")
|
||||
require.NotContains(t, errBody["message"], "secret")
|
||||
}
|
||||
|
||||
@@ -103,6 +103,10 @@ func TestCanonicalRouteRegistration(t *testing.T) {
|
||||
"GET /api/v1/admin/queue",
|
||||
"GET /api/v1/admin/queue/scheduler",
|
||||
"PUT /api/v1/admin/queue/scheduler",
|
||||
"GET /api/v1/admin/asr-providers",
|
||||
"GET /api/v1/admin/asr-providers/:provider_id",
|
||||
"POST /api/v1/admin/asr-providers/:provider_id/models/load",
|
||||
"POST /api/v1/admin/asr-providers/:provider_id/models/unload",
|
||||
"GET /api/v1/admin/users",
|
||||
"POST /api/v1/admin/users",
|
||||
"GET /api/v1/admin/users/:user_id",
|
||||
@@ -154,6 +158,8 @@ func TestEndpointContractSmoke(t *testing.T) {
|
||||
{name: "queue stats", method: http.MethodGet, path: "/api/v1/admin/queue", token: token, want: http.StatusOK},
|
||||
{name: "queue scheduler get", method: http.MethodGet, path: "/api/v1/admin/queue/scheduler", token: token, want: http.StatusOK},
|
||||
{name: "queue scheduler invalid update", method: http.MethodPut, path: "/api/v1/admin/queue/scheduler", body: map[string]any{"policy": "random"}, token: token, want: http.StatusUnprocessableEntity},
|
||||
{name: "asr provider diagnostics", method: http.MethodGet, path: "/api/v1/admin/asr-providers", token: token, want: http.StatusOK},
|
||||
{name: "asr provider missing", method: http.MethodGet, path: "/api/v1/admin/asr-providers/missing", token: token, want: http.StatusNotFound},
|
||||
{name: "admin users list", method: http.MethodGet, path: "/api/v1/admin/users", token: token, want: http.StatusOK},
|
||||
{name: "admin users invalid create", method: http.MethodPost, path: "/api/v1/admin/users", body: map[string]any{"username": "u"}, token: token, want: http.StatusUnprocessableEntity},
|
||||
{name: "youtube import", method: http.MethodPost, path: "/api/v1/files:import-youtube", body: map[string]any{"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, token: token, want: http.StatusAccepted},
|
||||
|
||||
@@ -376,6 +376,10 @@ func SetupRoutes(handler *Handler, _ *auth.AuthService) *gin.Engine {
|
||||
adminRoutes.GET("/queue", handler.queueStats)
|
||||
adminRoutes.GET("/queue/scheduler", handler.getAdminQueueScheduler)
|
||||
adminRoutes.PUT("/queue/scheduler", handler.updateAdminQueueScheduler)
|
||||
adminRoutes.GET("/asr-providers", handler.listASRProviders)
|
||||
adminRoutes.GET("/asr-providers/:provider_id", handler.getASRProvider)
|
||||
adminRoutes.POST("/asr-providers/:provider_id/models/load", handler.loadASRProviderModel)
|
||||
adminRoutes.POST("/asr-providers/:provider_id/models/unload", handler.unloadASRProviderModel)
|
||||
adminRoutes.GET("/users", handler.listAdminUsers)
|
||||
adminRoutes.POST("/users", handler.idempotencyMiddleware(), handler.createAdminUser)
|
||||
adminRoutes.GET("/users/:user_id", handler.getAdminUser)
|
||||
|
||||
@@ -106,6 +106,17 @@ type updateSettingsRequest struct {
|
||||
AutoRenameEnabled *bool `json:"auto_rename_enabled"`
|
||||
DefaultProfileID *string `json:"default_profile_id"`
|
||||
}
|
||||
type loadASRProviderModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Operation string `json:"operation"`
|
||||
LoadPolicy string `json:"load_policy"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
type unloadASRProviderModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Force bool `json:"force"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
type adminCreateUserRequest struct {
|
||||
Username string `json:"username"`
|
||||
Email *string `json:"email"`
|
||||
|
||||
@@ -50,6 +50,19 @@ func (r *StaticRegistry) Provider(id string) (Provider, bool) {
|
||||
return provider, ok
|
||||
}
|
||||
|
||||
func (r *StaticRegistry) Providers() []Provider {
|
||||
ids := make([]string, 0, len(r.providers))
|
||||
for id := range r.providers {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
sort.Strings(ids)
|
||||
out := make([]Provider, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
out = append(out, r.providers[id])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *StaticRegistry) Models(ctx context.Context) ([]asrcontract.ModelCard, error) {
|
||||
ids := make([]string, 0, len(r.providers))
|
||||
for id := range r.providers {
|
||||
|
||||
@@ -35,6 +35,7 @@ type ProgressSink interface {
|
||||
type Registry interface {
|
||||
DefaultProvider() Provider
|
||||
Provider(id string) (Provider, bool)
|
||||
Providers() []Provider
|
||||
Models(ctx context.Context) ([]asrcontract.ModelCard, error)
|
||||
Capabilities(ctx context.Context) ([]ModelCapability, error)
|
||||
Select(ctx context.Context, req SelectionRequest) (Provider, *ModelCapability, error)
|
||||
|
||||
Reference in New Issue
Block a user