mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-30 07:46:16 +00:00
because i messedup
This commit is contained in:
@@ -217,3 +217,87 @@ func (s *OllamaService) ChatCompletionStream(ctx context.Context, model string,
|
||||
|
||||
return contentChan, errorChan
|
||||
}
|
||||
|
||||
// ollamaShowRequest represents the request to show model info
|
||||
type ollamaShowRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// ollamaShowResponse represents the response from show model info
|
||||
type ollamaShowResponse struct {
|
||||
ModelInfo map[string]interface{} `json:"model_info"`
|
||||
Details struct {
|
||||
ContextLength int `json:"context_length"` // Some versions return this
|
||||
} `json:"details"`
|
||||
Parameters string `json:"parameters"`
|
||||
}
|
||||
|
||||
// GetContextWindow returns the context window size for a given Ollama model
|
||||
func (s *OllamaService) GetContextWindow(ctx context.Context, model string) (int, error) {
|
||||
// Default to 4096 if we can't determine
|
||||
defaultContext := 4096
|
||||
|
||||
reqBody := ollamaShowRequest{
|
||||
Name: model,
|
||||
}
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return defaultContext, nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", s.baseURL+"/api/show", bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return defaultContext, nil
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return defaultContext, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return defaultContext, nil
|
||||
}
|
||||
|
||||
var showResp ollamaShowResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&showResp); err != nil {
|
||||
return defaultContext, nil
|
||||
}
|
||||
|
||||
// Try to find context length in details
|
||||
// Note: Ollama API response format varies.
|
||||
// Sometimes it's in model_info -> llama.context_length
|
||||
// Sometimes it's in parameters string "num_ctx 8192"
|
||||
|
||||
// Check model_info
|
||||
if showResp.ModelInfo != nil {
|
||||
for k, v := range showResp.ModelInfo {
|
||||
if strings.Contains(k, "context_length") {
|
||||
if f, ok := v.(float64); ok {
|
||||
return int(f), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse parameters string
|
||||
if showResp.Parameters != "" {
|
||||
lines := strings.Split(showResp.Parameters, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "num_ctx") {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 2 {
|
||||
var ctxLen int
|
||||
if _, err := fmt.Sscanf(parts[1], "%d", &ctxLen); err == nil {
|
||||
return ctxLen, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultContext, nil
|
||||
}
|
||||
|
||||
@@ -310,3 +310,24 @@ func (s *OpenAIService) ValidateAPIKey(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetContextWindow returns the context window size for a given OpenAI model
|
||||
func (s *OpenAIService) GetContextWindow(ctx context.Context, model string) (int, error) {
|
||||
// Known context windows for OpenAI models
|
||||
// As of late 2024/early 2025
|
||||
switch {
|
||||
case strings.HasPrefix(model, "gpt-4-turbo"), strings.HasPrefix(model, "gpt-4o"):
|
||||
return 128000, nil
|
||||
case strings.HasPrefix(model, "gpt-4-32k"):
|
||||
return 32768, nil
|
||||
case strings.HasPrefix(model, "gpt-4"):
|
||||
return 8192, nil
|
||||
case strings.HasPrefix(model, "gpt-3.5-turbo-16k"):
|
||||
return 16385, nil
|
||||
case strings.HasPrefix(model, "gpt-3.5-turbo"):
|
||||
return 16385, nil // Most recent gpt-3.5-turbo is 16k
|
||||
default:
|
||||
// Default fallback
|
||||
return 4096, nil
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user