From 4e25297b6cb304191e6adf8b53b13f115d5c9a31 Mon Sep 17 00:00:00 2001 From: Franklyn Dunbar Date: Sat, 18 Apr 2020 16:31:20 -0600 Subject: [PATCH] . --- .../AlphaZero_exercise-checkpoint.ipynb | 1400 +++++++++++++++++ AlphaZero_exercise.ipynb | 944 ++++++----- 2 files changed, 1950 insertions(+), 394 deletions(-) create mode 100644 .ipynb_checkpoints/AlphaZero_exercise-checkpoint.ipynb diff --git a/.ipynb_checkpoints/AlphaZero_exercise-checkpoint.ipynb b/.ipynb_checkpoints/AlphaZero_exercise-checkpoint.ipynb new file mode 100644 index 0000000..41bb522 --- /dev/null +++ b/.ipynb_checkpoints/AlphaZero_exercise-checkpoint.ipynb @@ -0,0 +1,1400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xST_HD4EI-uL", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# AlphaGo Zero\n", + "---\n", + "For this exercise we will implement AlphaZero which is a more generalized form of AlphaGO and train it on a game of \n", + "connect 4. We will first play against AlphaZero when it has not trained at all. We will then let it train for a few training \n", + "cycles and play it again to observe how it has improved. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kqVyB4PiI-uM", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# 1. AlphaZero Configuration\n", + "---\n", + "\n", + "First lets make a class to keep all of our hyper parameters in one spot. This has been done for you." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0ESdWe4NI-uN", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "class AlphaZeroConfig(object):\n", + " \"\"\"\n", + " This holds the configuration parameters\n", + " \"\"\"\n", + " def __init__(self):\n", + " # Self-Play ==\n", + " self.max_moves = 42\n", + " self.num_simulations = 25\n", + "\n", + " # Root prior exploration noise.\n", + " self.root_dirichlet_alpha = 0.3 # for chess, 0.03 for Go and 0.15 for shogi.\n", + " self.root_exploration_fraction = 0.25\n", + "\n", + " # UCB formula\n", + " self.pb_c_base = 19652\n", + " self.pb_c_init = 1.25\n", + "\n", + " # Training ==\n", + " self.self_play_games = 30 # number of selfplay games per cycle\n", + " self.training_steps = int(40) # number of times we perform gradient descent\n", + " self.batch_size = 50 # size of training batch\n", + " self.cycles = 5 # number of policy iterations to do\n", + "\n", + " self.weight_decay = 1e-4\n", + " self.momentum = 0.9\n", + " self.learning_rate = 5e-4" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SNHXdaS0I-uR" + }, + "source": [ + "# 2. Game Definition\n", + "---\n", + "Next lets set up the connect 4 game. This part has been done for you. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "K2aARZTnI-uS", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import math\n", + "import numpy\n", + "from typing import List\n", + "import numpy as np\n", + "from torch.utils.data import TensorDataset\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "if torch.cuda.is_available(): \n", + " device = \"cuda:0\" \n", + "else: \n", + " device = \"cpu\" \n", + "\n", + "class Node(object):\n", + "\n", + " def __init__(self, prior: float): # prior = how good the network thought it would be\n", + " self.visit_count = 0\n", + " self.to_play = -1\n", + " self.prior = prior\n", + " self.value_sum = 0\n", + " self.children = {}\n", + "\n", + " def expanded(self):\n", + " return len(self.children) > 0\n", + "\n", + " def value(self):\n", + " if self.visit_count == 0:\n", + " return 0\n", + " return self.value_sum / self.visit_count\n", + "\n", + "class Game(object):\n", + "\n", + " def __init__(self, history=None):\n", + " # Connect 4 specific ===\n", + " self._num_rows = 6\n", + " self._num_cols = 7\n", + "\n", + " self._winner = None\n", + "\n", + " # Masks to \"convolve\" over board and detect a winner\n", + " self._win_masks = []\n", + " # Horizontal wins\n", + " for i in range(4):\n", + " mask = np.zeros((4, 4), dtype=np.bool)\n", + " mask[i, :] = True\n", + " self._win_masks.append(mask)\n", + " # Vertical wins\n", + " for j in range(4):\n", + " mask = np.zeros((4, 4), dtype=np.bool)\n", + " mask[:, j] = True\n", + " self._win_masks.append(mask)\n", + " # Diagonal wins\n", + " down = np.zeros((4, 4), dtype=np.bool)\n", + " for i, j in zip(range(4), range(4)):\n", + " down[i, j] = True\n", + " self._win_masks.append(down)\n", + " up = np.zeros((4, 4), dtype=np.bool)\n", + " for i, j in zip(reversed(range(4)), range(4)):\n", + " up[i, j] = True\n", + " self._win_masks.append(up)\n", + "\n", + " # All games will have these ===\n", + " self.history = history or []\n", + " self.child_visits = []\n", + " self.num_actions = self._num_cols # 7 for connect 4, 512 for chess/shogi, and 722 for Go.\n", + "\n", + " def terminal(self):\n", + " \"\"\"\n", + " returns bool if the game is finished or not\n", + " \"\"\"\n", + " if self._winner is not None or len(self.history) == 42:\n", + " return True\n", + "\n", + " image = self.make_image(len(self.history))\n", + " # check for wins from the bottom of the board up. Wins are more likely to appear there.\n", + " for i in reversed(range(self._num_rows - 3)):\n", + " for j in range(self._num_cols - 3):\n", + " for mask in self._win_masks:\n", + " for player in range(2):\n", + " test = image[player, i:i + 4, j:j + 4][mask]\n", + " if np.alltrue(test == 1):\n", + " self._winner = player\n", + " return True\n", + "\n", + " return False\n", + "\n", + " def terminal_value(self, to_play):\n", + " \"\"\"\n", + " The result of the game from the player that's going to_play? If player 1\n", + " won then and to_play is 1 then return 1 if to_play is 2 then return -1?\n", + " \"\"\"\n", + "\n", + " # call just to ensure that state is set\n", + " self.terminal()\n", + "\n", + " if self._winner is None and len(self.history) == 42:\n", + " return 0\n", + " if to_play == self._winner:\n", + " return 1\n", + " else:\n", + " return -1\n", + "\n", + " def legal_actions(self):\n", + " image = self.make_image(len(self.history))\n", + " return [j for j in range(self._num_cols) if image[0, 0, j] == 0 and image[1, 0, j] == 0]\n", + "\n", + " def clone(self):\n", + " return Game(list(self.history))\n", + "\n", + " def apply(self, action: int):\n", + " self.history.append(action)\n", + "\n", + " def store_search_statistics(self, root: Node):\n", + " sum_visits = sum(child.visit_count for child in iter(root.children.values()))\n", + " self.child_visits.append([\n", + " root.children[a].visit_count / sum_visits if a in root.children else 0\n", + " for a in range(self.num_actions)\n", + " ])\n", + "\n", + " def make_image(self, state_index: int):\n", + " \"\"\"\n", + " returns what the game looked like at state_index i\n", + " \"\"\"\n", + " player_0 = np.zeros((self._num_rows, self._num_cols), dtype=numpy.float)\n", + " player_1 = np.zeros((self._num_rows, self._num_cols), dtype=numpy.float)\n", + " for move_i, move in enumerate(self.history[:state_index+1]):\n", + " for row in reversed(range(self._num_rows)):\n", + " if player_0[row, move] == 0 and player_1[row, move] == 0:\n", + " if move_i % 2 == 0:\n", + " player_0[row, move] = 1\n", + " if move_i % 2 == 1:\n", + " player_1[row, move] = 1\n", + " break\n", + "\n", + " to_play = (state_index + 1) % 2 * np.ones((self._num_rows, self._num_cols), dtype=numpy.float)\n", + "\n", + " return np.array([player_0, player_1, to_play], dtype=numpy.float)\n", + "\n", + " def make_target(self, state_index: int):\n", + " \"\"\"\n", + " returns the nural network target i.e. what the NN should be gessing given the image\n", + " \"\"\"\n", + " return (self.terminal_value(state_index % 2), # state_index % 2 will always be who's playing\n", + " self.child_visits[state_index])\n", + "\n", + " def to_play(self):\n", + " \"\"\"\n", + " Return the player that is about to play\n", + " \"\"\"\n", + " return len(self.history) % 2\n", + "\n", + " def __str__(self):\n", + " board_state = self.make_image(len(self.history))\n", + "\n", + " out = \"\"\n", + " for i in range(self._num_rows):\n", + " out += f\"{i}|\"\n", + " for j in range(self._num_cols):\n", + " if board_state[0, i, j] == 1:\n", + " out += \" ○ \"\n", + " elif board_state[1, i, j] == 1:\n", + " out += \" ● \"\n", + " else:\n", + " out += \" \"\n", + " out += \"|\\n\"\n", + "\n", + " out += \" \"\n", + " for j in range(self._num_cols):\n", + " out += f\" \\u0305{j} \"\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "efYv8f0bI-uW", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# 3. One Network Two Heads\n", + "---\n", + "Two heads are smarter than one right? Lets implement the tow headed Neural Network. Recall that AlphaGo Zero uses Convolutional \n", + "ResNet architecture that supplies features to two convolutional networks. One that outputs a probability distribution over all possible moves $(p)$ and another that \n", + "outputs a single scalar value $(v)$ representing the value of the current state. \n", + "\n", + "The neural network is defined as: \n", + "$$f_\\theta (s) = (\\mathbf{p,v})$$ \n", + "\n", + "The game board is 6 spaces tall and 7 spaces wide. That means that we have 7 possible moves to make, hint...." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pgV4dCsNI-uX", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + "\n", + " \"\"\"\n", + " Convolution Block\n", + " \"\"\"\n", + " self.conv_block = nn.Sequential(\n", + " nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=1),\n", + " nn.BatchNorm2d(num_features=128),\n", + " nn.ReLU(inplace=True)\n", + " )\n", + "\n", + " \"\"\"\n", + " ResNet Block\n", + " \"\"\"\n", + " self.res_block = nn.Sequential(\n", + " nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),\n", + " nn.BatchNorm2d(num_features=128),\n", + " nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),\n", + " nn.BatchNorm2d(num_features=128)\n", + " )\n", + "\n", + " \"\"\"\n", + " Value Head\n", + " \"\"\"\n", + " self.value_convolv = nn.Sequential(\n", + " nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1),\n", + " nn.BatchNorm2d(num_features=3),\n", + " nn.ReLU(inplace=True),\n", + " )\n", + "\n", + " # TODO: The value head outputs a what?\n", + " self.value_linear = nn.Sequential(\n", + " nn.Linear(in_features=126, out_features=32),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(in_features=32, out_features=1),\n", + " nn.Tanh()\n", + " )\n", + "\n", + "\n", + " \"\"\"\n", + " Policy Head\n", + " \"\"\"\n", + " self.policy_convolv = nn.Sequential(\n", + " nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1),\n", + " nn.BatchNorm2d(num_features=32),\n", + " nn.ReLU(inplace=True)\n", + " )\n", + "\n", + " # TODO: the policy network outputs what?\n", + " # how many moves can we make? hint in section header\n", + " self.policy_linear = nn.Sequential(\n", + " nn.Linear(6*7*32, 7),\n", + " nn.LogSoftmax(dim=1)\n", + " )\n", + "\n", + " def inference(self, image):\n", + " \"\"\"\n", + " Use this for the evaluate() function in the next section, hint...\n", + " The game class has some nifty functions that feed this palatable images...\n", + " \"\"\"\n", + " image = torch.from_numpy(image)\n", + " image = image.to(torch.float)\n", + " image = image.unsqueeze(0)\n", + "\n", + " p, v = self.forward(image)\n", + "\n", + " return float(v.squeeze().detach()), p.squeeze().detach().cpu().numpy()\n", + "\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Perform forward.\"\"\"\n", + "\n", + " # you can mess with the number of residual blocks here if youd like\n", + " # the paper uses 20\n", + " num_blocks = 10\n", + "\n", + " x = x.to(device)\n", + " \"\"\"\n", + " ResNet\n", + " \"\"\"\n", + " x = self.conv_block(x)\n", + " for i in range(num_blocks):\n", + " residual = x\n", + " x = self.res_block(x)\n", + " x += residual\n", + " x = nn.functional.relu(x, inplace=True)\n", + "\n", + " \"\"\"\n", + " Value Head\n", + " \"\"\"\n", + " v = self.value_convolv(x)\n", + " v = v.view(-1, 3 * 6 * 7)\n", + " v = self.value_linear(v)\n", + "\n", + " \"\"\"\n", + " Policy Head\n", + " \"\"\"\n", + " p = self.policy_convolv(x)\n", + " p = p.view(-1, 6 * 7 * 32)\n", + " p = self.policy_linear(p)\n", + "\n", + " return p, v" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PbVMq1bbI-ua", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# 4. The training pipeline\n", + "---\n", + "AlphaZero training is split into two independent parts: Network training and self-play data generation.\n", + "These two parts only communicate by transferring the latest network checkpoint\n", + "from the training to the self-play, and the finished games from the self-play\n", + "to the training." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "GabAq2klI-ub", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def alphazero(config: AlphaZeroConfig, network: Net):\n", + "\n", + " # TODO: Here we'll have to do something! Remember the basic steps of the \n", + " # algorithm:\n", + " # 1: Create training data using the current neural network, this evaluates and improves our policy\n", + " # 2: Improve our policy by training our nural network\n", + " # 3: Repete!\n", + " #\n", + " # We have provided helper functions that may be of use to you\n", + " for i in range(config.cycles):\n", + " print(f\"self play {i} of {config.cycles}\")\n", + " network.eval()\n", + " games = run_selfplay(config, network) # TODO \n", + " print(f\"train network {i} of {config.cycles}\")\n", + " network.train()\n", + " train_network(config, games) # TODO\n", + "\n", + " return network\n", + "\n", + "# Each self-play job is independent of all others; it takes the latest network\n", + "# snapshot, produces a game and makes it available to the training job by\n", + "# writing it to a shared replay buffer.\n", + "def run_selfplay(config: AlphaZeroConfig, network: Net):\n", + " games = []\n", + " for i in range(config.self_play_games): \n", + " if i % 10 == 0:\n", + " print(f\"game {i} of {config.self_play_games}\")\n", + " game = play_game(config, network)\n", + " games.append(game)\n", + " return games\n", + "\n", + "\n", + "# Each game is produced by starting at the initial board position, then\n", + "# repeatedly executing a Monte Carlo Tree Search to generate moves until the end\n", + "# of the game is reached.\n", + "def play_game(config: AlphaZeroConfig, network: Net):\n", + " game = Game()\n", + " while not game.terminal() and len(game.history) < config.max_moves:\n", + " action, root = run_mcts(config, game, network)\n", + " game.apply(action)\n", + " game.store_search_statistics(root)\n", + " return game\n", + "\n", + "\n", + "# Core Monte Carlo Tree Search algorithm.\n", + "# To decide on an action, we run N simulations, always starting at the root of\n", + "# the search tree and traversing the tree according to the UCB formula until we\n", + "# reach a leaf node.\n", + "def run_mcts(config: AlphaZeroConfig, game: Game, network: Net):\n", + " root = Node(0)\n", + " # Populate child nodes AKA the states that the actions available at this\n", + " # states would take you too\n", + " evaluate(root, game, network)\n", + " add_exploration_noise(config, root)\n", + "\n", + " for i in range(config.num_simulations):\n", + " node = root\n", + " scratch_game = game.clone()\n", + " search_path = [node]\n", + "\n", + " while node.expanded():\n", + " # Here we take one step down our search tree towards a win or loss. Note\n", + " # that we are resetting the node variable here to be the state that our\n", + " # game picked given the action we took.\n", + " #\n", + " # On the first run all child nodes will not be expanded, so we'll only\n", + " # take one step before backpropatagating back up the tree.\n", + " action, node = select_child(config, node)\n", + " scratch_game.apply(action)\n", + " search_path.append(node)\n", + "\n", + " value = evaluate(node, scratch_game, network)\n", + " backpropagate(search_path, value, scratch_game.to_play())\n", + " return select_action(config, game, root), root\n", + "\n", + "\n", + "def select_action(config: AlphaZeroConfig, game: Game, root: Node):\n", + " # This is where we would do a softmax sample for the first 30 moves then\n", + " # turn down the temperature, our game is simple enough that we will just \n", + " # always pick the best computed move.\n", + " visit_counts = [(child.visit_count, action)\n", + " for action, child in iter(root.children.items())]\n", + " _, action = max(visit_counts)\n", + " return action\n", + "\n", + "\n", + "# Select the child with the highest UCB score.\n", + "def select_child(config: AlphaZeroConfig, node: Node):\n", + " \"\"\"\n", + " Return the child node, i.e. action to take, that UCB likes best\n", + " \"\"\"\n", + " _, action, child = max((ucb_score(config, node, child), action, child)\n", + " for action, child in iter(node.children.items()))\n", + " return action, child\n", + "\n", + "\n", + "# The score for a node is based on its value, plus an exploration bonus based on\n", + "# the prior.\n", + "def ucb_score(config: AlphaZeroConfig, parent: Node, child: Node):\n", + " pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /\n", + " config.pb_c_base) + config.pb_c_init\n", + " pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)\n", + "\n", + " prior_score = pb_c * child.prior\n", + " value_score = child.value()\n", + " return prior_score + value_score\n", + "\n", + "\n", + "# We use the neural network to obtain a value and policy prediction.\n", + "def evaluate(node: Node, game: Game, network: Net):\n", + " # TODO:\n", + " # Here we need to populate the input nodes's children list. Note that we\n", + " # only want to populate nodes for legal next actions. The Node class takes\n", + " # a prior on construction Node(prior_probability). Where will we get these\n", + " # priors? A network is passed in, that might be useful! Note: that the\n", + " # network returns a value and policy logits so something needs to be done\n", + " # to convert to proper probabilities (maybe softmax, ok definitely\n", + " # softmax). This function is also supposed to return the value, which you\n", + " # might also be able to get from the neural network.\n", + " #\n", + " # It would probably be helpful to know that the policy returned from the NN\n", + " # should be in the order of the columns of our connect 4 board. \n", + " # i.e. policy_logits[0] ∝ how much our network likes column 1. \n", + " \n", + " value, policy_logits = network.inference(game.make_image(len(game.history))) # TODO: take a look at back at the NN for a hint\n", + " # the game class may also have some useful functions for this\n", + "\n", + " # Expand the node.\n", + " node.to_play = game.to_play() \n", + " policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions()}\n", + " policy_sum = sum(iter(policy.values()))\n", + " for action, p in iter(policy.items()): \n", + " # this is just softmax, notice the math.exp 3 lines up\n", + " node.children[action] = Node(p / policy_sum) \n", + " return value # TODO: what are we returning from this?\n", + "\n", + "\n", + "# At the end of a simulation, we propagate the evaluation all the way up the\n", + "# tree to the root.\n", + "def backpropagate(search_path: List[Node], value: float, to_play):\n", + " for node in search_path:\n", + " node.value_sum += value if node.to_play == to_play else (1 - value)\n", + " node.visit_count += 1\n", + "\n", + "\n", + "# At the start of each search, we add dirichlet noise to the prior of the root\n", + "# to encourage the search to explore new actions.\n", + "def add_exploration_noise(config: AlphaZeroConfig, node: Node):\n", + " \"\"\"\n", + " Modifies the priors stored in nodes children with dirichlet noise whatever\n", + " that is\n", + " \"\"\"\n", + " actions = node.children.keys()\n", + " noise = numpy.random.gamma(config.root_dirichlet_alpha, 1, len(actions))\n", + " frac = config.root_exploration_fraction\n", + " for a, n in zip(actions, noise):\n", + " node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac\n", + "\n", + "def create_data_loader(config: AlphaZeroConfig, games: List[Game]):\n", + " game_pos = [(g, i) for g in games for i in range(len(g.history))]\n", + "\n", + " image = np.array([g.make_image(i) for (g, i) in game_pos], dtype=np.float)\n", + " image = torch.from_numpy(image)\n", + " image = image.to(torch.float)\n", + "\n", + " policy_target = np.array([g.make_target(i)[1] for (g, i) in game_pos])\n", + " policy_target = torch.from_numpy(policy_target)\n", + " policy_target = policy_target.to(torch.float)\n", + "\n", + " value_target = np.array([g.make_target(i)[0] for (g, i) in game_pos])\n", + " value_target = torch.from_numpy(value_target)\n", + " value_target = value_target.to(torch.float)\n", + "\n", + " batch_data = TensorDataset(image, policy_target, value_target)\n", + " return torch.utils.data.DataLoader(dataset=batch_data,\n", + " batch_size=config.batch_size,\n", + " shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Z2uhK7q3YHOx" + }, + "source": [ + "# 5. To improve is to change, to be perfect is to change often\n", + "--- \n", + "Bonus points if you can tell me who said this without googleing it. =)\n", + "\n", + "Recall that the loss function is:\n", + "$$l = (z - \\mathbf{v})^2 - \\pi^T log(\\mathbf{p}) + c||\\theta||^2$$\n", + "\n", + "This is the mean-squared error for the target value and predicted value and binary cross entropy for the target policy and predicted policy. \n", + "\n", + "The last term is the L2 regularization. We can take care of this in the optimizer and it is known as weight_decay. Our config class has a field for this parameter. Hint hint...." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ao5Xx6onYEwd" + }, + "outputs": [], + "source": [ + "def train_network(config: AlphaZeroConfig, games: List[Game]):\n", + " \n", + " # TODO: add the L2 regularization here. Look at section heading for hint.\n", + " # Weight decay takes care of our L2 regularization so it doesn't need to be in the loss function\n", + " optimizer = torch.optim.SGD(\n", + " network.parameters(),\n", + " lr=config.learning_rate,\n", + " momentum=config.momentum,\n", + " weight_decay=config.weight_decay\n", + " )\n", + "\n", + "\n", + " for i in range(config.training_steps): #(config.training_steps):\n", + " data_loader = create_data_loader(config, games)\n", + " update_weights(optimizer, network, data_loader, i)\n", + "\n", + "\n", + "def update_weights(optimizer, network, data_loader, batch_num):\n", + " # Loop over each subset of data\n", + " for image, policy_target, value_target in data_loader:\n", + " # Zero out the optimizer's gradient buffer\n", + " optimizer.zero_grad()\n", + " \n", + " # TODO: get the policy and the value from the network\n", + " policy, value = network(image)\n", + " \n", + " # convert data to correct type\n", + " policy = policy.exp()\n", + " value = value.squeeze()\n", + "\n", + " # TODO: Compute the loss here\n", + " # for the value_target and policy_target add .to(device) to make the tensors happy because \n", + " # we like happy tensors. The value and the policy do not need it, they are happy tensors already.\n", + " # Also nn.functional has nifty functions for computing loss\n", + " value_loss = nn.functional.mse_loss(value, value_target.to(device))\n", + " policy_loss = nn.functional.binary_cross_entropy(policy, policy_target.to(device))\n", + "\n", + " loss = value_loss + policy_loss\n", + "\n", + " # Use backpropagation to compute the derivative of the loss with respect to the parameters\n", + " loss.backward()\n", + "\n", + " # Use the derivative information to update the parameters\n", + " optimizer.step()\n", + " print(\"Batch: %d Loss: %f\" % (batch_num, loss))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9U3hnFz5YU5r" + }, + "source": [ + "# 6. Challenge Accepted!\n", + "---\n", + "Making an algorithm that can play against itself is cool and all but do you think your massive human brain can be it? To do so we need to make it possible to play against the singularity. This part has been done for you." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Iq9bmoypYcxs" + }, + "outputs": [], + "source": [ + "def get_human_action(i: int, game: Game):\n", + " while True:\n", + " print(f\"Player {i} choose move please: \", end='')\n", + " human_action = input()\n", + " try:\n", + " if int(human_action) not in game.legal_actions():\n", + " print(\"illegal action\")\n", + " else:\n", + " return int(human_action)\n", + " except ValueError:\n", + " print(\"illegal action\")\n", + "\n", + "def interactive_game(config: AlphaZeroConfig, network: Net, player_1_human=False):\n", + " play_again = 'y'\n", + " while play_again == 'y':\n", + " game = Game()\n", + " print(game)\n", + " while not game.terminal():\n", + " if player_1_human:\n", + " for i in range(2):\n", + " game.apply(get_human_action(i, game))\n", + " print(game)\n", + " if game.terminal():\n", + " break\n", + " else:\n", + " game.apply(get_human_action(0, game))\n", + " print(game)\n", + "\n", + " ai_action, _ = run_mcts(config, game, network)\n", + " print(f\"ai chooses {ai_action}\")\n", + " game.apply(ai_action)\n", + " print(game)\n", + " win_string = {-1: \"lost\", 1: \"won\", 0: \"tied\"}\n", + " print(f\"player 0 {win_string[game.terminal_value(0)]}\")\n", + " print(f\"player 1 {win_string[game.terminal_value(1)]}\")\n", + " while True:\n", + " print(\"Play again? y or n?\", end='')\n", + " play_again = input()\n", + " try:\n", + " if play_again != 'y' and play_again != 'n':\n", + " print(\"I didn't understand\")\n", + " else:\n", + " break\n", + " except ValueError:\n", + " print(\"illegal action\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_eAu4afHR-oN" + }, + "source": [ + "# 7. Let the Singularity Begin!\n", + "---\n", + "Now try playing the game with the untrained network. If you mess with the hyper parameters in the config class make sure to rerun this cell." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "TxQRHX2qI-ue", + "outputId": "daea46fd-4539-4ed8-ec0f-f07a2c7a8c51", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Device: cuda:0\n" + ] + } + ], + "source": [ + "\n", + "print(\"Device: %s\" % device)\n", + "network = Net().to(device)\n", + "config = AlphaZeroConfig()\n", + "#interactive_game(config, network)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8LKc9cWURREV" + }, + "source": [ + "# 8. I need more data to beat you human\n", + "---\n", + "Now train the network for a few cycles and observe its change in behavior" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "colab_type": "code", + "id": "eIx0dFdDQpSN", + "outputId": "1cfd4a66-c88d-452a-c0a5-1d181ce80833" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self play 0 of 5\n", + "game 0 of 30\n", + "game 10 of 30\n", + "game 20 of 30\n", + "train network 0 of 5\n", + "Batch: 0 Loss: 1.297457\n", + "Batch: 1 Loss: 1.070922\n", + "Batch: 2 Loss: 1.005536\n", + "Batch: 3 Loss: 0.739464\n", + "Batch: 4 Loss: 0.658090\n", + "Batch: 5 Loss: 0.601338\n", + "Batch: 6 Loss: 0.618326\n", + "Batch: 7 Loss: 0.574111\n", + "Batch: 8 Loss: 0.593461\n", + "Batch: 9 Loss: 0.523399\n", + "Batch: 10 Loss: 0.508888\n", + "Batch: 11 Loss: 0.526801\n", + "Batch: 12 Loss: 0.483138\n", + "Batch: 13 Loss: 0.520305\n", + "Batch: 14 Loss: 0.396360\n", + "Batch: 15 Loss: 0.429715\n", + "Batch: 16 Loss: 0.392540\n", + "Batch: 17 Loss: 0.409773\n", + "Batch: 18 Loss: 0.430012\n", + "Batch: 19 Loss: 0.475215\n", + "Batch: 20 Loss: 0.499511\n", + "Batch: 21 Loss: 0.478047\n", + "Batch: 22 Loss: 0.518550\n", + "Batch: 23 Loss: 0.351893\n", + "Batch: 24 Loss: 0.361895\n", + "Batch: 25 Loss: 0.481955\n", + "Batch: 26 Loss: 0.377299\n", + "Batch: 27 Loss: 0.337830\n", + "Batch: 28 Loss: 0.408330\n", + "Batch: 29 Loss: 0.605096\n", + "Batch: 30 Loss: 0.362263\n", + "Batch: 31 Loss: 0.361151\n", + "Batch: 32 Loss: 0.410002\n", + "Batch: 33 Loss: 0.304817\n", + "Batch: 34 Loss: 0.323277\n", + "Batch: 35 Loss: 0.495506\n", + "Batch: 36 Loss: 0.416801\n", + "Batch: 37 Loss: 0.395975\n", + "Batch: 38 Loss: 0.281300\n", + "Batch: 39 Loss: 0.317990\n", + "self play 1 of 5\n", + "game 0 of 30\n", + "game 10 of 30\n", + "game 20 of 30\n", + "train network 1 of 5\n", + "Batch: 0 Loss: 1.061677\n", + "Batch: 1 Loss: 0.573583\n", + "Batch: 2 Loss: 0.667789\n", + "Batch: 3 Loss: 0.447696\n", + "Batch: 4 Loss: 0.539596\n", + "Batch: 5 Loss: 0.469717\n", + "Batch: 6 Loss: 0.474496\n", + "Batch: 7 Loss: 0.495829\n", + "Batch: 8 Loss: 0.508060\n", + "Batch: 9 Loss: 0.553491\n", + "Batch: 10 Loss: 0.631098\n", + "Batch: 11 Loss: 0.371213\n", + "Batch: 12 Loss: 0.416165\n", + "Batch: 13 Loss: 0.339832\n", + "Batch: 14 Loss: 0.342511\n", + "Batch: 15 Loss: 0.442871\n", + "Batch: 16 Loss: 0.557983\n", + "Batch: 17 Loss: 0.490888\n", + "Batch: 18 Loss: 0.307327\n", + "Batch: 19 Loss: 0.370288\n", + "Batch: 20 Loss: 0.510485\n", + "Batch: 21 Loss: 0.630249\n", + "Batch: 22 Loss: 0.407610\n", + "Batch: 23 Loss: 0.306060\n", + "Batch: 24 Loss: 0.394436\n", + "Batch: 25 Loss: 0.400511\n", + "Batch: 26 Loss: 0.333179\n", + "Batch: 27 Loss: 0.328875\n", + "Batch: 28 Loss: 0.512536\n", + "Batch: 29 Loss: 0.296035\n", + "Batch: 30 Loss: 0.341999\n", + "Batch: 31 Loss: 0.392925\n", + "Batch: 32 Loss: 0.368058\n", + "Batch: 33 Loss: 0.470829\n", + "Batch: 34 Loss: 0.387786\n", + "Batch: 35 Loss: 0.458152\n", + "Batch: 36 Loss: 0.328859\n", + "Batch: 37 Loss: 0.425190\n", + "Batch: 38 Loss: 0.304928\n", + "Batch: 39 Loss: 0.258104\n", + "self play 2 of 5\n", + "game 0 of 30\n", + "game 10 of 30\n", + "game 20 of 30\n", + "train network 2 of 5\n", + "Batch: 0 Loss: 0.989657\n", + "Batch: 1 Loss: 0.716117\n", + "Batch: 2 Loss: 0.535005\n", + "Batch: 3 Loss: 0.671297\n", + "Batch: 4 Loss: 0.435962\n", + "Batch: 5 Loss: 0.379089\n", + "Batch: 6 Loss: 0.640533\n", + "Batch: 7 Loss: 0.457454\n", + "Batch: 8 Loss: 0.448093\n", + "Batch: 9 Loss: 0.546048\n", + "Batch: 10 Loss: 0.503211\n", + "Batch: 11 Loss: 0.394836\n", + "Batch: 12 Loss: 0.358743\n", + "Batch: 13 Loss: 0.360110\n", + "Batch: 14 Loss: 0.405433\n", + "Batch: 15 Loss: 0.350829\n", + "Batch: 16 Loss: 0.323355\n", + "Batch: 17 Loss: 0.375008\n", + "Batch: 18 Loss: 0.327038\n", + "Batch: 19 Loss: 0.449789\n", + "Batch: 20 Loss: 0.333693\n", + "Batch: 21 Loss: 0.389275\n", + "Batch: 22 Loss: 0.461503\n", + "Batch: 23 Loss: 0.438522\n", + "Batch: 24 Loss: 0.381604\n", + "Batch: 25 Loss: 0.497237\n", + "Batch: 26 Loss: 0.396383\n", + "Batch: 27 Loss: 0.315628\n", + "Batch: 28 Loss: 0.437005\n", + "Batch: 29 Loss: 0.308866\n", + "Batch: 30 Loss: 0.363945\n", + "Batch: 31 Loss: 0.455962\n", + "Batch: 32 Loss: 0.379210\n", + "Batch: 33 Loss: 0.413945\n", + "Batch: 34 Loss: 0.314033\n", + "Batch: 35 Loss: 0.470453\n", + "Batch: 36 Loss: 0.421071\n", + "Batch: 37 Loss: 0.342935\n", + "Batch: 38 Loss: 0.390231\n", + "Batch: 39 Loss: 0.349400\n", + "self play 3 of 5\n", + "game 0 of 30\n", + "game 10 of 30\n", + "game 20 of 30\n", + "train network 3 of 5\n", + "Batch: 0 Loss: 0.718287\n", + "Batch: 1 Loss: 0.609580\n", + "Batch: 2 Loss: 0.427867\n", + "Batch: 3 Loss: 0.436488\n", + "Batch: 4 Loss: 0.403250\n", + "Batch: 5 Loss: 0.376567\n", + "Batch: 6 Loss: 0.482127\n", + "Batch: 7 Loss: 0.410601\n", + "Batch: 8 Loss: 0.407043\n", + "Batch: 9 Loss: 0.429694\n", + "Batch: 10 Loss: 0.389290\n", + "Batch: 11 Loss: 0.408060\n", + "Batch: 12 Loss: 0.426255\n", + "Batch: 13 Loss: 0.483737\n", + "Batch: 14 Loss: 0.439357\n", + "Batch: 15 Loss: 0.636192\n", + "Batch: 16 Loss: 0.480884\n", + "Batch: 17 Loss: 0.449515\n", + "Batch: 18 Loss: 0.383923\n", + "Batch: 19 Loss: 0.438431\n", + "Batch: 20 Loss: 0.396062\n", + "Batch: 21 Loss: 0.362935\n", + "Batch: 22 Loss: 0.465432\n", + "Batch: 23 Loss: 0.373293\n", + "Batch: 24 Loss: 0.450536\n", + "Batch: 25 Loss: 0.368027\n", + "Batch: 26 Loss: 0.347392\n", + "Batch: 27 Loss: 0.372168\n", + "Batch: 28 Loss: 0.410154\n", + "Batch: 29 Loss: 0.402116\n", + "Batch: 30 Loss: 0.410266\n", + "Batch: 31 Loss: 0.394792\n", + "Batch: 32 Loss: 0.368578\n", + "Batch: 33 Loss: 0.351414\n", + "Batch: 34 Loss: 0.397556\n", + "Batch: 35 Loss: 0.329613\n", + "Batch: 36 Loss: 0.330286\n", + "Batch: 37 Loss: 0.371017\n", + "Batch: 38 Loss: 0.342738\n", + "Batch: 39 Loss: 0.370894\n", + "self play 4 of 5\n", + "game 0 of 30\n", + "game 10 of 30\n", + "game 20 of 30\n", + "train network 4 of 5\n", + "Batch: 0 Loss: 1.389253\n", + "Batch: 1 Loss: 1.023016\n", + "Batch: 2 Loss: 0.778526\n", + "Batch: 3 Loss: 0.621589\n", + "Batch: 4 Loss: 0.748947\n", + "Batch: 5 Loss: 0.588811\n", + "Batch: 6 Loss: 0.616743\n", + "Batch: 7 Loss: 0.619884\n", + "Batch: 8 Loss: 0.599840\n", + "Batch: 9 Loss: 0.453021\n", + "Batch: 10 Loss: 0.570860\n", + "Batch: 11 Loss: 0.547868\n", + "Batch: 12 Loss: 1.003815\n", + "Batch: 13 Loss: 0.430471\n", + "Batch: 14 Loss: 0.526788\n", + "Batch: 15 Loss: 0.551975\n", + "Batch: 16 Loss: 0.606334\n", + "Batch: 17 Loss: 0.448370\n", + "Batch: 18 Loss: 0.570277\n", + "Batch: 19 Loss: 0.415983\n", + "Batch: 20 Loss: 0.476248\n", + "Batch: 21 Loss: 0.461724\n", + "Batch: 22 Loss: 0.640585\n", + "Batch: 23 Loss: 0.595707\n", + "Batch: 24 Loss: 0.654544\n", + "Batch: 25 Loss: 0.461624\n", + "Batch: 26 Loss: 0.533261\n", + "Batch: 27 Loss: 0.572199\n", + "Batch: 28 Loss: 0.531856\n", + "Batch: 29 Loss: 0.519722\n", + "Batch: 30 Loss: 0.554501\n", + "Batch: 31 Loss: 0.716420\n", + "Batch: 32 Loss: 0.490347\n", + "Batch: 33 Loss: 0.412190\n", + "Batch: 34 Loss: 0.524030\n", + "Batch: 35 Loss: 0.618971\n", + "Batch: 36 Loss: 0.461805\n", + "Batch: 37 Loss: 0.508472\n", + "Batch: 38 Loss: 0.597108\n", + "Batch: 39 Loss: 0.558429\n", + "0| |\n", + "1| |\n", + "2| |\n", + "3| |\n", + "4| |\n", + "5| |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| |\n", + "1| |\n", + "2| |\n", + "3| |\n", + "4| |\n", + "5| ○ |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 6\n", + "0| |\n", + "1| |\n", + "2| |\n", + "3| |\n", + "4| |\n", + "5| ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| |\n", + "1| |\n", + "2| |\n", + "3| |\n", + "4| ○ |\n", + "5| ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 3\n", + "0| |\n", + "1| |\n", + "2| |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| |\n", + "1| |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 3\n", + "0| |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 2\n", + "0| |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 3\n", + "0| ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 5\n", + "0| ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| ● |\n", + "1| ● |\n", + "2| ○ ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 6\n", + "0| ● |\n", + "1| ● |\n", + "2| ○ ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ ● |\n", + "5| ● ○ ○ ● ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "player 0 won\n", + "player 1 lost\n", + "Play again? y or n?" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I didn't understand\n", + "Play again? y or n?" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " n\n" + ] + } + ], + "source": [ + "# this will train the network on self play games\n", + "alphazero(config, network)\n", + "interactive_game(config, network)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kiQd0snNdeBP" + }, + "source": [ + "# Question\n", + "---\n", + "Each time we start training the NN the error increases from the previous gradient descent steps. Is this an artifact? or why does this make sense?\n", + "\n", + "This does make sense. As the network explore more avenues of decision making, greater losses are incured with each path and thus there is an appeareance of stagnant learning relative to our typical experiences with classification. \n", + "\n", + "# 9. Round Two!\n", + "---\n", + "Ok our algorithm just turned bits into strait gains. Lets try to play it again." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "veD94lDcd54y" + }, + "outputs": [], + "source": [ + "interactive_game(config, network)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "cPC8Rs-VSOfg" + }, + "source": [ + "\n", + "# 10. Final Thoughts\n", + "---\n", + "Ok so its not a connect 4 champ in a few cycles but it does get better. Time permitting, run another few cycles to see how good you can get it. " + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "kqVyB4PiI-uM", + "SNHXdaS0I-uR", + "efYv8f0bI-uW", + "Z2uhK7q3YHOx", + "9U3hnFz5YU5r", + "_eAu4afHR-oN" + ], + "name": "AlphaZero_exercise.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/AlphaZero_exercise.ipynb b/AlphaZero_exercise.ipynb index b61afd6..41bb522 100644 --- a/AlphaZero_exercise.ipynb +++ b/AlphaZero_exercise.ipynb @@ -1,57 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "name": "python3", - "language": "python", - "display_name": "Python 3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "source": [], - "metadata": { - "collapsed": false - } - } - }, - "colab": { - "name": "AlphaZero_exercise.ipynb", - "provenance": [], - "collapsed_sections": [ - "kqVyB4PiI-uM", - "SNHXdaS0I-uR", - "efYv8f0bI-uW", - "Z2uhK7q3YHOx", - "9U3hnFz5YU5r", - "_eAu4afHR-oN" - ] - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "markdown", "metadata": { - "collapsed": true, + "colab_type": "text", + "id": "xST_HD4EI-uL", "pycharm": { "name": "#%% md\n" - }, - "id": "xST_HD4EI-uL", - "colab_type": "text" + } }, "source": [ "# AlphaGo Zero\n", @@ -64,12 +20,11 @@ { "cell_type": "markdown", "metadata": { - "collapsed": false, + "colab_type": "text", + "id": "kqVyB4PiI-uM", "pycharm": { "name": "#%% md\n" - }, - "id": "kqVyB4PiI-uM", - "colab_type": "text" + } }, "source": [ "# 1. AlphaZero Configuration\n", @@ -80,16 +35,17 @@ }, { "cell_type": "code", + "execution_count": 1, "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0ESdWe4NI-uN", "pycharm": { "name": "#%%\n" - }, - "id": "0ESdWe4NI-uN", - "colab_type": "code", - "colab": {} + } }, + "outputs": [], "source": [ - " \n", "class AlphaZeroConfig(object):\n", " \"\"\"\n", " This holds the configuration parameters\n", @@ -116,16 +72,13 @@ " self.weight_decay = 1e-4\n", " self.momentum = 0.9\n", " self.learning_rate = 5e-4" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "collapsed": false, - "id": "SNHXdaS0I-uR", - "colab_type": "text" + "colab_type": "text", + "id": "SNHXdaS0I-uR" }, "source": [ "# 2. Game Definition\n", @@ -135,14 +88,16 @@ }, { "cell_type": "code", + "execution_count": 2, "metadata": { + "colab": {}, + "colab_type": "code", + "id": "K2aARZTnI-uS", "pycharm": { "name": "#%%\n" - }, - "id": "K2aARZTnI-uS", - "colab_type": "code", - "colab": {} + } }, + "outputs": [], "source": [ "import math\n", "import numpy\n", @@ -314,19 +269,16 @@ " for j in range(self._num_cols):\n", " out += f\" \\u0305{j} \"\n", " return out" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "collapsed": false, + "colab_type": "text", + "id": "efYv8f0bI-uW", "pycharm": { "name": "#%% md\n" - }, - "id": "efYv8f0bI-uW", - "colab_type": "text" + } }, "source": [ "# 3. One Network Two Heads\n", @@ -343,14 +295,16 @@ }, { "cell_type": "code", + "execution_count": 3, "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pgV4dCsNI-uX", "pycharm": { "name": "#%%\n" - }, - "id": "pgV4dCsNI-uX", - "colab_type": "code", - "colab": {} + } }, + "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", @@ -388,7 +342,7 @@ " self.value_linear = nn.Sequential(\n", " nn.Linear(in_features=126, out_features=32),\n", " nn.ReLU(inplace=True),\n", - " nn.Linear(in_features=32, out_features=???),\n", + " nn.Linear(in_features=32, out_features=1),\n", " nn.Tanh()\n", " )\n", "\n", @@ -405,7 +359,7 @@ " # TODO: the policy network outputs what?\n", " # how many moves can we make? hint in section header\n", " self.policy_linear = nn.Sequential(\n", - " nn.Linear(6*7*32, ???),\n", + " nn.Linear(6*7*32, 7),\n", " nn.LogSoftmax(dim=1)\n", " )\n", "\n", @@ -455,21 +409,17 @@ " p = p.view(-1, 6 * 7 * 32)\n", " p = self.policy_linear(p)\n", "\n", - " return p, v\n", - " " - ], - "execution_count": 0, - "outputs": [] + " return p, v" + ] }, { "cell_type": "markdown", "metadata": { - "collapsed": false, + "colab_type": "text", + "id": "PbVMq1bbI-ua", "pycharm": { "name": "#%% md\n" - }, - "id": "PbVMq1bbI-ua", - "colab_type": "text" + } }, "source": [ "# 4. The training pipeline\n", @@ -482,14 +432,16 @@ }, { "cell_type": "code", + "execution_count": 4, "metadata": { + "colab": {}, + "colab_type": "code", + "id": "GabAq2klI-ub", "pycharm": { "name": "#%%\n" - }, - "id": "GabAq2klI-ub", - "colab_type": "code", - "colab": {} + } }, + "outputs": [], "source": [ "def alphazero(config: AlphaZeroConfig, network: Net):\n", "\n", @@ -503,10 +455,10 @@ " for i in range(config.cycles):\n", " print(f\"self play {i} of {config.cycles}\")\n", " network.eval()\n", - " games = ??? # TODO \n", + " games = run_selfplay(config, network) # TODO \n", " print(f\"train network {i} of {config.cycles}\")\n", " network.train()\n", - " train_network(???, ???) # TODO\n", + " train_network(config, games) # TODO\n", "\n", " return network\n", "\n", @@ -615,7 +567,7 @@ " # should be in the order of the columns of our connect 4 board. \n", " # i.e. policy_logits[0] ∝ how much our network likes column 1. \n", " \n", - " value, policy_logits = ??? # TODO: take a look at back at the NN for a hint\n", + " value, policy_logits = network.inference(game.make_image(len(game.history))) # TODO: take a look at back at the NN for a hint\n", " # the game class may also have some useful functions for this\n", "\n", " # Expand the node.\n", @@ -625,7 +577,7 @@ " for action, p in iter(policy.items()): \n", " # this is just softmax, notice the math.exp 3 lines up\n", " node.children[action] = Node(p / policy_sum) \n", - " return ??? # TODO: what are we returning from this?\n", + " return value # TODO: what are we returning from this?\n", "\n", "\n", "# At the end of a simulation, we propagate the evaluation all the way up the\n", @@ -668,15 +620,13 @@ " return torch.utils.data.DataLoader(dataset=batch_data,\n", " batch_size=config.batch_size,\n", " shuffle=True)" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Z2uhK7q3YHOx", - "colab_type": "text" + "colab_type": "text", + "id": "Z2uhK7q3YHOx" }, "source": [ "# 5. To improve is to change, to be perfect is to change often\n", @@ -693,11 +643,13 @@ }, { "cell_type": "code", + "execution_count": 5, "metadata": { - "id": "Ao5Xx6onYEwd", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "Ao5Xx6onYEwd" }, + "outputs": [], "source": [ "def train_network(config: AlphaZeroConfig, games: List[Game]):\n", " \n", @@ -707,7 +659,7 @@ " network.parameters(),\n", " lr=config.learning_rate,\n", " momentum=config.momentum,\n", - " ???\n", + " weight_decay=config.weight_decay\n", " )\n", "\n", "\n", @@ -723,7 +675,7 @@ " optimizer.zero_grad()\n", " \n", " # TODO: get the policy and the value from the network\n", - " policy, value = ???\n", + " policy, value = network(image)\n", " \n", " # convert data to correct type\n", " policy = policy.exp()\n", @@ -733,8 +685,8 @@ " # for the value_target and policy_target add .to(device) to make the tensors happy because \n", " # we like happy tensors. The value and the policy do not need it, they are happy tensors already.\n", " # Also nn.functional has nifty functions for computing loss\n", - " value_loss = ???\n", - " policy_loss = ???\n", + " value_loss = nn.functional.mse_loss(value, value_target.to(device))\n", + " policy_loss = nn.functional.binary_cross_entropy(policy, policy_target.to(device))\n", "\n", " loss = value_loss + policy_loss\n", "\n", @@ -743,16 +695,14 @@ "\n", " # Use the derivative information to update the parameters\n", " optimizer.step()\n", - " print(\"Batch: %d Loss: %f\" % (batch_num, loss))\n" - ], - "execution_count": 0, - "outputs": [] + " print(\"Batch: %d Loss: %f\" % (batch_num, loss))" + ] }, { "cell_type": "markdown", "metadata": { - "id": "9U3hnFz5YU5r", - "colab_type": "text" + "colab_type": "text", + "id": "9U3hnFz5YU5r" }, "source": [ "# 6. Challenge Accepted!\n", @@ -762,11 +712,13 @@ }, { "cell_type": "code", + "execution_count": 6, "metadata": { - "id": "Iq9bmoypYcxs", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "Iq9bmoypYcxs" }, + "outputs": [], "source": [ "def get_human_action(i: int, game: Game):\n", " while True:\n", @@ -813,15 +765,13 @@ " break\n", " except ValueError:\n", " print(\"illegal action\")" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "_eAu4afHR-oN", - "colab_type": "text" + "colab_type": "text", + "id": "_eAu4afHR-oN" }, "source": [ "# 7. Let the Singularity Begin!\n", @@ -831,41 +781,41 @@ }, { "cell_type": "code", + "execution_count": 7, "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "TxQRHX2qI-ue", - "colab_type": "code", - "outputId": "daea46fd-4539-4ed8-ec0f-f07a2c7a8c51", "colab": { "base_uri": "https://localhost:8080/", "height": 34 + }, + "colab_type": "code", + "id": "TxQRHX2qI-ue", + "outputId": "daea46fd-4539-4ed8-ec0f-f07a2c7a8c51", + "pycharm": { + "name": "#%%\n" } }, - "source": [ - "\n", - "print(\"Device: %s\" % device)\n", - "network = Net().to(device)\n", - "config = AlphaZeroConfig()\n", - "interactive_game(config, network)\n" - ], - "execution_count": 12, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:0\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "\n", + "print(\"Device: %s\" % device)\n", + "network = Net().to(device)\n", + "config = AlphaZeroConfig()\n", + "#interactive_game(config, network)\n" ] }, { "cell_type": "markdown", "metadata": { - "id": "8LKc9cWURREV", - "colab_type": "text" + "colab_type": "text", + "id": "8LKc9cWURREV" }, "source": [ "# 8. I need more data to beat you human\n", @@ -875,23 +825,19 @@ }, { "cell_type": "code", + "execution_count": 8, "metadata": { - "id": "eIx0dFdDQpSN", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, + "colab_type": "code", + "id": "eIx0dFdDQpSN", "outputId": "1cfd4a66-c88d-452a-c0a5-1d181ce80833" }, - "source": [ - "# this will train the network on self play games\n", - "alphazero(config, network)\n", - "interactive_game(config, network)" - ], - "execution_count": 13, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "self play 0 of 5\n", @@ -899,226 +845,226 @@ "game 10 of 30\n", "game 20 of 30\n", "train network 0 of 5\n", - "Batch: 0 Loss: 1.261694\n", - "Batch: 1 Loss: 1.116446\n", - "Batch: 2 Loss: 1.341696\n", - "Batch: 3 Loss: 0.909902\n", - "Batch: 4 Loss: 1.124683\n", - "Batch: 5 Loss: 1.176385\n", - "Batch: 6 Loss: 0.939910\n", - "Batch: 7 Loss: 1.149868\n", - "Batch: 8 Loss: 0.891490\n", - "Batch: 9 Loss: 1.200505\n", - "Batch: 10 Loss: 1.013344\n", - "Batch: 11 Loss: 0.928823\n", - "Batch: 12 Loss: 0.666806\n", - "Batch: 13 Loss: 0.627564\n", - "Batch: 14 Loss: 0.985499\n", - "Batch: 15 Loss: 0.615684\n", - "Batch: 16 Loss: 0.554353\n", - "Batch: 17 Loss: 0.730892\n", - "Batch: 18 Loss: 0.651778\n", - "Batch: 19 Loss: 0.558712\n", - "Batch: 20 Loss: 0.582786\n", - "Batch: 21 Loss: 0.609675\n", - "Batch: 22 Loss: 0.466753\n", - "Batch: 23 Loss: 0.523876\n", - "Batch: 24 Loss: 0.704813\n", - "Batch: 25 Loss: 0.547899\n", - "Batch: 26 Loss: 0.490743\n", - "Batch: 27 Loss: 0.487145\n", - "Batch: 28 Loss: 0.389103\n", - "Batch: 29 Loss: 0.394802\n", - "Batch: 30 Loss: 0.924192\n", - "Batch: 31 Loss: 0.693056\n", - "Batch: 32 Loss: 0.501567\n", - "Batch: 33 Loss: 0.470897\n", - "Batch: 34 Loss: 0.515070\n", - "Batch: 35 Loss: 0.418546\n", - "Batch: 36 Loss: 1.033680\n", - "Batch: 37 Loss: 0.425282\n", - "Batch: 38 Loss: 0.530426\n", - "Batch: 39 Loss: 0.429059\n", + "Batch: 0 Loss: 1.297457\n", + "Batch: 1 Loss: 1.070922\n", + "Batch: 2 Loss: 1.005536\n", + "Batch: 3 Loss: 0.739464\n", + "Batch: 4 Loss: 0.658090\n", + "Batch: 5 Loss: 0.601338\n", + "Batch: 6 Loss: 0.618326\n", + "Batch: 7 Loss: 0.574111\n", + "Batch: 8 Loss: 0.593461\n", + "Batch: 9 Loss: 0.523399\n", + "Batch: 10 Loss: 0.508888\n", + "Batch: 11 Loss: 0.526801\n", + "Batch: 12 Loss: 0.483138\n", + "Batch: 13 Loss: 0.520305\n", + "Batch: 14 Loss: 0.396360\n", + "Batch: 15 Loss: 0.429715\n", + "Batch: 16 Loss: 0.392540\n", + "Batch: 17 Loss: 0.409773\n", + "Batch: 18 Loss: 0.430012\n", + "Batch: 19 Loss: 0.475215\n", + "Batch: 20 Loss: 0.499511\n", + "Batch: 21 Loss: 0.478047\n", + "Batch: 22 Loss: 0.518550\n", + "Batch: 23 Loss: 0.351893\n", + "Batch: 24 Loss: 0.361895\n", + "Batch: 25 Loss: 0.481955\n", + "Batch: 26 Loss: 0.377299\n", + "Batch: 27 Loss: 0.337830\n", + "Batch: 28 Loss: 0.408330\n", + "Batch: 29 Loss: 0.605096\n", + "Batch: 30 Loss: 0.362263\n", + "Batch: 31 Loss: 0.361151\n", + "Batch: 32 Loss: 0.410002\n", + "Batch: 33 Loss: 0.304817\n", + "Batch: 34 Loss: 0.323277\n", + "Batch: 35 Loss: 0.495506\n", + "Batch: 36 Loss: 0.416801\n", + "Batch: 37 Loss: 0.395975\n", + "Batch: 38 Loss: 0.281300\n", + "Batch: 39 Loss: 0.317990\n", "self play 1 of 5\n", "game 0 of 30\n", "game 10 of 30\n", "game 20 of 30\n", "train network 1 of 5\n", - "Batch: 0 Loss: 1.173493\n", - "Batch: 1 Loss: 0.866215\n", - "Batch: 2 Loss: 0.892076\n", - "Batch: 3 Loss: 0.860861\n", - "Batch: 4 Loss: 0.776542\n", - "Batch: 5 Loss: 0.606288\n", - "Batch: 6 Loss: 0.777582\n", - "Batch: 7 Loss: 0.679524\n", - "Batch: 8 Loss: 0.555076\n", - "Batch: 9 Loss: 0.627292\n", - "Batch: 10 Loss: 0.605402\n", - "Batch: 11 Loss: 0.639736\n", - "Batch: 12 Loss: 0.661965\n", - "Batch: 13 Loss: 0.627451\n", - "Batch: 14 Loss: 0.463242\n", - "Batch: 15 Loss: 0.561799\n", - "Batch: 16 Loss: 0.507968\n", - "Batch: 17 Loss: 0.438971\n", - "Batch: 18 Loss: 0.496615\n", - "Batch: 19 Loss: 0.448660\n", - "Batch: 20 Loss: 0.476654\n", - "Batch: 21 Loss: 0.435443\n", - "Batch: 22 Loss: 0.507137\n", - "Batch: 23 Loss: 0.492124\n", - "Batch: 24 Loss: 0.399255\n", - "Batch: 25 Loss: 0.413275\n", - "Batch: 26 Loss: 0.437310\n", - "Batch: 27 Loss: 0.479034\n", - "Batch: 28 Loss: 0.514366\n", - "Batch: 29 Loss: 0.411214\n", - "Batch: 30 Loss: 0.390972\n", - "Batch: 31 Loss: 0.388191\n", - "Batch: 32 Loss: 0.405707\n", - "Batch: 33 Loss: 0.389139\n", - "Batch: 34 Loss: 0.444981\n", - "Batch: 35 Loss: 0.536180\n", - "Batch: 36 Loss: 0.514434\n", - "Batch: 37 Loss: 0.528833\n", - "Batch: 38 Loss: 0.378289\n", - "Batch: 39 Loss: 0.402099\n", + "Batch: 0 Loss: 1.061677\n", + "Batch: 1 Loss: 0.573583\n", + "Batch: 2 Loss: 0.667789\n", + "Batch: 3 Loss: 0.447696\n", + "Batch: 4 Loss: 0.539596\n", + "Batch: 5 Loss: 0.469717\n", + "Batch: 6 Loss: 0.474496\n", + "Batch: 7 Loss: 0.495829\n", + "Batch: 8 Loss: 0.508060\n", + "Batch: 9 Loss: 0.553491\n", + "Batch: 10 Loss: 0.631098\n", + "Batch: 11 Loss: 0.371213\n", + "Batch: 12 Loss: 0.416165\n", + "Batch: 13 Loss: 0.339832\n", + "Batch: 14 Loss: 0.342511\n", + "Batch: 15 Loss: 0.442871\n", + "Batch: 16 Loss: 0.557983\n", + "Batch: 17 Loss: 0.490888\n", + "Batch: 18 Loss: 0.307327\n", + "Batch: 19 Loss: 0.370288\n", + "Batch: 20 Loss: 0.510485\n", + "Batch: 21 Loss: 0.630249\n", + "Batch: 22 Loss: 0.407610\n", + "Batch: 23 Loss: 0.306060\n", + "Batch: 24 Loss: 0.394436\n", + "Batch: 25 Loss: 0.400511\n", + "Batch: 26 Loss: 0.333179\n", + "Batch: 27 Loss: 0.328875\n", + "Batch: 28 Loss: 0.512536\n", + "Batch: 29 Loss: 0.296035\n", + "Batch: 30 Loss: 0.341999\n", + "Batch: 31 Loss: 0.392925\n", + "Batch: 32 Loss: 0.368058\n", + "Batch: 33 Loss: 0.470829\n", + "Batch: 34 Loss: 0.387786\n", + "Batch: 35 Loss: 0.458152\n", + "Batch: 36 Loss: 0.328859\n", + "Batch: 37 Loss: 0.425190\n", + "Batch: 38 Loss: 0.304928\n", + "Batch: 39 Loss: 0.258104\n", "self play 2 of 5\n", "game 0 of 30\n", "game 10 of 30\n", "game 20 of 30\n", "train network 2 of 5\n", - "Batch: 0 Loss: 0.947907\n", - "Batch: 1 Loss: 0.754295\n", - "Batch: 2 Loss: 0.661248\n", - "Batch: 3 Loss: 0.693746\n", - "Batch: 4 Loss: 0.644416\n", - "Batch: 5 Loss: 0.544810\n", - "Batch: 6 Loss: 0.499344\n", - "Batch: 7 Loss: 0.610480\n", - "Batch: 8 Loss: 0.460812\n", - "Batch: 9 Loss: 0.523646\n", - "Batch: 10 Loss: 0.466662\n", - "Batch: 11 Loss: 0.432250\n", - "Batch: 12 Loss: 0.470015\n", - "Batch: 13 Loss: 0.488527\n", - "Batch: 14 Loss: 0.465875\n", - "Batch: 15 Loss: 0.441104\n", - "Batch: 16 Loss: 0.537610\n", - "Batch: 17 Loss: 0.432288\n", - "Batch: 18 Loss: 0.404103\n", - "Batch: 19 Loss: 0.401291\n", - "Batch: 20 Loss: 0.414535\n", - "Batch: 21 Loss: 0.427214\n", - "Batch: 22 Loss: 0.401292\n", - "Batch: 23 Loss: 0.466994\n", - "Batch: 24 Loss: 0.526220\n", - "Batch: 25 Loss: 0.485886\n", - "Batch: 26 Loss: 0.480067\n", - "Batch: 27 Loss: 0.386039\n", - "Batch: 28 Loss: 0.388183\n", - "Batch: 29 Loss: 0.444859\n", - "Batch: 30 Loss: 0.391749\n", - "Batch: 31 Loss: 0.418970\n", - "Batch: 32 Loss: 0.423082\n", - "Batch: 33 Loss: 0.375627\n", - "Batch: 34 Loss: 0.400309\n", - "Batch: 35 Loss: 0.366517\n", - "Batch: 36 Loss: 0.439064\n", - "Batch: 37 Loss: 0.455128\n", - "Batch: 38 Loss: 0.460750\n", - "Batch: 39 Loss: 0.535893\n", + "Batch: 0 Loss: 0.989657\n", + "Batch: 1 Loss: 0.716117\n", + "Batch: 2 Loss: 0.535005\n", + "Batch: 3 Loss: 0.671297\n", + "Batch: 4 Loss: 0.435962\n", + "Batch: 5 Loss: 0.379089\n", + "Batch: 6 Loss: 0.640533\n", + "Batch: 7 Loss: 0.457454\n", + "Batch: 8 Loss: 0.448093\n", + "Batch: 9 Loss: 0.546048\n", + "Batch: 10 Loss: 0.503211\n", + "Batch: 11 Loss: 0.394836\n", + "Batch: 12 Loss: 0.358743\n", + "Batch: 13 Loss: 0.360110\n", + "Batch: 14 Loss: 0.405433\n", + "Batch: 15 Loss: 0.350829\n", + "Batch: 16 Loss: 0.323355\n", + "Batch: 17 Loss: 0.375008\n", + "Batch: 18 Loss: 0.327038\n", + "Batch: 19 Loss: 0.449789\n", + "Batch: 20 Loss: 0.333693\n", + "Batch: 21 Loss: 0.389275\n", + "Batch: 22 Loss: 0.461503\n", + "Batch: 23 Loss: 0.438522\n", + "Batch: 24 Loss: 0.381604\n", + "Batch: 25 Loss: 0.497237\n", + "Batch: 26 Loss: 0.396383\n", + "Batch: 27 Loss: 0.315628\n", + "Batch: 28 Loss: 0.437005\n", + "Batch: 29 Loss: 0.308866\n", + "Batch: 30 Loss: 0.363945\n", + "Batch: 31 Loss: 0.455962\n", + "Batch: 32 Loss: 0.379210\n", + "Batch: 33 Loss: 0.413945\n", + "Batch: 34 Loss: 0.314033\n", + "Batch: 35 Loss: 0.470453\n", + "Batch: 36 Loss: 0.421071\n", + "Batch: 37 Loss: 0.342935\n", + "Batch: 38 Loss: 0.390231\n", + "Batch: 39 Loss: 0.349400\n", "self play 3 of 5\n", "game 0 of 30\n", "game 10 of 30\n", "game 20 of 30\n", "train network 3 of 5\n", - "Batch: 0 Loss: 1.034783\n", - "Batch: 1 Loss: 0.856359\n", - "Batch: 2 Loss: 0.596523\n", - "Batch: 3 Loss: 0.529779\n", - "Batch: 4 Loss: 0.439595\n", - "Batch: 5 Loss: 0.441252\n", - "Batch: 6 Loss: 0.464327\n", - "Batch: 7 Loss: 0.433184\n", - "Batch: 8 Loss: 0.388274\n", - "Batch: 9 Loss: 0.414981\n", - "Batch: 10 Loss: 0.406084\n", - "Batch: 11 Loss: 0.473081\n", - "Batch: 12 Loss: 0.440911\n", - "Batch: 13 Loss: 0.433353\n", - "Batch: 14 Loss: 0.389613\n", - "Batch: 15 Loss: 0.392770\n", - "Batch: 16 Loss: 0.416716\n", - "Batch: 17 Loss: 0.440191\n", - "Batch: 18 Loss: 0.495331\n", - "Batch: 19 Loss: 0.386110\n", - "Batch: 20 Loss: 0.404552\n", - "Batch: 21 Loss: 0.440497\n", - "Batch: 22 Loss: 0.391942\n", - "Batch: 23 Loss: 0.410947\n", - "Batch: 24 Loss: 0.379390\n", - "Batch: 25 Loss: 0.421499\n", - "Batch: 26 Loss: 0.386556\n", - "Batch: 27 Loss: 0.404068\n", - "Batch: 28 Loss: 0.397863\n", - "Batch: 29 Loss: 0.351861\n", - "Batch: 30 Loss: 0.358579\n", - "Batch: 31 Loss: 0.352469\n", - "Batch: 32 Loss: 0.347363\n", - "Batch: 33 Loss: 0.401403\n", - "Batch: 34 Loss: 0.405339\n", - "Batch: 35 Loss: 0.391409\n", - "Batch: 36 Loss: 0.417322\n", - "Batch: 37 Loss: 0.375573\n", - "Batch: 38 Loss: 0.423612\n", - "Batch: 39 Loss: 0.371811\n", + "Batch: 0 Loss: 0.718287\n", + "Batch: 1 Loss: 0.609580\n", + "Batch: 2 Loss: 0.427867\n", + "Batch: 3 Loss: 0.436488\n", + "Batch: 4 Loss: 0.403250\n", + "Batch: 5 Loss: 0.376567\n", + "Batch: 6 Loss: 0.482127\n", + "Batch: 7 Loss: 0.410601\n", + "Batch: 8 Loss: 0.407043\n", + "Batch: 9 Loss: 0.429694\n", + "Batch: 10 Loss: 0.389290\n", + "Batch: 11 Loss: 0.408060\n", + "Batch: 12 Loss: 0.426255\n", + "Batch: 13 Loss: 0.483737\n", + "Batch: 14 Loss: 0.439357\n", + "Batch: 15 Loss: 0.636192\n", + "Batch: 16 Loss: 0.480884\n", + "Batch: 17 Loss: 0.449515\n", + "Batch: 18 Loss: 0.383923\n", + "Batch: 19 Loss: 0.438431\n", + "Batch: 20 Loss: 0.396062\n", + "Batch: 21 Loss: 0.362935\n", + "Batch: 22 Loss: 0.465432\n", + "Batch: 23 Loss: 0.373293\n", + "Batch: 24 Loss: 0.450536\n", + "Batch: 25 Loss: 0.368027\n", + "Batch: 26 Loss: 0.347392\n", + "Batch: 27 Loss: 0.372168\n", + "Batch: 28 Loss: 0.410154\n", + "Batch: 29 Loss: 0.402116\n", + "Batch: 30 Loss: 0.410266\n", + "Batch: 31 Loss: 0.394792\n", + "Batch: 32 Loss: 0.368578\n", + "Batch: 33 Loss: 0.351414\n", + "Batch: 34 Loss: 0.397556\n", + "Batch: 35 Loss: 0.329613\n", + "Batch: 36 Loss: 0.330286\n", + "Batch: 37 Loss: 0.371017\n", + "Batch: 38 Loss: 0.342738\n", + "Batch: 39 Loss: 0.370894\n", "self play 4 of 5\n", "game 0 of 30\n", "game 10 of 30\n", "game 20 of 30\n", "train network 4 of 5\n", - "Batch: 0 Loss: 0.818256\n", - "Batch: 1 Loss: 1.123486\n", - "Batch: 2 Loss: 0.876669\n", - "Batch: 3 Loss: 0.402042\n", - "Batch: 4 Loss: 0.561577\n", - "Batch: 5 Loss: 0.449640\n", - "Batch: 6 Loss: 0.392212\n", - "Batch: 7 Loss: 0.434513\n", - "Batch: 8 Loss: 0.388681\n", - "Batch: 9 Loss: 0.397521\n", - "Batch: 10 Loss: 0.405278\n", - "Batch: 11 Loss: 0.428488\n", - "Batch: 12 Loss: 0.519868\n", - "Batch: 13 Loss: 0.387626\n", - "Batch: 14 Loss: 0.552459\n", - "Batch: 15 Loss: 0.383136\n", - "Batch: 16 Loss: 0.366464\n", - "Batch: 17 Loss: 0.373240\n", - "Batch: 18 Loss: 0.398293\n", - "Batch: 19 Loss: 0.440705\n", - "Batch: 20 Loss: 0.431868\n", - "Batch: 21 Loss: 0.513310\n", - "Batch: 22 Loss: 0.370319\n", - "Batch: 23 Loss: 0.416570\n", - "Batch: 24 Loss: 0.837449\n", - "Batch: 25 Loss: 0.342340\n", - "Batch: 26 Loss: 0.370385\n", - "Batch: 27 Loss: 0.403717\n", - "Batch: 28 Loss: 0.357792\n", - "Batch: 29 Loss: 0.372590\n", - "Batch: 30 Loss: 0.512121\n", - "Batch: 31 Loss: 0.372972\n", - "Batch: 32 Loss: 0.712874\n", - "Batch: 33 Loss: 0.334123\n", - "Batch: 34 Loss: 0.627990\n", - "Batch: 35 Loss: 0.524298\n", - "Batch: 36 Loss: 0.439857\n", - "Batch: 37 Loss: 0.469408\n", - "Batch: 38 Loss: 0.338335\n", - "Batch: 39 Loss: 0.367137\n", + "Batch: 0 Loss: 1.389253\n", + "Batch: 1 Loss: 1.023016\n", + "Batch: 2 Loss: 0.778526\n", + "Batch: 3 Loss: 0.621589\n", + "Batch: 4 Loss: 0.748947\n", + "Batch: 5 Loss: 0.588811\n", + "Batch: 6 Loss: 0.616743\n", + "Batch: 7 Loss: 0.619884\n", + "Batch: 8 Loss: 0.599840\n", + "Batch: 9 Loss: 0.453021\n", + "Batch: 10 Loss: 0.570860\n", + "Batch: 11 Loss: 0.547868\n", + "Batch: 12 Loss: 1.003815\n", + "Batch: 13 Loss: 0.430471\n", + "Batch: 14 Loss: 0.526788\n", + "Batch: 15 Loss: 0.551975\n", + "Batch: 16 Loss: 0.606334\n", + "Batch: 17 Loss: 0.448370\n", + "Batch: 18 Loss: 0.570277\n", + "Batch: 19 Loss: 0.415983\n", + "Batch: 20 Loss: 0.476248\n", + "Batch: 21 Loss: 0.461724\n", + "Batch: 22 Loss: 0.640585\n", + "Batch: 23 Loss: 0.595707\n", + "Batch: 24 Loss: 0.654544\n", + "Batch: 25 Loss: 0.461624\n", + "Batch: 26 Loss: 0.533261\n", + "Batch: 27 Loss: 0.572199\n", + "Batch: 28 Loss: 0.531856\n", + "Batch: 29 Loss: 0.519722\n", + "Batch: 30 Loss: 0.554501\n", + "Batch: 31 Loss: 0.716420\n", + "Batch: 32 Loss: 0.490347\n", + "Batch: 33 Loss: 0.412190\n", + "Batch: 34 Loss: 0.524030\n", + "Batch: 35 Loss: 0.618971\n", + "Batch: 36 Loss: 0.461805\n", + "Batch: 37 Loss: 0.508472\n", + "Batch: 38 Loss: 0.597108\n", + "Batch: 39 Loss: 0.558429\n", "0| |\n", "1| |\n", "2| |\n", @@ -1126,89 +1072,256 @@ "4| |\n", "5| |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "Player 0 choose move please: 5\n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "0| |\n", "1| |\n", "2| |\n", "3| |\n", "4| |\n", - "5| ○ |\n", + "5| ○ |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "ai chooses 0\n", + "ai chooses 6\n", "0| |\n", "1| |\n", "2| |\n", "3| |\n", "4| |\n", - "5| ● ○ |\n", + "5| ○ ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "Player 0 choose move please: 4\n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "0| |\n", "1| |\n", "2| |\n", "3| |\n", - "4| |\n", - "5| ● ○ ○ |\n", + "4| ○ |\n", + "5| ○ ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "ai chooses 4\n", + "ai chooses 3\n", "0| |\n", "1| |\n", "2| |\n", - "3| |\n", - "4| ● |\n", - "5| ● ○ ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "Player 0 choose move please: 3\n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "0| |\n", "1| |\n", - "2| |\n", - "3| |\n", - "4| ● |\n", - "5| ● ○ ○ ○ |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "ai chooses 4\n", + "ai chooses 3\n", "0| |\n", - "1| |\n", - "2| |\n", - "3| ● |\n", - "4| ● |\n", - "5| ● ○ ○ ○ |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ○ ○ ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "Player 0 choose move please: 2\n", + "ai chooses 2\n", "0| |\n", - "1| |\n", - "2| |\n", - "3| ● |\n", - "4| ● |\n", - "5| ● ○ ○ ○ ○ |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ |\n", + "5| ● ○ ○ ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", - "ai chooses 6\n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "0| |\n", - "1| |\n", - "2| |\n", - "3| ● |\n", - "4| ● |\n", - "5| ● ○ ○ ○ ○ ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 3\n", + "0| ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 5\n", + "0| ● |\n", + "1| ● |\n", + "2| ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "Player 0 choose move please: " + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0| ● |\n", + "1| ● |\n", + "2| ○ ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ |\n", + "5| ● ○ ○ ● ● |\n", + " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", + "ai chooses 6\n", + "0| ● |\n", + "1| ● |\n", + "2| ○ ○ |\n", + "3| ● ○ |\n", + "4| ○ ○ ● |\n", + "5| ● ○ ○ ● ● |\n", " ̅0 ̅1 ̅2 ̅3 ̅4 ̅5 ̅6 \n", "player 0 won\n", "player 1 lost\n", - "Play again? y or n?n\n" - ], - "name": "stdout" + "Play again? y or n?" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I didn't understand\n", + "Play again? y or n?" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " n\n" + ] } + ], + "source": [ + "# this will train the network on self play games\n", + "alphazero(config, network)\n", + "interactive_game(config, network)" ] }, { "cell_type": "markdown", "metadata": { - "id": "kiQd0snNdeBP", - "colab_type": "text" + "colab_type": "text", + "id": "kiQd0snNdeBP" }, "source": [ "# Question\n", "---\n", "Each time we start training the NN the error increases from the previous gradient descent steps. Is this an artifact? or why does this make sense?\n", "\n", + "This does make sense. As the network explore more avenues of decision making, greater losses are incured with each path and thus there is an appeareance of stagnant learning relative to our typical experiences with classification. \n", + "\n", "# 9. Round Two!\n", "---\n", "Ok our algorithm just turned bits into strait gains. Lets try to play it again." @@ -1216,22 +1329,22 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "veD94lDcd54y", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "veD94lDcd54y" }, + "outputs": [], "source": [ "interactive_game(config, network)" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "cPC8Rs-VSOfg", - "colab_type": "text" + "colab_type": "text", + "id": "cPC8Rs-VSOfg" }, "source": [ "\n", @@ -1240,5 +1353,48 @@ "Ok so its not a connect 4 champ in a few cycles but it does get better. Time permitting, run another few cycles to see how good you can get it. " ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "kqVyB4PiI-uM", + "SNHXdaS0I-uR", + "efYv8f0bI-uW", + "Z2uhK7q3YHOx", + "9U3hnFz5YU5r", + "_eAu4afHR-oN" + ], + "name": "AlphaZero_exercise.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}