mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-07-01 08:15:46 +00:00
602 lines
23 KiB
Go
602 lines
23 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"scriberr/internal/auth"
|
|
"scriberr/internal/database"
|
|
"scriberr/internal/models"
|
|
"scriberr/internal/transcription/asrcontract"
|
|
"scriberr/internal/transcription/engineprovider"
|
|
"scriberr/internal/transcription/worker"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeQueueService struct {
|
|
mu sync.Mutex
|
|
enqueued []string
|
|
canceled []string
|
|
stats worker.QueueStats
|
|
adminStats worker.AdminQueueStats
|
|
err error
|
|
cancelErr error
|
|
}
|
|
|
|
func (q *fakeQueueService) Enqueue(ctx context.Context, jobID string) error {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
q.enqueued = append(q.enqueued, jobID)
|
|
return q.err
|
|
}
|
|
func (q *fakeQueueService) Cancel(ctx context.Context, userID uint, jobID string) error {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
q.canceled = append(q.canceled, jobID)
|
|
return q.cancelErr
|
|
}
|
|
func (q *fakeQueueService) Start(context.Context) error { return nil }
|
|
func (q *fakeQueueService) Stop(context.Context) error { return nil }
|
|
func (q *fakeQueueService) Stats(context.Context, uint) (worker.QueueStats, error) {
|
|
return q.stats, q.err
|
|
}
|
|
func (q *fakeQueueService) AdminStats(context.Context) (worker.AdminQueueStats, error) {
|
|
if q.adminStats.ByUser == nil {
|
|
return worker.AdminQueueStats{QueueStats: q.stats}, q.err
|
|
}
|
|
return q.adminStats, q.err
|
|
}
|
|
|
|
func setTestQueueService(s *authTestServer, queue *fakeQueueService) {
|
|
s.handler.queueService = queue
|
|
if s.handler.transcriptions != nil {
|
|
s.handler.transcriptions.SetQueue(queue)
|
|
}
|
|
}
|
|
|
|
type fakeCapabilityProvider struct {
|
|
models []asrcontract.ModelCard
|
|
}
|
|
|
|
func (p fakeCapabilityProvider) ID() string { return "local" }
|
|
func (p fakeCapabilityProvider) Inspect(context.Context) (*asrcontract.ProviderInfo, error) {
|
|
return &asrcontract.ProviderInfo{ContractVersion: asrcontract.ContractVersionV1}, nil
|
|
}
|
|
func (p fakeCapabilityProvider) Models(context.Context) ([]asrcontract.ModelCard, error) {
|
|
return p.models, nil
|
|
}
|
|
func (p fakeCapabilityProvider) Status(context.Context) (*asrcontract.ProviderStatus, error) {
|
|
return &asrcontract.ProviderStatus{State: asrcontract.ProviderStateIdle}, nil
|
|
}
|
|
func (p fakeCapabilityProvider) LoadModel(context.Context, asrcontract.LoadModelRequest) error {
|
|
return nil
|
|
}
|
|
func (p fakeCapabilityProvider) UnloadModel(context.Context, asrcontract.UnloadModelRequest) error {
|
|
return nil
|
|
}
|
|
func (p fakeCapabilityProvider) LoadedModels(context.Context) ([]asrcontract.LoadedModel, error) {
|
|
return nil, nil
|
|
}
|
|
func (p fakeCapabilityProvider) ExecuteTask(context.Context, engineprovider.TaskRequest) (*engineprovider.TaskResult, error) {
|
|
return nil, nil
|
|
}
|
|
func (p fakeCapabilityProvider) Close() error { return nil }
|
|
|
|
func testASRModel(provider, id, name, modelType string, isDefault bool, capabilities ...asrcontract.Capability) asrcontract.ModelCard {
|
|
return asrcontract.ModelCard{
|
|
ID: id,
|
|
DisplayName: name,
|
|
Provider: provider,
|
|
ModelType: modelType,
|
|
Installed: true,
|
|
Default: isDefault,
|
|
Capabilities: testASRCapabilities(capabilities...),
|
|
}
|
|
}
|
|
|
|
func testASRCapabilities(capabilities ...asrcontract.Capability) asrcontract.Capabilities {
|
|
out := asrcontract.Capabilities{Extensions: map[string]bool{}}
|
|
for _, capability := range capabilities {
|
|
switch capability {
|
|
case asrcontract.CapabilityTranscription:
|
|
out.Transcription = true
|
|
case asrcontract.CapabilityDiarization:
|
|
out.Diarization = true
|
|
case asrcontract.CapabilitySpeakerIdentification:
|
|
out.SpeakerIdentification = true
|
|
case asrcontract.CapabilityWordTimestamps:
|
|
out.WordTimestamps = true
|
|
default:
|
|
out.Extensions[string(capability)] = true
|
|
}
|
|
}
|
|
if len(out.Extensions) == 0 {
|
|
out.Extensions = nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func TestASRModelCatalogEndpointFiltersCapabilities(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
registry, err := engineprovider.NewRegistry("local", fakeCapabilityProvider{models: []asrcontract.ModelCard{
|
|
{
|
|
ID: "parakeet-v2",
|
|
DisplayName: "NVIDIA Parakeet TDT v2",
|
|
Provider: "local",
|
|
ModelType: "nemo_transducer",
|
|
Installed: true,
|
|
Capabilities: testASRCapabilities(
|
|
asrcontract.CapabilityTranscription,
|
|
asrcontract.CapabilityWordTimestamps,
|
|
asrcontract.CapabilitySegmentTimestamps,
|
|
),
|
|
ParameterSchema: asrcontract.ParameterSchema{
|
|
{Key: asrcontract.CommonParameterChunkingMode, Type: asrcontract.ParameterTypeEnum, Scope: asrcontract.ParameterScopeChunking, Options: []asrcontract.ParameterOption{{Value: "fixed", Label: "Fixed"}}},
|
|
{Key: "sherpa.model_type", Type: asrcontract.ParameterTypeString, Scope: asrcontract.ParameterScopeModel, Default: "nemo_transducer"},
|
|
},
|
|
RecommendedDefaults: map[string]any{"sherpa.model_type": "nemo_transducer"},
|
|
},
|
|
{
|
|
ID: "parakeet-v3",
|
|
DisplayName: "NVIDIA Parakeet TDT v3",
|
|
Provider: "local",
|
|
ModelType: "nemo_transducer",
|
|
Installed: true,
|
|
Capabilities: testASRCapabilities(
|
|
asrcontract.CapabilityTranscription,
|
|
asrcontract.CapabilityWordTimestamps,
|
|
asrcontract.CapabilitySegmentTimestamps,
|
|
),
|
|
ParameterSchema: asrcontract.ParameterSchema{
|
|
{Key: asrcontract.CommonParameterChunkingMode, Type: asrcontract.ParameterTypeEnum, Scope: asrcontract.ParameterScopeChunking, Options: []asrcontract.ParameterOption{{Value: "fixed", Label: "Fixed"}}},
|
|
{Key: "sherpa.model_type", Type: asrcontract.ParameterTypeString, Scope: asrcontract.ParameterScopeModel, Default: "nemo_transducer"},
|
|
},
|
|
RecommendedDefaults: map[string]any{"sherpa.model_type": "nemo_transducer"},
|
|
},
|
|
{
|
|
ID: "diarization-default",
|
|
DisplayName: "Pyannote + 3D-Speaker Diarization",
|
|
Provider: "local",
|
|
ModelType: "diarization",
|
|
Installed: true,
|
|
Capabilities: testASRCapabilities(
|
|
asrcontract.CapabilityDiarization,
|
|
),
|
|
ParameterSchema: asrcontract.ParameterSchema{
|
|
{Key: "diarization.num_clusters", Type: asrcontract.ParameterTypeInteger, Scope: asrcontract.ParameterScopeModel, Default: float64(0)},
|
|
},
|
|
},
|
|
{
|
|
ID: "speaker-id-default",
|
|
DisplayName: "Speaker ID",
|
|
Provider: "local",
|
|
ModelType: "speaker_identification",
|
|
Installed: true,
|
|
Capabilities: testASRCapabilities(asrcontract.CapabilitySpeakerIdentification),
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
s.handler.modelRegistry = registry
|
|
token := registerForFileTests(t, s)
|
|
|
|
resp, body := s.request(t, http.MethodGet, "/api/v1/models", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.ElementsMatch(t, []string{"parakeet-v2", "parakeet-v3", "diarization-default", "speaker-id-default"}, modelIDsFromResponse(body))
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/models?capability=transcription", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.ElementsMatch(t, []string{"parakeet-v2", "parakeet-v3"}, modelIDsFromResponse(body))
|
|
requireModelSchemaKey(t, body, "parakeet-v2", "sherpa.model_type")
|
|
requireModelSchemaKey(t, body, "parakeet-v3", "sherpa.model_type")
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/models?capability=diarization", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.ElementsMatch(t, []string{"diarization-default"}, modelIDsFromResponse(body))
|
|
requireModelSchemaKey(t, body, "diarization-default", "diarization.num_clusters")
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/models?capability=transcription,diarization", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.ElementsMatch(t, []string{"parakeet-v2", "parakeet-v3", "diarization-default"}, modelIDsFromResponse(body))
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/models?capability=unknown", nil, token, "")
|
|
require.Equal(t, http.StatusUnprocessableEntity, resp.Code)
|
|
require.Equal(t, "VALIDATION_ERROR", body["error"].(map[string]any)["code"])
|
|
}
|
|
|
|
func modelIDsFromResponse(body map[string]any) []string {
|
|
items := body["items"].([]any)
|
|
out := make([]string, 0, len(items))
|
|
for _, item := range items {
|
|
model := item.(map[string]any)
|
|
out = append(out, model["id"].(string))
|
|
}
|
|
return out
|
|
}
|
|
|
|
func requireModelSchemaKey(t *testing.T, body map[string]any, modelID string, key string) {
|
|
t.Helper()
|
|
items := body["items"].([]any)
|
|
for _, item := range items {
|
|
model := item.(map[string]any)
|
|
if model["id"] != modelID {
|
|
continue
|
|
}
|
|
schema := model["parameter_schema"].([]any)
|
|
for _, rawParameter := range schema {
|
|
parameter := rawParameter.(map[string]any)
|
|
if parameter["key"] == key {
|
|
return
|
|
}
|
|
}
|
|
require.Failf(t, "missing schema key", "model %q missing parameter %q", modelID, key)
|
|
}
|
|
require.Failf(t, "missing model", "model %q not found", modelID)
|
|
}
|
|
|
|
type fakeAdminASRProvider struct {
|
|
mu sync.Mutex
|
|
id string
|
|
status asrcontract.ProviderStatus
|
|
models []asrcontract.ModelCard
|
|
loaded []asrcontract.LoadedModel
|
|
loads []asrcontract.LoadModelRequest
|
|
unloads []asrcontract.UnloadModelRequest
|
|
loadErr error
|
|
unloadErr error
|
|
statusErr error
|
|
inspectErr error
|
|
modelsErr error
|
|
loadedErr error
|
|
}
|
|
|
|
func (p *fakeAdminASRProvider) ID() string { return p.id }
|
|
func (p *fakeAdminASRProvider) Inspect(context.Context) (*asrcontract.ProviderInfo, error) {
|
|
if p.inspectErr != nil {
|
|
return nil, p.inspectErr
|
|
}
|
|
return &asrcontract.ProviderInfo{
|
|
ContractVersion: asrcontract.ContractVersionV1,
|
|
Provider: asrcontract.ProviderIdentity{ID: p.id, Name: "Fake ASR"},
|
|
Runtime: asrcontract.RuntimeInfo{DeviceBackends: []string{"cpu"}, MaxConcurrentJobs: 1},
|
|
}, nil
|
|
}
|
|
func (p *fakeAdminASRProvider) Models(context.Context) ([]asrcontract.ModelCard, error) {
|
|
if p.modelsErr != nil {
|
|
return nil, p.modelsErr
|
|
}
|
|
return p.models, nil
|
|
}
|
|
func (p *fakeAdminASRProvider) Status(context.Context) (*asrcontract.ProviderStatus, error) {
|
|
if p.statusErr != nil {
|
|
return nil, p.statusErr
|
|
}
|
|
status := p.status
|
|
if status.State == "" {
|
|
status.State = asrcontract.ProviderStateIdle
|
|
}
|
|
return &status, nil
|
|
}
|
|
func (p *fakeAdminASRProvider) LoadModel(_ context.Context, req asrcontract.LoadModelRequest) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
p.loads = append(p.loads, req)
|
|
return p.loadErr
|
|
}
|
|
func (p *fakeAdminASRProvider) UnloadModel(_ context.Context, req asrcontract.UnloadModelRequest) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
p.unloads = append(p.unloads, req)
|
|
return p.unloadErr
|
|
}
|
|
func (p *fakeAdminASRProvider) LoadedModels(context.Context) ([]asrcontract.LoadedModel, error) {
|
|
if p.loadedErr != nil {
|
|
return nil, p.loadedErr
|
|
}
|
|
return p.loaded, nil
|
|
}
|
|
func (p *fakeAdminASRProvider) ExecuteTask(context.Context, engineprovider.TaskRequest) (*engineprovider.TaskResult, error) {
|
|
return nil, nil
|
|
}
|
|
func (p *fakeAdminASRProvider) Close() error { return nil }
|
|
|
|
func TestCreateSubmitRetryUseQueueService(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
queue := &fakeQueueService{}
|
|
setTestQueueService(s, queue)
|
|
token := registerForFileTests(t, s)
|
|
fileID, _ := createUploadedFileForTranscription(t, s, token)
|
|
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{
|
|
"file_id": fileID,
|
|
"title": "Queued by service",
|
|
}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
firstID := strings.TrimPrefix(body["id"].(string), "tr_")
|
|
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions/"+body["id"].(string)+":retry", nil, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
retryID := strings.TrimPrefix(body["id"].(string), "tr_")
|
|
|
|
require.Len(t, queue.enqueued, 2)
|
|
require.Equal(t, firstID, queue.enqueued[0])
|
|
require.Equal(t, retryID, queue.enqueued[1])
|
|
}
|
|
|
|
func TestCreateReturnsServiceUnavailableWhenQueueStopped(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
setTestQueueService(s, &fakeQueueService{err: worker.ErrQueueStopped})
|
|
token := registerForFileTests(t, s)
|
|
fileID, _ := createUploadedFileForTranscription(t, s, token)
|
|
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
|
|
|
|
require.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
|
require.Equal(t, "SERVICE_UNAVAILABLE", body["error"].(map[string]any)["code"])
|
|
|
|
var count int64
|
|
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).
|
|
Where("source_file_hash IS NOT NULL").
|
|
Count(&count).Error)
|
|
require.Equal(t, int64(1), count)
|
|
}
|
|
|
|
func TestRetryPreservesNewJobWhenQueueStopped(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
queue := &fakeQueueService{}
|
|
setTestQueueService(s, queue)
|
|
token := registerForFileTests(t, s)
|
|
fileID, _ := createUploadedFileForTranscription(t, s, token)
|
|
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
transcriptionID := body["id"].(string)
|
|
|
|
queue.err = worker.ErrQueueStopped
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions/"+transcriptionID+":retry", nil, token, "")
|
|
require.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
|
require.Equal(t, "SERVICE_UNAVAILABLE", body["error"].(map[string]any)["code"])
|
|
|
|
var count int64
|
|
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).
|
|
Where("source_file_hash IS NOT NULL").
|
|
Count(&count).Error)
|
|
require.Equal(t, int64(2), count)
|
|
}
|
|
|
|
func TestCancelUsesQueueServiceAndMapsConflict(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
queue := &fakeQueueService{cancelErr: worker.ErrStateConflict}
|
|
setTestQueueService(s, queue)
|
|
token := registerForFileTests(t, s)
|
|
fileID, _ := createUploadedFileForTranscription(t, s, token)
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
transcriptionID := body["id"].(string)
|
|
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions/"+transcriptionID+":cancel", nil, token, "")
|
|
|
|
require.Equal(t, http.StatusConflict, resp.Code)
|
|
require.Equal(t, "CONFLICT", body["error"].(map[string]any)["code"])
|
|
require.Equal(t, strings.TrimPrefix(transcriptionID, "tr_"), queue.canceled[0])
|
|
}
|
|
|
|
func TestTranscriptExecutionsLogsModelsAndStatsUseEngineServices(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
queue := &fakeQueueService{stats: worker.QueueStats{Queued: 2, Processing: 1, Completed: 3, Failed: 4, Canceled: 5, Running: 1}}
|
|
setTestQueueService(s, queue)
|
|
registry, err := engineprovider.NewRegistry("local", fakeCapabilityProvider{models: []asrcontract.ModelCard{
|
|
testASRModel("local", "whisper-base", "Whisper Base", "whisper", true, asrcontract.CapabilityTranscription, asrcontract.CapabilityWordTimestamps),
|
|
}})
|
|
require.NoError(t, err)
|
|
s.handler.modelRegistry = registry
|
|
|
|
token := registerForFileTests(t, s)
|
|
userID := firstUserID(t)
|
|
fileID, _ := createUploadedFileForTranscription(t, s, token)
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
transcriptionID := body["id"].(string)
|
|
jobID := strings.TrimPrefix(transcriptionID, "tr_")
|
|
|
|
now := time.Now().UTC().Truncate(time.Millisecond)
|
|
transcript := `{"text":"hello","segments":[{"id":"seg_000001","start":0,"end":1,"speaker":"SPEAKER_00","text":"hello"}]}`
|
|
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).Where("id = ?", jobID).Updates(map[string]any{
|
|
"status": models.StatusCompleted,
|
|
"transcript_text": transcript,
|
|
"progress": 1.0,
|
|
"progress_stage": "completed",
|
|
"started_at": now.Add(-time.Minute),
|
|
"completed_at": now,
|
|
}).Error)
|
|
errorMessage := "failed at /tmp/private/model.bin api_key=secret-value"
|
|
require.NoError(t, database.DB.Create(&models.TranscriptionJobExecution{
|
|
TranscriptionJobID: jobID,
|
|
UserID: userID,
|
|
Status: models.StatusFailed,
|
|
Provider: "local",
|
|
ModelName: "whisper-base",
|
|
StartedAt: now.Add(-time.Minute),
|
|
FailedAt: &now,
|
|
ErrorMessage: &errorMessage,
|
|
}).Error)
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Equal(t, float64(1), body["progress"])
|
|
require.Equal(t, "completed", body["progress_stage"])
|
|
require.NotNil(t, body["started_at"])
|
|
require.NotNil(t, body["completed_at"])
|
|
|
|
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).Where("id = ?", jobID).Updates(map[string]any{
|
|
"status": models.StatusFailed,
|
|
"last_error": errorMessage,
|
|
}).Error)
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.NotContains(t, body["error"], "/tmp/private")
|
|
require.NotContains(t, body["error"], "secret-value")
|
|
require.Contains(t, body["error"], "[redacted-path]")
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID+"/transcript", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Equal(t, "hello", body["text"])
|
|
require.Empty(t, body["words"])
|
|
require.Len(t, body["segments"].([]any), 1)
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID+"/executions", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
execution := body["items"].([]any)[0].(map[string]any)
|
|
require.Equal(t, "local", execution["provider"])
|
|
require.Equal(t, "whisper-base", execution["model"])
|
|
require.NotContains(t, execution["error"], "/tmp/private")
|
|
require.NotContains(t, execution["error"], "secret-value")
|
|
|
|
resp, rawLogs := s.rawRequest(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID+"/logs", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.NotContains(t, rawLogs, "/tmp/private")
|
|
require.NotContains(t, rawLogs, "secret-value")
|
|
require.Contains(t, rawLogs, "[redacted-path]")
|
|
require.Contains(t, rawLogs, "\nfailed_at=")
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/models/transcription", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
model := body["items"].([]any)[0].(map[string]any)
|
|
require.Equal(t, "whisper-base", model["id"])
|
|
require.Equal(t, true, model["installed"])
|
|
require.Equal(t, true, model["default"])
|
|
require.Equal(t, true, model["capabilities"].(map[string]any)["transcription"])
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/queue", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Equal(t, float64(2), body["queued"])
|
|
require.Equal(t, float64(1), body["running"])
|
|
}
|
|
|
|
func TestQueueServiceErrorDoesNotLeakInternals(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
setTestQueueService(s, &fakeQueueService{err: errors.New("open /tmp/private/socket token=secret failed")})
|
|
token := registerForFileTests(t, s)
|
|
fileID, _ := createUploadedFileForTranscription(t, s, token)
|
|
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
|
|
|
|
require.Equal(t, http.StatusInternalServerError, resp.Code)
|
|
message := body["error"].(map[string]any)["message"].(string)
|
|
require.NotContains(t, message, "/tmp/private")
|
|
require.NotContains(t, message, "secret")
|
|
}
|
|
|
|
func TestAdminASRProviderDiagnosticsAndModelCommands(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
adminToken := registerForFileTests(t, s)
|
|
loadedAt := time.Now().UTC().Truncate(time.Millisecond)
|
|
memory := 512
|
|
progress := 0.42
|
|
provider := &fakeAdminASRProvider{
|
|
id: "local",
|
|
status: asrcontract.ProviderStatus{
|
|
State: asrcontract.ProviderStateBusy,
|
|
ActiveJob: &asrcontract.ActiveJob{
|
|
ID: "job-/tmp/private/audio.wav",
|
|
Operation: asrcontract.OperationTranscription,
|
|
Model: "whisper-base api_key=secret",
|
|
Stage: asrcontract.StageRunning,
|
|
Progress: &progress,
|
|
},
|
|
Capacity: asrcontract.ProviderCapacity{MaxConcurrentJobs: 1},
|
|
},
|
|
models: []asrcontract.ModelCard{{
|
|
ID: "whisper-base",
|
|
DisplayName: "Whisper Base",
|
|
Provider: "local",
|
|
ModelType: "whisper",
|
|
Installed: true,
|
|
Loaded: true,
|
|
Default: true,
|
|
SourceURL: "file:///tmp/private/model.bin?api_key=secret",
|
|
Capabilities: asrcontract.Capabilities{
|
|
Transcription: true,
|
|
WordTimestamps: true,
|
|
},
|
|
}},
|
|
loaded: []asrcontract.LoadedModel{{ID: "whisper-base", LoadedAt: &loadedAt, MemoryMB: &memory}},
|
|
}
|
|
registry, err := engineprovider.NewRegistry("local", provider)
|
|
require.NoError(t, err)
|
|
s.handler.modelRegistry = registry
|
|
|
|
resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, adminToken, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
item := body["items"].([]any)[0].(map[string]any)
|
|
require.Equal(t, "local", item["id"])
|
|
status := item["status"].(map[string]any)
|
|
activeJob := status["active_job"].(map[string]any)
|
|
require.NotContains(t, activeJob["id"], "/tmp/private")
|
|
require.NotContains(t, activeJob["model"], "secret")
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/local", nil, adminToken, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
model := body["models"].([]any)[0].(map[string]any)
|
|
require.Equal(t, "whisper-base", model["id"])
|
|
require.NotContains(t, model, "source_url")
|
|
require.Len(t, body["loaded_models"].([]any), 1)
|
|
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{
|
|
"model": "whisper-base",
|
|
"operation": "transcription",
|
|
"load_policy": "require",
|
|
}, adminToken, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
require.Equal(t, "loading", body["status"])
|
|
require.Len(t, provider.loads, 1)
|
|
require.Equal(t, asrcontract.OperationTranscription, provider.loads[0].Operation)
|
|
require.Equal(t, asrcontract.LoadPolicyRequire, provider.loads[0].LoadPolicy)
|
|
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/unload", map[string]any{
|
|
"model": "whisper-base",
|
|
"force": true,
|
|
}, adminToken, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
require.Equal(t, "unloading", body["status"])
|
|
require.Len(t, provider.unloads, 1)
|
|
require.True(t, provider.unloads[0].Force)
|
|
}
|
|
|
|
func TestAdminASRProviderRoutesAreAdminOnlyAndMapProviderErrors(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
adminToken := registerForFileTests(t, s)
|
|
nonAdmin := models.User{Username: "asr-member", Password: "pw", Role: "user"}
|
|
require.NoError(t, database.DB.Create(&nonAdmin).Error)
|
|
nonAdminToken, err := auth.NewAuthService("test-secret").GenerateToken(&nonAdmin)
|
|
require.NoError(t, err)
|
|
provider := &fakeAdminASRProvider{
|
|
id: "local",
|
|
loadErr: asrcontract.NewProviderError(asrcontract.CodeProviderBusy, "busy at /tmp/private/model.bin token=secret", true),
|
|
}
|
|
registry, err := engineprovider.NewRegistry("local", provider)
|
|
require.NoError(t, err)
|
|
s.handler.modelRegistry = registry
|
|
|
|
resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, nonAdminToken, "")
|
|
require.Equal(t, http.StatusForbidden, resp.Code)
|
|
require.Equal(t, "FORBIDDEN", body["error"].(map[string]any)["code"])
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/missing", nil, adminToken, "")
|
|
require.Equal(t, http.StatusNotFound, resp.Code)
|
|
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{
|
|
"model": "whisper-base",
|
|
}, adminToken, "")
|
|
require.Equal(t, http.StatusConflict, resp.Code)
|
|
errBody := body["error"].(map[string]any)
|
|
require.Equal(t, string(asrcontract.CodeProviderBusy), errBody["code"])
|
|
require.NotContains(t, errBody["message"], "/tmp/private")
|
|
require.NotContains(t, errBody["message"], "secret")
|
|
}
|