mirror of
https://github.com/siyuan-note/siyuan.git
synced 2026-06-27 22:36:00 +00:00
✨ Semantic search using AI embeddings https://github.com/siyuan-note/siyuan/issues/17788
Signed-off-by: Daniel <845765@qq.com>
This commit is contained in:
@@ -17,11 +17,12 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -213,9 +214,25 @@ func cosineSimilarity(a, b []float32) float32 {
|
||||
return float32(dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)))
|
||||
}
|
||||
|
||||
type embeddingEntry struct {
|
||||
id string
|
||||
vec []float32
|
||||
type scoredBlock struct {
|
||||
id string
|
||||
score float32
|
||||
}
|
||||
|
||||
type scoredHeap []scoredBlock
|
||||
|
||||
func (h scoredHeap) Len() int { return len(h) }
|
||||
func (h scoredHeap) Less(i, j int) bool { return h[i].score < h[j].score } // min-heap
|
||||
func (h scoredHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *scoredHeap) Push(x any) {
|
||||
*h = append(*h, x.(scoredBlock))
|
||||
}
|
||||
func (h *scoredHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
func SemanticSearchBlock(query string, boxes, paths []string, types, subTypes map[string]bool, page, pageSize int) (blocks []*Block, matchedBlockCount, matchedRootCount, pageCount int) {
|
||||
@@ -232,87 +249,108 @@ func SemanticSearchBlock(query string, boxes, paths []string, types, subTypes ma
|
||||
}
|
||||
queryVec := vectors[0]
|
||||
|
||||
results, err := sql.QueryNoLimit("SELECT id, embedding FROM block_embeddings WHERE embedding IS NOT NULL AND length(embedding) > 0")
|
||||
if err != nil {
|
||||
logging.LogErrorf("query embeddings for search failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var entries []embeddingEntry
|
||||
for _, row := range results {
|
||||
embRaw := row["embedding"].([]byte)
|
||||
if len(embRaw) == 0 {
|
||||
continue
|
||||
}
|
||||
id, _ := row["id"].(string)
|
||||
buf := make([]byte, len(embRaw))
|
||||
copy(buf, embRaw)
|
||||
entries = append(entries, embeddingEntry{id: id, vec: decodeVector(buf)})
|
||||
}
|
||||
|
||||
type scoredBlock struct {
|
||||
id string
|
||||
score float32
|
||||
}
|
||||
|
||||
numWorkers := runtime.GOMAXPROCS(0)
|
||||
if numWorkers < 1 {
|
||||
numWorkers = 1
|
||||
}
|
||||
chunkSize := (len(entries) + numWorkers - 1) / numWorkers
|
||||
|
||||
scoredCh := make(chan []scoredBlock, numWorkers)
|
||||
var wg sync.WaitGroup
|
||||
topK := page * pageSize
|
||||
h := &scoredHeap{}
|
||||
heap.Init(h)
|
||||
|
||||
for w := 0; w < numWorkers; w++ {
|
||||
start := w * chunkSize
|
||||
end := start + chunkSize
|
||||
if end > len(entries) {
|
||||
end = len(entries)
|
||||
scanSize := 4096
|
||||
cursor := int64(0)
|
||||
|
||||
for {
|
||||
q := fmt.Sprintf("SELECT rowid, id, embedding FROM block_embeddings WHERE embedding IS NOT NULL AND length(embedding) > 0 AND rowid > %d ORDER BY rowid LIMIT %d", cursor, scanSize)
|
||||
rows, qErr := sql.QueryNoLimit(q)
|
||||
if qErr != nil {
|
||||
logging.LogErrorf("query embeddings for search failed: %s", qErr)
|
||||
break
|
||||
}
|
||||
if start >= end {
|
||||
continue
|
||||
if 1 > len(rows) {
|
||||
break
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(chunk []embeddingEntry) {
|
||||
defer wg.Done()
|
||||
local := make([]scoredBlock, len(chunk))
|
||||
for i, e := range chunk {
|
||||
local[i] = scoredBlock{id: e.id, score: cosineSimilarity(queryVec, e.vec)}
|
||||
rawCursor, _ := rows[len(rows)-1]["rowid"].(int64)
|
||||
if rawCursor > cursor {
|
||||
cursor = rawCursor
|
||||
}
|
||||
|
||||
chunkSize := (len(rows) + numWorkers - 1) / numWorkers
|
||||
scoredCh := make(chan []scoredBlock, numWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for w := 0; w < numWorkers; w++ {
|
||||
start := w * chunkSize
|
||||
end := start + chunkSize
|
||||
if end > len(rows) {
|
||||
end = len(rows)
|
||||
}
|
||||
scoredCh <- local
|
||||
}(entries[start:end])
|
||||
if start >= end {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(chunk []map[string]any) {
|
||||
defer wg.Done()
|
||||
local := make([]scoredBlock, 0, len(chunk))
|
||||
for _, row := range chunk {
|
||||
embRaw := row["embedding"].([]byte)
|
||||
if len(embRaw) == 0 {
|
||||
continue
|
||||
}
|
||||
buf := make([]byte, len(embRaw))
|
||||
copy(buf, embRaw)
|
||||
vec := decodeVector(buf)
|
||||
score := cosineSimilarity(queryVec, vec)
|
||||
id, _ := row["id"].(string)
|
||||
local = append(local, scoredBlock{id: id, score: score})
|
||||
}
|
||||
scoredCh <- local
|
||||
}(rows[start:end])
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(scoredCh)
|
||||
|
||||
for ch := range scoredCh {
|
||||
for _, s := range ch {
|
||||
if h.Len() < topK {
|
||||
heap.Push(h, s)
|
||||
} else if s.score > (*h)[0].score {
|
||||
heap.Pop(h)
|
||||
heap.Push(h, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(scoredCh)
|
||||
|
||||
var scored []scoredBlock
|
||||
for ch := range scoredCh {
|
||||
scored = append(scored, ch...)
|
||||
matchedBlockCount = h.Len()
|
||||
if 1 > matchedBlockCount {
|
||||
pageCount = 0
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(scored, func(i, j int) bool {
|
||||
return scored[i].score > scored[j].score
|
||||
})
|
||||
|
||||
matchedBlockCount = len(scored)
|
||||
result := make([]scoredBlock, h.Len())
|
||||
for i := len(result) - 1; i >= 0; i-- {
|
||||
result[i] = heap.Pop(h).(scoredBlock)
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
if offset >= len(scored) {
|
||||
if offset >= len(result) {
|
||||
pageCount = (matchedBlockCount + pageSize - 1) / pageSize
|
||||
return
|
||||
}
|
||||
|
||||
end := offset + pageSize
|
||||
if end > len(scored) {
|
||||
end = len(scored)
|
||||
if end > len(result) {
|
||||
end = len(result)
|
||||
}
|
||||
|
||||
var topIDs []string
|
||||
for i := offset; i < end; i++ {
|
||||
topIDs = append(topIDs, scored[i].id)
|
||||
topIDs = append(topIDs, result[i].id)
|
||||
}
|
||||
|
||||
sqlBlocks := sql.GetBlocks(topIDs)
|
||||
|
||||
Reference in New Issue
Block a user