diff --git a/src/tiktok_research_api_helper/api_client.py b/src/tiktok_research_api_helper/api_client.py index 5affbd7..b0f16e4 100644 --- a/src/tiktok_research_api_helper/api_client.py +++ b/src/tiktok_research_api_helper/api_client.py @@ -28,6 +28,8 @@ INVALID_SEARCH_ID_ERROR_RETRY_WAIT = 5 INVALID_SEARCH_ID_ERROR_MAX_NUM_RETRIES = 5 +DAILY_API_REQUEST_QUOTA = 1000 + class ApiRateLimitError(Exception): pass @@ -550,7 +552,7 @@ def num_api_requests_sent(self): @property def expected_remaining_api_request_quota(self): - return 1000 - self.num_api_requests_sent + return DAILY_API_REQUEST_QUOTA - self.num_api_requests_sent def api_results_iter(self) -> TikTokApiClientFetchResult: """Fetches all results from API (ie requests until API indicates query results have been @@ -602,8 +604,9 @@ def api_results_iter(self) -> TikTokApiClientFetchResult: break logging.info( - "Crawl completed (or reached configured max_requests). Num api requests: %s. Expected " + "Crawl completed (or reached configured max_requests: %s). Num api requests: %s. Expected " "remaining API request quota: %s", + self._config.max_requests, self.num_api_requests_sent, self.expected_remaining_api_request_quota, ) @@ -614,14 +617,14 @@ def _max_requests_reached(self) -> bool: def _should_continue(self, crawl: Crawl) -> bool: should_continue = crawl.has_more and not self._max_requests_reached() logging.debug( - "crawl.has_more: %s, max_requests_reached: %s, shoudld_continue: %s", + "crawl.has_more: %s, max_requests_reached: %s, should_continue: %s", crawl.has_more, self._max_requests_reached(), should_continue, ) if crawl.has_more and self._max_requests_reached(): logging.info( - "Max requests rewached. Will discontinue this crawl even though API response " + "Max requests reached. Will discontinue this crawl even though API response " "indicates more results." ) return should_continue diff --git a/src/tiktok_research_api_helper/cli_data_acquisition.py b/src/tiktok_research_api_helper/cli_data_acquisition.py index 897f432..2a66e1a 100644 --- a/src/tiktok_research_api_helper/cli_data_acquisition.py +++ b/src/tiktok_research_api_helper/cli_data_acquisition.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Annotated, Any +import attrs import pause import pendulum import typer @@ -14,6 +15,7 @@ from tiktok_research_api_helper import region_codes, utils from tiktok_research_api_helper.api_client import ( + DAILY_API_REQUEST_QUOTA, ApiClientConfig, ApiRateLimitWaitStrategy, TikTokApiClient, @@ -66,55 +68,77 @@ CrawlDateWindow = namedtuple("CrawlDateWindow", ["start_date", "end_date"]) -def run_long_query(config: ApiClientConfig, spider_top_n_music_ids: int | None = None): +def run_long_query(config: ApiClientConfig) -> int: """Runs a "long" query, defined as one that may need multiple requests to get all the data. Unless you have a good reason to believe otherwise, queries should default to be considered - "long".""" + "long". + + Returns: + int: number of API requests sent (ie likely amount of API quota consumed). + """ api_client = TikTokApiClient.from_config(config) fetch_results = api_client.fetch_and_store_all() + # TODO(macpd): fix this return value structure. maybe a namedtuple + return { + "num_api_requests_sent": api_client.num_api_requests_sent, + "crawl_id": fetch_results.crawl.id, + } - if spider_top_n_music_ids is None: - return - potentially_remaining_qutoa = api_client.expected_remaining_api_request_quota - if potentially_remaining_qutoa <= 0: - # TODO(macpd): if crawl takes more than one day, and thus more than 1 day of allowed - # requests have been made, this check will be wrong. - logging.info("Refusing to spider top music IDs because no API quota remaints") - return +def driver_single_day(config: ApiClientConfig, spider_top_n_music_ids) -> int: + """Simpler driver for a single day of query. + + Returns: + int: number of API requests sent (ie likely amount of API quota consumed). + """ + assert ( + config.start_date == config.end_date + ), "Start and final date must be the same for single day driver" + + return run_long_query(config, spider_top_n_music_ids) + - config.max_requests = potentially_remaining_qutoa +def run_spider_top_n_music_ids_query( + config: ApiClientConfig, + spider_top_n_music_ids: int, + crawl_ids: Sequence[int], + expected_remaining_api_request_quota: int, +) -> int: + if expected_remaining_api_request_quota <= 0: + # TODO(macpd): add some way to bypass this. + logging.info("Refusing to spider top music IDs because no API quota remains") + return - crawl_id = fetch_results.crawl.id with Session(config.engine) as session: top_music_ids = most_used_music_ids( session, limit=None if spider_top_n_music_ids == 0 else spider_top_n_music_ids, - crawl_id=crawl_id, + crawl_ids=crawl_ids, ) new_query = generate_query( include_music_ids=",".join([str(x["music_id"]) for x in top_music_ids]) ) - config.query = new_query + new_crawl_tags = None if config.crawl_tags: - config.crawl_tags = [f"{tag}-music-id-spidering" for tag in config.crawl_tags] - - api_client = TikTokApiClient.from_config(config) - fetch_results = api_client.fetch_and_store_all() - + new_crawl_tags = [f"{tag}-music-id-spidering" for tag in config.crawl_tags] -def driver_single_day(config: ApiClientConfig, spider_top_n_music_ids): - """Simpler driver for a single day of query""" - assert ( - config.start_date == config.end_date - ), "Start and final date must be the same for single day driver" + new_config = attrs.evolve( + config, + max_requests=expected_remaining_api_request_quota, + query=new_query, + crawl_tags=new_crawl_tags, + ) - run_long_query(config, spider_top_n_music_ids) + api_client = TikTokApiClient.from_config(new_config) + api_client.fetch_and_store_all() + return api_client.num_api_requests_sent def main_driver(config: ApiClientConfig, spider_top_n_music_ids: int | None = None): + num_api_requests_sent = 0 + crawl_ids = [] days_per_iter = utils.int_to_days(_DAYS_PER_ITER) start_date = copy(config.start_date) @@ -124,24 +148,43 @@ def main_driver(config: ApiClientConfig, spider_top_n_music_ids: int | None = No local_end_date = start_date + days_per_iter local_end_date = min(local_end_date, config.end_date) - new_config = ApiClientConfig( - query=config.query, - start_date=start_date, - end_date=local_end_date, - engine=config.engine, - stop_after_one_request=config.stop_after_one_request, - crawl_tags=config.crawl_tags, - raw_responses_output_dir=config.raw_responses_output_dir, - api_credentials_file=config.api_credentials_file, - api_rate_limit_wait_strategy=config.api_rate_limit_wait_strategy, - ) - run_long_query(new_config, spider_top_n_music_ids=spider_top_n_music_ids) + new_config = attrs.evolve(config, start_date=start_date, end_date=local_end_date) + + # ApiClientConfig( + # query=config.query, + # start_date=start_date, + # end_date=local_end_date, + # engine=config.engine, + # stop_after_one_request=config.stop_after_one_request, + # crawl_tags=config.crawl_tags, + # raw_responses_output_dir=config.raw_responses_output_dir, + # api_credentials_file=config.api_credentials_file, + # api_rate_limit_wait_strategy=config.api_rate_limit_wait_strategy, + # ) + ret = run_long_query(new_config) + num_api_requests_sent += ret["num_api_requests_sent"] + crawl_ids.append(ret["crawl_id"]) start_date += days_per_iter if config.stop_after_one_request: logging.log(logging.WARN, "Stopping after one request") - break + return + + expected_remaining_api_request_quota = 0 + if spider_top_n_music_ids: + if num_api_requests_sent == DAILY_API_REQUEST_QUOTA: + # TODO(macpd): handle no remaing quota, perhaps flag to do anyway? + logging.warning("Refusing to spider top music IDs because no API quota remains") + return + + expected_remaining_api_request_quota = DAILY_API_REQUEST_QUOTA - (num_api_requests_sent % DAILY_API_REQUEST_QUOTA) + run_spider_top_n_music_ids_query( + config=config, + crawl_ids=crawl_ids, + spider_top_n_music_ids=spider_top_n_music_ids, + expected_remaining_api_request_quota=expected_remaining_api_request_quota, + ) @APP.command() diff --git a/src/tiktok_research_api_helper/models.py b/src/tiktok_research_api_helper/models.py index b123c6f..64c51dd 100644 --- a/src/tiktok_research_api_helper/models.py +++ b/src/tiktok_research_api_helper/models.py @@ -203,15 +203,17 @@ def effect_ids(self): return {effect.effect_id for effect in self.effects} -def most_used_music_ids(session: Session, limit: int | None = None, crawl_id: int | None = None): - """Returns dict of most used music_ids with count of video id with that music_id. If crawl_id - specified, only operates on videos associated to that video id. +def most_used_music_ids( + session: Session, limit: int | None = None, crawl_ids: Sequence[int] | None = None +): + """Returns dict of most used music_ids with count of video id with that music_id. If crawl_ids + specified, only operates on videos associated to those crawl IDs. """ - if crawl_id: + if crawl_ids: select_stmt = ( select(Video.music_id, func.count(Video.id).label("num_videos")) .join(Video.crawls) - .where(Crawl.id == crawl_id) + .where(Crawl.id.in_(crawl_ids)) ) else: select_stmt = select(Video.music_id, func.count(Video.id).label("num_videos")) diff --git a/tests/test_sql.py b/tests/test_sql.py index 41d4420..fa02e73 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -680,7 +680,7 @@ def test_most_used_music_ids(test_database_engine, mock_crawl, api_response_vide key=lambda x: x["music_id"], ) assert sorted( - most_used_music_ids(session, limit=None, crawl_id=mock_crawl_id), + most_used_music_ids(session, limit=None, crawl_ids=[mock_crawl_id]), key=lambda x: x["music_id"], ) == sorted( [{"music_id": x["music_id"], "num_videos": 1} for x in api_response_videos], @@ -698,11 +698,11 @@ def test_most_used_music_ids(test_database_engine, mock_crawl, api_response_vide {"music_id": new_music_id, "num_videos": len(api_response_videos) + 1}, {"music_id": 6817429116187314177, "num_videos": 1}, ] - assert most_used_music_ids(session, limit=2, crawl_id=mock_crawl_id) == [ + assert most_used_music_ids(session, limit=2, crawl_ids=[mock_crawl_id]) == [ {"music_id": new_music_id, "num_videos": len(api_response_videos) + 1}, {"music_id": 6817429116187314177, "num_videos": 1}, ] - assert most_used_music_ids(session, limit=2, crawl_id=mock_crawl_id + 1) == [] + assert most_used_music_ids(session, limit=2, crawl_ids=[mock_crawl_id + 1]) == [] def test_remove_all(test_database_engine, mock_videos, mock_crawl):