diff --git a/comet_for_mlflow/comet_for_mlflow.py b/comet_for_mlflow/comet_for_mlflow.py index 8f132bb..7eb8b58 100644 --- a/comet_for_mlflow/comet_for_mlflow.py +++ b/comet_for_mlflow/comet_for_mlflow.py @@ -49,7 +49,6 @@ from .compat import ( get_artifact_repository, - get_mlflow_model_name, get_mlflow_run_id, search_mlflow_store_experiments, search_mlflow_store_runs, @@ -381,17 +380,11 @@ def prepare_single_mlflow_run(self, run, original_experiment_name): LOGGER.debug("### Importing artifacts") artifact_store = get_artifact_repository(run.info.artifact_uri) - # List all the registered models if possible - models_prefixes = {} - if self.model_registry_store: - query = "run_id='%s'" % run.info.run_id - registered_models = self.model_registry_store.search_model_versions( - query - ) + # Get all of the artifact list as we need to search for the + # specific MLModel file to detect models + all_artifacts = list(walk_run_artifacts(artifact_store)) - for model in registered_models: - model_relpath = os.path.relpath(model.source, run.info.artifact_uri) - models_prefixes[model_relpath] = model + models_prefixes = self.get_model_prefixes(all_artifacts) for artifact in walk_run_artifacts(artifact_store): artifact_path = artifact.path @@ -405,27 +398,33 @@ def prepare_single_mlflow_run(self, run, original_experiment_name): self.summary["artifacts"] += 1 # Check if it's belonging to one of the registered model - matching_model = None - for model_prefix, model in models_prefixes.items(): + matching_model_name = None + for model_prefix, model_name in models_prefixes.items(): if artifact_path.startswith(model_prefix): - matching_model = model + matching_model_name = model_name # We should match at most one model break - if matching_model: - model_name = get_mlflow_model_name(matching_model) - + if matching_model_name: prefix = "models/" + if artifact_path.startswith(prefix): comet_artifact_path = artifact_path[len(prefix) :] else: comet_artifact_path = artifact_path + if comet_artifact_path.startswith(model_prefix): + comet_artifact_path = comet_artifact_path[ + len(model_prefix) + 1 : + ] + else: + comet_artifact_path = comet_artifact_path + json_writer.log_artifact_as_model( local_artifact_path, comet_artifact_path, run_start_time, - model_name, + matching_model_name, ) else: json_writer.log_artifact_as_asset( @@ -436,6 +435,22 @@ def prepare_single_mlflow_run(self, run, original_experiment_name): return self.compress_archive(run.info.run_id) + def get_model_prefixes(self, artifact_list): + """Return the model names from a list of artifacts""" + + # Dict of model prefix to model name + models = {} + + for artifact in artifact_list: + # Similar logic to MLFlw UI + # https://github.com/mlflow/mlflow/blob/v2.2.2/mlflow/server/js/src/experiment-tracking/components/ArtifactView.js#L253 + parts = artifact.path.split("/") + if parts[-1].lower() == "MLmodel".lower(): + # Comet don't support model names with / in their name + models["/".join(parts[:-1])] = parts[-2] + + return models + def upload(self, prepared_data): LOGGER.info("# Start uploading data to Comet.ml") diff --git a/comet_for_mlflow/compat.py b/comet_for_mlflow/compat.py index 8ab2619..f82d613 100644 --- a/comet_for_mlflow/compat.py +++ b/comet_for_mlflow/compat.py @@ -67,10 +67,3 @@ def get_mlflow_run_id(mlflow_run): return mlflow_run.info.run_id else: return mlflow_run.run_id - - -def get_mlflow_model_name(mlflow_model): - if hasattr(mlflow_model, "name"): - return mlflow_model.name - else: - return mlflow_model.registered_model.name