Skip to content

Commit

Permalink
Supports /embed endpoint. Ref #7
Browse files Browse the repository at this point in the history
  • Loading branch information
lebrunel committed Aug 12, 2024
1 parent 020d47c commit d38f791
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 1 deletion.
58 changes: 57 additions & 1 deletion lib/ollama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ defmodule Ollama do
],
context: [
type: {:list, {:or, [:integer, :float]}},
doc: "The context parameter returned from a previous `f:completion/2` call (enabling short conversational memory).",
doc: "The context parameter returned from a previous `completion/2` call (enabling short conversational memory).",
],
format: [
type: :string,
Expand Down Expand Up @@ -842,6 +842,61 @@ defmodule Ollama do
end


schema :embed, [
model: [
type: :string,
required: true,
doc: "The name of the model used to generate the embeddings.",
],
input: [
type: {:or, [:string, {:list, :string}]},
required: true,
doc: "Text or list of text to generate embeddings for.",
],
truncate: [
type: :boolean,
doc: "Truncates the end of each input to fit within context length.",
],
keep_alive: [
type: {:or, [:integer, :string]},
doc: "How long to keep the model loaded.",
],
options: [
type: {:map, {:or, [:atom, :string]}, :any},
doc: "Additional advanced [model parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).",
],
]

@doc """
Generate embeddings from a model for the given prompt.
## Options
#{doc(:embed)}
## Example
iex> Ollama.embed(client, [
...> model: "nomic-embed-text",
...> input: ["Why is the sky blue?", "Why is the grass green?"],
...> ])
{:ok, %{"embedding" => [
[ 0.009724553, 0.04449892, -0.14063916, 0.0013168337, 0.032128844,
0.10730086, -0.008447222, 0.010106917, 5.2289694e-4, -0.03554127, ...],
[ 0.028196355, 0.043162502, -0.18592504, 0.035034444, 0.055619627,
0.12082449, -0.0090096295, 0.047170386, -0.032078084, 0.0047163847, ...]
]}}
"""
@spec embed(client(), keyword()) :: response()
def embed(%__MODULE__{} = client, params) when is_list(params) do
with {:ok, params} <- NimbleOptions.validate(params, schema(:embed)) do
client
|> req(:post, "/embed", json: Enum.into(params, %{}))
|> res()
end
end


schema :embeddings, [
model: [
type: :string,
Expand Down Expand Up @@ -881,6 +936,7 @@ defmodule Ollama do
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
]}}
"""
@deprecated "Superseded by embed/2"
@spec embeddings(client(), keyword()) :: response()
def embeddings(%__MODULE__{} = client, params) when is_list(params) do
with {:ok, params} <- NimbleOptions.validate(params, schema(:embeddings)) do
Expand Down
33 changes: 33 additions & 0 deletions test/ollama_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,39 @@ defmodule OllamaTest do
end
end

describe "embed/1" do
test "generates an embedding for a given input", %{client: client} do
assert {:ok, res} = Ollama.embed(client, [
model: "nomic-embed-text",
input: "Why is the sky blue?",
])

assert res["model"] == "nomic-embed-text"
assert is_list(res["embeddings"])
assert length(res["embeddings"]) == 1
assert Enum.all?(res["embeddings"], &is_list/1)
end

test "generates an embedding for a list of input texts", %{client: client} do
assert {:ok, res} = Ollama.embed(client, [
model: "nomic-embed-text",
input: ["Why is the sky blue?", "Why is the grass green?"],
])

assert res["model"] == "nomic-embed-text"
assert is_list(res["embeddings"])
assert length(res["embeddings"]) == 2
assert Enum.all?(res["embeddings"], &is_list/1)
end

test "returns error when model not found", %{client: client} do
assert {:error, %HTTPError{status: 404}} = Ollama.embed(client, [
model: "not-found",
input: "Why is the sky blue?",
])
end
end

describe "embeddings/2" do
test "generates an embedding for a given prompt", %{client: client} do
assert {:ok, res} = Ollama.embeddings(client, [
Expand Down
44 changes: 44 additions & 0 deletions test/support/mock_server.ex
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,42 @@ defmodule Ollama.MockServer do
}
""",

# truncated for simplicity
embed_one: """
{
"embeddings": [
[
0.009724553, 0.04449892, -0.14063916, 0.0013168337, 0.032128844,
0.10730086, -0.008447222, 0.010106917, 5.2289694e-4, -0.03554127
]
],
"load_duration": 1881917,
"model": "nomic-embed-text",
"prompt_eval_count": 8,
"total_duration": 48675959
}
""",

# truncated for simplicity
embed_many: """
{
"embeddings": [
[
0.009724553, 0.04449892, -0.14063916, 0.0013168337, 0.032128844,
0.10730086, -0.008447222, 0.010106917, 5.2289694e-4, -0.03554127
],
[
0.028196355, 0.043162502, -0.18592504, 0.035034444, 0.055619627,
0.12082449, -0.0090096295, 0.047170386, -0.032078084, 0.0047163847
]
],
"load_duration": 1902709,
"model": "nomic-embed-text",
"prompt_eval_count": 16,
"total_duration": 53473292
}
""",

embeddings: """
{
"embedding": [
Expand Down Expand Up @@ -339,6 +375,14 @@ defmodule Ollama.MockServer do
post "/blobs/:digest", do: respond(conn, 200)
post "/embeddings", do: handle_request(conn, :embeddings)

post "/embed" do
case conn.body_params do
%{"model" => "not-found"} -> respond(conn, 404)
%{"input" => input} when is_binary(input) -> respond(conn, :embed_one)
%{"input" => input} when is_list(input) > 1 -> respond(conn, :embed_many)
end
end

defp handle_request(conn, name) do
case conn.body_params do
%{"model" => "not-found"} -> respond(conn, 404)
Expand Down

0 comments on commit d38f791

Please sign in to comment.