Skip to content

Commit

Permalink
[README]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Jan 29, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent bbb05f4 commit 1c808b9
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ Convert any model into a r1-like reasoning hyper-intelligent agent. Leverages TR
- [Open R1 Blog](https://huggingface.co/blog/open-r1)
- [GRPO Documentation from trl](https://huggingface.co/docs/trl/main/en/grpo_trainer)
- [Huggingface Docs](https://huggingface.co/docs/transformers/main/en/index)
- [GRPO Docs](https://huggingface.co/docs/trl/main/en/grpo_trainer)


## Installation
50 changes: 50 additions & 0 deletions agentgym/majority_voting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

import torch
from litellm import encode


class BaseModel:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)


class MajorityVoting(BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.models = []

def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)

def compute_accuracy(self, answer: str, target: str) -> float:
# first convert to tensors and then compute cosine similarity
answer_tokens = encode(model="gpt-4o", text=answer)
target_tokens = encode(model="gpt-4o", text=target)

answer_tensor = torch.tensor(
answer_tokens, dtype=torch.float32
)
target_tensor = torch.tensor(
target_tokens, dtype=torch.float32
)

if answer_tensor.dim() == 1:
answer_tensor = answer_tensor.unsqueeze(0)
if target_tensor.dim() == 1:
target_tensor = target_tensor.unsqueeze(0)

return (
torch.cosine_similarity(
answer_tensor, target_tensor, dim=1
)
.mean()
.item()
)


vote = MajorityVoting()

print(vote.compute_accuracy("hello", "chicken"))
1 change: 1 addition & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
sft_model="gpt2",
sft_dataset="stanfordnlp/imdb",
sft_args=SFTConfig(output_dir="/tmp"),
only_grpo=True
)

r1_pipeline.run()

0 comments on commit 1c808b9

Please sign in to comment.