mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-29 15:26:02 +00:00
158 lines
4.0 KiB
Go
158 lines
4.0 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"scriberr/internal/config"
|
|
"scriberr/internal/database"
|
|
"scriberr/internal/models"
|
|
"scriberr/pkg/logger"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const requestIDKey = "request_id"
|
|
|
|
func requestIDMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
requestID := strings.TrimSpace(c.GetHeader("X-Request-ID"))
|
|
if requestID == "" {
|
|
requestID = newRequestID()
|
|
}
|
|
c.Set(requestIDKey, requestID)
|
|
c.Header("X-Request-ID", requestID)
|
|
c.Next()
|
|
}
|
|
}
|
|
func recoveryMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
defer func() {
|
|
if recovered := recover(); recovered != nil {
|
|
logger.Error("API panic recovered", "request_id", requestID(c), "panic", recovered)
|
|
writeError(c, http.StatusInternalServerError, "INTERNAL_ERROR", "internal server error", nil)
|
|
c.Abort()
|
|
}
|
|
}()
|
|
c.Next()
|
|
}
|
|
}
|
|
func corsMiddleware(cfg *config.Config) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
origin := c.Request.Header.Get("Origin")
|
|
allowOrigin := "*"
|
|
if cfg != nil && cfg.IsProduction() && len(cfg.AllowedOrigins) > 0 {
|
|
allowOrigin = ""
|
|
for _, allowed := range cfg.AllowedOrigins {
|
|
if origin == allowed {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
}
|
|
} else if origin != "" {
|
|
allowOrigin = origin
|
|
}
|
|
if allowOrigin != "" {
|
|
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
|
c.Header("Access-Control-Allow-Credentials", "true")
|
|
}
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
|
|
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-API-Key, X-Request-ID, Idempotency-Key")
|
|
if c.Request.Method == http.MethodOptions {
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
func (h *Handler) handleCommandRoute(c *gin.Context) bool {
|
|
if c.Request.Method != http.MethodPost {
|
|
return false
|
|
}
|
|
switch c.Request.URL.Path {
|
|
case "/api/v1/files:import-youtube":
|
|
if !h.requireAuthForNoRoute(c) {
|
|
return true
|
|
}
|
|
h.runIdempotent(c, h.importYouTube)
|
|
return true
|
|
case "/api/v1/transcriptions:submit":
|
|
if !h.requireAuthForNoRoute(c) {
|
|
return true
|
|
}
|
|
h.runIdempotent(c, h.submitTranscription)
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
func (h *Handler) requireAuthForNoRoute(c *gin.Context) bool {
|
|
if h.authenticateAPIKey(c) || h.authenticateJWT(c) {
|
|
return true
|
|
}
|
|
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid authentication", nil)
|
|
return false
|
|
}
|
|
func (h *Handler) authRequired() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if h.authenticateAPIKey(c) || h.authenticateJWT(c) {
|
|
c.Next()
|
|
return
|
|
}
|
|
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid authentication", nil)
|
|
c.Abort()
|
|
}
|
|
}
|
|
func (h *Handler) jwtRequired() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if h.authenticateJWT(c) {
|
|
c.Next()
|
|
return
|
|
}
|
|
writeError(c, http.StatusUnauthorized, "UNAUTHORIZED", "missing or invalid bearer token", nil)
|
|
c.Abort()
|
|
}
|
|
}
|
|
func (h *Handler) authenticateJWT(c *gin.Context) bool {
|
|
if h.authService == nil {
|
|
return false
|
|
}
|
|
token := bearerToken(c.GetHeader("Authorization"))
|
|
if token == "" {
|
|
if cookie, err := c.Cookie("scriberr_access_token"); err == nil {
|
|
token = cookie
|
|
}
|
|
}
|
|
if token == "" {
|
|
return false
|
|
}
|
|
claims, err := h.authService.ValidateToken(token)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
c.Set("auth_type", "jwt")
|
|
c.Set("user_id", claims.UserID)
|
|
c.Set("username", claims.Username)
|
|
return true
|
|
}
|
|
func (h *Handler) authenticateAPIKey(c *gin.Context) bool {
|
|
key := strings.TrimSpace(c.GetHeader("X-API-Key"))
|
|
if key == "" || database.DB == nil {
|
|
return false
|
|
}
|
|
|
|
var apiKey models.APIKey
|
|
if err := database.DB.Where("key_hash = ? AND revoked_at IS NULL", sha256Hex(key)).First(&apiKey).Error; err != nil {
|
|
return false
|
|
}
|
|
now := time.Now()
|
|
apiKey.LastUsed = &now
|
|
_ = database.DB.Save(&apiKey).Error
|
|
|
|
c.Set("auth_type", "api_key")
|
|
c.Set("user_id", apiKey.UserID)
|
|
c.Set("api_key_id", apiKey.ID)
|
|
return true
|
|
}
|