diff --git a/src/tiktok_api_helper/test_api_client.py b/src/tiktok_api_helper/test_api_client.py index a3b4ad7..894e225 100644 --- a/src/tiktok_api_helper/test_api_client.py +++ b/src/tiktok_api_helper/test_api_client.py @@ -3,12 +3,14 @@ from unittest.mock import Mock, call, MagicMock import json import copy +import itertools import pytest import requests import pendulum from tiktok_api_helper import api_client +from tiktok_api_helper import query FAKE_SECRETS_YAML_FILE = Path("src/tiktok_api_helper/testdata/fake_secrets.yaml") @@ -237,17 +239,16 @@ def mock_tiktok_request_client(mock_tiktok_responses): return mock_request_client -def test_tiktok_api_client(mock_tiktok_request_client, mock_tiktok_responses): - config = api_client.AcquitionConfig(query=None, start_date=pendulum.parse('20240601'), final_date=pendulum.parse('20240601'), engine=None, api_credentials_file='') +def test_tiktok_api_client_api_results_iter(mock_tiktok_request_client, mock_tiktok_responses): + config = api_client.AcquitionConfig(query=query.generate_query(include_any_hashtags='test1,test2'), start_date=pendulum.parse('20240601'), final_date=pendulum.parse('20240601'), engine=None, api_credentials_file='') client = api_client.TikTokApiClient(request_client=mock_tiktok_request_client, config=config) for i, response in enumerate(client.api_results_iter()): assert response.videos == mock_tiktok_responses[i].videos assert response.crawl.has_more == (True if i < 2 else False), f"hash_more: {response.crawl.has_more}, i: {i}" assert response.crawl.cursor == 100 * (i + 1) - assert mock_tiktok_request_client.fetch.call_count == 3 - mock_tiktok_request_client.fetch.assert_has_calls( - [call(api_client.TiktokRequest(query=config.query, + assert mock_tiktok_request_client.fetch.call_count == len(mock_tiktok_responses) + assert mock_tiktok_request_client.fetch.mock_calls == [call(api_client.TiktokRequest(query=config.query, start_date=config.start_date.strftime('%Y%m%d'), end_date=config.final_date.strftime('%Y%m%d'), max_count=100, @@ -267,4 +268,36 @@ def test_tiktok_api_client(mock_tiktok_request_client, mock_tiktok_responses): max_count=100, is_random=False, cursor=200, - search_id=mock_tiktok_responses[-1].data['search_id']))]) + search_id=mock_tiktok_responses[-1].data['search_id']))] + +def test_tiktok_api_client_fetch_all(mock_tiktok_request_client, mock_tiktok_responses): + config = api_client.AcquitionConfig(query=query.generate_query(include_any_hashtags='test1,test2'), start_date=pendulum.parse('20240601'), final_date=pendulum.parse('20240601'), engine=None, api_credentials_file='') + client = api_client.TikTokApiClient(request_client=mock_tiktok_request_client, config=config) + + response = client.fetch_all() + + assert response.videos == list(itertools.chain.from_iterable([r.videos for r in mock_tiktok_responses])) + assert response.crawl.has_more == False + assert response.crawl.cursor == 100 * len(mock_tiktok_responses) + assert mock_tiktok_request_client.fetch.call_count == len(mock_tiktok_responses) + assert mock_tiktok_request_client.fetch.mock_calls == [call(api_client.TiktokRequest(query=config.query, + start_date=config.start_date.strftime('%Y%m%d'), + end_date=config.final_date.strftime('%Y%m%d'), + max_count=100, + is_random=False, + cursor=None, + search_id=None)), + call(api_client.TiktokRequest(query=config.query, + start_date=config.start_date.strftime('%Y%m%d'), + end_date=config.final_date.strftime('%Y%m%d'), + max_count=100, + is_random=False, + cursor=100, + search_id=mock_tiktok_responses[-1].data['search_id'])), + call(api_client.TiktokRequest(query=config.query, + start_date=config.start_date.strftime('%Y%m%d'), + end_date=config.final_date.strftime('%Y%m%d'), + max_count=100, + is_random=False, + cursor=200, + search_id=mock_tiktok_responses[-1].data['search_id']))]