|
24 | 24 |
|
25 | 25 | if TYPE_CHECKING:
|
26 | 26 | import mlflow
|
| 27 | + import mlflow.models |
27 | 28 | else:
|
28 | 29 | mlflow = LazyLoader(
|
29 | 30 | "mlflow",
|
@@ -205,19 +206,16 @@ def import_model(
|
205 | 206 | # For MLflow < 1.25
|
206 | 207 | from mlflow.tracking.artifact_utils import _download_artifact_from_uri
|
207 | 208 |
|
208 |
| - local_path = _download_artifact_from_uri( |
| 209 | + local_path: str = _download_artifact_from_uri( |
209 | 210 | artifact_uri=model_uri, output_path=download_dir
|
210 | 211 | )
|
211 |
| - |
212 |
| - mlflow_model_path = bento_model.path_of(MLFLOW_MODEL_FOLDER) |
213 |
| - # Rename model folder from original artifact name to fixed "mlflow_model" |
214 |
| - shutil.move(local_path, mlflow_model_path) |
215 |
| - # If the temp dir we created still exists now, we never needed it because |
216 |
| - # the provided mlflow url must have provided enough path information. Just |
217 |
| - # delete it. If it's not here, it means we needed it but it's been renamed |
218 |
| - # by now so we don't need to remove it. |
219 |
| - if os.path.isdir(download_dir): |
| 212 | + finally: |
| 213 | + mlflow_model_path = bento_model.path_of(MLFLOW_MODEL_FOLDER) |
| 214 | + # Rename model folder from original artifact name to fixed "mlflow_model" |
| 215 | + shutil.move(local_path, mlflow_model_path) # type: ignore (local_path is bound) |
| 216 | + # Remove the tempdir |
220 | 217 | shutil.rmtree(download_dir)
|
| 218 | + |
221 | 219 | mlflow_model_file = os.path.join(mlflow_model_path, MLMODEL_FILE_NAME)
|
222 | 220 |
|
223 | 221 | if not os.path.exists(mlflow_model_file):
|
@@ -258,7 +256,7 @@ def __init__(self):
|
258 | 256 | input_spec=None,
|
259 | 257 | output_spec=None,
|
260 | 258 | )
|
261 |
| - def predict(self, input_data): |
| 259 | + def predict(self, input_data: t.Any) -> t.Any: |
262 | 260 | return self.model.predict(input_data)
|
263 | 261 |
|
264 | 262 | return MLflowPyfuncRunnable
|
|
0 commit comments