Skip to content

Commit

Permalink
Fix mapping of unregistered MLFlow models to Comet Experiment Models
Browse files Browse the repository at this point in the history
Previously only models that were registered were imported as Comet Models. Now
also import MLFlow models that are not in the model registry as Comet
Models.
  • Loading branch information
Lothiraldan committed Nov 9, 2023
1 parent c52366a commit 188cc35
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
51 changes: 33 additions & 18 deletions comet_for_mlflow/comet_for_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@

from .compat import (
get_artifact_repository,
get_mlflow_model_name,
get_mlflow_run_id,
search_mlflow_store_experiments,
search_mlflow_store_runs,
Expand Down Expand Up @@ -381,17 +380,11 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):
LOGGER.debug("### Importing artifacts")
artifact_store = get_artifact_repository(run.info.artifact_uri)

# List all the registered models if possible
models_prefixes = {}
if self.model_registry_store:
query = "run_id='%s'" % run.info.run_id
registered_models = self.model_registry_store.search_model_versions(
query
)
# Get all of the artifact list as we need to search for the
# specific MLModel file to detect models
all_artifacts = list(walk_run_artifacts(artifact_store))

for model in registered_models:
model_relpath = os.path.relpath(model.source, run.info.artifact_uri)
models_prefixes[model_relpath] = model
models_prefixes = self.get_model_prefixes(all_artifacts)

for artifact in walk_run_artifacts(artifact_store):
artifact_path = artifact.path
Expand All @@ -405,27 +398,33 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):
self.summary["artifacts"] += 1

# Check if it's belonging to one of the registered model
matching_model = None
for model_prefix, model in models_prefixes.items():
matching_model_name = None
for model_prefix, model_name in models_prefixes.items():
if artifact_path.startswith(model_prefix):
matching_model = model
matching_model_name = model_name
# We should match at most one model
break

if matching_model:
model_name = get_mlflow_model_name(matching_model)

if matching_model_name:
prefix = "models/"

if artifact_path.startswith(prefix):
comet_artifact_path = artifact_path[len(prefix) :]
else:
comet_artifact_path = artifact_path

if comet_artifact_path.startswith(model_prefix):
comet_artifact_path = comet_artifact_path[
len(model_prefix) + 1 :
]
else:
comet_artifact_path = comet_artifact_path

json_writer.log_artifact_as_model(
local_artifact_path,
comet_artifact_path,
run_start_time,
model_name,
matching_model_name,
)
else:
json_writer.log_artifact_as_asset(
Expand All @@ -436,6 +435,22 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):

return self.compress_archive(run.info.run_id)

def get_model_prefixes(self, artifact_list):
"""Return the model names from a list of artifacts"""

# Dict of model prefix to model name
models = {}

for artifact in artifact_list:
# Similar logic to MLFlw UI
# https://github.com/mlflow/mlflow/blob/v2.2.2/mlflow/server/js/src/experiment-tracking/components/ArtifactView.js#L253
parts = artifact.path.split("/")
if parts[-1].lower() == "MLmodel".lower():
# Comet don't support model names with / in their name
models["/".join(parts[:-1])] = parts[-2]

return models

def upload(self, prepared_data):
LOGGER.info("# Start uploading data to Comet.ml")

Expand Down
7 changes: 0 additions & 7 deletions comet_for_mlflow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,3 @@ def get_mlflow_run_id(mlflow_run):
return mlflow_run.info.run_id
else:
return mlflow_run.run_id


def get_mlflow_model_name(mlflow_model):
if hasattr(mlflow_model, "name"):
return mlflow_model.name
else:
return mlflow_model.registered_model.name

0 comments on commit 188cc35

Please sign in to comment.