diff --git a/bin/zobristTable.dat b/bin/zobristTable.dat index 1fd7b4f..85c4cd5 100644 Binary files a/bin/zobristTable.dat and b/bin/zobristTable.dat differ diff --git a/src/main/java/Main.java b/src/main/java/Main.java index 54156f1..134ae90 100644 --- a/src/main/java/Main.java +++ b/src/main/java/Main.java @@ -1,4 +1,5 @@ import State.State; +import Tree.Heuristics; import Tree.MonteCarloTree; import ygraph.ai.smartfox.games.BaseGameGUI; import ygraph.ai.smartfox.games.GameClient; @@ -11,6 +12,7 @@ import java.io.ObjectInputStream; import java.util.ArrayList; import java.util.Map; +import java.util.Scanner; import State.*; @@ -27,7 +29,7 @@ public class Main extends GamePlayer { private int colour; private MonteCarloTree monteCarloTree; - private final double cValue = 2.0; + private final double cValue = 1.4; private int depth = 0; private static int[] moveDictionary; @@ -38,8 +40,10 @@ public class Main extends GamePlayer { */ public static void main(String[] args) { GamePlayer player; + String name = "new"; + System.out.println(name); if (args.length == 2) - player = new Main(args[0] + "-" + ((int) (Math.random() * 1000)), args[1]); + player = new Main(name, args[1]); else player = new HumanPlayer(); @@ -132,13 +136,38 @@ public boolean handleGameMessage(String messageType, Map msgDeta private void makeMove() { makeMonteCarloMove(); +// theClown(); depth++; } + private void theClown() { + ArrayList actions = ActionGenerator.generateActions(state, colour); + Action definitelyTheBestAction = null; + double bestH = Math.pow(-1, colour) * Integer.MAX_VALUE; + for (Action a : actions) { + double h = Heuristics.bigPoppa(new State(state, a), colour); + if (colour == 1 && h > bestH) { + definitelyTheBestAction = a; + bestH = h; + } else if (colour == 2 && h < bestH) { + definitelyTheBestAction = a; + bestH = h; + } + } + if (definitelyTheBestAction == null) { + System.out.println("OPPONENT WINS!!"); + new Scanner(System.in).nextLine(); + } + state = new State(state, definitelyTheBestAction); + getGameClient().sendMoveMessage(definitelyTheBestAction.toServerResponse()); + getGameGUI().updateGameState(definitelyTheBestAction.toServerResponse()); + } + private void makeMonteCarloMove() { + System.out.println("colour = " + colour); long start = System.currentTimeMillis(); if (action != null) - monteCarloTree.updateRoot(state, action, colour, depth); + monteCarloTree = new MonteCarloTree(state, cValue, colour, depth, moveDictionary); Action definitelyTheBestAction = monteCarloTree.search(); if (definitelyTheBestAction == null) { System.out.println("OPPONENT WINS!!"); diff --git a/src/main/java/State/ZobristHash.java b/src/main/java/State/ZobristHash.java new file mode 100644 index 0000000..5158ff5 --- /dev/null +++ b/src/main/java/State/ZobristHash.java @@ -0,0 +1,121 @@ +package State; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.Random; + +public class ZobristHash { + private static boolean initialized = false; + private static long[] zobristTable; + private static long blackToMove; + + public static long zobristHash(BitBoard bitBoard, int colorToMove) { + if (!initialized) init(); + + long hash = 0; + + if (colorToMove == State.BLACK_QUEEN) + hash = blackToMove; + + int index = 0; // Start at 0 for black + long currentBoard = bitBoard.getBlackQueensBottom(); + while (currentBoard > 0) { + if ((currentBoard & 1L) == 1) + hash ^= zobristTable[index]; + currentBoard >>= 1; + index += 3; + } + index = 150; // Make sure that the index is correct going into the next loop because the last loop likely didn't execute a full 50 times + currentBoard = bitBoard.getBlackQueensTop(); + while (currentBoard > 0) { + if ((currentBoard & 1L) == 1) + hash ^= zobristTable[index]; + currentBoard >>= 1; + index += 3; + } + + index = 1; // Start at 1 for white + currentBoard = bitBoard.getWhiteQueensBottom(); + while (currentBoard > 0) { + if ((currentBoard & 1L) == 1) + hash ^= zobristTable[index]; + currentBoard >>= 1; + index += 3; + } + index = 151; // Make sure that the index is correct going into the next loop because the last loop likely didn't execute a full 50 times + currentBoard = bitBoard.getWhiteQueensTop(); + while (currentBoard > 0) { + if ((currentBoard & 1L) == 1) + hash ^= zobristTable[index]; + currentBoard >>= 1; + index += 3; + } + + index = 2; // Start at 2 for arrows + currentBoard = bitBoard.getArrowBottom(); + while (currentBoard > 0) { + if ((currentBoard & 1L) == 1) + hash ^= zobristTable[index]; + currentBoard >>= 1; + index += 3; + } + index = 152; // Make sure that the index is correct going into the next loop because the last loop likely didn't execute a full 50 times + currentBoard = bitBoard.getArrowTop(); + while (currentBoard > 0) { + if ((currentBoard & 1L) == 1) + hash ^= zobristTable[index]; + currentBoard >>= 1; + index += 3; + } + + return hash; + } + + public static void init() { + Path path = Paths.get("bin/zobristTable.dat"); + try (ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(path)))) { + zobristTable = (long[]) in.readObject(); + blackToMove = in.readLong(); + } catch (Exception e) { + // Generate table if it doesn't exist or there are any errors reading it + + // Addressed boardIndex * 3 + (pieceNumber - 1) // 0 = black, 1 = white, 2 = arrow + // This does waste 25% of the space, but is more efficient for lookup + zobristTable = new long[300]; + + // Fill table with random longs + HashSet randomLongs = new HashSet<>(); + Random r = new Random(); + for (int i = 0; i < zobristTable.length; i++) { + // Fill table with random longs making sure there are no duplicates (even though it's an extremely small chance) + long randLong = r.nextLong(); + while (randomLongs.contains(randLong)) + randLong = r.nextLong(); + randomLongs.add(randLong); + zobristTable[i] = randLong; + } + + // Also generate another bitstring for if black is the player that makes the next move + blackToMove = r.nextLong(); + while (randomLongs.contains(blackToMove)) + blackToMove = r.nextLong(); + + // Save to file + try (ObjectOutputStream out = new ObjectOutputStream(Files.newOutputStream(path))) { + out.writeObject(zobristTable); + out.writeLong(blackToMove); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + + // Only set initialized to true if we make it here without an exception + initialized = true; + } +} diff --git a/src/main/java/Tests/HeuristicTesting.java b/src/main/java/Tests/HeuristicTesting.java index eade447..2629568 100644 --- a/src/main/java/Tests/HeuristicTesting.java +++ b/src/main/java/Tests/HeuristicTesting.java @@ -1,44 +1,829 @@ package Tests; -import State.Action; -import State.ActionGenerator; -import State.State; +import State.*; import Tree.Heuristics; +import java.sql.SQLOutput; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Random; public class HeuristicTesting { public static void main(String[] args) { - State s = new State(new ArrayList<>(Arrays.asList( + State state = new State(new ArrayList<>(Arrays.asList( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, + 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ))); + State s = randomState(); + BitBoard b = s.getBitBoard(); + System.out.println(b.boardToString()); + long n = (long) 1e6; + long start = System.nanoTime(); + for (long i = 0; i < n; i++) { + Heuristics.bigPoppa(s,1); + } + long end = System.nanoTime(); + System.out.println("Time: " + (end - start) / n + "ns"); + BitBoard testBoard = new BitBoard(); + +// for (int y = 9; y >= 0; y--) { +// for (int x = 0; x < 10; x++) { +// if (result[0][y*10 + x] == 1000000) { +// System.out.print("- "); +// } else { +// System.out.print(result[0][y*10+x] + " "); +// } +// } +// System.out.println(); +// } + +// testBoard.setArrowTop(result.getWhiteQueensTop()); +// testBoard.setArrowBottom(result.getWhiteQueensBottom()); +// System.out.println(testBoard.boardToString()); + } + + private static int[][] D2(State s) { + // Clone input so as not to modify it + BitBoard input; + try { + input = (BitBoard) s.getBitBoard().clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + // Initialize reachable array. Indexed as [color][moveNum][top/bottom] + long[][][] reachable = new long[2][10][2]; + BitBoard result; + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = kingReachableInOneMove(State.BLACK_QUEEN, input); + + reachable[0][nMoves][0] = result.getArrowTop();// result[0]; + reachable[0][nMoves][1] = result.getArrowBottom();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getBlackQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getBlackQueensBottom()); + input.setBlackQueensTop(result.getArrowTop()); + input.setBlackQueensBottom(result.getArrowBottom()); + } + + try { + input = (BitBoard) s.getBitBoard().clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = kingReachableInOneMove(State.WHITE_QUEEN, input); + + reachable[1][nMoves][0] = result.getArrowTop();// result[0]; + reachable[1][nMoves][1] = result.getArrowTop();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getWhiteQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getWhiteQueensBottom()); + input.setWhiteQueensTop(result.getArrowTop()); + input.setWhiteQueensBottom(result.getArrowBottom()); + } + + int[][] output = new int[2][100]; + + for (int color = 0; color < 2; color++) { + loop: for (int i = 0; i < 100; i++) { + for (int nMoves = 0; nMoves < 10; nMoves++) { + if (i < 50 && (reachable[color][nMoves][1] & (1L << i)) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } else if (i >= 50 && (reachable[color][nMoves][0] & (1L << (i - 50))) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } + } + output[color][i] = 1000000; + } + } + + return output; + } + + private static BitBoard kingReachableInOneMove(int color, BitBoard b) { + long boardMask = -1L >>> (64 - 50); + long aroundMask = 0b11100000001010000000111L; + long notAFile = 0b11111111101111111110111111111011111111101111111110L; + long notJFile = 0b01111111110111111111011111111101111111110111111111L; + + long blackTop = b.getBlackQueensTop(); + long blackBottom = b.getBlackQueensBottom(); + long whiteTop = b.getWhiteQueensTop(); + long whiteBottom = b.getWhiteQueensBottom(); + long arrowTop = b.getArrowTop(); + long arrowBottom = b.getArrowBottom(); + + long occupiedTop = blackTop | whiteTop | arrowTop; + long occupiedBottom = blackBottom | whiteBottom | arrowBottom; + + // All squares reachable in one move + long reachableTop = 0L; + long reachableBottom = 0L; + + int queenCount; + long[] queensTop, queensBottom; + + if (color == State.BLACK_QUEEN) { + // Find black queens + queenCount = Long.bitCount(blackTop) + Long.bitCount(blackBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (blackBottom > 1) { + queensBottom[i] = Long.lowestOneBit(blackBottom); + queensTop[i] = 0L; + blackBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(blackTop); + blackTop ^= queensTop[i]; + } + } + } else { + // Find white queens + queenCount = Long.bitCount(whiteTop) + Long.bitCount(whiteBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (whiteBottom > 1) { + queensBottom[i] = Long.lowestOneBit(whiteBottom); + queensTop[i] = 0L; + whiteBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(whiteTop); + whiteTop ^= queensTop[i]; + } + } + } + + + // Find all reachable squares + for (int pieceNum = 0; pieceNum < queenCount; pieceNum++) { + + // The piece we are moving + long startBottom = queensBottom[pieceNum]; + long startTop = queensTop[pieceNum]; + + /////////////////// + // Generate mask // + /////////////////// + + int index; + if (startTop > 0) { + index = 50 + Long.numberOfTrailingZeros(startTop); + } else { + index = Long.numberOfTrailingZeros(startBottom); + } + + long maskBottom, maskTop; + if (index > 39) { + if (index < 61) + maskTop = aroundMask >>> (61 - index); + else + maskTop = aroundMask << (index - 61); + } else { + maskTop = 0L; + } + + if (index < 61) { + if (index < 11) { + maskBottom = aroundMask >>> (11 - index); + } else { + maskBottom = aroundMask << (index - 11); + } + } else { + maskBottom = 0L; + } + + if (index % 10 == 0) { + maskTop &= notJFile; + maskBottom &= notJFile; + } else if (index % 10 == 9) { + maskTop &= notAFile; + maskBottom &= notAFile; + } + + maskTop &= boardMask; + maskBottom &= boardMask; + + + // Actually get squares we can move to + reachableTop |= ~occupiedTop & maskTop; + reachableBottom |= ~occupiedBottom & maskBottom; + } + + BitBoard out = new BitBoard(); + out.setArrowTop(reachableTop); + out.setArrowBottom(reachableBottom); + + return out; + } + + /** + * Returns the minimum number of moves required to reach each square from the given state. + * @param b The state to start from + * @return A 2D array of integers. The first index is the color, the second is the square. + */ + private static int[][] minMoves(BitBoard b) { + // Clone input so as not to modify it + BitBoard input; + try { + input = (BitBoard) b.clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + // Initialize reachable array. Indexed as [color][moveNum][top/bottom] + long[][][] reachable = new long[2][10][2]; + BitBoard result; + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = reachableInOneMove(State.BLACK_QUEEN, input); + + reachable[0][nMoves][0] = result.getArrowTop();// result[0]; + reachable[0][nMoves][1] = result.getArrowBottom();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getBlackQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getBlackQueensBottom()); + input.setBlackQueensTop(result.getArrowTop()); + input.setBlackQueensBottom(result.getArrowBottom()); + } + + try { + input = (BitBoard) b.clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = reachableInOneMove(State.WHITE_QUEEN, input); + + reachable[1][nMoves][0] = result.getArrowTop();// result[0]; + reachable[1][nMoves][1] = result.getArrowTop();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getWhiteQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getWhiteQueensBottom()); + input.setWhiteQueensTop(result.getArrowTop()); + input.setWhiteQueensBottom(result.getArrowBottom()); + } + + int[][] output = new int[2][100]; + + for (int color = 0; color < 2; color++) { + loop: for (int i = 0; i < 100; i++) { + for (int nMoves = 0; nMoves < 10; nMoves++) { + if (i < 50 && (reachable[color][nMoves][1] & (1L << i)) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } else if (i >= 50 && (reachable[color][nMoves][0] & (1L << (i - 50))) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } + } + output[color][i] = 1000000; + } + } + + return output; + } + + /** + * Returns a bitboard with all the reachable positions in one move from the given board. Uses almost only bitwise operations, making it very fast. + * + * @param b The board to find reachable positions from. + * @return A bitboard with all the reachable positions in one move from the given board for both colors. + */ + private static BitBoard reachableInOneMove(int color, BitBoard b) { + // Constants + long columnMask = 0b00000000010000000001000000000100000000010000000001L; + long rowMask = 0b1111111111L; + long diagonalMask = 0b10000000000100000000001000000000010000000000100000000001L; + long antiDiagonalMask = 0b000000001000000001000000001000000001000000001L; + long boardMask = -1L >>> (64 - 50); + + long blackTop = b.getBlackQueensTop(); + long blackBottom = b.getBlackQueensBottom(); + long whiteTop = b.getWhiteQueensTop(); + long whiteBottom = b.getWhiteQueensBottom(); + long arrowTop = b.getArrowTop(); + long arrowBottom = b.getArrowBottom(); + + long occupiedTop = blackTop | whiteTop | arrowTop; + long occupiedBottom = blackBottom | whiteBottom | arrowBottom; + + // All squares reachable in one move + long reachableTop = 0L; + long reachableBottom = 0L; + + int queenCount; + long[] queensTop, queensBottom; + + if (color == State.BLACK_QUEEN) { + // Find black queens + queenCount = Long.bitCount(blackTop) + Long.bitCount(blackBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (blackBottom > 1) { + queensBottom[i] = Long.lowestOneBit(blackBottom); + queensTop[i] = 0L; + blackBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(blackTop); + blackTop ^= queensTop[i]; + } + } + } else { + // Find white queens + queenCount = Long.bitCount(whiteTop) + Long.bitCount(whiteBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (whiteBottom > 1) { + queensBottom[i] = Long.lowestOneBit(whiteBottom); + queensTop[i] = 0L; + whiteBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(whiteTop); + whiteTop ^= queensTop[i]; + } + } + } + + for (int pieceNum = 0; pieceNum < queenCount; pieceNum++) { + + // The piece we are moving + long startBottom = queensBottom[pieceNum]; + long startTop = queensTop[pieceNum]; + + int startRow = getRow(startBottom, startTop); + int startCol = getCol(startBottom, startTop); + + // All pieces except the one we are moving + long currentOccupiedTop = occupiedTop ^ startTop; + long currentOccupiedBottom = occupiedBottom ^ startBottom; + + + /////////////////////////////////////////// + // Generate masks along the 4 directions // + /////////////////////////////////////////// + + // Horizontal + long rowMaskTop, rowMaskBottom; + if (startRow < 5) { + rowMaskTop = 0L; + rowMaskBottom = rowMask << (startRow * 10); + } else { + rowMaskTop = rowMask << ((startRow - 5) * 10); + rowMaskBottom = 0L; + } + + // Vertical + long colMaskTop = columnMask << startCol; + long colMaskBottom = columnMask << startCol; + + // Diagonal + long diagMaskTop, diagMaskBottom; + int diagShift = startCol - startRow; + if (diagShift >= 0) { + if (diagShift >= 5) { + diagMaskTop = 0L; + diagMaskBottom = (diagonalMask << diagShift) & ~(-1L << ((10 - diagShift) * 10)); + } else { + diagMaskTop = (diagonalMask << (diagShift + 5)) & ~(-1L << ((5 - diagShift) * 10)); + diagMaskBottom = diagonalMask << diagShift; + } + } else { + if (diagShift < -5) { + diagMaskTop = (diagonalMask >>> (-diagShift + 6)) & (-1L << ((-diagShift - 5) * 10)); + diagMaskBottom = 0L; + } else { + diagMaskTop = diagonalMask >>> (-diagShift + 6); + diagMaskBottom = (diagonalMask >>> (-diagShift)) & (-1L << (-diagShift * 10)); + } + } + + long antiDiagMaskTop, antiDiagMaskBottom; + int antiDiagShift = startCol + startRow; + if (antiDiagShift >= 9) { + if (antiDiagShift > 13) { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)) & (-1L << ((antiDiagShift - 14) * 10 + 1)); + antiDiagMaskBottom = 0L; + } else { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)); + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)) & (-1L << ((antiDiagShift - 9) * 10 + 1)); + } + } else { + if (antiDiagShift < 4) { + antiDiagMaskTop = 0L; + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)) & ~(-1L << ((antiDiagShift + 1) * 10 - 1)); + } else { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)) & ~(-1L << ((antiDiagShift - 4) * 10 - 1)); + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)); + } + } + + // Make sure that all masks are only on the board + rowMaskTop &= boardMask; + rowMaskBottom &= boardMask; + colMaskTop &= boardMask; + colMaskBottom &= boardMask; + diagMaskTop &= boardMask; + diagMaskBottom &= boardMask; + antiDiagMaskTop &= boardMask; + antiDiagMaskBottom &= boardMask; + + + //////////////////////// + // Get pieces // + //////////////////////// + + // Get the possible pieces below the current piece + long piecesBelowBottom, piecesBelowTop; + piecesBelowBottom = startBottom - 1; + if (startTop == 0) + piecesBelowTop = 0L; + else + piecesBelowTop = startTop - 1; + + // Get the possible pieces above the current piece + long piecesAboveBottom = ~startBottom ^ piecesBelowBottom; + long piecesAboveTop = ~startTop ^ piecesBelowTop; + + + // Get the pieces that are along the path of the current piece in each direction + long occupiedRowTop = currentOccupiedTop & rowMaskTop; + long occupiedRowBottom = currentOccupiedBottom & rowMaskBottom; + long occupiedColTop = currentOccupiedTop & colMaskTop; + long occupiedColBottom = currentOccupiedBottom & colMaskBottom; + long occupiedDiagTop = currentOccupiedTop & diagMaskTop; + long occupiedDiagBottom = currentOccupiedBottom & diagMaskBottom; + long occupiedAntiDiagTop = currentOccupiedTop & antiDiagMaskTop; + long occupiedAntiDiagBottom = currentOccupiedBottom & antiDiagMaskBottom; + + + //////////////////////// + // Get blocking piece // + //////////////////////// + + // Get the first blocking piece in col above the current piece + long blockingPieceColAboveBottom, blockingPieceColAboveTop; + if (startTop != 0) { + blockingPieceColAboveBottom = 0L; + blockingPieceColAboveTop = Long.lowestOneBit(occupiedColTop & piecesAboveTop); + } else { + blockingPieceColAboveBottom = Long.lowestOneBit(occupiedColBottom & piecesAboveBottom); + if (blockingPieceColAboveBottom == 0) + blockingPieceColAboveTop = Long.lowestOneBit(occupiedColTop & piecesAboveTop); + else + blockingPieceColAboveTop = 0L; + } + + // Get the first blocking piece in col below the current piece + long blockingPieceColBelowBottom, blockingPieceColBelowTop; + if (startTop == 0) { + blockingPieceColBelowTop = 0L; + blockingPieceColBelowBottom = Long.highestOneBit(occupiedColBottom & piecesBelowBottom); + } else { + blockingPieceColBelowTop = Long.highestOneBit(occupiedColTop & piecesBelowTop); + if (blockingPieceColBelowTop == 0) + blockingPieceColBelowBottom = Long.highestOneBit(occupiedColBottom & piecesBelowBottom); + else + blockingPieceColBelowBottom = 0L; + } + + // Get the first blocking piece in row above the current piece + long blockingPieceRowAboveBottom, blockingPieceRowAboveTop; + if (startTop != 0) { + blockingPieceRowAboveBottom = 0L; + blockingPieceRowAboveTop = Long.lowestOneBit(occupiedRowTop & piecesAboveTop); + } else { + blockingPieceRowAboveBottom = Long.lowestOneBit(occupiedRowBottom & piecesAboveBottom); + blockingPieceRowAboveTop = 0L; + } + + // Get the first blocking piece in row below the current piece + long blockingPieceRowBelowBottom, blockingPieceRowBelowTop; + if (startTop == 0) { + blockingPieceRowBelowTop = 0L; + blockingPieceRowBelowBottom = Long.highestOneBit(occupiedRowBottom & piecesBelowBottom); + } else { + blockingPieceRowBelowTop = Long.highestOneBit(occupiedRowTop & piecesBelowTop); + blockingPieceRowBelowBottom = 0L; + } + + // Get the first blocking piece in diag above the current piece + long blockingPieceDiagAboveBottom, blockingPieceDiagAboveTop; + if (startTop != 0) { + blockingPieceDiagAboveBottom = 0L; + blockingPieceDiagAboveTop = Long.lowestOneBit(occupiedDiagTop & piecesAboveTop); + } else { + blockingPieceDiagAboveBottom = Long.lowestOneBit(occupiedDiagBottom & piecesAboveBottom); + if (blockingPieceDiagAboveBottom == 0) + blockingPieceDiagAboveTop = Long.lowestOneBit(occupiedDiagTop & piecesAboveTop); + else + blockingPieceDiagAboveTop = 0L; + } + + // Get the first blocking piece in diag below the current piece + long blockingPieceDiagBelowBottom, blockingPieceDiagBelowTop; + if (startTop == 0) { + blockingPieceDiagBelowTop = 0L; + blockingPieceDiagBelowBottom = Long.highestOneBit(occupiedDiagBottom & piecesBelowBottom); + } else { + blockingPieceDiagBelowTop = Long.highestOneBit(occupiedDiagTop & piecesBelowTop); + if (blockingPieceDiagBelowTop == 0) + blockingPieceDiagBelowBottom = Long.highestOneBit(occupiedDiagBottom & piecesBelowBottom); + else + blockingPieceDiagBelowBottom = 0L; + } + + // Get the first blocking piece in anti-diag above the current piece + long blockingPieceAntiDiagAboveBottom, blockingPieceAntiDiagAboveTop; + if (startTop != 0) { + blockingPieceAntiDiagAboveBottom = 0L; + blockingPieceAntiDiagAboveTop = Long.lowestOneBit(occupiedAntiDiagTop & piecesAboveTop); + } else { + blockingPieceAntiDiagAboveBottom = Long.lowestOneBit(occupiedAntiDiagBottom & piecesAboveBottom); + if (blockingPieceAntiDiagAboveBottom == 0) + blockingPieceAntiDiagAboveTop = Long.lowestOneBit(occupiedAntiDiagTop & piecesAboveTop); + else + blockingPieceAntiDiagAboveTop = 0L; + } + + // Get the first blocking piece in anti-diag below the current piece + long blockingPieceAntiDiagBelowBottom, blockingPieceAntiDiagBelowTop; + if (startTop == 0) { + blockingPieceAntiDiagBelowTop = 0L; + blockingPieceAntiDiagBelowBottom = Long.highestOneBit(occupiedAntiDiagBottom & piecesBelowBottom); + } else { + blockingPieceAntiDiagBelowTop = Long.highestOneBit(occupiedAntiDiagTop & piecesBelowTop); + if (blockingPieceAntiDiagBelowTop == 0) + blockingPieceAntiDiagBelowBottom = Long.highestOneBit(occupiedAntiDiagBottom & piecesBelowBottom); + else + blockingPieceAntiDiagBelowBottom = 0L; + } + + + /////////////////////////// + // Get pieces in between // + /////////////////////////// + + // Get squares movable to in col above the current piece + long betweenColAboveBottom, betweenColAboveTop; + if (startTop != 0) { + betweenColAboveBottom = 0L; + if (blockingPieceColAboveTop == 0) + betweenColAboveTop = piecesAboveTop & colMaskTop; + else + betweenColAboveTop = piecesAboveTop & (blockingPieceColAboveTop - 1) & colMaskTop; + } else { + if (blockingPieceColAboveBottom == 0) { + betweenColAboveBottom = piecesAboveBottom & colMaskBottom; + betweenColAboveTop = (blockingPieceColAboveTop - 1) & colMaskTop; + } else { + betweenColAboveBottom = piecesAboveBottom & (blockingPieceColAboveBottom - 1) & colMaskBottom; + betweenColAboveTop = 0L; + } + } + + // Get squares movable to in col below the current piece + long betweenColBelowBottom, betweenColBelowTop; + if (startTop == 0) { + betweenColBelowTop = 0L; + if (blockingPieceColBelowBottom == 0) + betweenColBelowBottom = piecesBelowBottom & colMaskBottom; + else + betweenColBelowBottom = piecesBelowBottom & (-blockingPieceColBelowBottom ^ blockingPieceColBelowBottom) & colMaskBottom; + } else { + if (blockingPieceColBelowTop == 0) { + betweenColBelowTop = piecesBelowTop & colMaskTop; + if (blockingPieceColBelowBottom == 0) + betweenColBelowBottom = piecesBelowBottom & colMaskBottom; + else + betweenColBelowBottom = (-blockingPieceColBelowBottom ^ blockingPieceColBelowBottom) & colMaskBottom; + } else { + betweenColBelowTop = piecesBelowTop & (-blockingPieceColBelowTop ^ blockingPieceColBelowTop) & colMaskTop; + betweenColBelowBottom = 0L; + } + } + + // Get squares movable to in row above the current piece + long betweenRowAboveBottom, betweenRowAboveTop; + if (startTop != 0) { + betweenRowAboveBottom = 0L; + if (blockingPieceRowAboveTop == 0) + betweenRowAboveTop = piecesAboveTop & rowMaskTop; + else + betweenRowAboveTop = piecesAboveTop & (blockingPieceRowAboveTop - 1) & rowMaskTop; + } else { + if (blockingPieceRowAboveBottom == 0) { + betweenRowAboveBottom = piecesAboveBottom & rowMaskBottom; + } else { + betweenRowAboveBottom = piecesAboveBottom & (blockingPieceRowAboveBottom - 1) & rowMaskBottom; + } + betweenRowAboveTop = 0L; + } + + // Get squares movable to in row below the current piece + long betweenRowBelowBottom, betweenRowBelowTop; + if (startTop == 0) { + betweenRowBelowTop = 0L; + if (blockingPieceRowBelowBottom == 0) + betweenRowBelowBottom = piecesBelowBottom & rowMaskBottom; + else + betweenRowBelowBottom = piecesBelowBottom & (-blockingPieceRowBelowBottom ^ blockingPieceRowBelowBottom) & rowMaskBottom; + } else { + if (blockingPieceRowBelowTop == 0) { + betweenRowBelowTop = piecesBelowTop & rowMaskTop; + } else { + betweenRowBelowTop = piecesBelowTop & (-blockingPieceRowBelowTop ^ blockingPieceRowBelowTop) & rowMaskTop; + } + betweenRowBelowBottom = 0L; + } + + // Get squares movable to in diag above the current piece + long betweenDiagAboveBottom, betweenDiagAboveTop; + if (startTop != 0) { + betweenDiagAboveBottom = 0L; + if (blockingPieceDiagAboveTop == 0) + betweenDiagAboveTop = piecesAboveTop & diagMaskTop; + else + betweenDiagAboveTop = piecesAboveTop & (blockingPieceDiagAboveTop - 1) & diagMaskTop; + } else { + if (blockingPieceDiagAboveBottom == 0) { + betweenDiagAboveBottom = piecesAboveBottom & diagMaskBottom; + betweenDiagAboveTop = (blockingPieceDiagAboveTop - 1) & diagMaskTop; + } else { + betweenDiagAboveBottom = piecesAboveBottom & (blockingPieceDiagAboveBottom - 1) & diagMaskBottom; + betweenDiagAboveTop = 0L; + } + } + + // Get squares movable to in diag below the current piece + long betweenDiagBelowBottom, betweenDiagBelowTop; + if (startTop == 0) { + betweenDiagBelowTop = 0L; + if (blockingPieceDiagBelowBottom == 0) + betweenDiagBelowBottom = piecesBelowBottom & diagMaskBottom; + else + betweenDiagBelowBottom = piecesBelowBottom & (-blockingPieceDiagBelowBottom ^ blockingPieceDiagBelowBottom) & diagMaskBottom; + } else { + if (blockingPieceDiagBelowTop == 0) { + betweenDiagBelowTop = piecesBelowTop & diagMaskTop; + if (blockingPieceDiagBelowBottom == 0) + betweenDiagBelowBottom = piecesBelowBottom & diagMaskBottom; + else + betweenDiagBelowBottom = (-blockingPieceDiagBelowBottom ^ blockingPieceDiagBelowBottom) & diagMaskBottom; + } else { + betweenDiagBelowTop = piecesBelowTop & (-blockingPieceDiagBelowTop ^ blockingPieceDiagBelowTop) & diagMaskTop; + betweenDiagBelowBottom = 0L; + } + } + + // Get squares movable to in anti-diag above the current piece + long betweenAntiDiagAboveBottom, betweenAntiDiagAboveTop; + if (startTop != 0) { + betweenAntiDiagAboveBottom = 0L; + if (blockingPieceAntiDiagAboveTop == 0) + betweenAntiDiagAboveTop = piecesAboveTop & antiDiagMaskTop; + else + betweenAntiDiagAboveTop = piecesAboveTop & (blockingPieceAntiDiagAboveTop - 1) & antiDiagMaskTop; + } else { + if (blockingPieceAntiDiagAboveBottom == 0) { + betweenAntiDiagAboveBottom = piecesAboveBottom & antiDiagMaskBottom; + betweenAntiDiagAboveTop = (blockingPieceAntiDiagAboveTop - 1) & antiDiagMaskTop; + } else { + betweenAntiDiagAboveBottom = piecesAboveBottom & (blockingPieceAntiDiagAboveBottom - 1) & antiDiagMaskBottom; + betweenAntiDiagAboveTop = 0L; + } + } + + // Get squares movable to in anti-diag below the current piece + long betweenAntiDiagBelowBottom, betweenAntiDiagBelowTop; + if (startTop == 0) { + betweenAntiDiagBelowTop = 0L; + if (blockingPieceAntiDiagBelowBottom == 0) + betweenAntiDiagBelowBottom = piecesBelowBottom & antiDiagMaskBottom; + else + betweenAntiDiagBelowBottom = piecesBelowBottom & (-blockingPieceAntiDiagBelowBottom ^ blockingPieceAntiDiagBelowBottom) & antiDiagMaskBottom; + } else { + if (blockingPieceAntiDiagBelowTop == 0) { + betweenAntiDiagBelowTop = piecesBelowTop & antiDiagMaskTop; + if (blockingPieceAntiDiagBelowBottom == 0) + betweenAntiDiagBelowBottom = piecesBelowBottom & antiDiagMaskBottom; + else + betweenAntiDiagBelowBottom = (-blockingPieceAntiDiagBelowBottom ^ blockingPieceAntiDiagBelowBottom) & antiDiagMaskBottom; + } else { + betweenAntiDiagBelowTop = piecesBelowTop & (-blockingPieceAntiDiagBelowTop ^ blockingPieceAntiDiagBelowTop) & antiDiagMaskTop; + betweenAntiDiagBelowBottom = 0L; + } + } + + // Or results onto the total reachable squares in pne move + reachableTop |= betweenColAboveTop | betweenColBelowTop | betweenRowAboveTop | betweenRowBelowTop | betweenDiagAboveTop | betweenDiagBelowTop | betweenAntiDiagAboveTop | betweenAntiDiagBelowTop; + reachableBottom |= betweenColAboveBottom | betweenColBelowBottom | betweenRowAboveBottom | betweenRowBelowBottom | betweenDiagAboveBottom | betweenDiagBelowBottom | betweenAntiDiagAboveBottom | betweenAntiDiagBelowBottom; + } + + // Return the reachable squares + BitBoard out = new BitBoard(); + out.setArrowTop(reachableTop); + out.setArrowBottom(reachableBottom); + return out; +// return new long[]{reachableTop, reachableBottom}; + } + + private static int getCol(long startBottom, long startTop) { + if (startBottom == 0) + return Long.bitCount(startTop - 1) % 10; + else + return Long.bitCount(startBottom - 1) % 10; + } + + private static int getRow(long startBottom, long startTop) { + if (startBottom == 0) + return Long.bitCount(startTop - 1) / 10 + 5; + else + return Long.bitCount(startBottom - 1) / 10; + } + + public static State randomState() { Random r = new Random(); - ArrayList states = new ArrayList<>(); - states.add(s); - int numStates = 10000; - for (int i = 0; i < numStates; i++) { - State randomState = states.get(r.nextInt(states.size())); - ArrayList actions = ActionGenerator.generateActions(randomState, r.nextInt(1) + 1, 1); - states.add(new State(s, actions.get(r.nextInt(actions.size())))); - } - - long start = System.currentTimeMillis(); - for (State state : states) - Heuristics.bigPoppa(state, 1); - long end = System.currentTimeMillis(); - System.out.println((double)(end - start) / numStates + "ms"); + ArrayList board = new ArrayList<>(100); + board.add(1); + board.add(1); + board.add(1); + board.add(1); + board.add(2); + board.add(2); + board.add(2); + board.add(2); + int numArrows = r.nextInt(30) + 10; + for (int i = 0; i < numArrows; i++) { + board.add(3); + } + while (board.size() < 100) { + board.add(0); + } + Collections.shuffle(board); + // Turn into an 11 by 11 arraylist + ArrayList newBoard = new ArrayList<>(121); + for (int y = 0; y < 11; y++) { + for (int x = 0; x < 11; x++) { + if (x == 0 || y == 0) { + newBoard.add(0); + continue; + } + newBoard.add(board.get((y - 1) * 10 + x - 1)); + } + } + return new State(newBoard); } } diff --git a/src/main/java/Tree/ExpansionPolicy.java b/src/main/java/Tree/ExpansionPolicy.java index 694a55d..abfca59 100644 --- a/src/main/java/Tree/ExpansionPolicy.java +++ b/src/main/java/Tree/ExpansionPolicy.java @@ -6,26 +6,76 @@ import State.BitBoard; import java.util.ArrayList; +import java.util.Arrays; import java.util.PriorityQueue; public class ExpansionPolicy { public static Node[] expansionNode(Node node, int numToExpand) { - return bitBoardLibertyExpansionPolicy(node, numToExpand); +// return bitBoardLibertyExpansionPolicy(node, numToExpand); +// return heuristicExpansion(node, numToExpand); + return randomExpansion(node, numToExpand); } - private static Node randomExpansion(Node node) { - int randomInt = (int) (Math.random() * node.getPossibleActions().length); - while (node.getChildren()[randomInt] != null) { - randomInt = (int) (Math.random() * node.getPossibleActions().length); + private static Node[] heuristicExpansion(Node node, int numToExpand) { + int colour = node.getColour(); + Action[] actions = node.getPossibleActions(); + Node[] children = node.getChildren(); + int definitelyTheBestAction = 0; + double bestH = Math.pow(-1, colour) * Integer.MAX_VALUE; + for (int i = 0; i < actions.length; i++) { + if (children[i] == null) { + Action a = actions[i]; + double h = Heuristics.bigPoppa(new State(node.getState(), a), colour); + if (colour == 1 && h > bestH) { + definitelyTheBestAction = i; + bestH = h; + } else if (colour == 2 && h < bestH) { + definitelyTheBestAction = i; + bestH = h; + } + } + } + + State state = node.getState(); + colour = node.getColour() == State.BLACK_QUEEN ? State.WHITE_QUEEN : State.BLACK_QUEEN; + State newState = new State(state, actions[definitelyTheBestAction]); + Action[] newActions = ActionGenerator.generateActions(newState, colour).toArray(new Action[0]); + Node expansion = new Node(newState, actions[definitelyTheBestAction], node, colour, 0, 0, newActions, node.getDepth() + 1); + node.getChildren()[definitelyTheBestAction] = expansion; + + return new Node[]{expansion}; + } + + private static Node[] randomExpansion(Node node, int numToExpand) { + // Count the number of children that are null + int numNull = 0; + for (Node child : node.getChildren()) { + if (child == null) { + numNull++; + } + } + + if (numNull < numToExpand) { + numToExpand = numNull; + } + + ArrayList expanded = new ArrayList<>(); + for (int i = 0; i < numToExpand; i++) { + int randomInt = (int) (Math.random() * node.getPossibleActions().length); + while (node.getChildren()[randomInt] != null) { + randomInt = (int) (Math.random() * node.getPossibleActions().length); + } + Action randomAction = node.getPossibleActions()[randomInt]; + State state = new State(node.getState(), randomAction); + int colour = node.getColour() == State.BLACK_QUEEN ? State.WHITE_QUEEN : State.BLACK_QUEEN; + Action[] actions = ActionGenerator.generateActions(state, colour).toArray(new Action[0]); + Node expansion = new Node(state, randomAction, node, colour, 0, 0, actions, node.getDepth() + 1); + node.getChildren()[randomInt] = expansion; + + expanded.add(expansion); } - Action randomAction = node.getPossibleActions()[randomInt]; - State state = new State(node.getState(), randomAction); - int colour = node.getColour() == State.BLACK_QUEEN ? State.WHITE_QUEEN : State.BLACK_QUEEN; - Action[] actions = ActionGenerator.generateActions(state, colour).toArray(new Action[0]); - Node expansion = new Node(state, randomAction, node, colour, 0, 0, actions, node.getDepth() + 1); - node.getChildren()[randomInt] = expansion; - return expansion; + return expanded.toArray(new Node[0]); } private static Node[] bitBoardLibertyExpansionPolicy(Node node, int numToExpand) { diff --git a/src/main/java/Tree/Heuristics.java b/src/main/java/Tree/Heuristics.java index 4b2f569..ced1ef1 100644 --- a/src/main/java/Tree/Heuristics.java +++ b/src/main/java/Tree/Heuristics.java @@ -4,6 +4,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; enum Direction { DR, @@ -21,21 +22,35 @@ enum Direction { * Heuristics from this paper */ public class Heuristics { + private static final HashMap previouslyEvaluated = new HashMap<>(); + public static double bigPoppa(State s, int playerToMove) { - int[][] D1 = D(1, s); - int[][] D2 = D(2, s); + // Check to see if we have previously evaluated this state + long hash = ZobristHash.zobristHash(s.getBitBoard(), playerToMove); + Double heuristic = previouslyEvaluated.get(hash); + if (heuristic != null) + return heuristic; + + // Else evaluate it + int[][] D1 = D1(s); + int[][] D2 = D2(s); double t1 = t(D1, playerToMove); double t2 = t(D2, playerToMove); +// start = System.nanoTime(); double c1 = c(1, D1); double c2 = c(2, D2); double w = w(D1); - double[] f = f(w); + double[] f = planBf(w); + + heuristic = f[0] * t1 + f[1] * c1 + f[2] * c2 + f[3] * t2; - return f[0] * t1 + f[1] * c1 + f[2] * c2 + f[3] * t2; + previouslyEvaluated.put(hash, heuristic); + + return heuristic; } private static double[] f(double w) { @@ -51,10 +66,18 @@ private static double[] f(double w) { return f; } + private static double f1(double w, double a) { return 4 * Math.exp(-a * w) / Math.pow(1 + Math.exp(-a * w), 2); } + private static double[] planBf(double w) { + double[] f = new double[4]; + f[0] = (100 - w) / 100; + f[1] = f[2] = f[3] = (1 - f[0]) / 3; + return f; + } + private static double w(int[][] D) { double w = 0; for (int a = 0; a < 100; a++) @@ -295,6 +318,733 @@ else if (m == 6) return new int[][]{blackDistances, whiteDistances}; } + /** + * Returns the minimum number of moves required to reach each square from the given state. + * @param s The state to start from + * @return A 2D array of integers. The first index is the color, the second is the square. + */ + private static int[][] D1(State s) { + // Clone input so as not to modify it + BitBoard input; + try { + input = (BitBoard) s.getBitBoard().clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + // Initialize reachable array. Indexed as [color][moveNum][top/bottom] + long[][][] reachable = new long[2][10][2]; + BitBoard result; + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = queenReachableInOneMove(State.BLACK_QUEEN, input); + + reachable[0][nMoves][0] = result.getArrowTop();// result[0]; + reachable[0][nMoves][1] = result.getArrowBottom();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getBlackQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getBlackQueensBottom()); + input.setBlackQueensTop(result.getArrowTop()); + input.setBlackQueensBottom(result.getArrowBottom()); + } + + try { + input = (BitBoard) s.getBitBoard().clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = queenReachableInOneMove(State.WHITE_QUEEN, input); + + reachable[1][nMoves][0] = result.getArrowTop();// result[0]; + reachable[1][nMoves][1] = result.getArrowTop();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getWhiteQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getWhiteQueensBottom()); + input.setWhiteQueensTop(result.getArrowTop()); + input.setWhiteQueensBottom(result.getArrowBottom()); + } + + int[][] output = new int[2][100]; + + for (int color = 0; color < 2; color++) { + loop: for (int i = 0; i < 100; i++) { + for (int nMoves = 0; nMoves < 10; nMoves++) { + if (i < 50 && (reachable[color][nMoves][1] & (1L << i)) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } else if (i >= 50 && (reachable[color][nMoves][0] & (1L << (i - 50))) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } + } + output[color][i] = 1000000; + } + } + + return output; + } + + /** + * Returns a bitboard with all the reachable positions in one move from the given board. Uses almost only bitwise operations, making it very fast. + * + * @param b The board to find reachable positions from. + * @return A bitboard with all the reachable positions in one move from the given board for both colors. + */ + private static BitBoard queenReachableInOneMove(int color, BitBoard b) { + // Constants + long columnMask = 0b00000000010000000001000000000100000000010000000001L; + long rowMask = 0b1111111111L; + long diagonalMask = 0b10000000000100000000001000000000010000000000100000000001L; + long antiDiagonalMask = 0b000000001000000001000000001000000001000000001L; + long boardMask = -1L >>> (64 - 50); + + long blackTop = b.getBlackQueensTop(); + long blackBottom = b.getBlackQueensBottom(); + long whiteTop = b.getWhiteQueensTop(); + long whiteBottom = b.getWhiteQueensBottom(); + long arrowTop = b.getArrowTop(); + long arrowBottom = b.getArrowBottom(); + + long occupiedTop = blackTop | whiteTop | arrowTop; + long occupiedBottom = blackBottom | whiteBottom | arrowBottom; + + // All squares reachable in one move + long reachableTop = 0L; + long reachableBottom = 0L; + + int queenCount; + long[] queensTop, queensBottom; + + if (color == State.BLACK_QUEEN) { + // Find black queens + queenCount = Long.bitCount(blackTop) + Long.bitCount(blackBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (blackBottom > 1) { + queensBottom[i] = Long.lowestOneBit(blackBottom); + queensTop[i] = 0L; + blackBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(blackTop); + blackTop ^= queensTop[i]; + } + } + } else { + // Find white queens + queenCount = Long.bitCount(whiteTop) + Long.bitCount(whiteBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (whiteBottom > 1) { + queensBottom[i] = Long.lowestOneBit(whiteBottom); + queensTop[i] = 0L; + whiteBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(whiteTop); + whiteTop ^= queensTop[i]; + } + } + } + + for (int pieceNum = 0; pieceNum < queenCount; pieceNum++) { + + // The piece we are moving + long startBottom = queensBottom[pieceNum]; + long startTop = queensTop[pieceNum]; + + int startRow = getRow(startBottom, startTop); + int startCol = getCol(startBottom, startTop); + + // All pieces except the one we are moving + long currentOccupiedTop = occupiedTop ^ startTop; + long currentOccupiedBottom = occupiedBottom ^ startBottom; + + + /////////////////////////////////////////// + // Generate masks along the 4 directions // + /////////////////////////////////////////// + + // Horizontal + long rowMaskTop, rowMaskBottom; + if (startRow < 5) { + rowMaskTop = 0L; + rowMaskBottom = rowMask << (startRow * 10); + } else { + rowMaskTop = rowMask << ((startRow - 5) * 10); + rowMaskBottom = 0L; + } + + // Vertical + long colMaskTop = columnMask << startCol; + long colMaskBottom = columnMask << startCol; + + // Diagonal + long diagMaskTop, diagMaskBottom; + int diagShift = startCol - startRow; + if (diagShift >= 0) { + if (diagShift >= 5) { + diagMaskTop = 0L; + diagMaskBottom = (diagonalMask << diagShift) & ~(-1L << ((10 - diagShift) * 10)); + } else { + diagMaskTop = (diagonalMask << (diagShift + 5)) & ~(-1L << ((5 - diagShift) * 10)); + diagMaskBottom = diagonalMask << diagShift; + } + } else { + if (diagShift < -5) { + diagMaskTop = (diagonalMask >>> (-diagShift + 6)) & (-1L << ((-diagShift - 5) * 10)); + diagMaskBottom = 0L; + } else { + diagMaskTop = diagonalMask >>> (-diagShift + 6); + diagMaskBottom = (diagonalMask >>> (-diagShift)) & (-1L << (-diagShift * 10)); + } + } + + long antiDiagMaskTop, antiDiagMaskBottom; + int antiDiagShift = startCol + startRow; + if (antiDiagShift >= 9) { + if (antiDiagShift > 13) { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)) & (-1L << ((antiDiagShift - 14) * 10 + 1)); + antiDiagMaskBottom = 0L; + } else { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)); + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)) & (-1L << ((antiDiagShift - 9) * 10 + 1)); + } + } else { + if (antiDiagShift < 4) { + antiDiagMaskTop = 0L; + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)) & ~(-1L << ((antiDiagShift + 1) * 10 - 1)); + } else { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)) & ~(-1L << ((antiDiagShift - 4) * 10 - 1)); + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)); + } + } + + // Make sure that all masks are only on the board + rowMaskTop &= boardMask; + rowMaskBottom &= boardMask; + colMaskTop &= boardMask; + colMaskBottom &= boardMask; + diagMaskTop &= boardMask; + diagMaskBottom &= boardMask; + antiDiagMaskTop &= boardMask; + antiDiagMaskBottom &= boardMask; + + + //////////////////////// + // Get pieces // + //////////////////////// + + // Get the possible pieces below the current piece + long piecesBelowBottom, piecesBelowTop; + piecesBelowBottom = startBottom - 1; + if (startTop == 0) + piecesBelowTop = 0L; + else + piecesBelowTop = startTop - 1; + + // Get the possible pieces above the current piece + long piecesAboveBottom = ~startBottom ^ piecesBelowBottom; + long piecesAboveTop = ~startTop ^ piecesBelowTop; + + + // Get the pieces that are along the path of the current piece in each direction + long occupiedRowTop = currentOccupiedTop & rowMaskTop; + long occupiedRowBottom = currentOccupiedBottom & rowMaskBottom; + long occupiedColTop = currentOccupiedTop & colMaskTop; + long occupiedColBottom = currentOccupiedBottom & colMaskBottom; + long occupiedDiagTop = currentOccupiedTop & diagMaskTop; + long occupiedDiagBottom = currentOccupiedBottom & diagMaskBottom; + long occupiedAntiDiagTop = currentOccupiedTop & antiDiagMaskTop; + long occupiedAntiDiagBottom = currentOccupiedBottom & antiDiagMaskBottom; + + + //////////////////////// + // Get blocking piece // + //////////////////////// + + // Get the first blocking piece in col above the current piece + long blockingPieceColAboveBottom, blockingPieceColAboveTop; + if (startTop != 0) { + blockingPieceColAboveBottom = 0L; + blockingPieceColAboveTop = Long.lowestOneBit(occupiedColTop & piecesAboveTop); + } else { + blockingPieceColAboveBottom = Long.lowestOneBit(occupiedColBottom & piecesAboveBottom); + if (blockingPieceColAboveBottom == 0) + blockingPieceColAboveTop = Long.lowestOneBit(occupiedColTop & piecesAboveTop); + else + blockingPieceColAboveTop = 0L; + } + + // Get the first blocking piece in col below the current piece + long blockingPieceColBelowBottom, blockingPieceColBelowTop; + if (startTop == 0) { + blockingPieceColBelowTop = 0L; + blockingPieceColBelowBottom = Long.highestOneBit(occupiedColBottom & piecesBelowBottom); + } else { + blockingPieceColBelowTop = Long.highestOneBit(occupiedColTop & piecesBelowTop); + if (blockingPieceColBelowTop == 0) + blockingPieceColBelowBottom = Long.highestOneBit(occupiedColBottom & piecesBelowBottom); + else + blockingPieceColBelowBottom = 0L; + } + + // Get the first blocking piece in row above the current piece + long blockingPieceRowAboveBottom, blockingPieceRowAboveTop; + if (startTop != 0) { + blockingPieceRowAboveBottom = 0L; + blockingPieceRowAboveTop = Long.lowestOneBit(occupiedRowTop & piecesAboveTop); + } else { + blockingPieceRowAboveBottom = Long.lowestOneBit(occupiedRowBottom & piecesAboveBottom); + blockingPieceRowAboveTop = 0L; + } + + // Get the first blocking piece in row below the current piece + long blockingPieceRowBelowBottom, blockingPieceRowBelowTop; + if (startTop == 0) { + blockingPieceRowBelowTop = 0L; + blockingPieceRowBelowBottom = Long.highestOneBit(occupiedRowBottom & piecesBelowBottom); + } else { + blockingPieceRowBelowTop = Long.highestOneBit(occupiedRowTop & piecesBelowTop); + blockingPieceRowBelowBottom = 0L; + } + + // Get the first blocking piece in diag above the current piece + long blockingPieceDiagAboveBottom, blockingPieceDiagAboveTop; + if (startTop != 0) { + blockingPieceDiagAboveBottom = 0L; + blockingPieceDiagAboveTop = Long.lowestOneBit(occupiedDiagTop & piecesAboveTop); + } else { + blockingPieceDiagAboveBottom = Long.lowestOneBit(occupiedDiagBottom & piecesAboveBottom); + if (blockingPieceDiagAboveBottom == 0) + blockingPieceDiagAboveTop = Long.lowestOneBit(occupiedDiagTop & piecesAboveTop); + else + blockingPieceDiagAboveTop = 0L; + } + + // Get the first blocking piece in diag below the current piece + long blockingPieceDiagBelowBottom, blockingPieceDiagBelowTop; + if (startTop == 0) { + blockingPieceDiagBelowTop = 0L; + blockingPieceDiagBelowBottom = Long.highestOneBit(occupiedDiagBottom & piecesBelowBottom); + } else { + blockingPieceDiagBelowTop = Long.highestOneBit(occupiedDiagTop & piecesBelowTop); + if (blockingPieceDiagBelowTop == 0) + blockingPieceDiagBelowBottom = Long.highestOneBit(occupiedDiagBottom & piecesBelowBottom); + else + blockingPieceDiagBelowBottom = 0L; + } + + // Get the first blocking piece in anti-diag above the current piece + long blockingPieceAntiDiagAboveBottom, blockingPieceAntiDiagAboveTop; + if (startTop != 0) { + blockingPieceAntiDiagAboveBottom = 0L; + blockingPieceAntiDiagAboveTop = Long.lowestOneBit(occupiedAntiDiagTop & piecesAboveTop); + } else { + blockingPieceAntiDiagAboveBottom = Long.lowestOneBit(occupiedAntiDiagBottom & piecesAboveBottom); + if (blockingPieceAntiDiagAboveBottom == 0) + blockingPieceAntiDiagAboveTop = Long.lowestOneBit(occupiedAntiDiagTop & piecesAboveTop); + else + blockingPieceAntiDiagAboveTop = 0L; + } + + // Get the first blocking piece in anti-diag below the current piece + long blockingPieceAntiDiagBelowBottom, blockingPieceAntiDiagBelowTop; + if (startTop == 0) { + blockingPieceAntiDiagBelowTop = 0L; + blockingPieceAntiDiagBelowBottom = Long.highestOneBit(occupiedAntiDiagBottom & piecesBelowBottom); + } else { + blockingPieceAntiDiagBelowTop = Long.highestOneBit(occupiedAntiDiagTop & piecesBelowTop); + if (blockingPieceAntiDiagBelowTop == 0) + blockingPieceAntiDiagBelowBottom = Long.highestOneBit(occupiedAntiDiagBottom & piecesBelowBottom); + else + blockingPieceAntiDiagBelowBottom = 0L; + } + + + /////////////////////////// + // Get pieces in between // + /////////////////////////// + + // Get squares movable to in col above the current piece + long betweenColAboveBottom, betweenColAboveTop; + if (startTop != 0) { + betweenColAboveBottom = 0L; + if (blockingPieceColAboveTop == 0) + betweenColAboveTop = piecesAboveTop & colMaskTop; + else + betweenColAboveTop = piecesAboveTop & (blockingPieceColAboveTop - 1) & colMaskTop; + } else { + if (blockingPieceColAboveBottom == 0) { + betweenColAboveBottom = piecesAboveBottom & colMaskBottom; + betweenColAboveTop = (blockingPieceColAboveTop - 1) & colMaskTop; + } else { + betweenColAboveBottom = piecesAboveBottom & (blockingPieceColAboveBottom - 1) & colMaskBottom; + betweenColAboveTop = 0L; + } + } + + // Get squares movable to in col below the current piece + long betweenColBelowBottom, betweenColBelowTop; + if (startTop == 0) { + betweenColBelowTop = 0L; + if (blockingPieceColBelowBottom == 0) + betweenColBelowBottom = piecesBelowBottom & colMaskBottom; + else + betweenColBelowBottom = piecesBelowBottom & (-blockingPieceColBelowBottom ^ blockingPieceColBelowBottom) & colMaskBottom; + } else { + if (blockingPieceColBelowTop == 0) { + betweenColBelowTop = piecesBelowTop & colMaskTop; + if (blockingPieceColBelowBottom == 0) + betweenColBelowBottom = piecesBelowBottom & colMaskBottom; + else + betweenColBelowBottom = (-blockingPieceColBelowBottom ^ blockingPieceColBelowBottom) & colMaskBottom; + } else { + betweenColBelowTop = piecesBelowTop & (-blockingPieceColBelowTop ^ blockingPieceColBelowTop) & colMaskTop; + betweenColBelowBottom = 0L; + } + } + + // Get squares movable to in row above the current piece + long betweenRowAboveBottom, betweenRowAboveTop; + if (startTop != 0) { + betweenRowAboveBottom = 0L; + if (blockingPieceRowAboveTop == 0) + betweenRowAboveTop = piecesAboveTop & rowMaskTop; + else + betweenRowAboveTop = piecesAboveTop & (blockingPieceRowAboveTop - 1) & rowMaskTop; + } else { + if (blockingPieceRowAboveBottom == 0) { + betweenRowAboveBottom = piecesAboveBottom & rowMaskBottom; + } else { + betweenRowAboveBottom = piecesAboveBottom & (blockingPieceRowAboveBottom - 1) & rowMaskBottom; + } + betweenRowAboveTop = 0L; + } + + // Get squares movable to in row below the current piece + long betweenRowBelowBottom, betweenRowBelowTop; + if (startTop == 0) { + betweenRowBelowTop = 0L; + if (blockingPieceRowBelowBottom == 0) + betweenRowBelowBottom = piecesBelowBottom & rowMaskBottom; + else + betweenRowBelowBottom = piecesBelowBottom & (-blockingPieceRowBelowBottom ^ blockingPieceRowBelowBottom) & rowMaskBottom; + } else { + if (blockingPieceRowBelowTop == 0) { + betweenRowBelowTop = piecesBelowTop & rowMaskTop; + } else { + betweenRowBelowTop = piecesBelowTop & (-blockingPieceRowBelowTop ^ blockingPieceRowBelowTop) & rowMaskTop; + } + betweenRowBelowBottom = 0L; + } + + // Get squares movable to in diag above the current piece + long betweenDiagAboveBottom, betweenDiagAboveTop; + if (startTop != 0) { + betweenDiagAboveBottom = 0L; + if (blockingPieceDiagAboveTop == 0) + betweenDiagAboveTop = piecesAboveTop & diagMaskTop; + else + betweenDiagAboveTop = piecesAboveTop & (blockingPieceDiagAboveTop - 1) & diagMaskTop; + } else { + if (blockingPieceDiagAboveBottom == 0) { + betweenDiagAboveBottom = piecesAboveBottom & diagMaskBottom; + betweenDiagAboveTop = (blockingPieceDiagAboveTop - 1) & diagMaskTop; + } else { + betweenDiagAboveBottom = piecesAboveBottom & (blockingPieceDiagAboveBottom - 1) & diagMaskBottom; + betweenDiagAboveTop = 0L; + } + } + + // Get squares movable to in diag below the current piece + long betweenDiagBelowBottom, betweenDiagBelowTop; + if (startTop == 0) { + betweenDiagBelowTop = 0L; + if (blockingPieceDiagBelowBottom == 0) + betweenDiagBelowBottom = piecesBelowBottom & diagMaskBottom; + else + betweenDiagBelowBottom = piecesBelowBottom & (-blockingPieceDiagBelowBottom ^ blockingPieceDiagBelowBottom) & diagMaskBottom; + } else { + if (blockingPieceDiagBelowTop == 0) { + betweenDiagBelowTop = piecesBelowTop & diagMaskTop; + if (blockingPieceDiagBelowBottom == 0) + betweenDiagBelowBottom = piecesBelowBottom & diagMaskBottom; + else + betweenDiagBelowBottom = (-blockingPieceDiagBelowBottom ^ blockingPieceDiagBelowBottom) & diagMaskBottom; + } else { + betweenDiagBelowTop = piecesBelowTop & (-blockingPieceDiagBelowTop ^ blockingPieceDiagBelowTop) & diagMaskTop; + betweenDiagBelowBottom = 0L; + } + } + + // Get squares movable to in anti-diag above the current piece + long betweenAntiDiagAboveBottom, betweenAntiDiagAboveTop; + if (startTop != 0) { + betweenAntiDiagAboveBottom = 0L; + if (blockingPieceAntiDiagAboveTop == 0) + betweenAntiDiagAboveTop = piecesAboveTop & antiDiagMaskTop; + else + betweenAntiDiagAboveTop = piecesAboveTop & (blockingPieceAntiDiagAboveTop - 1) & antiDiagMaskTop; + } else { + if (blockingPieceAntiDiagAboveBottom == 0) { + betweenAntiDiagAboveBottom = piecesAboveBottom & antiDiagMaskBottom; + betweenAntiDiagAboveTop = (blockingPieceAntiDiagAboveTop - 1) & antiDiagMaskTop; + } else { + betweenAntiDiagAboveBottom = piecesAboveBottom & (blockingPieceAntiDiagAboveBottom - 1) & antiDiagMaskBottom; + betweenAntiDiagAboveTop = 0L; + } + } + + // Get squares movable to in anti-diag below the current piece + long betweenAntiDiagBelowBottom, betweenAntiDiagBelowTop; + if (startTop == 0) { + betweenAntiDiagBelowTop = 0L; + if (blockingPieceAntiDiagBelowBottom == 0) + betweenAntiDiagBelowBottom = piecesBelowBottom & antiDiagMaskBottom; + else + betweenAntiDiagBelowBottom = piecesBelowBottom & (-blockingPieceAntiDiagBelowBottom ^ blockingPieceAntiDiagBelowBottom) & antiDiagMaskBottom; + } else { + if (blockingPieceAntiDiagBelowTop == 0) { + betweenAntiDiagBelowTop = piecesBelowTop & antiDiagMaskTop; + if (blockingPieceAntiDiagBelowBottom == 0) + betweenAntiDiagBelowBottom = piecesBelowBottom & antiDiagMaskBottom; + else + betweenAntiDiagBelowBottom = (-blockingPieceAntiDiagBelowBottom ^ blockingPieceAntiDiagBelowBottom) & antiDiagMaskBottom; + } else { + betweenAntiDiagBelowTop = piecesBelowTop & (-blockingPieceAntiDiagBelowTop ^ blockingPieceAntiDiagBelowTop) & antiDiagMaskTop; + betweenAntiDiagBelowBottom = 0L; + } + } + + // Or results onto the total reachable squares in pne move + reachableTop |= betweenColAboveTop | betweenColBelowTop | betweenRowAboveTop | betweenRowBelowTop | betweenDiagAboveTop | betweenDiagBelowTop | betweenAntiDiagAboveTop | betweenAntiDiagBelowTop; + reachableBottom |= betweenColAboveBottom | betweenColBelowBottom | betweenRowAboveBottom | betweenRowBelowBottom | betweenDiagAboveBottom | betweenDiagBelowBottom | betweenAntiDiagAboveBottom | betweenAntiDiagBelowBottom; + } + + // Return the reachable squares + BitBoard out = new BitBoard(); + out.setArrowTop(reachableTop); + out.setArrowBottom(reachableBottom); + return out; +// return new long[]{reachableTop, reachableBottom}; + } + + private static int[][] D2(State s) { + // Clone input so as not to modify it + BitBoard input; + try { + input = (BitBoard) s.getBitBoard().clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + // Initialize reachable array. Indexed as [color][moveNum][top/bottom] + long[][][] reachable = new long[2][10][2]; + BitBoard result; + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = kingReachableInOneMove(State.BLACK_QUEEN, input); + + reachable[0][nMoves][0] = result.getArrowTop();// result[0]; + reachable[0][nMoves][1] = result.getArrowBottom();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getBlackQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getBlackQueensBottom()); + input.setBlackQueensTop(result.getArrowTop()); + input.setBlackQueensBottom(result.getArrowBottom()); + } + + try { + input = (BitBoard) s.getBitBoard().clone(); + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + + for (int nMoves = 0; nMoves < 10; nMoves++) { + // Find all reachable squares + result = kingReachableInOneMove(State.WHITE_QUEEN, input); + + reachable[1][nMoves][0] = result.getArrowTop();// result[0]; + reachable[1][nMoves][1] = result.getArrowTop();// result[1]; + + if (nMoves > 0) { + if (result.getArrowTop() == 0L && result.getArrowBottom() == 0L) { + break; + } + } + + input.setArrowTop(input.getArrowTop() | input.getWhiteQueensTop()); + input.setArrowBottom(input.getArrowBottom() | input.getWhiteQueensBottom()); + input.setWhiteQueensTop(result.getArrowTop()); + input.setWhiteQueensBottom(result.getArrowBottom()); + } + + int[][] output = new int[2][100]; + + for (int color = 0; color < 2; color++) { + loop: for (int i = 0; i < 100; i++) { + for (int nMoves = 0; nMoves < 10; nMoves++) { + if (i < 50 && (reachable[color][nMoves][1] & (1L << i)) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } else if (i >= 50 && (reachable[color][nMoves][0] & (1L << (i - 50))) != 0) { + output[color][i] = nMoves + 1; + continue loop; + } + } + output[color][i] = 1000000; + } + } + + return output; + } + + private static BitBoard kingReachableInOneMove(int color, BitBoard b) { + long boardMask = -1L >>> (64 - 50); + long aroundMask = 0b11100000001010000000111L; + long notAFile = 0b11111111101111111110111111111011111111101111111110L; + long notJFile = 0b01111111110111111111011111111101111111110111111111L; + + long blackTop = b.getBlackQueensTop(); + long blackBottom = b.getBlackQueensBottom(); + long whiteTop = b.getWhiteQueensTop(); + long whiteBottom = b.getWhiteQueensBottom(); + long arrowTop = b.getArrowTop(); + long arrowBottom = b.getArrowBottom(); + + long occupiedTop = blackTop | whiteTop | arrowTop; + long occupiedBottom = blackBottom | whiteBottom | arrowBottom; + + // All squares reachable in one move + long reachableTop = 0L; + long reachableBottom = 0L; + + int queenCount; + long[] queensTop, queensBottom; + + if (color == State.BLACK_QUEEN) { + // Find black queens + queenCount = Long.bitCount(blackTop) + Long.bitCount(blackBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (blackBottom > 1) { + queensBottom[i] = Long.lowestOneBit(blackBottom); + queensTop[i] = 0L; + blackBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(blackTop); + blackTop ^= queensTop[i]; + } + } + } else { + // Find white queens + queenCount = Long.bitCount(whiteTop) + Long.bitCount(whiteBottom); + queensTop = new long[queenCount]; + queensBottom = new long[queenCount]; + for (int i = 0; i < queenCount; i++) { + if (whiteBottom > 1) { + queensBottom[i] = Long.lowestOneBit(whiteBottom); + queensTop[i] = 0L; + whiteBottom ^= queensBottom[i]; + } else { + queensBottom[i] = 0L; + queensTop[i] = Long.lowestOneBit(whiteTop); + whiteTop ^= queensTop[i]; + } + } + } + + + // Find all reachable squares + for (int pieceNum = 0; pieceNum < queenCount; pieceNum++) { + + // The piece we are moving + long startBottom = queensBottom[pieceNum]; + long startTop = queensTop[pieceNum]; + + /////////////////// + // Generate mask // + /////////////////// + + int index; + if (startTop > 0) { + index = 50 + Long.numberOfTrailingZeros(startTop); + } else { + index = Long.numberOfTrailingZeros(startBottom); + } + + long maskBottom, maskTop; + if (index > 39) { + if (index < 61) + maskTop = aroundMask >>> (61 - index); + else + maskTop = aroundMask << (index - 61); + } else { + maskTop = 0L; + } + + if (index < 61) { + if (index < 11) { + maskBottom = aroundMask >>> (11 - index); + } else { + maskBottom = aroundMask << (index - 11); + } + } else { + maskBottom = 0L; + } + + if (index % 10 == 0) { + maskTop &= notJFile; + maskBottom &= notJFile; + } else if (index % 10 == 9) { + maskTop &= notAFile; + maskBottom &= notAFile; + } + + maskTop &= boardMask; + maskBottom &= boardMask; + + + // Actually get squares we can move to + reachableTop |= ~occupiedTop & maskTop; + reachableBottom |= ~occupiedBottom & maskBottom; + } + + BitBoard out = new BitBoard(); + out.setArrowTop(reachableTop); + out.setArrowBottom(reachableBottom); + + return out; + } + private static double delta(int blackDistance, int whiteDistance, int nextPlayer) { final double k = 0.2; if (blackDistance >= 1000000 && whiteDistance >= 1000000) @@ -368,6 +1118,522 @@ private static int[] getSurroundingValues(int index, int[] distanceBoard) { return values; } + public static BitBoard queenReachableInOneMove(BitBoard b) { + // Constants + long columnMask = 0b00000000010000000001000000000100000000010000000001L; + long rowMask = 0b1111111111L; + long diagonalMask = 0b10000000000100000000001000000000010000000000100000000001L; + long antiDiagonalMask = 0b000000001000000001000000001000000001000000001L; + long boardMask = -1L >>> (64 - 50); + + long whiteTop = b.getWhiteQueensTop(); + long whiteBottom = b.getWhiteQueensBottom(); + long blackTop = b.getBlackQueensTop(); + long blackBottom = b.getBlackQueensBottom(); + long arrowTop = b.getArrowTop(); + long arrowBottom = b.getArrowBottom(); + + // All squares reachable in one move + long blackReachableTop = 0L; + long blackReachableBottom = 0L; + long whiteReachableTop = 0L; + long whiteReachableBottom = 0L; + + // Find black queens + long[] blackQueensTop = new long[4]; + long[] blackQueensBottom = new long[4]; + if (blackBottom > 1) { + blackQueensBottom[0] = Long.lowestOneBit(blackBottom); + blackQueensTop[0] = 0L; + blackBottom ^= blackQueensBottom[0]; + } else { + blackQueensBottom[0] = 0L; + blackQueensTop[0] = Long.lowestOneBit(blackTop); + blackTop ^= blackQueensTop[0]; + } + if (blackBottom > 1) { + blackQueensBottom[1] = Long.lowestOneBit(blackBottom); + blackQueensTop[1] = 0L; + blackBottom ^= blackQueensBottom[1]; + } else { + blackQueensBottom[1] = 0L; + blackQueensTop[1] = Long.lowestOneBit(blackTop); + blackTop ^= blackQueensTop[1]; + } + if (blackBottom > 1) { + blackQueensBottom[2] = Long.lowestOneBit(blackBottom); + blackQueensTop[2] = 0L; + blackBottom ^= blackQueensBottom[2]; + } else { + blackQueensBottom[2] = 0L; + blackQueensTop[2] = Long.lowestOneBit(blackTop); + blackTop ^= blackQueensTop[2]; + } + if (blackBottom > 1) { + blackQueensBottom[3] = Long.lowestOneBit(blackBottom); + blackQueensTop[3] = 0L; + blackBottom ^= blackQueensBottom[3]; + } else { + blackQueensBottom[3] = 0L; + blackQueensTop[3] = Long.lowestOneBit(blackTop); + blackTop ^= blackQueensTop[3]; + } + + + // Find white queens + long[] whiteQueensTop = new long[4]; + long[] whiteQueensBottom = new long[4]; + if (whiteBottom > 1) { + whiteQueensBottom[0] = Long.lowestOneBit(whiteBottom); + whiteQueensTop[0] = 0L; + whiteBottom ^= whiteQueensBottom[0]; + } else { + whiteQueensBottom[0] = 0L; + whiteQueensTop[0] = Long.lowestOneBit(whiteTop); + whiteTop ^= whiteQueensTop[0]; + } + if (whiteBottom > 1) { + whiteQueensBottom[1] = Long.lowestOneBit(whiteBottom); + whiteQueensTop[1] = 0L; + whiteBottom ^= whiteQueensBottom[1]; + } else { + whiteQueensBottom[1] = 0L; + whiteQueensTop[1] = Long.lowestOneBit(whiteTop); + whiteTop ^= whiteQueensTop[1]; + } + if (whiteBottom > 1) { + whiteQueensBottom[2] = Long.lowestOneBit(whiteBottom); + whiteQueensTop[2] = 0L; + whiteBottom ^= whiteQueensBottom[2]; + } else { + whiteQueensBottom[2] = 0L; + whiteQueensTop[2] = Long.lowestOneBit(whiteTop); + whiteTop ^= whiteQueensTop[2]; + } + if (whiteBottom > 1) { + whiteQueensBottom[3] = Long.lowestOneBit(whiteBottom); + whiteQueensTop[3] = 0L; + whiteBottom ^= whiteQueensBottom[3]; + } else { + whiteQueensBottom[3] = 0L; + whiteQueensTop[3] = Long.lowestOneBit(whiteTop); + whiteTop ^= whiteQueensTop[3]; + } + + for (int pieceNum = 0; pieceNum < 8; pieceNum++) { + + // The piece we are moving + long startBottom, startTop; + if (pieceNum < 4) { + startBottom = blackQueensBottom[pieceNum]; + startTop = blackQueensTop[pieceNum]; + } else { + startBottom = whiteQueensBottom[pieceNum - 4]; + startTop = whiteQueensTop[pieceNum - 4]; + } + + int startRow = getRow(startBottom, startTop); + int startCol = getCol(startBottom, startTop); + + // All pieces except the one we are moving + long occupiedTop = blackTop | whiteTop | arrowTop; + long occupiedBottom = blackBottom | whiteBottom | arrowBottom; + occupiedTop ^= startTop; + occupiedBottom ^= startBottom; + + + /////////////////////////////////////////// + // Generate masks along the 4 directions // + /////////////////////////////////////////// + + // Horizontal + long rowMaskTop, rowMaskBottom; + if (startRow < 5) { + rowMaskTop = 0L; + rowMaskBottom = rowMask << (startRow * 10); + } else { + rowMaskTop = rowMask << ((startRow - 5) * 10); + rowMaskBottom = 0L; + } + + // Vertical + long colMaskTop = columnMask << startCol; + long colMaskBottom = columnMask << startCol; + + // Diagonal + long diagMaskTop, diagMaskBottom; + int diagShift = startCol - startRow; + if (diagShift >= 0) { + if (diagShift >= 5) { + diagMaskTop = 0L; + diagMaskBottom = (diagonalMask << diagShift) & ~(-1L << ((10 - diagShift) * 10)); + } else { + diagMaskTop = (diagonalMask << (diagShift + 5)) & ~(-1L << ((5 - diagShift) * 10)); + diagMaskBottom = diagonalMask << diagShift; + } + } else { + if (diagShift < -5) { + diagMaskTop = (diagonalMask >>> (-diagShift + 6)) & (-1L << ((-diagShift - 5) * 10)); + diagMaskBottom = 0L; + } else { + diagMaskTop = diagonalMask >>> (-diagShift + 6); + diagMaskBottom = (diagonalMask >>> (-diagShift)) & (-1L << (-diagShift * 10)); + } + } + + long antiDiagMaskTop, antiDiagMaskBottom; + int antiDiagShift = startCol + startRow; + if (antiDiagShift >= 9) { + if (antiDiagShift > 13) { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)) & (-1L << ((antiDiagShift - 14) * 10 + 1)); + antiDiagMaskBottom = 0L; + } else { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)); + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)) & (-1L << ((antiDiagShift - 9) * 10 + 1)); + } + } else { + if (antiDiagShift < 4) { + antiDiagMaskTop = 0L; + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)) & ~(-1L << ((antiDiagShift + 1) * 10 - 1)); + } else { + antiDiagMaskTop = (antiDiagonalMask << (antiDiagShift - 5)) & ~(-1L << ((antiDiagShift - 4) * 10 - 1)); + antiDiagMaskBottom = (antiDiagonalMask << (antiDiagShift)); + } + } + + // Make sure that all masks are only on the board + rowMaskTop &= boardMask; + rowMaskBottom &= boardMask; + colMaskTop &= boardMask; + colMaskBottom &= boardMask; + diagMaskTop &= boardMask; + diagMaskBottom &= boardMask; + antiDiagMaskTop &= boardMask; + antiDiagMaskBottom &= boardMask; + + + //////////////////////// + // Get pieces // + //////////////////////// + + // Get the possible pieces below the current piece + long piecesBelowBottom, piecesBelowTop; + piecesBelowBottom = startBottom - 1; + if (startTop == 0) + piecesBelowTop = 0L; + else + piecesBelowTop = startTop - 1; + + // Get the possible pieces above the current piece + long piecesAboveBottom = ~startBottom ^ piecesBelowBottom; + long piecesAboveTop = ~startTop ^ piecesBelowTop; + + + // Get the pieces that are along the path of the current piece in each direction + long occupiedRowTop = occupiedTop & rowMaskTop; + long occupiedRowBottom = occupiedBottom & rowMaskBottom; + long occupiedColTop = occupiedTop & colMaskTop; + long occupiedColBottom = occupiedBottom & colMaskBottom; + long occupiedDiagTop = occupiedTop & diagMaskTop; + long occupiedDiagBottom = occupiedBottom & diagMaskBottom; + long occupiedAntiDiagTop = occupiedTop & antiDiagMaskTop; + long occupiedAntiDiagBottom = occupiedBottom & antiDiagMaskBottom; + + + //////////////////////// + // Get blocking piece // + //////////////////////// + + // Get the first blocking piece in col above the current piece + long blockingPieceColAboveBottom, blockingPieceColAboveTop; + if (startTop != 0) { + blockingPieceColAboveBottom = 0L; + blockingPieceColAboveTop = Long.lowestOneBit(occupiedColTop & piecesAboveTop); + } else { + blockingPieceColAboveBottom = Long.lowestOneBit(occupiedColBottom & piecesAboveBottom); + if (blockingPieceColAboveBottom == 0) + blockingPieceColAboveTop = Long.lowestOneBit(occupiedColTop & piecesAboveTop); + else + blockingPieceColAboveTop = 0L; + } + + // Get the first blocking piece in col below the current piece + long blockingPieceColBelowBottom, blockingPieceColBelowTop; + if (startTop == 0) { + blockingPieceColBelowTop = 0L; + blockingPieceColBelowBottom = Long.highestOneBit(occupiedColBottom & piecesBelowBottom); + } else { + blockingPieceColBelowTop = Long.highestOneBit(occupiedColTop & piecesBelowTop); + if (blockingPieceColBelowTop == 0) + blockingPieceColBelowBottom = Long.highestOneBit(occupiedColBottom & piecesBelowBottom); + else + blockingPieceColBelowBottom = 0L; + } + + // Get the first blocking piece in row above the current piece + long blockingPieceRowAboveBottom, blockingPieceRowAboveTop; + if (startTop != 0) { + blockingPieceRowAboveBottom = 0L; + blockingPieceRowAboveTop = Long.lowestOneBit(occupiedRowTop & piecesAboveTop); + } else { + blockingPieceRowAboveBottom = Long.lowestOneBit(occupiedRowBottom & piecesAboveBottom); + blockingPieceRowAboveTop = 0L; + } + + // Get the first blocking piece in row below the current piece + long blockingPieceRowBelowBottom, blockingPieceRowBelowTop; + if (startTop == 0) { + blockingPieceRowBelowTop = 0L; + blockingPieceRowBelowBottom = Long.highestOneBit(occupiedRowBottom & piecesBelowBottom); + } else { + blockingPieceRowBelowTop = Long.highestOneBit(occupiedRowTop & piecesBelowTop); + blockingPieceRowBelowBottom = 0L; + } + + // Get the first blocking piece in diag above the current piece + long blockingPieceDiagAboveBottom, blockingPieceDiagAboveTop; + if (startTop != 0) { + blockingPieceDiagAboveBottom = 0L; + blockingPieceDiagAboveTop = Long.lowestOneBit(occupiedDiagTop & piecesAboveTop); + } else { + blockingPieceDiagAboveBottom = Long.lowestOneBit(occupiedDiagBottom & piecesAboveBottom); + if (blockingPieceDiagAboveBottom == 0) + blockingPieceDiagAboveTop = Long.lowestOneBit(occupiedDiagTop & piecesAboveTop); + else + blockingPieceDiagAboveTop = 0L; + } + + // Get the first blocking piece in diag below the current piece + long blockingPieceDiagBelowBottom, blockingPieceDiagBelowTop; + if (startTop == 0) { + blockingPieceDiagBelowTop = 0L; + blockingPieceDiagBelowBottom = Long.highestOneBit(occupiedDiagBottom & piecesBelowBottom); + } else { + blockingPieceDiagBelowTop = Long.highestOneBit(occupiedDiagTop & piecesBelowTop); + if (blockingPieceDiagBelowTop == 0) + blockingPieceDiagBelowBottom = Long.highestOneBit(occupiedDiagBottom & piecesBelowBottom); + else + blockingPieceDiagBelowBottom = 0L; + } + + // Get the first blocking piece in anti-diag above the current piece + long blockingPieceAntiDiagAboveBottom, blockingPieceAntiDiagAboveTop; + if (startTop != 0) { + blockingPieceAntiDiagAboveBottom = 0L; + blockingPieceAntiDiagAboveTop = Long.lowestOneBit(occupiedAntiDiagTop & piecesAboveTop); + } else { + blockingPieceAntiDiagAboveBottom = Long.lowestOneBit(occupiedAntiDiagBottom & piecesAboveBottom); + if (blockingPieceAntiDiagAboveBottom == 0) + blockingPieceAntiDiagAboveTop = Long.lowestOneBit(occupiedAntiDiagTop & piecesAboveTop); + else + blockingPieceAntiDiagAboveTop = 0L; + } + + // Get the first blocking piece in anti-diag below the current piece + long blockingPieceAntiDiagBelowBottom, blockingPieceAntiDiagBelowTop; + if (startTop == 0) { + blockingPieceAntiDiagBelowTop = 0L; + blockingPieceAntiDiagBelowBottom = Long.highestOneBit(occupiedAntiDiagBottom & piecesBelowBottom); + } else { + blockingPieceAntiDiagBelowTop = Long.highestOneBit(occupiedAntiDiagTop & piecesBelowTop); + if (blockingPieceAntiDiagBelowTop == 0) + blockingPieceAntiDiagBelowBottom = Long.highestOneBit(occupiedAntiDiagBottom & piecesBelowBottom); + else + blockingPieceAntiDiagBelowBottom = 0L; + } + + + /////////////////////////// + // Get pieces in between // + /////////////////////////// + + // Get squares movable to in col above the current piece + long betweenColAboveBottom, betweenColAboveTop; + if (startTop != 0) { + betweenColAboveBottom = 0L; + if (blockingPieceColAboveTop == 0) + betweenColAboveTop = piecesAboveTop & colMaskTop; + else + betweenColAboveTop = piecesAboveTop & (blockingPieceColAboveTop - 1) & colMaskTop; + } else { + if (blockingPieceColAboveBottom == 0) { + betweenColAboveBottom = piecesAboveBottom & colMaskBottom; + betweenColAboveTop = (blockingPieceColAboveTop - 1) & colMaskTop; + } else { + betweenColAboveBottom = piecesAboveBottom & (blockingPieceColAboveBottom - 1) & colMaskBottom; + betweenColAboveTop = 0L; + } + } + + // Get squares movable to in col below the current piece + long betweenColBelowBottom, betweenColBelowTop; + if (startTop == 0) { + betweenColBelowTop = 0L; + if (blockingPieceColBelowBottom == 0) + betweenColBelowBottom = piecesBelowBottom & colMaskBottom; + else + betweenColBelowBottom = piecesBelowBottom & (-blockingPieceColBelowBottom ^ blockingPieceColBelowBottom) & colMaskBottom; + } else { + if (blockingPieceColBelowTop == 0) { + betweenColBelowTop = piecesBelowTop & colMaskTop; + if (blockingPieceColBelowBottom == 0) + betweenColBelowBottom = piecesBelowBottom & colMaskBottom; + else + betweenColBelowBottom = (-blockingPieceColBelowBottom ^ blockingPieceColBelowBottom) & colMaskBottom; + } else { + betweenColBelowTop = piecesBelowTop & (-blockingPieceColBelowTop ^ blockingPieceColBelowTop) & colMaskTop; + betweenColBelowBottom = 0L; + } + } + + // Get squares movable to in row above the current piece + long betweenRowAboveBottom, betweenRowAboveTop; + if (startTop != 0) { + betweenRowAboveBottom = 0L; + if (blockingPieceRowAboveTop == 0) + betweenRowAboveTop = piecesAboveTop & rowMaskTop; + else + betweenRowAboveTop = piecesAboveTop & (blockingPieceRowAboveTop - 1) & rowMaskTop; + } else { + if (blockingPieceRowAboveBottom == 0) { + betweenRowAboveBottom = piecesAboveBottom & rowMaskBottom; + } else { + betweenRowAboveBottom = piecesAboveBottom & (blockingPieceRowAboveBottom - 1) & rowMaskBottom; + } + betweenRowAboveTop = 0L; + } + + // Get squares movable to in row below the current piece + long betweenRowBelowBottom, betweenRowBelowTop; + if (startTop == 0) { + betweenRowBelowTop = 0L; + if (blockingPieceRowBelowBottom == 0) + betweenRowBelowBottom = piecesBelowBottom & rowMaskBottom; + else + betweenRowBelowBottom = piecesBelowBottom & (-blockingPieceRowBelowBottom ^ blockingPieceRowBelowBottom) & rowMaskBottom; + } else { + if (blockingPieceRowBelowTop == 0) { + betweenRowBelowTop = piecesBelowTop & rowMaskTop; + } else { + betweenRowBelowTop = piecesBelowTop & (-blockingPieceRowBelowTop ^ blockingPieceRowBelowTop) & rowMaskTop; + } + betweenRowBelowBottom = 0L; + } + + // Get squares movable to in diag above the current piece + long betweenDiagAboveBottom, betweenDiagAboveTop; + if (startTop != 0) { + betweenDiagAboveBottom = 0L; + if (blockingPieceDiagAboveTop == 0) + betweenDiagAboveTop = piecesAboveTop & diagMaskTop; + else + betweenDiagAboveTop = piecesAboveTop & (blockingPieceDiagAboveTop - 1) & diagMaskTop; + } else { + if (blockingPieceDiagAboveBottom == 0) { + betweenDiagAboveBottom = piecesAboveBottom & diagMaskBottom; + betweenDiagAboveTop = (blockingPieceDiagAboveTop - 1) & diagMaskTop; + } else { + betweenDiagAboveBottom = piecesAboveBottom & (blockingPieceDiagAboveBottom - 1) & diagMaskBottom; + betweenDiagAboveTop = 0L; + } + } + + // Get squares movable to in diag below the current piece + long betweenDiagBelowBottom, betweenDiagBelowTop; + if (startTop == 0) { + betweenDiagBelowTop = 0L; + if (blockingPieceDiagBelowBottom == 0) + betweenDiagBelowBottom = piecesBelowBottom & diagMaskBottom; + else + betweenDiagBelowBottom = piecesBelowBottom & (-blockingPieceDiagBelowBottom ^ blockingPieceDiagBelowBottom) & diagMaskBottom; + } else { + if (blockingPieceDiagBelowTop == 0) { + betweenDiagBelowTop = piecesBelowTop & diagMaskTop; + if (blockingPieceDiagBelowBottom == 0) + betweenDiagBelowBottom = piecesBelowBottom & diagMaskBottom; + else + betweenDiagBelowBottom = (-blockingPieceDiagBelowBottom ^ blockingPieceDiagBelowBottom) & diagMaskBottom; + } else { + betweenDiagBelowTop = piecesBelowTop & (-blockingPieceDiagBelowTop ^ blockingPieceDiagBelowTop) & diagMaskTop; + betweenDiagBelowBottom = 0L; + } + } + + // Get squares movable to in anti-diag above the current piece + long betweenAntiDiagAboveBottom, betweenAntiDiagAboveTop; + if (startTop != 0) { + betweenAntiDiagAboveBottom = 0L; + if (blockingPieceAntiDiagAboveTop == 0) + betweenAntiDiagAboveTop = piecesAboveTop & antiDiagMaskTop; + else + betweenAntiDiagAboveTop = piecesAboveTop & (blockingPieceAntiDiagAboveTop - 1) & antiDiagMaskTop; + } else { + if (blockingPieceAntiDiagAboveBottom == 0) { + betweenAntiDiagAboveBottom = piecesAboveBottom & antiDiagMaskBottom; + betweenAntiDiagAboveTop = (blockingPieceAntiDiagAboveTop - 1) & antiDiagMaskTop; + } else { + betweenAntiDiagAboveBottom = piecesAboveBottom & (blockingPieceAntiDiagAboveBottom - 1) & antiDiagMaskBottom; + betweenAntiDiagAboveTop = 0L; + } + } + + // Get squares movable to in anti-diag below the current piece + long betweenAntiDiagBelowBottom, betweenAntiDiagBelowTop; + if (startTop == 0) { + betweenAntiDiagBelowTop = 0L; + if (blockingPieceAntiDiagBelowBottom == 0) + betweenAntiDiagBelowBottom = piecesBelowBottom & antiDiagMaskBottom; + else + betweenAntiDiagBelowBottom = piecesBelowBottom & (-blockingPieceAntiDiagBelowBottom ^ blockingPieceAntiDiagBelowBottom) & antiDiagMaskBottom; + } else { + if (blockingPieceAntiDiagBelowTop == 0) { + betweenAntiDiagBelowTop = piecesBelowTop & antiDiagMaskTop; + if (blockingPieceAntiDiagBelowBottom == 0) + betweenAntiDiagBelowBottom = piecesBelowBottom & antiDiagMaskBottom; + else + betweenAntiDiagBelowBottom = (-blockingPieceAntiDiagBelowBottom ^ blockingPieceAntiDiagBelowBottom) & antiDiagMaskBottom; + } else { + betweenAntiDiagBelowTop = piecesBelowTop & (-blockingPieceAntiDiagBelowTop ^ blockingPieceAntiDiagBelowTop) & antiDiagMaskTop; + betweenAntiDiagBelowBottom = 0L; + } + } + + // Or results onto the total reachable squares in pne move + long reachableTop = betweenColAboveTop | betweenColBelowTop | betweenRowAboveTop | betweenRowBelowTop | betweenDiagAboveTop | betweenDiagBelowTop | betweenAntiDiagAboveTop | betweenAntiDiagBelowTop; + long reachableBottom = betweenColAboveBottom | betweenColBelowBottom | betweenRowAboveBottom | betweenRowBelowBottom | betweenDiagAboveBottom | betweenDiagBelowBottom | betweenAntiDiagAboveBottom | betweenAntiDiagBelowBottom; + if (pieceNum < 4) { + blackReachableTop |= reachableTop; + blackReachableBottom |= reachableBottom; + } else { + whiteReachableTop |= reachableTop; + whiteReachableBottom |= reachableBottom; + } + } + + // Put the squares that are reachable by both players in one move into a BitBoard for return + BitBoard reachable = new BitBoard(); + reachable.setBlackQueensBottom(blackReachableBottom); + reachable.setBlackQueensTop(blackReachableTop); + reachable.setWhiteQueensBottom(whiteReachableBottom); + reachable.setWhiteQueensTop(whiteReachableTop); + + return reachable; + } + + private static int getCol(long startBottom, long startTop) { + if (startBottom == 0) + return Long.bitCount(startTop - 1) % 10; + else + return Long.bitCount(startBottom - 1) % 10; + } + + private static int getRow(long startBottom, long startTop) { + if (startBottom == 0) + return Long.bitCount(startTop - 1) / 10 + 5; + else + return Long.bitCount(startBottom - 1) / 10; + } + public static int calculateTileControl(int x, int y, BitBoard board) { // this is the number of tiles you can move to from this position int moves = 0; diff --git a/src/main/java/Tree/MonteCarloTree.java b/src/main/java/Tree/MonteCarloTree.java index e4c74b9..fac3843 100644 --- a/src/main/java/Tree/MonteCarloTree.java +++ b/src/main/java/Tree/MonteCarloTree.java @@ -1,6 +1,7 @@ package Tree; -import State.*; +import State.Action; +import State.State; import java.util.ArrayList; import java.util.Arrays; @@ -19,8 +20,7 @@ public class MonteCarloTree { * The number of nodes that the expansion policy will attempt to expand. May expand less if there aren't that many nodes left to expand */ private final int NUM_TO_EXPAND = Runtime.getRuntime().availableProcessors(); - - private final ExecutorService executor; + private ExecutorService executor; public MonteCarloTree(State state, double cValue, int colour, int depth, int[] moveDictionary) { this.cValue = cValue; @@ -38,11 +38,9 @@ public Action search() { searching our tree. */ - Action selectedAction = null; - - boolean useMoveDictionary = root.getDepth() < 8; + boolean useMoveDictionary = root.getDepth() < 6; if (useMoveDictionary) { - selectedAction = getMoveDictionaryMove(); + return getMoveDictionaryMove(); } try { while (time.timeLeft()) { @@ -51,39 +49,44 @@ public Action search() { // Get most promising nodes to simulate Node[] children = ExpansionPolicy.expansionNode(leaf, NUM_TO_EXPAND); - // Simulate each child in its own Thread // There should never be zero nodes returned because select() should not have chosen it if (children.length == 0) - throw new RuntimeException("HELP, THIS IS BAD"); + throw new RuntimeException("HELP, THIS IS BAD " + root.getTotalPlayouts()); + + // Simulate each child in its own Thread // Create a list of runnable tasks that will be executed in separate threads - List> callables = new ArrayList<>(); - for (Node child : children) - callables.add(() -> Simulate.simulate(child)); - - try { - // Execute all tasks. Will block until all threads have returned a value - List> futures = executor.invokeAll(callables); - - // Backpropagation of results - for (int i = 0; i < children.length; i++) { - Future future = futures.get(i); - int result = future.get(); - backPropagate(result, children[i]); - } - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); + List> callables = new ArrayList<>(); + for (Node child : children) { + backPropagate(Simulate.simulate(child), child); +// callables.add(() -> Simulate.simulate(child)); } + +// try { +// // Execute all tasks. Will block until all threads have returned a value +// List> futures = executor.invokeAll(callables); +// +// // Backpropagation of results +// for (int i = 0; i < children.length; i++) { +// Future future = futures.get(i); +// double result = future.get(); +// backPropagate(result, children[i]); +// } +// } catch (InterruptedException | ExecutionException e) { +// throw new RuntimeException(e); +// } } } catch (NullPointerException ignore) { } System.out.println("Ran " + getRoot().getTotalPlayouts() + " times"); - if (mostVisitedNode() != null && !useMoveDictionary) - selectedAction = mostVisitedNode().getAction(); + Node mostVisitedNode = mostVisitedNode(); + if (mostVisitedNode != null) + return mostVisitedNode.getAction(); - return selectedAction; + // End of game + return null; } private Node select(Node tree) { @@ -119,6 +122,27 @@ private void backPropagate(int result, Node child) { } } + private void backPropagate(double result, Node child) { + if (child.getColour() != State.BLACK_QUEEN) { + child.setTotalPlayouts(child.getTotalPlayouts() + 1); + child.setTotalWins(child.getTotalWins() + result); + } else { + child.setTotalPlayouts(child.getTotalPlayouts() + 1); + child.setTotalWins(child.getTotalWins() + (1 - result)); + } + + while (child.getParent() != null) { + child = child.getParent(); + if (child.getColour() != State.BLACK_QUEEN) { + child.setTotalPlayouts(child.getTotalPlayouts() + 1); + child.setTotalWins(child.getTotalWins() + result); + } else { + child.setTotalPlayouts(child.getTotalPlayouts() + 1); + child.setTotalWins(child.getTotalWins() + (1 - result)); + } + } + } + private Node mostVisitedNode() { Node bestNode = null; int bestCount = -1; @@ -151,7 +175,7 @@ private Node UCBMove(Node n) { } private double UCBEquation(Node n) { - return (double) n.getTotalWins() / n.getTotalPlayouts() + cValue * Math.sqrt(Math.log((double) n.getParent().getTotalPlayouts() / n.getTotalPlayouts())); + return (double) n.getTotalWins() / n.getTotalPlayouts() + cValue * Math.sqrt(Math.log((double) n.getParent().getTotalPlayouts()) / n.getTotalPlayouts()); } private Action getMoveDictionaryMove() { diff --git a/src/main/java/Tree/Node.java b/src/main/java/Tree/Node.java index 087d318..ce73e70 100644 --- a/src/main/java/Tree/Node.java +++ b/src/main/java/Tree/Node.java @@ -8,7 +8,7 @@ import java.util.Objects; public class Node { - private int totalWins; + private double totalWins; private int totalPlayouts; private Node parent; private Node[] children; @@ -69,11 +69,11 @@ public Action[] getPossibleActions() { return possibleActions; } - public int getTotalWins() { + public double getTotalWins() { return totalWins; } - public void setTotalWins(int totalWins) { + public void setTotalWins(double totalWins) { this.totalWins = totalWins; } diff --git a/src/main/java/Tree/Simulate.java b/src/main/java/Tree/Simulate.java index 1fc49d4..1cb9525 100644 --- a/src/main/java/Tree/Simulate.java +++ b/src/main/java/Tree/Simulate.java @@ -3,8 +3,10 @@ import State.*; import java.util.ArrayList; +import java.util.Random; public class Simulate { + private static Random r = new Random(); /** * We perform a playout from the newly generated child node, choosing moves for both players according to the playout policy. These moves are not recorded in the search tree. In the figure, the simulation results in a win for black. @@ -12,8 +14,46 @@ public class Simulate { * @param node The Node to be played out * @return The player that won. Either State.BLACK or State.WHITE */ - public static int simulate(Node node) { - return earlyTerminationPlayout(node); + public static double simulate(Node node) { +// return earlyTerminationPlayout(node); + return heuristicSimulation(node); + } + + private static double heuristicSimulation(Node node) { + State state = node.getState(); + + double heuristic = Heuristics.bigPoppa(state, node.getColour()); + return 1 / (1 + Math.exp(-3*heuristic)); + } + + + /** + * Hot garbage. Waaaaaaaaaay too slow + */ + private static int heuristicSimulation2(Node node) { + State state = node.getState(); + int color = node.getColour(); + ArrayList actions = ActionGenerator.generateActions(state, color); + while (actions.size() > 0) { + Action definitelyTheBestAction = null; + double bestH = Math.pow(-1, color) * Integer.MAX_VALUE; + for (Action a : actions) { + double h = Heuristics.bigPoppa(new State(state, a), color); + if (color == 1 && h > bestH) { + definitelyTheBestAction = a; + bestH = h; + } else if (color == 2 && h < bestH) { + definitelyTheBestAction = a; + bestH = h; + } + } + assert definitelyTheBestAction != null; + state = new State(state, definitelyTheBestAction); + color = color == 1 ? 2 : 1; + actions = ActionGenerator.generateActions(state, color); + } + + return color == 1 ? 2 : 1; } /** @@ -37,7 +77,7 @@ private static int randomPlayout(Node node) { private static int earlyTerminationPlayout(Node node) { int i = 0; - final int TERMINATION_DEPTH = 35; + final int TERMINATION_DEPTH = r.nextInt(3) + 5; State state = new State(node.getState(), node.getAction()); int color = node.getColour(); int depth = node.getDepth(); @@ -53,11 +93,16 @@ private static int earlyTerminationPlayout(Node node) { if (i < TERMINATION_DEPTH) return color == State.BLACK_QUEEN ? State.WHITE_QUEEN : State.BLACK_QUEEN; else { - int blackControl = boardControlHeuristic(state, State.BLACK_QUEEN); - int whiteControl = boardControlHeuristic(state, State.WHITE_QUEEN); - if (blackControl == whiteControl) - return 0; - else if (blackControl > whiteControl) +// int blackControl = boardControlHeuristic(state, State.BLACK_QUEEN); +// int whiteControl = boardControlHeuristic(state, State.WHITE_QUEEN); +// if (blackControl == whiteControl) +// return 0; +// else if (blackControl > whiteControl) +// return State.BLACK_QUEEN; +// else +// return State.WHITE_QUEEN; + double heuristic = Heuristics.bigPoppa(state, color); + if (heuristic > 0) return State.BLACK_QUEEN; else return State.WHITE_QUEEN;