diff --git a/extract_features.py b/extract_features.py index a2c2e47..47a2876 100755 --- a/extract_features.py +++ b/extract_features.py @@ -31,7 +31,7 @@ args = parser.parse_args() - params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048, + params_model = {'bsize': args.batch-size , 'word_emb_dim': 300, 'enc_lstm_dim': 2048, 'pool_type': 'max', 'dpout_model': 0.0, 'version': args.version} model = InferSent(params_model) model.load_state_dict(torch.load(args.model_path))