diff --git a/.gitignore b/.gitignore index ad15fb23..0e1dc86a 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,9 @@ ctoybox/*.model.data-00000-of-00001 ctoybox/*.model.index ctoybox/*.model.meta ctoybox/checkpoint -ctoybox/*.model \ No newline at end of file +ctoybox/*.modeltoybox/build +toybox/dist +ctoybox/toybox/build +ctoybox/toybox/dist +ctoybox/toybox/toybox.egg-info +ctoybox/toybox/toybox/_native__lib.so diff --git a/ctoybox/Cargo.toml b/ctoybox/Cargo.toml index d5ca2812..1659fbb6 100644 --- a/ctoybox/Cargo.toml +++ b/ctoybox/Cargo.toml @@ -2,6 +2,7 @@ name = "ctoybox" version = "0.1.0" authors = ["Emma 'Frank' Tosch "] +build = "build.rs" [lib] name = "ctoybox" @@ -19,3 +20,5 @@ toybox = {path = "../toybox", version="*"} version = "*" path = "../core" +[build-dependencies] +cbindgen = "0.5" \ No newline at end of file diff --git a/ctoybox/REQUIREMENTS.txt b/ctoybox/REQUIREMENTS.txt index fa52e575..de880776 100644 --- a/ctoybox/REQUIREMENTS.txt +++ b/ctoybox/REQUIREMENTS.txt @@ -2,4 +2,5 @@ tensorflow gym[atari] tensorboard pygame +cffi diff --git a/ctoybox/build.rs b/ctoybox/build.rs new file mode 100644 index 00000000..23c50526 --- /dev/null +++ b/ctoybox/build.rs @@ -0,0 +1,13 @@ +// from https://github.com/getsentry/milksnake +extern crate cbindgen; + +use std::env; + +fn main() { + let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let mut config: cbindgen::Config = Default::default(); + config.language = cbindgen::Language::C; + cbindgen::generate_with_config(&crate_dir, config) + .unwrap() + .write_to_file("../target/ctoybox.h"); +} diff --git a/ctoybox/human_play.py b/ctoybox/human_play.py index d76ec4d7..80a59f1c 100644 --- a/ctoybox/human_play.py +++ b/ctoybox/human_play.py @@ -60,13 +60,16 @@ break key_state = pygame.key.get_pressed() player_input = Input() - player_input.left = key_state[K_LEFT] or key_state[K_a] - player_input.right = key_state[K_RIGHT] or key_state[K_d] - player_input.up = key_state[K_UP] or key_state[K_w] - player_input.down = key_state[K_DOWN] or key_state[K_s] - player_input.button1 = key_state[K_z] or key_state[K_SPACE] - player_input.button2 = key_state[K_x] or key_state[K_RSHIFT] or key_state[K_LSHIFT] + # Explicitly casting to bools because in some versions, the RHS gets converted + # to ints, causing problems when we load into the associated rust structs. + player_input.left = bool(key_state[K_LEFT] or key_state[K_a]) + player_input.right = bool(key_state[K_RIGHT] or key_state[K_d]) + player_input.up = bool(key_state[K_UP] or key_state[K_w]) + player_input.down = bool(key_state[K_DOWN] or key_state[K_s]) + player_input.button1 = bool(key_state[K_z] or key_state[K_SPACE]) + player_input.button2 = bool(key_state[K_x] or key_state[K_RSHIFT] or key_state[K_LSHIFT]) + tb.apply_action(player_input) if args.query is not None: print(args.query, tb.query_state_json(args.query, args.query_args)) diff --git a/ctoybox/src/core.rs b/ctoybox/src/core.rs index 434f35dd..1f3e41d4 100644 --- a/ctoybox/src/core.rs +++ b/ctoybox/src/core.rs @@ -1,6 +1,6 @@ use super::WrapSimulator; use super::WrapState; -use libc::c_char; +use libc::{c_char, c_void}; use serde_json; use std::boxed::Box; use std::ffi::{CStr, CString}; @@ -215,16 +215,20 @@ pub extern "C" fn state_apply_ale_action(state_ptr: *mut WrapState, input: i32) } #[no_mangle] -pub extern "C" fn state_apply_action(state_ptr: *mut WrapState, input_ptr: *mut Input) { +pub extern "C" fn state_apply_action(state_ptr: *mut WrapState, input_ptr: *const c_char) { let &mut WrapState { ref mut state } = unsafe { assert!(!state_ptr.is_null()); &mut *state_ptr }; - let input = unsafe { + let input_ptr = unsafe { assert!(!input_ptr.is_null()); - &mut *input_ptr + CStr::from_ptr(input_ptr) }; - state.update_mut(*input); + let input_str = input_ptr + .to_str() + .expect("Could not create input string from pointer"); + let input: Input = serde_json::from_str(input_str).expect("Could not input string to Input"); + state.update_mut(input); } #[no_mangle] @@ -246,7 +250,7 @@ pub extern "C" fn state_score(state_ptr: *mut WrapState) -> i32 { } #[no_mangle] -pub extern "C" fn state_to_json(state_ptr: *mut WrapState) -> *const c_char { +pub extern "C" fn state_to_json(state_ptr: *mut WrapState) -> *mut c_void { let &mut WrapState { ref mut state } = unsafe { assert!(!state_ptr.is_null()); &mut *state_ptr @@ -254,7 +258,7 @@ pub extern "C" fn state_to_json(state_ptr: *mut WrapState) -> *const c_char { let json: String = state.to_json(); let cjson: CString = CString::new(json).expect("Conversion to CString should succeed!"); - CString::into_raw(cjson) + CString::into_raw(cjson) as *mut c_void } #[no_mangle] diff --git a/ctoybox/start_python b/ctoybox/start_python index b74d73fa..2ccb5d8e 100755 --- a/ctoybox/start_python +++ b/ctoybox/start_python @@ -1,7 +1,9 @@ #!/bin/bash -#pip3 install -q -r REQUIREMENTS.txt --user -export PYTHONPATH=${PWD}/baselines:${PWD}/toybox:${PYTHONPATH}:${HOME}/toybox/ctoybox/baselines:${HOME}/toybox/ctoybox/toybox -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${PWD}/../target/release:${HOME}/toybox/target/release:${HOME}/dev/toybox/target/release +# RUN THIS command FROM CTOYBOX. +pip3 install -q -r REQUIREMENTS.txt --user +export PYTHONPATH=${PWD}/baselines:${PWD}/toybox:${PYTHONPATH} +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${PWD}/../target/release +export LIBCTOYBOX=${PWD}/.. echo "PYTHONPATH $PYTHONPATH" echo "LD_LIBRARY_PATH $LD_LIBRARY_PATH" echo "LIBCTOYBOX $LIBCTOYBOX" diff --git a/ctoybox/toybox/setup.py b/ctoybox/toybox/setup.py index 5c6d7202..617082b9 100644 --- a/ctoybox/toybox/setup.py +++ b/ctoybox/toybox/setup.py @@ -1,6 +1,29 @@ +# From https://github.com/getsentry/milksnake from setuptools import setup -setup(name='openai_shim', - version='0.0.1', - install_requires=['gym'] # And any other dependencies foo needs -) \ No newline at end of file +def build_native(spec): + # build an example rust library + build = spec.add_external_build( + cmd=['cargo', 'build', '--release'], + path='.' + ) + + spec.add_cffi_module( + module_path='toybox._native', + dylib=lambda: build.find_dylib('ctoybox', in_path='../../target/release'), + header_filename=lambda: build.find_header('ctoybox.h', in_path='../../target'), + rtld_flags=['NOW', 'NODELETE'] + ) + +setup( + name='toybox', + version='0.1.0', + packages=['toybox', 'toybox.envs', 'toybox.interventions', 'toybox.sample_tests'], + zip_safe=False, + platforms='any', + setup_requires=['milksnake'], + install_requires=['milksnake'], + milksnake_tasks=[ + build_native + ] +) diff --git a/ctoybox/toybox/toybox/__init__.py b/ctoybox/toybox/toybox/__init__.py index a1d147f6..0d96e407 100644 --- a/ctoybox/toybox/toybox/__init__.py +++ b/ctoybox/toybox/toybox/__init__.py @@ -1,26 +1,32 @@ -from gym.envs.registration import register import toybox.toybox as toybox import toybox.envs as envs import toybox.interventions as interventions import toybox.sample_tests as sample_tests -# Updated to use v4 to be analogous with the ALE versioning -register( - id='BreakoutToyboxNoFrameskip-v4', - entry_point='toybox.envs.atari:BreakoutEnv', - nondeterministic=True -) +try: + from gym.envs.registration import register -register( - id='AmidarToyboxNoFrameskip-v4', - entry_point='toybox.envs.atari:AmidarEnv', - nondeterministic=False -) + # Updated to use v4 to be analogous with the ALE versioning + register( + id='BreakoutToyboxNoFrameskip-v4', + entry_point='toybox.envs.atari:BreakoutEnv', + nondeterministic=True + ) -register( - id='SpaceInvadersToyboxNoFrameskip-v4', - entry_point='toybox.envs.atari:SpaceInvadersEnv', - nondeterministic=False -) + register( + id='AmidarToyboxNoFrameskip-v4', + entry_point='toybox.envs.atari:AmidarEnv', + nondeterministic=False + ) -print("Loaded Toybox environments.") \ No newline at end of file + register( + id='SpaceInvadersToyboxNoFrameskip-v4', + entry_point='toybox.envs.atari:SpaceInvadersEnv', + nondeterministic=False + ) + + print("Registered Toybox environments with gym.") + +except: + # ModuleNotFoundError only in 3.6 and above + print("Loaded Toybox environments.") diff --git a/ctoybox/toybox/toybox/clib.py b/ctoybox/toybox/toybox/clib.py deleted file mode 100644 index 1d2ca4fe..00000000 --- a/ctoybox/toybox/toybox/clib.py +++ /dev/null @@ -1,240 +0,0 @@ -from collections import deque -import ctypes -import numpy as np -from PIL import Image -import os -import platform -import time -import json - -platform = platform.system() -lib_env_var = 'LIBCTOYBOX' -lib_dylib = 'libctoybox.dylib' -lib_so = 'libctoybox.so' - - -if platform == 'Darwin': - _lib_prefix = os.environ[lib_env_var] if lib_env_var in os.environ else '..' - _lib_path_debug = os.path.sep.join([_lib_prefix, 'target', 'debug', lib_dylib]) - _lib_path_release = os.path.sep.join([_lib_prefix, 'target', 'release', lib_dylib]) - print('Looking for toybox lib in\n\t%s\nor\n\t%s' % ( - _lib_path_debug, - _lib_path_release - )) - - _lib_ts_release = os.stat(_lib_path_release).st_birthtime \ - if os.path.exists(_lib_path_release) else 0 - _lib_ts_debug = os.stat(_lib_path_debug).st_birthtime \ - if os.path.exists(_lib_path_debug) else 0 - - if (not (_lib_ts_debug or _lib_ts_release)): - raise OSError('%s not found on this machine' % lib_dylib) - - _lib_path = _lib_path_debug if _lib_ts_debug > _lib_ts_release else _lib_path_release - print(_lib_path) - -elif platform == 'Linux': - _lib_path = lib_so - -else: - raise Exception('Unsupported platform: %s' % platform) - - -try: - _lib = ctypes.CDLL(_lib_path) -except Exception: - raise Exception('Could not load libopenai from path %s.' % _lib_path - + """If you are on OSX, this may be due the relative path being different - from `target/(target|release)/libopenai.dylib. If you are on Linux, try - prefixing your call with `LD_LIBRARY_PATH=/path/to/library`.""") - -class WrapSimulator(ctypes.Structure): - pass - -class WrapState(ctypes.Structure): - pass - - -# I don't know how actions will be issued, so let's have lots of options available -NOOP = 'noop' -LEFT = "left" -RIGHT = "right" -UP = "up" -DOWN = "down" -BUTTON1 = "button1" -BUTTON2 = "button2" - -class Input(ctypes.Structure): - _fields_ = [(LEFT, ctypes.c_bool), - (RIGHT, ctypes.c_bool), - (UP, ctypes.c_bool), - (DOWN, ctypes.c_bool), - (BUTTON1, ctypes.c_bool), - (BUTTON2, ctypes.c_bool)] - - def _set_default(self): - self.left = False - self.right = False - self.up = False - self.down = False - self.button1 = False - self.button2 = False - - """ - ALE_ACTION_MEANING = { - 0 : "NOOP", - 1 : "FIRE", - 2 : "UP", - 3 : "RIGHT", - 4 : "LEFT", - 5 : "DOWN", - 6 : "UPRIGHT", - 7 : "UPLEFT", - 8 : "DOWNRIGHT", - 9 : "DOWNLEFT", - 10 : "UPFIRE", - 11 : "RIGHTFIRE", - 12 : "LEFTFIRE", - 13 : "DOWNFIRE", - 14 : "UPRIGHTFIRE", - 15 : "UPLEFTFIRE", - 16 : "DOWNRIGHTFIRE", - 17 : "DOWNLEFTFIRE", - } - """ - def set_ale(self, num): - if num == 0: - pass - elif num == 1: - self.button1 = True - elif num == 2: - self.up = True - elif num == 3: - self.right = True - elif num == 4: - self.left = True - elif num == 5: - self.down = True - elif num == 6: - self.up = True - self.right = True - elif num == 7: - self.up = True - self.left = True - elif num == 8: - self.down = True - self.right = True - elif num == 9: - self.down = True - self.left = True - elif num == 10: - self.up = True - self.button1 = True - elif num == 11: - self.right = True - self.button1 = True - elif num == 12: - self.left = True - self.button1 = True - elif num == 13: - self.down = True - self.button1 = True - elif num == 14: - self.up = True - self.right = True - self.button1 = True - elif num == 15: - self.up = True - self.left = True - self.button1 = True - elif num == 16: - self.down = True - self.right = True - self.button1 = True - elif num == 17: - self.down = True - self.left = True - self.button1 = True - - - def set_input(self, input_dir, button=NOOP): - self._set_default() - input_dir = input_dir.lower() - button = button.lower() - - # reset all directions - if input_dir == NOOP: - pass - elif input_dir == LEFT: - self.left = True - elif input_dir == RIGHT: - self.right = True - elif input_dir == UP: - self.up = True - elif input_dir == DOWN: - self.down = True - else: - print('input_dir:', input_dir) - assert False - - # reset buttons - if button == NOOP: - pass - elif button == BUTTON1: - self.button1 = True - elif button == BUTTON2: - self.button2 = True - else: - assert False - - -_lib.simulator_alloc.argtypes = [ctypes.c_char_p] -_lib.simulator_alloc.restype = ctypes.POINTER(WrapSimulator) - -_lib.simulator_seed.argtypes = [ctypes.POINTER(WrapSimulator), ctypes.c_uint] -_lib.simulator_seed.restype = None - -_lib.simulator_is_legal_action.argtypes = [ctypes.POINTER(WrapSimulator), ctypes.c_int32] -_lib.simulator_is_legal_action.restype = ctypes.c_bool - -_lib.simulator_actions.argtypes = [ctypes.POINTER(WrapSimulator)] -_lib.simulator_actions.restype = ctypes.c_void_p - -_lib.state_alloc.argtypes = [ctypes.POINTER(WrapSimulator)] -_lib.state_alloc.restype = ctypes.POINTER(WrapState) - -_lib.free_str.argtypes = [ctypes.c_void_p] -_lib.free_str.restype = None - -_lib.state_query_json.argtypes = [ctypes.POINTER(WrapState), ctypes.c_char_p, ctypes.c_char_p] -_lib.state_query_json.restype = ctypes.c_void_p - -_lib.state_apply_ale_action.argtypes = [ctypes.POINTER(WrapState), ctypes.c_int32] -_lib.state_apply_ale_action.restype = ctypes.c_bool - -_lib.state_apply_action.argtypes = [ctypes.POINTER(WrapState), ctypes.POINTER(Input)] -_lib.state_apply_action.restype = None - -_lib.simulator_frame_width.argtypes = [ctypes.POINTER(WrapSimulator)] -_lib.simulator_frame_width.restype = ctypes.c_int32 - -_lib.simulator_frame_height.argtypes = [ctypes.POINTER(WrapSimulator)] -_lib.simulator_frame_height.restype = ctypes.c_int32 - -_lib.state_lives.restype = ctypes.c_int32 -_lib.state_score.restype = ctypes.c_int32 - -_lib.render_current_frame.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p] - #(frame_ptr, size, sim.get_simulator(), self.__state) - -_lib.state_to_json.argtypes = [ctypes.POINTER(WrapState)] -_lib.state_to_json.restype = ctypes.c_void_p - -_lib.state_from_json.argtypes = [ctypes.POINTER(WrapSimulator), ctypes.c_char_p] -_lib.state_from_json.restype = ctypes.POINTER(WrapState) - -_lib.simulator_to_json.argtypes = [ctypes.POINTER(WrapSimulator)] -_lib.simulator_to_json.restype = ctypes.c_void_p - -_lib.simulator_from_json.argtypes = [ctypes.POINTER(WrapSimulator), ctypes.c_char_p] -_lib.simulator_from_json.restype = ctypes.POINTER(WrapSimulator) diff --git a/ctoybox/toybox/toybox/envs/atari/constants.py b/ctoybox/toybox/toybox/envs/atari/constants.py index 76916bc7..a34be6e3 100644 --- a/ctoybox/toybox/toybox/envs/atari/constants.py +++ b/ctoybox/toybox/toybox/envs/atari/constants.py @@ -1,17 +1,16 @@ -from toybox.toybox import NOOP, UP, RIGHT, LEFT, DOWN, BUTTON1 -from toybox.clib import Input +from toybox.toybox import Input -NOOP_STR = NOOP.upper() +NOOP_STR = Input._NOOP.upper() FIRE_STR = "FIRE" -UP_STR = UP.upper() -RIGHT_STR = RIGHT.upper() -LEFT_STR = LEFT.upper() -DOWN_STR = DOWN.upper() +UP_STR = Input._UP.upper() +RIGHT_STR = Input._RIGHT.upper() +LEFT_STR = Input._LEFT.upper() +DOWN_STR = Input._DOWN.upper() UPFIRE_STR = "UPFIRE" RIGHTFIRE_STR = "RIGHTFIRE" LEFTFIRE_STR = "LEFTFIRE" DOWNFIRE_STR = "DOWNFIRE" -BUTTON1_STR = BUTTON1.upper() +BUTTON1_STR = Input._BUTTON1.upper() # Copied from, and required by, baselines ACTION_MEANING = { @@ -35,4 +34,4 @@ 17 : "DOWNLEFTFIRE", } -ACTION_LOOKUP = { v : k for (k, v) in ACTION_MEANING.items() } \ No newline at end of file +ACTION_LOOKUP = { v : k for (k, v) in ACTION_MEANING.items() } diff --git a/ctoybox/toybox/toybox/toybox.py b/ctoybox/toybox/toybox/toybox.py index 051fb321..a8578ea4 100644 --- a/ctoybox/toybox/toybox/toybox.py +++ b/ctoybox/toybox/toybox/toybox.py @@ -1,5 +1,4 @@ from collections import deque -import ctypes import numpy as np from PIL import Image import os @@ -7,17 +6,151 @@ import time import json -from toybox.clib import _lib, Input, NOOP, LEFT, RIGHT, UP, DOWN, BUTTON1, BUTTON2 +try: + from toybox._native import ffi, lib +except: + # should be ModuleNotFoundError, but this is not available on the version of python on travis + print('Global setup not found...trying local development install...') + platform = platform.system() + lib_env_var = 'LIBCTOYBOX' + lib_dylib = 'libctoybox.dylib' + lib_so = 'libctoybox.so' + + _lib_prefix = os.environ[lib_env_var] if lib_env_var in os.environ else '..' + + if platform == 'Darwin': + _lib_path_debug = os.path.sep.join([_lib_prefix, 'target', 'debug', lib_dylib]) + _lib_path_release = os.path.sep.join([_lib_prefix, 'target', 'release', lib_dylib]) + print('Looking for toybox lib in\n\t%s\nor\n\t%s' % ( + _lib_path_debug, + _lib_path_release + )) + + _lib_ts_release = os.stat(_lib_path_release).st_birthtime \ + if os.path.exists(_lib_path_release) else 0 + _lib_ts_debug = os.stat(_lib_path_debug).st_birthtime \ + if os.path.exists(_lib_path_debug) else 0 + + if (not (_lib_ts_debug or _lib_ts_release)): + raise OSError('%s not found on this machine' % lib_dylib) + + _lib_path = _lib_path_debug if _lib_ts_debug > _lib_ts_release else _lib_path_release + print(_lib_path) + + elif platform == 'Linux': + _lib_path = lib_so + + else: + raise Exception('Unsupported platform for development: %s' % platform) + + try: + from cffi import FFI + ffi = FFI() + with open(os.sep.join([_lib_prefix, 'target', 'ctoybox.h']), 'r') as f: + # directives not supported! + header = '\n'.join([line for line in f.readlines() if not line.startswith('#')]) + ffi.cdef(header) + lib = ffi.dlopen(_lib_path) + except Exception: + raise Exception('Could not load libctoybox from path %s. ' % _lib_path + + """If you are on OSX, this may be due the relative path being different + from `target/(target|release)/libctoybox.dylib. If you are on Linux, try + prefixing your call with `LD_LIBRARY_PATH=/path/to/library`.""") + + +class Input(): + """An input object represents a game controller having left, right, up, down, and two buttons. + + ALE mapping: + ALE_ACTION_MEANING = { + 0 : "NOOP", + 1 : "FIRE", + 2 : "UP", + 3 : "RIGHT", + 4 : "LEFT", + 5 : "DOWN", + 6 : "UPRIGHT", + 7 : "UPLEFT", + 8 : "DOWNRIGHT", + 9 : "DOWNLEFT", + 10 : "UPFIRE", + 11 : "RIGHTFIRE", + 12 : "LEFTFIRE", + 13 : "DOWNFIRE", + 14 : "UPRIGHTFIRE", + 15 : "UPLEFTFIRE", + 16 : "DOWNRIGHTFIRE", + 17 : "DOWNLEFTFIRE", + } + """ + + _LEFT = "left" + _RIGHT = "right" + _UP = "up" + _DOWN = "down" + _BUTTON1 = "button1" + _BUTTON2 = "button2" + _NOOP = "noop" + + def __init__(self): + self.reset() + + def reset(self): + self.left = False + self.right = False + self.up = False + self.down = False + self.button1 = False + self.button2 = False + + def __str__(self): + return self.__dict__.__str__() + + def __repr__(self): + return self.__dict__.__str__() + + def set_input(self, input_dir, button=_NOOP): + input_dir = input_dir.lower() + button = button.lower() + + # reset all directions + if input_dir == Input._NOOP: + pass + elif input_dir == Input._LEFT: + self.left = True + elif input_dir == Input._RIGHT: + self.right = True + elif input_dir == Input._UP: + self.up = True + elif input_dir == Input._DOWN: + self.down = True + else: + print('input_dir:', input_dir) + assert False + + # reset buttons + if button == Input._NOOP: + pass + elif button == Input._BUTTON1: + self.button1 = True + elif button == Input._BUTTON2: + self.button2 = True + else: + assert False + def rust_str(result): - txt = ctypes.cast(result, ctypes.c_char_p).value.decode('UTF-8') - _lib.free_str(result) + txt = ffi.cast("char *", result) #.value.decode('UTF-8') + txt = ffi.string(txt).decode('UTF-8') + lib.free_str(result) return txt def json_str(js): if type(js) is dict: js = json.dumps(js) + elif type(js) is Input: + js = json.dumps(js.__dict__) elif type(js) is not str: raise ValueError('Unknown json type: %s (only str and dict supported)' % type(js)) return js @@ -25,9 +158,8 @@ def json_str(js): class Simulator(object): def __init__(self, game_name, sim=None): if sim is None: - sim = _lib.simulator_alloc(game_name.encode('utf-8')) + sim = lib.simulator_alloc(game_name.encode('utf-8')) # sim should be a pointer - #self.__sim = ctypes.pointer(ctypes.c_int(sim)) self.game_name = game_name self.__sim = sim self.deleted = False @@ -35,7 +167,7 @@ def __init__(self, game_name, sim=None): def __del__(self): if not self.deleted: self.deleted = True - _lib.simulator_free(self.__sim) + lib.simulator_free(self.__sim) self.__sim = None def __enter__(self): @@ -45,13 +177,13 @@ def __exit__(self, exc_type, exc_value, traceback): self.__del__() def set_seed(self, value): - _lib.simulator_seed(self.__sim, value) + lib.simulator_seed(self.__sim, value) def get_frame_width(self): - return _lib.simulator_frame_width(self.__sim) + return lib.simulator_frame_width(self.__sim) def get_frame_height(self): - return _lib.simulator_frame_height(self.__sim) + return lib.simulator_frame_height(self.__sim) def get_simulator(self): return self.__sim @@ -60,22 +192,22 @@ def new_game(self): return State(self) def state_from_json(self, js): - state = _lib.state_from_json(self.get_simulator(), json_str(js).encode('utf-8')) + state = lib.state_from_json(self.get_simulator(), json_str(js).encode('utf-8')) return State(self, state=state) def to_json(self): - json_str = rust_str(_lib.simulator_to_json(self.get_simulator())) + json_str = rust_str(lib.simulator_to_json(self.get_simulator())) return json.loads(str(json_str)) def from_json(self, config_js): old_sim = self.__sim - self.__sim = _lib.simulator_from_json(self.get_simulator(), json_str(config_js).encode('utf-8')) + self.__sim = lib.simulator_from_json(self.get_simulator(), json_str(config_js).encode('utf-8')) del old_sim class State(object): def __init__(self, sim, state=None): - self.__state = state or _lib.state_alloc(sim.get_simulator()) + self.__state = state or lib.state_alloc(sim.get_simulator()) self.game_name = sim.game_name self.deleted = False @@ -85,7 +217,7 @@ def __enter__(self): def __del__(self): if not self.deleted: self.deleted = True - _lib.state_free(self.__state) + lib.state_free(self.__state) self.__state = None def __exit__(self, exc_type, exc_value, traceback): @@ -96,16 +228,16 @@ def get_state(self): return self.__state def lives(self): - return _lib.state_lives(self.__state) + return lib.state_lives(self.__state) def score(self): - return _lib.state_score(self.__state) + return lib.state_score(self.__state) def game_over(self): return self.lives() == 0 def query_json(self, query, args="null"): - txt = rust_str(_lib.state_query_json(self.__state, json_str(query).encode('utf-8'), json_str(args).encode('utf-8'))) + txt = rust_str(lib.state_query_json(self.__state, json_str(query).encode('utf-8'), json_str(args).encode('utf-8'))) try: out = json.loads(txt) except: @@ -124,8 +256,8 @@ def render_frame_color(self, sim): rgba = 4 size = h * w * rgba frame = np.zeros(size, dtype='uint8') - frame_ptr = frame.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) - _lib.render_current_frame(frame_ptr, size, False, sim.get_simulator(), self.__state) + frame_ptr = ffi.cast("uint8_t *", frame.ctypes.data) + lib.render_current_frame(frame_ptr, size, False, sim.get_simulator(), self.__state) return np.reshape(frame, (h,w,rgba)) def render_frame_rgb(self, sim): @@ -137,12 +269,12 @@ def render_frame_grayscale(self, sim): w = sim.get_frame_width() size = h * w frame = np.zeros(size, dtype='uint8') - frame_ptr = frame.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) - _lib.render_current_frame(frame_ptr, size, True, sim.get_simulator(), self.__state) + frame_ptr = ffi.cast("uint8_t *", frame.ctypes.data) + lib.render_current_frame(frame_ptr, size, True, sim.get_simulator(), self.__state) return np.reshape(frame, (h,w,1)) def to_json(self): - json_str = rust_str(_lib.state_to_json(self.__state)) + json_str = rust_str(lib.state_to_json(self.__state)) return json.loads(str(json_str)) class Toybox(object): @@ -168,23 +300,30 @@ def get_width(self): def get_legal_action_set(self): sim = self.rsimulator.get_simulator() - txt = rust_str(_lib.simulator_actions(sim)) + txt = rust_str(lib.simulator_actions(sim)) try: out = json.loads(txt) except: raise ValueError(txt) return out - def apply_ale_action(self, action_int): + def apply_ale_action(self, action_int): + """Takes an integer corresponding to an action, as specified in ALE and applies the action k times, where k is the sticky action constant stored in self.frames_per_action. + """ # implement frameskip(k) by sending the action (k+1) times every time we have an action. for _ in range(self.frames_per_action): - if not _lib.state_apply_ale_action(self.rstate.get_state(), action_int): + if not lib.state_apply_ale_action(self.rstate.get_state(), action_int): raise ValueError("Expected to apply action, but failed: {0}".format(action_int)) def apply_action(self, action_input_obj): + """Takes an Input + """ # implement frameskip(k) by sending the action (k+1) times every time we have an action. for _ in range(self.frames_per_action): - _lib.state_apply_action(self.rstate.get_state(), ctypes.byref(action_input_obj)) + js = json_str(action_input_obj).encode('UTF-8') + print("INPUT JSON", js) + lib.state_apply_action(self.rstate.get_state(), + ffi.new("char []", js)) def get_state(self): return self.rstate.render_frame(self.rsimulator, self.grayscale)