Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Currency Arbitrage with Reinforcement Learning added #578

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "ARmpyIVdteu5"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import random\n",
"from datetime import datetime, timedelta\n",
"\n",
"# Parameters\n",
"num_entries = 1000\n",
"currency_pairs = ['EUR/USD', 'USD/JPY', 'GBP/USD', 'USD/CHF', 'AUD/USD', 'USD/CAD']\n",
"\n",
"# Generate timestamps\n",
"start_time = datetime.now()\n",
"timestamps = [start_time - timedelta(minutes=5*i) for i in range(num_entries)]\n",
"\n",
"# Generate synthetic data\n",
"data = []\n",
"for i in range(num_entries):\n",
" timestamp = timestamps[i]\n",
" pair = random.choice(currency_pairs)\n",
" bid_price = round(random.uniform(1.0, 1.5), 4)\n",
" ask_price = round(bid_price + random.uniform(0.0001, 0.0005), 4)\n",
" data.append([timestamp, pair, bid_price, ask_price])\n",
"\n",
"# Create DataFrame\n",
"columns = ['Timestamp', 'Currency Pair', 'Bid Price', 'Ask Price']\n",
"df = pd.DataFrame(data, columns=columns)\n",
"\n",
"# Save to CSV\n",
"df.to_csv('currency_arbitrage_data.csv', index=False)\n"
]
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"# Load the dataset\n",
"df = pd.read_csv('currency_arbitrage_data.csv')\n",
"\n",
"# Convert the Timestamp column to datetime\n",
"df['Timestamp'] = pd.to_datetime(df['Timestamp'])\n",
"\n",
"# Display the first few rows of the dataset\n",
"print(df.head())\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tUTB5cjltxSh",
"outputId": "fd4c12e8-afb0-423c-a5ad-634b96b2bdcf"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Timestamp Currency Pair Bid Price Ask Price\n",
"0 2024-07-19 11:38:48.716950 USD/CHF 1.2220 1.2224\n",
"1 2024-07-19 11:33:48.716950 USD/JPY 1.4628 1.4631\n",
"2 2024-07-19 11:28:48.716950 USD/JPY 1.3268 1.3271\n",
"3 2024-07-19 11:23:48.716950 USD/CHF 1.1257 1.1262\n",
"4 2024-07-19 11:18:48.716950 USD/CAD 1.3318 1.3319\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
"\n",
"# Encode the currency pairs\n",
"le = LabelEncoder()\n",
"df['Currency Pair'] = le.fit_transform(df['Currency Pair'])\n",
"\n",
"# Normalize the bid and ask prices\n",
"scaler = StandardScaler()\n",
"df[['Bid Price', 'Ask Price']] = scaler.fit_transform(df[['Bid Price', 'Ask Price']])\n",
"\n",
"print(df.head())\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "krofbKLUtxPK",
"outputId": "006a9e1e-d992-4c06-e43e-70763da549c6"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Timestamp Currency Pair Bid Price Ask Price\n",
"0 2024-07-19 11:38:48.716950 4 -0.173319 -0.172652\n",
"1 2024-07-19 11:33:48.716950 5 1.508608 1.508600\n",
"2 2024-07-19 11:28:48.716950 5 0.558683 0.558661\n",
"3 2024-07-19 11:23:48.716950 4 -0.845950 -0.844594\n",
"4 2024-07-19 11:18:48.716950 3 0.593606 0.592188\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import gym\n",
"from gym import spaces\n",
"\n",
"class CurrencyArbitrageEnv(gym.Env):\n",
" def __init__(self, df):\n",
" super(CurrencyArbitrageEnv, self).__init__()\n",
" self.df = df\n",
" self.current_step = 0\n",
"\n",
" # Actions: 0 = hold, 1 = buy, 2 = sell\n",
" self.action_space = spaces.Discrete(3)\n",
"\n",
" # Observation space: bid price, ask price, and currency pair\n",
" self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32)\n",
"\n",
" # Initial account balance\n",
" self.balance = 10000\n",
"\n",
" def reset(self):\n",
" self.current_step = 0\n",
" self.balance = 10000\n",
" return self._next_observation()\n",
"\n",
" def _next_observation(self):\n",
" obs = self.df.iloc[self.current_step][['Bid Price', 'Ask Price', 'Currency Pair']].values\n",
" return obs\n",
"\n",
" def step(self, action):\n",
" obs = self._next_observation()\n",
" self.current_step += 1\n",
"\n",
" if action == 1: # Buy\n",
" self.balance -= obs[1] # Subtract ask price\n",
" elif action == 2: # Sell\n",
" self.balance += obs[0] # Add bid price\n",
"\n",
" reward = self.balance\n",
"\n",
" done = self.current_step >= len(self.df) - 1\n",
"\n",
" return obs, reward, done, {}\n",
"\n",
" def render(self, mode='human'):\n",
" print(f'Step: {self.current_step}, Balance: {self.balance}')\n",
"\n",
"# Create the environment\n",
"env = CurrencyArbitrageEnv(df)\n"
],
"metadata": {
"id": "pzTAsCD3txKv"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import random\n",
"from collections import deque\n",
"\n",
"class QLearningAgent:\n",
" def __init__(self, state_size, action_size):\n",
" self.state_size = state_size\n",
" self.action_size = action_size\n",
" self.memory = deque(maxlen=2000)\n",
" self.gamma = 0.95 # discount rate\n",
" self.epsilon = 1.0 # exploration rate\n",
" self.epsilon_min = 0.01\n",
" self.epsilon_decay = 0.995\n",
" self.learning_rate = 0.001\n",
" self.q_table = {}\n",
"\n",
" def get_qs(self, state):\n",
" return self.q_table.get(tuple(state.flatten()), np.zeros(self.action_size))\n",
"\n",
" def remember(self, state, action, reward, next_state, done):\n",
" self.memory.append((state, action, reward, next_state, done))\n",
"\n",
" def act(self, state):\n",
" if np.random.rand() <= self.epsilon:\n",
" return random.randrange(self.action_size)\n",
" qs = self.get_qs(state)\n",
" return np.argmax(qs)\n",
"\n",
" def replay(self, batch_size):\n",
" minibatch = random.sample(self.memory, batch_size)\n",
" for state, action, reward, next_state, done in minibatch:\n",
" target = reward\n",
" if not done:\n",
" target = reward + self.gamma * np.amax(self.get_qs(next_state))\n",
" qs = self.get_qs(state)\n",
" qs[action] = target\n",
" self.q_table[tuple(state.flatten())] = qs\n",
" if self.epsilon > self.epsilon_min:\n",
" self.epsilon *= self.epsilon_decay\n",
"\n",
"# Initialize the agent\n",
"state_size = env.observation_space.shape[0]\n",
"action_size = env.action_space.n\n",
"agent = QLearningAgent(state_size, action_size)\n",
"\n",
"# Train the agent\n",
"episodes = 300\n",
"batch_size = 32\n",
"\n",
"for e in range(episodes):\n",
" state = env.reset()\n",
" state = np.reshape(state, [1, state_size])\n",
"\n",
" for time in range(50):\n",
" action = agent.act(state)\n",
" next_state, reward, done, _ = env.step(action)\n",
" next_state = np.reshape(next_state, [1, state_size])\n",
"\n",
" agent.remember(state, action, reward, next_state, done)\n",
" state = next_state\n",
"\n",
" if done:\n",
" print(f\"Episode: {e+1}/{episodes}, Score: {reward}, Epsilon: {agent.epsilon:.2}\")\n",
" break\n",
"\n",
" if len(agent.memory) > batch_size:\n",
" agent.replay(batch_size)\n"
],
"metadata": {
"id": "l7MHaOXjtxGQ"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Evaluate the agent\n",
"state = env.reset()\n",
"state = np.reshape(state, [1, state_size])\n",
"total_reward = 0\n",
"\n",
"for time in range(500):\n",
" action = agent.act(state)\n",
" next_state, reward, done, _ = env.step(action)\n",
" next_state = np.reshape(next_state, [1, state_size])\n",
" total_reward += reward\n",
" state = next_state\n",
"\n",
" if done:\n",
" print(f\"Final Balance: {total_reward}\")\n",
" break\n"
],
"metadata": {
"id": "t_-a_DbRuDay"
},
"execution_count": 7,
"outputs": []
}
]
}
Loading
Loading