From 4b97b2fe99e3e61ad16dad8097661696842be7b6 Mon Sep 17 00:00:00 2001 From: Ayush Ishan <77822265+AYUSH-ISHAN@users.noreply.github.com> Date: Thu, 23 Jun 2022 19:26:21 +0530 Subject: [PATCH] Update updet_agent.py --- modules/agents/updet_agent.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/modules/agents/updet_agent.py b/modules/agents/updet_agent.py index 097cc84..dc81592 100644 --- a/modules/agents/updet_agent.py +++ b/modules/agents/updet_agent.py @@ -16,6 +16,9 @@ def init_hidden(self): return torch.zeros(1, self.args.emb).cpu() def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num): + + hidden_state = + outputs, _ = self.transformer.forward(inputs, hidden_state, None) # first output for 6 action (no_op stop up down left right) q_basic_actions = self.q_basic(outputs[:, 0, :]) @@ -42,6 +45,17 @@ def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num): return q, h + +def get_adjacency_matrix(obs): + adj = np.zeros((n_ant, n_ant)) + for agent in range(n_ant): + for i in range(agent): # already other half is marked below in index + # print(agent, i) + if(((obs[agent][2]-obs[i][2])**2 +(obs[agent][3]-obs[i][3])**2) < 0.1): + adj[agent][i] = 1 + adj[i][agent]=1 + return adj + class SelfAttention(nn.Module): def __init__(self, emb, heads=8, mask=False):