-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from worldcoin/0xkitsune/mpc-coordinator
feat(coordinator): Initial logic for coordinator
- Loading branch information
Showing
22 changed files
with
2,333 additions
and
56 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
use std::path::PathBuf; | ||
|
||
use aws_config::BehaviorVersion; | ||
use clap::Parser; | ||
use mpc::config::CoordinatorConfig; | ||
use mpc::coordinator::Coordinator; | ||
use telemetry_batteries::metrics::batteries::StatsdBattery; | ||
use telemetry_batteries::tracing::batteries::DatadogBattery; | ||
use tracing_subscriber::layer::SubscriberExt; | ||
use tracing_subscriber::util::SubscriberInitExt; | ||
|
||
pub const SERVICE_NAME: &str = "mpc-coordinator"; | ||
|
||
pub const METRICS_HOST: &str = "localhost"; | ||
pub const METRICS_PORT: u16 = 8125; | ||
pub const METRICS_QUEUE_SIZE: usize = 5000; | ||
pub const METRICS_BUFFER_SIZE: usize = 1024; | ||
pub const METRICS_PREFIX: &str = "mpc-coordinator"; | ||
|
||
#[derive(Parser)] | ||
#[clap(version)] | ||
pub struct Args { | ||
#[clap(short, long, env)] | ||
telemetry: bool, | ||
|
||
#[clap(short, long, env)] | ||
config: Option<PathBuf>, | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> eyre::Result<()> { | ||
dotenv::dotenv().ok(); | ||
|
||
let args = Args::parse(); | ||
|
||
if args.telemetry { | ||
DatadogBattery::init(None, SERVICE_NAME, None, true); | ||
|
||
StatsdBattery::init( | ||
METRICS_HOST, | ||
METRICS_PORT, | ||
METRICS_QUEUE_SIZE, | ||
METRICS_BUFFER_SIZE, | ||
Some(METRICS_PREFIX), | ||
)?; | ||
} else { | ||
tracing_subscriber::registry() | ||
.with(tracing_subscriber::fmt::layer().pretty().compact()) | ||
.with(tracing_subscriber::EnvFilter::from_default_env()) | ||
.init(); | ||
} | ||
|
||
let mut settings = config::Config::builder(); | ||
|
||
if let Some(path) = args.config { | ||
settings = settings.add_source(config::File::from(path).required(true)); | ||
} | ||
|
||
let settings = settings | ||
.add_source(config::Environment::with_prefix("MPC").separator("__")) | ||
.build()?; | ||
|
||
let config = settings.try_deserialize::<CoordinatorConfig>()?; | ||
|
||
let coordinator = | ||
Coordinator::new(vec![], "template_queue_url", "distance_queue_url") | ||
.await?; | ||
|
||
coordinator.spawn().await?; | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CREATE TABLE masks ( | ||
id BIGINT PRIMARY KEY, | ||
mask BYTEA NOT NULL | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CREATE TABLE shares ( | ||
id BIGINT PRIMARY KEY, | ||
share BYTEA NOT NULL | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[toolchain] | ||
channel = "nightly-2024-01-26" | ||
channel = "nightly-2024-01-25" | ||
components = ["rustc-dev", "rustc", "cargo", "rustfmt", "clippy"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#![allow(unused)] | ||
use std::cmp::min; | ||
use std::mem::swap; | ||
use std::thread::JoinHandle; | ||
|
||
use rayon::prelude::*; | ||
|
||
use crate::distance::Bits; | ||
use crate::encoded_bits::EncodedBits; | ||
|
||
pub fn distances<'a>( | ||
query: &'a EncodedBits, | ||
db: &'a [EncodedBits], | ||
) -> impl Iterator<Item = [u16; 31]> + 'a { | ||
const BATCH: usize = 10_000; | ||
|
||
// Prepare 31 rotations of query in advance | ||
let rotations: Box<[_]> = (-15..=15).map(|r| query.rotated(r)).collect(); | ||
|
||
// Iterate over a batch of database entries | ||
db.chunks(BATCH).flat_map(move |chunk| { | ||
let mut results = [[0_u16; 31]; BATCH]; | ||
|
||
// Parallel computation over batch | ||
results.par_iter_mut().zip(chunk.par_iter()).for_each( | ||
|(result, entry)| { | ||
// Compute dot product for each rotation | ||
for (d, rotation) in result.iter_mut().zip(rotations.iter()) { | ||
*d = rotation.dot(entry); | ||
} | ||
}, | ||
); | ||
|
||
// Sequentially output results | ||
results.into_iter().take(chunk.len()) | ||
}) | ||
} | ||
|
||
pub fn denominators<'a>( | ||
query: &'a Bits, | ||
db: &'a [Bits], | ||
) -> impl Iterator<Item = [u16; 31]> + 'a { | ||
const BATCH: usize = 10_000; | ||
|
||
// Prepare 31 rotations of query in advance | ||
let rotations: Box<[_]> = (-15..=15).map(|r| query.rotated(r)).collect(); | ||
|
||
// Iterate over a batch of database entries | ||
db.chunks(BATCH).flat_map(move |chunk| { | ||
// Parallel computation over batch | ||
let results = chunk | ||
.par_iter() | ||
.map(|(entry)| { | ||
let mut result = [0_u16; 31]; | ||
// Compute dot product for each rotation | ||
for (d, rotation) in result.iter_mut().zip(rotations.iter()) { | ||
*d = rotation.dot(entry); | ||
} | ||
result | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
// Sequentially output results | ||
results.into_iter().take(chunk.len()) | ||
}) | ||
} | ||
|
||
//TODO: move this to the benches file | ||
#[cfg(feature = "bench")] | ||
pub mod benches { | ||
use core::hint::black_box; | ||
|
||
use criterion::Criterion; | ||
use rand::{thread_rng, Rng}; | ||
|
||
use super::*; | ||
|
||
pub fn group(c: &mut Criterion) { | ||
let mut rng = thread_rng(); | ||
let mut g = c.benchmark_group("generic"); | ||
|
||
g.bench_function("distances 31x1000", |bench| { | ||
let a: EncodedBits = rng.gen(); | ||
let b: Box<[EncodedBits]> = (0..1000).map(|_| rng.gen()).collect(); | ||
bench.iter(|| { | ||
black_box(distances(black_box(&a), black_box(&b))) | ||
.for_each(|_| {}) | ||
}) | ||
}); | ||
|
||
g.bench_function("denominators 31x1000", |bench| { | ||
let a: Bits = rng.gen(); | ||
let b: Box<[Bits]> = (0..1000).map(|_| rng.gen()).collect(); | ||
bench.iter(|| { | ||
black_box(denominators(black_box(&a), black_box(&b))) | ||
.for_each(|_| {}) | ||
}) | ||
}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
mod generic; // Optimized generic implementation | ||
mod neon; // Optimized aarch64 NEON implementation | ||
mod reference; // Simple generic implementations | ||
|
||
pub use generic::{denominators, distances}; | ||
|
||
//TODO: move this to the benches file | ||
#[cfg(feature = "bench")] | ||
pub mod benches { | ||
use criterion::Criterion; | ||
|
||
use super::*; | ||
|
||
pub fn group(c: &mut Criterion) { | ||
reference::benches::group(c); | ||
|
||
generic::benches::group(c); | ||
|
||
#[cfg(target_feature = "neon")] | ||
neon::benches::group(c); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#![cfg(target_feature = "neon")] | ||
#![allow(unused)] | ||
|
||
// Rust + LLVM already generates good NEON code for the generic implementation. | ||
|
||
//TODO: move this to the benches file | ||
#[cfg(feature = "bench")] | ||
pub mod benches { | ||
use core::hint::black_box; | ||
|
||
use criterion::Criterion; | ||
use rand::{thread_rng, Rng}; | ||
|
||
use super::*; | ||
|
||
pub fn group(c: &mut Criterion) { | ||
let mut g = c.benchmark_group("neon"); | ||
let mut rng = thread_rng(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#![allow(unused)] | ||
|
||
use crate::distance::Bits; | ||
use crate::encoded_bits::EncodedBits; | ||
|
||
pub fn distances<'a>( | ||
query: &'a EncodedBits, | ||
db: &'a [EncodedBits], | ||
) -> impl Iterator<Item = [u16; 31]> + 'a { | ||
db.iter().map(|entry| { | ||
let mut result = [0_u16; 31]; | ||
for (d, r) in result.iter_mut().zip(-15..=15) { | ||
*d = query.rotated(r).dot(entry); | ||
} | ||
result | ||
}) | ||
} | ||
|
||
pub fn denominators<'a>( | ||
query: &'a Bits, | ||
db: &'a [Bits], | ||
) -> impl Iterator<Item = [u16; 31]> + 'a { | ||
db.iter().map(|entry| { | ||
let mut result = [0_u16; 31]; | ||
for (d, r) in result.iter_mut().zip(-15..=15) { | ||
*d = query.rotated(r).dot(entry); | ||
} | ||
result | ||
}) | ||
} | ||
|
||
//TODO: move this to the benches file | ||
#[cfg(feature = "bench")] | ||
pub mod benches { | ||
use core::hint::black_box; | ||
|
||
use criterion::Criterion; | ||
use rand::{thread_rng, Rng}; | ||
|
||
use super::*; | ||
|
||
pub fn group(c: &mut Criterion) { | ||
let mut rng = thread_rng(); | ||
let mut g = c.benchmark_group("reference"); | ||
|
||
g.bench_function("distances 31x1000", |bench| { | ||
let a: EncodedBits = rng.gen(); | ||
let b: Box<[EncodedBits]> = (0..1000).map(|_| rng.gen()).collect(); | ||
bench.iter(|| { | ||
black_box(distances(black_box(&a), black_box(&b))) | ||
.for_each(|_| {}) | ||
}) | ||
}); | ||
|
||
g.bench_function("denominators 31x1000", |bench| { | ||
let a: Bits = rng.gen(); | ||
let b: Box<[Bits]> = (0..1000).map(|_| rng.gen()).collect(); | ||
bench.iter(|| { | ||
black_box(denominators(black_box(&a), black_box(&b))) | ||
.for_each(|_| {}) | ||
}) | ||
}); | ||
} | ||
} |
Oops, something went wrong.