Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
frederikschubert committed Sep 4, 2023
1 parent a2a1b1e commit 7b50a28
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 12 deletions.
15 changes: 9 additions & 6 deletions carl/envs/mario/carl_mario.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
import sys

from typing import List

import sys

import numpy as np

from carl.context.context_space import (
Expand Down Expand Up @@ -46,9 +47,7 @@ def __init__(

def _update_context(self) -> None:
self.env: MarioEnv
self.context = CARLMarioEnv.get_context_space().insert_defaults(
self.context
)
self.context = CARLMarioEnv.get_context_space().insert_defaults(self.context)
if not self.levels:
for context in self.contexts.values():
level, _ = generate_level(
Expand All @@ -66,11 +65,15 @@ def _update_context(self) -> None:
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"level_width": UniformIntegerContextFeature("level_width", 16, 1000, default_value=100),
"level_width": UniformIntegerContextFeature(
"level_width", 16, 1000, default_value=100
),
"level_index": CategoricalContextFeature(
"level_index", choices=np.arange(0, 14), default_value=0
),
"noise_seed": UniformIntegerContextFeature("noise_seed", 0, sys.maxsize, default_value=0),
"noise_seed": UniformIntegerContextFeature(
"noise_seed", 0, sys.maxsize, default_value=0
),
"mario_state": CategoricalContextFeature(
"mario_state", choices=[0, 1, 2], default_value=0
),
Expand Down
8 changes: 6 additions & 2 deletions carl/envs/mario/pcg_smb_env/mario_env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Dict, List, Literal, Optional, cast

import atexit
import os
import random
import socket
from collections import deque
from typing import Any, Dict, List, Literal, Optional, cast

import cv2
import gymnasium
Expand Down Expand Up @@ -94,7 +95,10 @@ def __init__(
self.display = Display(use_xauth=True)

def reset(
self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, **kwargs
self,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
**kwargs,
):
self._reset_obs()
if self.game is None:
Expand Down
11 changes: 9 additions & 2 deletions carl/envs/mario/pcg_smb_env/toadgan/generate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@

# Generates a noise tensor. Uses torch.randn.
def generate_spatial_noise(
size: Union[Any, List[int], Tuple[int]], device: Union[str, torch.device] = "cpu", seed: int = 0
size: Union[Any, List[int], Tuple[int]],
device: Union[str, torch.device] = "cpu",
seed: int = 0,
) -> Tensor:
return torch.randn(size, device=device, dtype=torch.float32, generator=torch.Generator().manual_seed(seed))
return torch.randn(
size,
device=device,
dtype=torch.float32,
generator=torch.Generator().manual_seed(seed),
)


# Generate a sample given a TOAD-GAN and additional parameters
Expand Down
4 changes: 3 additions & 1 deletion carl/envs/mario/pcg_smb_env/toadgan/toad_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def generate_level(
return "".join(level), initial_noise.numpy()


def generate_initial_noise(width: int, height: int, level_index: int, seed: int) -> Tensor:
def generate_initial_noise(
width: int, height: int, level_index: int, seed: int
) -> Tensor:
toad_gan = load_generator(level_index)
base_noise_map = toad_gan.noise_maps[0]
nzx = (
Expand Down
2 changes: 1 addition & 1 deletion carl/envs/mario/pcg_smb_env/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def launch_gateway():
die_on_exit=True,
port=free_port,
java_path=str(JAVA),
javaopts=["-Djava.awt.headless=false"]
javaopts=["-Djava.awt.headless=false"],
),
free_port,
)

0 comments on commit 7b50a28

Please sign in to comment.