Skip to content

Commit

Permalink
make --query-file-json a repeatedable flag so that mutliple query fil…
Browse files Browse the repository at this point in the history
…es can be passed and run serially. remove single_day_driver since main_driver can do that.
  • Loading branch information
macpd committed Aug 19, 2024
1 parent cacd6e3 commit 146ebd7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/tiktok_research_api_helper/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def api_results_iter(self, query_config: VideoQueryConfig) -> TikTokApiClientFet
)
logging.debug("Crawl: %s", crawl)

logging.info("Beginning API results fetch.")
logging.info("Beginning API results fetch VideoQueryConfig: %s.", query_config)
while crawl.has_more:
request = TikTokVideoRequest.from_config(
config=query_config,
Expand Down
8 changes: 4 additions & 4 deletions src/tiktok_research_api_helper/cli/custom_argument_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@

DBUrlType = Annotated[Optional[str], typer.Option(help="database URL for storing API results")]

JsonQueryFileType = Annotated[
Optional[Path],
JsonQueryFileListType = Annotated[
Optional[list[Path]],
typer.Option(
"--query-file-json",
exists=True,
file_okay=True,
dir_okay=False,
help=(
"Path to file query as JSON. File contents will be parsed as JSON and used directly "
"in query of API requests. Can be used multiple times to run multiple queries (in "
"serial)"
"in query of API requests. Can be used multiple times to run multiple queries serially."
),
),
]
Expand Down
85 changes: 36 additions & 49 deletions src/tiktok_research_api_helper/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
IncludeAllKeywordListType,
IncludeAnyHashtagListType,
IncludeAnyKeywordListType,
JsonQueryFileType,
JsonQueryFileListType,
MaxApiRequests,
MaxDaysPerQueryType,
OnlyUsernamesListType,
Expand Down Expand Up @@ -71,16 +71,6 @@
_DEFAULT_CREDENTIALS_FILE_PATH = Path("./secrets.yaml")


def driver_single_day(client_config: ApiClientConfig, query_config: VideoQueryConfig):
"""Simpler driver for a single day of query"""
assert (
query_config.start_date == query_config.end_date
), "Start and final date must be the same for single day driver"

api_client = TikTokApiClient.from_config(client_config)
api_client.fetch_and_store_all(query_config)


def main_driver(
api_client_config: ApiClientConfig,
query_config: VideoQueryConfig,
Expand All @@ -93,7 +83,7 @@ def main_driver(

api_client = TikTokApiClient.from_config(api_client_config)

while start_date < query_config.end_date:
while start_date <= query_config.end_date:
local_end_date = start_date + max_days_per_query
local_end_date = min(local_end_date, query_config.end_date)
local_query_config = attrs.evolve(
Expand Down Expand Up @@ -311,7 +301,7 @@ def run_repeated(
db_url: DBUrlType = None,
crawl_tag: CrawlTagType = None,
raw_responses_output_dir: RawResponsesOutputDir = None,
query_file_json: JsonQueryFileType = None,
query_file_json_list: JsonQueryFileListType = None,
api_credentials_file: ApiCredentialsFileType = _DEFAULT_CREDENTIALS_FILE_PATH,
rate_limit_wait_strategy: ApiRateLimitWaitStrategyType = (
ApiRateLimitWaitStrategy.WAIT_FOUR_HOURS
Expand Down Expand Up @@ -356,7 +346,7 @@ def run_repeated(
db_url=db_url,
crawl_tag=crawl_tag,
raw_responses_output_dir=raw_responses_output_dir,
query_file_json=query_file_json,
query_file_json_list=query_file_json_list,
api_credentials_file=api_credentials_file,
rate_limit_wait_strategy=rate_limit_wait_strategy,
region=region,
Expand Down Expand Up @@ -446,7 +436,7 @@ def run(
max_days_per_query: MaxDaysPerQueryType = _MAX_DAYS_PER_QUERY_DEFAULT,
crawl_tag: CrawlTagType = None,
raw_responses_output_dir: RawResponsesOutputDir = None,
query_file_json: JsonQueryFileType = None,
query_file_json_list: JsonQueryFileListType = None,
api_credentials_file: ApiCredentialsFileType = _DEFAULT_CREDENTIALS_FILE_PATH,
rate_limit_wait_strategy: ApiRateLimitWaitStrategyType = (
ApiRateLimitWaitStrategy.WAIT_FOUR_HOURS
Expand Down Expand Up @@ -505,7 +495,7 @@ def run(
{"--db-url": db_url, "--db-file": db_file}, at_least_one_required=True
)

if query_file_json:
if query_file_json_list:
if any(
[
include_any_hashtags,
Expand All @@ -528,7 +518,7 @@ def run(
"--include-any-keywords, etc"
)

query = get_query_file_json(query_file_json)
query_list = [get_query_file_json(query_file) for query_file in query_file_json_list]
elif video_id:
if any(
[
Expand All @@ -545,7 +535,7 @@ def run(
]
):
raise typer.BadParameter("--video_id is mutually exclusisive with other flags")
query = generate_video_id_query(video_id)
query_list = [generate_video_id_query(video_id)]
# Since query for a specific video by ID should only return 1 video, we use the max allowed
# date span for queries.
max_days_per_query = _DAYS_PER_QUERY_MAX_API_ALLOWED
Expand Down Expand Up @@ -583,21 +573,21 @@ def run(

validate_region_code_flag_value(region)

query = generate_query(
region_codes=region,
include_any_hashtags=include_any_hashtags,
include_all_hashtags=include_all_hashtags,
exclude_any_hashtags=exclude_any_hashtags,
exclude_all_hashtags=exclude_all_hashtags,
include_any_keywords=include_any_keywords,
include_all_keywords=include_all_keywords,
exclude_any_keywords=exclude_any_keywords,
exclude_all_keywords=exclude_all_keywords,
only_from_usernames=only_from_usernames,
exclude_from_usernames=exclude_from_usernames,
)

logging.log(logging.INFO, f"VideoQuery: {query}")
query_list = [
generate_query(
region_codes=region,
include_any_hashtags=include_any_hashtags,
include_all_hashtags=include_all_hashtags,
exclude_any_hashtags=exclude_any_hashtags,
exclude_all_hashtags=exclude_all_hashtags,
include_any_keywords=include_any_keywords,
include_all_keywords=include_all_keywords,
exclude_any_keywords=exclude_any_keywords,
exclude_all_keywords=exclude_all_keywords,
only_from_usernames=only_from_usernames,
exclude_from_usernames=exclude_from_usernames,
)
]

if db_url:
engine = get_engine_and_create_tables(db_url)
Expand All @@ -611,21 +601,18 @@ def run(
api_credentials_file=api_credentials_file,
api_rate_limit_wait_strategy=rate_limit_wait_strategy,
)
query_config = VideoQueryConfig(
query=query,
start_date=start_date_datetime,
end_date=end_date_datetime,
crawl_tags=[crawl_tag] if crawl_tag else None,
fetch_user_info=fetch_user_info,
fetch_comments=fetch_comments,
)
logging.info("API client config: %s\nVideo query config: %s", api_client_config, query_config)

if query_config.start_date == query_config.end_date:
logging.info(
"Start and final date are the same - running single day driver",
query_configs = [
VideoQueryConfig(
query=query,
start_date=start_date_datetime,
end_date=end_date_datetime,
crawl_tags=[crawl_tag] if crawl_tag else None,
fetch_user_info=fetch_user_info,
fetch_comments=fetch_comments,
)
driver_single_day(api_client_config, query_config)
else:
logging.info("Running main driver")
for query in query_list
]
logging.info("API client config: %s\nVideo query configs: %s", api_client_config, query_configs)

for query_config in query_configs:
main_driver(api_client_config, query_config, max_days_per_query=max_days_per_query)

0 comments on commit 146ebd7

Please sign in to comment.