Skip to content

Commit

Permalink
Save graph via the API instead of directly to the DB (#61)
Browse files Browse the repository at this point in the history
* Persist graph through API

* wrap up, clean up
  • Loading branch information
neutralino1 authored Jul 5, 2022
1 parent 4c1e19f commit 56d2298
Show file tree
Hide file tree
Showing 24 changed files with 219 additions and 266 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ FROM --platform=linux/amd64 python:3.9-bullseye
RUN python3 -m pip install --upgrade pip

RUN pip install sematic
# When debugging use the wheel directly
# COPY sematic-*.whl .
# RUN pip install sematic-*.whl

EXPOSE 80
CMD python3 -m sematic.db.migrate --env cloud --verbose; python3 -m sematic.api.server --env cloud
CMD python3 -m sematic.db.migrate --env cloud; python3 -m sematic.api.server --env cloud
9 changes: 3 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,23 @@ clear_sqlite:

create_pg: start_db_container db_migrate_up

pre_commit:
pre-commit:
flake8
mypy sematic
black sematic --check

refresh_dependencies:
refresh-dependencies:
pip-compile --allow-unsafe requirements/requirements.in

test:
bazel test //sematic/... --test_output=all

build_ui:
ui:
cd sematic/ui; npm run build

server-image:
docker build -t sematicai/sematic-server:dev .

server_image_interpreter: build_server_image
docker run -it sematic-server python3

start:
cd sematic/api; docker compose up

Expand Down
11 changes: 10 additions & 1 deletion sematic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@ sematic_py_lib(
sematic_py_lib(
name = "config",
srcs = ["config.py"],
deps = [
":config_dir",
":user_settings",
]
)

sematic_py_lib(
name = "config_dir",
srcs = ["config_dir.py"],
deps = []
)

sematic_py_lib(
name = "user_settings",
srcs = ["user_settings.py"],
deps = [
":config",
":config_dir",
requirement("pyyaml"),
]
)
Expand Down
2 changes: 2 additions & 0 deletions sematic/api/endpoints/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ sematic_py_lib(
"//sematic/db:db",
"//sematic/db:queries",
"//sematic/db/models:run",
"//sematic/db/models:artifact",
"//sematic/db/models:edge",
":request_parameters",
requirement("flask"),
requirement("sqlalchemy"),
Expand Down
7 changes: 5 additions & 2 deletions sematic/api/endpoints/notes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

# Sematic
from sematic.api.app import sematic_api
from sematic.api.endpoints.request_parameters import get_request_parameters, jsonify_404
from sematic.api.endpoints.request_parameters import (
get_request_parameters,
jsonify_error,
)
from sematic.db.models.note import Note
from sematic.db.models.run import Run
from sematic.db.db import db
Expand Down Expand Up @@ -75,7 +78,7 @@ def delete_note_endpoint(note_id: str) -> flask.Response:
try:
note = get_note(note_id)
except NoResultFound:
return jsonify_404("No such note: {}".format(note_id))
return jsonify_error("No such note: {}".format(note_id), HTTPStatus.NOT_FOUND)

delete_note(note)

Expand Down
5 changes: 3 additions & 2 deletions sematic/api/endpoints/request_parameters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library
from http import HTTPStatus
from typing import Dict, Literal, Tuple, Optional, List, Union, cast
import json

Expand Down Expand Up @@ -77,10 +78,10 @@ def _none_if_empty(name: str) -> Optional[str]:
return limit, cursor, group_by_column, sql_predicates


def jsonify_404(error: str):
def jsonify_error(error: str, status: HTTPStatus):
return flask.Response(
json.dumps(dict(error=error)),
status=404,
status=status.value,
mimetype="application/json",
)

Expand Down
40 changes: 36 additions & 4 deletions sematic/api/endpoints/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Standard library
import base64
from http import HTTPStatus
import typing
from urllib.parse import urlunsplit, urlencode, urlsplit

Expand All @@ -16,9 +17,14 @@
# Sematic
from sematic.api.app import sematic_api
from sematic.db.db import db
from sematic.db.models.artifact import Artifact
from sematic.db.models.edge import Edge
from sematic.db.models.run import Run
from sematic.db.queries import get_root_graph, get_run
from sematic.api.endpoints.request_parameters import get_request_parameters, jsonify_404
from sematic.db.queries import get_root_graph, get_run, save_graph
from sematic.api.endpoints.request_parameters import (
get_request_parameters,
jsonify_error,
)


@sematic_api.route("/api/v1/runs", methods=["GET"])
Expand Down Expand Up @@ -158,7 +164,9 @@ def get_run_endpoint(run_id: str) -> flask.Response:
try:
run = get_run(run_id)
except NoResultFound:
return jsonify_404("No runs with id {}".format(repr(run_id)))
return jsonify_error(
"No runs with id {}".format(repr(run_id)), HTTPStatus.NOT_FOUND
)

payload = dict(
content=run.to_json_encodable(),
Expand Down Expand Up @@ -196,11 +204,35 @@ def get_run_graph(run_id: str) -> flask.Response:


@sematic_api.route("/api/v1/events/<namespace>/<event>", methods=["POST"])
def graph_update(namespace: str, event: str) -> flask.Response:
def events(namespace: str, event: str) -> flask.Response:
flask_socketio.emit(
event,
flask.request.json,
namespace="/{}".format(namespace),
broadcast=True,
)
return flask.jsonify({})


@sematic_api.route("/api/v1/graph", methods=["PUT"])
def save_graph_endpoint():
if not flask.request or not flask.request.json or "graph" not in flask.request.json:
return jsonify_error(
"Please provide a graph payload in JSON format.",
HTTPStatus.BAD_REQUEST.value,
)

graph = flask.request.json["graph"]

runs = [Run.from_json_encodable(run) for run in graph["runs"]]
artifacts = [
Artifact.from_json_encodable(artifact) for artifact in graph["artifacts"]
]
edges = [Edge.from_json_encodable(edge) for edge in graph["edges"]]

# try:
save_graph(runs, artifacts, edges)
# except Exception as e:
# return jsonify_error(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)

return flask.jsonify({})
9 changes: 6 additions & 3 deletions sematic/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Third-party
import argparse
from flask import jsonify, send_file
from flask_socketio import SocketIO # type: ignore
from flask_socketio import SocketIO, Namespace # type: ignore

# Sematic
from sematic.api.app import sematic_api
Expand All @@ -16,7 +16,6 @@
import sematic.api.endpoints.edges # noqa: F401
import sematic.api.endpoints.artifacts # noqa: F401
from sematic.config import (
DEFAULT_ENV,
get_config,
switch_env,
) # noqa: F401
Expand Down Expand Up @@ -54,11 +53,15 @@ def ping():


socketio = SocketIO(sematic_api, cors_allowed_origins="*")
# This is necessary because starting version 5.7.0 python-socketio does not
# accept connections to undeclared namespaces
socketio.on_namespace(Namespace("/pipeline"))
socketio.on_namespace(Namespace("/graph"))


def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser("Sematic API server")
parser.add_argument("--env", required=False, default=DEFAULT_ENV, type=str)
parser.add_argument("--env", required=False, default="local", type=str)
parser.add_argument("--debug", required=False, default=False, action="store_true")
parser.add_argument("--daemon", required=False, default=False, action="store_true")
return parser.parse_args()
Expand Down
31 changes: 30 additions & 1 deletion sematic/api_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
# Standard library
from typing import Any
from typing import Any, List

# Third party
import requests

# Sematic
from sematic.config import get_config
from sematic.db.models.artifact import Artifact
from sematic.db.models.edge import Edge
from sematic.db.models.run import Run


def save_graph(runs: List[Run], artifacts: List[Artifact], edges: List[Edge]):
"""
Persist a graph.
"""
payload = {
"graph": {
"runs": [run.to_json_encodable() for run in runs],
"artifacts": [artifact.to_json_encodable() for artifact in artifacts],
"edges": [edge.to_json_encodable() for edge in edges],
}
}

_put("/graph", payload)


def notify_pipeline_update(calculator_path: str):
Expand All @@ -31,5 +49,16 @@ def _post(endpoint, json_payload) -> Any:
return response.json()


def _put(endpoint, json_payload) -> Any:
url = _url(endpoint)
response = requests.put(url, json=json_payload)
response.raise_for_status()

if len(response.content) == 0:
return None

return response.json()


def _url(endpoint) -> str:
return "{}{}".format(get_config().api_url, endpoint)
2 changes: 1 addition & 1 deletion sematic/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def cli():
Run an example:
$ sematic run examples/mnist/pytorch
"""
switch_env("local_sqlite")
switch_env("local")
migrate()
Loading

0 comments on commit 56d2298

Please sign in to comment.