Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions strategy_train_env/bidding_train_env/baseline/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, dim_obs=2, dim_actions=1, gamma=1, tau=0.001, V_lr=1e-4, crit
self.log_alpha = torch.zeros(1, requires_grad=True)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=actor_lr)
self.min_q_weight = 1.0
self.num_random = 10

self.to(self.device)

Expand All @@ -97,7 +98,7 @@ def build_network(self, input_dim: int, output_dim: int) -> nn.Sequential:
nn.Linear(256, output_dim),
)

def step(self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, dones: torch.Tensor):
def step(self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, dones: torch.Tensor, step_num: int = 0) -> tuple:
'''
train model
'''
Expand All @@ -124,6 +125,15 @@ def step(self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tenso
)

policy_loss = (alpha * log_pi - q_new_actions).mean()
if step_num < 2000:
"""
For the initial few epochs, try doing behaivoral cloning, if needed
conventionally, there's not much difference in performance with having 20k
gradient steps here, or not having it
"""
policy_log_prob = self.policy.compute_log_prob(actions, states)
policy_loss = (alpha * log_pi - policy_log_prob).mean()

self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
Expand All @@ -140,16 +150,24 @@ def step(self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tenso
self.target_qf2(torch.cat([next_states, next_actions], dim=-1))
)
target_q_values = rewards + (1. - dones) * self.gamma * (target_q_values - alpha * next_log_pi)
target_q_values = target_q_values.detach()
#print("target_q_values",target_q_values)

# CQL Regularization
# Ensure the generated action range is broader than data action range for effective training
random_actions = torch.FloatTensor(states.shape[0], actions.shape[-1]).uniform_(-1000, 1000).to(self.device)
q1_rand = self.qf1(torch.cat([states, random_actions], dim=-1))
q2_rand = self.qf2(torch.cat([states, random_actions], dim=-1))
random_actions = torch.FloatTensor(states.shape[0] * self.num_random, actions.shape[-1]).uniform_(0, 300).to(self.device)
curr_actions_tensor = self._get_policy_actions(states, self.num_random, self.policy)
next_actions_tensor = self._get_policy_actions(next_states, self.num_random, self.policy)
q1_rand = self._get_tensor_values(states, random_actions, self.qf1)
q2_rand = self._get_tensor_values(states, random_actions, self.qf2)
q1_curr_actions = self._get_tensor_values(states, curr_actions_tensor, self.qf1)
q2_curr_actions = self._get_tensor_values(states, curr_actions_tensor, self.qf2)
q1_next_actions = self._get_tensor_values(next_states, next_actions_tensor, self.qf1)
q2_next_actions = self._get_tensor_values(next_states, next_actions_tensor, self.qf2)


cat_q1 = torch.cat([q1_rand, q1_pred, q1_pred], 1)
cat_q2 = torch.cat([q2_rand, q2_pred, q2_pred], 1)
cat_q1 = torch.cat([q1_rand, q1_pred.unsqueeze(1), q1_curr_actions, q1_next_actions], 1)
cat_q2 = torch.cat([q2_rand, q2_pred.unsqueeze(1), q2_curr_actions, q2_next_actions], 1)

min_qf1_loss = torch.logsumexp(cat_q1 / self.temperature,
dim=1).mean() * self.min_q_weight * self.temperature - q1_pred.mean() * self.min_q_weight
Expand All @@ -160,15 +178,18 @@ def step(self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tenso
diff1 = target_q_values - q1_pred
diff2 = target_q_values - q2_pred

qf1_loss = torch.where(diff1 > 0, self.expectile * diff1 ** 2, (1 - self.expectile) * diff1 ** 2).mean() + min_qf1_loss
qf2_loss = torch.where(diff2 > 0, self.expectile * diff2 ** 2, (1 - self.expectile) * diff2 ** 2).mean() + min_qf2_loss
qf1_loss = torch.nn.MSELoss()(q1_pred, target_q_values) + min_qf1_loss
qf2_loss = torch.nn.MSELoss()(q2_pred, target_q_values) + min_qf2_loss

#qf1_loss = torch.where(diff1 > 0, self.expectile * diff1 ** 2, (1 - self.expectile) * diff1 ** 2).mean() + min_qf1_loss
#qf2_loss = torch.where(diff2 > 0, self.expectile * diff2 ** 2, (1 - self.expectile) * diff2 ** 2).mean() + min_qf2_loss

self.qf1_optimizer.zero_grad()
qf1_loss.backward()
qf1_loss.backward(retain_graph=True)
self.qf1_optimizer.step()

self.qf2_optimizer.zero_grad()
qf2_loss.backward()
qf2_loss.backward(retain_graph=True)
self.qf2_optimizer.step()

# Soft update target networks
Expand Down Expand Up @@ -209,14 +230,14 @@ def save_net(self, save_path: str) -> None:
torch.save(self.qf2.state_dict(), save_path + "/qf2" + ".pkl")
torch.save(self.policy.state_dict(), save_path + "/policy" + ".pkl")

def save_jit(self, save_path: str) -> None:
def save_jit(self, save_path: str, seed:int, step:int) -> None:
'''
save model as JIT
'''
if not os.path.isdir(save_path):
os.makedirs(save_path)
scripted_policy = torch.jit.script(self.cpu())
scripted_policy.save(save_path + "/cql_model" + ".pth")
scripted_policy.save(save_path + "/cql_model" + "_" +str(seed) + "_" +str(step) + ".pth")

def load_net(self, load_path="saved_model/fixed_initial_budget", device='cuda:0') -> None:
'''
Expand All @@ -232,6 +253,24 @@ def load_net(self, load_path="saved_model/fixed_initial_budget", device='cuda:0'
self.target_qf1.to(self.device)
self.target_qf2.to(self.device)

def _get_tensor_values(self, obs, actions, network=None):
action_shape = actions.shape[0]
obs_shape = obs.shape[0]
num_repeat = int (action_shape / obs_shape)
obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
state_action_input = torch.cat([obs_temp, actions], dim=-1)
preds = network(state_action_input)
preds = preds.view(obs.shape[0], num_repeat, 1)
return preds

def _get_policy_actions(self, obs, num_actions, network=None):
obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
new_obs_actions = network(
obs_temp, reparameterize=True
)
return new_obs_actions


if __name__ == '__main__':
model = CQL(dim_obs=2)
model.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -254,10 +293,10 @@ def load_net(self, load_path="saved_model/fixed_initial_budget", device='cuda:0'
dtype=torch.float), torch.tensor(
terminals, dtype=torch.float)

q_loss, v_loss, a_loss = model.step(states, actions, rewards, next_states, terminals)
print(f'step:{i} q_loss:{q_loss} v_loss:{v_loss} a_loss:{a_loss}')
q1_loss, q2_loss, a_loss = model.step(states, actions, rewards, next_states, terminals)

#print(f'step:{i} q_loss:{q_loss} v_loss:{v_loss} a_loss:{a_loss}')

total_params = sum(p.numel() for p in model.parameters())
print("Learnable parameters: {:,}".format(total_params))


7 changes: 4 additions & 3 deletions strategy_train_env/main/main_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from run.run_cql import run_cql

torch.manual_seed(1)
np.random.seed(1)
seed = 3
torch.manual_seed(seed)
np.random.seed(seed)

if __name__ == "__main__":
run_cql()
run_cql(seed=seed)
30 changes: 18 additions & 12 deletions strategy_train_env/run/run_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import pandas as pd
import ast
import pickle
from torch.utils.tensorboard import SummaryWriter
# Configure logging
logging.basicConfig(level=logging.INFO,
format="[%(asctime)s] [%(name)s] [%(filename)s(%(lineno)d)] [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)

STATE_DIM = 16

def train_cql_model():
def train_cql_model(seed=1):
"""
Train the CQL model.
"""
Expand Down Expand Up @@ -45,15 +46,14 @@ def safe_literal_eval(val):
# Build replay buffer
replay_buffer = ReplayBuffer()
add_to_replay_buffer(replay_buffer, training_data, is_normalize)
print(len(replay_buffer.memory))

# Train model
model = CQL(dim_obs=STATE_DIM)
train_model_steps(model, replay_buffer)
train_model_steps(model, replay_buffer, seed=seed)

# Save model
# model.save_net("saved_model/CQLtest")
model.save_jit("saved_model/CQLtest")
#model.save_jit("saved_model/CQLtest", seed=seed, step=i)

# Test trained model
test_trained_model(model, replay_buffer)
Expand All @@ -69,13 +69,20 @@ def add_to_replay_buffer(replay_buffer, training_data, is_normalize):
replay_buffer.push(np.array(state), np.array([action]), np.array([reward]), np.zeros_like(state),
np.array([done]))

def train_model_steps(model, replay_buffer, step_num=100, batch_size=100):
def train_model_steps(model, replay_buffer, step_num=100000, batch_size=100, seed=1):
for i in range(step_num):
if i==8000:
pass
states, actions, rewards, next_states, terminals = replay_buffer.sample(batch_size)
q_loss, v_loss, a_loss = model.step(states, actions, rewards, next_states, terminals)
logger.info(f'Step: {i} Q_loss: {q_loss} V_loss: {v_loss} A_loss: {a_loss}')
q1_loss, q2_loss, policy_loss = model.step(states, actions, rewards, next_states, terminals, i)
if i == 0:
writer = SummaryWriter(log_dir="tensorboard/CQL/" + str(seed))
writer.add_scalar('Loss/q1_loss', q1_loss, i)
writer.add_scalar('Loss/q2_loss', q2_loss, i)
writer.add_scalar('Loss/policy_loss', policy_loss, i)
if i == step_num - 1:
writer.close()
if i+1 % 10000 == 0:
model.save_jit("saved_model/CQLtest", seed=seed, step=i+1)
#logger.info(f'Step: {i} Q_loss: {q_loss} V_loss: {v_loss} A_loss: {a_loss}')

def test_trained_model(model, replay_buffer):
for i in range(100):
Expand All @@ -85,12 +92,11 @@ def test_trained_model(model, replay_buffer):
tem = np.concatenate((actions, pred_actions), axis=1)
print("concate:",tem)

def run_cql():
print(sys.path)
def run_cql(seed=1):
"""
Run CQL model training and evaluation.
"""
train_cql_model()
train_cql_model(seed=seed)

if __name__ == '__main__':
run_cql()