Skip to content

Commit

Permalink
Fix mypy and pylint issues in tests (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorybchris authored Feb 26, 2024
1 parent 3234bf8 commit f3f07fd
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 97 deletions.
19 changes: 9 additions & 10 deletions tests/batch/test_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

import pytest

from hume import BatchJob, BatchJobDetails, BatchJobState, BatchJobStatus
from hume import BatchJob, BatchJobDetails, BatchJobState, BatchJobStatus, HumeClientException
from hume.models import ModelType
from hume.models.config import FaceConfig
from hume import HumeClientException


@pytest.fixture(scope="function")
def batch_client() -> Mock:
@pytest.fixture(name="batch_client", scope="function")
def batch_client_fixture() -> Mock:
mock_client = Mock()
job_details = BatchJobDetails(
configs={
Expand All @@ -32,34 +31,34 @@ def batch_client() -> Mock:
@pytest.mark.batch
class TestBatchJob:

def test_job_id(self, batch_client: Mock):
def test_job_id(self, batch_client: Mock) -> None:
mock_job_id = "mock-job-id"
job = BatchJob(batch_client, mock_job_id)
assert job.id == mock_job_id

def test_invalid_await_timeout(self, batch_client: Mock):
def test_invalid_await_timeout(self, batch_client: Mock) -> None:
job = BatchJob(batch_client, "mock-job-id")

message = "timeout must be at least 1 second"
with pytest.raises(ValueError, match=re.escape(message)):
job.await_complete(timeout=0)

def test_get_details(self, batch_client: Mock):
def test_get_details(self, batch_client: Mock) -> None:
job = BatchJob(batch_client, "mock-job-id")
details = job.get_details()
assert details.state.status == BatchJobStatus.FAILED

def test_get_status(self, batch_client: Mock):
def test_get_status(self, batch_client: Mock) -> None:
job = BatchJob(batch_client, "mock-job-id")
status = job.get_status()
assert status == BatchJobStatus.FAILED

def test_await_complete(self, batch_client: Mock):
def test_await_complete(self, batch_client: Mock) -> None:
job = BatchJob(batch_client, "mock-job-id")
details = job.await_complete()
assert details.state.status == BatchJobStatus.FAILED

def test_raise_on_failed(self, batch_client: Mock):
def test_raise_on_failed(self, batch_client: Mock) -> None:
job = BatchJob(batch_client, "mock-job-id")
message = "BatchJob mock-job-id failed."
with pytest.raises(HumeClientException, match=re.escape(message)):
Expand Down
42 changes: 24 additions & 18 deletions tests/batch/test_batch_job_details.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Optional

import pytest

from hume import BatchJobDetails, BatchJobStatus


@pytest.fixture(scope="function")
def completed_details() -> BatchJobDetails:
@pytest.fixture(name="completed_details", scope="function")
def completed_details_fixture() -> BatchJobDetails:
response_filepath = Path(__file__).parent / "data" / "details-response-completed.json"
with response_filepath.open() as f:
response = json.load(f)
return BatchJobDetails.from_response(response)


@pytest.fixture(scope="function")
def queued_details() -> BatchJobDetails:
@pytest.fixture(name="queued_details", scope="function")
def queued_details_fixture() -> BatchJobDetails:
response_filepath = Path(__file__).parent / "data" / "details-response-queued.json"
with response_filepath.open() as f:
response = json.load(f)
return BatchJobDetails.from_response(response)


@pytest.fixture(scope="function")
def failed_details() -> BatchJobDetails:
@pytest.fixture(name="failed_details", scope="function")
def failed_details_fixture() -> BatchJobDetails:
response_filepath = Path(__file__).parent / "data" / "details-response-failed.json"
with response_filepath.open() as f:
response = json.load(f)
Expand All @@ -33,31 +35,35 @@ def failed_details() -> BatchJobDetails:
@pytest.mark.batch
class TestBatchJobDetails:

def test_queued_status(self, queued_details: BatchJobDetails):
def test_queued_status(self, queued_details: BatchJobDetails) -> None:
assert queued_details.get_status() == BatchJobStatus.QUEUED

def test_completed(self, completed_details: BatchJobDetails):
def test_completed(self, completed_details: BatchJobDetails) -> None:
assert completed_details.get_status() == BatchJobStatus.COMPLETED
assert completed_details.configs is not None
assert completed_details.urls is not None
assert completed_details.files is not None
assert completed_details.callback_url is not None
assert completed_details.notify is not None

def test_job_time_completed(self, completed_details: BatchJobDetails):
assert completed_details.get_created_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:10'
assert completed_details.get_started_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:12'
assert completed_details.get_ended_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:15'
def test_job_time_completed(self, completed_details: BatchJobDetails) -> None:
self.check_time(completed_details.get_created_time(), "2022-08-15 20:20:10")
self.check_time(completed_details.get_started_time(), "2022-08-15 20:20:12")
self.check_time(completed_details.get_ended_time(), "2022-08-15 20:20:15")
assert completed_details.get_run_time_ms() == 3000

def test_job_time_failed(self, failed_details: BatchJobDetails):
assert failed_details.get_created_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:11'
assert failed_details.get_started_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:14'
assert failed_details.get_ended_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:16'
def test_job_time_failed(self, failed_details: BatchJobDetails) -> None:
self.check_time(failed_details.get_created_time(), "2022-08-15 20:20:11")
self.check_time(failed_details.get_started_time(), "2022-08-15 20:20:14")
self.check_time(failed_details.get_ended_time(), "2022-08-15 20:20:16")
assert failed_details.get_run_time_ms() == 2000

def test_job_time_queued(self, queued_details: BatchJobDetails):
assert queued_details.get_created_time().strftime('%Y-%m-%d %H:%M:%S') == '2022-08-15 20:20:15'
def test_job_time_queued(self, queued_details: BatchJobDetails) -> None:
self.check_time(queued_details.get_created_time(), "2022-08-15 20:20:15")
assert queued_details.get_started_time() is None
assert queued_details.get_ended_time() is None
assert queued_details.get_run_time_ms() is None

def check_time(self, date: Optional[datetime], formatted_date: str) -> None:
assert date is not None
assert date.strftime("%Y-%m-%d %H:%M:%S") == formatted_date
2 changes: 1 addition & 1 deletion tests/batch/test_batch_job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@pytest.mark.batch
class TestBatchJobState:

def test_create(self):
def test_create(self) -> None:
state = BatchJobState(
status=BatchJobStatus.COMPLETED,
created_timestamp_ms=1,
Expand Down
40 changes: 23 additions & 17 deletions tests/batch/test_batch_job_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,40 @@
@pytest.mark.batch
class TestBatchJobStatus:

def test_update(self):
def test_update(self) -> None:
# Note: If another status is added to the enum make sure to update parametrized tests below:
# - test_continuity
# - test_is_terminal
assert len(BatchJobStatus) == 4

@pytest.mark.parametrize("status_str", [
"COMPLETED",
"FAILED",
"IN_PROGRESS",
"QUEUED",
])
def test_continuity(self, status_str: str):
@pytest.mark.parametrize(
"status_str",
[
"COMPLETED",
"FAILED",
"IN_PROGRESS",
"QUEUED",
],
)
def test_continuity(self, status_str: str) -> None:
assert BatchJobStatus[status_str].value == status_str

def test_from_str(self):
def test_from_str(self) -> None:
assert BatchJobStatus.from_str("COMPLETED") == BatchJobStatus.COMPLETED

def test_from_str_fail(self):
def test_from_str_fail(self) -> None:
message = "Unknown status 'COMPLETE'"
with pytest.raises(ValueError, match=re.escape(message)):
BatchJobStatus.from_str("COMPLETE")

@pytest.mark.parametrize("status, is_terminal", [
(BatchJobStatus.COMPLETED, True),
(BatchJobStatus.FAILED, True),
(BatchJobStatus.IN_PROGRESS, False),
(BatchJobStatus.QUEUED, False),
])
def test_is_terminal(self, status: BatchJobStatus, is_terminal: bool):
@pytest.mark.parametrize(
"status, is_terminal",
[
(BatchJobStatus.COMPLETED, True),
(BatchJobStatus.FAILED, True),
(BatchJobStatus.IN_PROGRESS, False),
(BatchJobStatus.QUEUED, False),
],
)
def test_is_terminal(self, status: BatchJobStatus, is_terminal: bool) -> None:
assert BatchJobStatus.is_terminal(status) == is_terminal
29 changes: 15 additions & 14 deletions tests/batch/test_hume_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from hume.models.config import BurstConfig, FaceConfig, LanguageConfig, ProsodyConfig


@pytest.fixture(scope="function")
def batch_client(monkeypatch: MonkeyPatch) -> HumeBatchClient:
@pytest.fixture(name="batch_client", scope="function")
def batch_client_fixture(monkeypatch: MonkeyPatch) -> HumeBatchClient:
mock_submit_request = MagicMock(return_value="temp-job-value")
monkeypatch.setattr(HumeBatchClient, "_submit_job", mock_submit_request)
client = HumeBatchClient("0000-0000-0000-0000", timeout=15)
Expand All @@ -19,7 +19,7 @@ def batch_client(monkeypatch: MonkeyPatch) -> HumeBatchClient:
@pytest.mark.batch
class TestHumeBatchClient:

def test_face(self, batch_client: HumeBatchClient):
def test_face(self, batch_client: HumeBatchClient) -> None:
mock_url = "mock-url"
config = FaceConfig(fps_pred=5, prob_threshold=0.24, identify_faces=True, min_face_size=78)
job = batch_client.submit_job([mock_url], [config])
Expand All @@ -40,7 +40,7 @@ def test_face(self, batch_client: HumeBatchClient):
None,
)

def test_burst(self, batch_client: HumeBatchClient):
def test_burst(self, batch_client: HumeBatchClient) -> None:
mock_url = "mock-url"
config = BurstConfig()
job = batch_client.submit_job([mock_url], [config])
Expand All @@ -56,7 +56,7 @@ def test_burst(self, batch_client: HumeBatchClient):
None,
)

def test_prosody(self, batch_client: HumeBatchClient):
def test_prosody(self, batch_client: HumeBatchClient) -> None:
mock_url = "mock-url"
config = ProsodyConfig(identify_speakers=True)
job = batch_client.submit_job([mock_url], [config])
Expand All @@ -74,7 +74,7 @@ def test_prosody(self, batch_client: HumeBatchClient):
None,
)

def test_language(self, batch_client: HumeBatchClient):
def test_language(self, batch_client: HumeBatchClient) -> None:
mock_url = "mock-url"
config = LanguageConfig(granularity="word", identify_speakers=True)
job = batch_client.submit_job([mock_url], [config])
Expand All @@ -93,7 +93,7 @@ def test_language(self, batch_client: HumeBatchClient):
None,
)

def test_language_with_raw_text(self, batch_client: HumeBatchClient):
def test_language_with_raw_text(self, batch_client: HumeBatchClient) -> None:
mock_text = "Test!"
config = LanguageConfig(granularity="word", identify_speakers=True)
job = batch_client.submit_job([], [config], text=[mock_text])
Expand All @@ -113,11 +113,11 @@ def test_language_with_raw_text(self, batch_client: HumeBatchClient):
None,
)

def test_get_job(self, batch_client: HumeBatchClient):
def test_get_job(self, batch_client: HumeBatchClient) -> None:
job = batch_client.get_job("mock-job-id")
assert job.id == "mock-job-id"

def test_files(self, batch_client: HumeBatchClient):
def test_files(self, batch_client: HumeBatchClient) -> None:
mock_filepath = "my-audio.mp3"
config = ProsodyConfig(identify_speakers=True)
job = batch_client.submit_job(None, [config], files=[mock_filepath])
Expand All @@ -130,22 +130,23 @@ def test_files(self, batch_client: HumeBatchClient):
"prosody": {
"identify_speakers": True,
},
}
},
},
['my-audio.mp3'],
["my-audio.mp3"],
)

def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory):
def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory) -> None:
dirpath = tmp_path_factory.mktemp("multipart")
filepath = dirpath / "my-audio.mp3"
with filepath.open("w") as f:
f.write("I can't believe this test passed!")

request_body = {"mock": "body"}
filepaths = [filepath]
# pylint: disable=protected-access
result = batch_client._get_multipart_form_data(request_body, filepaths)

assert result == [
('file', ('my-audio.mp3', b"I can't believe this test passed!")),
('json', b'{"mock": "body"}'),
("file", ("my-audio.mp3", b"I can't believe this test passed!")),
("json", b'{"mock": "body"}'),
]
Loading

0 comments on commit f3f07fd

Please sign in to comment.