-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2cf092a
commit 9c087c7
Showing
8 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import sys | ||
import os | ||
|
||
# Add src directory to the Python path | ||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')) | ||
sys.path.append(src_path) | ||
|
||
from data_loader import load_data, load_and_concatenate_stock_data | ||
from data_cleaner import clean_data, drop_first_column, format_date_column | ||
from sentiment_analysis import analyze_sentiment, aggregate_sentiments_by_date | ||
from stock_returns import calculate_daily_returns | ||
from data_merger import merge_data, merge_sentiment_with_stock | ||
from correlation_analysis import calculate_correlation | ||
from visualizationt import plot_daily_returns, scatter_plot | ||
|
||
|
||
def main(): | ||
# Load data | ||
news_df = load_data('../data/news.csv') | ||
|
||
# File paths and corresponding company names | ||
file_paths = [ | ||
'../data/apple.csv', | ||
'../data/amazon.csv', | ||
'../data/google.csv', | ||
'../data/meta.csv', | ||
'../data/microsoft.csv', | ||
'../data/nvidia.csv', | ||
'../data/tesla.csv' | ||
] | ||
|
||
company_names = ['Apple', 'Amazon', 'Google', 'Meta', 'Microsoft', 'Nvidia', 'Tesla'] | ||
|
||
# Load and concatenate the stock data | ||
stock_df = load_and_concatenate_stock_data(file_paths, company_names) | ||
|
||
# Drop the first column | ||
news_df = drop_first_column(news_df) | ||
|
||
# Clean data | ||
stock_df = clean_data(stock_df, 'Date') | ||
news_df = clean_data(news_df, 'date') | ||
|
||
# Format the 'date' columns | ||
news_df = format_date_column(news_df, 'date') | ||
stock_df = format_date_column(stock_df, 'date') | ||
|
||
# Merge data | ||
merged_df = merge_data(news_df, stock_df) | ||
|
||
# Analyze sentiment | ||
merged_df['sentiment'] = merged_df['headline'].apply(analyze_sentiment) | ||
|
||
# Calculate daily stock returns | ||
merged_df = calculate_daily_returns(merged_df) | ||
|
||
# Aggregate sentiments by date | ||
daily_sentiment = aggregate_sentiments_by_date(merged_df) | ||
|
||
# Merge sentiment data with stock data | ||
daily_df = merge_sentiment_with_stock(merged_df, daily_sentiment) | ||
|
||
# Calculate correlation | ||
correlation = calculate_correlation(daily_df) | ||
print(f'Pearson correlation coefficient: {correlation}') | ||
|
||
# Visualize | ||
plot_daily_returns(daily_df) | ||
scatter_plot(daily_df) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
import pandas as pd | ||
|
||
def calculate_correlation(df, column1='Average Sentiment', column2='Daily Return'): | ||
correlation = df[[column1, column2]].corr().iloc[0, 1] | ||
return correlation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pandas as pd | ||
|
||
def clean_data(df, date_column): | ||
# Check if the column name needs to be renamed to 'date' | ||
if date_column in df.columns: | ||
df.rename(columns={date_column: 'date'}, inplace=True) | ||
|
||
# Convert the date column to datetime and remove timezone information | ||
df['date'] = pd.to_datetime(df['date'], errors='coerce').dt.tz_localize(None) | ||
|
||
# Drop rows with any missing values | ||
df.dropna(inplace=True) | ||
|
||
return df | ||
|
||
def drop_first_column(df): | ||
|
||
df = df.drop(df.columns[0], axis=1) | ||
return df | ||
|
||
def format_date_column(df, date_column): | ||
|
||
df[date_column] = pd.to_datetime(df[date_column], errors='coerce').dt.strftime('%Y-%m-%d') | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pandas as pd | ||
|
||
def load_data(file_path): | ||
return pd.read_csv(file_path) | ||
|
||
def load_and_concatenate_stock_data(file_paths, company_names): | ||
|
||
# List to store individual dataframes | ||
dataframes = [] | ||
|
||
# Loop through each file path and company name | ||
for file_path, company_name in zip(file_paths, company_names): | ||
# Load the CSV file | ||
df = pd.read_csv(file_path) | ||
# Add a 'Company' column | ||
df['Company'] = company_name | ||
# Append the dataframe to the list | ||
dataframes.append(df) | ||
|
||
# Concatenate all dataframes into one | ||
stock_df = pd.concat(dataframes) | ||
|
||
return stock_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import pandas as pd | ||
|
||
def merge_data(news_df, stock_df, on='date'): | ||
merged_df = pd.merge(news_df, stock_df, on=on) | ||
return merged_df | ||
|
||
def merge_sentiment_with_stock(stock_df, sentiment_df): | ||
daily_df = pd.merge(stock_df[['date', 'Daily Return']], sentiment_df, on='date', how='inner') | ||
return daily_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
import pandas as pd | ||
from textblob import TextBlob | ||
|
||
def analyze_sentiment(text): | ||
analysis = TextBlob(text) | ||
return analysis.sentiment.polarity | ||
|
||
def aggregate_sentiments_by_date(df, date_column='date', sentiment_column='sentiment', output_column='Average Sentiment'): | ||
daily_sentiment = df.groupby(date_column)[sentiment_column].mean().reset_index() | ||
daily_sentiment.columns = [date_column, output_column] | ||
return daily_sentiment |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
import pandas as pd | ||
|
||
def calculate_daily_returns(df, close_column='Close', return_column='Daily Return'): | ||
df[return_column] = df[close_column].pct_change() * 100 | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
|
||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
def plot_daily_returns(daily_df): | ||
plt.figure(figsize=(14, 7)) | ||
|
||
plt.subplot(2, 1, 1) | ||
plt.plot(daily_df.index, daily_df['Daily Return'], label='Daily Return', color='blue') | ||
plt.title('Daily Stock Returns') | ||
plt.xlabel('Date') | ||
plt.ylabel('Return (%)') | ||
plt.legend() | ||
|
||
plt.subplot(2, 1, 2) | ||
plt.plot(daily_df.index, daily_df['Average Sentiment'], label='Average Sentiment', color='orange') | ||
plt.title('Average Daily Sentiment') | ||
plt.xlabel('Date') | ||
plt.ylabel('Sentiment Score') | ||
plt.legend() | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
def scatter_plot(daily_df): | ||
plt.figure(figsize=(14, 7)) | ||
sns.scatterplot(data=daily_df, x='Average Sentiment', y='Daily Return', alpha=0.5) | ||
plt.title('Sentiment Score vs. Stock Returns') | ||
plt.xlabel('Average Sentiment Score') | ||
plt.ylabel('Daily Return (%)') | ||
plt.legend() | ||
plt.show() |