mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-07-01 08:15:46 +00:00
451 lines
16 KiB
Go
451 lines
16 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/textproto"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"scriberr/internal/database"
|
|
"scriberr/internal/models"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeYouTubeImporter struct {
|
|
mu sync.Mutex
|
|
once sync.Once
|
|
doneOnce sync.Once
|
|
calls []youtubeImportJob
|
|
content []byte
|
|
filename string
|
|
mimeType string
|
|
err error
|
|
block chan struct{}
|
|
completed chan struct{}
|
|
}
|
|
|
|
func (f *fakeYouTubeImporter) Import(ctx context.Context, job youtubeImportJob) (youtubeImportResult, error) {
|
|
f.mu.Lock()
|
|
f.calls = append(f.calls, job)
|
|
f.mu.Unlock()
|
|
if f.block != nil {
|
|
select {
|
|
case <-f.block:
|
|
case <-ctx.Done():
|
|
return youtubeImportResult{}, ctx.Err()
|
|
}
|
|
}
|
|
defer func() {
|
|
if f.completed != nil {
|
|
f.doneOnce.Do(func() { close(f.completed) })
|
|
}
|
|
}()
|
|
if f.err != nil {
|
|
return youtubeImportResult{}, f.err
|
|
}
|
|
content := f.content
|
|
if content == nil {
|
|
content = []byte("youtube audio")
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(job.OutputPath), 0755); err != nil {
|
|
return youtubeImportResult{}, err
|
|
}
|
|
if err := os.WriteFile(job.OutputPath, content, 0600); err != nil {
|
|
return youtubeImportResult{}, err
|
|
}
|
|
filename := f.filename
|
|
if filename == "" {
|
|
filename = "download.mp3"
|
|
}
|
|
mimeType := f.mimeType
|
|
if mimeType == "" {
|
|
mimeType = "audio/mpeg"
|
|
}
|
|
return youtubeImportResult{Filename: filename, MimeType: mimeType}, nil
|
|
}
|
|
|
|
func (f *fakeYouTubeImporter) unblock() {
|
|
if f.block != nil {
|
|
f.once.Do(func() { close(f.block) })
|
|
}
|
|
}
|
|
|
|
func (f *fakeYouTubeImporter) callCount() int {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
return len(f.calls)
|
|
}
|
|
|
|
func registerForFileTests(t *testing.T, s *authTestServer) string {
|
|
t.Helper()
|
|
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/auth/register", map[string]any{
|
|
"username": "admin",
|
|
"password": "password123",
|
|
"confirm_password": "password123",
|
|
}, "", "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
return body["access_token"].(string)
|
|
}
|
|
|
|
func uploadMultipart(t *testing.T, s *authTestServer, token, fieldName, filename, contentType string, content []byte, title string) (*httptest.ResponseRecorder, map[string]any) {
|
|
t.Helper()
|
|
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
partHeader := make(textproto.MIMEHeader)
|
|
partHeader.Set("Content-Disposition", `form-data; name="`+fieldName+`"; filename="`+filename+`"`)
|
|
partHeader.Set("Content-Type", contentType)
|
|
part, err := writer.CreatePart(partHeader)
|
|
require.NoError(t, err)
|
|
_, err = part.Write(content)
|
|
require.NoError(t, err)
|
|
if title != "" {
|
|
require.NoError(t, writer.WriteField("title", title))
|
|
}
|
|
require.NoError(t, writer.Close())
|
|
|
|
req, err := http.NewRequest(http.MethodPost, "/api/v1/files", &body)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
s.router.ServeHTTP(recorder, req)
|
|
|
|
var response map[string]any
|
|
if recorder.Body.Len() > 0 {
|
|
require.NoError(t, json.NewDecoder(recorder.Body).Decode(&response))
|
|
}
|
|
return recorder, response
|
|
}
|
|
|
|
func TestFileUploadListGetPatchDelete(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
token := registerForFileTests(t, s)
|
|
|
|
resp, body := uploadMultipart(t, s, token, "file", "meeting.wav", "audio/wav", []byte("RIFF----WAVEfmt data"), "Team sync")
|
|
require.Equal(t, http.StatusCreated, resp.Code)
|
|
fileID := body["id"].(string)
|
|
require.True(t, strings.HasPrefix(fileID, "file_"))
|
|
require.Equal(t, "Team sync", body["title"])
|
|
require.Equal(t, "audio", body["kind"])
|
|
require.Equal(t, "ready", body["status"])
|
|
require.Equal(t, "audio/wav", body["mime_type"])
|
|
require.NotContains(t, body, "audio_path")
|
|
require.NotContains(t, body, "source_file_path")
|
|
|
|
var stored models.TranscriptionJob
|
|
require.NoError(t, database.DB.First(&stored, "id = ?", strings.TrimPrefix(fileID, "file_")).Error)
|
|
require.NotEmpty(t, stored.AudioPath)
|
|
require.NotContains(t, fileID, filepath.Base(stored.AudioPath))
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items := body["items"].([]any)
|
|
require.Len(t, items, 1)
|
|
|
|
resp, body = s.request(t, http.MethodPost, "/api/v1/transcriptions", map[string]any{
|
|
"file_id": fileID,
|
|
"title": "Team sync transcript",
|
|
}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
transcriptionID := body["id"].(string)
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items = body["items"].([]any)
|
|
require.Len(t, items, 1)
|
|
|
|
resp, _ = s.request(t, http.MethodGet, "/api/v1/files/"+strings.Replace(transcriptionID, "tr_", "file_", 1), nil, token, "")
|
|
require.Equal(t, http.StatusNotFound, resp.Code)
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files/"+fileID, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Equal(t, fileID, body["id"])
|
|
require.NotContains(t, body, "source_file_path")
|
|
|
|
resp, body = s.request(t, http.MethodPatch, "/api/v1/files/"+fileID, map[string]any{"title": "Renamed"}, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Equal(t, "Renamed", body["title"])
|
|
|
|
resp, _ = s.request(t, http.MethodDelete, "/api/v1/files/"+fileID, nil, token, "")
|
|
require.Equal(t, http.StatusNoContent, resp.Code)
|
|
|
|
resp, _ = s.request(t, http.MethodGet, "/api/v1/files/"+fileID, nil, token, "")
|
|
require.Equal(t, http.StatusNotFound, resp.Code)
|
|
}
|
|
|
|
func TestFileUploadValidationAndSecurity(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
token := registerForFileTests(t, s)
|
|
|
|
resp, _ := uploadMultipart(t, s, token, "wrong", "meeting.wav", "audio/wav", []byte("RIFF----WAVEfmt data"), "")
|
|
require.Equal(t, http.StatusBadRequest, resp.Code)
|
|
|
|
resp, body := uploadMultipart(t, s, token, "file", "../secret.wav", "audio/wav", []byte("RIFF----WAVEfmt data"), "")
|
|
require.Equal(t, http.StatusCreated, resp.Code)
|
|
require.NotContains(t, body["title"], "..")
|
|
|
|
resp, body = uploadMultipart(t, s, token, "file", "notes.txt", "text/plain", []byte("plain text"), "")
|
|
require.Equal(t, http.StatusUnsupportedMediaType, resp.Code)
|
|
errBody := body["error"].(map[string]any)
|
|
require.NotContains(t, errBody["message"], "/")
|
|
}
|
|
|
|
func TestFileUploadSizeLimit(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
s.handler.maxUploadBytes = 128
|
|
token := registerForFileTests(t, s)
|
|
|
|
resp, body := uploadMultipart(t, s, token, "file", "large.wav", "audio/wav", bytes.Repeat([]byte("x"), 1024), "")
|
|
require.Equal(t, http.StatusRequestEntityTooLarge, resp.Code)
|
|
errBody := body["error"].(map[string]any)
|
|
require.Equal(t, "PAYLOAD_TOO_LARGE", errBody["code"])
|
|
require.Equal(t, "file", errBody["field"])
|
|
require.NotContains(t, errBody["message"], s.uploadDir)
|
|
}
|
|
|
|
func TestYouTubeImportDownloadsWithFakeImporterAndStreamsResult(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
importer := &fakeYouTubeImporter{content: []byte("ID3 youtube audio"), completed: make(chan struct{})}
|
|
s.handler.youtubeImporter = importer
|
|
token := registerForFileTests(t, s)
|
|
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/files:import-youtube", map[string]any{
|
|
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
"title": "Talk",
|
|
}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
require.True(t, strings.HasPrefix(body["id"].(string), "file_"))
|
|
require.Equal(t, "Talk", body["title"])
|
|
require.Equal(t, "youtube", body["kind"])
|
|
require.Equal(t, "processing", body["status"])
|
|
require.NotContains(t, body, "source_file_path")
|
|
fileID := body["id"].(string)
|
|
|
|
select {
|
|
case <-importer.completed:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("youtube import did not complete")
|
|
}
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files/"+fileID, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Equal(t, "ready", body["status"])
|
|
require.Equal(t, "youtube", body["kind"])
|
|
require.Equal(t, "audio/mpeg", body["mime_type"])
|
|
require.Equal(t, float64(len("ID3 youtube audio")), body["size_bytes"])
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items := body["items"].([]any)
|
|
require.Len(t, items, 1)
|
|
require.Equal(t, "youtube", items[0].(map[string]any)["kind"])
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "/api/v1/files/"+fileID+"/audio", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
stream := httptest.NewRecorder()
|
|
s.router.ServeHTTP(stream, req)
|
|
require.Equal(t, http.StatusOK, stream.Code)
|
|
require.Equal(t, []byte("ID3 youtube audio"), stream.Body.Bytes())
|
|
require.Equal(t, 1, importer.callCount())
|
|
}
|
|
|
|
func TestYouTubeImportFailureIsSanitizedAndPublishesFailedEvent(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
importer := &fakeYouTubeImporter{err: errors.New("yt-dlp failed /tmp/private/raw-url"), completed: make(chan struct{})}
|
|
s.handler.youtubeImporter = importer
|
|
token := registerForFileTests(t, s)
|
|
|
|
recorder, cancel, done := startEventStream(t, s, token, "/api/v1/events")
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/files:import-youtube", map[string]any{
|
|
"url": "https://youtu.be/dQw4w9WgXcQ",
|
|
"title": "Broken import",
|
|
}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
fileID := body["id"].(string)
|
|
|
|
select {
|
|
case <-importer.completed:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("youtube import did not fail")
|
|
}
|
|
|
|
deadline := time.Now().Add(time.Second)
|
|
for time.Now().Before(deadline) {
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files/"+fileID, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
if body["status"] == "failed" {
|
|
break
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
require.Equal(t, "failed", body["status"])
|
|
require.NotContains(t, body, "/tmp/private")
|
|
stopEventStream(t, cancel, done)
|
|
|
|
stream := recorder.Body.String()
|
|
require.Contains(t, stream, "event: file.failed")
|
|
require.Contains(t, stream, `"id":"`+fileID+`"`)
|
|
require.NotContains(t, stream, "/tmp/private")
|
|
require.NotContains(t, stream, "raw-url")
|
|
}
|
|
|
|
func TestYouTubeImportURLValidation(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
s.handler.youtubeImporter = &fakeYouTubeImporter{completed: make(chan struct{})}
|
|
token := registerForFileTests(t, s)
|
|
|
|
for _, rawURL := range []string{
|
|
"file:///etc/passwd",
|
|
"https://example.com/video",
|
|
"https://youtube.evil.test/watch?v=dQw4w9WgXcQ",
|
|
} {
|
|
resp, body := s.request(t, http.MethodPost, "/api/v1/files:import-youtube", map[string]any{
|
|
"url": rawURL,
|
|
}, token, "")
|
|
require.Equal(t, http.StatusUnprocessableEntity, resp.Code, rawURL)
|
|
errBody := body["error"].(map[string]any)
|
|
require.Equal(t, "url", errBody["field"])
|
|
}
|
|
}
|
|
|
|
func TestFileListFiltersSortingPaginationAndValidation(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
token := registerForFileTests(t, s)
|
|
|
|
uploads := []struct {
|
|
filename string
|
|
title string
|
|
}{
|
|
{filename: "alpha.wav", title: "Alpha meeting"},
|
|
{filename: "bravo.mp3", title: "Bravo notes"},
|
|
{filename: "charlie.wav", title: "Charlie sync"},
|
|
}
|
|
for _, upload := range uploads {
|
|
resp, _ := uploadMultipart(t, s, token, "file", upload.filename, "audio/wav", []byte("RIFF----WAVEfmt data"), upload.title)
|
|
require.Equal(t, http.StatusCreated, resp.Code)
|
|
}
|
|
resp, _ := s.request(t, http.MethodPost, "/api/v1/files:import-youtube", map[string]any{
|
|
"url": "https://www.youtube.com/watch?v=abc123",
|
|
"title": "YouTube talk",
|
|
}, token, "")
|
|
require.Equal(t, http.StatusAccepted, resp.Code)
|
|
|
|
resp, body := s.request(t, http.MethodGet, "/api/v1/files?kind=audio&q=bravo&sort=title", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items := body["items"].([]any)
|
|
require.Len(t, items, 1)
|
|
require.Equal(t, "Bravo notes", items[0].(map[string]any)["title"])
|
|
require.Equal(t, "audio", items[0].(map[string]any)["kind"])
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files?kind=youtube", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items = body["items"].([]any)
|
|
require.Len(t, items, 1)
|
|
require.Equal(t, "youtube", items[0].(map[string]any)["kind"])
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files?status=processing", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items = body["items"].([]any)
|
|
require.Len(t, items, 1)
|
|
require.Equal(t, "processing", items[0].(map[string]any)["status"])
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files?status=ready", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
items = body["items"].([]any)
|
|
require.Len(t, items, 3)
|
|
for _, raw := range items {
|
|
require.Equal(t, "ready", raw.(map[string]any)["status"])
|
|
}
|
|
|
|
future := time.Now().Add(time.Hour).Format(time.RFC3339)
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files?updated_after="+future, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
require.Empty(t, body["items"].([]any))
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files?limit=2&sort=title", nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
firstPage := body["items"].([]any)
|
|
require.Len(t, firstPage, 2)
|
|
require.Equal(t, "Alpha meeting", firstPage[0].(map[string]any)["title"])
|
|
require.Equal(t, "Bravo notes", firstPage[1].(map[string]any)["title"])
|
|
nextCursor, ok := body["next_cursor"].(string)
|
|
require.True(t, ok)
|
|
require.NotEmpty(t, nextCursor)
|
|
|
|
resp, body = s.request(t, http.MethodGet, "/api/v1/files?limit=2&sort=title&cursor="+nextCursor, nil, token, "")
|
|
require.Equal(t, http.StatusOK, resp.Code)
|
|
secondPage := body["items"].([]any)
|
|
require.Len(t, secondPage, 2)
|
|
require.Equal(t, "Charlie sync", secondPage[0].(map[string]any)["title"])
|
|
require.Equal(t, "YouTube talk", secondPage[1].(map[string]any)["title"])
|
|
require.Nil(t, body["next_cursor"])
|
|
|
|
validationCases := []string{
|
|
"/api/v1/files?limit=0",
|
|
"/api/v1/files?kind=document",
|
|
"/api/v1/files?status=completed",
|
|
"/api/v1/files?sort=size",
|
|
"/api/v1/files?updated_after=not-a-time",
|
|
"/api/v1/files?cursor=not-a-cursor",
|
|
}
|
|
for _, path := range validationCases {
|
|
resp, body := s.request(t, http.MethodGet, path, nil, token, "")
|
|
require.Equal(t, http.StatusUnprocessableEntity, resp.Code, path)
|
|
errBody := body["error"].(map[string]any)
|
|
require.NotEmpty(t, errBody["field"])
|
|
}
|
|
}
|
|
|
|
func TestFileAudioRangeStreaming(t *testing.T) {
|
|
s := newAuthTestServer(t)
|
|
token := registerForFileTests(t, s)
|
|
content := []byte("RIFF----WAVEfmt 0123456789abcdef")
|
|
|
|
resp, body := uploadMultipart(t, s, token, "file", "meeting.wav", "audio/wav", content, "Team sync")
|
|
require.Equal(t, http.StatusCreated, resp.Code)
|
|
fileID := body["id"].(string)
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "/api/v1/files/"+fileID+"/audio", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
full := httptest.NewRecorder()
|
|
s.router.ServeHTTP(full, req)
|
|
require.Equal(t, http.StatusOK, full.Code)
|
|
require.Equal(t, "bytes", full.Header().Get("Accept-Ranges"))
|
|
require.Equal(t, content, full.Body.Bytes())
|
|
|
|
req, err = http.NewRequest(http.MethodGet, "/api/v1/files/"+fileID+"/audio", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
req.Header.Set("Range", "bytes=5-9")
|
|
partial := httptest.NewRecorder()
|
|
s.router.ServeHTTP(partial, req)
|
|
require.Equal(t, http.StatusPartialContent, partial.Code)
|
|
require.Equal(t, "bytes 5-9/32", partial.Header().Get("Content-Range"))
|
|
require.Equal(t, content[5:10], partial.Body.Bytes())
|
|
|
|
req, err = http.NewRequest(http.MethodGet, "/api/v1/files/"+fileID+"/audio", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
req.Header.Set("Range", "bytes=99-100")
|
|
invalid := httptest.NewRecorder()
|
|
s.router.ServeHTTP(invalid, req)
|
|
require.Equal(t, http.StatusRequestedRangeNotSatisfiable, invalid.Code)
|
|
require.Equal(t, "bytes */32", invalid.Header().Get("Content-Range"))
|
|
}
|