Skip to content

Commit fd34371

Browse files
authored
Merge pull request #257 from bokveizen/fanchen_250306
[Additional features for #255] Added embedding functions for MCP
2 parents d1d238f + 53bfefc commit fd34371

File tree

4 files changed

+75
-8
lines changed

4 files changed

+75
-8
lines changed

rl4co/models/nn/env_embeddings/context.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from tensordict import TensorDict
55

6-
from rl4co.utils.ops import gather_by_index
6+
from rl4co.utils.ops import gather_by_index, batched_scatter_sum
77

88

99
def env_context_embedding(env_name: str, config: dict) -> nn.Module:
@@ -36,6 +36,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
3636
"mtvrp": MTVRPContext,
3737
"shpp": TSPContext,
3838
"flp": FLPContext,
39+
"mcp": MCPContext,
3940
}
4041

4142
if env_name not in embedding_registry:
@@ -379,9 +380,8 @@ class FLPContext(EnvContext):
379380
"""
380381
def __init__(self, embed_dim: int):
381382
super(FLPContext, self).__init__(embed_dim=embed_dim)
382-
self.embed_dim = embed_dim
383-
# self.mlp_context = MLP(embed_dim, [embed_dim, embed_dim])
384-
self.projection = nn.Linear(embed_dim, embed_dim, bias=True)
383+
self.embed_dim = embed_dim
384+
self.project_context = nn.Linear(embed_dim, embed_dim, bias=True)
385385

386386
def forward(self, embeddings, td):
387387
cur_dist = td["distances"].unsqueeze(-2) # (batch_size, 1, n_points)
@@ -390,5 +390,27 @@ def forward(self, embeddings, td):
390390

391391
# softmax
392392
loc_best_soft = torch.softmax(dist_improve, dim=-1) # (batch_size, n_points)
393-
embed_best = (embeddings * loc_best_soft[..., None]).sum(-2)
394-
return embed_best
393+
context_embedding = (embeddings * loc_best_soft[..., None]).sum(-2)
394+
return self.project_context(context_embedding)
395+
396+
class MCPContext(EnvContext):
397+
"""Context embedding for the Maximum Coverage Problem (MCP).
398+
"""
399+
def __init__(self, embed_dim: int):
400+
super(MCPContext, self).__init__(embed_dim=embed_dim)
401+
self.embed_dim = embed_dim
402+
self.project_context = nn.Linear(embed_dim, embed_dim, bias=True)
403+
404+
def forward(self, embeddings, td):
405+
membership_weighted = batched_scatter_sum(
406+
td["weights"].unsqueeze(-1), td["membership"].long()
407+
)
408+
membership_weighted.squeeze_(-1)
409+
# membership_weighted: [batch_size, n_sets]
410+
411+
# softmax; higher weights for better sets
412+
membership_weighted = torch.softmax(
413+
membership_weighted, dim=-1
414+
) # (batch_size, n_sets)
415+
context_embedding = (membership_weighted.unsqueeze(-1) * embeddings).sum(1)
416+
return self.project_context(context_embedding)

rl4co/models/nn/env_embeddings/init.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tensordict.tensordict import TensorDict
55

66
from rl4co.models.nn.ops import PositionalEncoding
7-
from rl4co.utils.ops import cartesian_to_polar
7+
from rl4co.utils.ops import cartesian_to_polar, batched_scatter_sum
88

99

1010
def env_init_embedding(env_name: str, config: dict) -> nn.Module:
@@ -41,6 +41,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
4141
"mtvrp": MTVRPInitEmbedding,
4242
"shpp": TSPInitEmbedding,
4343
"flp": FLPInitEmbedding,
44+
"mcp": MCPInitEmbedding,
4445
}
4546

4647
if env_name not in embedding_registry:
@@ -571,4 +572,16 @@ def __init__(self, embed_dim: int):
571572

572573
def forward(self, td: TensorDict):
573574
hdim = self.projection(td["locs"])
574-
return hdim
575+
return hdim
576+
577+
class MCPInitEmbedding(nn.Module):
578+
def __init__(self, embed_dim: int):
579+
super().__init__()
580+
self.projection_items = nn.Linear(1, embed_dim, bias=True)
581+
582+
def forward(self, td: TensorDict):
583+
items_embed = self.projection_items(td["weights"].unsqueeze(-1))
584+
# sum pooling
585+
membership_emb = batched_scatter_sum(items_embed, td["membership"].long())
586+
return membership_emb
587+

rl4co/utils/ops.py

+30
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,33 @@ def select_start_nodes_by_distance(td, env, num_starts, exclude_depot=True):
283283
)
284284
selected_nodes = node_index[:, 1:] if exclude_depot else node_index[:, :-1]
285285
return rearrange(selected_nodes, "b n -> (n b)")
286+
287+
288+
def batched_scatter_sum(src, idx):
289+
"""Performs a batched scatter and sum operation on the source tensor using the provided indices.
290+
291+
Parameters:
292+
src (Tensor): A tensor of shape [batch_size, N, h].
293+
Contains the data to be scattered and summed.
294+
idx (Tensor): A tensor of shape [batch_size, M, K] with zero-padding.
295+
Each non-zero element in idx represents an index (offset by 1)
296+
into src. A zero value indicates a padded (invalid) index.
297+
298+
Returns:
299+
Tensor: A tensor of shape [batch_size, M, h] where for each batch and each index j,
300+
the output is computed as:
301+
Output[batch, j] = sum(src[batch, k - 1] for k in idx[batch, j] if k != 0)
302+
The subtraction of 1 is applied because 0 is used as the padding value.
303+
304+
Details:
305+
- A temporary target tensor (tgt) of shape [batch_size, N+1, h] is created,
306+
where tgt[:, 1:] is populated with src.
307+
- The function reshapes idx to gather the corresponding values and then reshapes
308+
the result back to [batch_size, M, K, h] before summing over the scattering dimension.
309+
"""
310+
bs, N, h = src.shape
311+
bs, M, K = idx.shape
312+
tgt = torch.zeros(bs, N + 1, h, device=src.device)
313+
tgt[:, 1:] = src
314+
tgt = gather_by_index(tgt, idx.long().reshape(bs, -1), squeeze=False)
315+
return tgt.reshape(bs, M, K, h).sum(-2)

tests/test_policy.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"dpp",
2121
"mdpp",
2222
"smtwtp",
23+
"flp",
24+
"mcp",
2325
],
2426
)
2527
def test_am_policy(env_name, size=20, batch_size=2):

0 commit comments

Comments
 (0)