diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 4f548dcd5..e071a43a7 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests from truss.remote.baseten.auth import AuthService @@ -59,38 +58,57 @@ def model_s3_upload_credentials(self): def create_model_from_truss( self, - model_name, - s3_key, - config, - semver_bump, - client_version, - is_trusted=False, - model_id: Optional[str] = None, + model_name: str, + s3_key: str, + config: str, + semver_bump: str, + client_version: str, + is_trusted: bool, ): - if model_id: - mutation = "create_model_version_from_truss" - first_arg = f'model_id: "{model_id}"' - else: - mutation = "create_model_from_truss" - first_arg = f'name: "{model_name}"' - query_string = f""" mutation {{ - {mutation}({first_arg}, - s3_key: "{s3_key}", - config: "{config}", - semver_bump: "{semver_bump}", - client_version: "{client_version}", - is_trusted: {'true' if is_trusted else 'false'} - ) {{ - id, - name, - version_id + create_model_from_truss( + name: "{model_name}", + s3_key: "{s3_key}", + config: "{config}", + semver_bump: "{semver_bump}", + client_version: "{client_version}", + is_trusted: {'true' if is_trusted else 'false'} + ) {{ + id, + name, + version_id + }} }} + """ + resp = self._post_graphql_query(query_string) + return resp["data"]["create_model_from_truss"] + + def create_model_version_from_truss( + self, + model_id: str, + s3_key: str, + config: str, + semver_bump: str, + client_version: str, + is_trusted: bool, + ): + query_string = f""" + mutation {{ + create_model_version_from_truss( + model_id: "{model_id}" + s3_key: "{s3_key}", + config: "{config}", + semver_bump: "{semver_bump}", + client_version: "{client_version}", + is_trusted: {'true' if is_trusted else 'false'} + ) {{ + id + }} }} """ resp = self._post_graphql_query(query_string) - return resp["data"][mutation] + return resp["data"]["create_model_version_from_truss"] def create_development_model_from_truss( self, diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index d92557ad1..7707bbce2 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -106,8 +106,8 @@ def create_truss_service( model_name: str, s3_key: str, config: str, - semver_bump: Optional[str] = "MINOR", - is_trusted: Optional[bool] = False, + semver_bump: str = "MINOR", + is_trusted: bool = False, is_draft: Optional[bool] = False, model_id: Optional[str] = None, ) -> Tuple[str, str]: @@ -133,15 +133,28 @@ def create_truss_service( f"truss=={truss.version()}", is_trusted, ) - else: + + return (model_version_json["id"], model_version_json["version_id"]) + + if model_id is None: model_version_json = api.create_model_from_truss( - model_name, - s3_key, - config, - semver_bump, - f"truss=={truss.version()}", - is_trusted, - model_id, + model_name=model_name, + s3_key=s3_key, + config=config, + semver_bump=semver_bump, + client_version=f"truss=={truss.version()}", + is_trusted=is_trusted, ) - - return (model_version_json["id"], model_version_json["version_id"]) + return (model_version_json["id"], model_version_json["version_id"]) + + # Case where there is a model id already, create another version + model_version_json = api.create_model_version_from_truss( + model_id=model_id, + s3_key=s3_key, + config=config, + semver_bump=semver_bump, + client_version=f"truss=={truss.version()}", + is_trusted=is_trusted, + ) + model_version_id = model_version_json["id"] + return (model_id, model_version_id)