def forward(self, src):
mask = self._generate_square_subsequent_mask(len(src)).to(self.device)
self.src_mask = mask
src = self.input_emb(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output_enc = self.encoder(src, mask=self.src_mask)
output_dec = self.decoder(output_enc)
return F.log_softmax(output_dec, dim=-1), output_enc
it seems mask is always a causal mask, no matter it is an arm or llada model. According to the llada paper,it should not use it.
Correct me if I am wrong