Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
rd20karim committed Aug 21, 2024
1 parent d771bd9 commit 2d402d6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 31 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
## Description

Official implementation of our paper for interpretable motion to text generation:
[BMVC 2024] Official implementation of our paper for interpretable motion to text generation:

<div align="center">

[<span style="font-size: 25px;">Guided Attention for Interpretable Motion Captioning</span>](https://hal.science/hal-04251363v1)
[<span style="font-size: 25px;">Guided Attention for Interpretable Motion Captioning (BMVC 2024)</span>](https://hal.science/hal-04251363v1)

[![arxiv](https://img.shields.io/badge/arXiv-Motion_to_Text-red?logo=arxiv)](https://arxiv.org/abs/2310.07324v1)
[![BMVC](https://img.shields.io/badge/BMVC2024-gold)](https://bmvc2024.org/programme/accepted_papers/)
[![License](https://img.shields.io/badge/License-MIT-green)]()

</div>
Expand All @@ -20,13 +21,11 @@ Official implementation of our paper for interpretable motion to text generation
If you find this code or paper useful in your work, please cite:

```
@article{radouane2023guided,
@INPROCEEDINGS{radouane2024guided,
title={Guided Attention for Interpretable Motion Captioning},
author={Karim Radouane and Andon Tchechmedjiev and Sylvie Ranwez and Julien Lagarde},
year={2023},
eprint={2310.07324},
archivePrefix={arXiv},
primaryClass={cs.CV}
booktitle = {Proceedings of the 35th British Machine Vision Conference},
year={2024}
}
```

Expand Down Expand Up @@ -130,6 +129,7 @@ python nlg_evaluation.py
Generate skeleton animation and attention maps (adaptive+spatio-temporal):

```
set PYTHONPATH=Project_Absolute_Path
python visualizations/visu_word_attentions.py --path PATH --n_map NUMBER_ATTENTION_MAP --n_gifs NUMBER_3D_ANIMATIONS --save_results DIRECTORY_SAVE_PLOTS
```

Expand Down
22 changes: 4 additions & 18 deletions architectures/decode_beam_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def __init__(self, hiddenstate, previousNode, wordId, logProb, length,att_weight
self.att_position = att_position
def eval(self, alpha=1.0):
reward = 0
# Add here a function for shaping a reward

return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward


Expand All @@ -50,18 +48,16 @@ def beam_decode(self,target_tensor, decoder_hiddens, encoder_outputs=None):
decoded_batch = []
dec_pred_output = []


# decoding goes sentence by sentence
target_tensor = target_tensor.permute(1,0)
B = target_tensor.size(0) # batch_size
#TODO REMOVE DEBUG PARAM IN THE FINALE VESION
debug_param = B

# ------------ Spatial-Temp pose feature : R # shape : (T,N,V,C)
T, N, CV = self.x.size()
R = encoder_outputs.reshape(T, N, 6 * self.hidden_size).to(self.device)

# TODO CHANGE THIS CODE TO DECODE PER BATCH
for idx in range(debug_param):
for idx in range(B):
#x = xparts #shape[T,B,K,D] (k=6) (D=feature_dimension)
x = self.x[:,idx,:].unsqueeze(1)
enc_masks = torch.zeros((x.shape[0],1)) # (seq_len,batch_size)
Expand Down Expand Up @@ -100,7 +96,6 @@ def beam_decode(self,target_tensor, decoder_hiddens, encoder_outputs=None):
qsize = 1
# start beam search
while True:
#for idx in range(trg_len-1):
# give up when decoding takes too long
if qsize > 100: break
# fetch the best node
Expand Down Expand Up @@ -138,7 +133,7 @@ def beam_decode(self,target_tensor, decoder_hiddens, encoder_outputs=None):

# ------ TEMPORAL ATTENTION shape : (T,N,V*C)--------------------------------------------------------
ep_t = self.temporal_att(torch.tanh(
self.tempfeat_extract_hdec(bot_ht_1).expand(x.size()[:2] + (hidden_dim,)) + self.feat_extract_g(R))) # +self.speed_temp(velocity)))
self.tempfeat_extract_hdec(bot_ht_1).expand(x.size()[:2] + (hidden_dim,)) + self.feat_extract_g(R)))
ep_t = ep_t.masked_fill(enc_masks.unsqueeze(-1).to(self.device) == 0, float('-inf'))
b_t = torch.softmax(ep_t, dim=0)

Expand Down Expand Up @@ -172,9 +167,6 @@ def beam_decode(self,target_tensor, decoder_hiddens, encoder_outputs=None):
decoder_hidden = (bot_ht_1,bot_mt_1,top_ht_1,top_mt_1)

"""--------------------------------- END MODEL PREDICTION--------------------------------"""
#TODO ADAPT
# decoder_logits, decoder_hidden = self.dec(torch.tensor([[decoder_input]],device=self.device), decoder_hidden, encoder_output,enc_masks)

decoder_output = torch.log_softmax(decoder_logits,axis=-1)
# PUT HERE REAL BEAM SEARCH OF TOP
log_prob, indexes = torch.topk(decoder_output.squeeze(0), beam_width)
Expand Down Expand Up @@ -213,10 +205,4 @@ def beam_decode(self,target_tensor, decoder_hiddens, encoder_outputs=None):
utterances.append(utterance)
decoded_batch.append(utterances)

return decoded_batch

if __name__=="__main__":
hidden_size = 64
embedding_dim = 64
#decoder = seq2seq(642, hidden_size, embedding_dim, num_layers=1, device=device,
# bidirectional=False, attention="local", mask=True, beam_size=2).to(device)
return decoded_batch
12 changes: 6 additions & 6 deletions src/evaluate_m2L.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def load_model_config(args= None,device=None):
path_txt = project_path+"\datasets\sentences_corrections.csv"
path_motion = aug_path+"\kit_with_splits_2023.npz"

# ----------- H3D IMPORTS ---------------------
# ----------- [H3D IMPORTS] ---------------------
elif args.dataset_name=="h3D":
from architectures.LSTM_h3D import seq2seq
from datasets.h3d_m2t_dataset_ import dataset_class
Expand Down Expand Up @@ -55,12 +55,12 @@ def evaluate(loaded_model,data_loader,mode,multiple_references=False,name_file=N
loaded_model.eval()
epoch_loss = 0
output_per_batch,target_per_batch = [],[]
name_file = f"./Predictions/LSTM_{args.dataset_name}_{name_file}_{args.lambdas}"
BLEU_scores = []
with open(name_file + ".csv", mode="w") as _:pass
logging.info(f"Compute BLEU scores per batch and write predictions/refs to {name_file}")
loaded_model.eval()
if beam_size==1:
name_file = f"./Predictions/LSTM_{args.dataset_name}_{name_file}_{args.lambdas}"
BLEU_scores = []
with open(name_file + ".csv", mode="w") as _:pass
logging.info(f"Compute BLEU scores per batch and write predictions/refs to {name_file}")
for i, batch in enumerate(data_loader):
loss_b,bleu_score_4,pred,refs = run_batch(model=loaded_model,batch=batch,data_loader=data_loader,mode=mode,teacher_force_ratio=0,
device=device,multiple_references=multiple_references,attention_type=args.attention_type)
Expand Down Expand Up @@ -88,7 +88,7 @@ def evaluate(loaded_model,data_loader,mode,multiple_references=False,name_file=N

else:
logging.info("START BEAM SEARCHING")
file_save_beam = f"./Predictions/LSTM_{args.dataset_name}_preds_{args.lambdas}_beam_size_{beam_size}.csv"
file_save_beam = f"./Predictions/LSTM_{args.dataset_name}_{name_file}_{args.lambdas}_beam_size_{beam_size}.csv"
with open(file_save_beam,'w'): pass #create the file
beam_bleus = [[] for _ in range(beam_size)]
for i, batch in enumerate(data_loader):
Expand Down

0 comments on commit 2d402d6

Please sign in to comment.