diff --git a/scripts/run_rewardbench.py b/scripts/run_rewardbench.py index 4196c94..54c88e2 100644 --- a/scripts/run_rewardbench.py +++ b/scripts/run_rewardbench.py @@ -79,6 +79,7 @@ def main(): parser.add_argument("--save_all", action="store_true", default=False, help="save all results (include scores per instance)") parser.add_argument("--force_truncation", action="store_true", default=False, help="force truncation (if model errors)") parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32", "float64"], help="set PyTorch dtype (default: float16)") + parser.add_argument("--attn_implementation", type=str, default=None, choices=["eager", "sdpa", "flash_attention_2"], help="Attention implementation to use (default: None)") # fmt: on args = parser.parse_args() args.torch_dtype = torch_dtype_mapping(args.torch_dtype) @@ -283,6 +284,11 @@ def main(): "torch_dtype": torch_dtype, } + # if attn_implementation is not specified, this falls back to Hugging Face's default + # strategy (which chooses between sdpa and eager depending on pytorch version) + if args.attn_implementation: + model_kwargs["attn_implementation"] = args.attn_implementation + model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code) reward_pipe = pipeline_builder( "text-classification", # often not used