From 9bc423a61418c8ec20a70006fdcf0f7bc377c6ad Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 19 Aug 2024 19:35:37 -0400 Subject: [PATCH] make --query-file-json a repeatedable flag so that mutliple query files can be passed and run serially. remove single_day_driver since main_driver can do that. (#79) --- src/tiktok_research_api_helper/api_client.py | 2 +- .../cli/custom_argument_types.py | 8 +- src/tiktok_research_api_helper/cli/main.py | 85 ++++++++----------- 3 files changed, 41 insertions(+), 54 deletions(-) diff --git a/src/tiktok_research_api_helper/api_client.py b/src/tiktok_research_api_helper/api_client.py index 92a0654..628843c 100644 --- a/src/tiktok_research_api_helper/api_client.py +++ b/src/tiktok_research_api_helper/api_client.py @@ -779,7 +779,7 @@ def api_results_iter(self, query_config: VideoQueryConfig) -> TikTokApiClientFet ) logging.debug("Crawl: %s", crawl) - logging.info("Beginning API results fetch.") + logging.info("Beginning API results fetch VideoQueryConfig: %s.", query_config) while crawl.has_more: request = TikTokVideoRequest.from_config( config=query_config, diff --git a/src/tiktok_research_api_helper/cli/custom_argument_types.py b/src/tiktok_research_api_helper/cli/custom_argument_types.py index dcf83c0..b1a40cd 100644 --- a/src/tiktok_research_api_helper/cli/custom_argument_types.py +++ b/src/tiktok_research_api_helper/cli/custom_argument_types.py @@ -32,16 +32,16 @@ DBUrlType = Annotated[Optional[str], typer.Option(help="database URL for storing API results")] -JsonQueryFileType = Annotated[ - Optional[Path], +JsonQueryFileListType = Annotated[ + Optional[list[Path]], typer.Option( + "--query-file-json", exists=True, file_okay=True, dir_okay=False, help=( "Path to file query as JSON. File contents will be parsed as JSON and used directly " - "in query of API requests. Can be used multiple times to run multiple queries (in " - "serial)" + "in query of API requests. Can be used multiple times to run multiple queries serially." ), ), ] diff --git a/src/tiktok_research_api_helper/cli/main.py b/src/tiktok_research_api_helper/cli/main.py index a223d6d..e684726 100644 --- a/src/tiktok_research_api_helper/cli/main.py +++ b/src/tiktok_research_api_helper/cli/main.py @@ -39,7 +39,7 @@ IncludeAllKeywordListType, IncludeAnyHashtagListType, IncludeAnyKeywordListType, - JsonQueryFileType, + JsonQueryFileListType, MaxApiRequests, MaxDaysPerQueryType, OnlyUsernamesListType, @@ -71,16 +71,6 @@ _DEFAULT_CREDENTIALS_FILE_PATH = Path("./secrets.yaml") -def driver_single_day(client_config: ApiClientConfig, query_config: VideoQueryConfig): - """Simpler driver for a single day of query""" - assert ( - query_config.start_date == query_config.end_date - ), "Start and final date must be the same for single day driver" - - api_client = TikTokApiClient.from_config(client_config) - api_client.fetch_and_store_all(query_config) - - def main_driver( api_client_config: ApiClientConfig, query_config: VideoQueryConfig, @@ -93,7 +83,7 @@ def main_driver( api_client = TikTokApiClient.from_config(api_client_config) - while start_date < query_config.end_date: + while start_date <= query_config.end_date: local_end_date = start_date + max_days_per_query local_end_date = min(local_end_date, query_config.end_date) local_query_config = attrs.evolve( @@ -311,7 +301,7 @@ def run_repeated( db_url: DBUrlType = None, crawl_tag: CrawlTagType = None, raw_responses_output_dir: RawResponsesOutputDir = None, - query_file_json: JsonQueryFileType = None, + query_file_json_list: JsonQueryFileListType = None, api_credentials_file: ApiCredentialsFileType = _DEFAULT_CREDENTIALS_FILE_PATH, rate_limit_wait_strategy: ApiRateLimitWaitStrategyType = ( ApiRateLimitWaitStrategy.WAIT_FOUR_HOURS @@ -356,7 +346,7 @@ def run_repeated( db_url=db_url, crawl_tag=crawl_tag, raw_responses_output_dir=raw_responses_output_dir, - query_file_json=query_file_json, + query_file_json_list=query_file_json_list, api_credentials_file=api_credentials_file, rate_limit_wait_strategy=rate_limit_wait_strategy, region=region, @@ -446,7 +436,7 @@ def run( max_days_per_query: MaxDaysPerQueryType = _MAX_DAYS_PER_QUERY_DEFAULT, crawl_tag: CrawlTagType = None, raw_responses_output_dir: RawResponsesOutputDir = None, - query_file_json: JsonQueryFileType = None, + query_file_json_list: JsonQueryFileListType = None, api_credentials_file: ApiCredentialsFileType = _DEFAULT_CREDENTIALS_FILE_PATH, rate_limit_wait_strategy: ApiRateLimitWaitStrategyType = ( ApiRateLimitWaitStrategy.WAIT_FOUR_HOURS @@ -505,7 +495,7 @@ def run( {"--db-url": db_url, "--db-file": db_file}, at_least_one_required=True ) - if query_file_json: + if query_file_json_list: if any( [ include_any_hashtags, @@ -528,7 +518,7 @@ def run( "--include-any-keywords, etc" ) - query = get_query_file_json(query_file_json) + query_list = [get_query_file_json(query_file) for query_file in query_file_json_list] elif video_id: if any( [ @@ -545,7 +535,7 @@ def run( ] ): raise typer.BadParameter("--video_id is mutually exclusisive with other flags") - query = generate_video_id_query(video_id) + query_list = [generate_video_id_query(video_id)] # Since query for a specific video by ID should only return 1 video, we use the max allowed # date span for queries. max_days_per_query = _DAYS_PER_QUERY_MAX_API_ALLOWED @@ -583,21 +573,21 @@ def run( validate_region_code_flag_value(region) - query = generate_query( - region_codes=region, - include_any_hashtags=include_any_hashtags, - include_all_hashtags=include_all_hashtags, - exclude_any_hashtags=exclude_any_hashtags, - exclude_all_hashtags=exclude_all_hashtags, - include_any_keywords=include_any_keywords, - include_all_keywords=include_all_keywords, - exclude_any_keywords=exclude_any_keywords, - exclude_all_keywords=exclude_all_keywords, - only_from_usernames=only_from_usernames, - exclude_from_usernames=exclude_from_usernames, - ) - - logging.log(logging.INFO, f"VideoQuery: {query}") + query_list = [ + generate_query( + region_codes=region, + include_any_hashtags=include_any_hashtags, + include_all_hashtags=include_all_hashtags, + exclude_any_hashtags=exclude_any_hashtags, + exclude_all_hashtags=exclude_all_hashtags, + include_any_keywords=include_any_keywords, + include_all_keywords=include_all_keywords, + exclude_any_keywords=exclude_any_keywords, + exclude_all_keywords=exclude_all_keywords, + only_from_usernames=only_from_usernames, + exclude_from_usernames=exclude_from_usernames, + ) + ] if db_url: engine = get_engine_and_create_tables(db_url) @@ -611,21 +601,18 @@ def run( api_credentials_file=api_credentials_file, api_rate_limit_wait_strategy=rate_limit_wait_strategy, ) - query_config = VideoQueryConfig( - query=query, - start_date=start_date_datetime, - end_date=end_date_datetime, - crawl_tags=[crawl_tag] if crawl_tag else None, - fetch_user_info=fetch_user_info, - fetch_comments=fetch_comments, - ) - logging.info("API client config: %s\nVideo query config: %s", api_client_config, query_config) - - if query_config.start_date == query_config.end_date: - logging.info( - "Start and final date are the same - running single day driver", + query_configs = [ + VideoQueryConfig( + query=query, + start_date=start_date_datetime, + end_date=end_date_datetime, + crawl_tags=[crawl_tag] if crawl_tag else None, + fetch_user_info=fetch_user_info, + fetch_comments=fetch_comments, ) - driver_single_day(api_client_config, query_config) - else: - logging.info("Running main driver") + for query in query_list + ] + logging.info("API client config: %s\nVideo query configs: %s", api_client_config, query_configs) + + for query_config in query_configs: main_driver(api_client_config, query_config, max_days_per_query=max_days_per_query)