Skip to content

Commit

Permalink
Add attn_implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Jan 28, 2025
1 parent 7561957 commit 236e99d
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions scripts/run_rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
parser.add_argument("--save_all", action="store_true", default=False, help="save all results (include scores per instance)")
parser.add_argument("--force_truncation", action="store_true", default=False, help="force truncation (if model errors)")
parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32", "float64"], help="set PyTorch dtype (default: float16)")
parser.add_argument("--attn_implementation", type=str, default=None, choices=["eager", "sdpa", "flash_attention_2"], help="Attention implementation to use (default: None)")
# fmt: on
args = parser.parse_args()
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)
Expand Down Expand Up @@ -283,6 +284,11 @@ def main():
"torch_dtype": torch_dtype,
}

# if attn_implementation is not specified, this falls back to Hugging Face's default
# strategy (which chooses between sdpa and eager depending on pytorch version)
if args.attn_implementation:
model_kwargs["attn_implementation"] = args.attn_implementation

model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code)
reward_pipe = pipeline_builder(
"text-classification", # often not used
Expand Down

0 comments on commit 236e99d

Please sign in to comment.