Skip to content

Commit

Permalink
Fix bug with truss push with non-draft. (#697)
Browse files Browse the repository at this point in the history
* Fix bug with truss push with non-draft.

* Remove unnecessary call.
  • Loading branch information
squidarth authored Oct 10, 2023
1 parent 6a65458 commit 1ff6d6e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 38 deletions.
70 changes: 44 additions & 26 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import Optional

import requests
from truss.remote.baseten.auth import AuthService
Expand Down Expand Up @@ -59,38 +58,57 @@ def model_s3_upload_credentials(self):

def create_model_from_truss(
self,
model_name,
s3_key,
config,
semver_bump,
client_version,
is_trusted=False,
model_id: Optional[str] = None,
model_name: str,
s3_key: str,
config: str,
semver_bump: str,
client_version: str,
is_trusted: bool,
):
if model_id:
mutation = "create_model_version_from_truss"
first_arg = f'model_id: "{model_id}"'
else:
mutation = "create_model_from_truss"
first_arg = f'name: "{model_name}"'

query_string = f"""
mutation {{
{mutation}({first_arg},
s3_key: "{s3_key}",
config: "{config}",
semver_bump: "{semver_bump}",
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'}
) {{
id,
name,
version_id
create_model_from_truss(
name: "{model_name}",
s3_key: "{s3_key}",
config: "{config}",
semver_bump: "{semver_bump}",
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'}
) {{
id,
name,
version_id
}}
}}
"""
resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_from_truss"]

def create_model_version_from_truss(
self,
model_id: str,
s3_key: str,
config: str,
semver_bump: str,
client_version: str,
is_trusted: bool,
):
query_string = f"""
mutation {{
create_model_version_from_truss(
model_id: "{model_id}"
s3_key: "{s3_key}",
config: "{config}",
semver_bump: "{semver_bump}",
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'}
) {{
id
}}
}}
"""
resp = self._post_graphql_query(query_string)
return resp["data"][mutation]
return resp["data"]["create_model_version_from_truss"]

def create_development_model_from_truss(
self,
Expand Down
37 changes: 25 additions & 12 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def create_truss_service(
model_name: str,
s3_key: str,
config: str,
semver_bump: Optional[str] = "MINOR",
is_trusted: Optional[bool] = False,
semver_bump: str = "MINOR",
is_trusted: bool = False,
is_draft: Optional[bool] = False,
model_id: Optional[str] = None,
) -> Tuple[str, str]:
Expand All @@ -133,15 +133,28 @@ def create_truss_service(
f"truss=={truss.version()}",
is_trusted,
)
else:

return (model_version_json["id"], model_version_json["version_id"])

if model_id is None:
model_version_json = api.create_model_from_truss(
model_name,
s3_key,
config,
semver_bump,
f"truss=={truss.version()}",
is_trusted,
model_id,
model_name=model_name,
s3_key=s3_key,
config=config,
semver_bump=semver_bump,
client_version=f"truss=={truss.version()}",
is_trusted=is_trusted,
)

return (model_version_json["id"], model_version_json["version_id"])
return (model_version_json["id"], model_version_json["version_id"])

# Case where there is a model id already, create another version
model_version_json = api.create_model_version_from_truss(
model_id=model_id,
s3_key=s3_key,
config=config,
semver_bump=semver_bump,
client_version=f"truss=={truss.version()}",
is_trusted=is_trusted,
)
model_version_id = model_version_json["id"]
return (model_id, model_version_id)

0 comments on commit 1ff6d6e

Please sign in to comment.