diff --git a/etl/src/birdxplorer_etl/extract.py b/etl/src/birdxplorer_etl/extract.py index ee9b306..4aa2719 100644 --- a/etl/src/birdxplorer_etl/extract.py +++ b/etl/src/birdxplorer_etl/extract.py @@ -55,7 +55,9 @@ def extract_data(sqlite: Session, postgresql: Session): for index, row in enumerate(reader): if sqlite.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first(): continue - rows_to_add.append(RowNoteRecord(**row)) + row_note = RowNoteRecord(**row) + row_note.row_post_id = row["tweet_id"] + rows_to_add.append(row_note) if index % 1000 == 0: sqlite.bulk_save_objects(rows_to_add) rows_to_add = [] @@ -96,11 +98,18 @@ def extract_data(sqlite: Session, postgresql: Session): sqlite.commit() # Noteに紐づくtweetデータを取得 + targetStartDate = int(settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND) or int( + datetime.combine(datetime.now() - timedelta(days=2), datetime.min.time()).timestamp() * 1000 + ) + default_end_time = int(settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) or int( + datetime.combine(datetime.now() - timedelta(days=1) - timedelta(hours=20), datetime.min.time()).timestamp() + * 1000 + ) postExtract_targetNotes = ( sqlite.query(RowNoteRecord) .filter(RowNoteRecord.tweet_id != None) - .filter(RowNoteRecord.created_at_millis >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND) - .filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) + .filter(RowNoteRecord.created_at_millis >= targetStartDate) + .filter(RowNoteRecord.created_at_millis <= default_end_time) .all() ) logging.info(f"Target notes: {len(postExtract_targetNotes)}") diff --git a/etl/src/birdxplorer_etl/lib/x/postlookup.py b/etl/src/birdxplorer_etl/lib/x/postlookup.py index a7422f9..343c9f2 100644 --- a/etl/src/birdxplorer_etl/lib/x/postlookup.py +++ b/etl/src/birdxplorer_etl/lib/x/postlookup.py @@ -34,7 +34,11 @@ def connect_to_endpoint(url): if response.status_code == 429: limit = response.headers["x-rate-limit-reset"] logging.info("Waiting for rate limit reset...") - time.sleep(int(limit) - int(time.time()) + 1) + logging.info("Time now: {}".format(int(limit))) + logging.info("Time to wait: {}".format(int(time.time()))) + wait_time = int(limit) - int(time.time()) + 1 + if (wait_time) > 0: + time.sleep(int(limit) - int(time.time()) + 1) data = connect_to_endpoint(url) return data elif response.status_code != 200: diff --git a/etl/src/birdxplorer_etl/settings.py b/etl/src/birdxplorer_etl/settings.py index c09eeaf..fa0d6dc 100644 --- a/etl/src/birdxplorer_etl/settings.py +++ b/etl/src/birdxplorer_etl/settings.py @@ -4,10 +4,8 @@ load_dotenv() -default_start_time = int(datetime.combine(datetime.now() - timedelta(days=2), datetime.min.time()).timestamp() * 1000) -default_end_time = int( - datetime.combine(datetime.now() - timedelta(days=1) - timedelta(hours=20), datetime.min.time()).timestamp() * 1000 -) +default_start_time = datetime.combine(datetime.now() - timedelta(days=2), datetime.min.time()).timestamp() * 1000 +default_end_time = datetime.combine(datetime.now() - timedelta(days=1), datetime.min.time()).timestamp() * 1000 TARGET_TWITTER_POST_START_UNIX_MILLISECOND = int( os.getenv("TARGET_TWITTER_POST_START_UNIX_MILLISECOND", default_start_time) @@ -21,11 +19,11 @@ AI_MODEL = os.getenv("AI_MODEL") OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN") CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN") -TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = os.getenv( - "TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND", default_start_time +TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = int( + os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND", default_start_time) ) -TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = os.getenv( - "TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND", default_end_time +TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = int( + os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND", default_end_time) ) USE_DUMMY_DATA = os.getenv("USE_DUMMY_DATA", "False") == "True" diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index ca450d3..e9ee5aa 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -17,8 +17,8 @@ RowPostRecord, RowUserRecord, ) -from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service -from birdxplorer_etl.settings import ( +from lib.ai_model.ai_model_interface import get_ai_service +from settings import ( TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, ) @@ -65,6 +65,12 @@ def transform_data(sqlite: Session, postgresql: Session): RowNoteStatusRecord.current_status, func.cast(RowNoteRecord.created_at_millis, Integer).label("created_at"), ) + .filter( + and_( + RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + ) + ) .join(RowNoteStatusRecord, RowNoteRecord.note_id == RowNoteStatusRecord.note_id) .limit(limit) .offset(offset) @@ -73,6 +79,7 @@ def transform_data(sqlite: Session, postgresql: Session): for note in notes: note_as_list = list(note) note_as_list.append(ai_service.detect_language(note[2])) + note_as_list.append("ja") writer = csv.writer(file) writer.writerow(note_as_list) offset += limit @@ -156,28 +163,7 @@ def transform_data(sqlite: Session, postgresql: Session): generate_post_link(postgresql) # Transform row post embed url data and generate post_embed_url.csv - csv_seed_file_path = "./seed/topic_seed.csv" - output_csv_file_path = "./data/transformed/topic.csv" - records = [] - - if os.path.exists(output_csv_file_path): - return - - with open(csv_seed_file_path, newline="", encoding="utf-8") as csvfile: - reader = csv.DictReader(csvfile) - for index, row in enumerate(reader): - if "ja" in row and row["ja"]: - topic_id = index + 1 - label = {"ja": row["ja"], "en": row["en"]} # Assuming the label is in Japanese - record = {"topic_id": topic_id, "label": label} - records.append(record) - - with open(output_csv_file_path, "a", newline="", encoding="utf-8") as file: - fieldnames = ["topic_id", "label"] - writer = csv.DictWriter(file, fieldnames=fieldnames) - writer.writeheader() - for record in records: - writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) + generate_topic() generate_note_topic(sqlite) @@ -272,6 +258,31 @@ def generate_post_link(postgresql: Session): offset += limit +def generate_topic(): + csv_seed_file_path = "./seed/topic_seed.csv" + output_csv_file_path = "./data/transformed/topic.csv" + records = [] + + if os.path.exists(output_csv_file_path): + return + + with open(csv_seed_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for index, row in enumerate(reader): + if "ja" in row and row["ja"]: + topic_id = index + 1 + label = {"ja": row["ja"], "en": row["en"]} # Assuming the label is in Japanese + record = {"topic_id": topic_id, "label": label} + records.append(record) + + with open(output_csv_file_path, "a", newline="", encoding="utf-8") as file: + fieldnames = ["topic_id", "label"] + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + for record in records: + writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) + + def generate_note_topic(sqlite: Session): output_csv_file_path = "./data/transformed/note_topic_association.csv" ai_service = get_ai_service()