Files
Scriberr/internal/api/engine_worker_api_test.go
2026-05-04 00:48:24 -07:00

470 lines
18 KiB
Go

package api
import (
"context"
"errors"
"net/http"
"strings"
"sync"
"testing"
"time"
"scriberr/internal/auth"
"scriberr/internal/database"
"scriberr/internal/models"
"scriberr/internal/transcription/asrcontract"
"scriberr/internal/transcription/engineprovider"
"scriberr/internal/transcription/worker"
"github.com/stretchr/testify/require"
)
type fakeQueueService struct {
mu sync.Mutex
enqueued []string
canceled []string
stats worker.QueueStats
adminStats worker.AdminQueueStats
err error
cancelErr error
}
func (q *fakeQueueService) Enqueue(ctx context.Context, jobID string) error {
q.mu.Lock()
defer q.mu.Unlock()
q.enqueued = append(q.enqueued, jobID)
return q.err
}
func (q *fakeQueueService) Cancel(ctx context.Context, userID uint, jobID string) error {
q.mu.Lock()
defer q.mu.Unlock()
q.canceled = append(q.canceled, jobID)
return q.cancelErr
}
func (q *fakeQueueService) Start(context.Context) error { return nil }
func (q *fakeQueueService) Stop(context.Context) error { return nil }
func (q *fakeQueueService) Stats(context.Context, uint) (worker.QueueStats, error) {
return q.stats, q.err
}
func (q *fakeQueueService) AdminStats(context.Context) (worker.AdminQueueStats, error) {
if q.adminStats.ByUser == nil {
return worker.AdminQueueStats{QueueStats: q.stats}, q.err
}
return q.adminStats, q.err
}
func setTestQueueService(s *authTestServer, queue *fakeQueueService) {
s.handler.queueService = queue
if s.handler.transcriptions != nil {
s.handler.transcriptions.SetQueue(queue)
}
}
type fakeCapabilityProvider struct {
caps []engineprovider.ModelCapability
}
func (p fakeCapabilityProvider) ID() string { return "local" }
func (p fakeCapabilityProvider) Inspect(context.Context) (*asrcontract.ProviderInfo, error) {
return &asrcontract.ProviderInfo{ContractVersion: asrcontract.ContractVersionV1}, nil
}
func (p fakeCapabilityProvider) Models(context.Context) ([]asrcontract.ModelCard, error) {
return nil, nil
}
func (p fakeCapabilityProvider) Status(context.Context) (*asrcontract.ProviderStatus, error) {
return &asrcontract.ProviderStatus{State: asrcontract.ProviderStateIdle}, nil
}
func (p fakeCapabilityProvider) LoadModel(context.Context, asrcontract.LoadModelRequest) error {
return nil
}
func (p fakeCapabilityProvider) UnloadModel(context.Context, asrcontract.UnloadModelRequest) error {
return nil
}
func (p fakeCapabilityProvider) LoadedModels(context.Context) ([]asrcontract.LoadedModel, error) {
return nil, nil
}
func (p fakeCapabilityProvider) Capabilities(context.Context) ([]engineprovider.ModelCapability, error) {
return p.caps, nil
}
func (p fakeCapabilityProvider) Prepare(context.Context) error { return nil }
func (p fakeCapabilityProvider) Transcribe(context.Context, engineprovider.TranscriptionRequest) (*engineprovider.TranscriptionResult, error) {
return nil, nil
}
func (p fakeCapabilityProvider) Diarize(context.Context, engineprovider.DiarizationRequest) (*engineprovider.DiarizationResult, error) {
return nil, nil
}
func (p fakeCapabilityProvider) IdentifySpeakers(context.Context, asrcontract.SpeakerIDRequest) (*asrcontract.SpeakerIDResult, error) {
return nil, asrcontract.NewProviderError(asrcontract.CodeUnsupportedOperation, "speaker identification is not supported", false)
}
func (p fakeCapabilityProvider) Close() error { return nil }
type fakeAdminASRProvider struct {
mu sync.Mutex
id string
status asrcontract.ProviderStatus
models []asrcontract.ModelCard
loaded []asrcontract.LoadedModel
loads []asrcontract.LoadModelRequest
unloads []asrcontract.UnloadModelRequest
loadErr error
unloadErr error
statusErr error
inspectErr error
modelsErr error
loadedErr error
}
func (p *fakeAdminASRProvider) ID() string { return p.id }
func (p *fakeAdminASRProvider) Inspect(context.Context) (*asrcontract.ProviderInfo, error) {
if p.inspectErr != nil {
return nil, p.inspectErr
}
return &asrcontract.ProviderInfo{
ContractVersion: asrcontract.ContractVersionV1,
Provider: asrcontract.ProviderIdentity{ID: p.id, Name: "Fake ASR"},
Runtime: asrcontract.RuntimeInfo{DeviceBackends: []string{"cpu"}, MaxConcurrentJobs: 1},
}, nil
}
func (p *fakeAdminASRProvider) Models(context.Context) ([]asrcontract.ModelCard, error) {
if p.modelsErr != nil {
return nil, p.modelsErr
}
return p.models, nil
}
func (p *fakeAdminASRProvider) Status(context.Context) (*asrcontract.ProviderStatus, error) {
if p.statusErr != nil {
return nil, p.statusErr
}
status := p.status
if status.State == "" {
status.State = asrcontract.ProviderStateIdle
}
return &status, nil
}
func (p *fakeAdminASRProvider) LoadModel(_ context.Context, req asrcontract.LoadModelRequest) error {
p.mu.Lock()
defer p.mu.Unlock()
p.loads = append(p.loads, req)
return p.loadErr
}
func (p *fakeAdminASRProvider) UnloadModel(_ context.Context, req asrcontract.UnloadModelRequest) error {
p.mu.Lock()
defer p.mu.Unlock()
p.unloads = append(p.unloads, req)
return p.unloadErr
}
func (p *fakeAdminASRProvider) LoadedModels(context.Context) ([]asrcontract.LoadedModel, error) {
if p.loadedErr != nil {
return nil, p.loadedErr
}
return p.loaded, nil
}
func (p *fakeAdminASRProvider) Capabilities(context.Context) ([]engineprovider.ModelCapability, error) {
return []engineprovider.ModelCapability{{ID: "whisper-base", Provider: p.id, Installed: true, Capabilities: []string{"transcription"}}}, nil
}
func (p *fakeAdminASRProvider) Prepare(context.Context) error { return nil }
func (p *fakeAdminASRProvider) Transcribe(context.Context, engineprovider.TranscriptionRequest) (*engineprovider.TranscriptionResult, error) {
return nil, nil
}
func (p *fakeAdminASRProvider) Diarize(context.Context, engineprovider.DiarizationRequest) (*engineprovider.DiarizationResult, error) {
return nil, nil
}
func (p *fakeAdminASRProvider) IdentifySpeakers(context.Context, asrcontract.SpeakerIDRequest) (*asrcontract.SpeakerIDResult, error) {
return nil, nil
}
func (p *fakeAdminASRProvider) Close() error { return nil }
func TestCreateSubmitRetryUseQueueService(t *testing.T) {
s := newAuthTestServer(t)
queue := &fakeQueueService{}
setTestQueueService(s, queue)
token := registerForFileTests(t, s)
fileID, _ := createUploadedFileForTranscription(t, s, token)
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{
"file_id": fileID,
"title": "Queued by service",
}, token, "")
require.Equal(t, http.StatusAccepted, resp.Code)
firstID := strings.TrimPrefix(body["id"].(string), "tr_")
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions/"+body["id"].(string)+":retry", nil, token, "")
require.Equal(t, http.StatusAccepted, resp.Code)
retryID := strings.TrimPrefix(body["id"].(string), "tr_")
require.Len(t, queue.enqueued, 2)
require.Equal(t, firstID, queue.enqueued[0])
require.Equal(t, retryID, queue.enqueued[1])
}
func TestCreateReturnsServiceUnavailableWhenQueueStopped(t *testing.T) {
s := newAuthTestServer(t)
setTestQueueService(s, &fakeQueueService{err: worker.ErrQueueStopped})
token := registerForFileTests(t, s)
fileID, _ := createUploadedFileForTranscription(t, s, token)
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
require.Equal(t, http.StatusServiceUnavailable, resp.Code)
require.Equal(t, "SERVICE_UNAVAILABLE", body["error"].(map[string]any)["code"])
var count int64
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).
Where("source_file_hash IS NOT NULL").
Count(&count).Error)
require.Equal(t, int64(1), count)
}
func TestRetryPreservesNewJobWhenQueueStopped(t *testing.T) {
s := newAuthTestServer(t)
queue := &fakeQueueService{}
setTestQueueService(s, queue)
token := registerForFileTests(t, s)
fileID, _ := createUploadedFileForTranscription(t, s, token)
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
require.Equal(t, http.StatusAccepted, resp.Code)
transcriptionID := body["id"].(string)
queue.err = worker.ErrQueueStopped
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions/"+transcriptionID+":retry", nil, token, "")
require.Equal(t, http.StatusServiceUnavailable, resp.Code)
require.Equal(t, "SERVICE_UNAVAILABLE", body["error"].(map[string]any)["code"])
var count int64
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).
Where("source_file_hash IS NOT NULL").
Count(&count).Error)
require.Equal(t, int64(2), count)
}
func TestCancelUsesQueueServiceAndMapsConflict(t *testing.T) {
s := newAuthTestServer(t)
queue := &fakeQueueService{cancelErr: worker.ErrStateConflict}
setTestQueueService(s, queue)
token := registerForFileTests(t, s)
fileID, _ := createUploadedFileForTranscription(t, s, token)
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
require.Equal(t, http.StatusAccepted, resp.Code)
transcriptionID := body["id"].(string)
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions/"+transcriptionID+":cancel", nil, token, "")
require.Equal(t, http.StatusConflict, resp.Code)
require.Equal(t, "CONFLICT", body["error"].(map[string]any)["code"])
require.Equal(t, strings.TrimPrefix(transcriptionID, "tr_"), queue.canceled[0])
}
func TestTranscriptExecutionsLogsModelsAndStatsUseEngineServices(t *testing.T) {
s := newAuthTestServer(t)
queue := &fakeQueueService{stats: worker.QueueStats{Queued: 2, Processing: 1, Completed: 3, Failed: 4, Canceled: 5, Running: 1}}
setTestQueueService(s, queue)
registry, err := engineprovider.NewRegistry("local", fakeCapabilityProvider{caps: []engineprovider.ModelCapability{
{ID: "whisper-base", Name: "Whisper Base", Provider: "local", Installed: true, Default: true, Capabilities: []string{"transcription", "word_timestamps"}},
}})
require.NoError(t, err)
s.handler.modelRegistry = registry
token := registerForFileTests(t, s)
userID := firstUserID(t)
fileID, _ := createUploadedFileForTranscription(t, s, token)
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
require.Equal(t, http.StatusAccepted, resp.Code)
transcriptionID := body["id"].(string)
jobID := strings.TrimPrefix(transcriptionID, "tr_")
now := time.Now().UTC().Truncate(time.Millisecond)
transcript := `{"text":"hello","segments":[{"id":"seg_000001","start":0,"end":1,"speaker":"SPEAKER_00","text":"hello"}]}`
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).Where("id = ?", jobID).Updates(map[string]any{
"status": models.StatusCompleted,
"transcript_text": transcript,
"progress": 1.0,
"progress_stage": "completed",
"started_at": now.Add(-time.Minute),
"completed_at": now,
}).Error)
errorMessage := "failed at /tmp/private/model.bin api_key=secret-value"
require.NoError(t, database.DB.Create(&models.TranscriptionJobExecution{
TranscriptionJobID: jobID,
UserID: userID,
Status: models.StatusFailed,
Provider: "local",
ModelName: "whisper-base",
StartedAt: now.Add(-time.Minute),
FailedAt: &now,
ErrorMessage: &errorMessage,
}).Error)
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID, nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, float64(1), body["progress"])
require.Equal(t, "completed", body["progress_stage"])
require.NotNil(t, body["started_at"])
require.NotNil(t, body["completed_at"])
require.NoError(t, database.DB.Model(&models.TranscriptionJob{}).Where("id = ?", jobID).Updates(map[string]any{
"status": models.StatusFailed,
"last_error": errorMessage,
}).Error)
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID, nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
require.NotContains(t, body["error"], "/tmp/private")
require.NotContains(t, body["error"], "secret-value")
require.Contains(t, body["error"], "[redacted-path]")
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID+"/transcript", nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, "hello", body["text"])
require.Empty(t, body["words"])
require.Len(t, body["segments"].([]any), 1)
resp, body = s.request(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID+"/executions", nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
execution := body["items"].([]any)[0].(map[string]any)
require.Equal(t, "local", execution["provider"])
require.Equal(t, "whisper-base", execution["model"])
require.NotContains(t, execution["error"], "/tmp/private")
require.NotContains(t, execution["error"], "secret-value")
resp, rawLogs := s.rawRequest(t, http.MethodGet, "/api/v1/transcriptions/"+transcriptionID+"/logs", nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
require.NotContains(t, rawLogs, "/tmp/private")
require.NotContains(t, rawLogs, "secret-value")
require.Contains(t, rawLogs, "[redacted-path]")
require.Contains(t, rawLogs, "\nfailed_at=")
resp, body = s.request(t, http.MethodGet, "/api/v1/models/transcription", nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
model := body["items"].([]any)[0].(map[string]any)
require.Equal(t, "whisper-base", model["id"])
require.Equal(t, true, model["installed"])
require.Equal(t, true, model["default"])
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/queue", nil, token, "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, float64(2), body["queued"])
require.Equal(t, float64(1), body["running"])
}
func TestQueueServiceErrorDoesNotLeakInternals(t *testing.T) {
s := newAuthTestServer(t)
setTestQueueService(s, &fakeQueueService{err: errors.New("open /tmp/private/socket token=secret failed")})
token := registerForFileTests(t, s)
fileID, _ := createUploadedFileForTranscription(t, s, token)
resp, body := s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{"file_id": fileID}, token, "")
require.Equal(t, http.StatusInternalServerError, resp.Code)
message := body["error"].(map[string]any)["message"].(string)
require.NotContains(t, message, "/tmp/private")
require.NotContains(t, message, "secret")
}
func TestAdminASRProviderDiagnosticsAndModelCommands(t *testing.T) {
s := newAuthTestServer(t)
adminToken := registerForFileTests(t, s)
loadedAt := time.Now().UTC().Truncate(time.Millisecond)
memory := 512
progress := 0.42
provider := &fakeAdminASRProvider{
id: "local",
status: asrcontract.ProviderStatus{
State: asrcontract.ProviderStateBusy,
ActiveJob: &asrcontract.ActiveJob{
ID: "job-/tmp/private/audio.wav",
Operation: asrcontract.OperationTranscription,
Model: "whisper-base api_key=secret",
Stage: asrcontract.StageTranscribing,
Progress: &progress,
},
Capacity: asrcontract.ProviderCapacity{MaxConcurrentJobs: 1},
},
models: []asrcontract.ModelCard{{
ID: "whisper-base",
DisplayName: "Whisper Base",
Provider: "local",
Family: "whisper",
Installed: true,
Loaded: true,
Default: true,
SourceURL: "file:///tmp/private/model.bin?api_key=secret",
Capabilities: asrcontract.Capabilities{
Transcription: true,
WordTimestamps: true,
},
}},
loaded: []asrcontract.LoadedModel{{ID: "whisper-base", LoadedAt: &loadedAt, MemoryMB: &memory}},
}
registry, err := engineprovider.NewRegistry("local", provider)
require.NoError(t, err)
s.handler.modelRegistry = registry
resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, adminToken, "")
require.Equal(t, http.StatusOK, resp.Code)
item := body["items"].([]any)[0].(map[string]any)
require.Equal(t, "local", item["id"])
status := item["status"].(map[string]any)
activeJob := status["active_job"].(map[string]any)
require.NotContains(t, activeJob["id"], "/tmp/private")
require.NotContains(t, activeJob["model"], "secret")
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/local", nil, adminToken, "")
require.Equal(t, http.StatusOK, resp.Code)
model := body["models"].([]any)[0].(map[string]any)
require.Equal(t, "whisper-base", model["id"])
require.NotContains(t, model, "source_url")
require.Len(t, body["loaded_models"].([]any), 1)
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{
"model": "whisper-base",
"operation": "transcription",
"load_policy": "require",
}, adminToken, "")
require.Equal(t, http.StatusAccepted, resp.Code)
require.Equal(t, "loading", body["status"])
require.Len(t, provider.loads, 1)
require.Equal(t, asrcontract.OperationTranscription, provider.loads[0].Operation)
require.Equal(t, asrcontract.LoadPolicyRequire, provider.loads[0].LoadPolicy)
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/unload", map[string]any{
"model": "whisper-base",
"force": true,
}, adminToken, "")
require.Equal(t, http.StatusAccepted, resp.Code)
require.Equal(t, "unloading", body["status"])
require.Len(t, provider.unloads, 1)
require.True(t, provider.unloads[0].Force)
}
func TestAdminASRProviderRoutesAreAdminOnlyAndMapProviderErrors(t *testing.T) {
s := newAuthTestServer(t)
adminToken := registerForFileTests(t, s)
nonAdmin := models.User{Username: "asr-member", Password: "pw", Role: "user"}
require.NoError(t, database.DB.Create(&nonAdmin).Error)
nonAdminToken, err := auth.NewAuthService("test-secret").GenerateToken(&nonAdmin)
require.NoError(t, err)
provider := &fakeAdminASRProvider{
id: "local",
loadErr: asrcontract.NewProviderError(asrcontract.CodeProviderBusy, "busy at /tmp/private/model.bin token=secret", true),
}
registry, err := engineprovider.NewRegistry("local", provider)
require.NoError(t, err)
s.handler.modelRegistry = registry
resp, body := s.request(t, http.MethodGet, "/api/v1/admin/asr-providers", nil, nonAdminToken, "")
require.Equal(t, http.StatusForbidden, resp.Code)
require.Equal(t, "FORBIDDEN", body["error"].(map[string]any)["code"])
resp, body = s.request(t, http.MethodGet, "/api/v1/admin/asr-providers/missing", nil, adminToken, "")
require.Equal(t, http.StatusNotFound, resp.Code)
resp, body = s.request(t, http.MethodPost, "/api/v1/admin/asr-providers/local/models/load", map[string]any{
"model": "whisper-base",
}, adminToken, "")
require.Equal(t, http.StatusConflict, resp.Code)
errBody := body["error"].(map[string]any)
require.Equal(t, string(asrcontract.CodeProviderBusy), errBody["code"])
require.NotContains(t, errBody["message"], "/tmp/private")
require.NotContains(t, errBody["message"], "secret")
}