@@ -35,6 +35,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
35
35
"mdcpdp" : MDCPDPContext ,
36
36
"mtvrp" : MTVRPContext ,
37
37
"shpp" : TSPContext ,
38
+ "flp" : FLPContext ,
38
39
}
39
40
40
41
if env_name not in embedding_registry :
@@ -372,3 +373,22 @@ def _state_embedding(self, embeddings, td):
372
373
],
373
374
- 1 ,
374
375
)
376
+
377
+ class FLPContext (EnvContext ):
378
+ """Context embedding for the Facility Location Problem (FLP).
379
+ """
380
+ def __init__ (self , embed_dim : int ):
381
+ super (FLPContext , self ).__init__ (embed_dim = embed_dim )
382
+ self .embed_dim = embed_dim
383
+ # self.mlp_context = MLP(embed_dim, [embed_dim, embed_dim])
384
+ self .projection = nn .Linear (embed_dim , embed_dim , bias = True )
385
+
386
+ def forward (self , embeddings , td ):
387
+ cur_dist = td ["distances" ].unsqueeze (- 2 ) # (batch_size, 1, n_points)
388
+ dist_improve = cur_dist - td ["orig_distances" ] # (batch_size, n_points, n_points)
389
+ dist_improve = torch .clamp (dist_improve , min = 0 ).sum (- 1 ) # (batch_size, n_points)
390
+
391
+ # softmax
392
+ loc_best_soft = torch .softmax (dist_improve , dim = - 1 ) # (batch_size, n_points)
393
+ embed_best = (embeddings * loc_best_soft [..., None ]).sum (- 2 )
394
+ return embed_best
0 commit comments