Files
Scriberr/tests/security_test.go
rishikanthc 913063eb49 refactor: Switch yt-dlp to standalone binary & cleanup UV config
- Dockerfiles: Install yt-dlp binary from GitHub releases to /usr/local/bin
- Go: Execute yt-dlp binary directly, removing uv python wrapper
- Config: Remove unused UVPath configuration and findUVPath function
- Entrypoint: Remove yt-dlp init logic (still initializes whisperx env if needed)
2025-12-16 19:07:29 -08:00

562 lines
18 KiB
Go

package tests
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"scriberr/internal/api"
"scriberr/internal/auth"
"scriberr/internal/config"
"scriberr/internal/database"
"scriberr/internal/processing"
"scriberr/internal/queue"
"scriberr/internal/repository"
"scriberr/internal/service"
"scriberr/internal/sse"
"scriberr/internal/transcription"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type SecurityTestSuite struct {
suite.Suite
router *gin.Engine
config *config.Config
authService *auth.AuthService
taskQueue *queue.TaskQueue
unifiedProcessor *transcription.UnifiedJobProcessor
quickTranscriptionService *transcription.QuickTranscriptionService
handler *api.Handler
}
func (suite *SecurityTestSuite) SetupSuite() {
// Set Gin to test mode
gin.SetMode(gin.TestMode)
// Create test configuration
suite.config = &config.Config{
Port: "8080",
Host: "localhost",
DatabasePath: "security_test.db",
JWTSecret: "test-secret",
UploadDir: "security_test_uploads",
WhisperXEnv: "test_whisperx_env",
}
// Initialize test database
if err := database.Initialize(suite.config.DatabasePath); err != nil {
suite.T().Fatal("Failed to initialize test database:", err)
}
// Initialize services
suite.authService = auth.NewAuthService(suite.config.JWTSecret)
// Initialize repositories
jobRepo := repository.NewJobRepository(database.DB)
userRepo := repository.NewUserRepository(database.DB)
apiKeyRepo := repository.NewAPIKeyRepository(database.DB)
profileRepo := repository.NewProfileRepository(database.DB)
llmConfigRepo := repository.NewLLMConfigRepository(database.DB)
summaryRepo := repository.NewSummaryRepository(database.DB)
chatRepo := repository.NewChatRepository(database.DB)
noteRepo := repository.NewNoteRepository(database.DB)
speakerMappingRepo := repository.NewSpeakerMappingRepository(database.DB)
refreshTokenRepo := repository.NewRefreshTokenRepository(database.DB)
// Initialize services
userService := service.NewUserService(userRepo, suite.authService)
fileService := service.NewFileService()
// Initialize services
suite.unifiedProcessor = transcription.NewUnifiedJobProcessor(jobRepo)
var err error
suite.quickTranscriptionService, err = transcription.NewQuickTranscriptionService(suite.config, suite.unifiedProcessor, jobRepo)
if err != nil {
suite.T().Fatal("Failed to initialize quick transcription service:", err)
}
suite.taskQueue = queue.NewTaskQueue(1, suite.unifiedProcessor, jobRepo)
broadcaster := sse.NewBroadcaster()
multiTrackProcessor := processing.NewMultiTrackProcessor(database.DB, jobRepo)
suite.handler = api.NewHandler(
suite.config,
suite.authService,
userService,
fileService,
jobRepo,
apiKeyRepo,
profileRepo,
userRepo,
llmConfigRepo,
summaryRepo,
chatRepo,
noteRepo,
speakerMappingRepo,
refreshTokenRepo,
suite.taskQueue,
suite.unifiedProcessor,
suite.quickTranscriptionService,
multiTrackProcessor,
broadcaster,
)
// Set up router
suite.router = api.SetupRoutes(suite.handler, suite.authService)
// Create upload directory
os.MkdirAll(suite.config.UploadDir, 0755)
}
func (suite *SecurityTestSuite) TearDownSuite() {
// Clean up test database
database.Close()
os.Remove(suite.config.DatabasePath)
// Clean up upload directory
os.RemoveAll(suite.config.UploadDir)
}
// Helper method to make requests without authentication
func (suite *SecurityTestSuite) makeUnauthenticatedRequest(method, path string, body interface{}) *httptest.ResponseRecorder {
var req *http.Request
var err error
if body != nil {
switch v := body.(type) {
case string:
req, err = http.NewRequest(method, path, strings.NewReader(v))
case []byte:
req, err = http.NewRequest(method, path, bytes.NewBuffer(v))
case *bytes.Buffer:
req, err = http.NewRequest(method, path, v)
default:
jsonBody, _ := json.Marshal(v)
req, err = http.NewRequest(method, path, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
}
} else {
req, err = http.NewRequest(method, path, nil)
}
assert.NoError(suite.T(), err)
w := httptest.NewRecorder()
suite.router.ServeHTTP(w, req)
return w
}
// Helper method to create multipart form data without auth
func (suite *SecurityTestSuite) makeMultipartRequest(path string, fields map[string]string, filename string) *httptest.ResponseRecorder {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add form fields
for key, value := range fields {
writer.WriteField(key, value)
}
// Add file if filename provided
if filename != "" {
// Create a dummy file
tmpFile, err := os.CreateTemp("", "security_test_*.mp3")
assert.NoError(suite.T(), err)
tmpFile.WriteString("dummy audio data for security testing")
tmpFile.Close()
defer os.Remove(tmpFile.Name())
file, err := os.Open(tmpFile.Name())
assert.NoError(suite.T(), err)
defer file.Close()
part, err := writer.CreateFormFile("audio", filename)
assert.NoError(suite.T(), err)
io.Copy(part, file)
}
writer.Close()
req, err := http.NewRequest("POST", path, body)
assert.NoError(suite.T(), err)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
suite.router.ServeHTTP(w, req)
return w
}
// Test JWT-only endpoints (account management)
func (suite *SecurityTestSuite) TestAccountManagementEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"POST", "/api/v1/auth/change-password", map[string]string{"old_password": "old", "new_password": "new"}},
{"POST", "/api/v1/auth/change-username", map[string]string{"new_username": "newuser"}},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test API key management endpoints (JWT-only)
func (suite *SecurityTestSuite) TestAPIKeyManagementEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/api-keys/", nil},
{"POST", "/api/v1/api-keys/", map[string]string{"name": "test key"}},
{"DELETE", "/api/v1/api-keys/123", nil},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test transcription endpoints
func (suite *SecurityTestSuite) TestTranscriptionEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
isMultipart bool
}{
{"POST", "/api/v1/transcription/upload", nil, true},
{"POST", "/api/v1/transcription/youtube", map[string]string{"url": "https://youtube.com/watch?v=123"}, false},
{"POST", "/api/v1/transcription/submit", nil, true},
{"POST", "/api/v1/transcription/test-id/start", nil, false},
{"POST", "/api/v1/transcription/test-id/kill", nil, false},
{"GET", "/api/v1/transcription/test-id/status", nil, false},
{"GET", "/api/v1/transcription/test-id/transcript", nil, false},
{"GET", "/api/v1/transcription/test-id/audio", nil, false},
{"PUT", "/api/v1/transcription/test-id/title", map[string]string{"title": "New Title"}, false},
{"GET", "/api/v1/transcription/test-id/summary", nil, false},
{"GET", "/api/v1/transcription/test-id", nil, false},
{"DELETE", "/api/v1/transcription/test-id", nil, false},
{"GET", "/api/v1/transcription/list", nil, false},
{"GET", "/api/v1/transcription/models", nil, false},
{"GET", "/api/v1/transcription/test-id/notes", nil, false},
{"POST", "/api/v1/transcription/test-id/notes", map[string]string{"content": "Test note"}, false},
{"POST", "/api/v1/transcription/quick", nil, true},
{"GET", "/api/v1/transcription/quick/test-id", nil, false},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
var w *httptest.ResponseRecorder
if tc.isMultipart && tc.method == "POST" {
fields := map[string]string{
"title": "Test Audio",
"model": "base",
}
if strings.Contains(tc.path, "quick") {
w = suite.makeMultipartRequest(tc.path, fields, "test.mp3")
} else {
w = suite.makeMultipartRequest(tc.path, fields, "test.mp3")
}
} else {
w = suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
}
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test profile endpoints
func (suite *SecurityTestSuite) TestProfileEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/profiles/", nil},
{"POST", "/api/v1/profiles/", map[string]interface{}{
"name": "Test Profile",
"parameters": map[string]interface{}{
"model": "base",
"batch_size": 16,
"device": "auto",
},
}},
{"GET", "/api/v1/profiles/123", nil},
{"PUT", "/api/v1/profiles/123", map[string]interface{}{
"name": "Updated Profile",
}},
{"DELETE", "/api/v1/profiles/123", nil},
{"POST", "/api/v1/profiles/123/set-default", nil},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test admin endpoints
func (suite *SecurityTestSuite) TestAdminEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/admin/queue/stats", nil},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test LLM configuration endpoints
func (suite *SecurityTestSuite) TestLLMConfigEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/llm/config", nil},
{"POST", "/api/v1/llm/config", map[string]interface{}{
"provider": "openai",
"api_key": "test-key",
}},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test summary template endpoints
func (suite *SecurityTestSuite) TestSummaryTemplateEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/summaries/", nil},
{"POST", "/api/v1/summaries/", map[string]string{
"name": "Test Template",
"template": "Summarize this: {{content}}",
}},
{"GET", "/api/v1/summaries/123", nil},
{"PUT", "/api/v1/summaries/123", map[string]string{
"name": "Updated Template",
}},
{"DELETE", "/api/v1/summaries/123", nil},
{"GET", "/api/v1/summaries/settings", nil},
{"POST", "/api/v1/summaries/settings", map[string]interface{}{
"auto_summarize": true,
}},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test chat endpoints
func (suite *SecurityTestSuite) TestChatEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/chat/models", nil},
{"POST", "/api/v1/chat/sessions", map[string]interface{}{
"transcription_id": "test-id",
"title": "Test Session",
}},
{"GET", "/api/v1/chat/transcriptions/test-id/sessions", nil},
{"GET", "/api/v1/chat/sessions/session-123", nil},
{"POST", "/api/v1/chat/sessions/session-123/messages", map[string]string{
"content": "Hello",
}},
{"PUT", "/api/v1/chat/sessions/session-123/title", map[string]string{
"title": "New Title",
}},
{"POST", "/api/v1/chat/sessions/session-123/title/auto", nil},
{"DELETE", "/api/v1/chat/sessions/session-123", nil},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test notes endpoints
func (suite *SecurityTestSuite) TestNotesEndpointsUnauthorized() {
testCases := []struct {
method string
path string
body interface{}
}{
{"GET", "/api/v1/notes/note-123", nil},
{"PUT", "/api/v1/notes/note-123", map[string]string{
"content": "Updated note content",
}},
{"DELETE", "/api/v1/notes/note-123", nil},
}
for _, tc := range testCases {
suite.T().Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
w := suite.makeUnauthenticatedRequest(tc.method, tc.path, tc.body)
assert.Equal(t, 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to %s %s", tc.method, tc.path)
})
}
}
// Test summarize endpoint
func (suite *SecurityTestSuite) TestSummarizeEndpointUnauthorized() {
w := suite.makeUnauthenticatedRequest("POST", "/api/v1/summarize/", map[string]interface{}{
"transcription_id": "test-id",
"template_id": "template-123",
})
assert.Equal(suite.T(), 401, w.Code, "Should return 401 Unauthorized for unauthenticated request to POST /api/v1/summarize/")
}
// Test that public endpoints still work without authentication
func (suite *SecurityTestSuite) TestPublicEndpointsAccessible() {
publicEndpoints := []struct {
method string
path string
allowedCodes []int // codes that are acceptable (anything except 401)
}{
{"GET", "/health", []int{200}},
{"GET", "/swagger/index.html", []int{200, 301, 302, 404}}, // swagger might redirect or not exist
{"GET", "/api/v1/auth/registration-status", []int{200}},
{"POST", "/api/v1/auth/register", []int{200, 400, 409}}, // 400 for validation errors, 409 for user exists
{"POST", "/api/v1/auth/login", []int{200, 400, 401}}, // 401 for invalid creds is OK for login endpoint
{"POST", "/api/v1/auth/logout", []int{200}},
}
for _, endpoint := range publicEndpoints {
suite.T().Run(fmt.Sprintf("%s %s should be accessible", endpoint.method, endpoint.path), func(t *testing.T) {
var body interface{}
if endpoint.method == "POST" && strings.Contains(endpoint.path, "register") {
body = map[string]string{
"username": "newtestuser",
"password": "testpass",
}
} else if endpoint.method == "POST" && strings.Contains(endpoint.path, "login") {
body = map[string]string{
"username": "nonexistentuser",
"password": "wrongpass",
}
}
w := suite.makeUnauthenticatedRequest(endpoint.method, endpoint.path, body)
// Check if response code is in allowed codes list
codeAllowed := false
for _, allowedCode := range endpoint.allowedCodes {
if w.Code == allowedCode {
codeAllowed = true
break
}
}
assert.True(t, codeAllowed, "Public endpoint %s %s returned %d, expected one of %v", endpoint.method, endpoint.path, w.Code, endpoint.allowedCodes)
})
}
}
// Test with invalid/malformed authorization headers
func (suite *SecurityTestSuite) TestMalformedAuthHeaders() {
testEndpoint := "/api/v1/transcription/list"
malformedHeaders := []struct {
name string
header string
value string
}{
{"Invalid Bearer format", "Authorization", "InvalidBearer token123"},
{"Empty Bearer token", "Authorization", "Bearer "},
{"Invalid JWT token", "Authorization", "Bearer invalid.jwt.token"},
{"Empty API key", "X-API-Key", ""},
{"Malformed API key", "X-API-Key", "malformed-key-123"},
}
for _, tc := range malformedHeaders {
suite.T().Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", testEndpoint, nil)
req.Header.Set(tc.header, tc.value)
w := httptest.NewRecorder()
suite.router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code, "Should return 401 for malformed auth header: %s", tc.name)
})
}
}
// Test CORS preflight requests don't bypass authentication
func (suite *SecurityTestSuite) TestCORSPreflightDoesNotBypassAuth() {
protectedEndpoint := "/api/v1/transcription/list"
// OPTIONS request should return 204 (handled by CORS middleware)
req, _ := http.NewRequest("OPTIONS", protectedEndpoint, nil)
req.Header.Set("Access-Control-Request-Method", "GET")
req.Header.Set("Origin", "https://evil.example.com")
w := httptest.NewRecorder()
suite.router.ServeHTTP(w, req)
assert.Equal(suite.T(), 204, w.Code, "OPTIONS request should return 204")
// But actual GET request should still require authentication
w2 := suite.makeUnauthenticatedRequest("GET", protectedEndpoint, nil)
assert.Equal(suite.T(), 401, w2.Code, "GET request after CORS preflight should still require authentication")
}
// Test security headers are properly set
func (suite *SecurityTestSuite) TestSecurityHeaders() {
w := suite.makeUnauthenticatedRequest("GET", "/health", nil)
// Check CORS headers are present
assert.NotEmpty(suite.T(), w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(suite.T(), w.Header().Get("Access-Control-Allow-Methods"))
assert.NotEmpty(suite.T(), w.Header().Get("Access-Control-Allow-Headers"))
}
func TestSecurityTestSuite(t *testing.T) {
suite.Run(t, new(SecurityTestSuite))
}