diff --git a/Demos/Stratego/cartpole/README.md b/Demos/Stratego/cartpole/README.md new file mode 100644 index 0000000..9cf7e74 --- /dev/null +++ b/Demos/Stratego/cartpole/README.md @@ -0,0 +1,107 @@ +# CartPole model for UPPAAL Stratego + +This model is a [UPPAAL Stratego][1] implementation of the classical +reinforcement learning problem, where an agent has to balance a pole vertically +on a moving cart. The model corresponds to the [Gymnasium environment][4] that +is widely used in testing RL techniques. + +![cartpole-gif](./reinforcement-learning-cartpole-v0.gif) +([source][2]) + + +## Model description + +The model consists of an always moving cart that can either move left or +right on a flat, horizontal surface. On top of the cart, a pole is balancing and +the learning agent has to keep the pole standing upright by changing the +direction of the cart without letting it travel to far in either direction. +The agent is handed control every 0.02 seconds and then has to decide in which +direction to push the cart. + +![agent template](./imgs/agent-screenshot.png) + +The observable state space for the agent is 4-dimensional and consists of the +position and velocity of the cart, the angle of the pole and the velocity with +which the pole is falling. In the initial location of the model, all these +variables are instantiated with a random value between -0.05 and 0.05. The main +location for the CartPole model is when it is alive. Here the position of the +cart and the angle of the pole are set to change according to the corresponding +velocities. On the other hand, the velocities are updated by a set of functions +which are described in detail in [the paper][3] by Florian, +R. (2007). They are quite intricate and depends on the mass of the cart, the +mass and length of the pole and the magnitude of the force that gets applied. + +In this model, the functions are written up so it is easy to change these +values but by default they are set to match the configuration of the Gymnasium +environment that this model is supposed to correspond to. + +![acceleration functions](./imgs/functions-screenshot.png) + +Whenever the agent takes an action, the CartPole model visits the IsDead +location, where there effectively happens a check for whether the agent has lost +or is able to continue. There are four ways to lose: the cart is to far to the +left, the cart is to far to the right, the pole is to low on the left or the +pole is to low on the right. If none of these are true, the model transitions +back to the Alive location. Otherwise, the agent is terminated, the system is +reset and the death counter is increased by one. + +![cartpole template](./imgs/cartpole-screenshot.png) + + +## Training and evaluation + +For evaluation, we will start by getting UPPAAL Stratego to calculate the +expected number of deaths when we apply random control. The query `E[<=10;1000] +(max: CartPole.num_deaths)` estimates the maximal number of deaths over 10 +seconds on the basis of 1000 simulated runs. We see, that this estimation lies +about 23, which indicates a pretty poor control strategy. + +We therefore train a strategy that has access to the aforementioned state +variables and has the objective to minimize the number of deaths. The query +`strategy StayAlive = minE (CartPole.num_deaths) [<=10] {} -> +{CartPole.cart_pos, CartPole.cart_vel, CartPole.pole_ang, CartPole.pole_vel}: +<> time >= 10` does the job. + +Now we can rerun the estimation query, but this time appending `under StayAlive` +to it in order to utilize our newfound strategy. We now see a dramatically +better performance as the estimated maximal number of deaths is less than +0.01 (might vary a bit from each execution of the query)! This means that UPPAAL +Stratego has done what every other RL technique worth its salt should be able to +do: solve the cartpole problem! + + +## Evaluating in the Gymnasium environment + +We can assert that the strategy learned by UPPAAL Stratego is actually useful +outside Stratego. [Gymnasium][4] is a widely used Python framework for working +with a wide range of standard RL environments such as cartpole, and since this +model is based on that specific implementation, our Stratego strategies should +work in the Gymnasium setting. In the directory of this model are some Python +files that allow you to try this out! + +First, be a good pythonista and create and activate a virtual environment using +your favorite method to do so (or just the easy method: `python3 -m venv ./env +&& source env/bin/activate`). Then run the command `pip install -r +requirements.txt` which will install Gymnasium and [stratetrees][5], a small +package designed to work with UPPAAL Stratego strategies in a Python setting. + +Now you need to save a UPPAAL strategy as a json file. In UPPAAL, create a new +query that says `saveStrategy("/path/to/somewhere/strategy.json", StayAlive)` +(if your strategy is not called 'StayAlive', you should obviously write +something else at the end) and run it. Now you should be able to run the command +`python main.py --strategy /path/to/somewhere/strategy.json` and see your UPPAAL +Stratego strategy being applied on the Gymnasium environment! + +As the output suggest, a perfect score would be to reach a mean of 500. That +will probably not be the case (more likely something like 490). This indicates +that the UPPAAL strategy is not perfect and has not solved the problem +completely. You can try and tinker with the learning parameters in UPPAAL +(Options -> Learning parameters...) or you can increase the number of seconds +in the training queries above (eg. change 10 to 20 to force it to learn to +balance for a longer time period). + +[1]: https://people.cs.aau.dk/~marius/stratego/ +[2]: https://tenor.com/view/reinforcement-learning-cartpole-v0-tensorflow-open-ai-gif-18474251 +[3]: https://coneural.org/florian/papers/05_cart_pole.pdf +[4]: https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py +[5]: https://pypi.org/project/stratetrees/ diff --git a/Demos/Stratego/cartpole/cartpole.xml b/Demos/Stratego/cartpole/cartpole.xml new file mode 100644 index 0000000..f78a340 --- /dev/null +++ b/Demos/Stratego/cartpole/cartpole.xml @@ -0,0 +1,490 @@ + + + + /** +Implementation of the classical CartPole environment. +Author: Andreas Holck Høeg-Petersen + +Based on the Gymnasium env (https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py) +which uses the equations from Florian, R. 2007 (https://coneural.org/florian/papers/05_cart_pole.pdf) + +The system models a cart that can be pushed left or right with a pole balancing +on top of the cart. The objective of the learning agent is to keep the pole +from falling for 10 seconds and not pushing the cart to far in either +direction. Every time the agent fails, the system is reset and time continues. +We thus count the number of deaths during a single run and seek to minimize +this number. +**/ + +broadcast chan left, right; +clock time; + + + system Agent, CartPole; + + + + // How often do we die with random control? + + + + E[<=10;1000] (max: CartPole.num_deaths) + Expected number of deaths with random control. +Should finish in a couple of seconds and be around 23. + +
23.77 ± 0.162791 (95% CI)
+ + 15.0,0.001 +16.0,0.003 +17.0,0.004 +18.0,0.015 +19.0,0.032 +20.0,0.054 +21.0,0.087 +22.0,0.115 +23.0,0.15 +24.0,0.139 +25.0,0.125 +26.0,0.123 +27.0,0.079 +28.0,0.045 +29.0,0.02 +30.0,0.006 +31.0,0.002 + + 23.77,0.0 +23.77,0.15 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=17 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [15, 31] +Mean estimate of displayed sample: 23.77 ± 0.16279 (95% CI) + + + 15.0,0.001 +16.0,0.003 +17.0,0.004 +18.0,0.015 +19.0,0.032 +20.0,0.054 +21.0,0.087 +22.0,0.115 +23.0,0.15 +24.0,0.139 +25.0,0.125 +26.0,0.123 +27.0,0.079 +28.0,0.045 +29.0,0.02 +30.0,0.006 +31.0,0.002 + + 23.77,0.0 +23.77,0.15 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=17 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [15, 31] +Mean estimate of displayed sample: 23.77 ± 0.16279 (95% CI) + + + 15.0,0.0 +16.0,0.001 +17.0,0.004 +18.0,0.008 +19.0,0.023 +20.0,0.055 +21.0,0.109 +22.0,0.196 +23.0,0.311 +24.0,0.461 +25.0,0.6 +26.0,0.725 +27.0,0.848 +28.0,0.927 +29.0,0.972 +30.0,0.992 +31.0,0.998 + + 23.77,0.0 +23.77,1.0 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=17 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [15, 31] +Mean estimate of displayed sample: 23.77 ± 0.16279 (95% CI) + + + 15.0,0.003682083896865672 +16.0,0.005558924279826673 +17.0,0.010209664683929873 +18.0,0.015702049176074685 +19.0,0.03431233761285437 +20.0,0.07099151554205749 +21.0,0.12997836534293516 +22.0,0.22198406428487183 +23.0,0.34071618137621124 +24.0,0.4924717630817992 +25.0,0.6305310124510876 +26.0,0.7524786563444994 +27.0,0.869701749675018 +28.0,0.9423463747720674 +29.0,0.9813153138044625 +30.0,0.9965400238346707 +31.0,0.9997576988831227 + + 15.0,0.0 +16.0,2.5317487491294045E-5 +17.0,0.0010909079877259719 +18.0,0.003459976165329311 +19.0,0.014634582325176757 +20.0,0.04169879507953599 +21.0,0.09035814767691946 +22.0,0.17181996603357735 +23.0,0.28240130538824093 +24.0,0.4297584650591189 +25.0,0.5688784459558931 +26.0,0.6961899225470708 +27.0,0.8242275312668873 +28.0,0.9090858246054988 +29.0,0.9597851126261567 +30.0,0.9842979508239253 +31.0,0.9927941610885425 + + 15.0,0.0 +16.0,0.001 +17.0,0.004 +18.0,0.008 +19.0,0.023 +20.0,0.055 +21.0,0.109 +22.0,0.196 +23.0,0.311 +24.0,0.461 +25.0,0.6 +26.0,0.725 +27.0,0.848 +28.0,0.927 +29.0,0.972 +30.0,0.992 +31.0,0.998 + + 23.77,0.0 +23.77,1.0 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=17 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [15, 31] +Mean estimate of displayed sample: 23.77 ± 0.16279 (95% CI) + + + 15.0,1.0 +16.0,3.0 +17.0,4.0 +18.0,15.0 +19.0,32.0 +20.0,54.0 +21.0,87.0 +22.0,115.0 +23.0,150.0 +24.0,139.0 +25.0,125.0 +26.0,123.0 +27.0,79.0 +28.0,45.0 +29.0,20.0 +30.0,6.0 +31.0,2.0 + + 23.77,0.0 +23.77,150.0 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=17 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [15, 31] +Mean estimate of displayed sample: 23.77 ± 0.16279 (95% CI) + +
+
+ + + + + + // Train strategy + + + + strategy StayAlive = minE (CartPole.num_deaths) [<=10] {} -> {CartPole.cart_pos, CartPole.cart_vel, CartPole.pole_ang, CartPole.pole_vel}: <> time >= 10 + Train a strategy from a partially observable state space that minimizes the number of deaths over a 10 seconds run. +Should find a strategy within a minute. + + + + + + + + + // How often do we die with trained controller? + + + + E[<=10;1000] (max: CartPole.num_deaths) under StayAlive + Expected number of deaths under the well trained agent. +Should finish in a couple of seconds and be around 0.01. + +
0.001 ± 0.00196234 (95% CI)
+ + 0.0,0.999 +1.0,0.001 + + 0.001,0.0 +0.001,0.999 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=2 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [0, 1] +Mean estimate of displayed sample: 0.001 ± 0.00196 (95% CI) + + + 0.0,0.999 +1.0,0.001 + + 0.001,0.0 +0.001,0.999 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=2 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [0, 1] +Mean estimate of displayed sample: 0.001 ± 0.00196 (95% CI) + + + 0.0,0.0 +1.0,0.999 + + 0.001,0.0 +0.001,1.0 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=2 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [0, 1] +Mean estimate of displayed sample: 0.001 ± 0.00196 (95% CI) + + + 0.0,0.003682083896865672 +1.0,0.9999746825125088 + + 0.0,0.0 +1.0,0.9944410757201734 + + 0.0,0.0 +1.0,0.999 + + 0.001,0.0 +0.001,1.0 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=2 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [0, 1] +Mean estimate of displayed sample: 0.001 ± 0.00196 (95% CI) + + + 0.0,999.0 +1.0,1.0 + + 0.001,0.0 +0.001,999.0 + + Parameters: α=0.05, ε=0.05, bucket width=1, bucket count=2 +Runs: 1000 in total, 1000 (100%) displayed, 0 (0%) remaining +Span of displayed sample: [0, 1] +Mean estimate of displayed sample: 0.001 ± 0.00196 (95% CI) + +
+
+
+
diff --git a/Demos/Stratego/cartpole/imgs/agent-screenshot.png b/Demos/Stratego/cartpole/imgs/agent-screenshot.png new file mode 100644 index 0000000..337cc13 Binary files /dev/null and b/Demos/Stratego/cartpole/imgs/agent-screenshot.png differ diff --git a/Demos/Stratego/cartpole/imgs/cartpole-screenshot.png b/Demos/Stratego/cartpole/imgs/cartpole-screenshot.png new file mode 100644 index 0000000..0ec8aa0 Binary files /dev/null and b/Demos/Stratego/cartpole/imgs/cartpole-screenshot.png differ diff --git a/Demos/Stratego/cartpole/imgs/functions-screenshot.png b/Demos/Stratego/cartpole/imgs/functions-screenshot.png new file mode 100644 index 0000000..c3ec0dd Binary files /dev/null and b/Demos/Stratego/cartpole/imgs/functions-screenshot.png differ diff --git a/Demos/Stratego/cartpole/main.py b/Demos/Stratego/cartpole/main.py new file mode 100644 index 0000000..4d0ad36 --- /dev/null +++ b/Demos/Stratego/cartpole/main.py @@ -0,0 +1,52 @@ +import argparse +import numpy as np +import gymnasium as gym + +from tqdm import tqdm +from stratetrees.models import QTree + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--strategy', '-s', type=str, + help='Path to the json file containing the UPPAAL Stratego strategy' + ) + parser.add_argument( + '--epochs', '-e', type=int, default=1000, + help='Number of epochs to run' + ) + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + model = QTree(args.strategy) + + env = gym.make('CartPole-v1') + + print(f'Evaluating {args.strategy} on {args.epochs} episodes') + + rewards = [] + for epoch in tqdm(range(args.epochs)): + + obs, info = env.reset() + epoch_reward = 0 + + done, trunc = False, False + while not (done or trunc): + + action = model.predict(obs) + obs, reward, done, trunc, info = env.step(action) + + epoch_reward += reward + + rewards.append(epoch_reward) + + print() + print('Mean score: {} (+/- {})'.format( + np.round(np.mean(rewards), 2), + np.round(np.std(rewards), 2) + )) + print('(perfect score is 500)') + print() diff --git a/Demos/Stratego/cartpole/reinforcement-learning-cartpole-v0.gif b/Demos/Stratego/cartpole/reinforcement-learning-cartpole-v0.gif new file mode 100644 index 0000000..44e99c4 Binary files /dev/null and b/Demos/Stratego/cartpole/reinforcement-learning-cartpole-v0.gif differ diff --git a/Demos/Stratego/cartpole/requirements.txt b/Demos/Stratego/cartpole/requirements.txt new file mode 100644 index 0000000..90e3503 --- /dev/null +++ b/Demos/Stratego/cartpole/requirements.txt @@ -0,0 +1,2 @@ +gymnasium==0.28.1 +stratetrees==0.0.1