Skip to content

Commit

Permalink
remove embedding workers
Browse files Browse the repository at this point in the history
  • Loading branch information
ubaldus committed Dec 29, 2024
1 parent 12b01b6 commit dbac607
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 92 deletions.
102 changes: 34 additions & 68 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,6 @@ func (h *DBHandler) ProcessEmbeddings() error {
offset := 0
totalCount := 0

type HashData struct {
Hash string
Text string
}

err := h.db.QueryRow("SELECT COUNT(*) FROM hashes").Scan(&totalCount)
if err != nil {
return fmt.Errorf("error getting total count of hashes: %w", err)
Expand All @@ -550,65 +545,17 @@ func (h *DBHandler) ProcessEmbeddings() error {
}

startTime := time.Now()

numWorkers := options.aiEmbeddingWorkers
hashChan := make(chan HashData, numWorkers)
embeddingChan := make(chan map[string][]float32, numWorkers)
doneChan := make(chan bool)

// Start workers
for i := 0; i < numWorkers; i++ {
go func() {
for hashData := range hashChan {
hash := hashData.Hash
text := hashData.Text

exists, err := qdrantHashExists(qd.PointsClient, options.qdrantCollection, hash)
if err != nil {
log.Printf("Error checking qdrant existence for hash %s: %v", hash, err)
continue
}
if exists {
log.Printf("Hash %s already exists in qdrant", hash)
continue
}

embedding, err := aiEmbeddings(text)
if err != nil {
log.Printf("Embedding generation error for hash %s: %v", hash, err)
continue
}

embeddingChan <- map[string][]float32{hash: embedding}
}
doneChan <- true
}()
}

// Collect embeddings
go func() {
hashEmbeddings := make(map[string][]float32)
for embedding := range embeddingChan {
for hash, emb := range embedding {
hashEmbeddings[hash] = emb
}
}

if len(hashEmbeddings) > 0 {
err = qdrantUpsertPoints(qd.PointsClient, options.qdrantCollection, hashEmbeddings)
if err != nil {
log.Printf("Error upserting batch to qdrant: %v", err)
}
}
}()

for {
rows, err := h.db.Query(`SELECT hash, text FROM hashes LIMIT ? OFFSET ?`, batchSize, offset)
if err != nil {
return fmt.Errorf("error querying hashes: %w", err)
}
defer rows.Close()

type HashData struct {
Hash string
Text string
}
var hashesData []HashData
for rows.Next() {
var data HashData
Expand All @@ -626,9 +573,37 @@ func (h *DBHandler) ProcessEmbeddings() error {
break
}

// Send hashes to workers
hashEmbeddings := make(map[string][]float32)
var hashesToUpdate []string

for _, hashData := range hashesData {
hashChan <- hashData
hash := hashData.Hash
text := hashData.Text

exists, err := qdrantHashExists(qd.PointsClient, options.qdrantCollection, hash)
if err != nil {
log.Printf("Error checking qdrant existence for hash %s: %v", hash, err)
continue
}
if exists {
log.Printf("Hash %s already exists in qdrant", hash)
continue
}

embedding, err := aiEmbeddings(text)
if err != nil {
log.Printf("Embedding generation error for hash %s: %v", hash, err)
continue
}
hashEmbeddings[hash] = embedding
hashesToUpdate = append(hashesToUpdate, hash)
}

if len(hashEmbeddings) > 0 {
err = qdrantUpsertPoints(qd.PointsClient, options.qdrantCollection, hashEmbeddings)
if err != nil {
log.Printf("Error upserting batch to qdrant: %v", err)
}
}

processedCount := offset + len(hashesData)
Expand All @@ -642,15 +617,6 @@ func (h *DBHandler) ProcessEmbeddings() error {
offset += batchSize
}

close(hashChan)

// Wait for all workers to finish
for i := 0; i < numWorkers; i++ {
<-doneChan
}

close(embeddingChan)

err = qdrantIndexOn(qd.CollectionsClient, options.qdrantCollection)
if err != nil {
return fmt.Errorf("error enabling qdrant indexing: %w", err)
Expand Down
46 changes: 22 additions & 24 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,30 @@ import (
"os"
)

const Version = "0.0.40"
const Version = "0.0.41"

type Config struct {
importPath string //https://dumps.wikimedia.org/other/enterprise_html/runs/...
dbPath string
web bool
webHost string
webPort int
ai bool
aiApiKey string
aiEmbeddingModel string
aiEmbeddingSize int
aiEmbeddingWorkers int
aiLlmModel string
aiUrl string
qdrant bool
qdrantHost string
qdrantPort int
qdrantSync bool
qdrantCollection string
log bool
logFile string
cli bool
limit int
language string
importPath string //https://dumps.wikimedia.org/other/enterprise_html/runs/...
dbPath string
web bool
webHost string
webPort int
ai bool
aiApiKey string
aiEmbeddingModel string
aiEmbeddingSize int
aiLlmModel string
aiUrl string
qdrant bool
qdrantHost string
qdrantPort int
qdrantSync bool
qdrantCollection string
log bool
logFile string
cli bool
limit int
language string
}

var (
Expand All @@ -47,7 +46,6 @@ func parseConfig() (*Config, error) {
options = &Config{}
flag.BoolVar(&options.ai, "ai", false, "Enable AI")
flag.IntVar(&options.aiEmbeddingSize, "ai-embedding-size", 384, "AI embedding size")
flag.IntVar(&options.aiEmbeddingWorkers, "ai-embedding-workers", 1, "AI embedding workers")
flag.StringVar(&options.aiEmbeddingModel, "ai-embedding-model", "all-minilm", "AI embedding model")
flag.StringVar(&options.aiLlmModel, "ai-llm-model", "gemma2", "AI LLM model")
flag.StringVar(&options.aiUrl, "ai-url", "http://localhost:11434/v1/", "AI base url")
Expand Down

0 comments on commit dbac607

Please sign in to comment.