Skip to content

Commit

Permalink
Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
rd20karim committed May 16, 2024
1 parent bf38026 commit d771bd9
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 102 deletions.
36 changes: 10 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ For both HumanML3D and KIT-MLD (augmented versions) you can follow the steps her

You can download the best models here: [models_weights](https://drive.google.com/drive/folders/1LiirfvZsU5FX1SgNQ1Y6xMeAt52lbVX6?usp=sharing)

| Dataset | Human-ML3D | Human-ML3D | Human-ML3D | KIT-ML | KIT-ML | KIT-ML |
|-----------------------|------------|------------|------------|-----------|------------|----------|
| Run ID | ge2gc507 | v6tv9rsx | hnjlc7r6 | u8qatt2y | ton5mfwh | lupw04om |
| Attention supervision | [0,0] | [2,3] | [3,3] | [0,0] | [2,3] | [1,3] |
| Dataset | Human-ML3D | Human-ML3D | Human-ML3D | Human-ML3D | KIT-ML | KIT-ML | KIT-ML | KIT-ML |
|-----------------------|------------|------------|------------|------------|-----------|----------|-----------|----------|
| Run ID | ge2gc507 | ba9hhkji | v6tv9rsx | hnjlc7r6 | u8qatt2y | yxitfbp7 | ton5mfwh | lupw04om |
| Attention supervision | [0,0] | [0,3] | [2,3] | [3,3] | [0,0] | [0,3] | [2,3] | [1,3] |

Attention supervision parameters respectively refers to spatial and adaptive attention guidance weights.

Expand Down Expand Up @@ -103,13 +103,6 @@ python train_wandb.py --config {config_path} --dataset_name {dataset_name}
```




[//]: # ()
[//]: # (As described in the paper [Hal](https://hal.science/hal-04251363v1) according to ``(l_spat,l_adapt)`` values:)

[//]: # (* No attention supervision: ``(0,0)`` and Spatial or Adaptive:``(l_spat ,l_adapt)``)

* The config path specify the model to train and the hyperparameters to experiment, other values can be added by changing the config file of the chosen model
* SEED is fixed to ensure same model initialization across runs for reproducibility.
* Replace variables ``project_path`` and ``aug_path`` with your absolute data paths.
Expand Down Expand Up @@ -137,7 +130,7 @@ python nlg_evaluation.py
Generate skeleton animation and attention maps (adaptive+spatio-temporal):

```
python visualizations/poses2concepts.py --path PATH --attention_type ATTENTION_TYPE --n_map NUMBER_ATTENTION_MAP --n_gifs NUMBER_3D_ANIMATIONS --save_results DIRECTORY_SAVE_PLOTS
python visualizations/visu_word_attentions.py --path PATH --n_map NUMBER_ATTENTION_MAP --n_gifs NUMBER_3D_ANIMATIONS --save_results DIRECTORY_SAVE_PLOTS
```

The directory to which save the visualization could be set in the ```.yaml``` file for evaluation or given as argument ```--save_results path```
Expand All @@ -148,19 +141,6 @@ The directory to which save the visualization could be set in the ```.yaml``` fi
* The disk radius of each keypoint indicates the magnitude of the corresponding spatial attention weight.


[//]: # (<div align="center">)

[//]: # ()
[//]: # (<img src="./visualizations/readme/attention_sample_55.gif" alt="GIF 1" width="250" height="250">)

[//]: # (<img src="./visualizations/readme/attention_sample_69.gif" alt="GIF 2" width="250" height="250">)

[//]: # (<img src="./visualizations/readme/attention_sample_95.gif" alt="GIF 3" width="250" height="250">)

[//]: # (<img src="./visualizations/readme/attention_sample_2.gif" alt="GIF 3" width="250" height="250">)

[//]: # (</div>)


<div align="center">
<img src="./visualizations/readme/walk_bends_pickup_turns.gif" alt="GIF 1" width="250" height="250">
Expand All @@ -172,6 +152,10 @@ The directory to which save the visualization could be set in the ```.yaml``` fi

## Interpretability analysis

<div align="center">
<img src="./visualizations/readme/interpretablity_applications.png" alt="GIF 1" width="540" height="249">
</div>

The following steps can be explored for interpret-ability analysis:

* __Adaptive gate density__
Expand Down Expand Up @@ -220,5 +204,5 @@ This script will print the BLEU-4 score for each beam and write beam predictions

## License

This code is distributed under MIT LICENSE.
This code is distributed under [MIT LICENSE](https://github.com/rd20karim/M2T-Interpretable?tab=MIT-1-ov-file).

2 changes: 1 addition & 1 deletion architectures/LSTM_h3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self, x, y, teacher_force_ratio=0,src_lens=None):


if self.beam_size==1:
# TODO USE FULL TEACHER FORCING TO SPEED UP TRAINING AND GET OUT THIS LOOP !

for j in range(trg_len-1):
# ------- CLIP COIN
thr = random.random()
Expand Down
2 changes: 1 addition & 1 deletion architectures/LSTM_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(self, x, y, teacher_force_ratio=0,src_lens=None):


if self.beam_size==1:
# TODO USE FULL TEACHER FORCING TO SPEED UP TRAINING AND GET OUT THIS LOOP !

for j in range(trg_len-1):
# ------- CLIP COIN
thr = random.random()
Expand Down
65 changes: 4 additions & 61 deletions bleu_from_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from nltk.corpus.reader.wordnet import WordNetCorpusReader
from bert_score import score

#from nlgeval import NLGEval,compute_metrics

def read_pred_refs(path,split=True,tokenize=True):
predictions = []
references = []
Expand Down Expand Up @@ -119,22 +117,6 @@ def compute_bert_score(predictions,references,device="cpu"):
P,R,F1 = score(predictions,references,lang='en',rescale_with_baseline=True,idf=True,device=device)
_bert_sc = F1.mean().item()
return _bert_sc
def bleu_rouge_cider_dict(predictions,references):
with open("hyp_temp.txt",'w') as f:
f.writelines(predictions)
for i, refs in enumerate(references):
with open(f"./temp/ref_{i}.txt","w") as f:
f.write('\n'.join(refs))
nlg_eval = NLGEval(
metrics_to_omit=['METEOR','EmbeddingAverageCosineSimilarity','SkipThoughtCS','VectorExtremaCosineSimilarity','GreedyMatchingScore'])
return compute_metrics("hyp_temp.txt",[f"./temp/ref_{i}.txt" for i in range(len(references))])
def bleu_rouge_cider_dict_2(predictions,references):
nlg_eval = NLGEval(
metrics_to_omit=['METEOR','EmbeddingAverageCosineSimilarity','SkipThoughtCS','VectorExtremaCosineSimilarity','GreedyMatchingScore'])
ref_list = [list(refs) for refs in zip(*references)]
cand_list = predictions
return nlg_eval.compute_metrics(ref_list, cand_list)


def write_first_beam_and_refs(name_file,path,path_refs):
_, refs = read_pred_refs(path_refs, tokenize=False)
Expand All @@ -150,12 +132,11 @@ def write_first_beam_and_refs(name_file,path,path_refs):
if __name__=="__main__":

name_file = 'LSTM_h3D_preds_[0, 3]_beam_size_2' #'LSTM_kit_preds_[2, 3]_beam_size_3'#
abs_path = r"C:\Users\karim\PycharmProjects\M2T-Interpretable"
path = abs_path + "/Predictions/"+name_file+'.csv'
path_refs = abs_path + "/Predictions/LSTM_h3D_preds_[0, 3].csv"

path = "/home/karim/PycharmProjects/m2LSpTemp/src/Predictions/"+name_file+'.csv'
path_refs = "/home/karim/PycharmProjects/m2LSpTemp/src/Predictions/LSTM_h3D_preds_[0, 3].csv"

write_first_beam_and_refs("./src/temp/"+name_file,path,path_refs)

write_first_beam_and_refs(abs_path+"/src/temp/"+name_file,path,path_refs)


# predictions,references = read_pred_refs(path_refs,tokenize=False)
Expand All @@ -171,41 +152,3 @@ def write_first_beam_and_refs(name_file,path,path_refs):
# bleu_score = calculate_bleu(pa,ra,num_grams=4)
# df_bleu_ = bleu_to_df(pa,ra,smooth_method=SmoothingFunction().method0)
#
#

nlg_eval = NLGEval(
metrics_to_omit=['METEOR',
'EmbeddingAverageCosineSimilarity' ,
'SkipThoughtCS',
'VectorExtremaCosineSimilarity',
'GreedyMatchingScore']
)

# from collections import OrderedDict
# def evaluate_bleu_rouge_cider(text_loaders, file):
# bleu_score_dict = OrderedDict({})
# rouge_score_dict = OrderedDict({})
# cider_score_dict = OrderedDict({})
# # print(text_loaders.keys())
# print('========== Evaluating NLG Score ==========')
# for text_loader_name, text_loader in text_loaders.items():
#
# ref_list = [list(refs) for refs in zip(*text_loader.dataset.all_caption_list)]
# cand_list = text_loader.dataset.generated_texts_list
# scores = nlg_eval.compute_metrics(ref_list, cand_list)
# bleu_score_dict[text_loader_name] = np.array(
# [scores['Bleu_1'], scores['Bleu_2'], scores['Bleu_3'], scores['Bleu_4']])
# rouge_score_dict[text_loader_name] = scores['ROUGE_L']
# cider_score_dict[text_loader_name] = scores['CIDEr']
#
# line = f'---> [{text_loader_name}] BLEU: '
# for i in range(4):
# line += '(%d): %.4f ' % (i + 1, scores['Bleu_%d' % (i + 1)])
# print(line)
# print(line, file=file, flush=True)
#
# print(f'---> [{text_loader_name}] ROUGE_L: {scores["ROUGE_L"]:.4f}')
# print(f'---> [{text_loader_name}] ROUGE_L: {scores["ROUGE_L"]:.4f}', file=file, flush=True)
# print(f'---> [{text_loader_name}] CIDER: {scores["CIDEr"]:.4f}')
# print(f'---> [{text_loader_name}] CIDER: {scores["CIDEr"]:.4f}', file=file, flush=True)
# return bleu_score_dict, rouge_score_dict, cider_score_dict
9 changes: 4 additions & 5 deletions datasets/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def __init__(self, sentences, context_correction=False, correct_tokens=False,ask
#contextualSpellCheck.add_to_pipe(self.spacy_eng)
self.context_sentence_correction()
elif self.correct_tokens:
# Correction don't take account the context but more fast
#self.token_correction(ask_user)
pass
else:
self.corrected_sentences = self.sentences[:] # independent copy
Expand All @@ -49,7 +47,6 @@ def context_sentence_correction(self):
else:
self.corrected_sentences.append(desc)


def build_vocabulary(self, min_freq=1):
self.token_to_idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
self.idx_to_token = {}
Expand All @@ -72,7 +69,8 @@ def build_vocabulary(self, min_freq=1):
sp_count = len(sp_token) # count special tokens
# sanity check
assert self.vocab_size ==len(self.token_freq)+sp_count
# with open("./kit_vocab.txt", mode='w') as fw:

# with open("./dataset_vocab.txt", mode='w',encoding='utf-8') as fw:
# for key, value in self.token_freq.items():
# fw.write("%s:%s\n" % (key, value))

Expand All @@ -95,7 +93,8 @@ def build_vocabulary(self, min_freq=1):
# Note that self.idx_to_token don't have the same length as self.token_to_idx when min_freq !=1

self.vocab_size_unk = len(self.idx_to_token)

logging.info(f"The vocab size is {self.vocab_size}"
f" with minimum frequency of {min_freq} it becomes --> {self.vocab_size_unk} tokens")
# Verify if we have successive int indexes of tokens
idxs = set(self.token_to_idx.values())
try:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 9 additions & 8 deletions visualizations/visu_word_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def save_spatial_attention_figs(args,save_spTemp=True):
parser = argparse.ArgumentParser()
parser.add_argument("--path",type=str,help="Path of model weights not used if the config file is passed")
parser.add_argument("--dataset_name",type=str,default="h3D",choices=["h3D","kit"])
parser.add_argument("--run_id",type=str,default="h3D",choices=["h3D","kit"])
parser.add_argument("--kind",type=str,default="map",choices=["map","hist",'adapt','all'])
parser.add_argument("--config",type=str,default="../configs/lstm_eval_h3D.yaml")
parser.add_argument("--device",type=str,default="cpu")
Expand All @@ -163,19 +164,21 @@ def save_spatial_attention_figs(args,save_spTemp=True):
parser.add_argument("--batch_size", type=int, default=1024, help='Batch size should be >= to length of data to not have a variable BLEU score')
parser.add_argument("--n_map",type=int,default=10,required=False,help="Number of attention map to generate")
parser.add_argument("--n_gifs",type=int,default=100,help="Number of animation to generate")
parser.add_argument("--root_trajectory_sample",type=int,default=-1,help="Plot trajectory")
parser.add_argument("--save_results",type=str,help="Directory where to save generated plots")
# parser.add_argument("--start",type=int,default=0,help="Start sample index generation point")
# parser.add_argument('--indexs',type=int, nargs='+', help='word indexes')
# parser.add_argument('--spat',type=int, nargs='+', help='sample indexes')

args = parser.parse_args()

home_path = r"C:\Users\karim\PycharmProjects"
abs_path_project = home_path + "\m2LSpTemp"
abs_path_project = home_path + "\M2T-Interpretable"

# # Fix manually ---------------------------------------------------------------------------------
# run_id = 'ton5mfwh'
# args.dataset_name = 'kit'

run_id = args.run_id

run_id = 'ton5mfwh'
args.dataset_name = 'kit'
args.path = abs_path_project + f"\models\Interpretable_MC_{args.dataset_name}_f\model_{run_id}"

# From the loaded model -------------------------------------------------------------------
Expand Down Expand Up @@ -224,16 +227,14 @@ def save_spatial_attention_figs(args,save_spTemp=True):
for idp, part in enumerate(body_parts):
intensity[:,:,:,part] = spat_temp[:,:,:,idp:idp+1]

# Isolate root from Torso for HumanML3D
intensity[:, :, :, 0] = spat_temp[:, :, :, 0]

# --------------------------- Human Pose 3D Animation ---------------------------------------------------------

for ll, id_sample in enumerate((range(args.n_gifs))):#[0,5,13,25,42,43,53,54,59]
pred = preds[id_sample]
bleus = [bleu_score([pred], [[ref]]) for ref in trgs[id_sample]]
id_best_ref = np.argmax(bleus)
if np.max(bleus)>=.3:
if np.max(bleus)>=.3: # Select predictions above a threshold
trg = trgs[id_sample][id_best_ref]
trg = ' '.join(trg)
start_pad = lens[id_sample]
Expand Down

0 comments on commit d771bd9

Please sign in to comment.