Skip to content

Commit

Permalink
feat(ml): pass@ metric
Browse files Browse the repository at this point in the history
  • Loading branch information
lvzii committed Dec 11, 2024
1 parent 46d4d80 commit ff7d269
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/nlpertools/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
from .utils.package import *


def estimate_pass_at_k(num_samples:list, num_correct:list, k):
"""
copy from https://huggingface.co/spaces/evaluate-metric/code_eval/blob/main/code_eval.py
num_samples: list
"""
"""Estimates pass@k of each problem and returns them in an array."""

def estimator(n: int, c: int, k: int) -> float:
"""Calculates 1 - comb(n - c, k) / comb(n, k)."""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)

return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])


def calc_llm_train_activation_memory(
model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
):
Expand Down

0 comments on commit ff7d269

Please sign in to comment.