Skip to content

Commit d1d238f

Browse files
authored
Merge pull request #256 from bokveizen/fanchen_250306
[Bug fixes for #255] Added environment config files for FLP and MCP
2 parents cb72927 + ce7e96c commit d1d238f

File tree

5 files changed

+61
-1
lines changed

5 files changed

+61
-1
lines changed

configs/env/flp.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
_target_: rl4co.envs.FLPEnv
2+
name: flp
3+
4+
generator_params:
5+
num_loc: 100
6+
min_loc: 0.0
7+
max_loc: 1.0
8+
loc_distribution: uniform
9+
to_choose: 10
10+
11+
# data_dir: ${paths.root_dir}/data/mcp
12+
# val_file: mcp${env.generator_params.num_loc}_val_seed4321.npz
13+
# test_file: mcp${env.generator_params.num_loc}_test_seed1234.npz

configs/env/mcp.yaml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
_target_: rl4co.envs.MCPEnv
2+
name: mcp
3+
4+
generator_params:
5+
num_items: 200
6+
num_sets: 100
7+
min_weight: 1
8+
max_weight: 10
9+
min_size: 5
10+
max_size: 15
11+
n_sets_to_choose: 10
12+
size_distribution: uniform
13+
weight_distribution: uniform
14+
15+
# data_dir: ${paths.root_dir}/data/mcp
16+
# val_file: mcp${env.generator_params.num_loc}_val_seed4321.npz
17+
# test_file: mcp${env.generator_params.num_loc}_test_seed1234.npz

configs/experiment/graph/am.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
defaults:
44
- override /model: am.yaml
5-
- override /env: mcp.yaml
5+
- override /env: flp.yaml
66
- override /callbacks: default.yaml
77
- override /trainer: default.yaml
88
- override /logger: wandb.yaml

rl4co/models/nn/env_embeddings/context.py

+20
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
3535
"mdcpdp": MDCPDPContext,
3636
"mtvrp": MTVRPContext,
3737
"shpp": TSPContext,
38+
"flp": FLPContext,
3839
}
3940

4041
if env_name not in embedding_registry:
@@ -372,3 +373,22 @@ def _state_embedding(self, embeddings, td):
372373
],
373374
-1,
374375
)
376+
377+
class FLPContext(EnvContext):
378+
"""Context embedding for the Facility Location Problem (FLP).
379+
"""
380+
def __init__(self, embed_dim: int):
381+
super(FLPContext, self).__init__(embed_dim=embed_dim)
382+
self.embed_dim = embed_dim
383+
# self.mlp_context = MLP(embed_dim, [embed_dim, embed_dim])
384+
self.projection = nn.Linear(embed_dim, embed_dim, bias=True)
385+
386+
def forward(self, embeddings, td):
387+
cur_dist = td["distances"].unsqueeze(-2) # (batch_size, 1, n_points)
388+
dist_improve = cur_dist - td["orig_distances"] # (batch_size, n_points, n_points)
389+
dist_improve = torch.clamp(dist_improve, min=0).sum(-1) # (batch_size, n_points)
390+
391+
# softmax
392+
loc_best_soft = torch.softmax(dist_improve, dim=-1) # (batch_size, n_points)
393+
embed_best = (embeddings * loc_best_soft[..., None]).sum(-2)
394+
return embed_best

rl4co/models/nn/env_embeddings/init.py

+10
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
4040
"jssp": FJSPInitEmbedding,
4141
"mtvrp": MTVRPInitEmbedding,
4242
"shpp": TSPInitEmbedding,
43+
"flp": FLPInitEmbedding,
4344
}
4445

4546
if env_name not in embedding_registry:
@@ -562,3 +563,12 @@ def forward(self, td):
562563
)
563564
)
564565
return torch.cat((depot_embedding, node_embeddings), -2)
566+
567+
class FLPInitEmbedding(nn.Module):
568+
def __init__(self, embed_dim: int):
569+
super().__init__()
570+
self.projection = nn.Linear(2, embed_dim, bias=True)
571+
572+
def forward(self, td: TensorDict):
573+
hdim = self.projection(td["locs"])
574+
return hdim

0 commit comments

Comments
 (0)