Skip to content

Commit

Permalink
update client to work with new bento version
Browse files Browse the repository at this point in the history
  • Loading branch information
christinab12 committed Apr 30, 2024
1 parent b987cbc commit db8ceac
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/client/dcp_client/utils/bentoml_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from typing import Optional, List
from bentoml.client import Client as BentoClient
from bentoml.client import Client as BentoClient # Client is type SyncHTTPClient
from bentoml.exceptions import BentoMLException
import numpy as np

Expand Down Expand Up @@ -44,7 +43,7 @@ def is_connected(self) -> bool:
"""
return bool(self.client)

async def _run_train(self, data_path: str) -> Optional[str]:
def _run_train(self, data_path: str) -> Optional[str]:
"""Runs the training task asynchronously.
:param data_path: Path to the training data.
Expand All @@ -53,7 +52,7 @@ async def _run_train(self, data_path: str) -> Optional[str]:
:rtype: str, or None
"""
try:
response = await self.client.async_train(data_path)
response = self.client.train(data_path) # train is part of running server
return response
except BentoMLException:
return None
Expand All @@ -65,9 +64,9 @@ def run_train(self, data_path: str):
:type data_path: str
:return: Response from the server if successful, None otherwise.
"""
return asyncio.run(self._run_train(data_path))
return self._run_train(data_path)

async def _run_inference(self, data_path: str) -> Optional[np.ndarray]:
def _run_inference(self, data_path: str) -> Optional[np.ndarray]:
"""Runs the inference task asynchronously.
:param data_path: Path to the data for inference.
Expand All @@ -76,7 +75,7 @@ async def _run_inference(self, data_path: str) -> Optional[np.ndarray]:
:rtype: np.ndarray, or None
"""
try:
response = await self.client.async_segment_image(data_path)
response = self.client.segment_image(data_path) # segment_image is part of running server
return response
except BentoMLException:
return None
Expand All @@ -88,5 +87,5 @@ def run_inference(self, data_path: str) -> List:
:type data_path: str
:return: List of files not supported by the server if unsuccessful, otherwise returns None.
"""
list_of_files_not_suported = asyncio.run(self._run_inference(data_path))
list_of_files_not_suported = self._run_inference(data_path)
return list_of_files_not_suported

0 comments on commit db8ceac

Please sign in to comment.