Skip to content

Commit

Permalink
Merge pull request #7 from comet-ml/fix-model-detection
Browse files Browse the repository at this point in the history
Fix mapping of unregistered MLFlow models to Comet Experiment Models
  • Loading branch information
Lothiraldan authored Nov 13, 2023
2 parents c52366a + c59379b commit 3936662
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 44 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ on: [push]
jobs:
test:

runs-on: ubuntu-18.04
runs-on: ubuntu-22.04
strategy:
max-parallel: 2
fail-fast: false
matrix:
# There is is no Python 3.4 on ubuntu-latest
python-version: [2.7, 3.7]
python-version: [3.11]

steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -30,13 +30,13 @@ jobs:
run: |
pytest tests
lint:
runs-on: ubuntu-18.04
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v4
- name: Set up Python 3.7
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.11
- name: Install pre-commit
run: |
python -m pip install --upgrade pip
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/ambv/black
rev: 22.10.0
rev: 23.11.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.1.0
hooks:
- id: flake8
args: ['--config=.flake8']
additional_dependencies: ['flake8-coding==1.3.2', 'flake8-copyright==0.2.3', 'flake8-debugger==4.1.2', 'flake8-mypy==17.8.0']
additional_dependencies: ['flake8-coding==1.3.2', 'flake8-copyright==0.2.4', 'flake8-debugger==4.1.2']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: check-json
- id: check-merge-conflict
Expand All @@ -24,7 +24,7 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.6
hooks:
- id: codespell
exclude_types: [json]
Expand Down
54 changes: 33 additions & 21 deletions comet_for_mlflow/comet_for_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import tempfile
import traceback
from os.path import abspath
from typing import Optional
from zipfile import ZipFile

from comet_ml import API
Expand All @@ -49,7 +48,6 @@

from .compat import (
get_artifact_repository,
get_mlflow_model_name,
get_mlflow_run_id,
search_mlflow_store_experiments,
search_mlflow_store_runs,
Expand Down Expand Up @@ -104,7 +102,6 @@ def __init__(
answer,
email,
):
# type: (bool, str, str, bool, str, Optional[bool], str) -> None
self.answer = answer
self.email = email
self.config = get_config()
Expand Down Expand Up @@ -164,7 +161,6 @@ def prepare(self):

# First prepare all the data except the metadata as we need a project name
for experiment_number, experiment in enumerate(self.mlflow_experiments):

experiment_name = experiment.experiment_id
if experiment.name:
experiment_name = experiment.name
Expand Down Expand Up @@ -381,17 +377,11 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):
LOGGER.debug("### Importing artifacts")
artifact_store = get_artifact_repository(run.info.artifact_uri)

# List all the registered models if possible
models_prefixes = {}
if self.model_registry_store:
query = "run_id='%s'" % run.info.run_id
registered_models = self.model_registry_store.search_model_versions(
query
)
# Get all of the artifact list as we need to search for the
# specific MLModel file to detect models
all_artifacts = list(walk_run_artifacts(artifact_store))

for model in registered_models:
model_relpath = os.path.relpath(model.source, run.info.artifact_uri)
models_prefixes[model_relpath] = model
models_prefixes = self.get_model_prefixes(all_artifacts)

for artifact in walk_run_artifacts(artifact_store):
artifact_path = artifact.path
Expand All @@ -405,27 +395,33 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):
self.summary["artifacts"] += 1

# Check if it's belonging to one of the registered model
matching_model = None
for model_prefix, model in models_prefixes.items():
matching_model_name = None
for model_prefix, model_name in models_prefixes.items():
if artifact_path.startswith(model_prefix):
matching_model = model
matching_model_name = model_name
# We should match at most one model
break

if matching_model:
model_name = get_mlflow_model_name(matching_model)

if matching_model_name:
prefix = "models/"

if artifact_path.startswith(prefix):
comet_artifact_path = artifact_path[len(prefix) :]
else:
comet_artifact_path = artifact_path

if comet_artifact_path.startswith(model_prefix):
comet_artifact_path = comet_artifact_path[
len(model_prefix) + 1 :
]
else:
comet_artifact_path = comet_artifact_path

json_writer.log_artifact_as_model(
local_artifact_path,
comet_artifact_path,
run_start_time,
model_name,
matching_model_name,
)
else:
json_writer.log_artifact_as_asset(
Expand All @@ -436,6 +432,22 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):

return self.compress_archive(run.info.run_id)

def get_model_prefixes(self, artifact_list):
"""Return the model names from a list of artifacts"""

# Dict of model prefix to model name
models = {}

for artifact in artifact_list:
# Similar logic to MLFlw UI
# https://github.com/mlflow/mlflow/blob/v2.2.2/mlflow/server/js/src/experiment-tracking/components/ArtifactView.js#L253
parts = artifact.path.split("/")
if parts[-1].lower() == "MLmodel".lower():
# Comet don't support model names with / in their name
models["/".join(parts[:-1])] = parts[-2]

return models

def upload(self, prepared_data):
LOGGER.info("# Start uploading data to Comet.ml")

Expand Down
7 changes: 0 additions & 7 deletions comet_for_mlflow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,3 @@ def get_mlflow_run_id(mlflow_run):
return mlflow_run.info.run_id
else:
return mlflow_run.run_id


def get_mlflow_model_name(mlflow_model):
if hasattr(mlflow_model, "name"):
return mlflow_model.name
else:
return mlflow_model.registered_model.name
2 changes: 0 additions & 2 deletions comet_for_mlflow/file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def log_artifact_as_visualization(
def log_artifact_as_model(
self, artifact_path, artifact_name, timestamp, model_name
):

_, extension = os.path.splitext(
artifact_path
) # TODO: Support extension less file names?
Expand Down Expand Up @@ -328,7 +327,6 @@ def log_artifact_as_model(
self.write_line_data(data)

def log_artifact_as_asset(self, artifact_path, artifact_name, timestamp):

_, extension = os.path.splitext(
artifact_path
) # TODO: Support extension less file names?
Expand Down

0 comments on commit 3936662

Please sign in to comment.