Skip to content

Commit

Permalink
Move cv2 into render funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian Weichart committed May 10, 2024
1 parent 50c5eaa commit fd2caad
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 45 deletions.
20 changes: 8 additions & 12 deletions examples/play_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,14 @@

if __name__ == "__main__":
# Create an instance of Tetris
tetris_game = gym.make("tetris_gymnasium/Tetris", render_mode="rgb_array")
tetris_game = gym.make("tetris_gymnasium/Tetris", render_mode="human")
tetris_game.reset(seed=42)

window_name = "Tetris Gymnasium"
cv2.namedWindow(window_name, cv2.WINDOW_GUI_NORMAL)
cv2.resizeWindow(window_name, 200, 400)

# Main game loop
terminated = False
while not terminated:
# Render the current state of the game as text
rgb = tetris_game.render()

# Render the current state of the game as an image using CV2q
# CV2 uses BGR color format, so we need to convert the RGB image to BGR
cv2.imshow(window_name, cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
cv2.waitKey(50)
tetris_game.render()

# Pick an action from user input mapped to the keyboard
action = None
Expand All @@ -49,7 +40,12 @@
tetris_game.reset(seed=42)
break

if cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) == 0:
if (
cv2.getWindowProperty(
tetris_game.unwrapped.window_name, cv2.WND_PROP_VISIBLE
)
== 0
):
sys.exit()

# Perform the action
Expand Down
20 changes: 8 additions & 12 deletions examples/play_interactive_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,15 @@

if __name__ == "__main__":
# Create an instance of Tetris
tetris_game = gym.make("tetris_gymnasium/Tetris", render_mode="rgb_array")
tetris_game = gym.make("tetris_gymnasium/Tetris", render_mode="human")
tetris_game.reset(seed=42)
tetris_game = CnnObservation(tetris_game)

window_name = "Tetris Gymnasium"
cv2.namedWindow(window_name, cv2.WINDOW_GUI_NORMAL)
cv2.resizeWindow(window_name, 395, 250)

# Main game loop
terminated = False
while not terminated:
# Render the current state of the game as text
rgb = tetris_game.render()

# Render the current state of the game as an image using CV2
# CV2 uses BGR color format, so we need to convert the RGB image to BGR
cv2.imshow(window_name, cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
cv2.waitKey(50)
tetris_game.render()

# Pick an action from user input mapped to the keyboard
action = None
Expand All @@ -50,7 +41,12 @@
tetris_game.reset(seed=42)
break

if cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) == 0:
if (
cv2.getWindowProperty(
tetris_game.unwrapped.window_name, cv2.WND_PROP_VISIBLE
)
== 0
):
sys.exit()

# Perform the action
Expand Down
2 changes: 1 addition & 1 deletion tetris_gymnasium/components/tetromino.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Pixel:
A pixel is the basic building block of the game and has an id and a color.
The basic pixels are in the most cases the empty pixel (id=0) and the bedrock pixel (id=1).
The basic pixels are by default the empty pixel (id=0) and the bedrock pixel (id=1).
Additionally, multiple pixels can be combined to form a tetromino.
"""

Expand Down
2 changes: 1 addition & 1 deletion tetris_gymnasium/components/tetromino_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TetrominoQueue:
"""The queue shows the incoming tetrominoes in a game of Tetris.
"""The `TetrominoQueue` stores all incoming tetrominoes in a queue.
The sequence of pieces is generated by a :class:`Randomizer`, which can be customized by the user.
"""
Expand Down
39 changes: 28 additions & 11 deletions tetris_gymnasium/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import fields
from typing import Any, List

import cv2
import gymnasium as gym
import numpy as np
from gymnasium.core import ActType, RenderFrame
Expand Down Expand Up @@ -74,15 +75,15 @@ def __init__(
Args:
render_mode: The mode to use for rendering. If None, no rendering will be done.
width: The width of the game board.
height: The height of the game board.
randomizer: The randomizer to use for selecting tetrominoes.
holder: The holder to use for storing tetrominoes.
queue: The queue to use for storing tetrominoes.
width: The width of the board.
height: The height of the board.
randomizer: The :class:`Randomizer` to use for selecting tetrominoes.
holder: The :class:`TetrominoHolder` to use for storing tetrominoes.
queue: The :class:`TetrominoQueue` to use for holding tetrominoes temporarily.
actions_mapping: The mapping for the actions that the agent can take.
rewards_mapping: The mapping for the rewards that the agent can receive.
base_pixels: The base pixels to use for the environment (e.g. empty, bedrock).
tetrominoes: The tetrominoes to use for the environment.
base_pixels: A list of base (non-Tetromino) :class:`Pixel` to use for the environment (e.g. empty, bedrock).
tetrominoes: A list of :class:`Tetromino` to use in the environment.
"""
# Dimensions
self.height: int = height
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
# Utilities
self.queue = queue(randomizer(len(tetrominoes)), 5)
self.holder = holder()
self.has_swapped = False

# Position
self.x: int = 0
Expand Down Expand Up @@ -161,7 +163,7 @@ def __init__(

assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.has_swapped = False
self.window_name = None

def step(self, action: ActType) -> "tuple[dict, float, bool, bool, dict]":
"""Perform one step of the environment's dynamics.
Expand Down Expand Up @@ -234,6 +236,8 @@ def reset(
) -> "tuple[dict[str, Any], dict[str, Any]]":
"""Resets the state of the environment.
As with all Gymnasium environments, the reset method is called once at the beginning of an episode.
Args:
seed: The random seed to use for the reset.
options: A dictionary of options to use for the reset.
Expand All @@ -256,6 +260,9 @@ def reset(
self.holder.reset()
self.has_swapped = False

# Render
self.window_name = None

return self._get_obs(), self._get_info()

def render(self) -> "RenderFrame | list[RenderFrame] | None":
Expand All @@ -271,7 +278,7 @@ def render(self) -> "RenderFrame | list[RenderFrame] | None":
char_field = np.where(projection == 0, ".", projection.astype(str))
field_str = "\n".join("".join(row) for row in char_field)
return field_str
elif self.render_mode == "rgb_array":
elif self.render_mode == "human" or self.render_mode == "rgb_array":
# Initialize rgb array
rgb = np.zeros(
(self.board.shape[0], self.board.shape[1], 3), dtype=np.uint8
Expand All @@ -295,7 +302,17 @@ def render(self) -> "RenderFrame | list[RenderFrame] | None":
rgb[slices] += active_tetromino_rgb

# Crop padding away as we don't want to render it
return self.crop_padding(rgb)
rgb = self.crop_padding(rgb)

if self.render_mode == "rgb_array":
return rgb

if self.render_mode == "human":
if self.window_name is None:
self.window_name = "Tetris Gymnasium"
cv2.namedWindow(self.window_name, cv2.WINDOW_GUI_NORMAL)
cv2.resizeWindow(self.window_name, 200, 400)
cv2.imshow(self.window_name, cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))

return None

Expand All @@ -318,7 +335,7 @@ def collision(self, tetromino: Tetromino, x: int, y: int) -> bool:
"""Check if the tetromino collides with the board at the given position.
A collision is detected if the tetromino overlaps with any non-zero cell on the board.
These non-zero cells represent the padding / bedrock (value 1) or other tetrominoes (values 2+).
These non-zero cells represent the padding / bedrock (value 1) or other tetrominoes (values >=2).
Args:
tetromino: The tetromino to check for collision.
Expand Down
27 changes: 19 additions & 8 deletions tetris_gymnasium/wrappers/observation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Observation wrapper module for the Tetris Gymnasium environment."""
import cv2
import gymnasium as gym
import numpy as np
from gymnasium.core import RenderFrame
Expand All @@ -8,16 +9,13 @@


class CnnObservation(gym.ObservationWrapper):
"""Wrapper that displays all observations (board, holder, queue) in a single 2D matrix.
"""Observation wrapper that displays all observations (board, holder, queue) in a single 2D matrix, instead of a dictionary.
The 2D matrix contains the board on the left, the queue on the top right and the holder on the bottom right.
"""

def __init__(self, env: Tetris):
"""Initializes the observation space to be a single 2D matrix.
The size of the matrix depends on how many tetrominoes can be stored in the queue / holder.
"""
"""The size of the matrix depends on how many tetrominoes can be stored in the queue / holder."""
super().__init__(env)
self.observation_space = Box(
low=0,
Expand Down Expand Up @@ -63,21 +61,34 @@ def observation(self, observation):
def render(self) -> "RenderFrame | list[RenderFrame] | None":
"""Renders the environment in various formats.
This render function is different from the default as it uses the observation space to render the environment.
This render function is different from the default as it uses the values from :func:`observation` to render
the environment.
"""
matrix = self.observation(self.env.unwrapped._get_obs()).astype(np.integer)

if self.render_mode == "ansi":
char_field = np.where(matrix == 0, ".", matrix.astype(str))
field_str = "\n".join("".join(row) for row in char_field)
return field_str
if self.render_mode == "rgb_array":
if self.render_mode == "human" or self.render_mode == "rgb_array":
# Initialize rgb array
rgb = np.zeros((matrix.shape[0], matrix.shape[1], 3), dtype=np.uint8)
# Render the board
colors = np.array(list(p.color_rgb for p in self.pixels), dtype=np.uint8)
rgb[...] = colors[matrix]

return rgb
if self.render_mode == "rgb_array":
return rgb

if self.render_mode == "human":
if self.env.unwrapped.window_name is None:
self.env.unwrapped.window_name = "Tetris Gymnasium"
cv2.namedWindow(
self.env.unwrapped.window_name, cv2.WINDOW_GUI_NORMAL
)
cv2.resizeWindow(self.env.unwrapped.window_name, 395, 250)
cv2.imshow(
self.env.unwrapped.window_name, cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
)

return None

0 comments on commit fd2caad

Please sign in to comment.