mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-29 15:26:02 +00:00
Streamline profile provider parameters
This commit is contained in:
@@ -162,24 +162,28 @@ Commit:
|
||||
|
||||
## BE-ENG-PROVIDER-Sprint 5: Profile Parameter Model Cleanup
|
||||
|
||||
Status: planned
|
||||
Status: complete
|
||||
|
||||
Planned tasks:
|
||||
|
||||
- [ ] Make profiles store pipeline step options as descriptor-keyed maps.
|
||||
- [ ] Validate against provider model descriptors.
|
||||
- [ ] Remove active flat `ASRParams` execution usage.
|
||||
- [ ] Return schemas for frontend dynamic ASR controls.
|
||||
- [x] Make profiles store pipeline step options as descriptor-keyed maps.
|
||||
- [x] Validate against provider model descriptors.
|
||||
- [x] Remove active flat `ASRParams` execution usage.
|
||||
- [x] Return schemas for frontend dynamic ASR controls.
|
||||
|
||||
Acceptance checks:
|
||||
|
||||
- [ ] Frontend can render ASR profile controls from descriptors.
|
||||
- [ ] Unsupported options fail validation.
|
||||
- [ ] Adding local model support does not require backend schema edits.
|
||||
- [x] Frontend can render ASR profile controls from descriptors.
|
||||
- [x] Unsupported options fail validation.
|
||||
- [x] Adding local model support does not require backend schema edits.
|
||||
|
||||
Verification:
|
||||
|
||||
- [ ] Not started.
|
||||
- [x] `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/profile`
|
||||
- [x] `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/transcription/orchestrator`
|
||||
- [x] `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/api`
|
||||
- [x] `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/transcription/...`
|
||||
- [x] `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/models`
|
||||
|
||||
Commit:
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# BE-ENG-PROVIDER Sprint 5: Profile Parameter Cleanup
|
||||
|
||||
Status: complete.
|
||||
|
||||
## Changes
|
||||
|
||||
- Profile API options are now pipeline-only.
|
||||
- Model/runtime/chunking/decoding values must live under the owning pipeline step `options` map.
|
||||
- Legacy top-level profile knobs such as `language`, `threads`, `chunking_strategy`, and `decoding_method` are rejected when present.
|
||||
- Profile normalization validates step options against provider model descriptor schemas through the model catalog.
|
||||
- Profile persistence derives display columns from the normalized transcription step instead of copying flat `ASRParams` fields.
|
||||
- The orchestrator no longer fabricates a default execution pipeline from flat job fields.
|
||||
|
||||
## Verification
|
||||
|
||||
- `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/profile`
|
||||
- `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/transcription/orchestrator`
|
||||
- `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/api`
|
||||
- `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/transcription/...`
|
||||
- `GOCACHE=/private/tmp/scriberr-go-cache go test ./internal/models`
|
||||
|
||||
## Notes
|
||||
|
||||
- `ASRParams` still contains historical flat fields for migration/runtime cleanup follow-up, but profile creation, profile response, and orchestrator execution no longer depend on them.
|
||||
- Frontend profile controls should source field definitions from provider model descriptors and write values into `options.pipeline[].options`.
|
||||
@@ -255,8 +255,16 @@ func TestFileReadyAutoTranscribesWithDefaultProfile(t *testing.T) {
|
||||
"name": "Default profile",
|
||||
"is_default": true,
|
||||
"options": map[string]any{
|
||||
"pipeline": pipelineRequest("transcription", "whisper-small", "diarization", "diarization-default"),
|
||||
"language": "en",
|
||||
"pipeline": []map[string]any{
|
||||
{
|
||||
"kind": "transcription",
|
||||
"model": "whisper-small",
|
||||
"options": map[string]any{
|
||||
"sherpa.whisper.language": "en",
|
||||
},
|
||||
},
|
||||
{"kind": "diarization", "model": "diarization-default"},
|
||||
},
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusCreated, resp.Code)
|
||||
|
||||
@@ -147,94 +147,40 @@ func validateProfileInput(c *gin.Context, name string, options profileOptionsReq
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "pipeline is required", stringPtr("options.pipeline"))
|
||||
return false
|
||||
}
|
||||
if options.Language != nil && strings.TrimSpace(*options.Language) != "" && !validLanguage(strings.TrimSpace(*options.Language)) {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "language is invalid", stringPtr("options.language"))
|
||||
return false
|
||||
}
|
||||
if task := strings.TrimSpace(options.Task); task != "" && task != "transcribe" && task != "translate" {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "task is invalid", stringPtr("options.task"))
|
||||
return false
|
||||
}
|
||||
if method := strings.TrimSpace(options.DecodingMethod); method != "" && method != "greedy_search" && method != "modified_beam_search" {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "decoding method is invalid", stringPtr("options.decoding_method"))
|
||||
return false
|
||||
}
|
||||
if chunking := strings.ToLower(strings.TrimSpace(options.ChunkingStrategy)); chunking != "" && chunking != "fixed" && chunking != "vad" {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "chunking strategy is invalid", stringPtr("options.chunking_strategy"))
|
||||
return false
|
||||
}
|
||||
if options.Threads < 0 {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "threads must be zero or greater", stringPtr("options.threads"))
|
||||
return false
|
||||
}
|
||||
if options.TailPaddings != nil && (*options.TailPaddings < -1 || *options.TailPaddings > 16) {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "tail paddings is invalid", stringPtr("options.tail_paddings"))
|
||||
return false
|
||||
}
|
||||
if options.NumSpeakers < 0 {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "number of speakers must be zero or greater", stringPtr("options.num_speakers"))
|
||||
return false
|
||||
}
|
||||
if options.DiarizationThreshold < 0 || options.DiarizationThreshold > 1 {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "diarization threshold is invalid", stringPtr("options.diarization_threshold"))
|
||||
return false
|
||||
}
|
||||
if options.MinDurationOn < 0 || options.MinDurationOn > 2 {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "minimum speech duration is invalid", stringPtr("options.min_duration_on"))
|
||||
return false
|
||||
}
|
||||
if options.MinDurationOff < 0 || options.MinDurationOff > 2 {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "minimum silence duration is invalid", stringPtr("options.min_duration_off"))
|
||||
if field, ok := legacyProfileOptionField(options); ok {
|
||||
writeError(c, http.StatusUnprocessableEntity, "VALIDATION_ERROR", "profile option must be configured on the owning pipeline step", stringPtr(field))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
func profileParams(options profileOptionsRequest) models.ASRParams {
|
||||
task := strings.TrimSpace(options.Task)
|
||||
if task == "" {
|
||||
task = "transcribe"
|
||||
}
|
||||
decodingMethod := strings.TrimSpace(options.DecodingMethod)
|
||||
if decodingMethod == "" {
|
||||
decodingMethod = "greedy_search"
|
||||
}
|
||||
chunkingStrategy := strings.ToLower(strings.TrimSpace(options.ChunkingStrategy))
|
||||
if chunkingStrategy == "" {
|
||||
chunkingStrategy = "fixed"
|
||||
}
|
||||
var language *string
|
||||
if options.Language != nil {
|
||||
trimmed := strings.TrimSpace(*options.Language)
|
||||
if trimmed != "" && trimmed != "auto" {
|
||||
language = &trimmed
|
||||
}
|
||||
}
|
||||
diarizationThreshold := options.DiarizationThreshold
|
||||
if diarizationThreshold == 0 {
|
||||
diarizationThreshold = 0.5
|
||||
}
|
||||
minDurationOn := options.MinDurationOn
|
||||
if minDurationOn == 0 {
|
||||
minDurationOn = 0.2
|
||||
}
|
||||
minDurationOff := options.MinDurationOff
|
||||
if minDurationOff == 0 {
|
||||
minDurationOff = 0.3
|
||||
}
|
||||
return models.ASRParams{
|
||||
Pipeline: options.Pipeline,
|
||||
Language: language,
|
||||
Task: task,
|
||||
Threads: options.Threads,
|
||||
TailPaddings: options.TailPaddings,
|
||||
EnableTokenTimestamps: boolPtr(true),
|
||||
EnableSegmentTimestamps: boolPtr(true),
|
||||
DecodingMethod: decodingMethod,
|
||||
ChunkingStrategy: chunkingStrategy,
|
||||
NumSpeakers: options.NumSpeakers,
|
||||
DiarizationThreshold: diarizationThreshold,
|
||||
MinDurationOn: minDurationOn,
|
||||
MinDurationOff: minDurationOff,
|
||||
return models.ASRParams{Pipeline: options.Pipeline}
|
||||
}
|
||||
|
||||
func legacyProfileOptionField(options profileOptionsRequest) (string, bool) {
|
||||
switch {
|
||||
case options.Language != nil:
|
||||
return "options.language", true
|
||||
case strings.TrimSpace(options.Task) != "":
|
||||
return "options.task", true
|
||||
case options.Threads != nil:
|
||||
return "options.threads", true
|
||||
case options.TailPaddings != nil:
|
||||
return "options.tail_paddings", true
|
||||
case strings.TrimSpace(options.DecodingMethod) != "":
|
||||
return "options.decoding_method", true
|
||||
case strings.TrimSpace(options.ChunkingStrategy) != "":
|
||||
return "options.chunking_strategy", true
|
||||
case options.NumSpeakers != nil:
|
||||
return "options.num_speakers", true
|
||||
case options.DiarizationThreshold != nil:
|
||||
return "options.diarization_threshold", true
|
||||
case options.MinDurationOn != nil:
|
||||
return "options.min_duration_on", true
|
||||
case options.MinDurationOff != nil:
|
||||
return "options.min_duration_off", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
func (h *Handler) profileByPublicID(c *gin.Context, publicID string) (*models.TranscriptionProfile, bool) {
|
||||
|
||||
@@ -32,12 +32,15 @@ func TestProfileCRUDAndDefaultSelection(t *testing.T) {
|
||||
"description": "Fast local transcription",
|
||||
"is_default": true,
|
||||
"options": map[string]any{
|
||||
"pipeline": pipelineRequest("transcription", "whisper-base"),
|
||||
"language": "en",
|
||||
"chunking_strategy": "vad",
|
||||
"threads": 2,
|
||||
"enable_token_timestamps": false,
|
||||
"enable_segment_timestamps": false,
|
||||
"pipeline": []map[string]any{{
|
||||
"kind": "transcription",
|
||||
"model": "whisper-base",
|
||||
"options": map[string]any{
|
||||
"sherpa.whisper.language": "en",
|
||||
"chunking.mode": "vad",
|
||||
"runtime.num_threads": 2,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusCreated, resp.Code)
|
||||
@@ -46,27 +49,26 @@ func TestProfileCRUDAndDefaultSelection(t *testing.T) {
|
||||
require.Equal(t, true, body["is_default"])
|
||||
require.Equal(t, "Fast local", body["name"])
|
||||
options := body["options"].(map[string]any)
|
||||
require.Equal(t, "greedy_search", options["decoding_method"])
|
||||
require.Equal(t, "vad", options["chunking_strategy"])
|
||||
require.Equal(t, float64(0.5), options["diarization_threshold"])
|
||||
require.Equal(t, float64(0.2), options["min_duration_on"])
|
||||
require.Equal(t, float64(0.3), options["min_duration_off"])
|
||||
pipeline := options["pipeline"].([]any)
|
||||
require.Len(t, pipeline, 1)
|
||||
require.Equal(t, "transcription", pipeline[0].(map[string]any)["kind"])
|
||||
require.Equal(t, "whisper-base", pipeline[0].(map[string]any)["model"])
|
||||
require.NotContains(t, options, "enable_token_timestamps")
|
||||
require.NotContains(t, options, "enable_segment_timestamps")
|
||||
step := pipeline[0].(map[string]any)
|
||||
require.Equal(t, "transcription", step["kind"])
|
||||
require.Equal(t, "whisper-base", step["model"])
|
||||
require.Equal(t, "whisper", step["model_family"])
|
||||
stepOptions := step["options"].(map[string]any)
|
||||
require.Equal(t, "en", stepOptions["sherpa.whisper.language"])
|
||||
require.Equal(t, "vad", stepOptions["chunking.mode"])
|
||||
require.Equal(t, float64(2), stepOptions["runtime.num_threads"])
|
||||
require.NotContains(t, options, "decoding_method")
|
||||
require.NotContains(t, options, "chunking_strategy")
|
||||
|
||||
var storedProfile models.TranscriptionProfile
|
||||
require.NoError(t, database.DB.First(&storedProfile, "id = ?", strings.TrimPrefix(firstID, "profile_")).Error)
|
||||
require.NotNil(t, storedProfile.Parameters.EnableTokenTimestamps)
|
||||
require.True(t, *storedProfile.Parameters.EnableTokenTimestamps)
|
||||
require.NotNil(t, storedProfile.Parameters.EnableSegmentTimestamps)
|
||||
require.True(t, *storedProfile.Parameters.EnableSegmentTimestamps)
|
||||
require.Equal(t, "vad", storedProfile.Parameters.ChunkingStrategy)
|
||||
require.Empty(t, storedProfile.Parameters.Model)
|
||||
require.Empty(t, storedProfile.Parameters.ChunkingStrategy)
|
||||
require.Len(t, storedProfile.Parameters.Pipeline, 1)
|
||||
require.Equal(t, models.ASRStepTranscription, storedProfile.Parameters.Pipeline[0].Kind)
|
||||
require.Equal(t, "vad", storedProfile.Parameters.Pipeline[0].Options["chunking.mode"])
|
||||
|
||||
resp, body = s.request(t, http.MethodGet, "/api/v1/settings", nil, token, "")
|
||||
require.Equal(t, http.StatusOK, resp.Code)
|
||||
@@ -77,7 +79,6 @@ func TestProfileCRUDAndDefaultSelection(t *testing.T) {
|
||||
"is_default": true,
|
||||
"options": map[string]any{
|
||||
"pipeline": pipelineRequest("transcription", "whisper-small", "diarization", "diarization-default"),
|
||||
"language": "en",
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusCreated, resp.Code)
|
||||
@@ -108,8 +109,6 @@ func TestProfileCRUDAndDefaultSelection(t *testing.T) {
|
||||
"description": "Updated",
|
||||
"options": map[string]any{
|
||||
"pipeline": pipelineRequest("transcription", "parakeet-v2", "diarization", "diarization-default"),
|
||||
"language": "fr",
|
||||
"threads": 4,
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusOK, resp.Code)
|
||||
@@ -182,7 +181,7 @@ func TestProfileValidationAndAuth(t *testing.T) {
|
||||
require.Equal(t, http.StatusNotFound, resp.Code)
|
||||
}
|
||||
|
||||
func TestWhisperProfileForcesGreedyDecoding(t *testing.T) {
|
||||
func TestProfileRejectsLegacyDecodingOption(t *testing.T) {
|
||||
s := newAuthTestServer(t)
|
||||
token := registerForFileTests(t, s)
|
||||
|
||||
@@ -193,9 +192,9 @@ func TestWhisperProfileForcesGreedyDecoding(t *testing.T) {
|
||||
"decoding_method": "modified_beam_search",
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusCreated, resp.Code)
|
||||
options := body["options"].(map[string]any)
|
||||
require.Equal(t, "greedy_search", options["decoding_method"])
|
||||
require.Equal(t, http.StatusUnprocessableEntity, resp.Code)
|
||||
errBody := body["error"].(map[string]any)
|
||||
require.Equal(t, "options.decoding_method", errBody["field"])
|
||||
}
|
||||
|
||||
func TestGetProfileDoesNotPublishUpdateEvent(t *testing.T) {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -343,14 +342,7 @@ func profileResponse(profile *models.TranscriptionProfile) ProfileResponse {
|
||||
}
|
||||
}
|
||||
func profileOptionsMap(params models.ASRParams) gin.H {
|
||||
params.EnableTokenTimestamps = nil
|
||||
params.EnableSegmentTimestamps = nil
|
||||
var options gin.H
|
||||
bytes, err := json.Marshal(params)
|
||||
if err != nil || json.Unmarshal(bytes, &options) != nil {
|
||||
options = gin.H{}
|
||||
}
|
||||
return options
|
||||
return gin.H{"pipeline": params.Pipeline}
|
||||
}
|
||||
func settingsResponse(h *Handler, user *models.User) SettingsResponse {
|
||||
defaultProfileID := any(nil)
|
||||
|
||||
@@ -121,10 +121,18 @@ func TestTranscriptionCreateAppliesDefaultAndSelectedProfiles(t *testing.T) {
|
||||
"name": "Default profile",
|
||||
"is_default": true,
|
||||
"options": map[string]any{
|
||||
"pipeline": pipelineRequest("transcription", "whisper-small", "diarization", "diarization-default"),
|
||||
"language": "fr",
|
||||
"chunking_strategy": "vad",
|
||||
"threads": 2,
|
||||
"pipeline": []map[string]any{
|
||||
{
|
||||
"kind": "transcription",
|
||||
"model": "whisper-small",
|
||||
"options": map[string]any{
|
||||
"sherpa.whisper.language": "fr",
|
||||
"chunking.mode": "vad",
|
||||
"runtime.num_threads": 2,
|
||||
},
|
||||
},
|
||||
{"kind": "diarization", "model": "diarization-default"},
|
||||
},
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusCreated, resp.Code)
|
||||
@@ -138,19 +146,18 @@ func TestTranscriptionCreateAppliesDefaultAndSelectedProfiles(t *testing.T) {
|
||||
|
||||
var defaultJob models.TranscriptionJob
|
||||
require.NoError(t, database.DB.First(&defaultJob, "id = ?", defaultJobID).Error)
|
||||
require.Equal(t, "whisper-small", defaultJob.Parameters.Model)
|
||||
require.Equal(t, 2, defaultJob.Parameters.Threads)
|
||||
require.NotNil(t, defaultJob.Parameters.Language)
|
||||
require.Equal(t, "fr", *defaultJob.Parameters.Language)
|
||||
require.Equal(t, "vad", defaultJob.Parameters.ChunkingStrategy)
|
||||
require.Empty(t, defaultJob.Parameters.Model)
|
||||
require.Len(t, defaultJob.Parameters.Pipeline, 2)
|
||||
require.Equal(t, "whisper-small", defaultJob.Parameters.Pipeline[0].Model)
|
||||
require.Equal(t, "fr", defaultJob.Parameters.Pipeline[0].Options["sherpa.whisper.language"])
|
||||
require.Equal(t, "vad", defaultJob.Parameters.Pipeline[0].Options["chunking.mode"])
|
||||
require.Equal(t, float64(2), defaultJob.Parameters.Pipeline[0].Options["runtime.num_threads"])
|
||||
require.True(t, defaultJob.Diarization)
|
||||
|
||||
resp, body = s.request(t, http.MethodPost, "/api/v1/profiles", map[string]any{
|
||||
"name": "Selected profile",
|
||||
"options": map[string]any{
|
||||
"pipeline": pipelineRequest("transcription", "parakeet-v2", "diarization", "diarization-default"),
|
||||
"language": "es",
|
||||
"threads": 4,
|
||||
},
|
||||
}, token, "")
|
||||
require.Equal(t, http.StatusCreated, resp.Code)
|
||||
@@ -171,8 +178,9 @@ func TestTranscriptionCreateAppliesDefaultAndSelectedProfiles(t *testing.T) {
|
||||
|
||||
var selectedJob models.TranscriptionJob
|
||||
require.NoError(t, database.DB.First(&selectedJob, "id = ?", selectedJobID).Error)
|
||||
require.Equal(t, "parakeet-v2", selectedJob.Parameters.Model)
|
||||
require.Equal(t, 4, selectedJob.Parameters.Threads)
|
||||
require.Empty(t, selectedJob.Parameters.Model)
|
||||
require.Len(t, selectedJob.Parameters.Pipeline, 1)
|
||||
require.Equal(t, "parakeet-v2", selectedJob.Parameters.Pipeline[0].Model)
|
||||
require.NotNil(t, selectedJob.Parameters.Language)
|
||||
require.Equal(t, "en", *selectedJob.Parameters.Language)
|
||||
require.False(t, selectedJob.Diarization)
|
||||
|
||||
@@ -76,15 +76,15 @@ type updateTranscriptionRequest struct {
|
||||
type profileOptionsRequest struct {
|
||||
Pipeline []models.ASRStep `json:"pipeline,omitempty"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
Task string `json:"task"`
|
||||
Threads int `json:"threads"`
|
||||
Task string `json:"task,omitempty"`
|
||||
Threads *int `json:"threads,omitempty"`
|
||||
TailPaddings *int `json:"tail_paddings,omitempty"`
|
||||
DecodingMethod string `json:"decoding_method"`
|
||||
ChunkingStrategy string `json:"chunking_strategy"`
|
||||
NumSpeakers int `json:"num_speakers"`
|
||||
DiarizationThreshold float64 `json:"diarization_threshold"`
|
||||
MinDurationOn float64 `json:"min_duration_on"`
|
||||
MinDurationOff float64 `json:"min_duration_off"`
|
||||
DecodingMethod string `json:"decoding_method,omitempty"`
|
||||
ChunkingStrategy string `json:"chunking_strategy,omitempty"`
|
||||
NumSpeakers *int `json:"num_speakers,omitempty"`
|
||||
DiarizationThreshold *float64 `json:"diarization_threshold,omitempty"`
|
||||
MinDurationOn *float64 `json:"min_duration_on,omitempty"`
|
||||
MinDurationOff *float64 `json:"min_duration_off,omitempty"`
|
||||
}
|
||||
type createProfileRequest struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -627,8 +627,11 @@ func TestSchemaUpgradeRunsVersionedBackfill(t *testing.T) {
|
||||
Name: "profile-a",
|
||||
IsDefault: true,
|
||||
Parameters: models.ASRParams{
|
||||
Model: "medium",
|
||||
ModelFamily: "whisper",
|
||||
Pipeline: []models.ASRStep{{
|
||||
Kind: models.ASRStepTranscription,
|
||||
Model: "medium",
|
||||
ModelFamily: "whisper",
|
||||
}},
|
||||
},
|
||||
CreatedAt: base,
|
||||
UpdatedAt: base,
|
||||
@@ -639,8 +642,11 @@ func TestSchemaUpgradeRunsVersionedBackfill(t *testing.T) {
|
||||
Name: "profile-b",
|
||||
IsDefault: true,
|
||||
Parameters: models.ASRParams{
|
||||
Model: "large-v3",
|
||||
ModelFamily: "whisper",
|
||||
Pipeline: []models.ASRStep{{
|
||||
Kind: models.ASRStepTranscription,
|
||||
Model: "large-v3",
|
||||
ModelFamily: "whisper",
|
||||
}},
|
||||
},
|
||||
CreatedAt: base,
|
||||
UpdatedAt: base,
|
||||
@@ -660,8 +666,8 @@ func TestSchemaUpgradeRunsVersionedBackfill(t *testing.T) {
|
||||
require.NoError(t, db.First(&reloadedB, "id = ?", profileB.ID).Error)
|
||||
assert.True(t, reloadedA.IsDefault)
|
||||
assert.True(t, reloadedB.IsDefault)
|
||||
assert.Equal(t, "medium", reloadedA.Parameters.Model)
|
||||
assert.Equal(t, "large-v3", reloadedB.Parameters.Model)
|
||||
assert.Empty(t, reloadedA.Parameters.Pipeline)
|
||||
assert.Empty(t, reloadedB.Parameters.Pipeline)
|
||||
}
|
||||
|
||||
func TestSchemaUpgradeBackfillsMissingUserSettingsAndListIndex(t *testing.T) {
|
||||
|
||||
@@ -396,9 +396,16 @@ func (tp *TranscriptionProfile) BeforeSave(tx *gorm.DB) error {
|
||||
if err := requireUserID("transcription profile", tp.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
tp.ModelName = tp.Parameters.Model
|
||||
tp.ModelFamily = tp.Parameters.ModelFamily
|
||||
tp.Language = tp.Parameters.Language
|
||||
if step, ok := firstASRStep(tp.Parameters.Pipeline, ASRStepTranscription); ok {
|
||||
tp.Provider = step.Provider
|
||||
tp.ModelName = step.Model
|
||||
tp.ModelFamily = step.ModelFamily
|
||||
} else {
|
||||
tp.Provider = ""
|
||||
tp.ModelName = ""
|
||||
tp.ModelFamily = ""
|
||||
}
|
||||
tp.Language = nil
|
||||
tp.DiarizationEnabled = hasASRStep(tp.Parameters.Pipeline, ASRStepDiarization)
|
||||
configJSON, err := marshalJSONColumn("transcription_profiles.config_json", tp.Parameters)
|
||||
if err != nil {
|
||||
@@ -419,15 +426,6 @@ func (tp *TranscriptionProfile) AfterFind(tx *gorm.DB) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if tp.Parameters.Model == "" {
|
||||
tp.Parameters.Model = tp.ModelName
|
||||
}
|
||||
if tp.Parameters.ModelFamily == "" {
|
||||
tp.Parameters.ModelFamily = tp.ModelFamily
|
||||
}
|
||||
if tp.Parameters.Language == nil {
|
||||
tp.Parameters.Language = tp.Language
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -440,6 +438,15 @@ func hasASRStep(steps []ASRStep, kind string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func firstASRStep(steps []ASRStep, kind string) (ASRStep, bool) {
|
||||
for _, step := range steps {
|
||||
if step.Kind == kind {
|
||||
return step, true
|
||||
}
|
||||
}
|
||||
return ASRStep{}, false
|
||||
}
|
||||
|
||||
// LLMConfig represents a saved LLM profile.
|
||||
type LLMConfig struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
|
||||
@@ -101,17 +101,7 @@ func (s *Service) normalizeProfile(ctx context.Context, profile *models.Transcri
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transcription := pipeline[0]
|
||||
info, err := s.catalog.ResolveTranscriptionModel(ctx, transcription.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.Pipeline = pipeline
|
||||
params.Model = info.ID
|
||||
params.ModelFamily = info.Family
|
||||
if info.Family == "whisper" {
|
||||
params.DecodingMethod = "greedy_search"
|
||||
}
|
||||
profile.Parameters = params
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -104,8 +104,8 @@ func TestServiceCreateNormalizesProfileModelFromCatalog(t *testing.T) {
|
||||
if repo.created == nil {
|
||||
t.Fatal("profile was not created")
|
||||
}
|
||||
if repo.created.Parameters.Model != "parakeet-v2" || repo.created.Parameters.ModelFamily != "nemo_transducer" {
|
||||
t.Fatalf("profile parameters were not normalized: %#v", repo.created.Parameters)
|
||||
if repo.created.Parameters.Model != "" || repo.created.Parameters.ModelFamily != "" {
|
||||
t.Fatalf("flat profile parameters should remain empty: %#v", repo.created.Parameters)
|
||||
}
|
||||
if len(repo.created.Parameters.Pipeline) != 1 {
|
||||
t.Fatalf("profile pipeline length = %d, want 1", len(repo.created.Parameters.Pipeline))
|
||||
|
||||
@@ -307,25 +307,7 @@ func pipelineStepsForJob(job *models.TranscriptionJob) []models.ASRStep {
|
||||
if job == nil {
|
||||
return nil
|
||||
}
|
||||
if len(job.Parameters.Pipeline) > 0 {
|
||||
return job.Parameters.Pipeline
|
||||
}
|
||||
return []models.ASRStep{{
|
||||
Kind: models.ASRStepTranscription,
|
||||
Provider: providerFromJob(job),
|
||||
Model: engineprovider.DefaultTranscriptionModel,
|
||||
ModelFamily: "whisper",
|
||||
}}
|
||||
}
|
||||
|
||||
func providerFromJob(job *models.TranscriptionJob) string {
|
||||
if job == nil {
|
||||
return ""
|
||||
}
|
||||
if job.EngineID != nil && strings.TrimSpace(*job.EngineID) != "" {
|
||||
return strings.TrimSpace(*job.EngineID)
|
||||
}
|
||||
return strings.TrimSpace(job.Parameters.Provider)
|
||||
return job.Parameters.Pipeline
|
||||
}
|
||||
|
||||
func firstStepByKind(steps []resolvedASRStep, kind string) (resolvedASRStep, bool) {
|
||||
|
||||
@@ -137,6 +137,10 @@ func createOrchestratorJob(t *testing.T, db *gorm.DB, audioPath string, params m
|
||||
return job
|
||||
}
|
||||
|
||||
func transcriptionOnlyParams() models.ASRParams {
|
||||
return models.ASRParams{Pipeline: []models.ASRStep{{Kind: models.ASRStepTranscription, Model: "whisper-base"}}}
|
||||
}
|
||||
|
||||
func testHasASRStep(steps []models.ASRStep, kind string) bool {
|
||||
for _, step := range steps {
|
||||
if step.Kind == kind {
|
||||
@@ -508,7 +512,7 @@ func TestProcessorPersistsProviderProgress(t *testing.T) {
|
||||
db := openOrchestratorTestDB(t)
|
||||
audioPath := filepath.Join(t.TempDir(), "audio.wav")
|
||||
require.NoError(t, os.WriteFile(audioPath, []byte("fake wav"), 0o600))
|
||||
job := createOrchestratorJob(t, db, audioPath, models.ASRParams{})
|
||||
job := createOrchestratorJob(t, db, audioPath, transcriptionOnlyParams())
|
||||
progress := 0.31
|
||||
provider := &fakeProvider{
|
||||
id: "local",
|
||||
@@ -685,7 +689,7 @@ func TestProcessorReturnsSanitizedFailure(t *testing.T) {
|
||||
db := openOrchestratorTestDB(t)
|
||||
audioPath := filepath.Join(t.TempDir(), "audio.wav")
|
||||
require.NoError(t, os.WriteFile(audioPath, []byte("fake wav"), 0o600))
|
||||
job := createOrchestratorJob(t, db, audioPath, models.ASRParams{})
|
||||
job := createOrchestratorJob(t, db, audioPath, transcriptionOnlyParams())
|
||||
provider := &fakeProvider{
|
||||
id: "local",
|
||||
transErr: errors.New("open /tmp/private/model.bin failed api_key=secret-value"),
|
||||
|
||||
Reference in New Issue
Block a user