diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bfd7b86..31eada4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -120,11 +120,12 @@ jobs: go-version: '1.23.4' - name: Install dependencies - run: sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu + run: sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu - name: Build wikilite for Linux ARM64 run: | CC=aarch64-linux-gnu-gcc \ + CXX=aarch64-linux-gnu-g++ \ CGO_ENABLED=1 \ GOOS=linux \ GOARCH=arm64 \ diff --git a/ai.go b/ai.go index 562d62a..f051aa7 100644 --- a/ai.go +++ b/ai.go @@ -8,47 +8,126 @@ import ( "fmt" "math" "math/bits" + "os" + "path/filepath" + "runtime" + "github.com/ollama/ollama/llama" "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/openai/openai-go/shared" ) -func aiInstruct(input string) (output string, err error) { - client := openai.NewClient( - option.WithAPIKey(options.aiApiKey), - option.WithBaseURL(options.aiApiUrl), - ) - chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(input), - }), - Model: openai.F(options.aiLlmModel), - }) - if err != nil { - return "", err +var aiLocal struct { + model *llama.Model + context *llama.Context + batchSize int +} + +func aiInit() error { + if options.aiApiUrl == "" { + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("AI error getting executable path: %v", err) + } + exeDir := filepath.Dir(exePath) + aiModelPath := filepath.Join(exeDir, options.aiModel+".gguf") + if _, err := os.Stat(aiModelPath); err != nil { + return err + } else { + originalStderr := os.Stderr + if os.Stderr, err = MuteStderr(); err != nil { + return err + } + aiLocal.batchSize = 512 + llama.BackendInit() + aiLocal.model, err = llama.LoadModelFromFile(aiModelPath, llama.ModelParams{UseMmap: true}) + if err != nil { + return err + } + aiLocal.context, err = llama.NewContextWithModel(aiLocal.model, llama.NewContextParams(2048, aiLocal.batchSize, 1, runtime.NumCPU(), false, "")) + if err != nil { + return err + } + os.Stderr = originalStderr + } } - return chatCompletion.Choices[0].Message.Content, nil + + if _, err := aiEmbeddings("test"); err != nil { + return fmt.Errorf("AI error loading embedding model: %v", err) + } + + return nil } func aiEmbeddings(input string) (output []float32, err error) { - client := openai.NewClient( - option.WithAPIKey(options.aiApiKey), - option.WithBaseURL(options.aiApiUrl), - ) - response, err := client.Embeddings.New(context.TODO(), openai.EmbeddingNewParams{ - Model: openai.F(options.aiEmbeddingModel), - Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(input)), - EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat), - }) - - if err != nil { - return nil, err - } + if options.aiApiUrl == "" { + tokens, err := aiLocal.model.Tokenize(input, true, true) + if err != nil { + return nil, fmt.Errorf("failed to tokenize text: %v", err) + } + + var embeddings []float32 + seqId := 0 + + for i := 0; i < len(tokens); i += aiLocal.batchSize { + end := i + aiLocal.batchSize + if end > len(tokens) { + end = len(tokens) + } + + batchTokens := tokens[i:end] + batch, err := llama.NewBatch(len(batchTokens), 1, 0) + if err != nil { + return nil, fmt.Errorf("failed to create batch: %v", err) + } + + for j, token := range batchTokens { + isLast := (i + j + 1) == len(tokens) + batch.Add(token, nil, j, isLast, seqId) + } + + if err := aiLocal.context.Decode(batch); err != nil { + batch.Free() + return nil, fmt.Errorf("failed to decode batch: %v", err) + } + + if i+len(batchTokens) == len(tokens) { + batchEmbeddings := aiLocal.context.GetEmbeddingsSeq(seqId) + if batchEmbeddings != nil { + embeddings = batchEmbeddings + } + } + + batch.Free() + } + + if embeddings == nil || len(embeddings) == 0 { + return nil, fmt.Errorf("failed to get embeddings") + } + + return embeddings, nil + + } else { + + client := openai.NewClient( + option.WithAPIKey(options.aiApiKey), + option.WithBaseURL(options.aiApiUrl), + ) + response, err := client.Embeddings.New(context.TODO(), openai.EmbeddingNewParams{ + Model: openai.F(options.aiModel), + Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(input)), + EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat), + }) + + if err != nil { + return nil, err + } - for _, embedding := range response.Data { - for _, value := range embedding.Embedding { - output = append(output, float32(value)) + for _, embedding := range response.Data { + for _, value := range embedding.Embedding { + output = append(output, float32(value)) + } } } return diff --git a/go.mod b/go.mod index f12e9e5..f1be773 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,10 @@ module wikilite -go 1.22.5 - -toolchain go1.22.10 +go 1.23.4 require ( github.com/mattn/go-sqlite3 v1.14.24 + github.com/ollama/ollama v0.5.5 github.com/openai/openai-go v0.1.0-alpha.39 golang.org/x/net v0.33.0 ) diff --git a/go.sum b/go.sum index e412fbe..73fa3d0 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ollama/ollama v0.5.5 h1:/L4qt66ESqlT4cnwaX0SKmg+JkVCMEHtoXH7MpADAq8= +github.com/ollama/ollama v0.5.5/go.mod h1:x7Jfo4Kgmp98pe5PwtZQr/9vCt9HM52PaHZ38MzdA/A= github.com/openai/openai-go v0.1.0-alpha.39 h1:FvoNWy7BPhA0TjGOK5huRGU5sAUEx2jeubLXz34K9LE= github.com/openai/openai-go v0.1.0-alpha.39/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/main.go b/main.go index e277fef..bfdbf29 100644 --- a/main.go +++ b/main.go @@ -10,50 +10,43 @@ import ( "os" ) -const Version = "0.2.0" +const Version = "0.3.0" type Config struct { - ai bool - aiApiKey string - aiApiUrl string - aiEmbeddingModel string - aiLlmModel string - cli bool - dbOptimize bool - dbPath string - dbSyncEmbeddings bool - dbSyncFTS bool - language string - limit int - log bool - logFile string - web bool - webHost string - webPort int - webTlsPrivate string - webTlsPublic string - wikiImport string //https://dumps.wikimedia.org/other/enterprise_html/runs/... + aiApiKey string + aiApiUrl string + aiModel string + aiSync bool + cli bool + dbPath string + language string + limit int + log bool + logFile string + web bool + webHost string + webPort int + webTlsPrivate string + webTlsPublic string + wikiImport string //https://dumps.wikimedia.org/other/enterprise_html/runs/... } var ( + ai bool db *DBHandler options *Config ) func parseConfig() (*Config, error) { options = &Config{} - flag.BoolVar(&options.ai, "ai", false, "Enable AI") flag.StringVar(&options.aiApiKey, "ai-api-key", "", "AI API key") - flag.StringVar(&options.aiApiUrl, "ai-api-url", "http://localhost:11434/v1/", "AI API base url") - flag.StringVar(&options.aiEmbeddingModel, "ai-embedding-model", "all-minilm", "AI embedding model") - flag.StringVar(&options.aiLlmModel, "ai-llm-model", "gemma2", "AI LLM model") + flag.StringVar(&options.aiApiUrl, "ai-api-url", "", "AI API base url") + flag.StringVar(&options.aiModel, "ai-model", "all-minilm", "AI embedding model") + flag.BoolVar(&options.aiSync, "ai-sync", false, "AI generate embeddings") flag.BoolVar(&options.cli, "cli", false, "Interactive search") flag.StringVar(&options.dbPath, "db", "wikilite.db", "SQLite database path") - flag.BoolVar(&options.dbOptimize, "db-optimize", false, "Optimize database") - flag.BoolVar(&options.dbSyncEmbeddings, "db-sync-embeddings", false, "Sync database embeddings") - flag.BoolVar(&options.dbSyncFTS, "db-sync-fts", false, "Sync database full text search") flag.StringVar(&options.language, "language", "en", "Language") flag.IntVar(&options.limit, "limit", 5, "Maximum number of search results") @@ -85,7 +78,7 @@ func parseConfig() (*Config, error) { func main() { options, err := parseConfig() if err != nil { - log.Fatalf("Error parsing command line: %v\n\n", err) + log.Fatalf("Error parsing command line: %v\n", err) } if flag.NFlag() == 0 { @@ -111,25 +104,24 @@ func main() { } defer db.Close() - if options.dbOptimize || options.dbSyncFTS || options.dbSyncEmbeddings || options.wikiImport != "" { + if err := aiInit(); err != nil { + log.Printf("AI initialization error: %v\n", err) + } else { + ai = true + } + + if options.aiSync || options.wikiImport != "" { if err := db.PragmaImportMode(); err != nil { log.Fatalf("Error setting database in import mode: %v\n", err) } + if options.wikiImport != "" { if err = WikiImport(options.wikiImport); err != nil { log.Fatalf("Error processing import: %v\n", err) } - options.dbOptimize = true - options.dbSyncFTS = true - } - - if options.dbOptimize { if err := db.Optimize(); err != nil { log.Fatalf("Error during database optimization: %v\n", err) } - } - - if options.dbSyncFTS { if err := db.ProcessTitles(); err != nil { log.Fatalf("Error processing FTS titles: %v\n", err) } @@ -138,7 +130,7 @@ func main() { } } - if options.dbSyncEmbeddings { + if ai && options.aiSync { if err := db.ProcessEmbeddings(); err != nil { log.Fatalf("Error processing embeddings: %v\n", err) } diff --git a/search.go b/search.go index 49d3588..714363e 100644 --- a/search.go +++ b/search.go @@ -32,7 +32,7 @@ func Search(query string, limit int) ([]SearchResult, error) { results = append(results, content) } - if options.ai { + if ai { log.Println("Vectors searching", query) vectors, err := db.SearchVectors(query, limit) if err != nil { diff --git a/utils.go b/utils.go index bf07ed0..54539c5 100644 --- a/utils.go +++ b/utils.go @@ -8,7 +8,9 @@ import ( "crypto/md5" "embed" "encoding/hex" + "fmt" "io" + "os" "regexp" "strconv" "strings" @@ -83,3 +85,11 @@ func TextDeflate(text string) []byte { return out.Bytes() } + +func MuteStderr() (*os.File, error) { + devNull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + if err != nil { + return nil, fmt.Errorf("failed to mute stderr: %v", err) + } + return devNull, nil +} diff --git a/web.go b/web.go index 09f668c..423f8b8 100644 --- a/web.go +++ b/web.go @@ -1,4 +1,4 @@ -// Copyright (C) 2024 by Ubaldo Porcheddu +// Copyright (C) 2024-2025 by Ubaldo Porcheddu package main @@ -177,7 +177,7 @@ func (s *WebServer) handleAPISearchContent(w http.ResponseWriter, r *http.Reques } func (s *WebServer) handleAPISearchVectors(w http.ResponseWriter, r *http.Request) { - if !options.ai { + if !ai { s.sendAPIError(w, "Vector search is not enabled", http.StatusBadRequest) return }