Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working re-implementation #6

Merged
merged 3 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,074 changes: 21 additions & 1,053 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ indicatif = { version = "0.17.7", features = ["rayon"] }
itertools = "0.12.0"
ndarray = "0.15.6"
num_cpus = "1.16.0"
polars = { version = "0.35.4", features = ["lazy", "ndarray", "performant", "sign"] }
rayon = "1.8.1"
rusqlite = "0.30.0"
serde = "1.0.193"
serde = {version = "1.0.193", features = ["derive"]}
serde_rusqlite = "0.34.0"
tempfile = "3.9.0"

Expand Down
114 changes: 114 additions & 0 deletions src/align.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use std::collections::HashMap;

use crate::io::gwas::{GwasLine, GwasSumstats};
use crate::io::tagging::{TagInfo, TagLine};

pub fn check_aligned(gwas: &GwasSumstats, taginfo: &TagInfo) -> bool {
gwas.predictors == taginfo.predictors
}

pub fn align_gwas_taginfo(gwas: &GwasSumstats, taginfo: &TagInfo) -> (GwasSumstats, TagInfo) {
let gwas_predictor_to_idx: HashMap<String, usize> = gwas
.results
.iter()
.enumerate()
.map(|(i, x)| (x.predictor.clone(), i))
.collect();

let tag_predictor_to_idx: HashMap<String, usize> = taginfo
.taglines
.iter()
.enumerate()
.map(|(i, x)| (x.predictor.clone(), i))
.collect();

let common_predictors: Vec<String> = taginfo
.predictors
.clone()
.into_iter()
.filter(|x| gwas_predictor_to_idx.contains_key(x))
.collect();

let gwas_results: Vec<GwasLine> = common_predictors
.iter()
.map(|x| gwas_predictor_to_idx[x])
.map(|x| gwas.results[x].clone())
.collect();

let taginfo_taglines: Vec<TagLine> = common_predictors
.iter()
.map(|x| tag_predictor_to_idx[x])
.map(|x| taginfo.taglines[x].clone())
.collect();

(
GwasSumstats::new(gwas.phenotype.clone(), gwas_results),
TagInfo::new(
taginfo_taglines,
taginfo.annotation_names.clone(),
taginfo.ssums.clone(),
),
)
}

pub fn align_gwas_gwas(
left: &GwasSumstats,
right: &GwasSumstats,
taginfo: &TagInfo,
) -> (GwasSumstats, GwasSumstats, TagInfo) {
let left_predictor_to_idx: HashMap<String, usize> = left
.results
.iter()
.enumerate()
.map(|(i, x)| (x.predictor.clone(), i))
.collect();

let right_predictor_to_idx: HashMap<String, usize> = right
.results
.iter()
.enumerate()
.map(|(i, x)| (x.predictor.clone(), i))
.collect();

let tag_predictor_to_idx: HashMap<String, usize> = taginfo
.taglines
.iter()
.enumerate()
.map(|(i, x)| (x.predictor.clone(), i))
.collect();

let common_predictors: Vec<String> = left
.predictors
.clone()
.into_iter()
.filter(|x| right_predictor_to_idx.contains_key(x) && tag_predictor_to_idx.contains_key(x))
.collect();

let left_results: Vec<GwasLine> = common_predictors
.iter()
.map(|x| left_predictor_to_idx[x])
.map(|x| left.results[x].clone())
.collect();

let right_results: Vec<GwasLine> = common_predictors
.iter()
.map(|x| right_predictor_to_idx[x])
.map(|x| right.results[x].clone())
.collect();

let taginfo_taglines: Vec<TagLine> = common_predictors
.iter()
.map(|x| tag_predictor_to_idx[x])
.map(|x| taginfo.taglines[x].clone())
.collect();

(
GwasSumstats::new(left.phenotype.clone(), left_results),
GwasSumstats::new(right.phenotype.clone(), right_results),
TagInfo::new(
taginfo_taglines,
taginfo.annotation_names.clone(),
taginfo.ssums.clone(),
),
)
}
7 changes: 4 additions & 3 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use serde_rusqlite::to_params_named;
use std::collections::HashSet;
use std::path::Path;

use crate::hsq::{HsqResult, RgResult};
use crate::hsq::HsqResult;
use crate::rg::RgResult;

pub struct DbConnection {
pub conn: Connection,
Expand Down Expand Up @@ -91,7 +92,7 @@ impl DbConnection {

fn insert_hsq_data(tx: &Transaction, rows: &[HsqResult]) -> Result<()> {
let mut stmt = tx.prepare_cached(
"INSERT INTO h2 (phenotype, component, estimate, std_error) VALUES
"INSERT INTO h2 (phenotype, component, estimate, std_error) VALUES
(:phenotype, :component, :estimate, :se)",
)?;

Expand All @@ -105,7 +106,7 @@ fn insert_hsq_data(tx: &Transaction, rows: &[HsqResult]) -> Result<()> {

fn insert_rg_data(tx: &Transaction, rows: &[RgResult]) -> Result<()> {
let mut stmt = tx.prepare_cached(
"INSERT INTO rg (phenotype1, phenotype2, component, estimate, std_error) VALUES
"INSERT INTO rg (phenotype1, phenotype2, component, estimate, std_error) VALUES
(:phenotype1, :phenotype2, :component, :estimate, :se)",
)?;

Expand Down
94 changes: 89 additions & 5 deletions src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::io::gwas::AlignedGwasSumstats;
use crate::io::gwas::GwasSumstats;

use anyhow::{anyhow, Result};

Expand Down Expand Up @@ -165,8 +165,8 @@ pub fn solve_sums_wrapper(

pub fn solve_cors_wrapper(
tagging: &[f64],
gwas_sumstats_1: &AlignedGwasSumstats,
gwas_sumstats_2: &AlignedGwasSumstats,
gwas_sumstats_1: &GwasSumstats,
gwas_sumstats_2: &GwasSumstats,
category_vals: &[Vec<f64>],
category_contribs: &[Vec<f64>],
options: Option<SolveCorsOptions>,
Expand All @@ -192,10 +192,10 @@ pub fn solve_cors_wrapper(
tagging.as_ptr(),
svars.as_ptr(),
ssums.as_ptr(),
gwas_sumstats_1.sample_sizes.as_ptr(),
gwas_sumstats_1.sample_size.as_ptr(),
gwas_sumstats_1.chisq.as_ptr(),
gwas_sumstats_1.rhos.as_ptr(),
gwas_sumstats_2.sample_sizes.as_ptr(),
gwas_sumstats_2.sample_size.as_ptr(),
gwas_sumstats_2.chisq.as_ptr(),
gwas_sumstats_2.rhos.as_ptr(),
options.as_ref().and_then(|x| x.tol).unwrap_or(0.0001),
Expand All @@ -213,3 +213,87 @@ pub fn solve_cors_wrapper(
cept,
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_solve_sums_wrapper() {
let tagging = vec![1.0, 2.0, 3.0];
let chisqs = vec![1.0, 2.0, 3.0];
let sample_sizes = vec![1.0, 2.0, 3.0];
let category_vals = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let category_contribs = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];

let result = solve_sums_wrapper(
&tagging,
&chisqs,
&sample_sizes,
&category_vals,
&category_contribs,
None,
)
.unwrap();

// Check sizes
assert_eq!(result.stats.len(), 3 * (2 + 1 + 2));
assert_eq!(result.likes.len(), 11);
assert_eq!(result.cohers.len(), 2 * 2);
assert_eq!(result.influs.len(), 2);

// Check values
let desired_stats = [
0.4791464197764975,
-0.5588545017977516,
-0.07970808202125412,
0.2556046190573969,
1.3577311773082386,
2.3675053949121336,
4.497316861083112,
2.2773633495918704,
0.763803940571641,
5.227708866852336,
146.47111303693012,
146.47111303693012,
0.0,
87.88266782215808,
439.41333911079033,
];
for (i, &x) in result.stats.iter().enumerate() {
assert!((x - desired_stats[i]).abs() < 1e-10);
}

let desired_likes = [
-3.5411094874599045,
-3.3166019866559435,
-2.601387310875233,
0.6760515111840476,
-0.8075528916021616,
1.0601008469882118,
1.0,
0.0,
0.0,
0.0,
0.0,
];
for (i, &x) in result.likes.iter().enumerate() {
assert!((x - desired_likes[i]).abs() < 1e-10);
}

let desired_cohers = [
5.605081794938058,
-10.322278458928105,
-10.322278458928107,
20.22585894898246,
];
for (i, &x) in result.cohers.iter().enumerate() {
assert!((x - desired_cohers[i]).abs() < 1e-10);
}

let desired_influs = [4.363636363636364, 1.8545454545454547];
for (i, &x) in result.influs.iter().enumerate() {
assert!((x - desired_influs[i]).abs() < 1e-10);
}
}
}
Loading
Loading