Skip to content

Commit

Permalink
Batch refactor update (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorybchris authored May 2, 2023
1 parent 99685d9 commit 5801eb2
Show file tree
Hide file tree
Showing 34 changed files with 596 additions and 467 deletions.
1 change: 1 addition & 0 deletions docs/batch/batch-job-info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: hume._batch.batch_job_info.BatchJobInfo
1 change: 0 additions & 1 deletion docs/batch/batch-job-result.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/batch/batch-job-state.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: hume._batch.batch_job_state.BatchJobState
8 changes: 5 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ job = client.submit_job(urls, [config])
print(job)
print("Running...")

result = job.await_complete()
result.download_predictions("predictions.json")
job.await_complete()
job.download_predictions("predictions.json")
print("Predictions downloaded to predictions.json")

print("Predictions downloaded!")
job.download_artifacts("artifacts.zip")
print("Artifacts downloaded to artifacts.zip")
```

### Rehydrate a batch job from a job ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
"print(job)\n",
"print(\"Running...\")\n",
"\n",
"result = job.await_complete()\n",
"job.await_complete()\n",
"download_filepath = \"predictions.json\"\n",
"result.download_predictions(download_filepath)\n",
"job.download_predictions(download_filepath)\n",
"\n",
"print(\"Predictions ready!\")\n",
"print()\n",
"\n",
"with open(\"predictions.json\", \"r\") as fp:\n",
" predictions = json.load(fp)\n",
"with open(\"predictions.json\", \"r\") as f:\n",
" predictions = json.load(f)\n",
" for prediction in predictions:\n",
" for file_data in prediction[\"files\"]:\n",
" face_predictions = file_data[\"models\"][\"face\"]\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
"print(job)\n",
"print(\"Running...\")\n",
"\n",
"result = job.await_complete()\n",
"job.await_complete()\n",
"download_filepath = \"predictions.json\"\n",
"result.download_predictions(download_filepath)\n",
"job.download_predictions(download_filepath)\n",
"\n",
"print(\"Predictions ready!\")\n",
"print()\n",
"\n",
"with open(\"predictions.json\", \"r\") as fp:\n",
" predictions = json.load(fp)\n",
"with open(\"predictions.json\", \"r\") as f:\n",
" predictions = json.load(f)\n",
" for prediction in predictions:\n",
" for file_data in prediction[\"files\"]:\n",
" ner_predictions = file_data[\"models\"][\"ner\"]\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
"print(job)\n",
"print(\"Running...\")\n",
"\n",
"result = job.await_complete()\n",
"job.await_complete()\n",
"download_filepath = \"predictions.json\"\n",
"result.download_predictions(download_filepath)\n",
"job.download_predictions(download_filepath)\n",
"\n",
"print(\"Predictions ready!\")\n",
"print()\n",
"\n",
"with open(\"predictions.json\", \"r\") as fp:\n",
" predictions = json.load(fp)\n",
"with open(\"predictions.json\", \"r\") as f:\n",
" predictions = json.load(f)\n",
" for prediction in predictions:\n",
" for file_data in prediction[\"files\"]:\n",
" language_predictions = file_data[\"models\"][\"language\"]\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
"print(job)\n",
"print(\"Running...\")\n",
"\n",
"result = job.await_complete()\n",
"job.await_complete()\n",
"download_filepath = \"predictions.json\"\n",
"result.download_predictions(download_filepath)\n",
"job.download_predictions(download_filepath)\n",
"\n",
"print(\"Predictions ready!\")\n",
"print()\n",
"\n",
"with open(\"predictions.json\", \"r\") as fp:\n",
" predictions = json.load(fp)\n",
"with open(\"predictions.json\", \"r\") as f:\n",
" predictions = json.load(f)\n",
" for prediction in predictions:\n",
" for file_data in prediction[\"files\"]:\n",
" language_predictions = file_data[\"models\"][\"language\"]\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/batch-voice-expression/batch-voice-expression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@
"print(job)\n",
"print(\"Running...\")\n",
"\n",
"result = job.await_complete()\n",
"job.await_complete()\n",
"download_filepath = \"predictions.json\"\n",
"result.download_predictions(download_filepath)\n",
"job.download_predictions(download_filepath)\n",
"\n",
"print(\"Predictions ready!\")\n",
"\n",
"with open(\"predictions.json\", \"r\") as fp:\n",
" predictions = json.load(fp)\n",
"with open(\"predictions.json\", \"r\") as f:\n",
" predictions = json.load(f)\n",
" for prediction in predictions:\n",
" for file_data in prediction[\"files\"]:\n",
" print()\n",
Expand Down
6 changes: 4 additions & 2 deletions hume/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module init."""
from importlib.metadata import version

from hume._batch import BatchJob, BatchJobResult, BatchJobStatus, HumeBatchClient
from hume._batch import BatchJob, BatchJobInfo, BatchJobState, BatchJobStatus, HumeBatchClient, TranscriptionConfig
from hume._stream import HumeStreamClient, StreamSocket
from hume.error.hume_client_exception import HumeClientException

Expand All @@ -10,10 +10,12 @@
__all__ = [
"__version__",
"BatchJob",
"BatchJobResult",
"BatchJobInfo",
"BatchJobState",
"BatchJobStatus",
"HumeBatchClient",
"HumeClientException",
"HumeStreamClient",
"StreamSocket",
"TranscriptionConfig",
]
8 changes: 6 additions & 2 deletions hume/_batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Module init."""
from hume._batch.batch_job import BatchJob
from hume._batch.batch_job_result import BatchJobResult
from hume._batch.batch_job_info import BatchJobInfo
from hume._batch.batch_job_state import BatchJobState
from hume._batch.batch_job_status import BatchJobStatus
from hume._batch.hume_batch_client import HumeBatchClient
from hume._batch.transcription_config import TranscriptionConfig

__all__ = [
"BatchJob",
"BatchJobResult",
"BatchJobInfo",
"BatchJobState",
"BatchJobStatus",
"HumeBatchClient",
"TranscriptionConfig",
]
58 changes: 43 additions & 15 deletions hume/_batch/batch_job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Batch job."""
from typing import TYPE_CHECKING
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Union

from hume._batch.batch_job_result import BatchJobResult
from hume._batch.batch_job_info import BatchJobInfo
from hume._batch.batch_job_status import BatchJobStatus
from hume._common.retry_utils import retry, RetryIterError

Expand All @@ -28,21 +30,47 @@ def get_status(self) -> BatchJobStatus:
Returns:
BatchJobStatus: The status of the `BatchJob`.
"""
return self.get_result().status
return self.get_info().state.status

def get_result(self) -> BatchJobResult:
"""Get the result of the BatchJob.
def get_predictions(self) -> Any:
"""Get `BatchJob` predictions.
Note that the result of a job may be fetched before the job has completed.
Returns:
Any: Predictions for the `BatchJob`.
"""
return self._client.get_job_predictions(self.id)

def download_predictions(self, filepath: Union[str, Path]) -> None:
"""Download `BatchJob` predictions file.
Args:
filepath (Union[str, Path]): Filepath where predictions will be downloaded.
"""
predictions = self.get_predictions()
with Path(filepath).open("w") as f:
json.dump(predictions, f)

def download_artifacts(self, filepath: Union[str, Path]) -> None:
"""Download `BatchJob` artifacts zip file.
Args:
filepath (Optional[Union[str, Path]]): Filepath where artifacts will be downloaded.
"""
self._client.download_job_artifacts(self.id, filepath)

def get_info(self) -> BatchJobInfo:
"""Get info for the BatchJob.
Note that the info for a job may be fetched before the job has completed.
You may want to use `job.await_complete()` which will wait for the job to
reach a terminal state before returning the result.
reach a terminal state before returning.
Returns:
BatchJobResult: The result of the `BatchJob`.
BatchJobInfo: Info for the `BatchJob`.
"""
return self._client.get_job_result(self.id)
return self._client.get_job_info(self.id)

def await_complete(self, timeout: int = 300) -> BatchJobResult:
def await_complete(self, timeout: int = 300) -> BatchJobInfo:
"""Block until the job has reached a terminal status.
Args:
Expand All @@ -54,7 +82,7 @@ def await_complete(self, timeout: int = 300) -> BatchJobResult:
ValueError: If the timeout is not valid.
Returns:
BatchJobResult: The result of the `BatchJob`.
BatchJobInfo: Info for the `BatchJob`.
"""
if timeout < 1:
raise ValueError("timeout must be at least 1 second")
Expand All @@ -63,11 +91,11 @@ def await_complete(self, timeout: int = 300) -> BatchJobResult:

# pylint: disable=unused-argument
@retry()
def _await_complete(self, timeout: int = 300) -> BatchJobResult:
result = self._client.get_job_result(self.id)
if not BatchJobStatus.is_terminal(result.status):
def _await_complete(self, timeout: int = 300) -> BatchJobInfo:
info = self._client.get_job_info(self.id)
if not BatchJobStatus.is_terminal(info.state.status):
raise RetryIterError
return result
return info

def __repr__(self) -> str:
"""Get the string representation of the `BatchJob`.
Expand Down
101 changes: 101 additions & 0 deletions hume/_batch/batch_job_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Batch job info."""
import json
from typing import Any, Dict, List, Optional

from hume._batch.batch_job_state import BatchJobState
from hume._batch.batch_job_status import BatchJobStatus
from hume._common.config_utils import config_from_model_type
from hume.error.hume_client_exception import HumeClientException
from hume.models import ModelType
from hume.models.config.model_config_base import ModelConfigBase


class BatchJobInfo:
"""Batch job info."""

def __init__(
self,
*,
configs: Dict[ModelType, ModelConfigBase],
urls: List[str],
files: List[str],
state: BatchJobState,
callback_url: Optional[str] = None,
notify: bool = False,
):
"""Construct a BatchJobInfo.
Args:
configs (Dict[ModelType, ModelConfigBase]): Configurations for the `BatchJob`.
urls (List[str]): URLs processed in the `BatchJob`.
files (List[str]): Files processed in the `BatchJob`.
state (BatchJobState): State of `BatchJob`.
callback_url (Optional[str]): A URL to which a POST request is sent upon job completion.
notify (bool): Whether an email notification should be sent upon job completion.
"""
self.configs = configs
self.urls = urls
self.files = files
self.state = state
self.callback_url = callback_url
self.notify = notify

@classmethod
def from_response(cls, response: Any) -> "BatchJobInfo":
"""Construct a `BatchJobInfo` from a batch API job response.
Args:
response (Any): Batch API job response.
Returns:
BatchJobInfo: A `BatchJobInfo` based on a batch API job response.
"""
try:
request = response["request"]

configs = {}
for model_name, config_dict in request["models"].items():
if config_dict is None:
continue
model_type = ModelType.from_str(model_name)
config = config_from_model_type(model_type).from_dict(config_dict)
configs[model_type] = config

urls = request["urls"]
files = request["files"]
callback_url = request["callback_url"]
notify = request["notify"]

state_dict = response["state"]
state = BatchJobState(
status=BatchJobStatus.from_str(state_dict["status"]),
created_timestamp_ms=state_dict.get("created_timestamp_ms"),
started_timestamp_ms=state_dict.get("started_timestamp_ms"),
ended_timestamp_ms=state_dict.get("ended_timestamp_ms"),
)

return cls(
configs=configs,
urls=urls,
files=files,
state=state,
callback_url=callback_url,
notify=notify,
)
# pylint: disable=broad-except
except Exception as exc:
message = cls._get_invalid_response_message(response)
raise HumeClientException(message) from exc

@classmethod
def _get_invalid_response_message(cls, response: Any) -> str:
response_str = json.dumps(response)
message = f"Could not parse response into BatchJobInfo: {response_str}"

# Check for invalid API key
if "fault" in response and "faultstring" in response["fault"]:
fault_string = response["fault"]["faultstring"]
if fault_string == "Invalid ApiKey":
message = "HumeBatchClient initialized with invalid API key."

return message
Loading

0 comments on commit 5801eb2

Please sign in to comment.