From 5801eb202c6ba3046368ebe92471032a98382b1b Mon Sep 17 00:00:00 2001 From: Chris Gregory <8800689+gregorybchris@users.noreply.github.com> Date: Mon, 1 May 2023 18:06:06 -0700 Subject: [PATCH] Batch refactor update (#49) --- docs/batch/batch-job-info.md | 1 + docs/batch/batch-job-result.md | 1 - docs/batch/batch-job-state.md | 1 + docs/index.md | 8 +- .../batch-facial-action-coding-system.ipynb | 8 +- .../batch-text-entity-recognition.ipynb | 8 +- .../batch-text-sentiment-analysis.ipynb | 8 +- .../batch-text-toxicity-detection.ipynb | 8 +- .../batch-voice-expression.ipynb | 8 +- hume/__init__.py | 6 +- hume/_batch/__init__.py | 8 +- hume/_batch/batch_job.py | 58 ++++-- hume/_batch/batch_job_info.py | 101 ++++++++++ hume/_batch/batch_job_result.py | 183 ------------------ hume/_batch/batch_job_state.py | 67 +++++++ hume/_batch/hume_batch_client.py | 134 ++++++++++--- hume/_batch/transcription_config.py | 25 +++ hume/_common/client_base.py | 2 +- hume/_common/config_base.py | 48 +++++ hume/_stream/hume_stream_client.py | 2 +- hume/models/config/model_config_base.py | 44 +---- hume/models/config/ner_config.py | 15 -- hume/models/config/prosody_config.py | 7 + mkdocs.yml | 3 +- pyproject.toml | 4 +- tests/batch/data/info-response-completed.json | 22 +++ ...-failed.json => info-response-failed.json} | 18 +- ...-queued.json => info-response-queued.json} | 11 +- .../batch/data/result-response-completed.json | 23 --- tests/batch/test_batch_job.py | 26 ++- tests/batch/test_batch_job_info.py | 63 ++++++ tests/batch/test_batch_job_result.py | 77 -------- tests/batch/test_service_hume_batch_client.py | 63 +++--- tests/stream/test_hume_stream_client.py | 2 +- 34 files changed, 596 insertions(+), 467 deletions(-) create mode 100644 docs/batch/batch-job-info.md delete mode 100644 docs/batch/batch-job-result.md create mode 100644 docs/batch/batch-job-state.md create mode 100644 hume/_batch/batch_job_info.py delete mode 100644 hume/_batch/batch_job_result.py create mode 100644 hume/_batch/batch_job_state.py create mode 100644 hume/_batch/transcription_config.py create mode 100644 hume/_common/config_base.py create mode 100644 tests/batch/data/info-response-completed.json rename tests/batch/data/{result-response-failed.json => info-response-failed.json} (52%) rename tests/batch/data/{result-response-queued.json => info-response-queued.json} (60%) delete mode 100644 tests/batch/data/result-response-completed.json create mode 100644 tests/batch/test_batch_job_info.py delete mode 100644 tests/batch/test_batch_job_result.py diff --git a/docs/batch/batch-job-info.md b/docs/batch/batch-job-info.md new file mode 100644 index 00000000..2bf5b680 --- /dev/null +++ b/docs/batch/batch-job-info.md @@ -0,0 +1 @@ +::: hume._batch.batch_job_info.BatchJobInfo diff --git a/docs/batch/batch-job-result.md b/docs/batch/batch-job-result.md deleted file mode 100644 index b388a86f..00000000 --- a/docs/batch/batch-job-result.md +++ /dev/null @@ -1 +0,0 @@ -::: hume._batch.batch_job_result.BatchJobResult diff --git a/docs/batch/batch-job-state.md b/docs/batch/batch-job-state.md new file mode 100644 index 00000000..513b7faa --- /dev/null +++ b/docs/batch/batch-job-state.md @@ -0,0 +1 @@ +::: hume._batch.batch_job_state.BatchJobState diff --git a/docs/index.md b/docs/index.md index 13082e00..eb481f44 100644 --- a/docs/index.md +++ b/docs/index.md @@ -36,10 +36,12 @@ job = client.submit_job(urls, [config]) print(job) print("Running...") -result = job.await_complete() -result.download_predictions("predictions.json") +job.await_complete() +job.download_predictions("predictions.json") +print("Predictions downloaded to predictions.json") -print("Predictions downloaded!") +job.download_artifacts("artifacts.zip") +print("Artifacts downloaded to artifacts.zip") ``` ### Rehydrate a batch job from a job ID diff --git a/examples/batch-facial-action-coding-system/batch-facial-action-coding-system.ipynb b/examples/batch-facial-action-coding-system/batch-facial-action-coding-system.ipynb index 5c5444e3..b3533ad4 100644 --- a/examples/batch-facial-action-coding-system/batch-facial-action-coding-system.ipynb +++ b/examples/batch-facial-action-coding-system/batch-facial-action-coding-system.ipynb @@ -30,15 +30,15 @@ "print(job)\n", "print(\"Running...\")\n", "\n", - "result = job.await_complete()\n", + "job.await_complete()\n", "download_filepath = \"predictions.json\"\n", - "result.download_predictions(download_filepath)\n", + "job.download_predictions(download_filepath)\n", "\n", "print(\"Predictions ready!\")\n", "print()\n", "\n", - "with open(\"predictions.json\", \"r\") as fp:\n", - " predictions = json.load(fp)\n", + "with open(\"predictions.json\", \"r\") as f:\n", + " predictions = json.load(f)\n", " for prediction in predictions:\n", " for file_data in prediction[\"files\"]:\n", " face_predictions = file_data[\"models\"][\"face\"]\n", diff --git a/examples/batch-text-entity-recognition/batch-text-entity-recognition.ipynb b/examples/batch-text-entity-recognition/batch-text-entity-recognition.ipynb index 2c9cd89f..8e4e9f9d 100644 --- a/examples/batch-text-entity-recognition/batch-text-entity-recognition.ipynb +++ b/examples/batch-text-entity-recognition/batch-text-entity-recognition.ipynb @@ -30,15 +30,15 @@ "print(job)\n", "print(\"Running...\")\n", "\n", - "result = job.await_complete()\n", + "job.await_complete()\n", "download_filepath = \"predictions.json\"\n", - "result.download_predictions(download_filepath)\n", + "job.download_predictions(download_filepath)\n", "\n", "print(\"Predictions ready!\")\n", "print()\n", "\n", - "with open(\"predictions.json\", \"r\") as fp:\n", - " predictions = json.load(fp)\n", + "with open(\"predictions.json\", \"r\") as f:\n", + " predictions = json.load(f)\n", " for prediction in predictions:\n", " for file_data in prediction[\"files\"]:\n", " ner_predictions = file_data[\"models\"][\"ner\"]\n", diff --git a/examples/batch-text-sentiment-analysis/batch-text-sentiment-analysis.ipynb b/examples/batch-text-sentiment-analysis/batch-text-sentiment-analysis.ipynb index 89313dfb..b9acf5a6 100644 --- a/examples/batch-text-sentiment-analysis/batch-text-sentiment-analysis.ipynb +++ b/examples/batch-text-sentiment-analysis/batch-text-sentiment-analysis.ipynb @@ -30,15 +30,15 @@ "print(job)\n", "print(\"Running...\")\n", "\n", - "result = job.await_complete()\n", + "job.await_complete()\n", "download_filepath = \"predictions.json\"\n", - "result.download_predictions(download_filepath)\n", + "job.download_predictions(download_filepath)\n", "\n", "print(\"Predictions ready!\")\n", "print()\n", "\n", - "with open(\"predictions.json\", \"r\") as fp:\n", - " predictions = json.load(fp)\n", + "with open(\"predictions.json\", \"r\") as f:\n", + " predictions = json.load(f)\n", " for prediction in predictions:\n", " for file_data in prediction[\"files\"]:\n", " language_predictions = file_data[\"models\"][\"language\"]\n", diff --git a/examples/batch-text-toxicity-detection/batch-text-toxicity-detection.ipynb b/examples/batch-text-toxicity-detection/batch-text-toxicity-detection.ipynb index 0af91254..cbad30ee 100644 --- a/examples/batch-text-toxicity-detection/batch-text-toxicity-detection.ipynb +++ b/examples/batch-text-toxicity-detection/batch-text-toxicity-detection.ipynb @@ -30,15 +30,15 @@ "print(job)\n", "print(\"Running...\")\n", "\n", - "result = job.await_complete()\n", + "job.await_complete()\n", "download_filepath = \"predictions.json\"\n", - "result.download_predictions(download_filepath)\n", + "job.download_predictions(download_filepath)\n", "\n", "print(\"Predictions ready!\")\n", "print()\n", "\n", - "with open(\"predictions.json\", \"r\") as fp:\n", - " predictions = json.load(fp)\n", + "with open(\"predictions.json\", \"r\") as f:\n", + " predictions = json.load(f)\n", " for prediction in predictions:\n", " for file_data in prediction[\"files\"]:\n", " language_predictions = file_data[\"models\"][\"language\"]\n", diff --git a/examples/batch-voice-expression/batch-voice-expression.ipynb b/examples/batch-voice-expression/batch-voice-expression.ipynb index ca811825..3ffbbaea 100644 --- a/examples/batch-voice-expression/batch-voice-expression.ipynb +++ b/examples/batch-voice-expression/batch-voice-expression.ipynb @@ -32,14 +32,14 @@ "print(job)\n", "print(\"Running...\")\n", "\n", - "result = job.await_complete()\n", + "job.await_complete()\n", "download_filepath = \"predictions.json\"\n", - "result.download_predictions(download_filepath)\n", + "job.download_predictions(download_filepath)\n", "\n", "print(\"Predictions ready!\")\n", "\n", - "with open(\"predictions.json\", \"r\") as fp:\n", - " predictions = json.load(fp)\n", + "with open(\"predictions.json\", \"r\") as f:\n", + " predictions = json.load(f)\n", " for prediction in predictions:\n", " for file_data in prediction[\"files\"]:\n", " print()\n", diff --git a/hume/__init__.py b/hume/__init__.py index 390744f4..06b9efae 100644 --- a/hume/__init__.py +++ b/hume/__init__.py @@ -1,7 +1,7 @@ """Module init.""" from importlib.metadata import version -from hume._batch import BatchJob, BatchJobResult, BatchJobStatus, HumeBatchClient +from hume._batch import BatchJob, BatchJobInfo, BatchJobState, BatchJobStatus, HumeBatchClient, TranscriptionConfig from hume._stream import HumeStreamClient, StreamSocket from hume.error.hume_client_exception import HumeClientException @@ -10,10 +10,12 @@ __all__ = [ "__version__", "BatchJob", - "BatchJobResult", + "BatchJobInfo", + "BatchJobState", "BatchJobStatus", "HumeBatchClient", "HumeClientException", "HumeStreamClient", "StreamSocket", + "TranscriptionConfig", ] diff --git a/hume/_batch/__init__.py b/hume/_batch/__init__.py index 65c78a60..c8eb2335 100644 --- a/hume/_batch/__init__.py +++ b/hume/_batch/__init__.py @@ -1,12 +1,16 @@ """Module init.""" from hume._batch.batch_job import BatchJob -from hume._batch.batch_job_result import BatchJobResult +from hume._batch.batch_job_info import BatchJobInfo +from hume._batch.batch_job_state import BatchJobState from hume._batch.batch_job_status import BatchJobStatus from hume._batch.hume_batch_client import HumeBatchClient +from hume._batch.transcription_config import TranscriptionConfig __all__ = [ "BatchJob", - "BatchJobResult", + "BatchJobInfo", + "BatchJobState", "BatchJobStatus", "HumeBatchClient", + "TranscriptionConfig", ] diff --git a/hume/_batch/batch_job.py b/hume/_batch/batch_job.py index 238f9cfe..ca9dc5c9 100644 --- a/hume/_batch/batch_job.py +++ b/hume/_batch/batch_job.py @@ -1,7 +1,9 @@ """Batch job.""" -from typing import TYPE_CHECKING +import json +from pathlib import Path +from typing import TYPE_CHECKING, Any, Union -from hume._batch.batch_job_result import BatchJobResult +from hume._batch.batch_job_info import BatchJobInfo from hume._batch.batch_job_status import BatchJobStatus from hume._common.retry_utils import retry, RetryIterError @@ -28,21 +30,47 @@ def get_status(self) -> BatchJobStatus: Returns: BatchJobStatus: The status of the `BatchJob`. """ - return self.get_result().status + return self.get_info().state.status - def get_result(self) -> BatchJobResult: - """Get the result of the BatchJob. + def get_predictions(self) -> Any: + """Get `BatchJob` predictions. - Note that the result of a job may be fetched before the job has completed. + Returns: + Any: Predictions for the `BatchJob`. + """ + return self._client.get_job_predictions(self.id) + + def download_predictions(self, filepath: Union[str, Path]) -> None: + """Download `BatchJob` predictions file. + + Args: + filepath (Union[str, Path]): Filepath where predictions will be downloaded. + """ + predictions = self.get_predictions() + with Path(filepath).open("w") as f: + json.dump(predictions, f) + + def download_artifacts(self, filepath: Union[str, Path]) -> None: + """Download `BatchJob` artifacts zip file. + + Args: + filepath (Optional[Union[str, Path]]): Filepath where artifacts will be downloaded. + """ + self._client.download_job_artifacts(self.id, filepath) + + def get_info(self) -> BatchJobInfo: + """Get info for the BatchJob. + + Note that the info for a job may be fetched before the job has completed. You may want to use `job.await_complete()` which will wait for the job to - reach a terminal state before returning the result. + reach a terminal state before returning. Returns: - BatchJobResult: The result of the `BatchJob`. + BatchJobInfo: Info for the `BatchJob`. """ - return self._client.get_job_result(self.id) + return self._client.get_job_info(self.id) - def await_complete(self, timeout: int = 300) -> BatchJobResult: + def await_complete(self, timeout: int = 300) -> BatchJobInfo: """Block until the job has reached a terminal status. Args: @@ -54,7 +82,7 @@ def await_complete(self, timeout: int = 300) -> BatchJobResult: ValueError: If the timeout is not valid. Returns: - BatchJobResult: The result of the `BatchJob`. + BatchJobInfo: Info for the `BatchJob`. """ if timeout < 1: raise ValueError("timeout must be at least 1 second") @@ -63,11 +91,11 @@ def await_complete(self, timeout: int = 300) -> BatchJobResult: # pylint: disable=unused-argument @retry() - def _await_complete(self, timeout: int = 300) -> BatchJobResult: - result = self._client.get_job_result(self.id) - if not BatchJobStatus.is_terminal(result.status): + def _await_complete(self, timeout: int = 300) -> BatchJobInfo: + info = self._client.get_job_info(self.id) + if not BatchJobStatus.is_terminal(info.state.status): raise RetryIterError - return result + return info def __repr__(self) -> str: """Get the string representation of the `BatchJob`. diff --git a/hume/_batch/batch_job_info.py b/hume/_batch/batch_job_info.py new file mode 100644 index 00000000..e6bffa57 --- /dev/null +++ b/hume/_batch/batch_job_info.py @@ -0,0 +1,101 @@ +"""Batch job info.""" +import json +from typing import Any, Dict, List, Optional + +from hume._batch.batch_job_state import BatchJobState +from hume._batch.batch_job_status import BatchJobStatus +from hume._common.config_utils import config_from_model_type +from hume.error.hume_client_exception import HumeClientException +from hume.models import ModelType +from hume.models.config.model_config_base import ModelConfigBase + + +class BatchJobInfo: + """Batch job info.""" + + def __init__( + self, + *, + configs: Dict[ModelType, ModelConfigBase], + urls: List[str], + files: List[str], + state: BatchJobState, + callback_url: Optional[str] = None, + notify: bool = False, + ): + """Construct a BatchJobInfo. + + Args: + configs (Dict[ModelType, ModelConfigBase]): Configurations for the `BatchJob`. + urls (List[str]): URLs processed in the `BatchJob`. + files (List[str]): Files processed in the `BatchJob`. + state (BatchJobState): State of `BatchJob`. + callback_url (Optional[str]): A URL to which a POST request is sent upon job completion. + notify (bool): Whether an email notification should be sent upon job completion. + """ + self.configs = configs + self.urls = urls + self.files = files + self.state = state + self.callback_url = callback_url + self.notify = notify + + @classmethod + def from_response(cls, response: Any) -> "BatchJobInfo": + """Construct a `BatchJobInfo` from a batch API job response. + + Args: + response (Any): Batch API job response. + + Returns: + BatchJobInfo: A `BatchJobInfo` based on a batch API job response. + """ + try: + request = response["request"] + + configs = {} + for model_name, config_dict in request["models"].items(): + if config_dict is None: + continue + model_type = ModelType.from_str(model_name) + config = config_from_model_type(model_type).from_dict(config_dict) + configs[model_type] = config + + urls = request["urls"] + files = request["files"] + callback_url = request["callback_url"] + notify = request["notify"] + + state_dict = response["state"] + state = BatchJobState( + status=BatchJobStatus.from_str(state_dict["status"]), + created_timestamp_ms=state_dict.get("created_timestamp_ms"), + started_timestamp_ms=state_dict.get("started_timestamp_ms"), + ended_timestamp_ms=state_dict.get("ended_timestamp_ms"), + ) + + return cls( + configs=configs, + urls=urls, + files=files, + state=state, + callback_url=callback_url, + notify=notify, + ) + # pylint: disable=broad-except + except Exception as exc: + message = cls._get_invalid_response_message(response) + raise HumeClientException(message) from exc + + @classmethod + def _get_invalid_response_message(cls, response: Any) -> str: + response_str = json.dumps(response) + message = f"Could not parse response into BatchJobInfo: {response_str}" + + # Check for invalid API key + if "fault" in response and "faultstring" in response["fault"]: + fault_string = response["fault"]["faultstring"] + if fault_string == "Invalid ApiKey": + message = "HumeBatchClient initialized with invalid API key." + + return message diff --git a/hume/_batch/batch_job_result.py b/hume/_batch/batch_job_result.py deleted file mode 100644 index cf5b15f4..00000000 --- a/hume/_batch/batch_job_result.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Batch job result.""" -import json -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Union -from urllib.request import urlretrieve - -from hume._batch.batch_job_status import BatchJobStatus -from hume._common.config_utils import config_from_model_type -from hume.error.hume_client_exception import HumeClientException -from hume.models import ModelType -from hume.models.config.model_config_base import ModelConfigBase - - -class BatchJobResult: - """Batch job result.""" - - def __init__( - self, - *, - configs: Dict[ModelType, ModelConfigBase], - urls: List[str], - status: BatchJobStatus, - predictions_url: Optional[str] = None, - artifacts_url: Optional[str] = None, - errors_url: Optional[str] = None, - error_message: Optional[str] = None, - job_start_time: Optional[int] = None, - job_end_time: Optional[int] = None, - ): - """Construct a BatchJobResult. - - Args: - configs (Dict[ModelType, ModelConfigBase]): Configurations for the `BatchJob`. - urls (List[str]): URLs processed in the `BatchJob`. - status (BatchJobStatus): Status of `BatchJob`. - predictions_url (Optional[str]): URL to predictions file. - artifacts_url (Optional[str]): URL to artifacts zip archive. - errors_url (Optional[str]): URL to errors file. - error_message (Optional[str]): Error message for request. - job_start_time (Optional[int]): Time when job started. - job_end_time (Optional[int]): Time when job completed. - """ - self.configs = configs - self.urls = urls - self.status = status - self.predictions_url = predictions_url - self.artifacts_url = artifacts_url - self.errors_url = errors_url - self.error_message = error_message - self.job_start_time = job_start_time - self.job_end_time = job_end_time - - def download_predictions(self, filepath: Optional[Union[str, Path]] = None) -> None: - """Download `BatchJob` predictions file. - - Args: - filepath (Optional[Union[str, Path]]): Filepath where predictions will be downloaded. - """ - if self.predictions_url is None: - raise HumeClientException("Could not download job predictions. No predictions found on job result.") - urlretrieve(self.predictions_url, filepath) - - def download_artifacts(self, filepath: Optional[Union[str, Path]] = None) -> None: - """Download `BatchJob` artifacts zip archive. - - Args: - filepath (Optional[Union[str, Path]]): Filepath where artifacts zip archive will be downloaded. - """ - if self.artifacts_url is None: - raise HumeClientException("Could not download job artifacts. No artifacts found on job result.") - urlretrieve(self.artifacts_url, filepath) - - def download_errors(self, filepath: Optional[Union[str, Path]] = None) -> None: - """Download `BatchJob` errors file. - - Args: - filepath (Optional[Union[str, Path]]): Filepath where errors will be downloaded. - """ - if self.errors_url is None: - raise HumeClientException("Could not download job errors. No errors found on job result.") - urlretrieve(self.errors_url, filepath) - - def get_error_message(self) -> Optional[str]: - """Get any available error messages on the job. - - Returns: - Optional[str]: A string with the error message if there was an error, otherwise `None`. - """ - return self.error_message - - def get_run_time(self) -> Optional[int]: - """Get the total time in seconds it took for the job to run if the job is in a terminal state. - - Returns: - Optional[int]: Time in seconds it took for the job to run. If the job is not in a terminal - state then `None` is returned. - """ - if self.job_start_time is not None and self.job_end_time is not None: - return self.job_end_time - self.job_start_time - return None - - def get_start_time(self) -> Optional[datetime]: - """Get the time the job started running. - - Returns: - Optional[datetime]: Datetime when the job started running. If the job has not started - then `None` is returned. - """ - if self.job_start_time is None: - return None - return datetime.utcfromtimestamp(self.job_start_time) - - def get_end_time(self) -> Optional[datetime]: - """Get the time the job stopped running if the job is in a terminal state. - - Returns: - Optional[datetime]: Datetime when the job started running. If the job is not in a terminal - state then `None` is returned. - """ - if self.job_end_time is None: - return None - return datetime.utcfromtimestamp(self.job_end_time) - - @classmethod - def from_response(cls, response: Any) -> "BatchJobResult": - """Construct a `BatchJobResult` from a batch API job response. - - Args: - response (Any): Batch API job response. - - Returns: - BatchJobResult: A `BatchJobResult` based on a batch API job response. - """ - try: - request = response["request"] - configs = {} - for model_name, config_dict in request["models"].items(): - model_type = ModelType.from_str(model_name) - config = config_from_model_type(model_type).from_dict(config_dict) - configs[model_type] = config - - kwargs = {} - if "completed" in response: - completed_dict = response["completed"] - kwargs["artifacts_url"] = completed_dict["artifacts_url"] - kwargs["errors_url"] = completed_dict["errors_url"] - kwargs["predictions_url"] = completed_dict["predictions_url"] - - if "failed" in response: - failed_dict = response["failed"] - if "message" in failed_dict: - kwargs["error_message"] = failed_dict["message"] - - if "creation_timestamp" in response: - kwargs["job_start_time"] = response["creation_timestamp"] - - if "completion_timestamp" in response: - kwargs["job_end_time"] = response["completion_timestamp"] - - return cls( - configs=configs, - urls=request["urls"], - status=BatchJobStatus.from_str(response["status"]), - **kwargs, - ) - # pylint: disable=broad-except - except Exception as exc: - message = cls._get_invalid_response_message(response) - raise HumeClientException(message) from exc - - @classmethod - def _get_invalid_response_message(cls, response: Any) -> str: - response_str = json.dumps(response) - message = f"Could not parse response into BatchJobResult: {response_str}" - - # Check for invalid API key - if "fault" in response and "faultstring" in response["fault"]: - fault_string = response["fault"]["faultstring"] - if fault_string == "Invalid ApiKey": - message = "HumeBatchClient initialized with invalid API key." - - return message diff --git a/hume/_batch/batch_job_state.py b/hume/_batch/batch_job_state.py new file mode 100644 index 00000000..2f49626a --- /dev/null +++ b/hume/_batch/batch_job_state.py @@ -0,0 +1,67 @@ +"""Batch job state.""" +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +from hume._batch.batch_job_status import BatchJobStatus + + +@dataclass +class BatchJobState: + """Batch job state. + + Args: + status (BatchJobStatus): Status of the batch job. + created_timestamp_ms (Optional[int]): Time when job was created. + started_timestamp_ms (Optional[int]): Time when job started. + ended_timestamp_ms (Optional[int]): Time when job ended. + """ + + status: BatchJobStatus + created_timestamp_ms: Optional[int] + started_timestamp_ms: Optional[int] + ended_timestamp_ms: Optional[int] + + def get_run_time_ms(self) -> Optional[int]: + """Get the total time in milliseconds it took for the job to run if the job is in a terminal state. + + Returns: + Optional[int]: Time in milliseconds it took for the job to run. If the job is not in a terminal + state then `None` is returned. + """ + if self.started_timestamp_ms is not None and self.ended_timestamp_ms is not None: + return self.ended_timestamp_ms - self.started_timestamp_ms + return None + + def get_created_time(self) -> Optional[datetime]: + """Get the time the job was created. + + Returns: + Optional[datetime]: Datetime when the job was created. If the job has not started + then `None` is returned. + """ + if self.created_timestamp_ms is None: + return None + return datetime.utcfromtimestamp(self.created_timestamp_ms / 1000) + + def get_started_time(self) -> Optional[datetime]: + """Get the time the job started running. + + Returns: + Optional[datetime]: Datetime when the job started running. If the job has not started + then `None` is returned. + """ + if self.started_timestamp_ms is None: + return None + return datetime.utcfromtimestamp(self.started_timestamp_ms / 1000) + + def get_ended_time(self) -> Optional[datetime]: + """Get the time the job stopped running if the job is in a terminal state. + + Returns: + Optional[datetime]: Datetime when the job started running. If the job is not in a terminal + state then `None` is returned. + """ + if self.ended_timestamp_ms is None: + return None + return datetime.utcfromtimestamp(self.ended_timestamp_ms / 1000) diff --git a/hume/_batch/hume_batch_client.py b/hume/_batch/hume_batch_client.py index ba55dd59..c4cccea3 100644 --- a/hume/_batch/hume_batch_client.py +++ b/hume/_batch/hume_batch_client.py @@ -1,11 +1,13 @@ """Batch API client.""" import json -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Optional, Union import requests from hume._batch.batch_job import BatchJob -from hume._batch.batch_job_result import BatchJobResult +from hume._batch.batch_job_info import BatchJobInfo +from hume._batch.transcription_config import TranscriptionConfig from hume._common.api_type import ApiType from hume._common.client_base import ClientBase from hume._common.config_utils import serialize_configs @@ -22,17 +24,19 @@ class HumeBatchClient(ClientBase): from hume.models.config import FaceConfig client = HumeBatchClient("") - urls = [""] + urls = ["https://tinyurl.com/hume-img"] config = FaceConfig(identify_faces=True) - job = client.submit_job(urls, [configs]) + job = client.submit_job(urls, [config]) print(job) print("Running...") - result = job.await_complete() - result.download_predictions("predictions.json") + job.await_complete() + job.download_predictions("predictions.json") + print("Predictions downloaded to predictions.json") - print("Predictions downloaded!") + job.download_artifacts("artifacts.zip") + print("Artifacts downloaded to artifacts.zip") ``` """ @@ -55,17 +59,52 @@ def get_api_type(cls) -> ApiType: """ return ApiType.BATCH - def get_job_result(self, job_id: str) -> BatchJobResult: - """Get the result of the batch job. + def get_job(self, job_id: str) -> BatchJob: + """Rehydrate a job based on a Job ID. + + Args: + job_id (str): ID of the job to rehydrate. + + Returns: + BatchJob: Job associated with the given ID. + """ + return BatchJob(self, job_id) + + def submit_job( + self, + urls: List[str], + configs: List[ModelConfigBase], + transcription_config: Optional[TranscriptionConfig] = None, + callback_url: Optional[str] = None, + ) -> BatchJob: + """Submit a job for batch processing. + + Note: Only one config per model type should be passed. + If more than one config is passed for a given model type, only the last config will be used. + + Args: + urls (List[str]): List of URLs to media files to be processed. + configs (List[ModelConfigBase]): List of model config objects to run on each media URL. + transcription_config (Optional[TranscriptionConfig]): A `TranscriptionConfig` object. + callback_url (Optional[str]): A URL to which a POST request will be sent upon job completion. + + Returns: + BatchJob: The `BatchJob` representing the batch computation. + """ + request = self._construct_request(configs, urls, transcription_config, callback_url) + return self._submit_job_from_request(request) + + def get_job_info(self, job_id: str) -> BatchJobInfo: + """Get info for the batch job. Args: job_id (str): Job ID. Raises: - HumeClientException: If the job result cannot be loaded. + HumeClientException: If the job info cannot be loaded. Returns: - BatchJobResult: Batch job result. + BatchJobInfo: Batch job info. """ endpoint = self._construct_endpoint(f"jobs/{job_id}") response = requests.get( @@ -78,46 +117,83 @@ def get_job_result(self, job_id: str) -> BatchJobResult: body = response.json() except json.JSONDecodeError: # pylint: disable=raise-missing-from - raise HumeClientException("Unexpected error when getting job result") + raise HumeClientException("Unexpected error when getting job info") if "message" in body and body["message"] == "job not found": raise HumeClientException(f"Could not find a job with ID {job_id}") - return BatchJobResult.from_response(body) + return BatchJobInfo.from_response(body) - def get_job(self, job_id: str) -> BatchJob: - """Rehydrate a job based on a Job ID. + def get_job_predictions(self, job_id: str) -> Any: + """Get a batch job's predictions. Args: - job_id (str): ID of the job to rehydrate. + job_id (str): Job ID. + + Raises: + HumeClientException: If the job predictions cannot be loaded. Returns: - BatchJob: Job associated with the given ID. + Any: Batch job predictions. """ - return BatchJob(self, job_id) + endpoint = self._construct_endpoint(f"jobs/{job_id}/predictions") + response = requests.get( + endpoint, + timeout=self._DEFAULT_API_TIMEOUT, + headers=self._get_client_headers(), + ) - def submit_job(self, urls: List[str], configs: List[ModelConfigBase]) -> BatchJob: - """Submit a job for batch processing. + try: + body = response.json() + except json.JSONDecodeError: + # pylint: disable=raise-missing-from + raise HumeClientException("Unexpected error when getting job predictions") - Note: Only one config per model type should be passed. - If more than one config is passed for a given model type, only the last config will be used. + if "message" in body and body["message"] == "job not found": + raise HumeClientException(f"Could not find a job with ID {job_id}") + + return body + + def download_job_artifacts(self, job_id: str, filepath: Union[str, Path]) -> None: + """Download a batch job's artifacts as a zip file. Args: - urls (List[str]): List of URLs to media files to be processed. - configs (List[ModelConfigBase]): List of model config objects to run on each media URL. + job_id (str): Job ID. + filepath (Optional[Union[str, Path]]): Filepath where artifacts will be downloaded. + + Raises: + HumeClientException: If the job artifacts cannot be loaded. Returns: - BatchJob: The `BatchJob` representing the batch computation. + Any: Batch job artifacts. """ - request = self._get_request(configs, urls) - return self._submit_job_from_request(request) + endpoint = self._construct_endpoint(f"jobs/{job_id}/artifacts") + response = requests.get( + endpoint, + timeout=self._DEFAULT_API_TIMEOUT, + headers=self._get_client_headers(), + ) + + with Path(filepath).open("wb") as f: + f.write(response.content) @classmethod - def _get_request(cls, configs: List[ModelConfigBase], urls: List[str]) -> Dict[str, Any]: - return { + def _construct_request( + cls, + configs: List[ModelConfigBase], + urls: List[str], + transcription_config: Optional[TranscriptionConfig], + callback_url: Optional[str], + ) -> Dict[str, Any]: + request = { "urls": urls, "models": serialize_configs(configs), } + if transcription_config is not None: + request["transcription"] = transcription_config.to_dict() + if callback_url is not None: + request["callback_url"] = callback_url + return request def _submit_job_from_request(self, request_body: Any) -> BatchJob: """Start a job for batch processing by passing a JSON request body. diff --git a/hume/_batch/transcription_config.py b/hume/_batch/transcription_config.py new file mode 100644 index 00000000..e654746e --- /dev/null +++ b/hume/_batch/transcription_config.py @@ -0,0 +1,25 @@ +"""Configuration for speech transcription.""" +from dataclasses import dataclass +from typing import Optional + +from hume._common.config_base import ConfigBase + + +@dataclass +class TranscriptionConfig(ConfigBase["TranscriptionConfig"]): + """Configuration for speech transcription. + + Args: + language (Optional[str]): The BCP-47 tag (see above) of the language spoken in your media samples; + If missing or null, it will be automatically detected. Values are `zh`, `da`, `nl`, `en`, `en-AU`, + `en-IN`, `en-NZ`, `en-GB`, `fr`, `fr-CA`, `de`, `hi`, `hi-Latn`, `id`, `it`, `ja`, `ko`, `no`, + `pl`, `pt`, `pt-BR`, `pt-PT`, `ru`, `es`, `es-419`, `sv`, `ta`, `tr`, or `uk`. + This configuration is not available for the streaming API. + identify_speakers (Optional[bool]): Whether to return identifiers for speakers over time. + If true, unique identifiers will be assigned to spoken words to differentiate different speakers. + If false, all speakers will be tagged with an "unknown" ID. + This configuration is not available for the streaming API. + """ + + language: Optional[str] = None + identify_speakers: Optional[bool] = None diff --git a/hume/_common/client_base.py b/hume/_common/client_base.py index 771f6281..693a953e 100644 --- a/hume/_common/client_base.py +++ b/hume/_common/client_base.py @@ -39,7 +39,7 @@ def _get_client_headers(self) -> Dict[str, str]: package_version = version("hume") return { "X-Hume-Api-Key": self._api_key, - "X-Hume-Client-Name": "python-sdk", + "X-Hume-Client-Name": "python_sdk", "X-Hume-Client-Version": package_version, } diff --git a/hume/_common/config_base.py b/hume/_common/config_base.py new file mode 100644 index 00000000..6860821f --- /dev/null +++ b/hume/_common/config_base.py @@ -0,0 +1,48 @@ +"""Abstract base class for model configurations.""" +import warnings +from abc import ABC +from dataclasses import asdict, dataclass, fields +from typing import Any, Dict, Generic, TypeVar, cast + +T = TypeVar("T") # Type for subclasses of ConfigBase + + +@dataclass +class ConfigBase(ABC, Generic[T]): + """Abstract base class for configurations.""" + + def to_dict(self, skip_none: bool = True) -> Dict[str, Any]: + """Serialize configuration to dictionary. + + Args: + skip_none (bool): Whether None configurations should be skipped during serialization. + + Returns: + Dict[str, Any]: Serialized configuration object. + """ + return {k: v for k, v in asdict(self).items() if v is not None or not skip_none} + + @classmethod + def from_dict(cls, request_dict: Dict[str, Any]) -> T: + """Deserialize configuration from request JSON. + + Args: + request_dict (Dict[str, Any]): Request JSON data. + + Returns: + T: Deserialized configuration object. + """ + class_fields = set(field.name for field in fields(cls)) + removal_params = [] + for param in request_dict: + if param not in class_fields: + removal_params.append(param) + class_name = cls.__name__ + warnings.warn(f"Got an unknown parameter `{param}` when loading `{class_name}`. " + "Your installed version of the Python SDK may be out of date " + "with the latest Hume APIs. " + "Run `pip install --upgrade hume` to get the latest version of the Python SDK.") + for removal_param in removal_params: + request_dict.pop(removal_param) + + return cast(T, cls(**request_dict)) diff --git a/hume/_stream/hume_stream_client.py b/hume/_stream/hume_stream_client.py index 39c12f75..e17f1d55 100644 --- a/hume/_stream/hume_stream_client.py +++ b/hume/_stream/hume_stream_client.py @@ -29,7 +29,7 @@ class HumeStreamClient(ClientBase): async def main(): client = HumeStreamClient("") config = FaceConfig(identify_faces=True) - async with client.connect([configs]) as socket: + async with client.connect([config]) as socket: result = await socket.send_file("") print(result) diff --git a/hume/models/config/model_config_base.py b/hume/models/config/model_config_base.py index 17ba6c5d..833b7740 100644 --- a/hume/models/config/model_config_base.py +++ b/hume/models/config/model_config_base.py @@ -1,8 +1,8 @@ """Abstract base class for model configurations.""" -import warnings from abc import abstractmethod, ABC -from dataclasses import asdict, dataclass, fields -from typing import Any, Dict, Generic, TypeVar +from dataclasses import dataclass +from typing import Generic, TypeVar +from hume._common.config_base import ConfigBase from hume.models import ModelType @@ -10,7 +10,7 @@ @dataclass -class ModelConfigBase(ABC, Generic[T]): +class ModelConfigBase(ConfigBase["ModelConfigBase"], ABC, Generic[T]): """Abstract base class for model configurations.""" @classmethod @@ -21,39 +21,3 @@ def get_model_type(cls) -> ModelType: Returns: ModelType: Model type. """ - - def to_dict(self, skip_none: bool = True) -> Dict[str, Any]: - """Serialize configuration to dictionary. - - Args: - skip_none (bool): Whether None configurations should be skipped during serialization. - - Returns: - Dict[str, Any]: Serialized configuration object. - """ - return {k: v for k, v in asdict(self).items() if v is not None or not skip_none} - - @classmethod - def from_dict(cls, request_dict: Dict[str, Any]) -> "ModelConfigBase[T]": - """Deserialize configuration from request JSON. - - Args: - request_dict (Dict[str, Any]): Request JSON data. - - Returns: - T: Deserialized configuration object. - """ - class_fields = set(field.name for field in fields(cls)) - removal_params = [] - for param in request_dict: - if param not in class_fields: - removal_params.append(param) - class_name = cls.__name__ - warnings.warn(f"Got an unknown parameter `{param}` when loading `{class_name}`. " - "Your installed version of the Python SDK may be out of date " - "with the latest Hume APIs. " - "Run `pip install --upgrade hume` to get the latest version of the Python SDK.") - for removal_param in removal_params: - request_dict.pop(removal_param) - - return cls(**request_dict) diff --git a/hume/models/config/ner_config.py b/hume/models/config/ner_config.py index 8805fda9..f7f1c10f 100644 --- a/hume/models/config/ner_config.py +++ b/hume/models/config/ner_config.py @@ -1,6 +1,5 @@ """Configuration for the named-entity emotion model.""" from dataclasses import dataclass -from typing import Optional from hume.models import ModelType from hume.models.config.model_config_base import ModelConfigBase @@ -11,22 +10,8 @@ class NerConfig(ModelConfigBase["NerConfig"]): """Configuration for the named-entity emotion model. This model is not available for the streaming API. - - Args: - language (Optional[str]): The BCP-47 tag (see above) of the language spoken in your media samples; - If missing or null, it will be automatically detected. Values are `zh`, `da`, `nl`, `en`, `en-AU`, - `en-IN`, `en-NZ`, `en-GB`, `fr`, `fr-CA`, `de`, `hi`, `hi-Latn`, `id`, `it`, `ja`, `ko`, `no`, - `pl`, `pt`, `pt-BR`, `pt-PT`, `ru`, `es`, `es-419`, `sv`, `ta`, `tr`, or `uk`. - This configuration is not available for the streaming API. - identify_speakers (Optional[bool]): Whether to return identifiers for speakers over time. - If true, unique identifiers will be assigned to spoken words to differentiate different speakers. - If false, all speakers will be tagged with an "unknown" ID. - This configuration is not available for the streaming API. """ - language: Optional[str] = None - identify_speakers: Optional[bool] = None - @classmethod def get_model_type(cls) -> ModelType: """Get the configuration model type. diff --git a/hume/models/config/prosody_config.py b/hume/models/config/prosody_config.py index 108152fa..84ea5afe 100644 --- a/hume/models/config/prosody_config.py +++ b/hume/models/config/prosody_config.py @@ -20,10 +20,17 @@ class ProsodyConfig(ModelConfigBase["ProsodyConfig"]): unique identifiers will be assigned to spoken words to differentiate different speakers. If false, all speakers will be tagged with an "unknown" ID. This configuration is not available for the streaming API. + granularity (Optional[str]): The granularity at which to generate predictions. + Values are `word`, `sentence`, `utterance`, or `conversational_turn`. + Default value is `utterance`. + `utterance` corresponds to a natural pause or break in conversation + `conversational_turn` corresponds to a change in speaker. + This configuration is not available for the streaming API. """ language: Optional[str] = None identify_speakers: Optional[bool] = None + granularity: Optional[str] = None @classmethod def get_model_type(cls) -> ModelType: diff --git a/mkdocs.yml b/mkdocs.yml index 20f13c72..408aa954 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,8 +37,9 @@ nav: - "Batch Reference": - "HumeBatchClient": batch/hume-batch-client.md - "BatchJob": batch/batch-job.md + - "BatchJobState": batch/batch-job-state.md - "BatchJobStatus": batch/batch-job-status.md - - "BatchJobResult": batch/batch-job-result.md + - "BatchJobInfo": batch/batch-job-info.md - "Streaming Reference": - "HumeStreamClient": stream/hume-stream-client.md - "StreamSocket": stream/stream-socket.md diff --git a/pyproject.toml b/pyproject.toml index 186d913b..26410bdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ license = "Proprietary" name = "hume" readme = "README.md" repository = "https://github.com/HumeAI/hume-python-sdk" -version = "0.2.0" +version = "0.3.0" [tool.poetry.dependencies] python = ">=3.8.1,<4" @@ -76,7 +76,7 @@ requires = ["poetry-core>=1.0.0"] [tool.covcheck.group.unit.coverage] branch = 63.0 -line = 82.0 +line = 80.0 [tool.covcheck.group.service.coverage] branch = 64.0 diff --git a/tests/batch/data/info-response-completed.json b/tests/batch/data/info-response-completed.json new file mode 100644 index 00000000..9dff7e36 --- /dev/null +++ b/tests/batch/data/info-response-completed.json @@ -0,0 +1,22 @@ +{ + "request": { + "models": { + "face": { + "fps_pred": 3.0, + "identify_faces": true, + "min_face_size": 60.0, + "prob_threshold": 0.9900000095367432 + } + }, + "urls": ["https://storage.googleapis.com/hume-test-data/video/royal-opera-house.mp4"], + "files": [], + "callback_url": "https:/fake-callback", + "notify": false + }, + "state": { + "status": "COMPLETED", + "created_timestamp_ms": 1660594810000, + "started_timestamp_ms": 1660594812000, + "ended_timestamp_ms": 1660594815000 +} +} diff --git a/tests/batch/data/result-response-failed.json b/tests/batch/data/info-response-failed.json similarity index 52% rename from tests/batch/data/result-response-failed.json rename to tests/batch/data/info-response-failed.json index eb15cdab..26a0c05a 100644 --- a/tests/batch/data/result-response-failed.json +++ b/tests/batch/data/info-response-failed.json @@ -1,9 +1,4 @@ { - "completion_timestamp": 1663790201, - "creation_timestamp": 1663790199, - "failed": { - "message": "user 'abcde' has exceeded their usage limit" - }, "request": { "models": { "face": { @@ -13,8 +8,15 @@ "prob_threshold": 0.9900000095367432 } }, - "notify": false, - "urls": ["https://storage.googleapis.com/hume-test-data/image/obama.png"] + "urls": ["https://storage.googleapis.com/hume-test-data/image/obama.png"], + "files": [], + "callback_url": null, + "notify": false }, - "status": "FAILED" + "state": { + "status": "FAILED", + "created_timestamp_ms": 1660594811000, + "started_timestamp_ms": 1660594814000, + "ended_timestamp_ms": 1660594816000 + } } diff --git a/tests/batch/data/result-response-queued.json b/tests/batch/data/info-response-queued.json similarity index 60% rename from tests/batch/data/result-response-queued.json rename to tests/batch/data/info-response-queued.json index 2863ce1d..632e9195 100644 --- a/tests/batch/data/result-response-queued.json +++ b/tests/batch/data/info-response-queued.json @@ -1,5 +1,4 @@ { - "creation_timestamp": 1660594810, "request": { "models": { "face": { @@ -9,7 +8,13 @@ "prob_threshold": 0.9900000095367432 } }, - "urls": ["https://storage.googleapis.com/hume-test-data/video/royal-opera-house.mp4"] + "urls": ["https://storage.googleapis.com/hume-test-data/video/royal-opera-house.mp4"], + "files": [], + "callback_url": null, + "notify": false }, - "status": "QUEUED" + "state": { + "status": "QUEUED", + "created_timestamp_ms": 1660594815000 + } } diff --git a/tests/batch/data/result-response-completed.json b/tests/batch/data/result-response-completed.json deleted file mode 100644 index a4970c17..00000000 --- a/tests/batch/data/result-response-completed.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "creation_timestamp": 1660594810, - "request": { - "models": { - "face": { - "fps_pred": 3.0, - "identify_faces": true, - "min_face_size": 60.0, - "prob_threshold": 0.9900000095367432 - } - }, - "urls": ["https://storage.googleapis.com/hume-test-data/video/royal-opera-house.mp4"] - }, - "status": "COMPLETED", - "completion_timestamp": 1660594813, - "completed": { - "num_errors": 0, - "num_predictions": 20, - "artifacts_url": "https://storage.googleapis.com/humeai-batch-api-prod-results/c8e7d7a5-e74a-48dd-aa90-a788b5aad36e/4de3d8dac4704cf1aa4a281d0a90a26f/job-artifacts.zip?Expires=1661199613&GoogleAccessId=batch-api-subscriber%40hume-data.iam.gserviceaccount.com&Signature=reD6CDQ9xhy5OWX%2FEYvyFGziYAPGyMsacARf5lp4X5I7LBQ%2Bb8a5%2F6LLUJlAuZT0DEZt3hOmMDg3znIeTf8U99tn9tRIa%2Fdv6%2BAMvBAT7w4sQQbkThS4E5IajZWhNSCgzOh1MJ8JLFA%2BCeLetjgUti5KmiLHe%2FErpd7grHUf3yYOB41TDMVJH5K7BrL2jss5v%2BCaQIKue2geYvcOZnv4loM66JYZ2%2FQ0aQ3007R2TnitnJeDZKMNjnBBakQ6iSxLfOcoweerr7SY53KITZhIwgjsq7454a73klr3FjOTZWlI9TVyFVRR7KnnURcFk5AHa0oDJPORSCiOez2GuKAogQ%3D%3D", - "errors_url": "https://storage.googleapis.com/humeai-batch-api-prod-results/c8e7d7a5-e74a-48dd-aa90-a788b5aad36e/4de3d8dac4704cf1aa4a281d0a90a26f/job-errors.txt?Expires=1661199613&GoogleAccessId=batch-api-subscriber%40hume-data.iam.gserviceaccount.com&Signature=H61hUom2tLjLg7Q8o2ji6QbraiOJ1BJisbiymERMKNrqcXIWRnj5cfPcSBoiqw9bEQsM7xA4OQLkS%2FRRZ7iQ7RVwNNKyt%2FG4%2FdcI59OH3Z4uh9u7jsaRJbtvpKoozQbupVCcrz2OS%2BrY5sCuCZVZ9qIKd6ZVlUXfZrdyyPNB0mN2eVojhpEYjN4HwSoqtuCxuVNSEcLgzdmW5TxGaHmg0fzF3AqnrrnVUhVouWKmZr95UMhTL9mRqDLlWpxA%2FR8eWnQBq3dmsOKkrLjihnpooP1tn48Smwd4YUGNsbg43LzYElQJ82gBrEl4qOEaGolgv5vYMxYwCqy5Wpxz4UPwoA%3D%3D", - "predictions_url": "https://storage.googleapis.com/humeai-batch-api-prod-results/c8e7d7a5-e74a-48dd-aa90-a788b5aad36e/4de3d8dac4704cf1aa4a281d0a90a26f/job-predictions.json?Expires=1661199613&GoogleAccessId=batch-api-subscriber%40hume-data.iam.gserviceaccount.com&Signature=mLAqbIdkEdUsyFgxaExODJsLL%2F1Qx2fj55HKzXJesXry%2Fi9zS8ue7Xr5gI6YyZAyytRI5pvP9bpTzvMhvatG7clTz8wIRvhlDFZom62sQTgg5PforhAWCou5TmZT4NcbqhhU6utd7%2B%2BIC1xDp5kOo33fqzWBUpeptoQwBuz2uxNZE44%2FIXAUSesu4v0qM%2B%2BwYLEBReIRR%2BXh9okfN%2B4PxaNe0owScZ%2BL3ZK70CLdkCNjWkn1dpol%2FGCUsfIzQhDR2GrfK9J7GIGBaRZpD3ve%2FkNhA3CjfCyH9DCuwmPg0GISz204APG3RuQ%2B7SoU6LcG6nBUyvvLPdVr7xg9X%2FyCvw%3D%3D" - } -} diff --git a/tests/batch/test_batch_job.py b/tests/batch/test_batch_job.py index 471fa673..2babc78c 100644 --- a/tests/batch/test_batch_job.py +++ b/tests/batch/test_batch_job.py @@ -3,7 +3,7 @@ import pytest -from hume import BatchJob, BatchJobResult, BatchJobStatus +from hume import BatchJob, BatchJobInfo, BatchJobState, BatchJobStatus from hume.models import ModelType from hume.models.config import FaceConfig @@ -11,14 +11,20 @@ @pytest.fixture(scope="function") def batch_client() -> Mock: mock_client = Mock() - job_result = BatchJobResult( + job_info = BatchJobInfo( configs={ ModelType.FACE: FaceConfig(), }, - urls="mock-url", - status=BatchJobStatus.FAILED, + urls=["mock-url"], + files=["mock-file"], + state=BatchJobState( + BatchJobStatus.FAILED, + created_timestamp_ms=0, + started_timestamp_ms=1, + ended_timestamp_ms=2, + ), ) - mock_client.get_job_result = Mock(return_value=job_result) + mock_client.get_job_info = Mock(return_value=job_info) return mock_client @@ -37,10 +43,10 @@ def test_invalid_await_timeout(self, batch_client: Mock): with pytest.raises(ValueError, match=re.escape(message)): job.await_complete(timeout=0) - def test_get_result(self, batch_client: Mock): + def test_get_info(self, batch_client: Mock): job = BatchJob(batch_client, "mock-job-id") - result = job.get_result() - assert result.status == BatchJobStatus.FAILED + info = job.get_info() + assert info.state.status == BatchJobStatus.FAILED def test_get_status(self, batch_client: Mock): job = BatchJob(batch_client, "mock-job-id") @@ -49,5 +55,5 @@ def test_get_status(self, batch_client: Mock): def test_await_complete(self, batch_client: Mock): job = BatchJob(batch_client, "mock-job-id") - result = job.await_complete() - assert result.status == BatchJobStatus.FAILED + info = job.await_complete() + assert info.state.status == BatchJobStatus.FAILED diff --git a/tests/batch/test_batch_job_info.py b/tests/batch/test_batch_job_info.py new file mode 100644 index 00000000..4b1075b4 --- /dev/null +++ b/tests/batch/test_batch_job_info.py @@ -0,0 +1,63 @@ +import json +from pathlib import Path + +import pytest + +from hume import BatchJobInfo, BatchJobStatus + + +@pytest.fixture(scope="function") +def completed_info() -> BatchJobInfo: + response_filepath = Path(__file__).parent / "data" / "info-response-completed.json" + with response_filepath.open() as f: + response = json.load(f) + return BatchJobInfo.from_response(response) + + +@pytest.fixture(scope="function") +def queued_info() -> BatchJobInfo: + response_filepath = Path(__file__).parent / "data" / "info-response-queued.json" + with response_filepath.open() as f: + response = json.load(f) + return BatchJobInfo.from_response(response) + + +@pytest.fixture(scope="function") +def failed_info() -> BatchJobInfo: + response_filepath = Path(__file__).parent / "data" / "info-response-failed.json" + with response_filepath.open() as f: + response = json.load(f) + return BatchJobInfo.from_response(response) + + +@pytest.mark.batch +class TestBatchJobInfo: + + def test_queued_status(self, queued_info: BatchJobInfo): + assert queued_info.state.status == BatchJobStatus.QUEUED + + def test_completed(self, completed_info: BatchJobInfo): + assert completed_info.state.status == BatchJobStatus.COMPLETED + assert completed_info.configs is not None + assert completed_info.urls is not None + assert completed_info.files is not None + assert completed_info.callback_url is not None + assert completed_info.notify is not None + + def test_job_time_completed(self, completed_info: BatchJobInfo): + assert completed_info.state.get_created_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:10' + assert completed_info.state.get_started_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:12' + assert completed_info.state.get_ended_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:15' + assert completed_info.state.get_run_time_ms() == 3000 + + def test_job_time_failed(self, failed_info: BatchJobInfo): + assert failed_info.state.get_created_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:11' + assert failed_info.state.get_started_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:14' + assert failed_info.state.get_ended_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:16' + assert failed_info.state.get_run_time_ms() == 2000 + + def test_job_time_queued(self, queued_info: BatchJobInfo): + assert queued_info.state.get_created_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:15' + assert queued_info.state.get_started_time() is None + assert queued_info.state.get_ended_time() is None + assert queued_info.state.get_run_time_ms() is None diff --git a/tests/batch/test_batch_job_result.py b/tests/batch/test_batch_job_result.py deleted file mode 100644 index eaf26926..00000000 --- a/tests/batch/test_batch_job_result.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -import re -from pathlib import Path - -import pytest - -from hume import BatchJobResult, BatchJobStatus, HumeClientException - - -@pytest.fixture(scope="function") -def completed_result() -> BatchJobResult: - response_filepath = Path(__file__).parent / "data" / "result-response-completed.json" - with response_filepath.open() as f: - response = json.load(f) - return BatchJobResult.from_response(response) - - -@pytest.fixture(scope="function") -def queued_result() -> BatchJobResult: - response_filepath = Path(__file__).parent / "data" / "result-response-queued.json" - with response_filepath.open() as f: - response = json.load(f) - return BatchJobResult.from_response(response) - - -@pytest.fixture(scope="function") -def failed_result() -> BatchJobResult: - response_filepath = Path(__file__).parent / "data" / "result-response-failed.json" - with response_filepath.open() as f: - response = json.load(f) - return BatchJobResult.from_response(response) - - -@pytest.mark.batch -class TestBatchJobResult: - - def test_queued_status(self, queued_result: BatchJobResult): - assert queued_result.status == BatchJobStatus.QUEUED - - def test_queued_download_fail(self, queued_result: BatchJobResult): - - message = "Could not download job artifacts. No artifacts found on job result." - with pytest.raises(HumeClientException, match=re.escape(message)): - queued_result.download_artifacts("fake-path") - - message = "Could not download job errors. No errors found on job result." - with pytest.raises(HumeClientException, match=re.escape(message)): - queued_result.download_errors("fake-path") - - message = "Could not download job predictions. No predictions found on job result." - with pytest.raises(HumeClientException, match=re.escape(message)): - queued_result.download_predictions("fake-path") - - def test_completed(self, completed_result: BatchJobResult): - assert completed_result.status == BatchJobStatus.COMPLETED - assert completed_result.predictions_url is not None - assert completed_result.errors_url is not None - assert completed_result.artifacts_url is not None - - def test_failed_message(self, failed_result: BatchJobResult): - assert failed_result.status == BatchJobStatus.FAILED - assert failed_result.get_error_message() == "user 'abcde' has exceeded their usage limit" - - def test_job_time_completed(self, completed_result: BatchJobResult): - assert completed_result.get_run_time() == 3 - assert completed_result.get_start_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:10' - assert completed_result.get_end_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:13' - - def test_job_time_failed(self, failed_result: BatchJobResult): - assert failed_result.get_run_time() == 2 - assert failed_result.get_start_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-09-21 19:56:39' - assert failed_result.get_end_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-09-21 19:56:41' - - def test_job_time_queued(self, queued_result: BatchJobResult): - assert queued_result.get_run_time() is None - assert queued_result.get_start_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:10' - assert queued_result.get_end_time() is None diff --git a/tests/batch/test_service_hume_batch_client.py b/tests/batch/test_service_hume_batch_client.py index d86d86a9..4db82e8b 100644 --- a/tests/batch/test_service_hume_batch_client.py +++ b/tests/batch/test_service_hume_batch_client.py @@ -1,6 +1,7 @@ import json import logging import re +import zipfile from datetime import datetime from pathlib import Path from typing import Dict @@ -8,7 +9,7 @@ import pytest from pytest import TempPathFactory -from hume import BatchJob, BatchJobResult, HumeBatchClient, HumeClientException +from hume import BatchJob, BatchJobInfo, HumeBatchClient, HumeClientException from hume.models.config import BurstConfig, FaceConfig, LanguageConfig, ProsodyConfig EvalData = Dict[str, str] @@ -30,44 +31,44 @@ def test_face(self, eval_data: EvalData, batch_client: HumeBatchClient, tmp_path config = FaceConfig(fps_pred=5, prob_threshold=0.24, identify_faces=True, min_face_size=78) job = batch_client.submit_job([data_url], [config]) assert isinstance(job, BatchJob) - assert len(job.id) == 32 + assert len(job.id) == 36 logger.info(f"Running test job {job.id}") - result = job.await_complete() + info = job.await_complete() job_files_dir = tmp_path_factory.mktemp("job-files") - self.check_result(result, job_files_dir) + self.check_info(job, info, job_files_dir) def test_burst(self, eval_data: EvalData, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory): data_url = eval_data["burst-amusement-009"] config = BurstConfig() job = batch_client.submit_job([data_url], [config]) assert isinstance(job, BatchJob) - assert len(job.id) == 32 + assert len(job.id) == 36 logger.info(f"Running test job {job.id}") - result = job.await_complete() + info = job.await_complete() job_files_dir = tmp_path_factory.mktemp("job-files") - self.check_result(result, job_files_dir) + self.check_info(job, info, job_files_dir) def test_prosody(self, eval_data: EvalData, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory): data_url = eval_data["prosody-horror-1051"] config = ProsodyConfig(identify_speakers=True) job = batch_client.submit_job([data_url], [config]) assert isinstance(job, BatchJob) - assert len(job.id) == 32 + assert len(job.id) == 36 logger.info(f"Running test job {job.id}") - result = job.await_complete() + info = job.await_complete() job_files_dir = tmp_path_factory.mktemp("job-files") - self.check_result(result, job_files_dir) + self.check_info(job, info, job_files_dir) def test_language(self, eval_data: EvalData, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory): data_url = eval_data["text-happy-place"] config = LanguageConfig(granularity="word", identify_speakers=True) job = batch_client.submit_job([data_url], [config]) assert isinstance(job, BatchJob) - assert len(job.id) == 32 + assert len(job.id) == 36 logger.info(f"Running test job {job.id}") - result = job.await_complete() + info = job.await_complete() job_files_dir = tmp_path_factory.mktemp("job-files") - self.check_result(result, job_files_dir) + self.check_info(job, info, job_files_dir) def test_client_invalid_api_key(self, eval_data: EvalData): invalid_client = HumeBatchClient("invalid-api-key") @@ -85,21 +86,23 @@ def test_job_invalid_api_key(self, eval_data: EvalData, batch_client: HumeBatchC rehydrated_job = BatchJob(invalid_client, job.id) rehydrated_job.await_complete(10) - def check_result(self, result: BatchJobResult, job_files_dir: Path): - predictions_filepath = job_files_dir / "results.json" - result.download_predictions(predictions_filepath) - with predictions_filepath.open() as f: - json.load(f) - - artifacts_filepath = job_files_dir / "artifacts" - result.download_artifacts(artifacts_filepath) - - error_filepath = job_files_dir / "errors.json" - result.download_errors(error_filepath) + def check_info(self, job: BatchJob, info: BatchJobInfo, job_files_dir: Path): + assert isinstance(info.state.get_run_time_ms(), int) + assert isinstance(info.state.get_started_time(), datetime) + assert isinstance(info.state.get_ended_time(), datetime) - error_message = result.get_error_message() - assert error_message is None - - assert isinstance(result.get_run_time(), int) - assert isinstance(result.get_start_time(), datetime) - assert isinstance(result.get_end_time(), datetime) + predictions_filepath = job_files_dir / "predictions.json" + job.download_predictions(predictions_filepath) + with predictions_filepath.open() as f: + predictions = json.load(f) + assert len(predictions) == 1 + assert "results" in predictions[0] + + artifacts_filepath = job_files_dir / "artifacts.zip" + job.download_artifacts(artifacts_filepath) + logger.info(f"Artifacts for job {job.id} downloaded to {artifacts_filepath}") + + extracted_artifacts_dir = job_files_dir / "extract" + with zipfile.ZipFile(artifacts_filepath, "r") as zip_ref: + zip_ref.extractall(extracted_artifacts_dir) + assert len(list(extracted_artifacts_dir.iterdir())) == 1 diff --git a/tests/stream/test_hume_stream_client.py b/tests/stream/test_hume_stream_client.py index 32c59ae6..e80567df 100644 --- a/tests/stream/test_hume_stream_client.py +++ b/tests/stream/test_hume_stream_client.py @@ -13,7 +13,7 @@ def mock_connect(uri: str, extra_headers: Optional[Dict[str, str]] = None): assert uri == "wss://api.hume.ai/v0/stream/models" assert isinstance(extra_headers, dict) - assert extra_headers.get("X-Hume-Client-Name") == "python-sdk" + assert extra_headers.get("X-Hume-Client-Name") == "python_sdk" assert extra_headers.get("X-Hume-Api-Key") is not None assert isinstance(extra_headers.get("X-Hume-Client-Version"), str)