Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions airflow/dags/plugins/spark_snowflake_conn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import shutil
from datetime import datetime

from dags.plugins.variables import SPARK_JARS
from pyspark.sql import SparkSession
from snowflake.connector.pandas_tools import write_pandas

from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook

Expand Down Expand Up @@ -37,6 +39,29 @@ def create_spark_session(app_name: str):
return spark


def write_pandas_snowflake(df, table_name):

conn = SnowflakeHook(snowflake_conn_id="SNOWFLAKE_CONN", schema="RAW_DAT")

success, num_chunks, num_rows, output = write_pandas(
conn, df, f"{table_name}")


def write_spark_csv(file_name, df):
temp_folder = f"data/{file_name}"
df.coalesce(1).write.mode("overwrite").csv(temp_folder, header=True)

# 저장된 디렉터리에서 CSV 파일 찾기
csv_file = [f for f in os.listdir(temp_folder) if f.startswith("part-")][0]

# 새 파일명 설정과 이동
output_path = f"data/{file_name}.csv"
shutil.move(os.path.join(temp_folder, csv_file), output_path)

# 임시 폴더 삭제
shutil.rmtree(temp_folder)


def create_snowflake_table(sql):

hook = SnowflakeHook(snowflake_conn_id="SNOWFLAKE_CONN", schema="RAW_DATA")
Expand Down
43 changes: 7 additions & 36 deletions airflow/dags/scripts/ELT_artist_info_globalTop50.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import requests
import snowflake.connector
from plugins.spark_snowflake_conn import *
from pyspark.sql.functions import (col, current_date, explode, lit,
regexp_replace, split, udf)
from pyspark.sql.functions import (col, concat_ws, current_date, explode, lit,
regexp_replace, split)
from pyspark.sql.types import (ArrayType, IntegerType, StringType, StructField,
StructType)

LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY")

BUCKET_NAME = "de5-s4tify"
OBJECT_NAME = "raw_data"

Expand All @@ -19,7 +17,6 @@

def load():

# 테이블 있는지 확인하는 sql
sql = """
CREATE TABLE IF NOT EXISTS artist_info_globalTop50(
artist_id VARCHAR(100),
Expand All @@ -36,12 +33,8 @@ def load():
create_snowflake_table(sql)

transform_df = transformation()
transform_df.show()

# Null 값이 있는 행 출력
# transform_df.filter(col("title") == "Sweet Dreams (feat. Miguel)").show(truncate=False)

write_snowflake_spark_dataframe("artist_info_globalTop50", transform_df)
write_spark_csv(f"join_artsit_info_chart_{TODAY}", transform_df)


def transformation():
Expand Down Expand Up @@ -80,37 +73,15 @@ def transformation():
"date_time", current_date())

artist_info_top50_df = artist_info_top50_df.withColumn(
"song_genre", add_song_genre_udf(col("artist_name"), col("title"))
"artist", concat_ws(",", col("artist"))
)
artist_info_top50_df = artist_info_top50_df.withColumn(
"artist_genre", concat_ws(",", col("artist_genre"))
)

return artist_info_top50_df


def add_song_genre(artist, track):

url = f"https://ws.audioscrobbler.com/2.0/?method=track.getInfo&api_key={LAST_FM_API_KEY}&artist={artist}&track={track}&format=json"
print(url)

try:
response = requests.get(url).json()
return [
genre["name"] for genre in response.get(
"track",
{}).get(
"toptags",
{}).get(
"tag",
[])]
except requests.exceptions.RequestException as e:
print(f"API 요청 오류: {e}")
return ["API Error"]
except KeyError:
return ["Unknown"]


add_song_genre_udf = udf(add_song_genre, ArrayType(StringType()))


def extract(file_name, schema):

spark = create_spark_session("artist_global_table")
Expand Down
32 changes: 3 additions & 29 deletions airflow/dags/scripts/ELT_artist_info_top10.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import requests
from dags.plugins.spark_snowflake_conn import *
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, current_date, regexp_replace, split, udf
from pyspark.sql.functions import (col, concat_ws, current_date,
regexp_replace, split)
from pyspark.sql.types import ArrayType, StringType, StructField, StructType

LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY")

BUCKET_NAME = "de5-s4tify"
OBJECT_NAME = "raw_data"

Expand Down Expand Up @@ -75,38 +74,13 @@ def transformation():
artist_info_top10_df = artist_info_top10_df.withColumn(
"date_time", current_date())

# 노래 장르 데이터 추가
artist_info_top10_df = artist_info_top10_df.withColumn(
"song_genre", add_song_genre_udf(col("artist"), col("title"))
"artist_genre", concat_ws(",", col("artist_genre"))
)

return artist_info_top10_df


def add_song_genre(artist, track):

url = f"https://ws.audioscrobbler.com/2.0/?method=track.getInfo&api_key={LAST_FM_API_KEY}&artist={artist}&track={track}&format=json"

try:
response = requests.get(url).json()
return [
genre["name"] for genre in response.get(
"track",
{}).get(
"toptags",
{}).get(
"tag",
[])]
except requests.exceptions.RequestException as e:
print(f"API 요청 오류: {e}")
return ["API Error"]
except KeyError:
return ["Unknown"]


add_song_genre_udf = udf(add_song_genre, ArrayType(StringType()))


def extract(file_name, schema):

spark = create_spark_session("artist_top10_table")
Expand Down
55 changes: 55 additions & 0 deletions airflow/dags/scripts/add_song_genre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os

import pandas as pd
import requests
from plugins.spark_snowflake_conn import *

LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY")


def add_song_genre(file_name, table_name):

join_data = pd.read_csv(f"data/{file_name}")
song_genres = []

for _, row in join_data.iterrows():

artist = row["artist"]
track = row["title"]
url = f"https://ws.audioscrobbler.com/2.0/?method=track.getInfo&api_key={LAST_FM_API_KEY}&artist={artist}&track={track}&format=json"

try:
response = requests.get(url).json()
# API 응답에서 장르 정보 추출
genre = response.get("track", {}).get("toptags", {}).get("tag", [])

if genre:
genre_list = [g["name"] for g in genre] # 장르 리스트로 변환
song_genres.append(", ".join(genre_list)) # 문자열로 저장
else:
song_genres.append("Unknown")

except Exception as e:
print(f"Error fetching genre for {artist} - {track}: {e}")
song_genres.append("Error")

# 새로운 컬럼 추가
join_data["song_genre"] = song_genres

# string으로 변경 되었던 아티스트 장르 다시 array 변경
join_data["artist_genre"] = join_data["artist_genre"].apply(
lambda x: x.split(",") if isinstance(x, str) else []
)

join_data.columns = [col.upper() for col in join_data.columns]

write_pandas_snowflake(join_data, table_name)


def main(logical_date):
add_song_genre(
f"join_artist_info_track10_{logical_date}.csv",
"ARTIST_INFO_TOP10")
add_song_genre(
f"join_artsit_info_chart_{logical_date}.csv", "ARTIST_INFO_GLOBALTOP50"
)
21 changes: 17 additions & 4 deletions airflow/dags/spotify_ELT_DAG.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datetime import datetime, timedelta

from scripts.add_song_genre import *
from scripts.crawling_spotify_data import *
from scripts.load_spotify_data import *
from scripts.request_spotify_api import *

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.apache.spark.operators.spark_submit import \
SparkSubmitOperator

Expand Down Expand Up @@ -52,6 +54,13 @@
dag=dag,
)

add_song_genre_col = PythonOperator(
task_id="add_song_genre_col",
python_callable=main,
op_kwargs={"logical_date": "{{ ds }}"},
dag=dag,
)

spotify_genre_count_table = SparkSubmitOperator(
task_id="spotify_genre_count_table",
application="dags/scripts/ELT_chart_genre_count.py",
Expand All @@ -60,7 +69,11 @@
dag=dag,
)

[
artist_info_Top10_table,
artist_info_globalTop50_table,
] >> spotify_genre_count_table
(
[
artist_info_Top10_table,
artist_info_globalTop50_table,
]
>> add_song_genre_col
>> spotify_genre_count_table
)