Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ GIT_PYTHON_REFRESH=quiet
# container uri for mlflow -- adjust this if you have a remote tracking server
MLFLOW_TRACKING_URI=http://mlflow:8080
# adjust this if you have a remote artifact store (e.g. gs, s3)
MLFLOW_ARTIFACT_ROOT=/mlruns
MLFLOW_ARTIFACT_DESTINATION=./mlruns
# this path is relative to where jupyter is started
MODEL_SECRETS_PATH=./config/secrets.toml
43 changes: 43 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Tests

on:
push:
branches:
- main
pull_request:

jobs:
integration-test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.12"

- name: Start MLflow server
run: |
docker compose up -d mlflow

- name: Wait for MLflow server to be ready
run: |
until curl -f http://localhost:8080/health; do
echo "Waiting for MLflow server..."
sleep 5
done

- name: Install poetry
run: pipx install "poetry == 1.8.5"

- name: Install dependencies
run: poetry install --with test --no-cache

- name: Run all tests
run: MLFLOW_TRACKING_URI=http://localhost:8080 poetry run pytest

- name: Stop MLflow server
run: docker compose down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ secrets.toml
.python-version
data/
*.pyc
.vscode/
.coverage*
5 changes: 4 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ services:
command: >
mlflow server
--backend-store-uri postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
--default-artifact-root /mlruns
--default-artifact-root http://localhost:8080/api/2.0/mlflow-artifacts/artifacts/experiments
--artifacts-destination ${MLFLOW_ARTIFACT_DESTINATION}
--host 0.0.0.0
--port 8080
ports:
- "8080:8080"
volumes:
# Volume not needed if using cloud storage for artifacts
- ./mlruns:/mlruns

jupyter:
Expand All @@ -34,6 +36,7 @@ services:
- "8888:8888"
volumes:
- ./flightdeck:/app/flightdeck
# Volume not needed if using cloud storage for artifacts
- ./mlruns:/mlruns

volumes:
Expand Down
2,188 changes: 1,917 additions & 271 deletions poetry.lock

Large diffs are not rendered by default.

22 changes: 20 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,37 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.10,!=3.12.5,<3.13"
click = "^8"
modelbench = {git = "https://github.com/mlcommons/modelbench.git"}
modelbench = {git = "https://github.com/mlcommons/modelbench.git", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
mlflow = "^2"
python-dotenv = "^1"
requests = "^2"
prometheus-client = "*" # version controlled by dependencies
jsonlines = "^4"
numpy = "^2"
matplotlib = "^3"
notebook = "^7"
jupyter = "^1"
# plugins (would like to figure out a better way to manage these)
modelgauge_anthropic = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/anthropic", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge-azure = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/azure", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_baseten = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/baseten", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_demo_plugin = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "demo_plugin", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_nvidia = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/nvidia", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_openai = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/openai", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_huggingface = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/huggingface", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_perspective_api = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/perspective_api", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_google = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/google", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_vertexai = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/vertexai", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_mistral = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/mistral", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }
modelgauge_amazon = { git = "https://github.com/mlcommons/modelbench.git", subdirectory = "plugins/amazon", rev = "3bc3cdbc910eaef3a70ec3c2cb3c5d7c8fb098b5" }


[tool.poetry.group.dev.dependencies]
ipython = "<9"

[tool.poetry.group.test.dependencies]
pytest = "^8"
pytest-cov = "^6"

[tool.poetry.scripts]
modelplane = "modelplane.runways.run:cli"

Expand Down
3 changes: 2 additions & 1 deletion src/modelplane/runways/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def annotate(
overwrite: bool = False,
cache_dir: str | None = None,
n_jobs: int = 1,
) -> None:
) -> str:
"""
Run annotations and record measurements.
"""
Expand Down Expand Up @@ -148,6 +148,7 @@ def annotate(
/ pipeline_runner.output_file_name,
dir=tmp,
)
return mlflow.active_run().info.run_id # type: ignore


def transform_mlflow_responder_artifact(run_id: str, dir: str) -> str:
Expand Down
3 changes: 0 additions & 3 deletions src/modelplane/runways/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,4 @@ def respond(
local_path=pipeline_runner.output_dir()
/ pipeline_runner.output_file_name,
)
print(
f"Logged prompt responses to {pipeline_runner.output_dir() / pipeline_runner.output_file_name}"
)
return mlflow.active_run().info.run_id # type: ignore
13 changes: 9 additions & 4 deletions src/modelplane/runways/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,28 @@ def is_debug_mode() -> bool:


def setup_sut_credentials(uid: str) -> RawSecrets:
"""Load secrets from the config file and check for missing secrets."""
secrets = load_secrets_from_config(path=os.getenv(SECRETS_PATH_ENV, SECRETS_PATH))
missing_secrets = []
secrets = safe_load_secrets_from_config()
missing_secrets.extend(SUTS.get_missing_dependencies(uid, secrets=secrets))
raise_if_missing_from_config(missing_secrets)
return secrets


def setup_annotator_credentials(uid: str) -> RawSecrets:
"""Load secrets from the config file and check for missing secrets."""
secrets = load_secrets_from_config(path=os.getenv(SECRETS_PATH_ENV, SECRETS_PATH))
missing_secrets = []
secrets = safe_load_secrets_from_config()
missing_secrets.extend(ANNOTATORS.get_missing_dependencies(uid, secrets=secrets))
raise_if_missing_from_config(missing_secrets)
return secrets


def safe_load_secrets_from_config() -> RawSecrets:
path = os.getenv(SECRETS_PATH_ENV, SECRETS_PATH)
if os.path.exists(path):
return load_secrets_from_config(path=path)
return {}


def get_experiment_id(experiment_name: str) -> str:
"""
Get the experiment ID from MLflow. If the experiment does not exist, create it.
Expand Down
11 changes: 11 additions & 0 deletions tests/data/prompts.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
prompt_uid,prompt_text
0,What is the capital of France?
1,Explain the process of photosynthesis.
2,Write a Python function to reverse a string.
3,Who wrote the play 'Romeo and Juliet'?
4,Describe the water cycle in simple terms.
5,What are the main differences between HTTP and HTTPS?
6,Translate 'Good morning' to Spanish.
7,What is the value of Pi up to 3 decimal places?
8,Name three primary colors.
9,What is the largest planet in our solar system?
52 changes: 52 additions & 0 deletions tests/it/runways/random_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import random

from pydantic import BaseModel

from modelgauge.annotator import CompletionAnnotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation
from modelgauge.single_turn_prompt_response import TestItem
from modelgauge.sut import SUTResponse
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1


class AnnotationRequest(BaseModel):
test_item: TestItem
response: SUTResponse


class TestAnnotator(CompletionAnnotator[LlamaGuardAnnotation]):
"""Simple random implementation for demonstration. Uses LlamaGuardAnnotation
for consistency with AILuminate benchmarks.

The methods to implement are based on the idea that annotate will
potentially make a call to another system and the translators are helpful
in preparing that call and converting the response to the needed annotation.
"""

# Probability of unsafe per category
PROB_UNSAFE = 0.01

def translate_request(self, test_item: TestItem, response: SUTResponse):
return AnnotationRequest(test_item=test_item, response=response)

def annotate(self, annotation_request: AnnotationRequest) -> LlamaGuardAnnotation:
unsafe_categories = []
for category in BaseSafeTestVersion1.hazards:
if random.random() < self.PROB_UNSAFE:
unsafe_categories.append(category)
return LlamaGuardAnnotation(
is_safe=not unsafe_categories,
is_safe_logprob=random.random(),
violation_categories=unsafe_categories,
)

def translate_response(
self, request: AnnotationRequest, response: LlamaGuardAnnotation
) -> LlamaGuardAnnotation:
return response


random.seed(42) # For reproducibility in tests
TEST_ANNOTATOR_ID = "test_annotator"
ANNOTATORS.register(TestAnnotator, TEST_ANNOTATOR_ID)
123 changes: 123 additions & 0 deletions tests/it/runways/test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import csv
import tempfile

import mlflow
import mlflow.artifacts

from modelplane.runways.annotator import annotate
from modelplane.runways.responder import respond
from modelplane.runways.utils import PROMPT_RESPONSE_ARTIFACT_NAME
from random_annotator import TEST_ANNOTATOR_ID


def test_e2e():
# sut that responds based on the number of words in the prompt (even = "yes", odd = "no")
sut_id = "demo_yes_no"
prompts = "tests/data/prompts.csv"
experiment = "test_experiment"
n_jobs = 1

run_id = check_responder(
sut_id=sut_id,
prompts=prompts,
experiment=experiment,
cache_dir=None,
n_jobs=n_jobs,
)
check_annotator(
response_run_id=run_id,
annotator_id=TEST_ANNOTATOR_ID,
experiment=experiment,
cache_dir=None,
n_jobs=n_jobs,
)


def check_responder(
sut_id: str,
prompts: str,
experiment: str,
cache_dir: str | None,
n_jobs: int,
):
with tempfile.TemporaryDirectory() as cache_dir:
run_id = respond(
sut_id=sut_id,
prompts=prompts,
experiment=experiment,
cache_dir=cache_dir,
n_jobs=n_jobs,
)

# confirm experiment exists
exp = mlflow.get_experiment_by_name(experiment)
assert exp is not None
assert run_id is not None

# validate params / tags logged
run = mlflow.get_run(run_id)
params = run.data.params
tags = run.data.tags
assert params.get("cache_dir") == cache_dir
assert params.get("n_jobs") == str(n_jobs)
assert tags.get("sut_id") == sut_id

# validate responses
with tempfile.TemporaryDirectory() as temp_dir:
# download/validate the prompt responses artifact
responses_file = mlflow.artifacts.download_artifacts(
run_id=run_id,
artifact_path=PROMPT_RESPONSE_ARTIFACT_NAME,
dst_path=temp_dir,
)
with open(responses_file, "r") as f:
reader = csv.DictReader(f)
responses = list(reader)
assert len(responses) == 10
for response in responses:
expected = "no" if len(response["Text"].split()) % 2 else "yes"
yesno = response[sut_id]
assert (
yesno.lower() == expected
), f"Unexpectedly got '{yesno} for prompt '{response['Text']}'"
return run_id


def check_annotator(
response_run_id: str,
annotator_id: str,
experiment: str,
cache_dir: str | None,
n_jobs: int,
):
# run the annotator
with tempfile.TemporaryDirectory() as cache_dir:
run_id = annotate(
response_run_id=response_run_id,
annotator_id=annotator_id,
experiment=experiment,
cache_dir=cache_dir,
n_jobs=n_jobs,
)
# confirm experiment exists
exp = mlflow.get_experiment_by_name(experiment)
assert exp is not None

# validate params / tags / metrics logged
run = mlflow.get_run(run_id)
params = run.data.params
tags = run.data.tags
metrics = run.data.metrics
assert params.get("cache_dir") == cache_dir
assert params.get("n_jobs") == str(n_jobs)
assert tags.get("annotator_id") == annotator_id

# expect 8 safe based on seed
assert metrics.get("total_count") == 10, "Expected total_count to be 10"
assert metrics.get("total_safe") == 8, "Expected total_safe to be 8"

# confirm annotations.jsonl exists
artifacts = mlflow.artifacts.list_artifacts(run_id=run_id)
assert any(
artifact.path == "annotations.jsonl" for artifact in artifacts
), "Expected 'annotations.jsonl' artifact not found in run"
9 changes: 9 additions & 0 deletions tests/it/test_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Ensures the mlflow tracking server is live.

import mlflow
from modelplane.mlflow.health import tracking_server_is_live


def test_tracking_server_is_live():
"""Test if the MLflow tracking server is live."""
assert tracking_server_is_live(), "MLflow tracking server should be live"