Skip to content

Commit

Permalink
fix etl
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 committed Nov 1, 2024
1 parent 2add18a commit 7e0e6ad
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 36 deletions.
15 changes: 12 additions & 3 deletions etl/src/birdxplorer_etl/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def extract_data(sqlite: Session, postgresql: Session):
for index, row in enumerate(reader):
if sqlite.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first():
continue
rows_to_add.append(RowNoteRecord(**row))
row_note = RowNoteRecord(**row)
row_note.row_post_id = row["tweet_id"]
rows_to_add.append(row_note)
if index % 1000 == 0:
sqlite.bulk_save_objects(rows_to_add)
rows_to_add = []
Expand Down Expand Up @@ -96,11 +98,18 @@ def extract_data(sqlite: Session, postgresql: Session):
sqlite.commit()

# Noteに紐づくtweetデータを取得
targetStartDate = int(settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND) or int(
datetime.combine(datetime.now() - timedelta(days=2), datetime.min.time()).timestamp() * 1000
)
default_end_time = int(settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) or int(
datetime.combine(datetime.now() - timedelta(days=1) - timedelta(hours=20), datetime.min.time()).timestamp()
* 1000
)
postExtract_targetNotes = (
sqlite.query(RowNoteRecord)
.filter(RowNoteRecord.tweet_id != None)
.filter(RowNoteRecord.created_at_millis >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND)
.filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND)
.filter(RowNoteRecord.created_at_millis >= targetStartDate)
.filter(RowNoteRecord.created_at_millis <= default_end_time)
.all()
)
logging.info(f"Target notes: {len(postExtract_targetNotes)}")
Expand Down
6 changes: 5 additions & 1 deletion etl/src/birdxplorer_etl/lib/x/postlookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def connect_to_endpoint(url):
if response.status_code == 429:
limit = response.headers["x-rate-limit-reset"]
logging.info("Waiting for rate limit reset...")
time.sleep(int(limit) - int(time.time()) + 1)
logging.info("Time now: {}".format(int(limit)))
logging.info("Time to wait: {}".format(int(time.time())))
wait_time = int(limit) - int(time.time()) + 1
if (wait_time) > 0:
time.sleep(int(limit) - int(time.time()) + 1)
data = connect_to_endpoint(url)
return data
elif response.status_code != 200:
Expand Down
14 changes: 6 additions & 8 deletions etl/src/birdxplorer_etl/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

load_dotenv()

default_start_time = int(datetime.combine(datetime.now() - timedelta(days=2), datetime.min.time()).timestamp() * 1000)
default_end_time = int(
datetime.combine(datetime.now() - timedelta(days=1) - timedelta(hours=20), datetime.min.time()).timestamp() * 1000
)
default_start_time = datetime.combine(datetime.now() - timedelta(days=2), datetime.min.time()).timestamp() * 1000
default_end_time = datetime.combine(datetime.now() - timedelta(days=1), datetime.min.time()).timestamp() * 1000

TARGET_TWITTER_POST_START_UNIX_MILLISECOND = int(
os.getenv("TARGET_TWITTER_POST_START_UNIX_MILLISECOND", default_start_time)
Expand All @@ -21,11 +19,11 @@
AI_MODEL = os.getenv("AI_MODEL")
OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN")
CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN")
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = os.getenv(
"TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND", default_start_time
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = int(
os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND", default_start_time)
)
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = os.getenv(
"TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND", default_end_time
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = int(
os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND", default_end_time)
)

USE_DUMMY_DATA = os.getenv("USE_DUMMY_DATA", "False") == "True"
Expand Down
59 changes: 35 additions & 24 deletions etl/src/birdxplorer_etl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
RowPostRecord,
RowUserRecord,
)
from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service
from birdxplorer_etl.settings import (
from lib.ai_model.ai_model_interface import get_ai_service
from settings import (
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND,
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND,
)
Expand Down Expand Up @@ -65,6 +65,12 @@ def transform_data(sqlite: Session, postgresql: Session):
RowNoteStatusRecord.current_status,
func.cast(RowNoteRecord.created_at_millis, Integer).label("created_at"),
)
.filter(
and_(
RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND,
RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND,
)
)
.join(RowNoteStatusRecord, RowNoteRecord.note_id == RowNoteStatusRecord.note_id)
.limit(limit)
.offset(offset)
Expand All @@ -73,6 +79,7 @@ def transform_data(sqlite: Session, postgresql: Session):
for note in notes:
note_as_list = list(note)
note_as_list.append(ai_service.detect_language(note[2]))
note_as_list.append("ja")
writer = csv.writer(file)
writer.writerow(note_as_list)
offset += limit
Expand Down Expand Up @@ -156,28 +163,7 @@ def transform_data(sqlite: Session, postgresql: Session):
generate_post_link(postgresql)

# Transform row post embed url data and generate post_embed_url.csv
csv_seed_file_path = "./seed/topic_seed.csv"
output_csv_file_path = "./data/transformed/topic.csv"
records = []

if os.path.exists(output_csv_file_path):
return

with open(csv_seed_file_path, newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for index, row in enumerate(reader):
if "ja" in row and row["ja"]:
topic_id = index + 1
label = {"ja": row["ja"], "en": row["en"]} # Assuming the label is in Japanese
record = {"topic_id": topic_id, "label": label}
records.append(record)

with open(output_csv_file_path, "a", newline="", encoding="utf-8") as file:
fieldnames = ["topic_id", "label"]
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for record in records:
writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}})
generate_topic()

generate_note_topic(sqlite)

Expand Down Expand Up @@ -272,6 +258,31 @@ def generate_post_link(postgresql: Session):
offset += limit


def generate_topic():
csv_seed_file_path = "./seed/topic_seed.csv"
output_csv_file_path = "./data/transformed/topic.csv"
records = []

if os.path.exists(output_csv_file_path):
return

with open(csv_seed_file_path, newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for index, row in enumerate(reader):
if "ja" in row and row["ja"]:
topic_id = index + 1
label = {"ja": row["ja"], "en": row["en"]} # Assuming the label is in Japanese
record = {"topic_id": topic_id, "label": label}
records.append(record)

with open(output_csv_file_path, "a", newline="", encoding="utf-8") as file:
fieldnames = ["topic_id", "label"]
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for record in records:
writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}})


def generate_note_topic(sqlite: Session):
output_csv_file_path = "./data/transformed/note_topic_association.csv"
ai_service = get_ai_service()
Expand Down

0 comments on commit 7e0e6ad

Please sign in to comment.