Skip to content

Commit

Permalink
Merge pull request #4 from Klavionik/improve-ai-targeting
Browse files Browse the repository at this point in the history
Improve AI targeting
  • Loading branch information
Klavionik authored Sep 20, 2024
2 parents feaa29e + 9991930 commit a73401f
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 135 deletions.
92 changes: 64 additions & 28 deletions battleship/engine/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,84 @@


class TargetCaller:
def __init__(self, board: domain.Board) -> None:
def __init__(self, board: domain.Board, no_adjacent_ships: bool = False) -> None:
self.board = board
self.next_targets: deque[domain.Cell] = deque()
self.no_adjacent_ships = no_adjacent_ships
self.next_targets: deque[domain.Coordinate] = deque()
self.excluded_cells: set[domain.Coordinate] = set()

def call_out(self, *, count: int = 1) -> list[str]:
cells = self._get_targets(count)
return [cell.coordinate.to_human() for cell in cells]
targets = self._get_targets(count)
return [target.to_human() for target in targets]

def provide_feedback(self, shots: Iterable[domain.Shot]) -> None:
for shot in shots:
if shot.hit and not shot.ship.destroyed: # type: ignore
cell = self.board.get_cell(shot.coordinate)

if cell is None:
raise errors.CellOutOfRange(f"Cell at {shot.coordinate} doesn't exist.")

neighbors = self._find_neighbor_cells(cell)
self.next_targets.extend(neighbors)

def _get_targets(self, count: int) -> list[domain.Cell]:
targets: list[domain.Cell] = []
# If shot was a hit, we can learn something from it.
if shot.hit:
assert shot.ship, "Shot was a hit, but no ship present"

if shot.ship.destroyed and self.no_adjacent_ships:
# If ship was destroyed and there's "No adjacent ships"
# rule enabled, there's no point in firing cells
# that surrounds the ship - it's impossible to place
# another ship there.
coordinates = self._find_cells_around_ship(shot.ship)
self.excluded_cells.update(coordinates)
self.next_targets.clear()
elif not shot.ship.destroyed:
# If ship was hit, but not destroyed, keep on
# firing cells around until it is destroyed.
cells = self._target_ship(shot.coordinate)
self.next_targets.extend(cells)

def _get_targets(self, count: int) -> list[domain.Coordinate]:
targets: list[domain.Coordinate] = []

while len(self.next_targets) and len(targets) != count:
next_target = self.next_targets.popleft()
targets.append(next_target)

if len(targets) != count:
random_targets = self._find_random_targets(count - len(targets))
random_targets = self._target_random_cells(count - len(targets))
targets.extend(random_targets)

return targets

def _find_random_targets(self, count: int) -> list[domain.Cell]:
candidates = [cell for cell in self.board.cells if not cell.is_shot]
def _target_random_cells(self, count: int) -> list[domain.Coordinate]:
candidates = [
cell.coordinate
for cell in self.board.cells
if not (cell.is_shot or cell.coordinate in self.excluded_cells)
]
return random.sample(candidates, k=min(len(candidates), count))

def _find_neighbor_cells(self, cell: domain.Cell) -> list[domain.Cell]:
def _target_ship(self, coordinate: domain.Coordinate) -> list[domain.Coordinate]:
cells = []

for direction in list(domain.Direction):
candidate = self.board.get_adjacent_cell(cell, direction) # type: ignore[arg-type]

if candidate is None or candidate.is_shot or candidate in self.next_targets:
for cell_ in self.board.get_adjacent_cells(coordinate, with_diagonals=False):
if (
cell_.is_shot
or cell_.coordinate in self.next_targets
or cell_.coordinate in self.excluded_cells
):
continue

cells.append(candidate)
cells.append(cell_.coordinate)

return cells

def _find_cells_around_ship(self, ship: domain.Ship) -> list[domain.Coordinate]:
cells = []

for coordinate in ship.cells:
adjacent_cells = self.board.get_adjacent_cells(coordinate)
adjacent_coordinates = [
cell.coordinate
for cell in adjacent_cells
if not cell.is_shot and cell.coordinate not in ship.cells
]

cells.extend(adjacent_coordinates)

return cells

Expand All @@ -67,7 +99,11 @@ def __init__(self, board: domain.Board, ship_suite: rosters.Roster, no_adjacent_
def place(self, ship_type: rosters.ShipType) -> list[domain.Coordinate]:
ship_hp = self.ship_hp_map[ship_type]
position: list[domain.Coordinate] = []
empty_cells = [cell for cell in self.board.cells if cell.ship is None]
empty_cells = [
cell.coordinate
for cell in self.board.cells
if not self.board.has_ship_at(cell.coordinate)
]
directions = list[domain.Direction](domain.Direction)
random.shuffle(empty_cells)
random.shuffle(directions)
Expand All @@ -79,16 +115,16 @@ def place(self, ship_type: rosters.ShipType) -> list[domain.Coordinate]:
# Try to found enough empty cells to place the ship in this direction.
for _ in range(ship_hp):
# Get the next cell in this direction.
next_cell = self.board.get_adjacent_cell(start_cell, direction)
next_cell = start_cell.next(direction)

# If there is no cell or the cell is taken,
# clear the progress and try another direction.
if next_cell is None or next_cell.ship is not None:
if not self.board.has_cell(next_cell) or self.board.has_ship_at(next_cell):
position.clear()
break

# Otherwise, save the coordinate.
position.append(next_cell.coordinate)
position.append(next_cell)

# If there is enough cells to place the ship, return the position.
if len(position) == ship_hp:
Expand Down
151 changes: 87 additions & 64 deletions battleship/engine/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,26 @@
ASCII_OFFSET = 64


@enum.unique
class Direction(StrEnum):
UP = "up"
DOWN = "down"
RIGHT = "right"
LEFT = "left"
UP = enum.auto()
DOWN = enum.auto()
RIGHT = enum.auto()
LEFT = enum.auto()


@enum.unique
class DiagonalDirection(StrEnum):
UP_RIGHT = enum.auto()
UP_LEFT = enum.auto()
DOWN_RIGHT = enum.auto()
DOWN_LEFT = enum.auto()


@enum.unique
class FiringOrder(StrEnum):
ALTERNATELY = "alternately"
UNTIL_MISS = "until_miss"
ALTERNATELY = enum.auto()
UNTIL_MISS = enum.auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -58,17 +68,16 @@ def __eq__(self, other: object) -> bool:

raise NotImplementedError(f"Cannot compare Coordinate to {other.__class__.__name__}.")

def up(self) -> "Coordinate":
return Coordinate(self.x, self.y - 1)
def __hash__(self) -> int:
return hash((self.x, self.y))

def right(self) -> "Coordinate":
return Coordinate(self.x + 1, self.y)

def down(self) -> "Coordinate":
return Coordinate(self.x, self.y + 1)
def __repr__(self) -> str:
return f"Coordinate(x={self.x}, y={self.y}, human={self.to_human()})"

def left(self) -> "Coordinate":
return Coordinate(self.x - 1, self.y)
@classmethod
def from_human(cls, coordinate: str) -> "Coordinate":
col, row = parse_coordinate(coordinate)
return Coordinate(ord(col) - ASCII_OFFSET - 1, row - 1)

@property
def col(self) -> str:
Expand All @@ -78,14 +87,46 @@ def col(self) -> str:
def row(self) -> int:
return self.y + 1

@classmethod
def from_human(cls, coordinate: str) -> "Coordinate":
col, row = parse_coordinate(coordinate)
return Coordinate(ord(col) - ASCII_OFFSET - 1, row - 1)

def to_human(self) -> str:
return f"{self.col}{self.row}"

def up(self) -> "Coordinate":
return Coordinate(self.x, self.y - 1)

def right(self) -> "Coordinate":
return Coordinate(self.x + 1, self.y)

def down(self) -> "Coordinate":
return Coordinate(self.x, self.y + 1)

def left(self) -> "Coordinate":
return Coordinate(self.x - 1, self.y)

def next(self, direction: Direction | DiagonalDirection) -> "Coordinate":
match direction:
case Direction.UP:
return self.up()
case Direction.DOWN:
return self.down()
case Direction.RIGHT:
return self.right()
case Direction.LEFT:
return self.left()
case DiagonalDirection.UP_RIGHT:
up = self.up()
return up.right()
case DiagonalDirection.UP_LEFT:
up = self.up()
return up.left()
case DiagonalDirection.DOWN_RIGHT:
down = self.down()
return down.right()
case DiagonalDirection.DOWN_LEFT:
down = self.down()
return down.left()
case _:
raise ValueError(f"Invalid direction {direction}.")


@dataclasses.dataclass
class Cell:
Expand Down Expand Up @@ -164,50 +205,40 @@ def __repr__(self) -> str:
def cells(self) -> list[Cell]:
return [cell for row in self.grid for cell in row]

def get_adjacent_cell(self, cell: Cell, direction: Direction) -> Cell | None:
match direction:
case Direction.UP:
coordinate = cell.coordinate.up()
case Direction.DOWN:
coordinate = cell.coordinate.down()
case Direction.RIGHT:
coordinate = cell.coordinate.right()
case Direction.LEFT:
coordinate = cell.coordinate.left()
case _:
raise ValueError(f"Invalid direction {direction}.")
def has_cell(self, coordinate: Coordinate) -> bool:
return 0 <= coordinate.x < self.size and 0 <= coordinate.y < self.size

return self.get_cell(coordinate)
def has_ship_at(self, coordinate: Coordinate) -> bool:
return self.get_cell(coordinate).ship is not None

def get_cell(self, coordinate: Coordinate) -> Cell | None:
if not (0 <= coordinate.x < self.size and 0 <= coordinate.y < self.size):
return None
def get_adjacent_cells(self, coordinate: Coordinate, with_diagonals: bool = True) -> list[Cell]:
cells = []

return self.grid[coordinate.y][coordinate.x]
if with_diagonals:
directions = itertools.chain(Direction, DiagonalDirection)
else:
directions = itertools.chain(Direction)

def has_adjacent_ship(self, coordinate: Coordinate) -> bool:
cell = self.get_cell(coordinate)
for direction in directions:
try:
next_coordinate = coordinate.next(direction) # type: ignore[arg-type]
adjacent_cell = self.get_cell(next_coordinate)
except errors.CellOutOfRange:
continue
else:
cells.append(adjacent_cell)

if not cell:
raise errors.CellOutOfRange(f"Cell at {coordinate=} does not exist.")
return cells

adjacent_coordinates = [
cell.coordinate.up(),
cell.coordinate.right(),
cell.coordinate.down(),
cell.coordinate.left(),
]
diagonals = [
adjacent_coordinates[1].up(),
adjacent_coordinates[1].down(),
adjacent_coordinates[3].up(),
adjacent_coordinates[3].down(),
]
adjacent_coordinates.extend(diagonals)
def get_cell(self, coordinate: Coordinate) -> Cell:
if not self.has_cell(coordinate):
raise errors.CellOutOfRange(f"Cell at {coordinate} doesn't exist.")

cells = [self.get_cell(coor) for coor in adjacent_coordinates]
return self.grid[coordinate.y][coordinate.x]

return any([cell is not None and cell.ship is not None for cell in cells])
def has_adjacent_ship(self, coordinate: Coordinate) -> bool:
adjacent_cells = self.get_adjacent_cells(coordinate)
return any([self.has_ship_at(cell.coordinate) for cell in adjacent_cells])

def place_ship(
self, coordinates: Collection[Coordinate], ship: Ship, no_adjacent_ships: bool = False
Expand All @@ -226,21 +257,13 @@ def place_ship(

for coordinate in coordinates:
cell = self.get_cell(coordinate)

if cell is None:
raise errors.CellOutOfRange(f"Cell at {coordinate} doesn't exist.")

cell.set_ship(ship)

self.ships.append(ship)
ship.cells.extend(coordinates)

def hit_cell(self, coordinate: Coordinate) -> Ship | None:
cell = self.get_cell(coordinate)

if cell is None:
raise errors.CellOutOfRange(f"Cell at {coordinate} doesn't exist.")

cell.hit()
return cell.ship

Expand Down
10 changes: 6 additions & 4 deletions battleship/tui/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def __init__(self, game: domain.Game):
self._enable_move_delay = not is_debug()
self._human_player = game.player_a
self._bot_player = game.player_b
self._target_caller = ai.TargetCaller(self._human_player.board)
self._target_caller = ai.TargetCaller(
self._human_player.board, self._game.no_adjacent_ships
)
self._autoplacer = ai.Autoplacer(
self._bot_player.board, self._game.roster, self._game.no_adjacent_ships
)
Expand Down Expand Up @@ -319,14 +321,14 @@ def fire(self, position: Collection[str]) -> None:
def cancel(self) -> None:
pass

def _call_bot_target(self) -> Collection[str]:
def _call_bot_target(self) -> list[str]:
if self._game.salvo_mode:
count = self._bot_player.ships_alive
else:
count = 1

position = self._target_caller.call_out(count=count)
return position
targets = self._target_caller.call_out(count=count)
return targets

def _spawn_bot_fleet(self) -> None:
for item in self._game.roster:
Expand Down
Loading

0 comments on commit a73401f

Please sign in to comment.