Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
dev0Guy committed Nov 16, 2024
1 parent b262323 commit 50033ae
Showing 1 changed file with 48 additions and 14 deletions.
62 changes: 48 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import gymnasium as gym
import tianshou as ts
import clusterenv.envs
import torch, numpy as np
import torch
import numpy as np
from torch import nn

import logging
Expand All @@ -24,18 +25,44 @@ def __init__(self, machine_shape, job_shape, action_shape):

# Sub-networks for machines and jobs
self.machine_net = nn.Sequential(
nn.Conv2d(machine_channels, 32, kernel_size=(3, 3), stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(
machine_channels,
32,
kernel_size=(
3,
3),
stride=1,
padding=1),
nn.ReLU(
inplace=True),
nn.Flatten(),
nn.Linear(32 * machine_shape[0] * machine_shape[2], 128),
nn.ReLU(inplace=True),
nn.Linear(
32 *
machine_shape[0] *
machine_shape[2],
128),
nn.ReLU(
inplace=True),
)
self.job_net = nn.Sequential(
nn.Conv2d(job_channels, 32, kernel_size=(3, 3), stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(
job_channels,
32,
kernel_size=(
3,
3),
stride=1,
padding=1),
nn.ReLU(
inplace=True),
nn.Flatten(),
nn.Linear(32 * job_shape[0] * job_shape[2], 128),
nn.ReLU(inplace=True),
nn.Linear(
32 *
job_shape[0] *
job_shape[2],
128),
nn.ReLU(
inplace=True),
)
# Combine both outputs
self.combined_net = nn.Sequential(
Expand All @@ -57,7 +84,8 @@ def forward(self, obs, state=None, info={}):
jobs = obs["jobs"]

# Ensure proper shape for the convolutional layers
machines = machines.permute(0, 2, 1, 3) # (batch, channels, height, width)
# (batch, channels, height, width)
machines = machines.permute(0, 2, 1, 3)
jobs = jobs.permute(0, 2, 1, 3) # (batch, channels, height, width)

# Process machines and jobs separately
Expand All @@ -73,11 +101,16 @@ def forward(self, obs, state=None, info={}):

def main() -> None:
env_id = "Cluster-discrete-v0"
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_id) for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_id) for _ in range(100)])
train_envs = ts.env.DummyVectorEnv(
[lambda: gym.make(env_id) for _ in range(10)])
test_envs = ts.env.DummyVectorEnv(
[lambda: gym.make(env_id) for _ in range(100)])
state_shape = train_envs.observation_space[0]
action_shape = train_envs.action_space[0]
net = Net(state_shape["machines"].shape, state_shape["jobs"].shape, action_shape.n)
net = Net(
state_shape["machines"].shape,
state_shape["jobs"].shape,
action_shape.n)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
policy = ts.policy.DQNPolicy(
model=net,
Expand All @@ -93,7 +126,8 @@ def main() -> None:
ts.data.VectorReplayBuffer(2_000, 200),
exploration_noise=True,
)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)
test_collector = ts.data.Collector(
policy, test_envs, exploration_noise=True)
result = ts.trainer.OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
Expand Down

0 comments on commit 50033ae

Please sign in to comment.