Skip to content

Commit

Permalink
feat(trainer): integration with the rl algorithm with the customized …
Browse files Browse the repository at this point in the history
…environment

creation of the robot's cinematics, error in creating the map and error in the lidar
  • Loading branch information
Nicolasalan committed Jun 29, 2024
1 parent 1261e78 commit 1bee538
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 222 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@
sim:
@echo "Starting Training ..."
@python microvault/environment/continuous.py

# === Generate World === #
.PHONY: gen
gen:
@echo "Starting Training ..."
@python microvault/environment/generate.py
Empty file.
60 changes: 32 additions & 28 deletions microvault/algorithm/td3agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import os
import sys
from typing import Tuple

import numpy as np
Expand All @@ -8,8 +9,10 @@
import torch.optim as optim
from numpy import inf

from microvault.network.model import ModelActor, ModelCritic
from microvault.components.replaybuffer import PER
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from components.replaybuffer import PER
from network.model import ModelActor, ModelCritic

# Verificar se o cuda está disponivel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -23,8 +26,9 @@
BATCH_SIZE = 128 # Tamanho do lote


class TD3Agent:
class Agent:
"""Treinamento do agente com o Ambiente."""

def __init__(
self,
state_size: int = 24,
Expand Down Expand Up @@ -97,14 +101,19 @@ def __init__(
torch.load("/content/checkpoint_critic.pth", map_location=device)
)
self.critic_target = (
ModelCritic(state_size, action_size).to(device).eval().requires_grad_(False)
ModelCritic(state_size, action_size)
.to(device)
.eval()
.requires_grad_(False)
)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)

else:
# Rede "Actor" (com/ Rede Alvo)
self.actor = ModelActor(state_size, action_size, float(max_action)).to(device)
self.actor = ModelActor(state_size, action_size, float(max_action)).to(
device
)
self.actor_target = (
ModelActor(state_size, action_size, float(max_action))
.to(device)
Expand All @@ -113,14 +122,17 @@ def __init__(
)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)

self.actor_noised = ModelActor(state_size, action_size, float(max_action)).to(
device
)
self.actor_noised = ModelActor(
state_size, action_size, float(max_action)
).to(device)

# Rede "Actor" (com/ Rede Alvo)
self.critic = ModelCritic(state_size, action_size).to(device)
self.critic_target = (
ModelCritic(state_size, action_size).to(device).eval().requires_grad_(False)
ModelCritic(state_size, action_size)
.to(device)
.eval()
.requires_grad_(False)
)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)
Expand All @@ -131,7 +143,7 @@ def __init__(
# Fonte: https://arxiv.org/abs/1511.05952
self.memory = PER(BUFFER_SIZE, BATCH_SIZE, GAMMA, self.nstep)

# Inicializar o modelo RND
# TODO: Inicializar o modelo RND

def step(
self,
Expand All @@ -141,18 +153,13 @@ def step(
next_state: np.ndarray,
done: bool,
) -> None:
if isinstance(state, tuple):
state = np.array(state[0])
"""Salvar experiência na memória de replay (estado, ação, recompensa, próximo estado, feito)."""

self.memory.add(state, action, reward, next_state, done)

def predict(self, states: np.ndarray) -> np.ndarray:
"""Retorna ações para determinado estado de acordo com a política atual."""

if isinstance(states, tuple):
states = np.array(states[0])

assert isinstance(
states, np.ndarray
), "Os estados não são de estrutura de dados (np.ndarray) em PREDICT -> estados: {}.".format(
Expand All @@ -164,8 +171,8 @@ def predict(self, states: np.ndarray) -> np.ndarray:
type(states)
)
assert (
states.shape[0] == 24
), "O Tamanho dos estados não é (24) em PREDICT -> states size: {}.".format(
states.shape[0] == self.state_size
), f"O Tamanho dos estados não é {self.state_size} em PREDICT -> states size: {states.shape[0]}.".format(
states.shape[0]
)
assert (
Expand All @@ -174,10 +181,6 @@ def predict(self, states: np.ndarray) -> np.ndarray:
states.ndim
)

# Verificar se o estado não é uma tupla, se for converter para (np.array)
if isinstance(states, tuple):
states = np.array(states[0])

# Converter estados para tensor
state = torch.from_numpy(states).float().to(device)

Expand Down Expand Up @@ -235,10 +238,11 @@ def learn(
n_iteraion (int): O número de iterações para treinar a rede
gamma (float): Factor de desconto
"""
print("Learning...")

if episode % 200 == 0:
self.save(self.actor, "/content/", "actor", str(episode))
self.save(self.critic, "/content/", "critic", str(episode))
self.save(self.actor, "/content/", "actor", str(episode))
self.save(self.critic, "/content/", "critic", str(episode))

self.actor.train()
self.critic.train()
Expand Down Expand Up @@ -374,11 +378,11 @@ def soft_update(
@staticmethod
def save(model: nn.Module, path: str, filename: str, version: str) -> None:
"""Salvar o modelo"""
torch.save(model.state_dict(), path + "checkpoint_" + filename + "_" + version + ".pth")
torch.save(
model.state_dict(), path + "checkpoint_" + filename + "_" + version + ".pth"
)

@staticmethod
def load(model: nn.Module, path: str, device: str) -> None:
"""Carregar o modelo"""
model.load_state_dict(
torch.load(path, map_location=device)
) # del torch.load
model.load_state_dict(torch.load(path, map_location=device)) # del torch.load
47 changes: 34 additions & 13 deletions microvault/components/replaybuffer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

from collections import deque

import numpy as np
import torch
from microvault.components.sumtree import SumTree
from components.sumtree import SumTree

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -82,13 +81,13 @@ def add(
)

assert (
state.shape[0] == 24
), "The size of the state is not (24) in REPLAY BUFFER -> state size: {}.".format(
state.shape[0] == 23
), "The size of the state is not (23) in REPLAY BUFFER -> state size: {}.".format(
state.shape[0]
)
assert (
action.shape[0] == 4
), "The size of the action is not (4) in REPLAY BUFFER -> action size: {}.".format(
action.shape[0] == 2
), "The size of the action is not (2) in REPLAY BUFFER -> action size: {}.".format(
state.shape[0]
)
if isinstance(reward, np.float64):
Expand All @@ -98,8 +97,8 @@ def add(
reward.size
)
assert (
next_state.shape[0] == 24
), "The size of the next_state is not (24) in REPLAY BUFFER -> next_state size: {}.".format(
next_state.shape[0] == 23
), "The size of the next_state is not (23) in REPLAY BUFFER -> next_state size: {}.".format(
next_state.shape[0]
)

Expand Down Expand Up @@ -192,11 +191,33 @@ def sample(self):
idxs[i] = index
minibatch.append(data)

states = torch.from_numpy(np.vstack([e[0] for e in minibatch if e is not None])).float().to(device)
actions = torch.from_numpy(np.vstack([e[1] for e in minibatch if e is not None])).float().to(device)
rewards = torch.from_numpy(np.vstack([e[2] for e in minibatch if e is not None])).float().to(device)
next_states = torch.from_numpy(np.vstack([e[3] for e in minibatch if e is not None])).float().to(device)
dones = torch.from_numpy(np.vstack([e[4] for e in minibatch if e is not None]).astype(np.uint8)).float().to(device)
states = (
torch.from_numpy(np.vstack([e[0] for e in minibatch if e is not None]))
.float()
.to(device)
)
actions = (
torch.from_numpy(np.vstack([e[1] for e in minibatch if e is not None]))
.float()
.to(device)
)
rewards = (
torch.from_numpy(np.vstack([e[2] for e in minibatch if e is not None]))
.float()
.to(device)
)
next_states = (
torch.from_numpy(np.vstack([e[3] for e in minibatch if e is not None]))
.float()
.to(device)
)
dones = (
torch.from_numpy(
np.vstack([e[4] for e in minibatch if e is not None]).astype(np.uint8)
)
.float()
.to(device)
)

# assert isinstance(
# states, torch.Tensor
Expand Down
Loading

0 comments on commit 1bee538

Please sign in to comment.