Skip to content

Commit

Permalink
Merge pull request #190 from brevdev/ollama-simplify-further
Browse files Browse the repository at this point in the history
ollama: arbitrary model loading
  • Loading branch information
ishandhanani authored May 17, 2024
2 parents 13c8398 + ca9c632 commit c5575b2
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 20 deletions.
88 changes: 70 additions & 18 deletions pkg/cmd/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ package ollama
import (
_ "embed"
"fmt"
"os"
"os/exec"
"strings"
"time"

"github.com/google/uuid"

"github.com/brevdev/brev-cli/pkg/cmd/refresh"
"github.com/brevdev/brev-cli/pkg/cmd/util"
"github.com/brevdev/brev-cli/pkg/collections"
"github.com/brevdev/brev-cli/pkg/config"
Expand All @@ -25,13 +29,13 @@ var (
ollamaExample = `
brev ollama --model llama3
`
modelTypes = []string{"llama3"}
)

//go:embed ollamaverb.yaml
var verbYaml string

type OllamaStore interface {
refresh.RefreshStore
util.GetWorkspaceByNameOrIDErrStore
GetActiveOrganizationOrDefault() (*entity.Organization, error)
GetCurrentUser() (*entity.User, error)
Expand All @@ -41,13 +45,26 @@ type OllamaStore interface {
ModifyPublicity(workspace *entity.Workspace, applicationName string, publicity bool) (*entity.Tunnel, error)
}

func validateModelType(modelType string) bool {
for _, v := range modelTypes {
if modelType == v {
return true
}
func validateModelType(input string) (bool, error) {
var model string
var tag string

split := strings.Split(input, ":")
switch len(split) {
case 2:
model = split[0]
tag = split[1]
case 1:
model = input
tag = "latest"
default:
return false, fmt.Errorf("invalid model type: %s", input)
}
valid, err := store.ValidateOllamaModel(model, tag)
if err != nil {
return false, fmt.Errorf("error validating model: %s", err)
}
return false
return valid, nil
}

func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command {
Expand All @@ -67,7 +84,10 @@ func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command
return fmt.Errorf("model type must be specified")
}

isValid := validateModelType(model)
isValid, valErr := validateModelType(model)
if valErr != nil {
return valErr
}
if !isValid {
return fmt.Errorf("invalid model type: %s", model)
}
Expand Down Expand Up @@ -115,7 +135,7 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
instanceName := fmt.Sprintf("ollama-%s", uuid)
cwOptions := store.NewCreateWorkspacesOptions(clusterID, instanceName).WithInstanceType(instanceType)

hello.TypeItToMeUnskippable27(fmt.Sprintf("Creating Ollama server %s with model %s in org %s", t.Green(cwOptions.Name), t.Green(model), t.Green(org.ID)))
hello.TypeItToMeUnskippable27(fmt.Sprintf("Creating Ollama server %s with model %s in org %s\n", t.Green(cwOptions.Name), t.Green(model), t.Green(org.ID)))

s := t.NewSpinner()

Expand All @@ -140,14 +160,15 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
if err != nil {
return breverrors.WrapAndTrace(err)
}

if !vmStatus {
return breverrors.New("instance did not start")
}
s.Stop()
hello.TypeItToMeUnskippable27(fmt.Sprintf("VM is ready!\n"))

Check failure on line 167 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / ci (ubuntu-20.04)

S1039: unnecessary use of fmt.Sprintf (gosimple)

Check failure on line 167 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / goreleaser

S1039: unnecessary use of fmt.Sprintf (gosimple)
s.Start()

// sleep for 10 seconds to solve for possible race condition
// TODO: look into timing of verb call
time.Sleep(time.Second * 10)
time.Sleep(time.Second * 5)

verbBuildRes := collections.Async(func() (*store.BuildVerbRes, error) {
lf, errr := ollamaStore.BuildVerbContainer(w.ID, verbYaml)
Expand All @@ -157,8 +178,7 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
return lf, nil
})

s.Start()
s.Suffix = "Starting the Ollama server. Hang tight 🤙"
s.Suffix = " Building the Ollama container. Hang tight 🤙"

_, err = verbBuildRes.Await()
if err != nil {
Expand All @@ -175,9 +195,15 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
if !vstatus {
return breverrors.New("verb container did not build correctly")
}

s.Stop()

s = t.NewSpinner()
s.Suffix = "(connectivity) Pulling the %s model, just a bit more! 🏄"

// shell in and run ollama pull:
if err := refresh.RunRefresh(ollamaStore); err != nil {

Check failure on line 204 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / ci (ubuntu-20.04)

shadow: declaration of "err" shadows declaration at line 121 (govet)

Check failure on line 204 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / goreleaser

shadow: declaration of "err" shadows declaration at line 121 (govet)
return breverrors.WrapAndTrace(err)
}
// Reload workspace to get the latest status
w, err = ollamaStore.GetWorkspace(w.ID)
if err != nil {
Expand All @@ -194,15 +220,26 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
return breverrors.WrapAndTrace(err)
}

s.Suffix = "Pulling the %s model, just a bit more! 🏄"

// shell in and run ollama pull:
if err := runSSHExec(instanceName, []string{"ollama", "pull", model}, false); err != nil {
return breverrors.WrapAndTrace(err)
}
if err := runSSHExec(instanceName, []string{"ollama", "run", model, "hello world"}, true); err != nil {
return breverrors.WrapAndTrace(err)
}
s.Stop()

fmt.Print("\n")
t.Vprint(t.Green("Ollama is ready to go!\n"))
displayOllamaConnectBreadCrumb(t, link)
displayOllamaConnectBreadCrumb(t, link, model)
return nil
}

func displayOllamaConnectBreadCrumb(t *terminal.Terminal, link string) {
func displayOllamaConnectBreadCrumb(t *terminal.Terminal, link string, model string) {
t.Vprintf(t.Green("Query the Ollama API with the following command:\n"))
t.Vprintf(t.Yellow(fmt.Sprintf("curl %s/api/chat -d '{\n \"model\": \"llama3\",\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"why is the sky blue?\"\n }\n ]\n}'", link)))
t.Vprintf(t.Yellow(fmt.Sprintf("curl %s/api/chat -d '{\n \"model\": \"%s\",\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"why is the sky blue?\"\n }\n ]\n}'\n", link, model)))
}

func pollInstanceUntilVMReady(workspace *entity.Workspace, interval time.Duration, timeout time.Duration, ollamaStore OllamaStore) (bool, error) {
Expand Down Expand Up @@ -267,3 +304,18 @@ func makeTunnelPublic(workspace *entity.Workspace, applicationName string, ollam
}
return false, breverrors.New("Could not find Ollama tunnel")
}

func runSSHExec(sshAlias string, args []string, fireAndForget bool) error {
sshAgentEval := "eval $(ssh-agent -s)"
cmd := fmt.Sprintf("ssh %s -- %s", sshAlias, strings.Join(args, " "))
cmd = fmt.Sprintf("%s && %s", sshAgentEval, cmd)
sshCmd := exec.Command("bash", "-c", cmd) //nolint:gosec //cmd is user input

if fireAndForget {
return sshCmd.Start()

Check failure on line 315 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / ci (ubuntu-20.04)

error returned from external package is unwrapped: sig: func (*os/exec.Cmd).Start() error (wrapcheck)

Check failure on line 315 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / goreleaser

error returned from external package is unwrapped: sig: func (*os/exec.Cmd).Start() error (wrapcheck)
}
sshCmd.Stderr = os.Stderr
sshCmd.Stdout = os.Stdout
sshCmd.Stdin = os.Stdin
return sshCmd.Run()

Check failure on line 320 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / ci (ubuntu-20.04)

error returned from external package is unwrapped: sig: func (*os/exec.Cmd).Run() error (wrapcheck)

Check failure on line 320 in pkg/cmd/ollama/ollama.go

View workflow job for this annotation

GitHub Actions / goreleaser

error returned from external package is unwrapped: sig: func (*os/exec.Cmd).Run() error (wrapcheck)
}
Empty file.
3 changes: 1 addition & 2 deletions pkg/cmd/ollama/ollamaverb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ build:
- jupyterlab
run:
- curl -fsSL https://ollama.com/install.sh | sh
- ollama serve & sleep 10; ollama pull llama3; echo "kill 'ollama serve' process"; ps -ef | grep 'ollama serve' | grep -v grep | awk '{print $2}' | xargs -r kill -9
user:
shell: zsh
authorized_keys_path: /home/ubuntu/.ssh/authorized_keys
Expand All @@ -16,4 +15,4 @@ services:
- name: ollama-server
entrypoint: OLLAMA_HOST=0.0.0.0 ollama serve
ports:
- 127.0.0.1:11434:11434
- 127.0.0.1:11434:11434
5 changes: 5 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
defaultWorkspaceTemplate EnvVarName = "DEFAULT_WORKSPACE_TEMPLATE"
sentryURL EnvVarName = "DEFAULT_SENTRY_URL"
debugHTTP EnvVarName = "DEBUG_HTTP"
ollamaAPIURL EnvVarName = "OLLAMA_API_URL"
)

type ConstantsConfig struct{}
Expand All @@ -27,6 +28,10 @@ func (c ConstantsConfig) GetBrevAPIURl() string {
return getEnvOrDefault(brevAPIURL, "https://brevapi.us-west-2-prod.control-plane.brev.dev")
}

func (c ConstantsConfig) GetOllamaAPIURL() string {
return getEnvOrDefault(ollamaAPIURL, "https://registry.ollama.ai")
}

func (c ConstantsConfig) GetServiceMeshCoordServerURL() string {
return getEnvOrDefault(coordURL, "")
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/store/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ type NoAuthHTTPClient struct {
restyClient *resty.Client
}

type OllamaHTTPClient struct {
restyClient *resty.Client
}

func NewOllamaHTTPClient(ollamaAPIURL string) *OllamaHTTPClient {
restyClient := resty.New().SetBaseURL(ollamaAPIURL)

return &OllamaHTTPClient{
restyClient: restyClient,
}
}

func NewNoAuthHTTPClient(brevAPIURL string) *NoAuthHTTPClient {
restyClient := NewRestyClient(brevAPIURL)
return &NoAuthHTTPClient{restyClient}
Expand Down
75 changes: 75 additions & 0 deletions pkg/store/workspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
breverrors "github.com/brevdev/brev-cli/pkg/errors"
"github.com/brevdev/brev-cli/pkg/setupscript"
"github.com/brevdev/brev-cli/pkg/uri"
resty "github.com/go-resty/resty/v2"
"github.com/google/uuid"
"github.com/spf13/afero"
)
Expand Down Expand Up @@ -647,3 +648,77 @@ func (s AuthHTTPStore) ModifyPublicity(workspace *entity.Workspace, applicationN
}
return &result, nil
}

type OllamaRegistrySuccessResponse struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config OllamaConfig `json:"config"`
Layers []OllamaLayer `json:"layers"`
}

type OllamaConfig struct {
MediaType string `json:"mediaType"`
Size int `json:"size"`
Digest string `json:"digest"`
}

type OllamaLayer struct {
MediaType string `json:"mediaType"`
Size int `json:"size"`
Digest string `json:"digest"`
}

type OllamaRegistryFailureResponse struct {
Errors []OllamaRegistryError `json:"errors"`
}

type OllamaRegistryError struct {
Code string `json:"code"`
Message string `json:"message"`
Detail OllamaRegistryErrorDetail `json:"detail"`
}

type OllamaRegistryErrorDetail struct {
Tag string `json:"Tag"`
}

type OllamaModelRequest struct {
Model string
Tag string
}

var (
modelNameParamName = "modelName"
tagNameParamName = "tagName"
ollamaModelPathPattern = "v2/library/%s/manifests/%s"
ollamaModelPath = fmt.Sprintf(ollamaModelPathPattern, fmt.Sprintf("{%s}", modelNameParamName), fmt.Sprintf("{%s}", tagNameParamName))
)

func ValidateOllamaModel(model string, tag string) (bool, error) {
restyClient := resty.New().SetBaseURL(config.NewConstants().GetOllamaAPIURL())
if tag == "" {
tag = "latest"
}
res, err := restyClient.R().
SetHeader("Accept", "application/vnd.docker.distribution.manifest.v2+json").
SetPathParam(modelNameParamName, model).
SetPathParam(tagNameParamName, tag).
Get(ollamaModelPath)

if err != nil {
return false, breverrors.WrapAndTrace(err)
}
if res.StatusCode() == 200 { //nolint:gocritic // 200 is a valid status code
if err := json.Unmarshal(res.Body(), &OllamaRegistrySuccessResponse{}); err != nil {
return false, breverrors.WrapAndTrace(err)
}
return true, nil
} else if res.StatusCode() == 404 {
if err := json.Unmarshal(res.Body(), &OllamaRegistryFailureResponse{}); err != nil {
return false, breverrors.WrapAndTrace(err)
}
return false, nil
} else {
return false, breverrors.New("invalid response from ollama registry")
}
}
29 changes: 29 additions & 0 deletions pkg/store/workspace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,32 @@ func TestDeleteWorkspace(t *testing.T) { //nolint:dupl // ok to have this be dup
return
}
}

func TestValidateOllamaModel(t *testing.T) {
type args struct {
model string
tag string
}
tests := []struct {
name string
args args
want bool
wantErr bool
}{
{"empty", args{"", ""}, false, false},
{"llama3", args{"llama3", ""}, true, false},
{"llama3:80b", args{"llama3", "80b"}, false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ValidateOllamaModel(tt.args.model, tt.args.tag)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateOllamaModel() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ValidateOllamaModel() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit c5575b2

Please sign in to comment.