mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-28 06:46:25 +00:00
977 lines
33 KiB
Go
977 lines
33 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"scriberr/internal/database"
|
|
"scriberr/internal/llm"
|
|
"scriberr/internal/models"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const StylePrompt = "\n\nINSTRUCTIONS: Return your answer as a raw markdown string. \n1. Use LaTeX for equations (e.g., $E=mc^2$). \n2. Do NOT use code block fences (```) around the entire response. \n3. Do NOT include any meta-comments (e.g., \"Here is the markdown...\"). \n4. Just provide the raw content."
|
|
|
|
// ChatCreateRequest represents a request to create a new chat session
|
|
type ChatCreateRequest struct {
|
|
TranscriptionID string `json:"transcription_id" binding:"required"`
|
|
Model string `json:"model" binding:"required"`
|
|
Title string `json:"title,omitempty"`
|
|
}
|
|
|
|
// ChatMessageRequest represents a request to send a message
|
|
type ChatMessageRequest struct {
|
|
Content string `json:"content" binding:"required"`
|
|
}
|
|
|
|
// ChatSessionResponse represents a chat session response
|
|
type ChatSessionResponse struct {
|
|
ID string `json:"id"`
|
|
TranscriptionID string `json:"transcription_id"`
|
|
Title string `json:"title"`
|
|
Model string `json:"model"`
|
|
Provider string `json:"provider"`
|
|
IsActive bool `json:"is_active"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
MessageCount int `json:"message_count"`
|
|
LastActivityAt *time.Time `json:"last_activity_at,omitempty"`
|
|
LastMessage *ChatMessageResponse `json:"last_message,omitempty"`
|
|
}
|
|
|
|
// ChatMessageResponse represents a chat message response
|
|
type ChatMessageResponse struct {
|
|
ID uint `json:"id"`
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// ChatModelsResponse represents the available chat models
|
|
type ChatModelsResponse struct {
|
|
Models []string `json:"models"`
|
|
}
|
|
|
|
// ChatSessionWithMessages represents a chat session with messages
|
|
type ChatSessionWithMessages struct {
|
|
ChatSessionResponse
|
|
Messages []ChatMessageResponse `json:"messages"`
|
|
}
|
|
|
|
type Transcript struct {
|
|
Segments []Segment `json:"segments"`
|
|
}
|
|
|
|
type Segment struct {
|
|
Start float64 `json:"start"`
|
|
End float64 `json:"end"`
|
|
Text string `json:"text"`
|
|
Speaker string `json:"speaker"`
|
|
}
|
|
|
|
// getLLMService returns a provider-agnostic LLM service based on active config
|
|
func (h *Handler) getLLMService(ctx context.Context) (llm.Service, string, error) {
|
|
cfg, err := h.llmConfigRepo.GetActive(ctx)
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return nil, "", fmt.Errorf("no active LLM configuration found")
|
|
}
|
|
return nil, "", fmt.Errorf("failed to get LLM config: %w", err)
|
|
}
|
|
switch strings.ToLower(cfg.Provider) {
|
|
case "openai":
|
|
if cfg.APIKey == nil || *cfg.APIKey == "" {
|
|
return nil, cfg.Provider, fmt.Errorf("OpenAI API key not configured")
|
|
}
|
|
return llm.NewOpenAIService(*cfg.APIKey, cfg.OpenAIBaseURL), cfg.Provider, nil
|
|
case "ollama":
|
|
if cfg.BaseURL == nil || *cfg.BaseURL == "" {
|
|
return nil, cfg.Provider, fmt.Errorf("Ollama base URL not configured")
|
|
}
|
|
return llm.NewOllamaService(*cfg.BaseURL), cfg.Provider, nil
|
|
default:
|
|
return nil, cfg.Provider, fmt.Errorf("unsupported LLM provider: %s", cfg.Provider)
|
|
}
|
|
}
|
|
|
|
// @Summary Get available chat models
|
|
// @Description Get list of available OpenAI chat models
|
|
// @Tags chat
|
|
// @Produce json
|
|
// @Success 200 {object} ChatModelsResponse
|
|
// @Failure 400 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/models [get]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) GetChatModels(c *gin.Context) {
|
|
svc, _, err := h.getLLMService(c.Request.Context())
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
models, err := svc.GetModels(ctx)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch models: " + err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, ChatModelsResponse{Models: models})
|
|
}
|
|
|
|
// @Summary Create a new chat session
|
|
// @Description Create a new chat session for a transcription
|
|
// @Tags chat
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Param request body ChatCreateRequest true "Chat session creation request"
|
|
// @Success 201 {object} ChatSessionResponse
|
|
// @Failure 400 {object} map[string]string
|
|
// @Failure 404 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/sessions [post]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) CreateChatSession(c *gin.Context) {
|
|
var req ChatCreateRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Verify transcription exists and has completed transcript
|
|
transcription, err := h.jobRepo.FindByID(c.Request.Context(), req.TranscriptionID)
|
|
if err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "Transcription not found"})
|
|
return
|
|
}
|
|
|
|
if transcription.Status != models.StatusCompleted || transcription.Transcript == nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Transcription must be completed to create a chat session"})
|
|
return
|
|
}
|
|
|
|
// Verify LLM service is available
|
|
_, _, err = h.getLLMService(c.Request.Context())
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Create chat session
|
|
title := req.Title
|
|
if title == "" {
|
|
title = "New Chat Session"
|
|
}
|
|
|
|
now := time.Now()
|
|
chatSession := &models.ChatSession{
|
|
JobID: req.TranscriptionID, // Use same ID for JobID as TranscriptionID
|
|
TranscriptionID: req.TranscriptionID,
|
|
Title: title,
|
|
Model: req.Model,
|
|
Provider: "openai",
|
|
MessageCount: 0,
|
|
LastActivityAt: &now,
|
|
IsActive: true,
|
|
}
|
|
|
|
if err := h.chatRepo.Create(c.Request.Context(), chatSession); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create chat session"})
|
|
return
|
|
}
|
|
|
|
response := ChatSessionResponse{
|
|
ID: chatSession.ID,
|
|
TranscriptionID: chatSession.TranscriptionID,
|
|
Title: chatSession.Title,
|
|
Model: chatSession.Model,
|
|
Provider: chatSession.Provider,
|
|
IsActive: chatSession.IsActive,
|
|
CreatedAt: chatSession.CreatedAt,
|
|
UpdatedAt: chatSession.UpdatedAt,
|
|
MessageCount: chatSession.MessageCount,
|
|
LastActivityAt: chatSession.LastActivityAt,
|
|
}
|
|
|
|
c.JSON(http.StatusCreated, response)
|
|
}
|
|
|
|
// @Summary Get chat sessions for a transcription
|
|
// @Description Get all chat sessions for a specific transcription
|
|
// @Tags chat
|
|
// @Produce json
|
|
// @Param transcription_id path string true "Transcription ID"
|
|
// @Success 200 {array} ChatSessionResponse
|
|
// @Failure 400 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/transcriptions/{transcription_id}/sessions [get]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) GetChatSessions(c *gin.Context) {
|
|
transcriptionID := c.Param("transcription_id")
|
|
if transcriptionID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Transcription ID is required"})
|
|
return
|
|
}
|
|
|
|
sessions, err := h.chatRepo.ListByJob(c.Request.Context(), transcriptionID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get chat sessions"})
|
|
return
|
|
}
|
|
|
|
// Extract session IDs for batch queries
|
|
sessionIDs := make([]string, len(sessions))
|
|
for i, session := range sessions {
|
|
sessionIDs[i] = session.ID
|
|
}
|
|
|
|
// Batch query for message counts - eliminates N+1 problem
|
|
type MessageCount struct {
|
|
SessionID string `json:"session_id"`
|
|
Count int64 `json:"count"`
|
|
}
|
|
var messageCounts []MessageCount
|
|
database.DB.Model(&models.ChatMessage{}).
|
|
Select("chat_session_id as session_id, COUNT(*) as count").
|
|
Where("chat_session_id IN ?", sessionIDs).
|
|
Group("chat_session_id").
|
|
Scan(&messageCounts)
|
|
|
|
// Create message count lookup map
|
|
messageCountMap := make(map[string]int64)
|
|
for _, mc := range messageCounts {
|
|
messageCountMap[mc.SessionID] = mc.Count
|
|
}
|
|
|
|
// Batch query for last messages - eliminates N+1 problem
|
|
var lastMessages []models.ChatMessage
|
|
database.DB.Where(`id IN (
|
|
SELECT id FROM chat_messages cm1
|
|
WHERE cm1.chat_session_id IN ?
|
|
AND cm1.created_at = (
|
|
SELECT MAX(cm2.created_at)
|
|
FROM chat_messages cm2
|
|
WHERE cm2.chat_session_id = cm1.chat_session_id
|
|
)
|
|
)`, sessionIDs).Find(&lastMessages)
|
|
|
|
// Create last message lookup map
|
|
lastMessageMap := make(map[string]*ChatMessageResponse)
|
|
for _, msg := range lastMessages {
|
|
lastMessageMap[msg.ChatSessionID] = &ChatMessageResponse{
|
|
ID: msg.ID,
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
CreatedAt: msg.CreatedAt,
|
|
}
|
|
}
|
|
|
|
var responses []ChatSessionResponse
|
|
for _, session := range sessions {
|
|
responses = append(responses, ChatSessionResponse{
|
|
ID: session.ID,
|
|
TranscriptionID: session.TranscriptionID,
|
|
Title: session.Title,
|
|
Model: session.Model,
|
|
Provider: session.Provider,
|
|
IsActive: session.IsActive,
|
|
CreatedAt: session.CreatedAt,
|
|
UpdatedAt: session.UpdatedAt,
|
|
MessageCount: int(messageCountMap[session.ID]), // Use batch-loaded count
|
|
LastActivityAt: session.LastActivityAt,
|
|
LastMessage: lastMessageMap[session.ID], // Use batch-loaded last message
|
|
})
|
|
}
|
|
|
|
c.JSON(http.StatusOK, responses)
|
|
}
|
|
|
|
// @Summary Get a chat session with messages
|
|
// @Description Get a specific chat session with all its messages
|
|
// @Tags chat
|
|
// @Produce json
|
|
// @Param session_id path string true "Chat Session ID"
|
|
// @Success 200 {object} ChatSessionWithMessages
|
|
// @Failure 404 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/sessions/{session_id} [get]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) GetChatSession(c *gin.Context) {
|
|
sessionID := c.Param("session_id")
|
|
if sessionID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Session ID is required"})
|
|
return
|
|
}
|
|
|
|
session, err := h.chatRepo.GetSessionWithMessages(c.Request.Context(), sessionID)
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "Chat session not found"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get chat session"})
|
|
return
|
|
}
|
|
|
|
var messageResponses []ChatMessageResponse
|
|
for _, msg := range session.Messages {
|
|
messageResponses = append(messageResponses, ChatMessageResponse{
|
|
ID: msg.ID,
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
CreatedAt: msg.CreatedAt,
|
|
})
|
|
}
|
|
|
|
response := ChatSessionWithMessages{
|
|
ChatSessionResponse: ChatSessionResponse{
|
|
ID: session.ID,
|
|
TranscriptionID: session.TranscriptionID,
|
|
Title: session.Title,
|
|
Model: session.Model,
|
|
Provider: session.Provider,
|
|
IsActive: session.IsActive,
|
|
CreatedAt: session.CreatedAt,
|
|
UpdatedAt: session.UpdatedAt,
|
|
MessageCount: len(messageResponses),
|
|
LastActivityAt: session.LastActivityAt,
|
|
},
|
|
Messages: messageResponses,
|
|
}
|
|
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
// format time from transcription json as 00:00:00
|
|
func formatTime(seconds float64) string {
|
|
s := int(math.Round(seconds))
|
|
|
|
hours := s / 3600
|
|
minutes := (s % 3600) / 60
|
|
secs := s % 60
|
|
|
|
return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, secs)
|
|
}
|
|
|
|
// @Summary Send a message to a chat session
|
|
// @Description Send a message to a chat session and get streaming response
|
|
// @Tags chat
|
|
// @Accept json
|
|
// @Produce text/plain
|
|
// @Param session_id path string true "Chat Session ID"
|
|
// @Param message body ChatMessageRequest true "Message content"
|
|
// @Success 200 {string} string "Streaming response"
|
|
// @Failure 400 {object} map[string]string
|
|
// @Failure 404 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/sessions/{session_id}/messages [post]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) SendChatMessage(c *gin.Context) {
|
|
sessionID := c.Param("session_id")
|
|
if sessionID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Session ID is required"})
|
|
return
|
|
}
|
|
|
|
var req ChatMessageRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Get chat session
|
|
session, err := h.chatRepo.GetSessionWithTranscription(c.Request.Context(), sessionID)
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "Chat session not found"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get chat session"})
|
|
return
|
|
}
|
|
|
|
// Get LLM service
|
|
svc, _, err := h.getLLMService(c.Request.Context())
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Save user message
|
|
userMessage := &models.ChatMessage{
|
|
SessionID: sessionID,
|
|
ChatSessionID: sessionID,
|
|
Role: "user",
|
|
Content: req.Content,
|
|
}
|
|
|
|
if err := h.chatRepo.AddMessage(c.Request.Context(), userMessage); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save message"})
|
|
return
|
|
}
|
|
|
|
// Check if this is the first user message and update session title
|
|
messages, err := h.chatRepo.GetMessages(c.Request.Context(), sessionID, 0)
|
|
if err == nil {
|
|
userMsgCount := 0
|
|
for _, m := range messages {
|
|
if m.Role == "user" {
|
|
userMsgCount++
|
|
}
|
|
}
|
|
if userMsgCount == 1 {
|
|
// Generate a title based on the first message
|
|
title := generateChatTitle(req.Content)
|
|
session.Title = title
|
|
h.chatRepo.Update(c.Request.Context(), session)
|
|
}
|
|
}
|
|
|
|
// Get context window
|
|
contextWindow, err := svc.GetContextWindow(c.Request.Context(), session.Model)
|
|
if err != nil {
|
|
fmt.Printf("Failed to get context window for model %s: %v. Using default 4096.\n", session.Model, err)
|
|
contextWindow = 4096
|
|
}
|
|
|
|
// Build OpenAI messages including transcript context
|
|
var openaiMessages []llm.ChatMessage
|
|
var currentTokenCount int
|
|
var transcriptContext string
|
|
|
|
// Fallback: If transcript wasn't loaded via Preload, fetch it directly from the job repository
|
|
if session.Transcription.Transcript == nil || *session.Transcription.Transcript == "" {
|
|
fmt.Printf("Debug: Transcript not loaded via Preload for session %s (TranscriptionID: %s), fetching directly...\n", sessionID, session.TranscriptionID)
|
|
job, jobErr := h.jobRepo.FindByID(c.Request.Context(), session.TranscriptionID)
|
|
if jobErr == nil && job != nil && job.Transcript != nil && *job.Transcript != "" {
|
|
session.Transcription.Transcript = job.Transcript
|
|
fmt.Printf("Debug: Direct fetch succeeded, transcript length: %d\n", len(*job.Transcript))
|
|
} else {
|
|
fmt.Printf("Debug: Direct fetch failed or transcript empty. Error: %v\n", jobErr)
|
|
}
|
|
}
|
|
|
|
// Add system message with transcript context
|
|
if session.Transcription.Transcript != nil && *session.Transcription.Transcript != "" {
|
|
transcript := *session.Transcription.Transcript
|
|
fmt.Printf("Debug: Transcript found for session %s. Length: %d\n", sessionID, len(transcript))
|
|
|
|
// Parse transcript json segments and build string with format: [SPEAKER_01] [00:00:17 - 00:00:19] Nej, det var tråkigt att höra.
|
|
var t Transcript
|
|
if err := json.Unmarshal([]byte(transcript), &t); err != nil {
|
|
fmt.Printf("Error parsing transcript JSON for session %s: %v\n", sessionID, err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse transcript data"})
|
|
return
|
|
}
|
|
|
|
fmt.Printf("Debug: Parsed %d segments from transcript\n", len(t.Segments))
|
|
|
|
var sb strings.Builder
|
|
|
|
// Get speaker mappings
|
|
mappings, err := h.speakerMappingRepo.ListByJob(c.Request.Context(), session.TranscriptionID)
|
|
speakerMap := make(map[string]string)
|
|
if err == nil {
|
|
for _, m := range mappings {
|
|
speakerMap[m.OriginalSpeaker] = m.CustomName
|
|
}
|
|
} else {
|
|
fmt.Printf("Failed to get speaker mappings for job %s: %v\n", session.TranscriptionID, err)
|
|
}
|
|
|
|
for _, seg := range t.Segments {
|
|
start := formatTime(seg.Start)
|
|
end := formatTime(seg.End)
|
|
|
|
speakerName := seg.Speaker
|
|
if customName, ok := speakerMap[speakerName]; ok {
|
|
speakerName = customName
|
|
}
|
|
|
|
fmt.Fprintf(&sb, "[%s] [%s - %s] %s\n",
|
|
speakerName,
|
|
start,
|
|
end,
|
|
strings.TrimSpace(seg.Text),
|
|
)
|
|
}
|
|
|
|
cleanTranscript := sb.String()
|
|
fmt.Printf("Debug: Clean transcript length: %d\n", len(cleanTranscript))
|
|
|
|
// Build transcript context - will be prepended to first user message for better model compatibility
|
|
transcriptContext = fmt.Sprintf("You are analyzing the following transcript. Use this transcript to answer questions:\n\n---TRANSCRIPT START---\n%s\n---TRANSCRIPT END---\n\n", cleanTranscript)
|
|
|
|
fmt.Printf("Injecting transcript of length %d into chat context for session %s\n", len(transcriptContext), sessionID)
|
|
|
|
// Check if transcript itself exceeds context (leaving some room for response)
|
|
// Estimate 1 token ~= 4 chars
|
|
transcriptTokens := len(transcriptContext) / 4
|
|
if transcriptTokens > contextWindow-500 { // Leave 500 tokens for response/history
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Transcript is too long for this model's context window (estimated %d tokens, limit %d). Please use a model with a larger context window.", transcriptTokens, contextWindow)})
|
|
return
|
|
}
|
|
currentTokenCount += transcriptTokens
|
|
} else {
|
|
fmt.Printf("Warning: Transcript is nil or empty for chat session %s. Transcription ID: %s\n", sessionID, session.TranscriptionID)
|
|
if session.Transcription.ID == "" {
|
|
fmt.Println("Warning: Session Transcription relation seems missing or empty")
|
|
}
|
|
}
|
|
|
|
// Add conversation history with transcript context prepended to first user message
|
|
for i, msg := range messages {
|
|
msgContent := msg.Content
|
|
// Prepend transcript context to the first user message for better model compatibility
|
|
// (Some models like Qwen3 don't properly handle system messages)
|
|
if i == 0 && msg.Role == "user" && transcriptContext != "" {
|
|
msgContent = transcriptContext + "User question: " + msg.Content
|
|
fmt.Printf("Debug: Prepended transcript to first user message\n")
|
|
}
|
|
msgTokens := len(msgContent) / 4
|
|
|
|
// Inject style prompt for user messages (in-memory only, not saved to DB)
|
|
finalContent := msgContent
|
|
if msg.Role == "user" {
|
|
finalContent += StylePrompt
|
|
}
|
|
|
|
openaiMessages = append(openaiMessages, llm.ChatMessage{
|
|
Role: msg.Role,
|
|
Content: finalContent,
|
|
})
|
|
currentTokenCount += msgTokens
|
|
}
|
|
|
|
// Intelligent context trimming: if context exceeds limit, remove oldest messages
|
|
// Keep the first message (with transcript context) and trim from the middle
|
|
trimmedCount := 0
|
|
for currentTokenCount > contextWindow && len(openaiMessages) > 2 {
|
|
// Remove the second message (oldest after the context-bearing first message)
|
|
removed := openaiMessages[1]
|
|
removedTokens := len(removed.Content) / 4
|
|
openaiMessages = append(openaiMessages[:1], openaiMessages[2:]...)
|
|
currentTokenCount -= removedTokens
|
|
trimmedCount++
|
|
fmt.Printf("Debug: Trimmed message to fit context. Removed %d tokens, new count: %d/%d\n", removedTokens, currentTokenCount, contextWindow)
|
|
}
|
|
|
|
if trimmedCount > 0 {
|
|
fmt.Printf("Debug: Trimmed %d messages to fit context window\n", trimmedCount)
|
|
}
|
|
|
|
// Final check - if still over limit after trimming all possible messages, return error
|
|
if currentTokenCount > contextWindow {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Transcript alone exceeds model context limit (%d tokens > %d). Please use a model with larger context window.", currentTokenCount, contextWindow)})
|
|
return
|
|
}
|
|
|
|
// Set up streaming response with context info headers
|
|
c.Header("Content-Type", "text/plain; charset=utf-8")
|
|
c.Header("Cache-Control", "no-cache, no-store, must-revalidate")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("Transfer-Encoding", "chunked")
|
|
c.Header("X-Accel-Buffering", "no") // Disable nginx buffering
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Expose-Headers", "X-Context-Used, X-Context-Limit, X-Messages-Trimmed")
|
|
c.Header("X-Context-Used", fmt.Sprintf("%d", currentTokenCount))
|
|
c.Header("X-Context-Limit", fmt.Sprintf("%d", contextWindow))
|
|
c.Header("X-Messages-Trimmed", fmt.Sprintf("%d", trimmedCount))
|
|
c.Status(http.StatusOK) // Start the response immediately
|
|
|
|
// Stream the response
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Minute)
|
|
defer cancel()
|
|
|
|
// Use model defaults: do not set temperature explicitly
|
|
contentChan, errorChan := svc.ChatCompletionStream(ctx, session.Model, openaiMessages, 0.0)
|
|
|
|
var assistantResponse strings.Builder
|
|
for {
|
|
select {
|
|
case content, ok := <-contentChan:
|
|
if !ok {
|
|
// Channel closed, save complete response and return
|
|
if assistantResponse.Len() > 0 {
|
|
assistantMessage := &models.ChatMessage{
|
|
SessionID: sessionID,
|
|
ChatSessionID: sessionID,
|
|
Role: "assistant",
|
|
Content: assistantResponse.String(),
|
|
}
|
|
h.chatRepo.AddMessage(context.Background(), assistantMessage)
|
|
|
|
// Update session updated_at, message count, and last activity
|
|
now := time.Now()
|
|
session.UpdatedAt = now
|
|
session.LastActivityAt = &now
|
|
session.MessageCount += 2 // +2 for user + assistant message
|
|
h.chatRepo.Update(context.Background(), session)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Write content to response
|
|
c.Writer.WriteString(content)
|
|
c.Writer.Flush()
|
|
assistantResponse.WriteString(content)
|
|
|
|
case err := <-errorChan:
|
|
if err != nil {
|
|
// If streaming is not supported for this model/org, fall back to non-streaming
|
|
errStr := err.Error()
|
|
if strings.Contains(errStr, "\"param\": \"stream\"") || strings.Contains(errStr, "unsupported_value") || strings.Contains(errStr, "must be verified to stream") {
|
|
resp, err2 := svc.ChatCompletion(ctx, session.Model, openaiMessages, 0.0)
|
|
if err2 != nil || resp == nil || len(resp.Choices) == 0 {
|
|
c.Writer.WriteString("\nError: " + err2.Error())
|
|
c.Writer.Flush()
|
|
return
|
|
}
|
|
content := resp.Choices[0].Message.Content
|
|
c.Writer.WriteString(content)
|
|
c.Writer.Flush()
|
|
assistantResponse.WriteString(content)
|
|
|
|
if assistantResponse.Len() > 0 {
|
|
assistantMessage := &models.ChatMessage{
|
|
SessionID: sessionID,
|
|
ChatSessionID: sessionID,
|
|
Role: "assistant",
|
|
Content: assistantResponse.String(),
|
|
}
|
|
h.chatRepo.AddMessage(context.Background(), assistantMessage)
|
|
|
|
// Update session updated_at, message count, and last activity
|
|
now := time.Now()
|
|
session.UpdatedAt = now
|
|
session.LastActivityAt = &now
|
|
session.MessageCount += 2 // +2 for user + assistant message
|
|
h.chatRepo.Update(context.Background(), session)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Otherwise, return the error to the client
|
|
c.Writer.WriteString("\nError: " + err.Error())
|
|
c.Writer.Flush()
|
|
return
|
|
}
|
|
|
|
case <-ctx.Done():
|
|
c.Writer.WriteString("\nRequest timeout")
|
|
c.Writer.Flush()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// @Summary Update chat session title
|
|
// @Description Update the title of a chat session
|
|
// @Tags chat
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Param session_id path string true "Chat Session ID"
|
|
// @Param request body map[string]string true "Title update request"
|
|
// @Success 200 {object} ChatSessionResponse
|
|
// @Failure 400 {object} map[string]string
|
|
// @Failure 404 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/sessions/{session_id}/title [put]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) UpdateChatSessionTitle(c *gin.Context) {
|
|
sessionID := c.Param("session_id")
|
|
if sessionID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Session ID is required"})
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Title string `json:"title" binding:"required,min=1,max=255"`
|
|
}
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
session, err := h.chatRepo.FindByID(c.Request.Context(), sessionID)
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "Chat session not found"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get chat session"})
|
|
return
|
|
}
|
|
|
|
session.Title = req.Title
|
|
if err := h.chatRepo.Update(c.Request.Context(), session); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update title"})
|
|
return
|
|
}
|
|
|
|
response := ChatSessionResponse{
|
|
ID: session.ID,
|
|
TranscriptionID: session.TranscriptionID,
|
|
Title: session.Title,
|
|
Model: session.Model,
|
|
Provider: session.Provider,
|
|
IsActive: session.IsActive,
|
|
CreatedAt: session.CreatedAt,
|
|
UpdatedAt: session.UpdatedAt,
|
|
MessageCount: session.MessageCount,
|
|
LastActivityAt: session.LastActivityAt,
|
|
}
|
|
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
// @Summary Delete a chat session
|
|
// @Description Delete a chat session and all its messages
|
|
// @Tags chat
|
|
// @Produce json
|
|
// @Param session_id path string true "Chat Session ID"
|
|
// @Success 204
|
|
// @Failure 404 {object} map[string]string
|
|
// @Failure 500 {object} map[string]string
|
|
// @Router /api/v1/chat/sessions/{session_id} [delete]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) DeleteChatSession(c *gin.Context) {
|
|
sessionID := c.Param("session_id")
|
|
if sessionID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Session ID is required"})
|
|
return
|
|
}
|
|
|
|
if err := h.chatRepo.DeleteSession(c.Request.Context(), sessionID); err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "Chat session not found"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete chat session"})
|
|
return
|
|
}
|
|
|
|
c.Status(http.StatusNoContent)
|
|
}
|
|
|
|
// generateChatTitle generates a title based on the first user message
|
|
func generateChatTitle(message string) string {
|
|
// Truncate to reasonable length and clean up
|
|
title := strings.TrimSpace(message)
|
|
if len(title) > 50 {
|
|
title = title[:47] + "..."
|
|
}
|
|
|
|
// Remove newlines and replace with spaces
|
|
title = strings.ReplaceAll(title, "\n", " ")
|
|
title = strings.ReplaceAll(title, "\r", " ")
|
|
|
|
// Replace multiple spaces with single space
|
|
for strings.Contains(title, " ") {
|
|
title = strings.ReplaceAll(title, " ", " ")
|
|
}
|
|
|
|
return title
|
|
}
|
|
|
|
// AutoGenerateChatTitle generates a session title using the configured LLM based on conversation history
|
|
// @Summary Auto-generate chat session title
|
|
// @Description Uses the configured LLM to summarize the first exchange into a concise title. Only updates if the current title appears default/user-unset.
|
|
// @Tags chat
|
|
// @Produce json
|
|
// @Param session_id path string true "Chat Session ID"
|
|
// @Success 200 {object} ChatSessionResponse
|
|
// @Failure 400 {object} map[string]string
|
|
// @Failure 404 {object} map[string]string
|
|
// @Router /api/v1/chat/sessions/{session_id}/title/auto [post]
|
|
// @Security ApiKeyAuth
|
|
// @Security BearerAuth
|
|
func (h *Handler) AutoGenerateChatTitle(c *gin.Context) {
|
|
sessionID := c.Param("session_id")
|
|
if sessionID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Session ID is required"})
|
|
return
|
|
}
|
|
|
|
// Load session
|
|
session, err := h.chatRepo.FindByID(c.Request.Context(), sessionID)
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "Chat session not found"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get chat session"})
|
|
return
|
|
}
|
|
|
|
// Determine if title appears user-unset (default or derived from first user message)
|
|
isDefaultTitle := strings.EqualFold(strings.TrimSpace(session.Title), "New Chat Session")
|
|
|
|
// Load first user message to compare against simple generator
|
|
msgs, err := h.chatRepo.GetMessages(c.Request.Context(), sessionID, 0)
|
|
if err == nil && len(msgs) > 0 {
|
|
var firstUser *models.ChatMessage
|
|
for _, m := range msgs {
|
|
if m.Role == "user" {
|
|
firstUser = &m
|
|
break
|
|
}
|
|
}
|
|
if firstUser != nil {
|
|
simple := generateChatTitle(firstUser.Content)
|
|
if strings.EqualFold(strings.TrimSpace(session.Title), strings.TrimSpace(simple)) {
|
|
isDefaultTitle = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if !isDefaultTitle {
|
|
// Respect user-edited titles; return current session response
|
|
c.JSON(http.StatusOK, ChatSessionResponse{
|
|
ID: session.ID,
|
|
TranscriptionID: session.TranscriptionID,
|
|
Title: session.Title,
|
|
Model: session.Model,
|
|
Provider: session.Provider,
|
|
IsActive: session.IsActive,
|
|
CreatedAt: session.CreatedAt,
|
|
UpdatedAt: session.UpdatedAt,
|
|
MessageCount: session.MessageCount,
|
|
LastActivityAt: session.LastActivityAt,
|
|
})
|
|
return
|
|
}
|
|
|
|
// Fetch recent messages (first user + first assistant ideally)
|
|
// We already fetched all messages above, but let's limit to 6 for context if we re-query or just use the slice
|
|
// Since we fetched all messages (limit 0), we can just use the slice.
|
|
// But wait, GetMessages(0) fetches ALL messages. If the chat is huge, this is bad.
|
|
// We should probably limit the fetch for the first check too, or just fetch enough.
|
|
// The original code fetched "first user" then "limit 6".
|
|
// Let's optimize.
|
|
// Fetch first 6 messages.
|
|
recentMsgs, err := h.chatRepo.GetMessages(c.Request.Context(), sessionID, 6)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get messages"})
|
|
return
|
|
}
|
|
if len(recentMsgs) == 0 {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Not enough conversation to generate a title"})
|
|
return
|
|
}
|
|
|
|
// Prepare LLM messages
|
|
prompt := `You are an expert at creating concise, meaningful titles for conversations. Based on the conversation below, generate a short, descriptive title (3-8 words) that captures the main topic or purpose.
|
|
|
|
Guidelines:
|
|
- Use Title Case formatting (Every Important Word Capitalized)
|
|
- Be specific and descriptive, not generic
|
|
- Focus on the core subject matter or task being discussed
|
|
- Avoid generic terms like "Chat", "Discussion", "Question", "Conversation"
|
|
- Avoid mentioning AI, Assistant, or model names
|
|
- No quotation marks, brackets, or punctuation at the end
|
|
- No emojis or special characters
|
|
- Make it something a user would easily recognize and remember
|
|
|
|
Examples of good titles:
|
|
- "Python Data Analysis Tutorial"
|
|
- "Marketing Strategy Planning Session"
|
|
- "JavaScript Debugging Help"
|
|
- "Recipe for Chocolate Cake"
|
|
- "React Component Architecture"
|
|
|
|
Return only the title, nothing else.`
|
|
|
|
var chatMsgs []llm.ChatMessage
|
|
chatMsgs = append(chatMsgs, llm.ChatMessage{Role: "system", Content: prompt})
|
|
|
|
for _, msg := range recentMsgs {
|
|
role := msg.Role
|
|
if role != "user" && role != "assistant" {
|
|
role = "user"
|
|
}
|
|
// Truncate very long messages to keep within token limits
|
|
content := msg.Content
|
|
if len(content) > 500 {
|
|
content = content[:497] + "..."
|
|
}
|
|
chatMsgs = append(chatMsgs, llm.ChatMessage{Role: role, Content: content})
|
|
}
|
|
|
|
// Use configured LLM service
|
|
svc, _, err := h.getLLMService(c.Request.Context())
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
|
defer cancel()
|
|
// Use slightly higher temperature for more creative titles
|
|
// Use model defaults: do not set temperature explicitly
|
|
resp, err := svc.ChatCompletion(ctx, session.Model, chatMsgs, 0.0)
|
|
if err != nil || resp == nil || len(resp.Choices) == 0 || resp.Choices[0].Message.Content == "" {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate title"})
|
|
return
|
|
}
|
|
|
|
title := strings.TrimSpace(resp.Choices[0].Message.Content)
|
|
// Strip wrapping quotes/backticks if present
|
|
if len(title) >= 2 {
|
|
if (strings.HasPrefix(title, "\"") && strings.HasSuffix(title, "\"")) ||
|
|
(strings.HasPrefix(title, "'") && strings.HasSuffix(title, "'")) ||
|
|
(strings.HasPrefix(title, "`") && strings.HasSuffix(title, "`")) {
|
|
title = strings.Trim(title, "'\"`")
|
|
}
|
|
}
|
|
// Sanitize: enforce max length and single line
|
|
title = strings.ReplaceAll(title, "\n", " ")
|
|
title = strings.ReplaceAll(title, "\r", " ")
|
|
if len(title) > 60 {
|
|
title = title[:57] + "..."
|
|
}
|
|
|
|
// Update session title
|
|
if err := database.DB.Model(&session).Update("title", title).Error; err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update title"})
|
|
return
|
|
}
|
|
|
|
// Reload to return response
|
|
var updated models.ChatSession
|
|
if err := database.DB.Where("id = ?", sessionID).First(&updated).Error; err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load updated session"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, ChatSessionResponse{
|
|
ID: updated.ID,
|
|
TranscriptionID: updated.TranscriptionID,
|
|
Title: updated.Title,
|
|
Model: updated.Model,
|
|
Provider: updated.Provider,
|
|
IsActive: updated.IsActive,
|
|
CreatedAt: updated.CreatedAt,
|
|
UpdatedAt: updated.UpdatedAt,
|
|
MessageCount: updated.MessageCount,
|
|
LastActivityAt: updated.LastActivityAt,
|
|
})
|
|
}
|