Skip to content

Commit

Permalink
move top music ID spidering code out of run_long_query and into own f…
Browse files Browse the repository at this point in the history
…unc. handle multiple crawls/crawl_ids (in case main_driver divides crawl due to > 28 days). try to handle more than 1 day of fetch where api requests sent will be greater than 1000.
  • Loading branch information
macpd committed Jul 1, 2024
1 parent 36374bf commit 90ce142
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 50 deletions.
11 changes: 7 additions & 4 deletions src/tiktok_research_api_helper/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
INVALID_SEARCH_ID_ERROR_RETRY_WAIT = 5
INVALID_SEARCH_ID_ERROR_MAX_NUM_RETRIES = 5

DAILY_API_REQUEST_QUOTA = 1000


class ApiRateLimitError(Exception):
pass
Expand Down Expand Up @@ -550,7 +552,7 @@ def num_api_requests_sent(self):

@property
def expected_remaining_api_request_quota(self):
return 1000 - self.num_api_requests_sent
return DAILY_API_REQUEST_QUOTA - self.num_api_requests_sent

def api_results_iter(self) -> TikTokApiClientFetchResult:
"""Fetches all results from API (ie requests until API indicates query results have been
Expand Down Expand Up @@ -602,8 +604,9 @@ def api_results_iter(self) -> TikTokApiClientFetchResult:
break

logging.info(
"Crawl completed (or reached configured max_requests). Num api requests: %s. Expected "
"Crawl completed (or reached configured max_requests: %s). Num api requests: %s. Expected "
"remaining API request quota: %s",
self._config.max_requests,
self.num_api_requests_sent,
self.expected_remaining_api_request_quota,
)
Expand All @@ -614,14 +617,14 @@ def _max_requests_reached(self) -> bool:
def _should_continue(self, crawl: Crawl) -> bool:
should_continue = crawl.has_more and not self._max_requests_reached()
logging.debug(
"crawl.has_more: %s, max_requests_reached: %s, shoudld_continue: %s",
"crawl.has_more: %s, max_requests_reached: %s, should_continue: %s",
crawl.has_more,
self._max_requests_reached(),
should_continue,
)
if crawl.has_more and self._max_requests_reached():
logging.info(
"Max requests rewached. Will discontinue this crawl even though API response "
"Max requests reached. Will discontinue this crawl even though API response "
"indicates more results."
)
return should_continue
Expand Down
119 changes: 81 additions & 38 deletions src/tiktok_research_api_helper/cli_data_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from pathlib import Path
from typing import Annotated, Any

import attrs
import pause
import pendulum
import typer
from sqlalchemy.orm import Session

from tiktok_research_api_helper import region_codes, utils
from tiktok_research_api_helper.api_client import (
DAILY_API_REQUEST_QUOTA,
ApiClientConfig,
ApiRateLimitWaitStrategy,
TikTokApiClient,
Expand Down Expand Up @@ -66,55 +68,77 @@
CrawlDateWindow = namedtuple("CrawlDateWindow", ["start_date", "end_date"])


def run_long_query(config: ApiClientConfig, spider_top_n_music_ids: int | None = None):
def run_long_query(config: ApiClientConfig) -> int:
"""Runs a "long" query, defined as one that may need multiple requests to get all the data.
Unless you have a good reason to believe otherwise, queries should default to be considered
"long"."""
"long".
Returns:
int: number of API requests sent (ie likely amount of API quota consumed).
"""
api_client = TikTokApiClient.from_config(config)
fetch_results = api_client.fetch_and_store_all()
# TODO(macpd): fix this return value structure. maybe a namedtuple
return {
"num_api_requests_sent": api_client.num_api_requests_sent,
"crawl_id": fetch_results.crawl.id,
}

if spider_top_n_music_ids is None:
return

potentially_remaining_qutoa = api_client.expected_remaining_api_request_quota
if potentially_remaining_qutoa <= 0:
# TODO(macpd): if crawl takes more than one day, and thus more than 1 day of allowed
# requests have been made, this check will be wrong.
logging.info("Refusing to spider top music IDs because no API quota remaints")
return
def driver_single_day(config: ApiClientConfig, spider_top_n_music_ids) -> int:
"""Simpler driver for a single day of query.
Returns:
int: number of API requests sent (ie likely amount of API quota consumed).
"""
assert (
config.start_date == config.end_date
), "Start and final date must be the same for single day driver"

return run_long_query(config, spider_top_n_music_ids)


config.max_requests = potentially_remaining_qutoa
def run_spider_top_n_music_ids_query(
config: ApiClientConfig,
spider_top_n_music_ids: int,
crawl_ids: Sequence[int],
expected_remaining_api_request_quota: int,
) -> int:
if expected_remaining_api_request_quota <= 0:
# TODO(macpd): add some way to bypass this.
logging.info("Refusing to spider top music IDs because no API quota remains")
return

crawl_id = fetch_results.crawl.id
with Session(config.engine) as session:
top_music_ids = most_used_music_ids(
session,
limit=None if spider_top_n_music_ids == 0 else spider_top_n_music_ids,
crawl_id=crawl_id,
crawl_ids=crawl_ids,
)
new_query = generate_query(
include_music_ids=",".join([str(x["music_id"]) for x in top_music_ids])
)
config.query = new_query

new_crawl_tags = None
if config.crawl_tags:
config.crawl_tags = [f"{tag}-music-id-spidering" for tag in config.crawl_tags]

api_client = TikTokApiClient.from_config(config)
fetch_results = api_client.fetch_and_store_all()

new_crawl_tags = [f"{tag}-music-id-spidering" for tag in config.crawl_tags]

def driver_single_day(config: ApiClientConfig, spider_top_n_music_ids):
"""Simpler driver for a single day of query"""
assert (
config.start_date == config.end_date
), "Start and final date must be the same for single day driver"
new_config = attrs.evolve(
config,
max_requests=expected_remaining_api_request_quota,
query=new_query,
crawl_tags=new_crawl_tags,
)

run_long_query(config, spider_top_n_music_ids)
api_client = TikTokApiClient.from_config(new_config)
api_client.fetch_and_store_all()
return api_client.num_api_requests_sent


def main_driver(config: ApiClientConfig, spider_top_n_music_ids: int | None = None):
num_api_requests_sent = 0
crawl_ids = []
days_per_iter = utils.int_to_days(_DAYS_PER_ITER)

start_date = copy(config.start_date)
Expand All @@ -124,24 +148,43 @@ def main_driver(config: ApiClientConfig, spider_top_n_music_ids: int | None = No
local_end_date = start_date + days_per_iter
local_end_date = min(local_end_date, config.end_date)

new_config = ApiClientConfig(
query=config.query,
start_date=start_date,
end_date=local_end_date,
engine=config.engine,
stop_after_one_request=config.stop_after_one_request,
crawl_tags=config.crawl_tags,
raw_responses_output_dir=config.raw_responses_output_dir,
api_credentials_file=config.api_credentials_file,
api_rate_limit_wait_strategy=config.api_rate_limit_wait_strategy,
)
run_long_query(new_config, spider_top_n_music_ids=spider_top_n_music_ids)
new_config = attrs.evolve(config, start_date=start_date, end_date=local_end_date)

# ApiClientConfig(
# query=config.query,
# start_date=start_date,
# end_date=local_end_date,
# engine=config.engine,
# stop_after_one_request=config.stop_after_one_request,
# crawl_tags=config.crawl_tags,
# raw_responses_output_dir=config.raw_responses_output_dir,
# api_credentials_file=config.api_credentials_file,
# api_rate_limit_wait_strategy=config.api_rate_limit_wait_strategy,
# )
ret = run_long_query(new_config)
num_api_requests_sent += ret["num_api_requests_sent"]
crawl_ids.append(ret["crawl_id"])

start_date += days_per_iter

if config.stop_after_one_request:
logging.log(logging.WARN, "Stopping after one request")
break
return

expected_remaining_api_request_quota = 0
if spider_top_n_music_ids:
if num_api_requests_sent == DAILY_API_REQUEST_QUOTA:
# TODO(macpd): handle no remaing quota, perhaps flag to do anyway?
logging.warning("Refusing to spider top music IDs because no API quota remains")
return

expected_remaining_api_request_quota = DAILY_API_REQUEST_QUOTA - (num_api_requests_sent % DAILY_API_REQUEST_QUOTA)
run_spider_top_n_music_ids_query(
config=config,
crawl_ids=crawl_ids,
spider_top_n_music_ids=spider_top_n_music_ids,
expected_remaining_api_request_quota=expected_remaining_api_request_quota,
)


@APP.command()
Expand Down
12 changes: 7 additions & 5 deletions src/tiktok_research_api_helper/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,17 @@ def effect_ids(self):
return {effect.effect_id for effect in self.effects}


def most_used_music_ids(session: Session, limit: int | None = None, crawl_id: int | None = None):
"""Returns dict of most used music_ids with count of video id with that music_id. If crawl_id
specified, only operates on videos associated to that video id.
def most_used_music_ids(
session: Session, limit: int | None = None, crawl_ids: Sequence[int] | None = None
):
"""Returns dict of most used music_ids with count of video id with that music_id. If crawl_ids
specified, only operates on videos associated to those crawl IDs.
"""
if crawl_id:
if crawl_ids:
select_stmt = (
select(Video.music_id, func.count(Video.id).label("num_videos"))
.join(Video.crawls)
.where(Crawl.id == crawl_id)
.where(Crawl.id.in_(crawl_ids))
)
else:
select_stmt = select(Video.music_id, func.count(Video.id).label("num_videos"))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def test_most_used_music_ids(test_database_engine, mock_crawl, api_response_vide
key=lambda x: x["music_id"],
)
assert sorted(
most_used_music_ids(session, limit=None, crawl_id=mock_crawl_id),
most_used_music_ids(session, limit=None, crawl_ids=[mock_crawl_id]),
key=lambda x: x["music_id"],
) == sorted(
[{"music_id": x["music_id"], "num_videos": 1} for x in api_response_videos],
Expand All @@ -698,11 +698,11 @@ def test_most_used_music_ids(test_database_engine, mock_crawl, api_response_vide
{"music_id": new_music_id, "num_videos": len(api_response_videos) + 1},
{"music_id": 6817429116187314177, "num_videos": 1},
]
assert most_used_music_ids(session, limit=2, crawl_id=mock_crawl_id) == [
assert most_used_music_ids(session, limit=2, crawl_ids=[mock_crawl_id]) == [
{"music_id": new_music_id, "num_videos": len(api_response_videos) + 1},
{"music_id": 6817429116187314177, "num_videos": 1},
]
assert most_used_music_ids(session, limit=2, crawl_id=mock_crawl_id + 1) == []
assert most_used_music_ids(session, limit=2, crawl_ids=[mock_crawl_id + 1]) == []


def test_remove_all(test_database_engine, mock_videos, mock_crawl):
Expand Down

0 comments on commit 90ce142

Please sign in to comment.