From e525318d6bb8ec221053ff1aeaed550badbb428b Mon Sep 17 00:00:00 2001 From: Mufeez Amjad Date: Sun, 8 Sep 2024 22:50:16 -0400 Subject: [PATCH] cleanup some more --- server/src/rpc/eval.rs | 200 +++++++++++++++++++++++++++-------------- 1 file changed, 131 insertions(+), 69 deletions(-) diff --git a/server/src/rpc/eval.rs b/server/src/rpc/eval.rs index 19be9a4..8f3df6b 100644 --- a/server/src/rpc/eval.rs +++ b/server/src/rpc/eval.rs @@ -15,6 +15,7 @@ use polay_db::{ }; use polay_proto::polay::{ EvalOutcome, EvalScore, MeaningfulEvalScore, RecordEvalRequest, RecordEvalResponse, + VersionedPrompt, }; /// Record an eval run and compare it to a previous run @@ -22,128 +23,189 @@ pub async fn record_eval( request: Request, ) -> Result, Status> { let message = request.into_inner(); + let eval = message .eval .ok_or_else(|| Status::invalid_argument("Missing eval"))?; let prompt = message .prompt .ok_or_else(|| Status::invalid_argument("Missing prompt"))?; + let eval_scores = message.eval_scores; let base_version = message.base_version; let mut conn = establish_connection(); - // Create a new prompt version if it doesn't exist - let mut repo = DieselRepository::new(&mut conn, prompt_version::table); - let prompt_version = prompt_version::table + // Get or create the prompt version being evaluated + let prompt_version = get_or_create_prompt_version(&mut conn, &prompt)?; + let existing_eval_version = get_or_create_eval_version(&mut conn, &eval, &prompt_version)?; + + // Get the previous eval result for the base version + let previous_eval_result = get_previous_eval_result(&mut conn, &prompt, base_version)?; + + let scores: EvalRunScores = convert_eval_scores(eval_scores); + create_new_eval_result(&mut conn, &existing_eval_version, &scores)?; + + if let Some(previous_result) = previous_eval_result { + let previous_results: EvalRunScores = serde_json::from_value(previous_result.scores) + .map_err(|_| Status::internal("Failed to deserialize previous scores"))?; + + let (result, meaningful_scores) = compare_results(&previous_results, scores); + + Ok(Response::new(RecordEvalResponse { + outcome: result.into(), + previous_eval_scores: previous_results + .into_iter() + .map(|res| EvalScore { + eval_hash: res.eval_hash, + score: res.score, + }) + .collect(), + meaningful_eval_scores: meaningful_scores, + message: "Success".to_string(), + })) + } else { + Ok(Response::new(RecordEvalResponse { + outcome: EvalOutcome::NoChange.into(), + previous_eval_scores: Vec::new(), + meaningful_eval_scores: Vec::new(), + message: "Success".to_string(), + })) + } +} + +fn get_or_create_prompt_version( + conn: &mut PgConnection, + prompt: &VersionedPrompt, +) -> Result { + let mut repo = DieselRepository::new(conn, prompt_version::table); + + let existing = repo + .table .filter(prompt_version::version.eq(&prompt.version)) .first::(repo.connection) .optional() - .map_err(|_| Status::internal("Failed to fetch prompt version"))? - .unwrap_or_else(|| { + .map_err(|_| Status::internal("Failed to fetch prompt version"))?; + + match existing { + Some(version) => Ok(version), + None => { let new_prompt_version = InsertablePromptVersion { name: prompt.name.clone(), version: prompt.version.clone(), created_at: Utc::now(), }; + repo.create(&new_prompt_version) - .expect("Failed to create new prompt version") - }); + .map_err(|_| Status::internal("Failed to create prompt version")) + } + } +} - // Get the prompt version to compare against - let base_prompt_version = if let Some(base_version) = base_version { - prompt_version::table - .filter(prompt_version::version.eq(&base_version)) - .first::(&mut conn) - .optional() - .map_err(|_| Status::internal("Failed to fetch base prompt version"))? - .ok_or_else(|| Status::invalid_argument("Base version not found"))? - } else { - prompt_version::table - .filter(prompt_version::version.lt(&prompt.version)) - .order(prompt_version::version.desc()) - .first::(&mut conn) - .optional() - .map_err(|_| Status::internal("Failed to fetch base prompt version"))? - .ok_or_else(|| Status::not_found("No previous version found"))? - }; +fn get_or_create_eval_version( + conn: &mut PgConnection, + eval: &polay_proto::polay::Eval, + prompt_version: &PromptVersion, +) -> Result { + let mut repo = DieselRepository::new(conn, eval::table); - // Create a new eval version if it doesn't exist - let mut repo = DieselRepository::new(&mut conn, eval::table); - let existing_eval_version = eval::table + let existing = repo + .table .inner_join(prompt_version::table) .filter( eval::name .eq(&eval.name) - .and(prompt_version::version.eq(&prompt.version)), + .and(prompt_version::version.eq(&prompt_version.version)), ) .select(eval::all_columns) .first::(repo.connection) .optional() - .map_err(|_| Status::internal("Failed to fetch eval version"))? - .unwrap_or_else(|| { + .map_err(|_| Status::internal("Failed to fetch eval version"))?; + + match existing { + Some(version) => Ok(version), + None => { let new_eval_version = InsertableEval { - name: eval.name, + name: eval.name.clone(), prompt_version_id: prompt_version.id, created_at: Utc::now(), }; + repo.create(&new_eval_version) - .expect("Failed to create new eval version") - }); + .map_err(|_| Status::internal("Failed to create eval version")) + } + } +} + +fn get_previous_eval_result( + conn: &mut PgConnection, + prompt: &VersionedPrompt, + base_version: Option, +) -> Result, Status> { + let base_prompt_version = get_base_prompt_version(conn, prompt, base_version)?; + + let repo = DieselRepository::new(conn, eval_result::table); - // Get the last eval result for the base prompt version - let mut repo = DieselRepository::new(&mut conn, eval_result::table); - let previous_eval_result = eval_result::table + eval_result::table .inner_join(eval::table) .filter(eval::prompt_version_id.eq(base_prompt_version.id)) .order(eval_result::created_at.desc()) .select(eval_result::all_columns) .first::(repo.connection) .optional() - .map_err(|_| Status::internal("Failed to fetch previous eval result"))?; + .map_err(|_| Status::internal("Failed to fetch previous eval result")) +} + +fn get_base_prompt_version( + conn: &mut PgConnection, + prompt: &VersionedPrompt, + base_version: Option, +) -> Result { + let repo = DieselRepository::new(conn, prompt_version::table); + + if let Some(base_version) = base_version { + repo.table + .filter(prompt_version::version.eq(&base_version)) + .first::(conn) + .optional() + .map_err(|_| Status::internal("Failed to fetch base prompt version"))? + .ok_or_else(|| Status::invalid_argument("Base version not found")) + } else { + repo.table + .filter(prompt_version::version.lt(&prompt.version)) + .order(prompt_version::version.desc()) + .first::(conn) + .optional() + .map_err(|_| Status::internal("Failed to fetch base prompt version"))? + .ok_or_else(|| Status::not_found("No previous version found")) + } +} - let scores: EvalRunScores = eval_scores +fn convert_eval_scores(eval_scores: Vec) -> EvalRunScores { + eval_scores .into_iter() .map(|score| SingleEvalScore { eval_hash: score.eval_hash, score: score.score, }) - .collect(); + .collect() +} + +fn create_new_eval_result( + conn: &mut PgConnection, + existing_eval_version: &Eval, + scores: &EvalRunScores, +) -> Result { + let mut repo = DieselRepository::new(conn, eval_result::table); - repo.create(&InsertableEvalResult { + let new_eval_result = repo.create(&InsertableEvalResult { eval_id: existing_eval_version.id, - scores: serde_json::to_value(&scores) + scores: serde_json::to_value(scores) .map_err(|_| Status::internal("Failed to serialize scores"))?, created_at: Utc::now(), - }) - .map_err(|_| Status::internal("Failed to create new eval result"))?; - - if let Some(previous_result) = previous_eval_result { - let previous_results: EvalRunScores = serde_json::from_value(previous_result.scores) - .map_err(|_| Status::internal("Failed to deserialize previous scores"))?; - - let (result, meaningful_scores) = compare_results(&previous_results, scores); + }); - Ok(Response::new(RecordEvalResponse { - outcome: result.into(), - previous_eval_scores: previous_results - .into_iter() - .map(|res| EvalScore { - eval_hash: res.eval_hash, - score: res.score, - }) - .collect(), - meaningful_eval_scores: meaningful_scores, - message: "Success".to_string(), - })) - } else { - Ok(Response::new(RecordEvalResponse { - outcome: EvalOutcome::NoChange.into(), - previous_eval_scores: Vec::new(), - meaningful_eval_scores: Vec::new(), - message: "Success".to_string(), - })) - } + new_eval_result.map_err(|_| Status::internal("Failed to create eval result")) } fn compare_results(