diff --git a/db.go b/db.go index 7271dfe..cd9f4b4 100644 --- a/db.go +++ b/db.go @@ -534,11 +534,6 @@ func (h *DBHandler) ProcessEmbeddings() error { offset := 0 totalCount := 0 - type HashData struct { - Hash string - Text string - } - err := h.db.QueryRow("SELECT COUNT(*) FROM hashes").Scan(&totalCount) if err != nil { return fmt.Errorf("error getting total count of hashes: %w", err) @@ -550,58 +545,6 @@ func (h *DBHandler) ProcessEmbeddings() error { } startTime := time.Now() - - numWorkers := options.aiEmbeddingWorkers - hashChan := make(chan HashData, numWorkers) - embeddingChan := make(chan map[string][]float32, numWorkers) - doneChan := make(chan bool) - - // Start workers - for i := 0; i < numWorkers; i++ { - go func() { - for hashData := range hashChan { - hash := hashData.Hash - text := hashData.Text - - exists, err := qdrantHashExists(qd.PointsClient, options.qdrantCollection, hash) - if err != nil { - log.Printf("Error checking qdrant existence for hash %s: %v", hash, err) - continue - } - if exists { - log.Printf("Hash %s already exists in qdrant", hash) - continue - } - - embedding, err := aiEmbeddings(text) - if err != nil { - log.Printf("Embedding generation error for hash %s: %v", hash, err) - continue - } - - embeddingChan <- map[string][]float32{hash: embedding} - } - doneChan <- true - }() - } - - // Collect embeddings - go func() { - hashEmbeddings := make(map[string][]float32) - for embedding := range embeddingChan { - for hash, emb := range embedding { - hashEmbeddings[hash] = emb - } - } - - if len(hashEmbeddings) > 0 { - err = qdrantUpsertPoints(qd.PointsClient, options.qdrantCollection, hashEmbeddings) - if err != nil { - log.Printf("Error upserting batch to qdrant: %v", err) - } - } - }() - for { rows, err := h.db.Query(`SELECT hash, text FROM hashes LIMIT ? OFFSET ?`, batchSize, offset) if err != nil { @@ -609,6 +552,10 @@ func (h *DBHandler) ProcessEmbeddings() error { } defer rows.Close() + type HashData struct { + Hash string + Text string + } var hashesData []HashData for rows.Next() { var data HashData @@ -626,9 +573,37 @@ func (h *DBHandler) ProcessEmbeddings() error { break } - // Send hashes to workers + hashEmbeddings := make(map[string][]float32) + var hashesToUpdate []string + for _, hashData := range hashesData { - hashChan <- hashData + hash := hashData.Hash + text := hashData.Text + + exists, err := qdrantHashExists(qd.PointsClient, options.qdrantCollection, hash) + if err != nil { + log.Printf("Error checking qdrant existence for hash %s: %v", hash, err) + continue + } + if exists { + log.Printf("Hash %s already exists in qdrant", hash) + continue + } + + embedding, err := aiEmbeddings(text) + if err != nil { + log.Printf("Embedding generation error for hash %s: %v", hash, err) + continue + } + hashEmbeddings[hash] = embedding + hashesToUpdate = append(hashesToUpdate, hash) + } + + if len(hashEmbeddings) > 0 { + err = qdrantUpsertPoints(qd.PointsClient, options.qdrantCollection, hashEmbeddings) + if err != nil { + log.Printf("Error upserting batch to qdrant: %v", err) + } } processedCount := offset + len(hashesData) @@ -642,15 +617,6 @@ func (h *DBHandler) ProcessEmbeddings() error { offset += batchSize } - close(hashChan) - - // Wait for all workers to finish - for i := 0; i < numWorkers; i++ { - <-doneChan - } - - close(embeddingChan) - err = qdrantIndexOn(qd.CollectionsClient, options.qdrantCollection) if err != nil { return fmt.Errorf("error enabling qdrant indexing: %w", err) diff --git a/main.go b/main.go index bf9e37e..af1ef52 100644 --- a/main.go +++ b/main.go @@ -10,31 +10,30 @@ import ( "os" ) -const Version = "0.0.40" +const Version = "0.0.41" type Config struct { - importPath string //https://dumps.wikimedia.org/other/enterprise_html/runs/... - dbPath string - web bool - webHost string - webPort int - ai bool - aiApiKey string - aiEmbeddingModel string - aiEmbeddingSize int - aiEmbeddingWorkers int - aiLlmModel string - aiUrl string - qdrant bool - qdrantHost string - qdrantPort int - qdrantSync bool - qdrantCollection string - log bool - logFile string - cli bool - limit int - language string + importPath string //https://dumps.wikimedia.org/other/enterprise_html/runs/... + dbPath string + web bool + webHost string + webPort int + ai bool + aiApiKey string + aiEmbeddingModel string + aiEmbeddingSize int + aiLlmModel string + aiUrl string + qdrant bool + qdrantHost string + qdrantPort int + qdrantSync bool + qdrantCollection string + log bool + logFile string + cli bool + limit int + language string } var ( @@ -47,7 +46,6 @@ func parseConfig() (*Config, error) { options = &Config{} flag.BoolVar(&options.ai, "ai", false, "Enable AI") flag.IntVar(&options.aiEmbeddingSize, "ai-embedding-size", 384, "AI embedding size") - flag.IntVar(&options.aiEmbeddingWorkers, "ai-embedding-workers", 1, "AI embedding workers") flag.StringVar(&options.aiEmbeddingModel, "ai-embedding-model", "all-minilm", "AI embedding model") flag.StringVar(&options.aiLlmModel, "ai-llm-model", "gemma2", "AI LLM model") flag.StringVar(&options.aiUrl, "ai-url", "http://localhost:11434/v1/", "AI base url")