diff --git a/airflow/dags/plugins/spark_snowflake_conn.py b/airflow/dags/plugins/spark_snowflake_conn.py index 69ae3d7..5ca6375 100644 --- a/airflow/dags/plugins/spark_snowflake_conn.py +++ b/airflow/dags/plugins/spark_snowflake_conn.py @@ -1,8 +1,10 @@ import os +import shutil from datetime import datetime from dags.plugins.variables import SPARK_JARS from pyspark.sql import SparkSession +from snowflake.connector.pandas_tools import write_pandas from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -37,6 +39,29 @@ def create_spark_session(app_name: str): return spark +def write_pandas_snowflake(df, table_name): + + conn = SnowflakeHook(snowflake_conn_id="SNOWFLAKE_CONN", schema="RAW_DAT") + + success, num_chunks, num_rows, output = write_pandas( + conn, df, f"{table_name}") + + +def write_spark_csv(file_name, df): + temp_folder = f"data/{file_name}" + df.coalesce(1).write.mode("overwrite").csv(temp_folder, header=True) + + # 저장된 디렉터리에서 CSV 파일 찾기 + csv_file = [f for f in os.listdir(temp_folder) if f.startswith("part-")][0] + + # 새 파일명 설정과 이동 + output_path = f"data/{file_name}.csv" + shutil.move(os.path.join(temp_folder, csv_file), output_path) + + # 임시 폴더 삭제 + shutil.rmtree(temp_folder) + + def create_snowflake_table(sql): hook = SnowflakeHook(snowflake_conn_id="SNOWFLAKE_CONN", schema="RAW_DATA") diff --git a/airflow/dags/scripts/ELT_artist_info_globalTop50.py b/airflow/dags/scripts/ELT_artist_info_globalTop50.py index f090d71..1939f62 100644 --- a/airflow/dags/scripts/ELT_artist_info_globalTop50.py +++ b/airflow/dags/scripts/ELT_artist_info_globalTop50.py @@ -4,13 +4,11 @@ import requests import snowflake.connector from plugins.spark_snowflake_conn import * -from pyspark.sql.functions import (col, current_date, explode, lit, - regexp_replace, split, udf) +from pyspark.sql.functions import (col, concat_ws, current_date, explode, lit, + regexp_replace, split) from pyspark.sql.types import (ArrayType, IntegerType, StringType, StructField, StructType) -LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY") - BUCKET_NAME = "de5-s4tify" OBJECT_NAME = "raw_data" @@ -19,7 +17,6 @@ def load(): - # 테이블 있는지 확인하는 sql sql = """ CREATE TABLE IF NOT EXISTS artist_info_globalTop50( artist_id VARCHAR(100), @@ -36,12 +33,8 @@ def load(): create_snowflake_table(sql) transform_df = transformation() - transform_df.show() - - # Null 값이 있는 행 출력 - # transform_df.filter(col("title") == "Sweet Dreams (feat. Miguel)").show(truncate=False) - write_snowflake_spark_dataframe("artist_info_globalTop50", transform_df) + write_spark_csv(f"join_artsit_info_chart_{TODAY}", transform_df) def transformation(): @@ -80,37 +73,15 @@ def transformation(): "date_time", current_date()) artist_info_top50_df = artist_info_top50_df.withColumn( - "song_genre", add_song_genre_udf(col("artist_name"), col("title")) + "artist", concat_ws(",", col("artist")) + ) + artist_info_top50_df = artist_info_top50_df.withColumn( + "artist_genre", concat_ws(",", col("artist_genre")) ) return artist_info_top50_df -def add_song_genre(artist, track): - - url = f"https://ws.audioscrobbler.com/2.0/?method=track.getInfo&api_key={LAST_FM_API_KEY}&artist={artist}&track={track}&format=json" - print(url) - - try: - response = requests.get(url).json() - return [ - genre["name"] for genre in response.get( - "track", - {}).get( - "toptags", - {}).get( - "tag", - [])] - except requests.exceptions.RequestException as e: - print(f"API 요청 오류: {e}") - return ["API Error"] - except KeyError: - return ["Unknown"] - - -add_song_genre_udf = udf(add_song_genre, ArrayType(StringType())) - - def extract(file_name, schema): spark = create_spark_session("artist_global_table") diff --git a/airflow/dags/scripts/ELT_artist_info_top10.py b/airflow/dags/scripts/ELT_artist_info_top10.py index a525f37..8e640c9 100644 --- a/airflow/dags/scripts/ELT_artist_info_top10.py +++ b/airflow/dags/scripts/ELT_artist_info_top10.py @@ -3,11 +3,10 @@ import requests from dags.plugins.spark_snowflake_conn import * from pyspark.sql import SparkSession -from pyspark.sql.functions import col, current_date, regexp_replace, split, udf +from pyspark.sql.functions import (col, concat_ws, current_date, + regexp_replace, split) from pyspark.sql.types import ArrayType, StringType, StructField, StructType -LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY") - BUCKET_NAME = "de5-s4tify" OBJECT_NAME = "raw_data" @@ -75,38 +74,13 @@ def transformation(): artist_info_top10_df = artist_info_top10_df.withColumn( "date_time", current_date()) - # 노래 장르 데이터 추가 artist_info_top10_df = artist_info_top10_df.withColumn( - "song_genre", add_song_genre_udf(col("artist"), col("title")) + "artist_genre", concat_ws(",", col("artist_genre")) ) return artist_info_top10_df -def add_song_genre(artist, track): - - url = f"https://ws.audioscrobbler.com/2.0/?method=track.getInfo&api_key={LAST_FM_API_KEY}&artist={artist}&track={track}&format=json" - - try: - response = requests.get(url).json() - return [ - genre["name"] for genre in response.get( - "track", - {}).get( - "toptags", - {}).get( - "tag", - [])] - except requests.exceptions.RequestException as e: - print(f"API 요청 오류: {e}") - return ["API Error"] - except KeyError: - return ["Unknown"] - - -add_song_genre_udf = udf(add_song_genre, ArrayType(StringType())) - - def extract(file_name, schema): spark = create_spark_session("artist_top10_table") diff --git a/airflow/dags/scripts/add_song_genre.py b/airflow/dags/scripts/add_song_genre.py new file mode 100644 index 0000000..6fa6cde --- /dev/null +++ b/airflow/dags/scripts/add_song_genre.py @@ -0,0 +1,55 @@ +import os + +import pandas as pd +import requests +from plugins.spark_snowflake_conn import * + +LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY") + + +def add_song_genre(file_name, table_name): + + join_data = pd.read_csv(f"data/{file_name}") + song_genres = [] + + for _, row in join_data.iterrows(): + + artist = row["artist"] + track = row["title"] + url = f"https://ws.audioscrobbler.com/2.0/?method=track.getInfo&api_key={LAST_FM_API_KEY}&artist={artist}&track={track}&format=json" + + try: + response = requests.get(url).json() + # API 응답에서 장르 정보 추출 + genre = response.get("track", {}).get("toptags", {}).get("tag", []) + + if genre: + genre_list = [g["name"] for g in genre] # 장르 리스트로 변환 + song_genres.append(", ".join(genre_list)) # 문자열로 저장 + else: + song_genres.append("Unknown") + + except Exception as e: + print(f"Error fetching genre for {artist} - {track}: {e}") + song_genres.append("Error") + + # 새로운 컬럼 추가 + join_data["song_genre"] = song_genres + + # string으로 변경 되었던 아티스트 장르 다시 array 변경 + join_data["artist_genre"] = join_data["artist_genre"].apply( + lambda x: x.split(",") if isinstance(x, str) else [] + ) + + join_data.columns = [col.upper() for col in join_data.columns] + + write_pandas_snowflake(join_data, table_name) + + +def main(logical_date): + add_song_genre( + f"join_artist_info_track10_{logical_date}.csv", + "ARTIST_INFO_TOP10") + add_song_genre( + f"join_artsit_info_chart_{logical_date}.csv", "ARTIST_INFO_GLOBALTOP50" + ) diff --git a/airflow/dags/spotify_ELT_DAG.py b/airflow/dags/spotify_ELT_DAG.py index 8cb81c3..c4dd834 100644 --- a/airflow/dags/spotify_ELT_DAG.py +++ b/airflow/dags/spotify_ELT_DAG.py @@ -1,10 +1,12 @@ from datetime import datetime, timedelta +from scripts.add_song_genre import * from scripts.crawling_spotify_data import * from scripts.load_spotify_data import * from scripts.request_spotify_api import * from airflow import DAG +from airflow.operators.python import PythonOperator from airflow.providers.apache.spark.operators.spark_submit import \ SparkSubmitOperator @@ -52,6 +54,13 @@ dag=dag, ) + add_song_genre_col = PythonOperator( + task_id="add_song_genre_col", + python_callable=main, + op_kwargs={"logical_date": "{{ ds }}"}, + dag=dag, + ) + spotify_genre_count_table = SparkSubmitOperator( task_id="spotify_genre_count_table", application="dags/scripts/ELT_chart_genre_count.py", @@ -60,7 +69,11 @@ dag=dag, ) - [ - artist_info_Top10_table, - artist_info_globalTop50_table, - ] >> spotify_genre_count_table + ( + [ + artist_info_Top10_table, + artist_info_globalTop50_table, + ] + >> add_song_genre_col + >> spotify_genre_count_table + )