Skip to content

Commit

Permalink
Merge pull request #6 from mnemonica-ai/development
Browse files Browse the repository at this point in the history
Release 0.0.6
  • Loading branch information
p1nox authored May 3, 2024
2 parents e454f17 + 5785191 commit 7e8a052
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 32 deletions.
48 changes: 36 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# oshepherd

> The Oshepherd guiding the Ollama(s) inference orchestration.
> _The Oshepherd guiding the Ollama(s) inference orchestration._
A centralized [Flask](https://flask.palletsprojects.com) API service, using [Celery](https://docs.celeryq.dev) ([RabbitMQ](https://www.rabbitmq.com) + [Redis](https://redis.com)) to orchestrate multiple [Ollama](https://ollama.com) workers.
A centralized [Flask](https://flask.palletsprojects.com) API service, using [Celery](https://docs.celeryq.dev) ([RabbitMQ](https://www.rabbitmq.com) + [Redis](https://redis.com)) to orchestrate multiple [Ollama](https://ollama.com) servers as workers.

### Install

Expand All @@ -14,9 +14,7 @@ pip install oshepherd

1. Setup RabbitMQ and Redis:

Create instances for free for both:
* [cloudamqp.com](https://www.cloudamqp.com)
* [redislabs.com](https://app.redislabs.com)
[Celery](https://docs.celeryq.dev) uses [RabbitMQ](https://docs.celeryq.dev/en/stable/getting-started/backends-and-brokers/index.html#rabbitmq) as message broker, and [Redis](https://docs.celeryq.dev/en/stable/getting-started/backends-and-brokers/index.html#redis) as backend, you'll need to create one instance for each. You can create small instances for free in [cloudamqp.com](https://www.cloudamqp.com) and [redislabs.com](https://app.redislabs.com) respectively.

2. Setup Flask API Server:

Expand All @@ -43,17 +41,47 @@ pip install oshepherd
oshepherd start-worker --env-file .worker.env
```
4. Done, now you're ready to execute Ollama completions remotely. You can point your Ollama client to your oshepherd api server by setting the `host`, and it will return your requested completions from any of the workers:
* [ollama-python](https://github.com/ollama/ollama-python) client:
```python
import ollama
client = ollama.Client(host="http://127.0.0.1:5001")
ollama_response = client.generate({"model": "mistral", "prompt": "Why is the sky blue?"})
```
* [ollama-js](https://github.com/ollama/ollama-js) client:
```javascript
import { Ollama } from "ollama/browser";
const ollama = new Ollama({ host: "http://127.0.0.1:5001" });
const ollamaResponse = await ollama.generate({
model: "mistral",
prompt: "Why is the sky blue?",
});
```
* Raw http request:
```sh
curl -X POST -H "Content-Type: application/json" -L http://127.0.0.1:5001/api/generate/ -d '{
"model": "mistral",
"prompt":"Why is the sky blue?"
}'
```
### Words of advice 🚨
This package is in alpha, its architecture and api might change in the near future. Currently this is getting tested in a closed environment by real users, but haven't been audited, nor tested thorugly. Use it at your own risk.
This package is in alpha, its architecture and api might change in the near future. Currently this is getting tested in a controlled environment by real users, but haven't been audited, nor tested thorugly. Use it at your own risk.
### Disclaimer on Support
As this is an alpha version, support and responses might be limited. We'll do our best to address questions and issues as quickly as possible.
### Contribution Guidelines
We welcome contributions! If you find a bug or have suggestions for improvements, please open an issue or submit a pull request.
We welcome contributions! If you find a bug or have suggestions for improvements, please open an [issue](https://github.com/mnemonica-ai/oshepherd/issues) or submit a [pull request](https://github.com/mnemonica-ai/oshepherd/pulls). Before creating a new issue/pull request, take a moment to search through the existing issues/pull requests to avoid duplicates.
##### Conda Support
Expand All @@ -76,13 +104,9 @@ Follow usage instructions to start api server and celery worker using a local ol
pytest -s tests/
```
### Reporting Issues
Please report any issues you encounter on the GitHub issues page. Before creating a new issue, take a moment to search through the existing issues to avoid duplicates.
### Author
Currently, [mnemonica.ai](mnemonica.ai) is sponsoring the development of this tool.
This is a project developed and maintained by [mnemonica.ai](mnemonica.ai).
### License
Expand Down
2 changes: 2 additions & 0 deletions oshepherd/api/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Flask, Blueprint
from oshepherd.api.config import ApiConfig
from oshepherd.api.generate.routes import generate_blueprint
from oshepherd.api.chat.routes import chat_blueprint
from oshepherd.worker.app import create_celery_app_for_flask


Expand All @@ -19,6 +20,7 @@ def start_flask_app(config: ApiConfig):
# endpoints
api = Blueprint("api", __name__)
api.register_blueprint(generate_blueprint)
api.register_blueprint(chat_blueprint)
app.register_blueprint(api)

app.run(
Expand Down
Empty file added oshepherd/api/chat/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions oshepherd/api/chat/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pydantic import BaseModel
from typing import Optional, List, Literal
from datetime import datetime


class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str

# TODO add support for images
# images: NotRequired[Sequence[Any]]


class ChatRequestPayload(BaseModel):
model: str
messages: Optional[List[ChatMessage]] = None
format: Optional[str] = ""
options: Optional[dict] = {}
stream: Optional[bool] = False
keep_alive: Optional[str] = None


class ChatRequest(BaseModel):
type: str = "chat"
payload: ChatRequestPayload


class ChatResponse(BaseModel):
model: str
created_at: datetime
message: ChatMessage
done: bool
total_duration: int
load_duration: int
prompt_eval_count: Optional[int] = None
prompt_eval_duration: int
eval_count: int
eval_duration: int
43 changes: 43 additions & 0 deletions oshepherd/api/chat/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Generate a chat completion
API implementation of `POST /api/chat` endpoint, handling completion orchestration, as replica of the same Ollama server endpoint.
Ollama endpoint reference: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
"""

import time
from flask import Blueprint, request
from oshepherd.api.utils import streamify_json
from oshepherd.api.chat.models import ChatRequest

chat_blueprint = Blueprint("chat", __name__, url_prefix="/api/chat")


@chat_blueprint.route("/", methods=["POST"])
def chat():
from oshepherd.worker.tasks import exec_completion

print(f" # request.json {request.json}")
chat_request = ChatRequest(**{"payload": request.json})

# req as json string ready to be sent through broker
chat_request_json_str = chat_request.model_dump_json()
print(f" # chat request {chat_request_json_str}")

# queue request to remote ollama api server
task = exec_completion.delay(chat_request_json_str)
while not task.ready():
print(" > waiting for response...")
time.sleep(1)
ollama_res = task.get(timeout=1)

status = 200
if ollama_res.get("error"):
ollama_res = {
"error": "Internal Server Error",
"message": f"error executing completion: {ollama_res['error']['message']}",
}
status = 500

print(f" $ ollama response {status}: {ollama_res}")

return streamify_json(ollama_res, status)
7 changes: 6 additions & 1 deletion oshepherd/api/generate/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List


class GenerateRequest(BaseModel):
class GenerateRequestPayload(BaseModel):
model: str
prompt: str
images: Optional[List[str]] = None
Expand All @@ -16,6 +16,11 @@ class GenerateRequest(BaseModel):
keep_alive: Optional[str] = "5m"


class GenerateRequest(BaseModel):
type: str = "generate"
payload: GenerateRequestPayload


class GenerateResponse(BaseModel):
model: str
created_at: str
Expand Down
21 changes: 13 additions & 8 deletions oshepherd/api/generate/routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Generate a completion
API implementation of `POST /api/generate` endpoint, handling completion orchestration, as replica of the same Ollama server endpoint.
Ollama endpoint reference: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
"""

import time
from flask import Blueprint, request
from oshepherd.api.utils import streamify_json
Expand All @@ -8,18 +14,17 @@

@generate_blueprint.route("/", methods=["POST"])
def generate():
# TODO use .send_task() instead?
from oshepherd.worker.tasks import make_generate_request
from oshepherd.worker.tasks import exec_completion

print(f" # request.json {request.json}")
generate_request = GenerateRequest(**request.json)
generate_request = GenerateRequest(**{"payload": request.json})

# req as json string ready to be sent though broker
# req as json string ready to be sent through broker
generate_request_json_str = generate_request.model_dump_json()
print(generate_request_json_str)
print(f" # generate request {generate_request_json_str}")

# queue request to remote ollama api server though
task = make_generate_request.delay(generate_request_json_str)
# queue request to remote ollama api server
task = exec_completion.delay(generate_request_json_str)
while not task.ready():
print(" > waiting for response...")
time.sleep(1)
Expand All @@ -29,7 +34,7 @@ def generate():
if ollama_res.get("error"):
ollama_res = {
"error": "Internal Server Error",
"message": f"error triggering llm inference: {ollama_res['error']['message']}",
"message": f"error executing completion: {ollama_res['error']['message']}",
}
status = 500

Expand Down
2 changes: 2 additions & 0 deletions oshepherd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def start_worker(env_file):
loglevel=config.LOGLEVEL,
concurrency=config.CONCURRENCY,
prefetch_multiplier=config.PREFETCH_MULTIPLIER,
redis_retry_on_timeout=config.REDIS_RETRY_ON_TIMEOUT,
redis_socket_keepalive=config.REDIS_SOCKET_KEEPALIVE,
)
worker.start()

Expand Down
2 changes: 2 additions & 0 deletions oshepherd/worker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ class WorkerConfig(BaseModel):
"interval_step": 0.1,
"interval_max": 0.5,
}
REDIS_RETRY_ON_TIMEOUT: bool = True
REDIS_SOCKET_KEEPALIVE: bool = True
17 changes: 11 additions & 6 deletions oshepherd/worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,32 @@


@celery_app.task(
name="oshepherd.worker.tasks.make_generate_request",
name="oshepherd.worker.tasks.exec_completion",
bind=True,
base=OllamaCeleryTask,
)
def make_generate_request(self, request_str: str):
def exec_completion(self, request_str: str):
try:
request = json.loads(request_str)
print(f"# make_generate_request request {request}")
print(f"# exec_completion request {request}")
req_type = request["type"]
req_payload = request["payload"]

if req_type == "generate":
response = ollama.generate(**req_payload)
elif req_type == "chat":
response = ollama.chat(**req_payload)

response = ollama.generate(**request)
print(f" $ success {response}")
except Exception as error:
print(f" * error make_generate_request {error}")
print(f" * error exec_completion {error}")
response = {
"error": {"type": str(error.__class__.__name__), "message": str(error)}
}
print(
f" * error response {response}",
)

# Rethrow exception in order to be handled by base class
raise

return response
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "oshepherd"
version = "0.0.5"
version = "0.0.6"
description = "The Oshepherd guiding the Ollama(s) inference orchestration."
readme = "README.md"
authors = [
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="oshepherd",
version="0.0.5",
version="0.0.6",
description="The Oshepherd guiding the Ollama(s) inference orchestration.",
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down
35 changes: 32 additions & 3 deletions tests/test_e2e_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Basic end to end tests, using an equivalent request as pointing to ollama api server in local:
Basic end to end tests, using Ollama python package, and its equivalent http requests to an Ollama server in local.
i.e.:
curl -X POST -H "Content-Type: application/json" -L http://127.0.0.1:11434/api/generate/ -d '{
"model": "mistral",
"prompt":"Why is the sky blue?"
Expand All @@ -10,9 +11,10 @@
import requests
import ollama
from oshepherd.api.generate.models import GenerateResponse
from oshepherd.api.chat.models import ChatResponse


def test_basic_api_worker_queueing_using_ollama():
def test_basic_generate_completion_using_ollama():
params = {"model": "mistral", "prompt": "Why is the sky blue?"}
client = ollama.Client(host="http://127.0.0.1:5001")
ollama_res = client.generate(**params)
Expand All @@ -21,7 +23,7 @@ def test_basic_api_worker_queueing_using_ollama():
assert ollama_res.response, "response should not be empty"


def test_basic_api_worker_queueing_using_requests():
def test_basic_generate_completion_using_requests():
url = "http://127.0.0.1:5001/api/generate/"
headers = {"Content-Type": "application/json"}
data = {"model": "mistral", "prompt": "Why is the sky blue?"}
Expand All @@ -31,3 +33,30 @@ def test_basic_api_worker_queueing_using_requests():
assert "error" not in response
ollama_res = GenerateResponse(**response.json())
assert ollama_res.response, "response should not be empty"


def test_basic_chat_completion_using_ollama():
params = {
"model": "mistral",
"messages": [{"role": "user", "content": "why is the sky blue?"}],
}
client = ollama.Client(host="http://127.0.0.1:5001")
ollama_res = client.chat(**params)

ollama_res = ChatResponse(**ollama_res)
assert ollama_res.message.content, "response should not be empty"


def test_basic_chat_completion_using_requests():
url = "http://127.0.0.1:5001/api/chat/"
headers = {"Content-Type": "application/json"}
data = {
"model": "mistral",
"messages": [{"role": "user", "content": "why is the sky blue?"}],
}
response = requests.post(url, headers=headers, data=json.dumps(data))

assert response.status_code == 200
assert "error" not in response
ollama_res = ChatResponse(**response.json())
assert ollama_res.message.content, "response should not be empty"

0 comments on commit 7e8a052

Please sign in to comment.