Skip to content

Commit

Permalink
cleanup some more
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeez-amjad committed Sep 9, 2024
1 parent 07076db commit cfd1c5c
Showing 1 changed file with 82 additions and 46 deletions.
128 changes: 82 additions & 46 deletions server/src/rpc/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,114 +15,150 @@ use polay_db::{
};
use polay_proto::polay::{
EvalOutcome, EvalScore, MeaningfulEvalScore, RecordEvalRequest, RecordEvalResponse,
VersionedPrompt,
};

/// Record an eval run and compare it to a previous run
pub async fn record_eval(
request: Request<RecordEvalRequest>,
) -> Result<Response<RecordEvalResponse>, Status> {
let message = request.into_inner();
let eval = message
.eval
.ok_or_else(|| Status::invalid_argument("Missing eval"))?;
let prompt = message
.prompt
.ok_or_else(|| Status::invalid_argument("Missing prompt"))?;
let eval_scores = message.eval_scores;
let base_version = message.base_version;
let (eval, prompt, eval_scores, base_version) = extract_message_data(message)?;

let mut conn = establish_connection();

// Create a new prompt version if it doesn't exist
let mut repo = DieselRepository::new(&mut conn, prompt_version::table);
let prompt_version = prompt_version::table
let prompt_version = get_or_create_prompt_version(&mut conn, &prompt)?;
let base_prompt_version = get_base_prompt_version(&mut conn, &prompt, base_version)?;
let existing_eval_version = get_or_create_eval_version(&mut conn, &eval, &prompt_version)?;
let previous_eval_result = get_previous_eval_result(&mut conn, &base_prompt_version)?;

let scores: EvalRunScores = convert_eval_scores(eval_scores);
create_new_eval_result(&mut conn, &existing_eval_version, &scores)?;

create_response(previous_eval_result, scores)
}

fn extract_message_data(message: RecordEvalRequest) -> Result<(polay_proto::polay::Eval, VersionedPrompt, Vec<EvalScore>, Option<String>), Status> {
let eval = message.eval.ok_or_else(|| Status::invalid_argument("Missing eval"))?;
let prompt = message.prompt.ok_or_else(|| Status::invalid_argument("Missing prompt"))?;
Ok((eval, prompt, message.eval_scores, message.base_version))
}

fn get_or_create_prompt_version(conn: &mut PgConnection, prompt: &VersionedPrompt) -> Result<PromptVersion, Status> {
let mut repo = DieselRepository::new(conn, prompt_version::table);

let existing = repo.table
.filter(prompt_version::version.eq(&prompt.version))
.first::<PromptVersion>(repo.connection)
.optional()
.map_err(|_| Status::internal("Failed to fetch prompt version"))?
.unwrap_or_else(|| {
.map_err(|_| Status::internal("Failed to fetch prompt version"))?;

match existing {
Some(version) => Ok(version),
None => {
let new_prompt_version = InsertablePromptVersion {
name: prompt.name.clone(),
version: prompt.version.clone(),
created_at: Utc::now(),
};
repo.create(&new_prompt_version)
.expect("Failed to create new prompt version")
});

// Get the prompt version to compare against
let base_prompt_version = if let Some(base_version) = base_version {
prompt_version::table
repo.create(&new_prompt_version).map_err(|_| Status::internal("Failed to create prompt version"))
}
}
}

fn get_base_prompt_version(conn: &mut PgConnection, prompt: &VersionedPrompt, base_version: Option<String>) -> Result<PromptVersion, Status> {
let repo = DieselRepository::new(conn, prompt_version::table);

if let Some(base_version) = base_version {
repo.table
.filter(prompt_version::version.eq(&base_version))
.first::<PromptVersion>(&mut conn)
.first::<PromptVersion>(conn)
.optional()
.map_err(|_| Status::internal("Failed to fetch base prompt version"))?
.ok_or_else(|| Status::invalid_argument("Base version not found"))?
.ok_or_else(|| Status::invalid_argument("Base version not found"))
} else {
prompt_version::table
repo.table
.filter(prompt_version::version.lt(&prompt.version))
.order(prompt_version::version.desc())
.first::<PromptVersion>(&mut conn)
.first::<PromptVersion>(conn)
.optional()
.map_err(|_| Status::internal("Failed to fetch base prompt version"))?
.ok_or_else(|| Status::not_found("No previous version found"))?
};
.ok_or_else(|| Status::not_found("No previous version found"))
}
}

fn get_or_create_eval_version(conn: &mut PgConnection, eval: &polay_proto::polay::Eval, prompt_version: &PromptVersion) -> Result<Eval, Status> {
let mut repo = DieselRepository::new(conn, eval::table);

// Create a new eval version if it doesn't exist
let mut repo = DieselRepository::new(&mut conn, eval::table);
let existing_eval_version = eval::table
let existing = repo.table
.inner_join(prompt_version::table)
.filter(
eval::name
.eq(&eval.name)
.and(prompt_version::version.eq(&prompt.version)),
.and(prompt_version::version.eq(&prompt_version.version)),
)
.select(eval::all_columns)
.first::<Eval>(repo.connection)
.optional()
.map_err(|_| Status::internal("Failed to fetch eval version"))?
.unwrap_or_else(|| {
.map_err(|_| Status::internal("Failed to fetch eval version"))?;

match existing {
Some(version) => Ok(version),
None => {
let new_eval_version = InsertableEval {
name: eval.name,
name: eval.name.clone(),
prompt_version_id: prompt_version.id,
created_at: Utc::now(),
};
repo.create(&new_eval_version)
.expect("Failed to create new eval version")
});

// Get the last eval result for the base prompt version
let mut repo = DieselRepository::new(&mut conn, eval_result::table);
let previous_eval_result = eval_result::table
repo.create(&new_eval_version).map_err(|_| Status::internal("Failed to create eval version"))
}
}
}

fn get_previous_eval_result(conn: &mut PgConnection, base_prompt_version: &PromptVersion) -> Result<Option<EvalResult>, Status> {
let repo = DieselRepository::new(conn, eval_result::table);

eval_result::table
.inner_join(eval::table)
.filter(eval::prompt_version_id.eq(base_prompt_version.id))
.order(eval_result::created_at.desc())
.select(eval_result::all_columns)
.first::<EvalResult>(repo.connection)
.optional()
.map_err(|_| Status::internal("Failed to fetch previous eval result"))?;
.map_err(|_| Status::internal("Failed to fetch previous eval result"))
}

let scores: EvalRunScores = eval_scores
fn convert_eval_scores(eval_scores: Vec<EvalScore>) -> EvalRunScores {
eval_scores
.into_iter()
.map(|score| SingleEvalScore {
eval_hash: score.eval_hash,
score: score.score,
})
.collect();
.collect()
}

fn create_new_eval_result(conn: &mut PgConnection, existing_eval_version: &Eval, scores: &EvalRunScores) -> Result<EvalResult, Status> {
let mut repo = DieselRepository::new(conn, eval_result::table);

repo.create(&InsertableEvalResult {
let new_eval_result = repo.create(&InsertableEvalResult {
eval_id: existing_eval_version.id,
scores: serde_json::to_value(&scores)
scores: serde_json::to_value(scores)
.map_err(|_| Status::internal("Failed to serialize scores"))?,
created_at: Utc::now(),
})
.map_err(|_| Status::internal("Failed to create new eval result"))?;
});

new_eval_result.map_err(|_| Status::internal("Failed to create eval result"))
}

fn create_response(previous_eval_result: Option<EvalResult>, current_scores: EvalRunScores) -> Result<Response<RecordEvalResponse>, Status> {
if let Some(previous_result) = previous_eval_result {
let previous_results: EvalRunScores = serde_json::from_value(previous_result.scores)
.map_err(|_| Status::internal("Failed to deserialize previous scores"))?;

let (result, meaningful_scores) = compare_results(&previous_results, scores);
let (result, meaningful_scores) = compare_results(&previous_results, current_scores);

Ok(Response::new(RecordEvalResponse {
outcome: result.into(),
Expand Down

0 comments on commit cfd1c5c

Please sign in to comment.