Files
Scriberr/internal/api/middleware.go
2026-05-03 15:14:21 -07:00

369 lines
9.9 KiB
Go

package api
import (
"context"
"errors"
"net/http"
"strings"
"scriberr/internal/config"
"scriberr/pkg/logger"
"github.com/gin-gonic/gin"
)
const requestIDKey = "request_id"
type principal struct {
UserID uint
Username string
Role string
AuthType string
APIKeyID *uint
}
func requestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := strings.TrimSpace(c.GetHeader("X-Request-ID"))
if requestID == "" {
requestID = newRequestID()
}
c.Set(requestIDKey, requestID)
c.Header("X-Request-ID", requestID)
c.Next()
}
}
func recoveryMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if recovered := recover(); recovered != nil {
logger.Error("API panic recovered", "request_id", requestID(c), "panic", recovered)
writeError(c, http.StatusInternalServerError, "INTERNAL_ERROR", "internal server error", nil)
c.Abort()
}
}()
c.Next()
}
}
func corsMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
allowOrigin := "*"
if cfg != nil && cfg.IsProduction() && len(cfg.AllowedOrigins) > 0 {
allowOrigin = ""
for _, allowed := range cfg.AllowedOrigins {
if origin == allowed {
allowOrigin = origin
break
}
}
} else if origin != "" {
allowOrigin = origin
}
if allowOrigin != "" {
c.Header("Access-Control-Allow-Origin", allowOrigin)
c.Header("Access-Control-Allow-Credentials", "true")
}
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-API-Key, X-Request-ID, Idempotency-Key, X-Chunk-SHA256, X-Chunk-Duration-Ms")
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func (h *Handler) handleCommandRoute(c *gin.Context) bool {
if c.Request.Method != http.MethodPost {
return false
}
switch c.Request.URL.Path {
case "/api/v1/files:import-youtube":
if !h.requireAuthForNoRoute(c) {
return true
}
h.runIdempotent(c, h.importYouTube)
return true
case "/api/v1/transcriptions:submit":
if !h.requireAuthForNoRoute(c) {
return true
}
h.runIdempotent(c, h.submitTranscription)
return true
default:
if publicID, command, ok := parseAdminUserCommandPath(c.Request.URL.Path); ok {
if !h.requireAdminForNoRoute(c) {
return true
}
c.Params = append(c.Params, gin.Param{Key: "user_id", Value: publicID})
h.runIdempotent(c, func(c *gin.Context) {
switch command {
case "reset-password":
h.resetAdminUserPassword(c)
case "disable":
h.disableAdminUser(c)
case "enable":
h.enableAdminUser(c)
default:
writeError(c, http.StatusNotFound, "NOT_FOUND", "API endpoint not found", nil)
}
})
return true
}
if sessionID, ok := parseChatMessageStreamPath(c.Request.URL.Path); ok {
if !h.requireAuthForNoRoute(c) {
return true
}
h.runIdempotent(c, func(c *gin.Context) { h.streamChatMessage(c, sessionID) })
return true
}
if runID, ok := parseChatRunCancelPath(c.Request.URL.Path); ok {
if !h.requireAuthForNoRoute(c) {
return true
}
h.cancelChatRun(c, runID)
return true
}
if publicID, command, ok := parseTranscriptionCommandPath(c.Request.URL.Path); ok {
if !h.requireAuthForNoRoute(c) {
return true
}
h.runIdempotent(c, func(c *gin.Context) {
switch command {
case "stop", "cancel":
h.cancelTranscription(c, publicID)
case "retry":
h.retryTranscription(c, publicID)
default:
writeError(c, http.StatusNotFound, "NOT_FOUND", "API endpoint not found", nil)
}
})
return true
}
if publicID, command, ok := parseRecordingCommandPath(c.Request.URL.Path); ok {
if !h.requireAuthForNoRoute(c) {
return true
}
h.runIdempotent(c, func(c *gin.Context) {
switch command {
case "stop":
h.stopRecording(c, publicID)
case "cancel":
h.cancelRecording(c, publicID)
case "retry-finalize":
h.retryFinalizeRecording(c, publicID)
default:
writeError(c, http.StatusNotFound, "NOT_FOUND", "API endpoint not found", nil)
}
})
return true
}
return false
}
}
func parseAdminUserCommandPath(requestPath string) (string, string, bool) {
trimmed := strings.TrimPrefix(requestPath, "/api/v1/admin/users/")
if trimmed == requestPath || trimmed == "" || strings.Contains(trimmed, "/") {
return "", "", false
}
publicID, command, ok := strings.Cut(trimmed, ":")
if !ok || publicID == "" || command == "" {
return "", "", false
}
switch command {
case "reset-password", "disable", "enable":
return publicID, command, true
default:
return "", "", false
}
}
func parseChatMessageStreamPath(requestPath string) (string, bool) {
trimmed := strings.TrimPrefix(requestPath, "/api/v1/chat/sessions/")
if trimmed == requestPath || trimmed == "" {
return "", false
}
sessionID, suffix, ok := strings.Cut(trimmed, "/messages:stream")
if !ok || suffix != "" || sessionID == "" {
return "", false
}
return sessionID, true
}
func parseChatRunCancelPath(requestPath string) (string, bool) {
trimmed := strings.TrimPrefix(requestPath, "/api/v1/chat/runs/")
if trimmed == requestPath || trimmed == "" {
return "", false
}
runID, command, ok := strings.Cut(trimmed, ":")
if !ok || command != "cancel" || runID == "" {
return "", false
}
return runID, true
}
func parseTranscriptionCommandPath(requestPath string) (string, string, bool) {
trimmed := strings.TrimPrefix(requestPath, "/api/v1/transcriptions/")
if trimmed == requestPath || trimmed == "" || strings.Contains(trimmed, "/") {
return "", "", false
}
publicID, command, ok := strings.Cut(trimmed, ":")
if !ok || publicID == "" || command == "" {
return "", "", false
}
switch command {
case "stop", "cancel", "retry":
return publicID, command, true
default:
return "", "", false
}
}
func parseRecordingCommandPath(requestPath string) (string, string, bool) {
trimmed := strings.TrimPrefix(requestPath, "/api/v1/recordings/")
if trimmed == requestPath || trimmed == "" {
return "", "", false
}
id, command, ok := strings.Cut(trimmed, ":")
if !ok || id == "" || command == "" || strings.Contains(id, "/") || strings.Contains(command, "/") {
return "", "", false
}
return id, command, true
}
func (h *Handler) requireAuthForNoRoute(c *gin.Context) bool {
if h.authenticateAPIKey(c) || h.authenticateJWT(c) {
return true
}
if writeAuthContextError(c) {
return false
}
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid authentication", nil)
return false
}
func (h *Handler) requireAdminForNoRoute(c *gin.Context) bool {
if !h.authenticateJWT(c) && !h.authenticateAPIKey(c) {
if writeAuthContextError(c) {
return false
}
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid authentication", nil)
return false
}
principal, ok := h.currentPrincipal(c)
if !ok || principal.AuthType != "jwt" || principal.Role != "admin" {
writeError(c, http.StatusForbidden, "FORBIDDEN", "admin access is required", nil)
return false
}
return true
}
func (h *Handler) authRequired() gin.HandlerFunc {
return func(c *gin.Context) {
if h.authenticateAPIKey(c) || h.authenticateJWT(c) || h.authenticateStreamToken(c) {
c.Next()
return
}
if writeAuthContextError(c) {
c.Abort()
return
}
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid authentication", nil)
c.Abort()
}
}
func (h *Handler) adminRequired() gin.HandlerFunc {
return func(c *gin.Context) {
if !h.authenticateJWT(c) && !h.authenticateAPIKey(c) {
if writeAuthContextError(c) {
c.Abort()
return
}
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid authentication", nil)
c.Abort()
return
}
principal, ok := h.currentPrincipal(c)
if !ok || principal.AuthType != "jwt" || principal.Role != "admin" {
writeError(c, http.StatusForbidden, "FORBIDDEN", "admin access is required", nil)
c.Abort()
return
}
c.Next()
}
}
func (h *Handler) jwtRequired() gin.HandlerFunc {
return func(c *gin.Context) {
if h.authenticateJWT(c) {
c.Next()
return
}
if writeAuthContextError(c) {
c.Abort()
return
}
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid bearer token", nil)
c.Abort()
}
}
func (h *Handler) authenticateJWT(c *gin.Context) bool {
if h.authService == nil {
return false
}
token := bearerToken(c.GetHeader("Authorization"))
if token == "" {
if cookie, err := c.Cookie("scriberr_access_token"); err == nil {
token = cookie
}
}
if token == "" {
return false
}
claims, err := h.authService.ValidateToken(token)
if err != nil {
return false
}
if h.account != nil {
user, err := h.account.ValidateActiveUser(c.Request.Context(), claims.UserID)
if err != nil {
return false
}
claims.Username = user.Username
claims.Role = user.Role
}
c.Set("auth_type", "jwt")
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("role", claims.Role)
return true
}
func (h *Handler) authenticateAPIKey(c *gin.Context) bool {
key := strings.TrimSpace(c.GetHeader("X-API-Key"))
if key == "" || h.account == nil {
return false
}
apiKey, err := h.account.AuthenticateAPIKey(c.Request.Context(), key)
if err != nil {
return false
}
c.Set("auth_type", "api_key")
c.Set("user_id", apiKey.UserID)
c.Set("api_key_id", apiKey.ID)
return true
}
func writeAuthContextError(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
err := c.Request.Context().Err()
if errors.Is(err, context.Canceled) {
writeError(c, statusClientClosedRequest, "REQUEST_CANCELED", "request was canceled", nil)
return true
}
if errors.Is(err, context.DeadlineExceeded) {
writeError(c, http.StatusGatewayTimeout, "REQUEST_TIMEOUT", "request timed out", nil)
return true
}
return false
}