Files
Scriberr/internal/api/auth_test.go
2026-05-08 10:22:37 -07:00

463 lines
17 KiB
Go

package api
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"scriberr/internal/account"
admindomain "scriberr/internal/admin"
"scriberr/internal/annotations"
"scriberr/internal/auth"
"scriberr/internal/automation"
chatdomain "scriberr/internal/chat"
"scriberr/internal/config"
"scriberr/internal/database"
filesdomain "scriberr/internal/files"
"scriberr/internal/llmprovider"
"scriberr/internal/mediaimport"
"scriberr/internal/models"
profiledomain "scriberr/internal/profile"
recordingdomain "scriberr/internal/recording"
"scriberr/internal/repository"
"scriberr/internal/summarization"
"scriberr/internal/tags"
transcriptiondomain "scriberr/internal/transcription"
"scriberr/internal/transcription/asrcontract"
"scriberr/internal/transcription/engineprovider"
"scriberr/internal/transcription/orchestrator"
"scriberr/pkg/logger"
"github.com/stretchr/testify/require"
gormlogger "gorm.io/gorm/logger"
)
type authTestServer struct {
router http.Handler
auth *auth.AuthService
uploadDir string
handler *Handler
}
func newAuthTestServer(t *testing.T) *authTestServer {
t.Helper()
logger.Init("silent")
require.NoError(t, database.Initialize(filepath.Join(t.TempDir(), "scriberr.db")))
database.DB.Logger = gormlogger.Default.LogMode(gormlogger.Silent)
t.Cleanup(func() { _ = database.Close() })
authService := auth.NewAuthService("test-secret")
uploadDir := filepath.Join(t.TempDir(), "uploads")
cfg := &config.Config{
Environment: "test",
UploadDir: uploadDir,
Recordings: config.RecordingConfig{
Dir: filepath.Join(t.TempDir(), "recordings"),
MaxChunkBytes: 8,
MaxDuration: time.Hour,
SessionTTL: time.Hour,
AllowedMimeTypes: []string{"audio/webm;codecs=opus", "audio/webm"},
},
}
jobRepo := repository.NewJobRepository(database.DB)
profileRepo := repository.NewProfileRepository(database.DB)
rawLLMConfigRepo := repository.NewLLMConfigRepository(database.DB)
llmConfigRepo, err := llmprovider.NewProtectedRepository(rawLLMConfigRepo, "test-llm-credential-secret")
require.NoError(t, err)
accountService := account.NewService(
repository.NewUserRepository(database.DB),
repository.NewUserSettingsRepository(database.DB),
repository.NewRefreshTokenRepository(database.DB),
repository.NewAPIKeyRepository(database.DB),
profileRepo,
llmConfigRepo,
authService,
)
adminService := admindomain.NewService(
repository.NewUserRepository(database.DB),
repository.NewRefreshTokenRepository(database.DB),
repository.NewAPIKeyRepository(database.DB),
repository.NewSystemSettingsRepository(database.DB),
)
providerRegistry, err := engineprovider.NewRegistry("local", fakeCapabilityProvider{models: []asrcontract.ModelCard{
testASRModel("local", "whisper-base", "Whisper Base", "whisper", true, asrcontract.CapabilityTranscription, asrcontract.CapabilityWordTimestamps),
testASRModel("local", "whisper-small", "Whisper Small", "whisper", false, asrcontract.CapabilityTranscription, asrcontract.CapabilityWordTimestamps),
testASRModel("local", "parakeet-v2", "Parakeet v2", "nemo_transducer", false, asrcontract.CapabilityTranscription, asrcontract.CapabilityWordTimestamps),
testASRModel("local", "diarization-default", "Diarization", "diarization", true, asrcontract.CapabilityDiarization),
}})
require.NoError(t, err)
profileService := profiledomain.NewService(profileRepo, profiledomain.NewProviderModelCatalog(providerRegistry))
llmProviderService := llmprovider.NewService(llmConfigRepo, llmprovider.HTTPConnectionTester{})
fileService := filesdomain.NewService(jobRepo, filesdomain.Config{UploadDir: cfg.UploadDir})
mediaImportService := mediaimport.NewService(mediaimport.ServiceOptions{
Repository: jobRepo,
UploadDir: cfg.UploadDir,
})
transcriptionService := transcriptiondomain.NewService(jobRepo, profileRepo, nil)
transcriptionService.SetAudioStore(fileService)
transcriptionService.SetExecutionLogStore(orchestrator.NewLocalExecutionLogStore(cfg.TranscriptsDir))
summaryService := summarization.NewService(repository.NewSummaryRepository(database.DB), llmConfigRepo, jobRepo, summarization.Config{})
chatService := chatdomain.NewService(repository.NewChatRepository(database.DB), llmConfigRepo)
postFileAutomation := automation.NewService(jobRepo, repository.NewUserRepository(database.DB), profileRepo, llmConfigRepo, transcriptionService)
fileService.SetReadyObserver(postFileAutomation)
annotationService := annotations.NewService(repository.NewAnnotationRepository(database.DB), jobRepo)
tagService := tags.NewService(repository.NewTagRepository(database.DB), jobRepo)
recordingStorage, err := recordingdomain.NewStorage(cfg.Recordings.Dir)
require.NoError(t, err)
recordingService := recordingdomain.NewService(repository.NewRecordingRepository(database.DB), recordingStorage, recordingdomain.Config{
MaxChunkBytes: cfg.Recordings.MaxChunkBytes,
MaxDuration: cfg.Recordings.MaxDuration,
SessionTTL: cfg.Recordings.SessionTTL,
AllowedMimeTypes: cfg.Recordings.AllowedMimeTypes,
})
handler := NewHandler(cfg, authService, HandlerDependencies{
ReadinessCheck: func() error { return nil },
Account: accountService,
Admin: adminService,
Profiles: profileService,
LLMProvider: llmProviderService,
Files: fileService,
MediaImport: mediaImportService,
Annotations: annotationService,
Tags: tagService,
Recordings: recordingService,
Transcriptions: transcriptionService,
Summaries: summaryService,
Chat: chatService,
})
handler.modelRegistry = providerRegistry
postFileAutomation.SetEventPublisher(handler)
youtubeImporter := &fakeYouTubeImporter{block: make(chan struct{})}
mediaImportService.SetImporter(youtubeImporter)
t.Cleanup(func() {
youtubeImporter.unblock()
handler.asyncJobs.Wait()
})
return &authTestServer{router: SetupRoutes(handler, authService), auth: authService, uploadDir: uploadDir, handler: handler}
}
func (s *authTestServer) request(t *testing.T, method, path string, body any, token string, apiKey string) (*httptest.ResponseRecorder, map[string]any) {
t.Helper()
var payload bytes.Buffer
if body != nil {
require.NoError(t, json.NewEncoder(&payload).Encode(body))
}
req, err := http.NewRequest(method, path, &payload)
require.NoError(t, err)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
if apiKey != "" {
req.Header.Set("X-API-Key", apiKey)
}
recorder := httptest.NewRecorder()
s.router.ServeHTTP(recorder, req)
var response map[string]any
if recorder.Code != http.StatusNoContent {
require.NoError(t, json.NewDecoder(recorder.Body).Decode(&response))
}
return recorder, response
}
func (s *authTestServer) rawRequest(t *testing.T, method, path string, body any, token string, apiKey string) (*httptest.ResponseRecorder, string) {
t.Helper()
var payload bytes.Buffer
if body != nil {
require.NoError(t, json.NewEncoder(&payload).Encode(body))
}
req, err := http.NewRequest(method, path, &payload)
require.NoError(t, err)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
if apiKey != "" {
req.Header.Set("X-API-Key", apiKey)
}
recorder := httptest.NewRecorder()
s.router.ServeHTTP(recorder, req)
return recorder, recorder.Body.String()
}
func tokenForTestUser(t *testing.T, username, role string) string {
t.Helper()
user := models.User{Username: username, Password: "pw", Role: role}
require.NoError(t, database.DB.Create(&user).Error)
token, err := auth.NewAuthService("test-secret").GenerateToken(&user)
require.NoError(t, err)
return token
}
func currentTestUserID(t *testing.T, username string) uint {
t.Helper()
var user models.User
require.NoError(t, database.DB.Where("username = ?", username).First(&user).Error)
return user.ID
}
func TestAuthRegisterLoginRefreshMeLogout(t *testing.T) {
s := newAuthTestServer(t)
resp, body := s.request(t, http.MethodGet, "/api/v1/auth/registration-status", nil, "", "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, true, body["registration_enabled"])
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
"username": "admin",
"password": "password123",
"confirm_password": "password123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
require.NotEmpty(t, body["access_token"])
require.NotEmpty(t, body["refresh_token"])
user := body["user"].(map[string]any)
require.Equal(t, "user_self", user["id"])
require.Equal(t, "admin", user["username"])
var stored models.User
require.NoError(t, database.DB.Where("username = ?", "admin").First(&stored).Error)
require.Equal(t, models.UserStatusActive, stored.Status)
resp, body = s.request(t, http.MethodGet, "/api/v1/auth/registration-status", nil, "", "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, false, body["registration_enabled"])
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/login", map[string]any{
"username": "admin",
"password": "password123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
accessToken := body["access_token"].(string)
refreshToken := body["refresh_token"].(string)
resp, body = s.request(t, http.MethodGet, "/api/v1/auth/me", nil, accessToken, "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, "user_self", body["id"])
require.Equal(t, "admin", body["username"])
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/refresh", map[string]any{
"refresh_token": refreshToken,
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
require.NotEmpty(t, body["access_token"])
rotatedRefresh := body["refresh_token"].(string)
require.NotEqual(t, refreshToken, rotatedRefresh)
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/refresh", map[string]any{
"refresh_token": refreshToken,
}, "", "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/logout", map[string]any{
"refresh_token": rotatedRefresh,
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, true, body["ok"])
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/refresh", map[string]any{
"refresh_token": rotatedRefresh,
}, "", "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
}
func TestAuthValidationAndPasswordChanges(t *testing.T) {
s := newAuthTestServer(t)
resp, _ := s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
"username": "ad",
"password": "password123",
"confirm_password": "different",
}, "", "")
require.Equal(t, http.StatusUnprocessableEntity, resp.Code)
resp, body := s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
"username": "admin",
"password": "password123",
"confirm_password": "password123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
accessToken := body["access_token"].(string)
refreshToken := body["refresh_token"].(string)
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/change-password", map[string]any{
"current_password": "wrong",
"new_password": "newpassword123",
"confirm_password": "newpassword123",
}, accessToken, "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/change-password", map[string]any{
"current_password": "password123",
"new_password": "newpassword123",
"confirm_password": "newpassword123",
}, accessToken, "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, true, body["ok"])
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/refresh", map[string]any{
"refresh_token": refreshToken,
}, "", "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
var changed models.User
require.NoError(t, database.DB.Where("username = ?", "admin").First(&changed).Error)
require.NotNil(t, changed.PasswordChangedAt)
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/login", map[string]any{
"username": "admin",
"password": "password123",
}, "", "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/login", map[string]any{
"username": "admin",
"password": "newpassword123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
accessToken = body["access_token"].(string)
resp, body = s.request(t, http.MethodPost, "/api/v1/auth/change-username", map[string]any{
"new_username": "owner",
"password": "newpassword123",
}, accessToken, "")
require.Equal(t, http.StatusOK, resp.Code)
require.Equal(t, "user_self", body["id"])
require.Equal(t, "owner", body["username"])
}
func TestAPIKeyCreateListDeleteAndRedaction(t *testing.T) {
s := newAuthTestServer(t)
resp, body := s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
"username": "admin",
"password": "password123",
"confirm_password": "password123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
accessToken := body["access_token"].(string)
resp, body = s.request(t, http.MethodPost, "/api/v1/api-keys", map[string]any{
"name": "CLI",
"description": "Local scripts",
}, accessToken, "")
require.Equal(t, http.StatusCreated, resp.Code)
rawKey := body["key"].(string)
require.NotEmpty(t, rawKey)
require.Contains(t, rawKey, "sk_")
keyID := body["id"].(string)
var stored models.APIKey
require.NoError(t, database.DB.First(&stored).Error)
require.NotEqual(t, rawKey, stored.KeyHash)
require.Equal(t, sha256String(rawKey), stored.KeyHash)
resp, body = s.request(t, http.MethodGet, "/api/v1/api-keys", nil, accessToken, "")
require.Equal(t, http.StatusOK, resp.Code)
items := body["items"].([]any)
require.Len(t, items, 1)
item := items[0].(map[string]any)
require.Equal(t, keyID, item["id"])
require.NotContains(t, item, "key")
require.NotContains(t, item, "key_hash")
require.NotEmpty(t, item["key_preview"])
resp, _ = s.request(t, http.MethodGet, "/api/v1/files", nil, "", rawKey)
require.Equal(t, http.StatusOK, resp.Code)
idNumber, err := strconv.Atoi(strings.TrimPrefix(keyID, "key_"))
require.NoError(t, err)
resp, _ = s.request(t, http.MethodDelete, "/api/v1/api-keys/"+strconv.Itoa(idNumber), nil, accessToken, "")
require.Equal(t, http.StatusNoContent, resp.Code)
resp, _ = s.request(t, http.MethodGet, "/api/v1/files", nil, "", rawKey)
require.Equal(t, http.StatusUnauthorized, resp.Code)
}
func TestAPIKeyManagementRequiresJWT(t *testing.T) {
s := newAuthTestServer(t)
resp, body := s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
"username": "admin",
"password": "password123",
"confirm_password": "password123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
accessToken := body["access_token"].(string)
resp, body = s.request(t, http.MethodPost, "/api/v1/api-keys", map[string]any{"name": "CLI"}, accessToken, "")
require.Equal(t, http.StatusCreated, resp.Code)
rawKey := body["key"].(string)
resp, _ = s.request(t, http.MethodGet, "/api/v1/api-keys", nil, "", rawKey)
require.Equal(t, http.StatusUnauthorized, resp.Code)
}
func TestDisabledUserCannotAuthenticateOrUseExistingCredentials(t *testing.T) {
s := newAuthTestServer(t)
resp, body := s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
"username": "admin",
"password": "password123",
"confirm_password": "password123",
}, "", "")
require.Equal(t, http.StatusOK, resp.Code)
accessToken := body["access_token"].(string)
refreshToken := body["refresh_token"].(string)
resp, body = s.request(t, http.MethodPost, "/api/v1/api-keys", map[string]any{
"name": "CLI",
}, accessToken, "")
require.Equal(t, http.StatusCreated, resp.Code)
rawKey := body["key"].(string)
require.NoError(t, database.DB.Model(&models.User{}).
Where("username = ?", "admin").
Update("status", models.UserStatusDisabled).Error)
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/login", map[string]any{
"username": "admin",
"password": "password123",
}, "", "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, _ = s.request(t, http.MethodPost, "/api/v1/auth/refresh", map[string]any{
"refresh_token": refreshToken,
}, "", "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, _ = s.request(t, http.MethodGet, "/api/v1/files", nil, "", rawKey)
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, _ = s.request(t, http.MethodGet, "/api/v1/events", nil, accessToken, "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
resp, _ = s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{
"file_id": "file_missing",
}, accessToken, "")
require.Equal(t, http.StatusUnauthorized, resp.Code)
}
func sha256String(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:])
}