diff --git a/src/tiktok_research_api_helper/cli/custom_argument_types.py b/src/tiktok_research_api_helper/cli/custom_argument_types.py index fcee4d6..8ffdd6b 100644 --- a/src/tiktok_research_api_helper/cli/custom_argument_types.py +++ b/src/tiktok_research_api_helper/cli/custom_argument_types.py @@ -181,6 +181,8 @@ ), ] +VideoIdType = Annotated[int, typer.Option(help="ID of specific video to query for.")] + CrawlTagType = Annotated[ str, typer.Option( @@ -239,3 +241,18 @@ ) ), ] + +MaxDaysPerQueryType = Annotated[ + int, + typer.Option( + help=( + "Threshold for number of days between start and end dates at which a single query will " + "be split into multiple queries. Often the API gets overloaded and returns 500 (which " + "still consumes request quota) if the query returns lots of videos and the date range " + "is large. So reducing this can reduce 500 responses (and request quota consumption " + "from those) for queries that match lots of videos. IE if this is set to 3 and the " + "start and end date are 7 days apart the query will be split in 3 queries with start " + "and end dates: (start, start + 3), (start + 3, start + 6), (start + 6, start + 7)" + ) + ), +] diff --git a/src/tiktok_research_api_helper/cli/main.py b/src/tiktok_research_api_helper/cli/main.py index 678b75b..8ab0af3 100644 --- a/src/tiktok_research_api_helper/cli/main.py +++ b/src/tiktok_research_api_helper/cli/main.py @@ -41,12 +41,14 @@ IncludeAnyKeywordListType, JsonQueryFileType, MaxApiRequests, + MaxDaysPerQueryType, OnlyUsernamesListType, RawResponsesOutputDir, RegionCodeListType, StopAfterOneRequestFlag, TikTokEndDateFormat, TikTokStartDateFormat, + VideoIdType, ) from tiktok_research_api_helper.models import ( get_engine_and_create_tables, @@ -59,11 +61,13 @@ VideoQuery, VideoQueryJSONEncoder, generate_query, + generate_video_id_query, ) APP = typer.Typer(rich_markup_mode="markdown") -_DAYS_PER_ITER = 7 +_MAX_DAYS_PER_QUERY_DEFAULT = 7 +_DAYS_PER_QUERY_MAX_API_ALLOWED = 30 _DEFAULT_CREDENTIALS_FILE_PATH = Path("./secrets.yaml") @@ -77,16 +81,20 @@ def driver_single_day(client_config: ApiClientConfig, query_config: VideoQueryCo api_client.fetch_and_store_all(query_config) -def main_driver(api_client_config: ApiClientConfig, query_config: VideoQueryConfig): - days_per_iter = utils.int_to_days(_DAYS_PER_ITER) +def main_driver( + api_client_config: ApiClientConfig, + query_config: VideoQueryConfig, + max_days_per_query: int, +): + # TODO(macpd): maybe move this logic int TikTokApiClient + max_days_per_query = utils.int_to_days(max_days_per_query) start_date = copy(query_config.start_date) api_client = TikTokApiClient.from_config(api_client_config) while start_date < query_config.end_date: - # API limit is 30, we maintain 28 to be safe - local_end_date = start_date + days_per_iter + local_end_date = start_date + max_days_per_query local_end_date = min(local_end_date, query_config.end_date) local_query_config = attrs.evolve( query_config, start_date=start_date, end_date=local_end_date @@ -94,7 +102,7 @@ def main_driver(api_client_config: ApiClientConfig, query_config: VideoQueryConf api_client.fetch_and_store_all(local_query_config) - start_date += days_per_iter + start_date += max_days_per_query @APP.command() @@ -191,6 +199,7 @@ def print_query( exclude_all_keywords: ExcludeAllKeywordListType | None = None, only_from_usernames: OnlyUsernamesListType | None = None, exclude_from_usernames: ExcludeUsernamesListType | None = None, + video_id: VideoIdType | None = None, ) -> None: """Prints to stdout the query generated from flags. Useful for creating a base from which to build more complex custom JSON queries.""" @@ -206,6 +215,7 @@ def print_query( exclude_all_keywords, only_from_usernames, exclude_from_usernames, + video_id, ] ): raise typer.BadParameter( @@ -213,53 +223,73 @@ def print_query( "--include-all-hashtags, --exclude-all-hashtags, --include-any-keywords, " "--include-all-keywords, --exclude-any-keywords, --exclude-all-keywords, " "--include-any-usernames, --include-all-usernames, --exclude-any-usernames, " - "--exclude-all-usernames]" + "--exclude-all-usernames, --video-id]" ) - validate_mutually_exclusive_flags( - { - "--include-any-hashtags": include_any_hashtags, - "--include-all-hashtags": include_all_hashtags, - } - ) - validate_mutually_exclusive_flags( - { - "--exclude-any-hashtags": exclude_any_hashtags, - "--exclude-all-hashtags": exclude_all_hashtags, - } - ) - validate_mutually_exclusive_flags( - { - "--include-any-keywords": include_any_keywords, - "--include-all-keywords": include_all_keywords, - } - ) - validate_mutually_exclusive_flags( - { - "--exclude-any-keywords": exclude_any_keywords, - "--exclude-all-keywords": exclude_all_keywords, - } - ) - validate_mutually_exclusive_flags( - { - "--only-from-usernames": only_from_usernames, - "--exclude-from-usernames": exclude_from_usernames, - } - ) - validate_region_code_flag_value(region) + if video_id: + if any( + [ + include_any_hashtags, + exclude_any_hashtags, + include_all_hashtags, + exclude_all_hashtags, + include_any_keywords, + exclude_any_keywords, + include_all_keywords, + exclude_all_keywords, + only_from_usernames, + exclude_from_usernames, + ] + ): + raise typer.BadParameter( + "--video-id is mutually exclusisive with other query specification flags" + ) + query = generate_video_id_query(video_id) + else: + validate_mutually_exclusive_flags( + { + "--include-any-hashtags": include_any_hashtags, + "--include-all-hashtags": include_all_hashtags, + } + ) + validate_mutually_exclusive_flags( + { + "--exclude-any-hashtags": exclude_any_hashtags, + "--exclude-all-hashtags": exclude_all_hashtags, + } + ) + validate_mutually_exclusive_flags( + { + "--include-any-keywords": include_any_keywords, + "--include-all-keywords": include_all_keywords, + } + ) + validate_mutually_exclusive_flags( + { + "--exclude-any-keywords": exclude_any_keywords, + "--exclude-all-keywords": exclude_all_keywords, + } + ) + validate_mutually_exclusive_flags( + { + "--only-from-usernames": only_from_usernames, + "--exclude-from-usernames": exclude_from_usernames, + } + ) + validate_region_code_flag_value(region) - query = generate_query( - region_codes=region, - include_any_hashtags=include_any_hashtags, - include_all_hashtags=include_all_hashtags, - exclude_any_hashtags=exclude_any_hashtags, - exclude_all_hashtags=exclude_all_hashtags, - include_any_keywords=include_any_keywords, - include_all_keywords=include_all_keywords, - exclude_any_keywords=exclude_any_keywords, - exclude_all_keywords=exclude_all_keywords, - only_from_usernames=only_from_usernames, - exclude_from_usernames=exclude_from_usernames, - ) + query = generate_query( + region_codes=region, + include_any_hashtags=include_any_hashtags, + include_all_hashtags=include_all_hashtags, + exclude_any_hashtags=exclude_any_hashtags, + exclude_all_hashtags=exclude_all_hashtags, + include_any_keywords=include_any_keywords, + include_all_keywords=include_all_keywords, + exclude_any_keywords=exclude_any_keywords, + exclude_all_keywords=exclude_all_keywords, + only_from_usernames=only_from_usernames, + exclude_from_usernames=exclude_from_usernames, + ) print(json.dumps(query, cls=VideoQueryJSONEncoder, indent=2)) @@ -413,6 +443,7 @@ def run( db_url: DBUrlType | None = None, stop_after_one_request: StopAfterOneRequestFlag = False, max_api_requests: MaxApiRequests | None = None, + max_days_per_query: MaxDaysPerQueryType = _MAX_DAYS_PER_QUERY_DEFAULT, crawl_tag: CrawlTagType | None = None, raw_responses_output_dir: RawResponsesOutputDir | None = None, query_file_json: JsonQueryFileType | None = None, @@ -431,6 +462,7 @@ def run( exclude_all_keywords: ExcludeAllKeywordListType | None = None, only_from_usernames: OnlyUsernamesListType | None = None, exclude_from_usernames: ExcludeUsernamesListType | None = None, + video_id: VideoIdType | None = None, fetch_user_info: FetchUserInfoFlag | None = None, fetch_comments: FetchCommentsFlag | None = None, debug: EnableDebugLoggingFlag = False, @@ -456,6 +488,12 @@ def run( "only one." ) + if max_days_per_query > _DAYS_PER_QUERY_MAX_API_ALLOWED or max_days_per_query <= 0: + raise typer.BadParameter( + "--max-days-per-query must be a positive integer less than or equal to " + f"{_DAYS_PER_QUERY_MAX_API_ALLOWED}. This is a restriction of the tiktok research API." + ) + logging.log(logging.INFO, f"Arguments: {locals()}") # Using an actual datetime object instead of a string would not allows to @@ -481,6 +519,7 @@ def run( only_from_usernames, exclude_from_usernames, region, + video_id, ] ): raise typer.BadParameter( @@ -490,6 +529,26 @@ def run( ) query = get_query_file_json(query_file_json) + elif video_id: + if any( + [ + include_any_hashtags, + exclude_any_hashtags, + include_all_hashtags, + exclude_all_hashtags, + include_any_keywords, + exclude_any_keywords, + include_all_keywords, + exclude_all_keywords, + only_from_usernames, + exclude_from_usernames, + ] + ): + raise typer.BadParameter("--video_id is mutually exclusisive with other flags") + query = generate_video_id_query(video_id) + # Since query for a specific video by ID should only return 1 video, we use the max allowed + # date span for queries. + max_days_per_query = _DAYS_PER_QUERY_MAX_API_ALLOWED else: validate_mutually_exclusive_flags( { @@ -569,4 +628,4 @@ def run( driver_single_day(api_client_config, query_config) else: logging.info("Running main driver") - main_driver(api_client_config, query_config) + main_driver(api_client_config, query_config, max_days_per_query=max_days_per_query) diff --git a/src/tiktok_research_api_helper/query.py b/src/tiktok_research_api_helper/query.py index 5a61091..5a508d5 100644 --- a/src/tiktok_research_api_helper/query.py +++ b/src/tiktok_research_api_helper/query.py @@ -161,6 +161,10 @@ def default(self, o): return super().default(o) +def generate_video_id_query(video_id: int) -> VideoQuery: + return VideoQuery(**{_QUERY_AND_ARG_NAME: Cond(Fields.video_id, str(video_id), Op.EQ)}) + + def get_normalized_hashtag_set(comma_separated_hashtags: str) -> set[str]: """Takes a string of comma separated hashtag names and returns a set of hashtag names all lowercase and stripped of leading "#" if present.""" diff --git a/tests/test_query.py b/tests/test_query.py index 524acef..4451b38 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -9,6 +9,7 @@ VideoQuery, VideoQueryJSONEncoder, generate_query, + generate_video_id_query, get_normalized_hashtag_set, get_normalized_keyword_set, get_normalized_username_set, @@ -324,6 +325,20 @@ def test_normalized_username_set(test_input, expected): assert get_normalized_username_set(test_input) == expected +def test_generate_query_video_id(): + assert generate_video_id_query(1234567).as_dict() == { + "and": [ + { + "field_name": "video_id", + "field_values": [ + "1234567", + ], + "operation": "EQ", + } + ] + } + + def test_generate_query_include_any_hashtags(): assert generate_query(include_any_hashtags="this,that,other").as_dict() == { "and": [