From 08aaaedb13dc5e5f6fa0c93ce2c1dfe64f8dfaec Mon Sep 17 00:00:00 2001 From: Daniel <845765@qq.com> Date: Sat, 30 May 2026 21:03:38 +0800 Subject: [PATCH] :sparkles: Semantic search using AI embeddings https://github.com/siyuan-note/siyuan/issues/17788 Signed-off-by: Daniel <845765@qq.com> --- kernel/api/router.go | 1 + kernel/api/search.go | 24 ++++ kernel/cli/cmd/serve.go | 1 + kernel/conf/ai.go | 31 ++-- kernel/go.mod | 3 +- kernel/go.sum | 4 +- kernel/mobile/kernel.go | 1 + kernel/model/conf.go | 8 ++ kernel/model/embedding.go | 295 ++++++++++++++++++++++++++++++++++++++ kernel/sql/database.go | 31 ++++ kernel/sql/queue.go | 31 ++++ kernel/util/openai.go | 34 +++++ 12 files changed, 450 insertions(+), 14 deletions(-) create mode 100644 kernel/model/embedding.go diff --git a/kernel/api/router.go b/kernel/api/router.go index 4f32d2646..3f73b8bd8 100644 --- a/kernel/api/router.go +++ b/kernel/api/router.go @@ -201,6 +201,7 @@ func ServeAPI(ginServer *gin.Engine) { ginServer.Handle("POST", "/api/search/fullTextSearchAssetContent", model.CheckAuth, fullTextSearchAssetContent) ginServer.Handle("POST", "/api/search/getAssetContent", model.CheckAuth, getAssetContent) ginServer.Handle("POST", "/api/search/listInvalidBlockRefs", model.CheckAuth, listInvalidBlockRefs) + ginServer.Handle("POST", "/api/search/semanticSearchBlock", model.CheckAuth, semanticSearchBlock) ginServer.Handle("POST", "/api/block/getBlockInfo", model.CheckAuth, getBlockInfo) ginServer.Handle("POST", "/api/block/getBlockDOM", model.CheckAuth, model.CheckAdminRole, getBlockDOM) diff --git a/kernel/api/search.go b/kernel/api/search.go index 0277c141f..c730ea910 100644 --- a/kernel/api/search.go +++ b/kernel/api/search.go @@ -540,3 +540,27 @@ func parseSearchAssetContentArgs(arg map[string]any) (page, pageSize int, query } return } + +func semanticSearchBlock(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + page, pageSize, query, _, boxes, types, subTypes, _, _, _ := parseSearchBlockArgs(arg) + + blocks, matchedBlockCount, matchedRootCount, pageCount := model.SemanticSearchBlock(query, boxes, nil, types, subTypes, page, pageSize) + if model.IsReadOnlyRoleContext(c) { + publishAccess := model.GetPublishAccess() + blocks = model.FilterBlocksByPublishAccess(c, publishAccess, blocks) + } + ret.Data = map[string]any{ + "blocks": blocks, + "matchedBlockCount": matchedBlockCount, + "matchedRootCount": matchedRootCount, + "pageCount": pageCount, + } +} diff --git a/kernel/cli/cmd/serve.go b/kernel/cli/cmd/serve.go index ed6f3e9ad..c28f772a4 100644 --- a/kernel/cli/cmd/serve.go +++ b/kernel/cli/cmd/serve.go @@ -87,6 +87,7 @@ var serveCmd = &cobra.Command{ go cache.LoadAssets() go util.CheckFileSysStatus() go plugin.InitManager() + go model.StartEmbeddingIndexer() model.WatchAssets() model.WatchEmojis() diff --git a/kernel/conf/ai.go b/kernel/conf/ai.go index 9ca4df1b6..26ba72015 100644 --- a/kernel/conf/ai.go +++ b/kernel/conf/ai.go @@ -30,17 +30,20 @@ type AI struct { } type OpenAI struct { - APIKey string `json:"apiKey"` - APITimeout int `json:"apiTimeout"` - APIProxy string `json:"apiProxy"` - APIModel string `json:"apiModel"` - APIMaxTokens int `json:"apiMaxTokens"` - APITemperature float64 `json:"apiTemperature"` - APIMaxContexts int `json:"apiMaxContexts"` - APIBaseURL string `json:"apiBaseURL"` - APIUserAgent string `json:"apiUserAgent"` - APIProvider string `json:"apiProvider"` // OpenAI, Azure - APIVersion string `json:"apiVersion"` // Azure API version + APIKey string `json:"apiKey"` + APITimeout int `json:"apiTimeout"` + APIProxy string `json:"apiProxy"` + APIModel string `json:"apiModel"` + APIMaxTokens int `json:"apiMaxTokens"` + APITemperature float64 `json:"apiTemperature"` + APIMaxContexts int `json:"apiMaxContexts"` + APIBaseURL string `json:"apiBaseURL"` + APIUserAgent string `json:"apiUserAgent"` + APIProvider string `json:"apiProvider"` // OpenAI, Azure + APIVersion string `json:"apiVersion"` // Azure API version + EmbeddingModel string `json:"embeddingModel"` + EmbeddingBaseURL string `json:"embeddingBaseURL"` + EmbeddingAPIKey string `json:"embeddingAPIKey"` } func NewAI() *AI { @@ -95,5 +98,11 @@ func NewAI() *AI { if userAgent := os.Getenv("SIYUAN_OPENAI_API_USER_AGENT"); "" != userAgent { openAI.APIUserAgent = userAgent } + if embeddingBaseURL := os.Getenv("SIYUAN_OPENAI_EMBEDDING_BASE_URL"); "" != embeddingBaseURL { + openAI.EmbeddingBaseURL = embeddingBaseURL + } + if embeddingAPIKey := os.Getenv("SIYUAN_OPENAI_EMBEDDING_API_KEY"); "" != embeddingAPIKey { + openAI.EmbeddingAPIKey = embeddingAPIKey + } return &AI{OpenAI: openAI} } diff --git a/kernel/go.mod b/kernel/go.mod index 28260ce14..ecb8d2685 100644 --- a/kernel/go.mod +++ b/kernel/go.mod @@ -69,7 +69,7 @@ require ( github.com/siyuan-note/dataparser v0.0.0-20260115084335-b57cb8bc7c17 github.com/siyuan-note/dejavu v0.0.0-20260529092727-2b5e57a676af github.com/siyuan-note/encryption v0.0.0-20251120032857-3ddc3c2cc49f - github.com/siyuan-note/eventbus v0.0.0-20240627125516-396fdb0f0f97 + github.com/siyuan-note/eventbus v0.0.0-20260530125927-d77c74260dce github.com/siyuan-note/filelock v0.0.0-20260411141728-bf44452627c0 github.com/siyuan-note/httpclient v0.0.0-20260529092300-aa977bebbd71 github.com/siyuan-note/logging v0.0.0-20260513050044-06b8e04d5490 @@ -227,3 +227,4 @@ replace github.com/pdfcpu/pdfcpu => github.com/88250/pdfcpu v0.3.14-0.2025042412 //replace github.com/mattn/go-sqlite3 => D:\88250\go-sqlite3 //replace github.com/88250/epub => D:\88250\epub //replace github.com/siyuan-note/logging => D:\88250\logging +//replace github.com/siyuan-note/eventbus => D:\88250\eventbus diff --git a/kernel/go.sum b/kernel/go.sum index 2952045ed..2b3855f65 100644 --- a/kernel/go.sum +++ b/kernel/go.sum @@ -387,8 +387,8 @@ github.com/siyuan-note/dejavu v0.0.0-20260529092727-2b5e57a676af h1:UeK/zRYGFlBa github.com/siyuan-note/dejavu v0.0.0-20260529092727-2b5e57a676af/go.mod h1:Wal8Y/RgYRINa0jyHdN3XqPUh0GGEcr3QfNVcK9Ci2w= github.com/siyuan-note/encryption v0.0.0-20251120032857-3ddc3c2cc49f h1:HSgJKIAMgokJDAvBBfRj47SzRSm6mNGssY0Wv7rcEtg= github.com/siyuan-note/encryption v0.0.0-20251120032857-3ddc3c2cc49f/go.mod h1:JE3S9VuJqTggyfhjesNDuqvqrRvwG3IctFjXXchLx1M= -github.com/siyuan-note/eventbus v0.0.0-20240627125516-396fdb0f0f97 h1:lM5v8BfNtbOL5jYwhCdMYBcYtr06IYBKjjSLAPMKTM8= -github.com/siyuan-note/eventbus v0.0.0-20240627125516-396fdb0f0f97/go.mod h1:1/nGgthl89FPA7GzAcEWKl6zRRnfgyTjzLZj9bW7kuw= +github.com/siyuan-note/eventbus v0.0.0-20260530125927-d77c74260dce h1:356bfC3pucxXuYpZIwaQ3viVCbq5RbRuvBhb4xtoBR0= +github.com/siyuan-note/eventbus v0.0.0-20260530125927-d77c74260dce/go.mod h1:1/nGgthl89FPA7GzAcEWKl6zRRnfgyTjzLZj9bW7kuw= github.com/siyuan-note/filelock v0.0.0-20260411141728-bf44452627c0 h1:9qiLwd2reW71iyNHqiMBA7pBIb1LLwbqIrkH3QaOlxI= github.com/siyuan-note/filelock v0.0.0-20260411141728-bf44452627c0/go.mod h1:ew8UsjZTqPTGmdR9Pd13Ep6hVVx+c69nHljuGDro/WU= github.com/siyuan-note/httpclient v0.0.0-20260529092300-aa977bebbd71 h1:xt6pqMImwdrLs3jyxDnxl2PGBZ7RUjQyqaNV0kaNOK4= diff --git a/kernel/mobile/kernel.go b/kernel/mobile/kernel.go index e80220598..362760e2d 100644 --- a/kernel/mobile/kernel.go +++ b/kernel/mobile/kernel.go @@ -229,6 +229,7 @@ func StartKernel(container, appDir, workspaceBaseDir, timezoneID, localIPs, lang job.StartCron() go model.AutoGenerateFileHistory() go cache.LoadAssets() + go model.StartEmbeddingIndexer() }() } diff --git a/kernel/model/conf.go b/kernel/model/conf.go index 75170f2c0..9137a16bc 100644 --- a/kernel/model/conf.go +++ b/kernel/model/conf.go @@ -591,6 +591,14 @@ func InitConf() { Conf.AI.OpenAI.APIMaxContexts) } + if "" != Conf.AI.OpenAI.EmbeddingAPIKey { + logging.LogInfof("embedding API enabled\n"+ + " baseURL=%s\n"+ + " model=%s", + Conf.AI.OpenAI.EmbeddingBaseURL, + Conf.AI.OpenAI.EmbeddingModel) + } + Conf.ReadOnly = util.ReadOnly if "" != util.AccessAuthCode { diff --git a/kernel/model/embedding.go b/kernel/model/embedding.go new file mode 100644 index 000000000..3bc3b15ee --- /dev/null +++ b/kernel/model/embedding.go @@ -0,0 +1,295 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package model + +import ( + "encoding/json" + "math" + "os" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/siyuan-note/eventbus" + "github.com/siyuan-note/logging" + "github.com/siyuan-note/siyuan/kernel/sql" + "github.com/siyuan-note/siyuan/kernel/util" +) + +const ( + embeddingBatchSize = 10 + embeddingFetchSize = 30 + embeddingMinTextLen = 7 + embeddingMaxContentLen = 12000 +) + +var ( + embeddingDirtyCh = make(chan string, 1024) + embeddingTableOk bool +) + +func checkEmbeddingTable() bool { + _, err := sql.QueryNoLimit("SELECT COUNT(*) FROM block_embeddings") + if err != nil { + logging.LogWarnf("block_embeddings table not available, embedding indexer disabled: %s", err) + return false + } + return true +} + +func StartEmbeddingIndexer() { + if !checkEmbeddingTable() || !isEmbeddingEnabled() { + return + } + + eventbus.Subscribe(eventbus.EvtEmbeddingDirty, func(id string) { + select { + case embeddingDirtyCh <- id: + default: + } + }) + + embeddingTableOk = true + + processPendingEmbeddings() + + for { + select { + case <-embeddingDirtyCh: + processPendingEmbeddings() + case <-time.After(30 * time.Second): + processPendingEmbeddings() + } + } +} + +func processPendingEmbeddings() { + if !isEmbeddingEnabled() { + return + } + + for { + results, err := sql.QueryNoLimit(stmtPendingBlocks) + if err != nil { + logging.LogErrorf("query pending embedding blocks failed: %s", err) + return + } + + if 1 > len(results) { + return + } + + var batches [][]map[string]any + var batch []map[string]any + for _, row := range results { + id, _ := row["id"].(string) + rootID, _ := row["root_id"].(string) + box, _ := row["box"].(string) + path, _ := row["path"].(string) + updated, _ := row["updated"].(string) + content, _ := row["content"].(string) + if len(content) < embeddingMinTextLen || len(content) > embeddingMaxContentLen { + sql.Exec("INSERT OR IGNORE INTO block_embeddings (id, root_id, box, path, embedding, model, content_len, updated) VALUES ('" + + id + "','" + rootID + "','" + box + "','" + path + "','','" + Conf.AI.OpenAI.EmbeddingModel + "',0,'" + updated + "')") + continue + } + row["plain_text"] = content + batch = append(batch, row) + + if len(batch) >= embeddingBatchSize { + batches = append(batches, batch) + batch = nil + } + } + if len(batch) > 0 { + batches = append(batches, batch) + } + + var wg sync.WaitGroup + for _, bt := range batches { + wg.Add(1) + go func(blocks []map[string]any) { + defer wg.Done() + var texts []string + for _, row := range blocks { + texts = append(texts, row["plain_text"].(string)) + } + doEmbedAndStore(texts, blocks) + }(bt) + } + wg.Wait() + } +} + +const stmtPendingBlocks = "SELECT b.id, b.root_id, b.box, b.path, b.content, b.updated FROM blocks b " + + "LEFT JOIN block_embeddings e ON b.id = e.id " + + "WHERE e.id IS NULL " + + "ORDER BY b.updated DESC LIMIT 30" + +func doEmbedAndStore(texts []string, blocks []map[string]any) { + vectors, err := util.BatchGetEmbeddings(texts, embeddingKey(), embeddingBaseURL(), Conf.AI.OpenAI.EmbeddingModel, Conf.AI.OpenAI.APITimeout) + if err != nil { + return + } + + for i, row := range blocks { + id, _ := row["id"].(string) + rootID, _ := row["root_id"].(string) + box, _ := row["box"].(string) + path, _ := row["path"].(string) + updated, _ := row["updated"].(string) + plainText, _ := row["plain_text"].(string) + + embeddingJSON, err := json.Marshal(vectors[i]) + if err != nil { + logging.LogErrorf("marshal embedding failed for block [%s]: %s", id, err) + continue + } + + escaped := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + stmt := "INSERT OR REPLACE INTO block_embeddings (id, root_id, box, path, embedding, model, content_len, updated) VALUES ('" + + escaped(id) + "', '" + escaped(rootID) + "', '" + escaped(box) + "', '" + escaped(path) + "', '" + + escaped(string(embeddingJSON)) + "', '" + escaped(Conf.AI.OpenAI.EmbeddingModel) + "', " + + strconv.Itoa(len(plainText)) + ", '" + escaped(updated) + "')" + + err = sql.Exec(stmt) + if err != nil { + logging.LogErrorf("store embedding failed for block [%s]: %s", id, err) + } + } +} + +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) || len(a) == 0 { + return 0 + } + + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + + if normA == 0 || normB == 0 { + return 0 + } + + return float32(dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))) +} + +func SemanticSearchBlock(query string, boxes, paths []string, types, subTypes map[string]bool, page, pageSize int) (blocks []*Block, matchedBlockCount, matchedRootCount, pageCount int) { + blocks = []*Block{} + + if !embeddingTableOk || !isEmbeddingEnabled() || "" == query { + return + } + + vectors, err := util.BatchGetEmbeddings([]string{query}, embeddingKey(), embeddingBaseURL(), Conf.AI.OpenAI.EmbeddingModel, Conf.AI.OpenAI.APITimeout) + if err != nil || 1 > len(vectors) { + logging.LogErrorf("get query embedding failed") + return + } + queryVec := vectors[0] + + results, err := sql.QueryNoLimit("SELECT id, embedding FROM block_embeddings") + if err != nil { + logging.LogErrorf("query embeddings for search failed: %s", err) + return + } + + type scoredBlock struct { + id string + score float32 + } + var scored []scoredBlock + + for _, row := range results { + embStr, _ := row["embedding"].([]byte) + if embStr == nil { + if s, ok := row["embedding"].(string); ok { + embStr = []byte(s) + } else { + continue + } + } + var vec []float32 + if err := json.Unmarshal(embStr, &vec); err != nil { + continue + } + score := cosineSimilarity(queryVec, vec) + id, _ := row["id"].(string) + scored = append(scored, scoredBlock{id: id, score: score}) + } + + sort.Slice(scored, func(i, j int) bool { + return scored[i].score > scored[j].score + }) + + matchedBlockCount = len(scored) + + offset := (page - 1) * pageSize + if offset >= len(scored) { + pageCount = (matchedBlockCount + pageSize - 1) / pageSize + return + } + + end := offset + pageSize + if end > len(scored) { + end = len(scored) + } + + var topIDs []string + for i := offset; i < end; i++ { + topIDs = append(topIDs, scored[i].id) + } + + sqlBlocks := sql.GetBlocks(topIDs) + rootIDSet := map[string]bool{} + for _, b := range sqlBlocks { + rootIDSet[b.RootID] = true + blocks = append(blocks, fromSQLBlock(b, "", 36)) + } + matchedRootCount = len(rootIDSet) + pageCount = (matchedBlockCount + pageSize - 1) / pageSize + + return +} + +func isEmbeddingEnabled() bool { + return "" != embeddingKey() +} + +func embeddingKey() string { + if "" != Conf.AI.OpenAI.EmbeddingAPIKey { + return Conf.AI.OpenAI.EmbeddingAPIKey + } + return os.Getenv("SIYUAN_OPENAI_EMBEDDING_API_KEY") +} + +func embeddingBaseURL() string { + if "" != Conf.AI.OpenAI.EmbeddingBaseURL { + return Conf.AI.OpenAI.EmbeddingBaseURL + } + if v := os.Getenv("SIYUAN_OPENAI_EMBEDDING_BASE_URL"); "" != v { + return v + } + return Conf.AI.OpenAI.EmbeddingBaseURL +} diff --git a/kernel/sql/database.go b/kernel/sql/database.go index 75683951c..3138454eb 100644 --- a/kernel/sql/database.go +++ b/kernel/sql/database.go @@ -224,6 +224,19 @@ func initDBTables() { if err != nil { logging.LogFatalf(logging.ExitCodeUnavailableDatabase, "create table [refs] failed: %s", err) } + + _, err = db.Exec("DROP TABLE IF EXISTS block_embeddings") + if err != nil { + logging.LogFatalf(logging.ExitCodeUnavailableDatabase, "drop table [block_embeddings] failed: %s", err) + } + _, err = db.Exec("CREATE TABLE block_embeddings (id TEXT PRIMARY KEY, root_id TEXT, box TEXT, path TEXT, embedding BLOB, model TEXT, content_len INTEGER, updated TEXT)") + if err != nil { + logging.LogFatalf(logging.ExitCodeUnavailableDatabase, "create table [block_embeddings] failed: %s", err) + } + _, err = db.Exec("CREATE INDEX idx_block_embeddings_root_id ON block_embeddings(root_id)") + if err != nil { + logging.LogFatalf(logging.ExitCodeUnavailableDatabase, "create index [idx_block_embeddings_root_id] failed: %s", err) + } } func initFTSBlocks() (err error) { @@ -1065,6 +1078,11 @@ func deleteBlocksByIDs(tx *sql.Tx, ids []string) (err error) { return } } + + stmt = "DELETE FROM block_embeddings WHERE id IN (" + strings.Join(ftsIDs, ",") + ")" + if err = execStmtTx(tx, stmt); err != nil { + return + } return } @@ -1420,6 +1438,19 @@ func query(query string, args ...any) (*sql.Rows, error) { return db.Query(query, args...) } +func Exec(stmt string) error { + stmt = strings.TrimSpace(stmt) + if "" == stmt { + return errors.New("statement is empty") + } + + if nil == db { + return errors.New("database is nil") + } + _, err := db.Exec(stmt) + return err +} + func beginTx() (tx *sql.Tx, err error) { if tx, err = db.Begin(); err != nil { logging.LogErrorf("begin tx failed: %s\n %s", err, logging.ShortStack()) diff --git a/kernel/sql/queue.go b/kernel/sql/queue.go index 0af1ebb2a..a9897ae28 100644 --- a/kernel/sql/queue.go +++ b/kernel/sql/queue.go @@ -184,6 +184,17 @@ func FlushQueue() { continue } + switch op.action { + case "index": + eventbus.Publish(eventbus.EvtEmbeddingDirty, op.indexTree.ID) + case "upsert": + eventbus.Publish(eventbus.EvtEmbeddingDirty, op.upsertTree.ID) + case "update_block_content": + eventbus.Publish(eventbus.EvtEmbeddingDirty, op.block.ID) + case "index_node": + eventbus.Publish(eventbus.EvtEmbeddingDirty, op.id) + } + if 16 < i && 0 == i%128 { debug.FreeOSMemory() } @@ -215,10 +226,21 @@ func execOp(op *dbQueueOperation, tx *sql.Tx, context map[string]any) (err error err = upsertTree(tx, op.upsertTree, context) case "delete": err = batchDeleteByPathPrefix(tx, op.removeTreeBox, op.removeTreePath) + if nil == err { + tx.Exec("DELETE FROM block_embeddings WHERE box = ? AND path LIKE ?", op.removeTreeBox, op.removeTreePath+"%") + } case "delete_id": err = deleteByRootID(tx, op.removeTreeID, context) + if nil == err { + tx.Exec("DELETE FROM block_embeddings WHERE root_id = ?", op.removeTreeID) + } case "delete_ids": err = batchDeleteByRootIDs(tx, op.removeTreeIDs, context) + if nil == err { + for _, rootID := range op.removeTreeIDs { + tx.Exec("DELETE FROM block_embeddings WHERE root_id = ?", rootID) + } + } case "rename": err = batchUpdateHPath(tx, op.indexTree, context) if err != nil { @@ -226,10 +248,19 @@ func execOp(op *dbQueueOperation, tx *sql.Tx, context map[string]any) (err error } err = updateRootContent(tx, path.Base(op.indexTree.HPath), op.indexTree.Root.IALAttr("updated"), treenode.IALStr(op.indexTree.Root), op.indexTree.ID) + if nil == err { + tx.Exec("UPDATE block_embeddings SET box = ?, path = ? WHERE root_id = ?", op.indexTree.Box, op.indexTree.Path, op.indexTree.ID) + } case "move": err = batchUpdatePath(tx, op.indexTree, context) + if nil == err { + tx.Exec("UPDATE block_embeddings SET box = ?, path = ? WHERE root_id = ?", op.indexTree.Box, op.indexTree.Path, op.indexTree.ID) + } case "delete_box": err = deleteByBoxTx(tx, op.box) + if nil == err { + tx.Exec("DELETE FROM block_embeddings WHERE box = ?", op.box) + } case "delete_box_refs": err = deleteRefsByBoxTx(tx, op.box) case "update_refs": diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 6f28b4603..5a189fe55 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -122,3 +122,37 @@ func (adt *AddHeaderTransport) RoundTrip(req *http.Request) (*http.Response, err func newAddHeaderTransport(transport *http.Transport, userAgent string) *AddHeaderTransport { return &AddHeaderTransport{RoundTripper: transport, UserAgent: userAgent} } + +func BatchGetEmbeddings(texts []string, apiKey, baseURL, model string, timeout int) (ret [][]float32, err error) { + if 1 > len(texts) { + return + } + + config := openai.DefaultConfig(apiKey) + config.BaseURL = baseURL + config.HTTPClient = &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + Transport: &AddHeaderTransport{ + RoundTripper: &http.Transport{}, + UserAgent: UserAgent, + }, + } + client := openai.NewClientWithConfig(config) + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ + Input: texts, + Model: openai.EmbeddingModel(model), + }) + if err != nil { + logging.LogErrorf("create embeddings failed: %s", err) + return + } + + for _, data := range resp.Data { + ret = append(ret, data.Embedding) + } + return +}