From ff7d26962a2298a7c11ecc865b186b2d5b0c159c Mon Sep 17 00:00:00 2001 From: lvzi Date: Wed, 11 Dec 2024 17:35:39 +0800 Subject: [PATCH] feat(ml): pass@ metric --- src/nlpertools/ml.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/nlpertools/ml.py b/src/nlpertools/ml.py index 5ed0f93..a9c6add 100644 --- a/src/nlpertools/ml.py +++ b/src/nlpertools/ml.py @@ -17,6 +17,28 @@ from .utils.package import * +def estimate_pass_at_k(num_samples:list, num_correct:list, k): + """ + copy from https://huggingface.co/spaces/evaluate-metric/code_eval/blob/main/code_eval.py + num_samples: list + """ + """Estimates pass@k of each problem and returns them in an array.""" + + def estimator(n: int, c: int, k: int) -> float: + """Calculates 1 - comb(n - c, k) / comb(n, k).""" + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + if isinstance(num_samples, int): + num_samples_it = itertools.repeat(num_samples, len(num_correct)) + else: + assert len(num_samples) == len(num_correct) + num_samples_it = iter(num_samples) + + return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) + + def calc_llm_train_activation_memory( model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1 ):