Signed-off-by: Daniel <845765@qq.com>
This commit is contained in:
Daniel
2026-05-31 10:18:07 +08:00
parent 5a02e7b28f
commit de8a724a9a

View File

@@ -17,11 +17,12 @@
package model
import (
"container/heap"
"encoding/binary"
"fmt"
"math"
"os"
"runtime"
"sort"
"sync"
"time"
"unsafe"
@@ -213,9 +214,25 @@ func cosineSimilarity(a, b []float32) float32 {
return float32(dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)))
}
type embeddingEntry struct {
id string
vec []float32
type scoredBlock struct {
id string
score float32
}
type scoredHeap []scoredBlock
func (h scoredHeap) Len() int { return len(h) }
func (h scoredHeap) Less(i, j int) bool { return h[i].score < h[j].score } // min-heap
func (h scoredHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *scoredHeap) Push(x any) {
*h = append(*h, x.(scoredBlock))
}
func (h *scoredHeap) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[:n-1]
return x
}
func SemanticSearchBlock(query string, boxes, paths []string, types, subTypes map[string]bool, page, pageSize int) (blocks []*Block, matchedBlockCount, matchedRootCount, pageCount int) {
@@ -232,87 +249,108 @@ func SemanticSearchBlock(query string, boxes, paths []string, types, subTypes ma
}
queryVec := vectors[0]
results, err := sql.QueryNoLimit("SELECT id, embedding FROM block_embeddings WHERE embedding IS NOT NULL AND length(embedding) > 0")
if err != nil {
logging.LogErrorf("query embeddings for search failed: %s", err)
return
}
var entries []embeddingEntry
for _, row := range results {
embRaw := row["embedding"].([]byte)
if len(embRaw) == 0 {
continue
}
id, _ := row["id"].(string)
buf := make([]byte, len(embRaw))
copy(buf, embRaw)
entries = append(entries, embeddingEntry{id: id, vec: decodeVector(buf)})
}
type scoredBlock struct {
id string
score float32
}
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers < 1 {
numWorkers = 1
}
chunkSize := (len(entries) + numWorkers - 1) / numWorkers
scoredCh := make(chan []scoredBlock, numWorkers)
var wg sync.WaitGroup
topK := page * pageSize
h := &scoredHeap{}
heap.Init(h)
for w := 0; w < numWorkers; w++ {
start := w * chunkSize
end := start + chunkSize
if end > len(entries) {
end = len(entries)
scanSize := 4096
cursor := int64(0)
for {
q := fmt.Sprintf("SELECT rowid, id, embedding FROM block_embeddings WHERE embedding IS NOT NULL AND length(embedding) > 0 AND rowid > %d ORDER BY rowid LIMIT %d", cursor, scanSize)
rows, qErr := sql.QueryNoLimit(q)
if qErr != nil {
logging.LogErrorf("query embeddings for search failed: %s", qErr)
break
}
if start >= end {
continue
if 1 > len(rows) {
break
}
wg.Add(1)
go func(chunk []embeddingEntry) {
defer wg.Done()
local := make([]scoredBlock, len(chunk))
for i, e := range chunk {
local[i] = scoredBlock{id: e.id, score: cosineSimilarity(queryVec, e.vec)}
rawCursor, _ := rows[len(rows)-1]["rowid"].(int64)
if rawCursor > cursor {
cursor = rawCursor
}
chunkSize := (len(rows) + numWorkers - 1) / numWorkers
scoredCh := make(chan []scoredBlock, numWorkers)
var wg sync.WaitGroup
for w := 0; w < numWorkers; w++ {
start := w * chunkSize
end := start + chunkSize
if end > len(rows) {
end = len(rows)
}
scoredCh <- local
}(entries[start:end])
if start >= end {
continue
}
wg.Add(1)
go func(chunk []map[string]any) {
defer wg.Done()
local := make([]scoredBlock, 0, len(chunk))
for _, row := range chunk {
embRaw := row["embedding"].([]byte)
if len(embRaw) == 0 {
continue
}
buf := make([]byte, len(embRaw))
copy(buf, embRaw)
vec := decodeVector(buf)
score := cosineSimilarity(queryVec, vec)
id, _ := row["id"].(string)
local = append(local, scoredBlock{id: id, score: score})
}
scoredCh <- local
}(rows[start:end])
}
wg.Wait()
close(scoredCh)
for ch := range scoredCh {
for _, s := range ch {
if h.Len() < topK {
heap.Push(h, s)
} else if s.score > (*h)[0].score {
heap.Pop(h)
heap.Push(h, s)
}
}
}
}
wg.Wait()
close(scoredCh)
var scored []scoredBlock
for ch := range scoredCh {
scored = append(scored, ch...)
matchedBlockCount = h.Len()
if 1 > matchedBlockCount {
pageCount = 0
return
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].score > scored[j].score
})
matchedBlockCount = len(scored)
result := make([]scoredBlock, h.Len())
for i := len(result) - 1; i >= 0; i-- {
result[i] = heap.Pop(h).(scoredBlock)
}
offset := (page - 1) * pageSize
if offset >= len(scored) {
if offset >= len(result) {
pageCount = (matchedBlockCount + pageSize - 1) / pageSize
return
}
end := offset + pageSize
if end > len(scored) {
end = len(scored)
if end > len(result) {
end = len(result)
}
var topIDs []string
for i := offset; i < end; i++ {
topIDs = append(topIDs, scored[i].id)
topIDs = append(topIDs, result[i].id)
}
sqlBlocks := sql.GetBlocks(topIDs)