Skip to content

Commit

Permalink
ollama/llama integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ubaldus committed Jan 12, 2025
1 parent 021629d commit 7a958a4
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 76 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ jobs:
go-version: '1.23.4'

- name: Install dependencies
run: sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu
run: sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu

- name: Build wikilite for Linux ARM64
run: |
CC=aarch64-linux-gnu-gcc \
CXX=aarch64-linux-gnu-g++ \
CGO_ENABLED=1 \
GOOS=linux \
GOARCH=arm64 \
Expand Down
139 changes: 109 additions & 30 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,126 @@ import (
"fmt"
"math"
"math/bits"
"os"
"path/filepath"
"runtime"

"github.com/ollama/ollama/llama"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/shared"
)

func aiInstruct(input string) (output string, err error) {
client := openai.NewClient(
option.WithAPIKey(options.aiApiKey),
option.WithBaseURL(options.aiApiUrl),
)
chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(input),
}),
Model: openai.F(options.aiLlmModel),
})
if err != nil {
return "", err
var aiLocal struct {
model *llama.Model
context *llama.Context
batchSize int
}

func aiInit() error {
if options.aiApiUrl == "" {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("AI error getting executable path: %v", err)
}
exeDir := filepath.Dir(exePath)
aiModelPath := filepath.Join(exeDir, options.aiModel+".gguf")
if _, err := os.Stat(aiModelPath); err != nil {
return err
} else {
originalStderr := os.Stderr
if os.Stderr, err = MuteStderr(); err != nil {
return err
}
aiLocal.batchSize = 512
llama.BackendInit()
aiLocal.model, err = llama.LoadModelFromFile(aiModelPath, llama.ModelParams{UseMmap: true})
if err != nil {
return err
}
aiLocal.context, err = llama.NewContextWithModel(aiLocal.model, llama.NewContextParams(2048, aiLocal.batchSize, 1, runtime.NumCPU(), false, ""))
if err != nil {
return err
}
os.Stderr = originalStderr
}
}
return chatCompletion.Choices[0].Message.Content, nil

if _, err := aiEmbeddings("test"); err != nil {
return fmt.Errorf("AI error loading embedding model: %v", err)
}

return nil
}

func aiEmbeddings(input string) (output []float32, err error) {
client := openai.NewClient(
option.WithAPIKey(options.aiApiKey),
option.WithBaseURL(options.aiApiUrl),
)
response, err := client.Embeddings.New(context.TODO(), openai.EmbeddingNewParams{
Model: openai.F(options.aiEmbeddingModel),
Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(input)),
EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat),
})

if err != nil {
return nil, err
}
if options.aiApiUrl == "" {
tokens, err := aiLocal.model.Tokenize(input, true, true)
if err != nil {
return nil, fmt.Errorf("failed to tokenize text: %v", err)
}

var embeddings []float32
seqId := 0

for i := 0; i < len(tokens); i += aiLocal.batchSize {
end := i + aiLocal.batchSize
if end > len(tokens) {
end = len(tokens)
}

batchTokens := tokens[i:end]
batch, err := llama.NewBatch(len(batchTokens), 1, 0)
if err != nil {
return nil, fmt.Errorf("failed to create batch: %v", err)
}

for j, token := range batchTokens {
isLast := (i + j + 1) == len(tokens)
batch.Add(token, nil, j, isLast, seqId)
}

if err := aiLocal.context.Decode(batch); err != nil {
batch.Free()
return nil, fmt.Errorf("failed to decode batch: %v", err)
}

if i+len(batchTokens) == len(tokens) {
batchEmbeddings := aiLocal.context.GetEmbeddingsSeq(seqId)
if batchEmbeddings != nil {
embeddings = batchEmbeddings
}
}

batch.Free()
}

if embeddings == nil || len(embeddings) == 0 {
return nil, fmt.Errorf("failed to get embeddings")
}

return embeddings, nil

} else {

client := openai.NewClient(
option.WithAPIKey(options.aiApiKey),
option.WithBaseURL(options.aiApiUrl),
)
response, err := client.Embeddings.New(context.TODO(), openai.EmbeddingNewParams{
Model: openai.F(options.aiModel),
Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(input)),
EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat),
})

if err != nil {
return nil, err
}

for _, embedding := range response.Data {
for _, value := range embedding.Embedding {
output = append(output, float32(value))
for _, embedding := range response.Data {
for _, value := range embedding.Embedding {
output = append(output, float32(value))
}
}
}
return
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module wikilite

go 1.22.5

toolchain go1.22.10
go 1.23.4

require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/ollama/ollama v0.5.5
github.com/openai/openai-go v0.1.0-alpha.39
golang.org/x/net v0.33.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/ollama/ollama v0.5.5 h1:/L4qt66ESqlT4cnwaX0SKmg+JkVCMEHtoXH7MpADAq8=
github.com/ollama/ollama v0.5.5/go.mod h1:x7Jfo4Kgmp98pe5PwtZQr/9vCt9HM52PaHZ38MzdA/A=
github.com/openai/openai-go v0.1.0-alpha.39 h1:FvoNWy7BPhA0TjGOK5huRGU5sAUEx2jeubLXz34K9LE=
github.com/openai/openai-go v0.1.0-alpha.39/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
Expand Down
70 changes: 31 additions & 39 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,43 @@ import (
"os"
)

const Version = "0.2.0"
const Version = "0.3.0"

type Config struct {
ai bool
aiApiKey string
aiApiUrl string
aiEmbeddingModel string
aiLlmModel string
cli bool
dbOptimize bool
dbPath string
dbSyncEmbeddings bool
dbSyncFTS bool
language string
limit int
log bool
logFile string
web bool
webHost string
webPort int
webTlsPrivate string
webTlsPublic string
wikiImport string //https://dumps.wikimedia.org/other/enterprise_html/runs/...
aiApiKey string
aiApiUrl string
aiModel string
aiSync bool
cli bool
dbPath string
language string
limit int
log bool
logFile string
web bool
webHost string
webPort int
webTlsPrivate string
webTlsPublic string
wikiImport string //https://dumps.wikimedia.org/other/enterprise_html/runs/...
}

var (
ai bool
db *DBHandler
options *Config
)

func parseConfig() (*Config, error) {
options = &Config{}
flag.BoolVar(&options.ai, "ai", false, "Enable AI")
flag.StringVar(&options.aiApiKey, "ai-api-key", "", "AI API key")
flag.StringVar(&options.aiApiUrl, "ai-api-url", "http://localhost:11434/v1/", "AI API base url")
flag.StringVar(&options.aiEmbeddingModel, "ai-embedding-model", "all-minilm", "AI embedding model")
flag.StringVar(&options.aiLlmModel, "ai-llm-model", "gemma2", "AI LLM model")
flag.StringVar(&options.aiApiUrl, "ai-api-url", "", "AI API base url")
flag.StringVar(&options.aiModel, "ai-model", "all-minilm", "AI embedding model")
flag.BoolVar(&options.aiSync, "ai-sync", false, "AI generate embeddings")

flag.BoolVar(&options.cli, "cli", false, "Interactive search")

flag.StringVar(&options.dbPath, "db", "wikilite.db", "SQLite database path")
flag.BoolVar(&options.dbOptimize, "db-optimize", false, "Optimize database")
flag.BoolVar(&options.dbSyncEmbeddings, "db-sync-embeddings", false, "Sync database embeddings")
flag.BoolVar(&options.dbSyncFTS, "db-sync-fts", false, "Sync database full text search")

flag.StringVar(&options.language, "language", "en", "Language")
flag.IntVar(&options.limit, "limit", 5, "Maximum number of search results")
Expand Down Expand Up @@ -85,7 +78,7 @@ func parseConfig() (*Config, error) {
func main() {
options, err := parseConfig()
if err != nil {
log.Fatalf("Error parsing command line: %v\n\n", err)
log.Fatalf("Error parsing command line: %v\n", err)
}

if flag.NFlag() == 0 {
Expand All @@ -111,25 +104,24 @@ func main() {
}
defer db.Close()

if options.dbOptimize || options.dbSyncFTS || options.dbSyncEmbeddings || options.wikiImport != "" {
if err := aiInit(); err != nil {
log.Printf("AI initialization error: %v\n", err)
} else {
ai = true
}

if options.aiSync || options.wikiImport != "" {
if err := db.PragmaImportMode(); err != nil {
log.Fatalf("Error setting database in import mode: %v\n", err)
}

if options.wikiImport != "" {
if err = WikiImport(options.wikiImport); err != nil {
log.Fatalf("Error processing import: %v\n", err)
}
options.dbOptimize = true
options.dbSyncFTS = true
}

if options.dbOptimize {
if err := db.Optimize(); err != nil {
log.Fatalf("Error during database optimization: %v\n", err)
}
}

if options.dbSyncFTS {
if err := db.ProcessTitles(); err != nil {
log.Fatalf("Error processing FTS titles: %v\n", err)
}
Expand All @@ -138,7 +130,7 @@ func main() {
}
}

if options.dbSyncEmbeddings {
if ai && options.aiSync {
if err := db.ProcessEmbeddings(); err != nil {
log.Fatalf("Error processing embeddings: %v\n", err)
}
Expand Down
2 changes: 1 addition & 1 deletion search.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func Search(query string, limit int) ([]SearchResult, error) {
results = append(results, content)
}

if options.ai {
if ai {
log.Println("Vectors searching", query)
vectors, err := db.SearchVectors(query, limit)
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"crypto/md5"
"embed"
"encoding/hex"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -83,3 +85,11 @@ func TextDeflate(text string) []byte {

return out.Bytes()
}

func MuteStderr() (*os.File, error) {
devNull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
if err != nil {
return nil, fmt.Errorf("failed to mute stderr: %v", err)
}
return devNull, nil
}
4 changes: 2 additions & 2 deletions web.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024 by Ubaldo Porcheddu <ubaldo@eja.it>
// Copyright (C) 2024-2025 by Ubaldo Porcheddu <ubaldo@eja.it>

package main

Expand Down Expand Up @@ -177,7 +177,7 @@ func (s *WebServer) handleAPISearchContent(w http.ResponseWriter, r *http.Reques
}

func (s *WebServer) handleAPISearchVectors(w http.ResponseWriter, r *http.Request) {
if !options.ai {
if !ai {
s.sendAPIError(w, "Vector search is not enabled", http.StatusBadRequest)
return
}
Expand Down

0 comments on commit 7a958a4

Please sign in to comment.