Skip to content

Commit

Permalink
add --video-id and --max-days-per-query flags (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
macpd authored Aug 19, 2024
1 parent 5fde36f commit 00757ef
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 52 deletions.
17 changes: 17 additions & 0 deletions src/tiktok_research_api_helper/cli/custom_argument_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@
),
]

VideoIdType = Annotated[int, typer.Option(help="ID of specific video to query for.")]

CrawlTagType = Annotated[
str,
typer.Option(
Expand Down Expand Up @@ -239,3 +241,18 @@
)
),
]

MaxDaysPerQueryType = Annotated[
int,
typer.Option(
help=(
"Threshold for number of days between start and end dates at which a single query will "
"be split into multiple queries. Often the API gets overloaded and returns 500 (which "
"still consumes request quota) if the query returns lots of videos and the date range "
"is large. So reducing this can reduce 500 responses (and request quota consumption "
"from those) for queries that match lots of videos. IE if this is set to 3 and the "
"start and end date are 7 days apart the query will be split in 3 queries with start "
"and end dates: (start, start + 3), (start + 3, start + 6), (start + 6, start + 7)"
)
),
]
163 changes: 111 additions & 52 deletions src/tiktok_research_api_helper/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
IncludeAnyKeywordListType,
JsonQueryFileType,
MaxApiRequests,
MaxDaysPerQueryType,
OnlyUsernamesListType,
RawResponsesOutputDir,
RegionCodeListType,
StopAfterOneRequestFlag,
TikTokEndDateFormat,
TikTokStartDateFormat,
VideoIdType,
)
from tiktok_research_api_helper.models import (
get_engine_and_create_tables,
Expand All @@ -59,11 +61,13 @@
VideoQuery,
VideoQueryJSONEncoder,
generate_query,
generate_video_id_query,
)

APP = typer.Typer(rich_markup_mode="markdown")

_DAYS_PER_ITER = 7
_MAX_DAYS_PER_QUERY_DEFAULT = 7
_DAYS_PER_QUERY_MAX_API_ALLOWED = 30
_DEFAULT_CREDENTIALS_FILE_PATH = Path("./secrets.yaml")


Expand All @@ -77,24 +81,28 @@ def driver_single_day(client_config: ApiClientConfig, query_config: VideoQueryCo
api_client.fetch_and_store_all(query_config)


def main_driver(api_client_config: ApiClientConfig, query_config: VideoQueryConfig):
days_per_iter = utils.int_to_days(_DAYS_PER_ITER)
def main_driver(
api_client_config: ApiClientConfig,
query_config: VideoQueryConfig,
max_days_per_query: int,
):
# TODO(macpd): maybe move this logic int TikTokApiClient
max_days_per_query = utils.int_to_days(max_days_per_query)

start_date = copy(query_config.start_date)

api_client = TikTokApiClient.from_config(api_client_config)

while start_date < query_config.end_date:
# API limit is 30, we maintain 28 to be safe
local_end_date = start_date + days_per_iter
local_end_date = start_date + max_days_per_query
local_end_date = min(local_end_date, query_config.end_date)
local_query_config = attrs.evolve(
query_config, start_date=start_date, end_date=local_end_date
)

api_client.fetch_and_store_all(local_query_config)

start_date += days_per_iter
start_date += max_days_per_query


@APP.command()
Expand Down Expand Up @@ -191,6 +199,7 @@ def print_query(
exclude_all_keywords: ExcludeAllKeywordListType | None = None,
only_from_usernames: OnlyUsernamesListType | None = None,
exclude_from_usernames: ExcludeUsernamesListType | None = None,
video_id: VideoIdType | None = None,
) -> None:
"""Prints to stdout the query generated from flags. Useful for creating a base from which to
build more complex custom JSON queries."""
Expand All @@ -206,60 +215,81 @@ def print_query(
exclude_all_keywords,
only_from_usernames,
exclude_from_usernames,
video_id,
]
):
raise typer.BadParameter(
"must specify at least one of [--include-any-hashtags, --exclude-any-hashtags, "
"--include-all-hashtags, --exclude-all-hashtags, --include-any-keywords, "
"--include-all-keywords, --exclude-any-keywords, --exclude-all-keywords, "
"--include-any-usernames, --include-all-usernames, --exclude-any-usernames, "
"--exclude-all-usernames]"
"--exclude-all-usernames, --video-id]"
)
validate_mutually_exclusive_flags(
{
"--include-any-hashtags": include_any_hashtags,
"--include-all-hashtags": include_all_hashtags,
}
)
validate_mutually_exclusive_flags(
{
"--exclude-any-hashtags": exclude_any_hashtags,
"--exclude-all-hashtags": exclude_all_hashtags,
}
)
validate_mutually_exclusive_flags(
{
"--include-any-keywords": include_any_keywords,
"--include-all-keywords": include_all_keywords,
}
)
validate_mutually_exclusive_flags(
{
"--exclude-any-keywords": exclude_any_keywords,
"--exclude-all-keywords": exclude_all_keywords,
}
)
validate_mutually_exclusive_flags(
{
"--only-from-usernames": only_from_usernames,
"--exclude-from-usernames": exclude_from_usernames,
}
)
validate_region_code_flag_value(region)
if video_id:
if any(
[
include_any_hashtags,
exclude_any_hashtags,
include_all_hashtags,
exclude_all_hashtags,
include_any_keywords,
exclude_any_keywords,
include_all_keywords,
exclude_all_keywords,
only_from_usernames,
exclude_from_usernames,
]
):
raise typer.BadParameter(
"--video-id is mutually exclusisive with other query specification flags"
)
query = generate_video_id_query(video_id)
else:
validate_mutually_exclusive_flags(
{
"--include-any-hashtags": include_any_hashtags,
"--include-all-hashtags": include_all_hashtags,
}
)
validate_mutually_exclusive_flags(
{
"--exclude-any-hashtags": exclude_any_hashtags,
"--exclude-all-hashtags": exclude_all_hashtags,
}
)
validate_mutually_exclusive_flags(
{
"--include-any-keywords": include_any_keywords,
"--include-all-keywords": include_all_keywords,
}
)
validate_mutually_exclusive_flags(
{
"--exclude-any-keywords": exclude_any_keywords,
"--exclude-all-keywords": exclude_all_keywords,
}
)
validate_mutually_exclusive_flags(
{
"--only-from-usernames": only_from_usernames,
"--exclude-from-usernames": exclude_from_usernames,
}
)
validate_region_code_flag_value(region)

query = generate_query(
region_codes=region,
include_any_hashtags=include_any_hashtags,
include_all_hashtags=include_all_hashtags,
exclude_any_hashtags=exclude_any_hashtags,
exclude_all_hashtags=exclude_all_hashtags,
include_any_keywords=include_any_keywords,
include_all_keywords=include_all_keywords,
exclude_any_keywords=exclude_any_keywords,
exclude_all_keywords=exclude_all_keywords,
only_from_usernames=only_from_usernames,
exclude_from_usernames=exclude_from_usernames,
)
query = generate_query(
region_codes=region,
include_any_hashtags=include_any_hashtags,
include_all_hashtags=include_all_hashtags,
exclude_any_hashtags=exclude_any_hashtags,
exclude_all_hashtags=exclude_all_hashtags,
include_any_keywords=include_any_keywords,
include_all_keywords=include_all_keywords,
exclude_any_keywords=exclude_any_keywords,
exclude_all_keywords=exclude_all_keywords,
only_from_usernames=only_from_usernames,
exclude_from_usernames=exclude_from_usernames,
)

print(json.dumps(query, cls=VideoQueryJSONEncoder, indent=2))

Expand Down Expand Up @@ -413,6 +443,7 @@ def run(
db_url: DBUrlType | None = None,
stop_after_one_request: StopAfterOneRequestFlag = False,
max_api_requests: MaxApiRequests | None = None,
max_days_per_query: MaxDaysPerQueryType = _MAX_DAYS_PER_QUERY_DEFAULT,
crawl_tag: CrawlTagType | None = None,
raw_responses_output_dir: RawResponsesOutputDir | None = None,
query_file_json: JsonQueryFileType | None = None,
Expand All @@ -431,6 +462,7 @@ def run(
exclude_all_keywords: ExcludeAllKeywordListType | None = None,
only_from_usernames: OnlyUsernamesListType | None = None,
exclude_from_usernames: ExcludeUsernamesListType | None = None,
video_id: VideoIdType | None = None,
fetch_user_info: FetchUserInfoFlag | None = None,
fetch_comments: FetchCommentsFlag | None = None,
debug: EnableDebugLoggingFlag = False,
Expand All @@ -456,6 +488,12 @@ def run(
"only one."
)

if max_days_per_query > _DAYS_PER_QUERY_MAX_API_ALLOWED or max_days_per_query <= 0:
raise typer.BadParameter(
"--max-days-per-query must be a positive integer less than or equal to "
f"{_DAYS_PER_QUERY_MAX_API_ALLOWED}. This is a restriction of the tiktok research API."
)

logging.log(logging.INFO, f"Arguments: {locals()}")

# Using an actual datetime object instead of a string would not allows to
Expand All @@ -481,6 +519,7 @@ def run(
only_from_usernames,
exclude_from_usernames,
region,
video_id,
]
):
raise typer.BadParameter(
Expand All @@ -490,6 +529,26 @@ def run(
)

query = get_query_file_json(query_file_json)
elif video_id:
if any(
[
include_any_hashtags,
exclude_any_hashtags,
include_all_hashtags,
exclude_all_hashtags,
include_any_keywords,
exclude_any_keywords,
include_all_keywords,
exclude_all_keywords,
only_from_usernames,
exclude_from_usernames,
]
):
raise typer.BadParameter("--video_id is mutually exclusisive with other flags")
query = generate_video_id_query(video_id)
# Since query for a specific video by ID should only return 1 video, we use the max allowed
# date span for queries.
max_days_per_query = _DAYS_PER_QUERY_MAX_API_ALLOWED
else:
validate_mutually_exclusive_flags(
{
Expand Down Expand Up @@ -569,4 +628,4 @@ def run(
driver_single_day(api_client_config, query_config)
else:
logging.info("Running main driver")
main_driver(api_client_config, query_config)
main_driver(api_client_config, query_config, max_days_per_query=max_days_per_query)
4 changes: 4 additions & 0 deletions src/tiktok_research_api_helper/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def default(self, o):
return super().default(o)


def generate_video_id_query(video_id: int) -> VideoQuery:
return VideoQuery(**{_QUERY_AND_ARG_NAME: Cond(Fields.video_id, str(video_id), Op.EQ)})


def get_normalized_hashtag_set(comma_separated_hashtags: str) -> set[str]:
"""Takes a string of comma separated hashtag names and returns a set of hashtag names all
lowercase and stripped of leading "#" if present."""
Expand Down
15 changes: 15 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
VideoQuery,
VideoQueryJSONEncoder,
generate_query,
generate_video_id_query,
get_normalized_hashtag_set,
get_normalized_keyword_set,
get_normalized_username_set,
Expand Down Expand Up @@ -324,6 +325,20 @@ def test_normalized_username_set(test_input, expected):
assert get_normalized_username_set(test_input) == expected


def test_generate_query_video_id():
assert generate_video_id_query(1234567).as_dict() == {
"and": [
{
"field_name": "video_id",
"field_values": [
"1234567",
],
"operation": "EQ",
}
]
}


def test_generate_query_include_any_hashtags():
assert generate_query(include_any_hashtags="this,that,other").as_dict() == {
"and": [
Expand Down

0 comments on commit 00757ef

Please sign in to comment.