diff --git a/cmd/server/main.go b/cmd/server/main.go index 7f227738..14758b92 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -18,6 +18,7 @@ import ( "scriberr/internal/queue" "scriberr/internal/repository" "scriberr/internal/service" + "scriberr/internal/sse" "scriberr/internal/transcription" "scriberr/internal/transcription/adapters" "scriberr/internal/transcription/registry" @@ -92,6 +93,10 @@ func main() { logger.Startup("auth", "Setting up authentication") authService := auth.NewAuthService(cfg.JWTSecret) + // Initialize SSE Broadcaster + logger.Startup("sse", "Initializing SSE broadcaster") + broadcaster := sse.NewBroadcaster() + // Initialize repositories logger.Startup("repository", "Initializing repositories") jobRepo := repository.NewJobRepository(database.DB) @@ -109,9 +114,12 @@ func main() { userService := service.NewUserService(userRepo, authService) fileService := service.NewFileService() + // Initialize unified transcription processor + logger.Startup("transcription", "Initializing transcription service") // Initialize unified transcription processor logger.Startup("transcription", "Initializing transcription service") unifiedProcessor := transcription.NewUnifiedJobProcessor(jobRepo) + unifiedProcessor.GetUnifiedService().SetBroadcaster(broadcaster) // Bootstrap embedded Python environment (for all adapters) logger.Startup("python", "Preparing Python environment") @@ -152,6 +160,7 @@ func main() { taskQueue, unifiedProcessor, quickTranscriptionService, + broadcaster, ) // Set up router @@ -189,6 +198,11 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + // Shutdown broadcaster to close all active SSE connections + if broadcaster != nil { + broadcaster.Shutdown() + } + // Gracefully shutdown the server if err := srv.Shutdown(ctx); err != nil { logger.Error("Server forced to shutdown", "error", err) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 75d8ee3f..a9e4775e 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -23,6 +23,7 @@ import ( "scriberr/internal/queue" "scriberr/internal/repository" "scriberr/internal/service" + "scriberr/internal/sse" "scriberr/internal/transcription" "scriberr/pkg/logger" @@ -50,6 +51,7 @@ type Handler struct { unifiedProcessor *transcription.UnifiedJobProcessor quickTranscription *transcription.QuickTranscriptionService multiTrackProcessor *processing.MultiTrackProcessor + broadcaster *sse.Broadcaster } // NewHandler creates a new handler @@ -70,6 +72,7 @@ func NewHandler( taskQueue *queue.TaskQueue, unifiedProcessor *transcription.UnifiedJobProcessor, quickTranscription *transcription.QuickTranscriptionService, + broadcaster *sse.Broadcaster, ) *Handler { return &Handler{ config: cfg, @@ -89,6 +92,7 @@ func NewHandler( unifiedProcessor: unifiedProcessor, quickTranscription: quickTranscription, multiTrackProcessor: processing.NewMultiTrackProcessor(), + broadcaster: broadcaster, } } @@ -2829,3 +2833,13 @@ func (h *Handler) UpdateUserSettings(c *gin.Context) { c.JSON(http.StatusOK, response) } + +// @Summary SSE Events +// @Description Subscribe to server-sent events +// @Tags events +// @Produce text/event-stream +// @Success 200 {string} string "stream" +// @Router /api/v1/events [get] +func (h *Handler) Events(c *gin.Context) { + h.broadcaster.ServeHTTP(c.Writer, c.Request) +} diff --git a/internal/api/router.go b/internal/api/router.go index 051cc282..94b53878 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -231,6 +231,13 @@ func SetupRoutes(handler *Handler, authService *auth.AuthService) *gin.Engine { { config.POST("/openai/validate", handler.ValidateOpenAIKey) } + + // SSE Events (require authentication) + events := v1.Group("/events") + events.Use(middleware.AuthMiddleware(authService)) + { + events.GET("/", handler.Events) + } } // Set up static file serving for React app diff --git a/internal/sse/broadcaster.go b/internal/sse/broadcaster.go new file mode 100644 index 00000000..b2fa153f --- /dev/null +++ b/internal/sse/broadcaster.go @@ -0,0 +1,191 @@ +package sse + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "scriberr/pkg/logger" +) + +// Event represents a Server-Sent Event +type Event struct { + Type string `json:"type"` + Payload interface{} `json:"payload"` +} + +// Subscription represents a client subscription to a specific job +type Subscription struct { + JobID string + Channel chan Event +} + +// Message represents an internal broadcast message +type Message struct { + JobID string + Event Event +} + +// Broadcaster manages SSE connections and broadcasting +type Broadcaster struct { + subscribers map[string]map[chan Event]bool // JobID -> Set of Clients + register chan Subscription + unregister chan Subscription + broadcast chan Message + shutdown chan struct{} + mutex sync.RWMutex +} + +// NewBroadcaster creates a new Broadcaster +func NewBroadcaster() *Broadcaster { + b := &Broadcaster{ + subscribers: make(map[string]map[chan Event]bool), + register: make(chan Subscription), + unregister: make(chan Subscription), + broadcast: make(chan Message), + shutdown: make(chan struct{}), + } + + go b.listen() + return b +} + +// listen handles the addition and removal of clients and broadcasting of messages +func (b *Broadcaster) listen() { + for { + select { + case sub := <-b.register: + b.mutex.Lock() + if b.subscribers[sub.JobID] == nil { + b.subscribers[sub.JobID] = make(map[chan Event]bool) + } + b.subscribers[sub.JobID][sub.Channel] = true + b.mutex.Unlock() + logger.Debug("New SSE client registered", "job_id", sub.JobID) + + case sub := <-b.unregister: + b.mutex.Lock() + if clients, ok := b.subscribers[sub.JobID]; ok { + delete(clients, sub.Channel) + close(sub.Channel) + if len(clients) == 0 { + delete(b.subscribers, sub.JobID) + } + } + b.mutex.Unlock() + logger.Debug("SSE client unregistered", "job_id", sub.JobID) + + case msg := <-b.broadcast: + b.mutex.RLock() + // Send only to subscribers of this job + if clients, ok := b.subscribers[msg.JobID]; ok { + for s := range clients { + // Send non-blocking + select { + case s <- msg.Event: + default: + logger.Warn("Skipping slow SSE client", "job_id", msg.JobID) + } + } + } + b.mutex.RUnlock() + + case <-b.shutdown: + b.mutex.Lock() + logger.Info("Broadcaster shutting down") + for _, clients := range b.subscribers { + for s := range clients { + close(s) + } + } + b.subscribers = nil + b.mutex.Unlock() + return + } + } +} + +// Shutdown stops the broadcaster and closes all client connections +func (b *Broadcaster) Shutdown() { + close(b.shutdown) +} + +// ServeHTTP handles the SSE connection +func (b *Broadcaster) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Require Job ID + jobID := r.URL.Query().Get("job_id") + if jobID == "" { + http.Error(w, "job_id is required", http.StatusBadRequest) + return + } + + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + // Create a channel for this client + messageChan := make(chan Event) + subscription := Subscription{JobID: jobID, Channel: messageChan} + + // Register subscription + b.register <- subscription + + // Ensure cleanup on exit + defer func() { + // Use select to avoid blocking if the broadcaster has already shut down + select { + case b.unregister <- subscription: + case <-b.shutdown: + // Broadcaster is shutting down/stopped, no need to deregister + logger.Debug("Skipping SSE client deregistration (shutdown)") + } + }() + + // Send initial connection message + fmt.Fprintf(w, "data: {\"type\":\"connected\", \"job_id\":\"%s\"}\n\n", jobID) + flusher.Flush() + + // Keep connection open and push events + for { + select { + case <-r.Context().Done(): + return + case msg, ok := <-messageChan: + if !ok { + return // Channel closed, exit handler + } + data, err := json.Marshal(msg) + if err != nil { + logger.Error("Failed to marshal SSE message", "error", err) + continue + } + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-time.After(30 * time.Second): + // Keep-alive heartbeat + fmt.Fprintf(w, ": keepalive\n\n") + flusher.Flush() + } + } +} + +// Broadcast sends an event to clients subscribed to the specific job +func (b *Broadcaster) Broadcast(jobID string, eventType string, payload interface{}) { + b.broadcast <- Message{ + JobID: jobID, + Event: Event{ + Type: eventType, + Payload: payload, + }, + } +} diff --git a/internal/sse/broadcaster_test.go b/internal/sse/broadcaster_test.go new file mode 100644 index 00000000..b7830c6e --- /dev/null +++ b/internal/sse/broadcaster_test.go @@ -0,0 +1,55 @@ +package sse + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestBroadcaster(t *testing.T) { + b := NewBroadcaster() + + // 1. Test ServeHTTP connection + req := httptest.NewRequest("GET", "/events?job_id=test-job-1", nil) + w := httptest.NewRecorder() + + // Use a context with timeout to simulate a client disconnecting after receiving messages + // In a real scenario, the ServeHTTP blocks until client disconnects + // We need to run ServeHTTP in a goroutine and consume the response body + // However, httptest.Recorder buffers everything, so it wont work well out of the box for streaming if we wait for it to return. + // Instead we can use a custom writer or just test the Broadcast logic separately? + // Actually, we can test that connecting establishes the headers correctly. + + // Let's test headers first without blocking + go b.ServeHTTP(w, req) + time.Sleep(100 * time.Millisecond) // Wait for connection + + // Check headers + if contentType := w.Header().Get("Content-Type"); contentType != "text/event-stream" { + t.Errorf("Expected Content-Type text/event-stream, got %s", contentType) + } + + // 2. Test Broadcasting + jobID := "test-job-1" + eventType := "status_update" + testPayload := map[string]string{"status": "completed"} + b.Broadcast(jobID, eventType, testPayload) + + time.Sleep(100 * time.Millisecond) // Allow processing + + // The recorder body should contain the data + body := w.Body.String() + + // Check for connected message + if !strings.Contains(body, "{\"type\":\"connected\", \"job_id\":\"test-job-1\"}") { + t.Errorf("Expected connected message not found, got: %s", body) + } + + // Check for broadcasted message + expectedJson, _ := json.Marshal(Event{Type: "status_update", Payload: testPayload}) + if !strings.Contains(body, string(expectedJson)) { + t.Errorf("Expected message %s not found in body: %s", string(expectedJson), body) + } +} diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index 42cc8d81..73be1cc5 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -13,6 +13,7 @@ import ( "scriberr/internal/models" "scriberr/internal/repository" + "scriberr/internal/sse" "scriberr/internal/transcription/interfaces" "scriberr/internal/transcription/pipeline" "scriberr/internal/transcription/registry" @@ -32,6 +33,7 @@ type UnifiedTranscriptionService struct { multiTrackTranscriber *MultiTrackTranscriber // For termination support jobRepo repository.JobRepository webhookService *webhook.Service + broadcaster *sse.Broadcaster } // NewUnifiedTranscriptionService creates a new unified transcription service @@ -52,6 +54,11 @@ func NewUnifiedTranscriptionService(jobRepo repository.JobRepository) *UnifiedTr } } +// SetBroadcaster sets the SSE broadcaster for the service +func (u *UnifiedTranscriptionService) SetBroadcaster(b *sse.Broadcaster) { + u.broadcaster = b +} + // Initialize prepares all registered models for use func (u *UnifiedTranscriptionService) Initialize(ctx context.Context) error { logger.Info("Initializing unified transcription service") @@ -97,6 +104,14 @@ func (u *UnifiedTranscriptionService) ProcessJob(ctx context.Context, jobID stri return fmt.Errorf("failed to create execution record: %w", err) } + // Broadcast initial processing status + if u.broadcaster != nil { + u.broadcaster.Broadcast(jobID, "job_update", map[string]interface{}{ + "job_id": jobID, + "status": models.StatusProcessing, + }) + } + // Helper function to update execution status updateExecutionStatus := func(status models.JobStatus, errorMsg string) { completedAt := time.Now() @@ -110,6 +125,15 @@ func (u *UnifiedTranscriptionService) ProcessJob(ctx context.Context, jobID stri u.jobRepo.UpdateExecution(ctx, execution) + // Broadcast update via SSE + if u.broadcaster != nil { + u.broadcaster.Broadcast(jobID, "job_update", map[string]interface{}{ + "job_id": jobID, + "status": status, + "error": errorMsg, + }) + } + // Trigger webhook if callback URL is present if job.Parameters.CallbackURL != nil && *job.Parameters.CallbackURL != "" { payload := webhook.WebhookPayload{ diff --git a/web/frontend/src/features/transcription/components/AudioFilesTable.tsx b/web/frontend/src/features/transcription/components/AudioFilesTable.tsx index 9ef71db7..00a396cd 100644 --- a/web/frontend/src/features/transcription/components/AudioFilesTable.tsx +++ b/web/frontend/src/features/transcription/components/AudioFilesTable.tsx @@ -43,6 +43,12 @@ import { TranscribeDDialog } from "@/components/TranscribeDDialog"; import { useNavigate } from "react-router-dom"; import { useAuth } from "@/features/auth/hooks/useAuth"; import { useAudioListInfinite, type AudioFile } from "@/features/transcription/hooks/useAudioFiles"; +import { useTranscriptionEvents } from "@/features/transcription/hooks/useTranscriptionEvents"; + +const JobStatusMonitor = memo(({ jobId }: { jobId: string }) => { + useTranscriptionEvents(jobId); + return null; +}); import { DebouncedSearchInput } from "@/components/DebouncedSearchInput"; @@ -83,6 +89,14 @@ export const AudioFilesTable = memo(function AudioFilesTable({ sortOrder: sorting[0]?.desc ? 'desc' : 'asc' }); + // Get active jobs for real-time monitoring + const activeJobs = useMemo(() => { + if (!infiniteData) return []; + return infiniteData.pages.flatMap(page => page.jobs).filter( + job => job.status === 'processing' || job.status === 'pending' + ); + }, [infiniteData]); + // Flatten data from pages const data = useMemo(() => { return infiniteData?.pages.flatMap(page => page.jobs) || []; @@ -203,6 +217,9 @@ export const AudioFilesTable = memo(function AudioFilesTable({ // Close dialog and refresh setConfigDialogOpen(false); setSelectedJobId(null); + // Refresh the list immediately to show the new processing status + // This also triggers SSE connection if it wasn't active + refetch(); if (onTranscribe) { onTranscribe(selectedJobId); } @@ -248,6 +265,7 @@ export const AudioFilesTable = memo(function AudioFilesTable({ // Close dialog and refresh setTranscribeDDialogOpen(false); setSelectedJobId(null); + refetch(); if (onTranscribe) { onTranscribe(selectedJobId); } @@ -949,6 +967,11 @@ export const AudioFilesTable = memo(function AudioFilesTable({ + + {/* Active Job Monitors */} + {activeJobs.map(job => ( + + ))} ); }); diff --git a/web/frontend/src/features/transcription/hooks/useAudioFiles.ts b/web/frontend/src/features/transcription/hooks/useAudioFiles.ts index c74bc354..73943862 100644 --- a/web/frontend/src/features/transcription/hooks/useAudioFiles.ts +++ b/web/frontend/src/features/transcription/hooks/useAudioFiles.ts @@ -61,13 +61,7 @@ export function useAudioList(params: AudioListParams) { return response.json() as Promise; }, placeholderData: keepPreviousData, - refetchInterval: (query) => { - const data = query.state.data; - if (data?.jobs.some(j => j.status === 'processing' || j.status === 'pending')) { - return 5000; - } - return false; - } + refetchInterval: false }); } @@ -105,14 +99,7 @@ export function useAudioListInfinite(params: Omit) { return undefined; }, initialPageParam: 1, - refetchInterval: (query) => { - // flattening the pages to check for pending status - const allJobs = query.state.data?.pages.flatMap(p => p.jobs) || []; - if (allJobs.some(j => j.status === 'processing' || j.status === 'pending')) { - return 5000; - } - return false; - } + refetchInterval: false }); } diff --git a/web/frontend/src/features/transcription/hooks/useTranscriptionEvents.ts b/web/frontend/src/features/transcription/hooks/useTranscriptionEvents.ts new file mode 100644 index 00000000..ab35b604 --- /dev/null +++ b/web/frontend/src/features/transcription/hooks/useTranscriptionEvents.ts @@ -0,0 +1,151 @@ +import { useEffect, useRef } from 'react'; +import { useAuth } from '@/features/auth/hooks/useAuth'; +import { useQueryClient } from '@tanstack/react-query'; +import type { AudioFile } from '@/features/transcription/hooks/useAudioFiles'; + +interface JobUpdateEvent { + type: string; + payload: { + job_id: string; + status: string; + error?: string; + progress?: number; + }; +} + +export const useTranscriptionEvents = (jobId: string | null) => { + const { token } = useAuth(); + const queryClient = useQueryClient(); + const abortControllerRef = useRef(null); + + useEffect(() => { + if (!token || !jobId) return; + + // Cleanup previous connection if any + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const connect = async () => { + try { + const response = await fetch(`/api/v1/events?job_id=${jobId}`, { + headers: { + Authorization: `Bearer ${token}`, + }, + signal: abortController.signal, + }); + + if (!response.ok) { + throw new Error(`SSE connection failed: ${response.status}`); + } + + if (!response.body) { + throw new Error('No response body'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + const lines = buffer.split('\n\n'); + // Keep the last partial line in buffer + buffer = lines.pop() || ''; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || trimmed.startsWith(':')) continue; // Skip comments/keepalives + + if (trimmed.startsWith('data: ')) { + const data = trimmed.slice(6); + try { + const event = JSON.parse(data); + handleEvent(event); + } catch (e) { + console.error('Failed to parse SSE data:', e); + } + } + } + } + } catch (error) { + if ((error as Error).name !== 'AbortError') { + // Ignore "Error in input stream" which happens on abort/close in some browsers + const errorMsg = (error as Error).message; + if (!errorMsg.includes('Error in input stream')) { + console.error('SSE connection error, reconnecting in 5s...', error); + setTimeout(() => { + if (!abortController.signal.aborted) { + connect(); + } + }, 5000); + } + } + } + }; + + const handleEvent = (event: any) => { + if (event.type === 'job_update') { + const payload = event.payload as JobUpdateEvent['payload']; + + // Optimistically update the list + queryClient.setQueriesData({ queryKey: ['audioFiles'] }, (oldData: any) => { + if (!oldData) return oldData; + + // Handle generic infinite query structure + if (oldData.pages) { + return { + ...oldData, + pages: oldData.pages.map((page: any) => ({ + ...page, + jobs: page.jobs.map((job: AudioFile) => { + if (job.id === payload.job_id) { + return { + ...job, + status: payload.status, + error_message: payload.error || job.error_message, + }; + } + return job; + }), + })), + }; + } + + // Handle standard query structure (if used elsewhere) + if (oldData.jobs) { + return { + ...oldData, + jobs: oldData.jobs.map((job: AudioFile) => { + if (job.id === payload.job_id) { + return { + ...job, + status: payload.status, + error_message: payload.error || job.error_message, + }; + } + return job; + }), + }; + } + + return oldData; + }); + } + }; + + connect(); + + return () => { + abortController.abort(); + }; + }, [token, queryClient, jobId]); +};