diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cda6da1f4..3052837dc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,7 +16,7 @@ concurrency: jobs: build_and_test: - name: cargo ${{ matrix.cargo_flags }} + name: cargo clippy + test runs-on: ubuntu-latest env: RUSTFLAGS: -D warnings @@ -31,14 +31,23 @@ jobs: - name: Install protobuf run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler + sudo apt-get update -qq + sudo apt-get install -yqq protobuf-compiler - name: Configure CI cache uses: Swatinem/rust-cache@v2 - - name: Build - run: cargo build --all-targets + - name: Prepare .sqlx files + working-directory: presage-store-sqlite + env: + DATABASE_URL: sqlite:presage.sqlite + run: | + cargo install --locked sqlx-cli + yes | cargo sqlx database reset + cargo sqlx prepare + + - name: Clippy + run: cargo clippy --all-targets - name: Test run: cargo test --all-targets @@ -61,27 +70,3 @@ jobs: - name: Check code format run: cargo fmt -- --check - - clippy: - name: clippy - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Rust toolchain - uses: dtolnay/rust-toolchain@v1 - with: - toolchain: stable - components: clippy - - - name: Install protobuf - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - - name: Setup CI cache - uses: Swatinem/rust-cache@v2 - - - name: Run clippy lints - run: cargo clippy diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..9e26dfeeb --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/presage-cli/Cargo.toml b/presage-cli/Cargo.toml index 738d810ac..fcf070838 100644 --- a/presage-cli/Cargo.toml +++ b/presage-cli/Cargo.toml @@ -5,9 +5,15 @@ edition = "2021" authors = ["Gabriel FĂ©ron "] license = "AGPL-3.0-only" +[features] +default = ["sqlite"] +sled = ["presage-store-sled"] +sqlite = ["presage-store-sqlite"] + [dependencies] presage = { path = "../presage" } -presage-store-sled = { path = "../presage-store-sled" } +presage-store-sled = { path = "../presage-store-sled", optional = true } +presage-store-sqlite = { path = "../presage-store-sqlite", optional = true } anyhow = { version = "1.0", features = ["backtrace"] } base64 = "0.22" diff --git a/presage-cli/src/main.rs b/presage-cli/src/main.rs index 9f4e70409..3a8e5b7ad 100644 --- a/presage-cli/src/main.rs +++ b/presage-cli/src/main.rs @@ -39,8 +39,6 @@ use presage::{ store::{Store, Thread}, Manager, }; -use presage_store_sled::MigrationConflictStrategy; -use presage_store_sled::SledStore; use tempfile::Builder; use tokio::task; use tokio::{ @@ -223,14 +221,22 @@ async fn main() -> anyhow::Result<()> { .config_dir() .into() }); - debug!(db_path =% db_path.display(), "opening config database"); - let config_store = SledStore::open_with_passphrase( + debug!(dir =% db_path.display(), "opening database in dir"); + + #[cfg(feature = "sled")] + let config_store = presage_store_sled::SledStore::open_with_passphrase( db_path, args.passphrase, - MigrationConflictStrategy::Raise, + presage_store_sled::MigrationConflictStrategy::Raise, OnNewIdentity::Trust, ) .await?; + + #[cfg(feature = "sqlite")] + let config_store = + presage_store_sqlite::SqliteStore::open(db_path.join("db.sqlite"), OnNewIdentity::Trust) + .await?; + run(args.subcommand, config_store).await } diff --git a/presage-store-sled/src/content.rs b/presage-store-sled/src/content.rs index de986e886..b76841d84 100644 --- a/presage-store-sled/src/content.rs +++ b/presage-store-sled/src/content.rs @@ -70,7 +70,7 @@ impl ContentsStore for SledStore { Ok(()) } - async fn save_contact(&mut self, contact: &Contact) -> Result<(), SledStoreError> { + async fn save_contact(&mut self, contact: Contact) -> Result<(), SledStoreError> { self.insert(SLED_TREE_CONTACTS, contact.uuid, contact)?; debug!("saved contact"); Ok(()) diff --git a/presage-store-sqlite/.env b/presage-store-sqlite/.env new file mode 100644 index 000000000..62df85634 --- /dev/null +++ b/presage-store-sqlite/.env @@ -0,0 +1 @@ +DATABASE_URL="sqlite:presage.sqlite" diff --git a/presage-store-sqlite/.gitignore b/presage-store-sqlite/.gitignore new file mode 100644 index 000000000..9262b372b --- /dev/null +++ b/presage-store-sqlite/.gitignore @@ -0,0 +1,2 @@ +.sqlx/* +presage.sqlite diff --git a/presage-store-sqlite/Cargo.toml b/presage-store-sqlite/Cargo.toml index 4f3d6057b..e562c1ecd 100644 --- a/presage-store-sqlite/Cargo.toml +++ b/presage-store-sqlite/Cargo.toml @@ -5,9 +5,15 @@ edition = "2021" [dependencies] async-trait = "0.1.83" +bytes = "1.8.0" chrono = "0.4.38" +hex = "0.4.3" +postcard = { version = "1.0.10", features = ["alloc"] } presage = { path = "../presage" } presage-store-cipher = { path = "../presage-store-cipher", optional = true } +prost = "0.13.3" -sqlx = { version = "0.8.2", features = ["sqlite"] } +sqlx = { version = "0.8.2", features = ["sqlite", "uuid", "runtime-tokio"] } thiserror = "1.0.65" +tracing = "0.1.40" +uuid = { version = "1.11.0", features = ["v4"] } diff --git a/presage-store-sqlite/migrations/20241024072558_Initial_data_model.sql b/presage-store-sqlite/migrations/20241024072558_Initial_data_model.sql new file mode 100644 index 000000000..0854fc2e0 --- /dev/null +++ b/presage-store-sqlite/migrations/20241024072558_Initial_data_model.sql @@ -0,0 +1,164 @@ +CREATE TABLE config( + key TEXT PRIMARY KEY NOT NULL ON CONFLICT REPLACE, + value BLOB NOT NULL +); + +CREATE TABLE sessions ( + address VARCHAR(36) NOT NULL, + device_id INTEGER NOT NULL, + record BLOB NOT NULL, + identity TEXT CHECK(identity IN ('aci', 'pni')) NOT NULL DEFAULT 'aci', + + PRIMARY KEY(address, device_id, identity) ON CONFLICT REPLACE +); + +CREATE TABLE identities ( + address VARCHAR(36) NOT NULL, + record BLOB NOT NULL, + identity TEXT CHECK(identity IN ('aci', 'pni')) NOT NULL DEFAULT 'aci', + + -- TODO: Signal adds a lot more fields here that I don't yet care about. + + PRIMARY KEY(address, identity) ON CONFLICT REPLACE +); + +CREATE TABLE prekeys ( + id INTEGER NOT NULL, + record BLOB NOT NULL, + identity TEXT CHECK(identity IN ('aci', 'pni')) NOT NULL, + + PRIMARY KEY(id, identity) ON CONFLICT REPLACE +); + +CREATE TABLE signed_prekeys ( + id INTEGER, + record BLOB NOT NULL, + identity TEXT CHECK(identity IN ('aci', 'pni')) NOT NULL DEFAULT 'aci', + + PRIMARY KEY(id, identity) ON CONFLICT REPLACE +); + +CREATE TABLE kyber_prekeys ( + id INTEGER, + record BLOB NOT NULL, + is_last_resort BOOLEAN DEFAULT FALSE NOT NULL, + identity TEXT CHECK(identity IN ('aci', 'pni')) NOT NULL, + + PRIMARY KEY(id, identity) ON CONFLICT REPLACE +); + +CREATE TABLE sender_keys ( + address VARCHAR(36), + device INTEGER NOT NULL, + distribution_id TEXT NOT NULL, + record BLOB NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + identity TEXT CHECK(identity IN ('aci', 'pni')) NOT NULL DEFAULT 'aci', + + PRIMARY KEY(address, device, distribution_id) ON CONFLICT REPLACE +); + +-- Groups +CREATE TABLE groups( + id INTEGER PRIMARY KEY AUTOINCREMENT, + master_key BLOB NOT NULL, + title TEXT NOT NULL, + revision INTEGER NOT NULL DEFAULT 0, + invite_link_password BLOB, + access_required BLOB, + avatar TEXT NOT NULL, + description TEXT, + members BLOB NOT NULL, + pending_members BLOB NOT NULL, + requesting_members BLOB NOT NULL +); + +CREATE TABLE group_avatars( + id INTEGER PRIMARY KEY AUTOINCREMENT, + bytes BLOB NOT NULL, + + FOREIGN KEY(id) REFERENCES groups(id) ON DELETE CASCADE +); + +CREATE TABLE contacts( + uuid VARCHAR(36) NOT NULL, + -- E.164 numbers should never be longer than 15 chars (excl. international prefix) + phone_number VARCHAR(20), + name TEXT NOT NULL, + color VARCHAR(32), + profile_key BLOB NOT NULL, + expire_timer INTEGER NOT NULL, + expire_timer_version INTEGER NOT NULL DEFAULT 2, + inbox_position INTEGER NOT NULL, + archived BOOLEAN NOT NULL, + avatar BLOB, + + PRIMARY KEY(uuid) ON CONFLICT REPLACE +); + +CREATE TABLE contacts_verification_state( + destination_aci VARCHAR(36) NOT NULL, + identity_key BLOB NOT NULL, + is_verified BOOLEAN, + + FOREIGN KEY(destination_aci) REFERENCES contacts(uuid) ON UPDATE CASCADE, + PRIMARY KEY(destination_aci) ON CONFLICT REPLACE +); + +CREATE TABLE profile_keys( + uuid VARCHAR(36) NOT NULL, + key BLOB NOT NULL, + + PRIMARY KEY(uuid) ON CONFLICT REPLACE +); + +CREATE TABLE profiles( + uuid VARCHAR(36) NOT NULL, + given_name TEXT, + family_name TEXT, + about TEXT, + about_emoji TEXT, + avatar TEXT, + + FOREIGN KEY(uuid) REFERENCES profile_keys(uuid) ON UPDATE CASCADE + PRIMARY KEY(uuid) ON CONFLICT REPLACE +); + +CREATE TABLE profile_avatars( + uuid VARCHAR(36) NOT NULL, + bytes BLOB NOT NULL, + + FOREIGN KEY(uuid) REFERENCES profile_keys(uuid) ON UPDATE CASCADE +); + +-- Threads +CREATE TABLE threads ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + group_id BLOB DEFAULT NULL, + recipient_id VARCHAR(36) DEFAULT NULL, + + FOREIGN KEY(id) REFERENCES groups(id) ON DELETE CASCADE +); +CREATE UNIQUE INDEX threads_target ON threads(group_id, recipient_id); + +CREATE TABLE thread_messages( + ts INTEGER NOT NULL, + thread_id INTEGER NOT NULL, + + sender_service_id TEXT NOT NULL, + sender_device_id INTEGER NOT NULL, + destination_service_id TEXT NOT NULL, + needs_receipt BOOLEAN NOT NULL, + unidentified_sender BOOLEAN NOT NULL, + + content_body BLOB NOT NULL, + + PRIMARY KEY(ts, thread_id) ON CONFLICT REPLACE, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON UPDATE CASCADE +); + +CREATE TABLE sticker_packs( + id BLOB PRIMARY KEY NOT NULL, + key BLOB NOT NULL, + manifest BLOB NOT NULL +); diff --git a/presage-store-sqlite/src/content.rs b/presage-store-sqlite/src/content.rs index b06932ca4..83bc3a1e7 100644 --- a/presage-store-sqlite/src/content.rs +++ b/presage-store-sqlite/src/content.rs @@ -1,217 +1,625 @@ use std::marker::PhantomData; +use bytes::Bytes; use presage::{ - libsignal_service::{prelude::Content, zkgroup::GroupMasterKeyBytes}, - model::{contacts::Contact, groups::Group}, - store::{ContentsStore, StickerPack}, + libsignal_service::{ + self, + content::{ContentBody, Metadata}, + models::Attachment, + prelude::{ + phonenumber::{self, PhoneNumber}, + AccessControl, Content, GroupMasterKey, Member, ProfileKey, ServiceError, + }, + profile_name::ProfileName, + protocol::ServiceId, + zkgroup::{self, GroupMasterKeyBytes}, + Profile, ServiceAddress, + }, + model::{ + contacts::Contact, + groups::{Group, PendingMember, RequestingMember}, + }, + proto::{self, verified, Verified}, + store::{ContentsStore, StickerPack, Thread}, }; +use sqlx::{query, query_as, query_scalar, types::Uuid, QueryBuilder, Sqlite}; +use tracing::warn; +use uuid::timestamp; -use crate::{SqliteStore, SqliteStoreError}; +use crate::{ + data::{SqlContact, SqlGroup, SqlMessage, SqlProfile}, + SqliteStore, SqliteStoreError, +}; impl ContentsStore for SqliteStore { type ContentsStoreError = SqliteStoreError; - type ContactsIter = DummyIter>; + type ContactsIter = Box>>; - type GroupsIter = DummyIter>; + type GroupsIter = + Box>>; - type MessagesIter = DummyIter>; + type MessagesIter = Box>>; - type StickerPacksIter = DummyIter>; + type StickerPacksIter = Box>>; async fn clear_profiles(&mut self) -> Result<(), Self::ContentsStoreError> { - todo!() + query!("DELETE FROM profiles").execute(&self.db).await?; + Ok(()) } async fn clear_contents(&mut self) -> Result<(), Self::ContentsStoreError> { - todo!() + let mut tx = self.db.begin().await?; + + query!("DELETE FROM groups").execute(&mut *tx).await; + query!("DELETE FROM contacts").execute(&mut *tx).await; + + tx.commit().await?; + Ok(()) } async fn clear_messages(&mut self) -> Result<(), Self::ContentsStoreError> { - todo!() + let mut tx = self.db.begin().await?; + query!("DELETE FROM thread_messages") + .execute(&mut *tx) + .await?; + query!("DELETE FROM threads").execute(&mut *tx).await?; + tx.commit().await?; + + Ok(()) } - async fn clear_thread( - &mut self, - thread: &presage::store::Thread, - ) -> Result<(), Self::ContentsStoreError> { - todo!() + async fn clear_thread(&mut self, thread: &Thread) -> Result<(), Self::ContentsStoreError> { + if let Some(thread_id) = self.thread_id(thread).await? { + query!("DELETE FROM thread_messages WHERE thread_id = ?", thread_id) + .execute(&self.db) + .await?; + }; + + Ok(()) } async fn save_message( &self, - thread: &presage::store::Thread, - message: presage::libsignal_service::prelude::Content, + thread: &Thread, + Content { metadata, body }: Content, ) -> Result<(), Self::ContentsStoreError> { - todo!() + let mut tx = self.db.begin().await?; + + let thread_id = match thread { + Thread::Contact(uuid) => { + query_scalar!( + "INSERT INTO threads(recipient_id, group_id) VALUES (?, NULL) ON CONFLICT DO NOTHING RETURNING id", + metadata.sender.uuid, + ) + .fetch_one(&mut *tx) + .await? + } + Thread::Group(master_key_bytes) => { + let master_key_bytes = master_key_bytes.as_slice(); + query_scalar!( + "INSERT INTO threads(group_id) SELECT id FROM groups WHERE groups.master_key = ? ON CONFLICT DO NOTHING RETURNING id", + master_key_bytes + ) + .fetch_one(&mut *tx) + .await? + } + }; + + let Metadata { + sender, + destination, + sender_device, + timestamp, + needs_receipt, + unidentified_sender, + server_guid, + } = metadata; + + let proto_bytes = prost::Message::encode_to_vec(&body.into_proto()); + + let timestamp: i64 = timestamp.try_into()?; + + query!( + "INSERT OR REPLACE INTO + thread_messages(ts, thread_id, sender_service_id, needs_receipt, unidentified_sender, content_body) + VALUES(?, ?, ?, ?, ?, ?)", + timestamp, + thread_id, + sender.uuid, + needs_receipt, + unidentified_sender, + proto_bytes + ) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(()) } async fn delete_message( &mut self, - thread: &presage::store::Thread, + thread: &Thread, timestamp: u64, ) -> Result { - todo!() + let timestamp: i64 = timestamp.try_into()?; + let deleted: u64 = match thread { + Thread::Contact(uuid) => query_scalar!( + "DELETE FROM thread_messages WHERE ts = ? AND thread_id IN ( + SELECT thread_id FROM threads + WHERE recipient_id = ? + )", + timestamp, + uuid + ) + .execute(&self.db) + .await? + .rows_affected(), + Thread::Group(master_key) => { + let master_key = master_key.as_slice(); + query_scalar!( + " + DELETE FROM thread_messages WHERE ts = ? AND thread_id IN ( + SELECT thread_id FROM threads + WHERE group_id = ? + )", + timestamp, + master_key + ) + .execute(&self.db) + .await? + .rows_affected() + } + }; + Ok(deleted > 0) } async fn message( &self, - thread: &presage::store::Thread, + thread: &Thread, timestamp: u64, - ) -> Result, Self::ContentsStoreError> - { - todo!() + ) -> Result, Self::ContentsStoreError> { + let timestamp: i64 = timestamp.try_into()?; + let Some(thread_id) = self.thread_id(thread).await? else { + warn!(%thread, "thread not found"); + return Ok(None); + }; + + query_as!( + SqlMessage, + "SELECT * FROM thread_messages WHERE ts = ? AND thread_id = ?", + timestamp, + thread_id + ) + .fetch_optional(&self.db) + .await? + .map(TryInto::try_into) + .transpose() } async fn messages( &self, - thread: &presage::store::Thread, + thread: &Thread, range: impl std::ops::RangeBounds, ) -> Result { - todo!() + let Some(thread_id) = self.thread_id(thread).await? else { + warn!(%thread, "thread not found"); + return Ok(Box::new(std::iter::empty())); + }; + + let start = range.start_bound(); + + let end = range.end_bound(); + + let mut query_builder: QueryBuilder = + QueryBuilder::new("SELECT * FROM thread_messages WHERE thread_id = "); + query_builder.push_bind(thread_id); + match range.start_bound() { + std::ops::Bound::Included(ts) => { + query_builder.push("AND ts >= "); + query_builder.push_bind(*ts as i64); + } + std::ops::Bound::Excluded(ts) => { + query_builder.push("AND ts > "); + query_builder.push_bind(*ts as i64); + } + std::ops::Bound::Unbounded => (), + } + match range.end_bound() { + std::ops::Bound::Included(ts) => { + query_builder.push("AND ts <= "); + query_builder.push_bind(*ts as i64); + } + std::ops::Bound::Excluded(ts) => { + query_builder.push("AND ts < "); + query_builder.push_bind(*ts as i64); + } + std::ops::Bound::Unbounded => (), + } + + query_builder.push("ORDER BY ts DESC"); + + let messages: Vec = query_builder.build_query_as().fetch_all(&self.db).await?; + Ok(Box::new(messages.into_iter().map(TryInto::try_into))) } async fn clear_contacts(&mut self) -> Result<(), Self::ContentsStoreError> { - todo!() + query!("DELETE FROM contacts").execute(&self.db).await?; + Ok(()) } - async fn save_contact( - &mut self, - contacts: &presage::model::contacts::Contact, - ) -> Result<(), Self::ContentsStoreError> { - todo!() + async fn save_contact(&mut self, contact: Contact) -> Result<(), Self::ContentsStoreError> { + let profile_key: &[u8] = contact.profile_key.as_ref(); + let avatar_bytes = contact.avatar.map(|a| a.reader.to_vec()); + let phone_number = contact.phone_number.map(|p| p.to_string()); + + let mut tx = self.db.begin().await?; + + query!( + "INSERT OR REPLACE INTO contacts + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + contact.uuid, + phone_number, + contact.name, + contact.color, + profile_key, + contact.expire_timer, + contact.expire_timer_version, + contact.inbox_position, + contact.archived, + avatar_bytes, + ) + .execute(&mut *tx) + .await?; + + let Verified { + destination_aci, + identity_key, + state, + .. + } = contact.verified; + let verified_state = match verified::State::from_i32(state.unwrap_or_default()) { + None | Some(verified::State::Default) => None, + Some(verified::State::Unverified) => Some("unverified"), + Some(verified::State::Verified) => Some("verified"), + }; + + query!( + "INSERT OR REPLACE INTO contacts_verification_state(destination_aci, identity_key, is_verified) + VALUES(?, ?, ?)", + destination_aci, + identity_key, + verified_state, + ) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(()) } async fn contacts(&self) -> Result { - todo!() + let contacts = query_as!( + SqlContact, + "SELECT * + FROM contacts c + LEFT JOIN contacts_verification_state cv ON c.uuid = cv.destination_aci + ORDER BY inbox_position + " + ) + .fetch_all(&self.db) + .await? + .into_iter() + .map(TryInto::try_into); + + Ok(Box::new(contacts)) } async fn contact_by_id( &self, - id: &presage::libsignal_service::prelude::Uuid, + uuid: &Uuid, ) -> Result, Self::ContentsStoreError> { - todo!() + query_as!( + SqlContact, + "SELECT * + FROM contacts c + LEFT JOIN contacts_verification_state cv ON c.uuid = cv.destination_aci + WHERE c.uuid = ? + ORDER BY inbox_position + LIMIT 1 + ", + uuid + ) + .fetch_optional(&self.db) + .await? + .map(TryInto::try_into) + .transpose() } async fn clear_groups(&mut self) -> Result<(), Self::ContentsStoreError> { - todo!() + query!("DELETE FROM groups").execute(&self.db).await?; + Ok(()) } async fn save_group( &self, - master_key: presage::libsignal_service::zkgroup::GroupMasterKeyBytes, + master_key: zkgroup::GroupMasterKeyBytes, group: impl Into, ) -> Result<(), Self::ContentsStoreError> { - todo!() + let group = SqlGroup::from_group(master_key, group.into())?; + query_as!( + SqlGroup, + "INSERT OR REPLACE INTO groups( + id, + master_key, + title, + revision, + invite_link_password, + access_required, + avatar, + description, + members, + pending_members, + requesting_members + ) VALUES (NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + group.master_key, + group.title, + group.revision, + group.invite_link_password, + group.access_required, + group.avatar, + group.description, + group.members, + group.pending_members, + group.requesting_members, + ) + .execute(&self.db) + .await?; + Ok(()) } async fn groups(&self) -> Result { - todo!() + let groups = query_as!(SqlGroup, "SELECT * FROM groups") + .fetch_all(&self.db) + .await? + .into_iter() + .map(|g| { + let group_master_key_bytes: GroupMasterKeyBytes = + g.master_key.clone().try_into().expect("invalid master key"); + let group = g.into_group()?; + Ok((group_master_key_bytes, group)) + }); + Ok(Box::new(groups)) } async fn group( &self, - master_key: presage::libsignal_service::zkgroup::GroupMasterKeyBytes, + master_key: zkgroup::GroupMasterKeyBytes, ) -> Result, Self::ContentsStoreError> { - todo!() + query_as!(SqlGroup, "SELECT * FROM groups") + .fetch_optional(&self.db) + .await? + .map(|g| g.into_group()) + .transpose() } async fn save_group_avatar( &self, - master_key: presage::libsignal_service::zkgroup::GroupMasterKeyBytes, + master_key: zkgroup::GroupMasterKeyBytes, avatar: &presage::AvatarBytes, ) -> Result<(), Self::ContentsStoreError> { - todo!() + let mut tx = self.db.begin().await?; + + let group_id = self.group_id(&master_key).await?; + query!( + "INSERT OR REPLACE INTO group_avatars(id, bytes) VALUES(?, ?)", + group_id, + avatar + ) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(()) } async fn group_avatar( &self, - master_key: presage::libsignal_service::zkgroup::GroupMasterKeyBytes, + master_key: zkgroup::GroupMasterKeyBytes, ) -> Result, Self::ContentsStoreError> { - todo!() + let group_id = self.group_id(&master_key).await?; + query_scalar!("SELECT bytes FROM group_avatars WHERE id = ?", group_id) + .fetch_optional(&self.db) + .await + .map_err(Into::into) } async fn upsert_profile_key( &mut self, - uuid: &presage::libsignal_service::prelude::Uuid, - key: presage::libsignal_service::prelude::ProfileKey, + uuid: &Uuid, + key: ProfileKey, ) -> Result { - todo!() + let profile_key_bytes = key.get_bytes(); + let profile_key_slice = profile_key_bytes.as_slice(); + let rows_upserted = query!( + "INSERT OR REPLACE INTO profile_keys VALUES(?, ?)", + uuid, + profile_key_slice + ) + .execute(&self.db) + .await? + .rows_affected(); + Ok(rows_upserted == 1) } async fn profile_key( &self, - uuid: &presage::libsignal_service::prelude::Uuid, - ) -> Result, Self::ContentsStoreError> - { - todo!() + uuid: &Uuid, + ) -> Result, Self::ContentsStoreError> { + let profile_key = + query_scalar!("SELECT key FROM profile_keys WHERE uuid = ? LIMIT 1", uuid) + .fetch_optional(&self.db) + .await? + .and_then(|key_bytes| key_bytes.try_into().ok().map(ProfileKey::create)); + Ok(profile_key) } async fn save_profile( &mut self, - uuid: presage::libsignal_service::prelude::Uuid, - key: presage::libsignal_service::prelude::ProfileKey, - profile: presage::libsignal_service::Profile, + uuid: Uuid, + _key: ProfileKey, + profile: Profile, ) -> Result<(), Self::ContentsStoreError> { - todo!() + let given_name = profile.name.clone().map(|n| n.given_name); + let family_name = profile.name.map(|n| n.family_name).flatten(); + query!( + "INSERT OR REPLACE INTO profiles VALUES(?, ?, ?, ?, ?, ?)", + uuid, + given_name, + family_name, + profile.about, + profile.about_emoji, + profile.avatar + ) + .execute(&self.db) + .await?; + Ok(()) } async fn profile( &self, - uuid: presage::libsignal_service::prelude::Uuid, - key: presage::libsignal_service::prelude::ProfileKey, - ) -> Result, Self::ContentsStoreError> { - todo!() + uuid: Uuid, + key: ProfileKey, + ) -> Result, Self::ContentsStoreError> { + let key_bytes = key.get_bytes(); + let key_slice = key_bytes.as_slice(); + query_as!( + SqlProfile, + "SELECT pk.key, p.* FROM profile_keys pk + LEFT JOIN profiles p ON pk.uuid = p.uuid + WHERE pk.uuid = ? + LIMIT 1", + uuid + ) + .fetch_optional(&self.db) + .await? + .map(TryInto::try_into) + .transpose() } async fn save_profile_avatar( &mut self, - uuid: presage::libsignal_service::prelude::Uuid, - key: presage::libsignal_service::prelude::ProfileKey, + uuid: Uuid, + _key: ProfileKey, profile: &presage::AvatarBytes, ) -> Result<(), Self::ContentsStoreError> { - todo!() + query!( + "INSERT OR REPLACE INTO profile_avatars(uuid, bytes) VALUES(?, ?)", + uuid, + profile + ) + .execute(&self.db) + .await?; + + Ok(()) } async fn profile_avatar( &self, - uuid: presage::libsignal_service::prelude::Uuid, - key: presage::libsignal_service::prelude::ProfileKey, + uuid: Uuid, + key: ProfileKey, ) -> Result, Self::ContentsStoreError> { - todo!() + query_scalar!("SELECT bytes FROM profile_avatars WHERE uuid = ?", uuid) + .fetch_optional(&self.db) + .await + .map_err(Into::into) } async fn add_sticker_pack( &mut self, - pack: &presage::store::StickerPack, + pack: &StickerPack, ) -> Result<(), Self::ContentsStoreError> { - todo!() + let manifest_json = postcard::to_allocvec(&pack.manifest)?; + query!( + "INSERT OR REPLACE INTO sticker_packs(id, key, manifest) VALUES(?, ?, ?)", + pack.id, + pack.key, + manifest_json, + ) + .execute(&self.db) + .await?; + + Ok(()) } async fn sticker_pack( &self, id: &[u8], - ) -> Result, Self::ContentsStoreError> { - todo!() + ) -> Result, Self::ContentsStoreError> { + query_scalar!("SELECT manifest FROM sticker_packs WHERE id = ?", id) + .fetch_optional(&self.db) + .await? + .map(|bytes| postcard::from_bytes(&bytes).map_err(Into::into)) + .transpose() } async fn remove_sticker_pack(&mut self, id: &[u8]) -> Result { - todo!() + query!("DELETE FROM sticker_packs WHERE id = ?", id) + .execute(&self.db) + .await + .map_err(Into::into) + .map(|r| r.rows_affected() > 0) } async fn sticker_packs(&self) -> Result { - todo!() + let sticker_packs = query!("SELECT * FROM sticker_packs") + .fetch_all(&self.db) + .await? + .into_iter() + .map(|r| { + Ok(StickerPack { + id: r.id, + key: r.key, + manifest: postcard::from_bytes(&r.manifest)?, + }) + }); + Ok(Box::new(sticker_packs)) } } -pub struct DummyIter { - _data: PhantomData, -} - -impl Iterator for DummyIter { - type Item = T; +impl SqliteStore { + async fn group_id(&self, master_key: &[u8]) -> Result { + query_scalar!( + "SELECT id FROM groups WHERE groups.master_key = ?", + master_key + ) + .fetch_one(&self.db) + .await + .map_err(Into::into) + } - fn next(&mut self) -> Option { - todo!() + async fn thread_id(&self, thread: &Thread) -> Result, SqliteStoreError> { + Ok(match thread { + Thread::Contact(uuid) => { + query_scalar!( + "SELECT id FROM threads WHERE recipient_id = ? LIMIT 1", + uuid + ) + .fetch_optional(&self.db) + .await? + } + Thread::Group(master_key) => { + let master_key = master_key.as_slice(); + query_scalar!( + "SELECT id FROM threads WHERE group_id = ? LIMIT 1", + master_key + ) + .fetch_optional(&self.db) + .await? + } + }) } } diff --git a/presage-store-sqlite/src/data.rs b/presage-store-sqlite/src/data.rs new file mode 100644 index 000000000..33a7c26ad --- /dev/null +++ b/presage-store-sqlite/src/data.rs @@ -0,0 +1,203 @@ +use bytes::Bytes; +use presage::{ + libsignal_service::{ + content::Metadata, + models::Attachment, + prelude::{phonenumber, Content}, + profile_name::ProfileName, + protocol::ServiceId, + zkgroup::GroupMasterKeyBytes, + Profile, + }, + model::{ + contacts::Contact, + groups::{Group, Member, PendingMember, RequestingMember}, + }, + proto::{self, verified, Verified}, +}; + +use crate::SqliteStoreError; + +#[derive(Debug, sqlx::FromRow)] +pub struct SqlContact { + pub uuid: String, + pub phone_number: Option, + pub name: String, + pub color: Option, + pub profile_key: Vec, + pub expire_timer: i64, + pub expire_timer_version: i64, + pub inbox_position: i64, + pub archived: bool, + pub avatar: Option>, + + pub destination_aci: Option, + pub identity_key: Option>, + pub is_verified: Option, +} + +impl TryInto for SqlContact { + type Error = SqliteStoreError; + + #[tracing::instrument] + fn try_into(self) -> Result { + Ok(Contact { + uuid: self.uuid.parse()?, + phone_number: self + .phone_number + .map(|p| phonenumber::parse(None, &p)) + .transpose()?, + name: self.name, + color: self.color, + verified: Verified { + destination_aci: self.destination_aci, + identity_key: self.identity_key, + state: self.is_verified.map(|v| { + match v { + true => verified::State::Verified, + false => verified::State::Unverified, + } + .into() + }), + null_message: None, + }, + profile_key: self.profile_key, + expire_timer: self.expire_timer as u32, + expire_timer_version: self.expire_timer_version as u32, + inbox_position: self.inbox_position as u32, + archived: self.archived, + avatar: self.avatar.map(|b| Attachment { + content_type: "application/octet-stream".into(), + reader: Bytes::from(b), + }), + }) + } +} + +#[derive(Debug, sqlx::FromRow)] +pub struct SqlProfile { + pub uuid: String, + pub key: Vec, + pub given_name: Option, + pub family_name: Option, + pub about: Option, + pub about_emoji: Option, + pub avatar: Option, +} + +impl TryInto for SqlProfile { + type Error = SqliteStoreError; + + #[tracing::instrument] + fn try_into(self) -> Result { + Ok(Profile { + name: self.given_name.map(|gn| ProfileName { + given_name: gn, + family_name: self.family_name, + }), + about: self.about, + about_emoji: self.about_emoji, + avatar: self.avatar, + }) + } +} + +#[derive(Debug, sqlx::FromRow)] +pub struct SqlGroup { + pub id: Option, + pub master_key: Vec, + pub title: String, + pub revision: i64, + pub invite_link_password: Option>, + pub access_required: Option>, + pub avatar: String, + pub description: Option, + pub members: Vec, + pub pending_members: Vec, + pub requesting_members: Vec, +} + +impl SqlGroup { + #[tracing::instrument] + pub fn from_group( + master_key: GroupMasterKeyBytes, + group: Group, + ) -> Result { + Ok(SqlGroup { + id: None, + master_key: master_key.to_vec(), + title: group.title, + revision: group.revision as i64, + invite_link_password: Some(group.invite_link_password), + access_required: group + .access_control + .map(|ac| postcard::to_allocvec(&ac)) + .transpose()?, + avatar: group.avatar, + description: group.description, + members: postcard::to_allocvec(&group.members)?, + pending_members: postcard::to_allocvec(&group.pending_members)?, + requesting_members: postcard::to_allocvec(&group.requesting_members)?, + }) + } + + #[tracing::instrument] + pub fn into_group(self) -> Result { + let members: Vec = postcard::from_bytes(&self.members)?; + let pending_members: Vec = postcard::from_bytes(&self.pending_members)?; + let requesting_members: Vec = + postcard::from_bytes(&self.requesting_members)?; + Ok(Group { + title: self.title, + avatar: self.avatar, + disappearing_messages_timer: None, + access_control: None, + revision: self.revision.try_into()?, + members, + pending_members, + requesting_members, + invite_link_password: self.invite_link_password.unwrap_or_default(), + description: self.description, + }) + } +} + +#[derive(Debug, sqlx::FromRow)] +pub struct SqlMessage { + pub ts: i64, + pub thread_id: i64, + + pub sender_service_id: String, + pub sender_device_id: i64, + pub destination_service_id: String, + pub needs_receipt: bool, + pub unidentified_sender: bool, + + pub content_body: Vec, +} + +impl TryInto for SqlMessage { + type Error = SqliteStoreError; + + #[tracing::instrument] + fn try_into(self) -> Result { + let body: proto::Content = prost::Message::decode(&self.content_body[..]).unwrap(); + let sender_service_id = + ServiceId::parse_from_service_id_string(&self.sender_service_id).unwrap(); + let destination_service_id = + ServiceId::parse_from_service_id_string(&self.destination_service_id).unwrap(); + Content::from_proto( + body, + Metadata { + sender: sender_service_id.into(), + destination: destination_service_id.into(), + sender_device: self.sender_device_id.try_into().unwrap(), + timestamp: self.ts.try_into().unwrap(), + needs_receipt: self.needs_receipt, + unidentified_sender: self.unidentified_sender, + server_guid: None, + }, + ) + .map_err(|err| SqliteStoreError::InvalidFormat) + } +} diff --git a/presage-store-sqlite/src/error.rs b/presage-store-sqlite/src/error.rs index fa7764462..eaea67e17 100644 --- a/presage-store-sqlite/src/error.rs +++ b/presage-store-sqlite/src/error.rs @@ -1,11 +1,26 @@ -use presage::store::StoreError; +use presage::{libsignal_service::prelude::phonenumber, store::StoreError}; +use sqlx::types::uuid; #[derive(Debug, thiserror::Error)] pub enum SqliteStoreError { - #[error("database migration is not supported")] - MigrationConflict, + #[error("database migration failed: {0}")] + Migration(#[from] sqlx::migrate::MigrateError), #[error("data store error: {0}")] Db(#[from] sqlx::Error), + + #[error("error parsing phonenumber: {0}")] + PhoneNumber(#[from] phonenumber::ParseError), + #[error("error parsing UUID: {0}")] + Uuid(#[from] uuid::Error), + + #[error("failed to convert int: {0}")] + TryFromInt(#[from] std::num::TryFromIntError), + + #[error("invalid format")] + InvalidFormat, + + #[error(transparent)] + Postcard(#[from] postcard::Error), } impl StoreError for SqliteStoreError {} diff --git a/presage-store-sqlite/src/lib.rs b/presage-store-sqlite/src/lib.rs index 90e58977b..233546b88 100644 --- a/presage-store-sqlite/src/lib.rs +++ b/presage-store-sqlite/src/lib.rs @@ -1,19 +1,34 @@ #![allow(warnings)] -use std::path::Path; +use std::{future::Future, path::Path, pin::Pin}; use presage::{ + libsignal_service::{ + configuration::SignalServers, + prelude::{phonenumber::PhoneNumber, ProfileKey, SignalingKey}, + protocol::SignalProtocolError, + push_service::ServiceIds, + }, model::identity::OnNewIdentity, store::{StateStore, Store}, }; use protocol::SqliteProtocolStore; -use sqlx::{sqlite::SqliteConnectOptions, SqlitePool}; +use sqlx::{ + migrate::{MigrateDatabase, Migration, MigrationSource, Migrator}, + query, query_scalar, + sqlite::SqliteConnectOptions, + Sqlite, SqlitePool, +}; mod content; +pub(crate) mod data; mod error; mod protocol; pub use error::SqliteStoreError; +use tracing::{debug, trace}; + +static MIGRATOR: Migrator = sqlx::migrate!(); #[derive(Debug, Clone)] pub struct SqliteStore { @@ -27,9 +42,21 @@ impl SqliteStore { db_path: impl AsRef, trust_new_identities: OnNewIdentity, ) -> Result { + let db_path = db_path.as_ref(); + let db_path_str = db_path.to_str().unwrap(); + + // Create database if it doesn't exist + if !Sqlite::database_exists(db_path_str).await.unwrap_or(false) { + debug!(path = db_path_str, "creating sqlite database"); + Sqlite::create_database(db_path_str).await?; + } + let connect_options = SqliteConnectOptions::new().filename(db_path); let pool = SqlitePool::connect_with(connect_options).await?; + debug!("applying sqlite migrations"); + MIGRATOR.run(&pool).await?; + Ok(Self { db: pool, trust_new_identities, @@ -37,6 +64,16 @@ impl SqliteStore { } } +trait SqlxErrorExt { + fn into_protocol_error(self) -> Result; +} + +impl SqlxErrorExt for Result { + fn into_protocol_error(self) -> Result { + self.map_err(|error| SignalProtocolError::InvalidState("sqlite", error.to_string())) + } +} + impl Store for SqliteStore { type Error = SqliteStoreError; @@ -45,18 +82,21 @@ impl Store for SqliteStore { type PniStore = SqliteProtocolStore; async fn clear(&mut self) -> Result<(), SqliteStoreError> { - todo!() + query!("DELETE FROM config").execute(&self.db).await?; + Ok(()) } fn aci_protocol_store(&self) -> Self::AciStore { SqliteProtocolStore { store: self.clone(), + identity_type: "aci", } } fn pni_protocol_store(&self) -> Self::PniStore { SqliteProtocolStore { store: self.clone(), + identity_type: "pni", } } } @@ -64,38 +104,71 @@ impl Store for SqliteStore { impl StateStore for SqliteStore { type StateStoreError = SqliteStoreError; + #[tracing::instrument(skip(self))] async fn load_registration_data( &self, ) -> Result, Self::StateStoreError> { - todo!() + query_scalar!("SELECT value FROM config WHERE key = 'registration'") + .fetch_optional(&self.db) + .await? + .map(|value: Vec| postcard::from_bytes(&value).map_err(Into::into)) + .transpose() } async fn set_aci_identity_key_pair( &self, key_pair: presage::libsignal_service::protocol::IdentityKeyPair, ) -> Result<(), Self::StateStoreError> { - todo!() + trace!("setting ACI identity key pair"); + let key_pair_bytes = key_pair.serialize(); + query!( + "INSERT OR REPLACE INTO config(key, value) VALUES('aci_identity_key_pair', ?)", + key_pair_bytes + ) + .execute(&self.db) + .await?; + Ok(()) } async fn set_pni_identity_key_pair( &self, key_pair: presage::libsignal_service::protocol::IdentityKeyPair, ) -> Result<(), Self::StateStoreError> { - todo!() + trace!("setting PNI identity key pair"); + let key_pair_bytes = key_pair.serialize(); + query!( + "INSERT OR REPLACE INTO config(key, value) VALUES('pni_identity_key_pair', ?)", + key_pair_bytes + ) + .execute(&self.db) + .await?; + Ok(()) } async fn save_registration_data( &mut self, state: &presage::manager::RegistrationData, ) -> Result<(), Self::StateStoreError> { - todo!() + trace!("saving registration data"); + let registration_data_json = postcard::to_allocvec(&state)?; + query!( + "INSERT OR REPLACE INTO config(key, value) VALUES('registration', ?)", + registration_data_json + ) + .execute(&self.db) + .await?; + + Ok(()) } async fn is_registered(&self) -> bool { - todo!() + self.load_registration_data().await.ok().is_some() } async fn clear_registration(&mut self) -> Result<(), Self::StateStoreError> { - todo!() + query!("DELETE FROM config WHERE key = 'registration'") + .execute(&self.db) + .await?; + Ok(()) } } diff --git a/presage-store-sqlite/src/protocol.rs b/presage-store-sqlite/src/protocol.rs index 196229f26..1150cb181 100644 --- a/presage-store-sqlite/src/protocol.rs +++ b/presage-store-sqlite/src/protocol.rs @@ -1,23 +1,32 @@ +use std::fmt::{self, Formatter}; + use async_trait::async_trait; use chrono::{DateTime, Utc}; -use presage::libsignal_service::{ - pre_keys::{KyberPreKeyStoreExt, PreKeysStore}, - prelude::{IdentityKeyStore, SessionStoreExt, Uuid}, - protocol::{ - Direction, IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord, - KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore, ProtocolAddress, ProtocolStore, - SenderKeyRecord, SenderKeyStore, SessionRecord, SessionStore, - SignalProtocolError as ProtocolError, SignedPreKeyId, SignedPreKeyRecord, - SignedPreKeyStore, +use presage::{ + libsignal_service::{ + pre_keys::{KyberPreKeyStoreExt, PreKeysStore}, + prelude::{IdentityKeyStore, SessionStoreExt, Uuid}, + protocol::{ + Direction, GenericSignedPreKey, IdentityKey, IdentityKeyPair, KyberPreKeyId, + KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore, + ProtocolAddress, ProtocolStore, SenderKeyRecord, SenderKeyStore, SessionRecord, + SessionStore, SignalProtocolError as ProtocolError, SignedPreKeyId, SignedPreKeyRecord, + SignedPreKeyStore, + }, + push_service::DEFAULT_DEVICE_ID, + ServiceAddress, }, - ServiceAddress, + store::StateStore, }; +use sqlx::{query, query_scalar, Executor, QueryBuilder}; +use tracing::trace; -use crate::SqliteStore; +use crate::{SqliteStore, SqlxErrorExt}; #[derive(Clone)] pub struct SqliteProtocolStore { pub(crate) store: SqliteStore, + pub(crate) identity_type: &'static str, } impl ProtocolStore for SqliteProtocolStore {} @@ -29,7 +38,19 @@ impl SessionStore for SqliteProtocolStore { &self, address: &ProtocolAddress, ) -> Result, ProtocolError> { - todo!() + let uuid = address.name(); + let device_id: u32 = address.device_id().into(); + query!( + "SELECT record FROM sessions WHERE address = $1 AND device_id = $2 AND identity = $3 LIMIT 1", + uuid, + device_id, + self.identity_type + ) + .fetch_optional(&self.store.db) + .await + .into_protocol_error()? + .map(|record| SessionRecord::deserialize(&record.record)) + .transpose() } /// Set the entry for `address` to the value of `record`. @@ -38,7 +59,21 @@ impl SessionStore for SqliteProtocolStore { address: &ProtocolAddress, record: &SessionRecord, ) -> Result<(), ProtocolError> { - todo!(); + let uuid = address.name(); + let device_id: u32 = address.device_id().into(); + let record_data = record.serialize()?; + query!( + "INSERT OR REPLACE INTO sessions ( address, device_id, identity, record ) VALUES ( $1, $2, $3, $4 )", + uuid, + device_id, + self.identity_type, + record_data, + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } } @@ -51,12 +86,31 @@ impl SessionStoreExt for SqliteProtocolStore { &self, name: &ServiceAddress, ) -> Result, ProtocolError> { - todo!() + query_scalar!( + "SELECT device_id AS 'id: u32' FROM sessions WHERE address = ? AND device_id != ?", + name.uuid, + DEFAULT_DEVICE_ID + ) + .fetch_all(&self.store.db) + .await + .into_protocol_error() } /// Remove a session record for a recipient ID + device ID tuple. async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), ProtocolError> { - todo!() + let uuid = address.name(); + let device_id: u32 = address.device_id().into(); + query!( + "DELETE FROM sessions WHERE address = $1 AND device_id = $2 AND identity = $3", + uuid, + device_id, + self.identity_type + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } /// Remove the session records corresponding to all devices of a recipient @@ -64,7 +118,18 @@ impl SessionStoreExt for SqliteProtocolStore { /// /// Returns the number of deleted sessions. async fn delete_all_sessions(&self, address: &ServiceAddress) -> Result { - todo!() + let uuid = address.uuid.to_string(); + let rows = query!( + "DELETE FROM sessions WHERE address = $1 AND identity = $3", + uuid, + self.identity_type + ) + .execute(&self.store.db) + .await + .into_protocol_error()? + .rows_affected(); + + Ok(rows as usize) } } @@ -72,7 +137,12 @@ impl SessionStoreExt for SqliteProtocolStore { impl PreKeyStore for SqliteProtocolStore { /// Look up the pre-key corresponding to `prekey_id`. async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result { - todo!() + let id: u32 = prekey_id.into(); + query!("SELECT id, record FROM prekeys WHERE id = $1 LIMIT 1", id) + .fetch_one(&self.store.db) + .await + .into_protocol_error() + .and_then(|record| PreKeyRecord::deserialize(&record.record)) } /// Set the entry for `prekey_id` to the value of `record`. @@ -81,12 +151,29 @@ impl PreKeyStore for SqliteProtocolStore { prekey_id: PreKeyId, record: &PreKeyRecord, ) -> Result<(), ProtocolError> { - todo!() + let id: u32 = prekey_id.into(); + let record_data = record.serialize()?; + query!( + "INSERT OR REPLACE INTO prekeys( id, record, identity ) VALUES( ?1, ?2, ?3 )", + id, + record_data, + self.identity_type, + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } /// Remove the entry for `prekey_id`. async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), ProtocolError> { - todo!() + let id: u32 = prekey_id.into(); + let rows_affected = query!("DELETE FROM prekeys WHERE id = $1", id) + .execute(&self.store.db) + .await + .into_protocol_error()?; + Ok(()) } } @@ -94,27 +181,47 @@ impl PreKeyStore for SqliteProtocolStore { impl PreKeysStore for SqliteProtocolStore { /// ID of the next pre key async fn next_pre_key_id(&self) -> Result { - todo!() + query_scalar!("SELECT MAX(id) FROM prekeys") + .fetch_one(&self.store.db) + .await + .into_protocol_error() + .map(|record| record.map(|i| i as u32 + 1).unwrap_or_default()) } /// ID of the next signed pre key async fn next_signed_pre_key_id(&self) -> Result { - todo!() + query_scalar!("SELECT MAX(id) FROM signed_prekeys") + .fetch_one(&self.store.db) + .await + .into_protocol_error() + .map(|record| record.map(|i| i as u32 + 1).unwrap_or_default()) } /// ID of the next PQ pre key async fn next_pq_pre_key_id(&self) -> Result { - todo!() + query!("SELECT MAX(id) as 'max_id: u32' FROM kyber_prekeys") + .fetch_one(&self.store.db) + .await + .into_protocol_error() + .map(|record| record.max_id.map(|i| i + 1).unwrap_or_default()) } /// number of signed pre-keys we currently have in store async fn signed_pre_keys_count(&self) -> Result { - todo!() + let count = query_scalar!("SELECT COUNT(id) FROM signed_prekeys") + .fetch_one(&self.store.db) + .await + .into_protocol_error()?; + Ok(count as usize) } /// number of kyber pre-keys we currently have in store async fn kyber_pre_keys_count(&self, last_resort: bool) -> Result { - todo!() + let count = query_scalar!("SELECT COUNT(id) FROM kyber_prekeys") + .fetch_one(&self.store.db) + .await + .into_protocol_error()?; + Ok(count as usize) } } @@ -125,7 +232,15 @@ impl SignedPreKeyStore for SqliteProtocolStore { &self, signed_prekey_id: SignedPreKeyId, ) -> Result { - todo!() + let id: u32 = signed_prekey_id.into(); + query!( + "SELECT id, record FROM signed_prekeys WHERE id = $1 LIMIT 1", + id + ) + .fetch_one(&self.store.db) + .await + .into_protocol_error() + .and_then(|record| SignedPreKeyRecord::deserialize(&record.record)) } /// Set the entry for `signed_prekey_id` to the value of `record`. @@ -134,7 +249,19 @@ impl SignedPreKeyStore for SqliteProtocolStore { signed_prekey_id: SignedPreKeyId, record: &SignedPreKeyRecord, ) -> Result<(), ProtocolError> { - todo!() + let id: u32 = signed_prekey_id.into(); + let record_data = record.serialize()?; + query!( + "INSERT OR REPLACE INTO signed_prekeys( id, record, identity ) VALUES( ?1, ?2, ?3 )", + id, + record_data, + self.identity_type + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } } @@ -145,7 +272,15 @@ impl KyberPreKeyStore for SqliteProtocolStore { &self, kyber_prekey_id: KyberPreKeyId, ) -> Result { - todo!() + let id: u32 = kyber_prekey_id.into(); + query!( + "SELECT id, record FROM kyber_prekeys WHERE id = $1 LIMIT 1", + id + ) + .fetch_one(&self.store.db) + .await + .into_protocol_error() + .and_then(|record| KyberPreKeyRecord::deserialize(&record.record)) } /// Set the entry for `kyber_prekey_id` to the value of `record`. @@ -154,7 +289,19 @@ impl KyberPreKeyStore for SqliteProtocolStore { kyber_prekey_id: KyberPreKeyId, record: &KyberPreKeyRecord, ) -> Result<(), ProtocolError> { - todo!() + let id: u32 = kyber_prekey_id.into(); + let record_data = record.serialize()?; + query!( + "INSERT OR REPLACE INTO kyber_prekeys( id, record, identity ) VALUES( ?1, ?2, ?3 )", + id, + record_data, + self.identity_type, + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } /// Mark the entry for `kyber_prekey_id` as "used". @@ -163,7 +310,17 @@ impl KyberPreKeyStore for SqliteProtocolStore { &mut self, kyber_prekey_id: KyberPreKeyId, ) -> Result<(), ProtocolError> { - todo!() + let id: u32 = kyber_prekey_id.into(); + query!( + "DELETE FROM kyber_prekeys WHERE id = $1 AND identity = $2", + id, + self.identity_type, + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } } @@ -174,37 +331,61 @@ impl KyberPreKeyStoreExt for SqliteProtocolStore { kyber_prekey_id: KyberPreKeyId, record: &KyberPreKeyRecord, ) -> Result<(), ProtocolError> { - todo!() + let id: u32 = kyber_prekey_id.into(); + let record_data = record.serialize()?; + query!( + "INSERT OR REPLACE INTO kyber_prekeys( id, record, is_last_resort, identity ) + VALUES( ?, ?, TRUE, ? )", + id, + record_data, + self.identity_type, + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } async fn load_last_resort_kyber_pre_keys( &self, ) -> Result, ProtocolError> { - todo!() + let records = query!( + "SELECT * FROM kyber_prekeys WHERE is_last_resort = true AND identity = $1", + self.identity_type, + ) + .fetch_all(&self.store.db) + .await + .into_protocol_error()?; + + let kyber_prekeys: Result, ProtocolError> = records + .into_iter() + .map(|record| KyberPreKeyRecord::deserialize(&record.record)) + .collect(); + + Ok(kyber_prekeys?) } async fn remove_kyber_pre_key( &mut self, kyber_prekey_id: KyberPreKeyId, ) -> Result<(), ProtocolError> { - todo!() + unimplemented!("unexpected in this flow") } - /// Analogous to markAllOneTimeKyberPreKeysStaleIfNecessary async fn mark_all_one_time_kyber_pre_keys_stale_if_necessary( &mut self, stale_time: DateTime, ) -> Result<(), ProtocolError> { - todo!() + unimplemented!("unexpected in this flow") } - /// Analogue of deleteAllStaleOneTimeKyberPreKeys async fn delete_all_stale_one_time_kyber_pre_keys( &mut self, threshold: DateTime, min_count: usize, ) -> Result<(), ProtocolError> { - todo!() + unimplemented!("unexpected in this flow") } } @@ -212,7 +393,12 @@ impl KyberPreKeyStoreExt for SqliteProtocolStore { impl IdentityKeyStore for SqliteProtocolStore { /// Return the single specific identity the store is assumed to represent, with private key. async fn get_identity_key_pair(&self) -> Result { - todo!() + let key = format!("{}_identity_key_pair", self.identity_type); + let key_pair_bytes = query_scalar!("SELECT value FROM config WHERE key = ?", key) + .fetch_one(&self.store.db) + .await + .into_protocol_error()?; + IdentityKeyPair::try_from(key_pair_bytes.as_slice()) } /// Return a [u32] specific to this store instance. @@ -224,7 +410,13 @@ impl IdentityKeyStore for SqliteProtocolStore { /// may be the same, but the store registration id returned by this method should /// be regenerated. async fn get_local_registration_id(&self) -> Result { - todo!() + Ok(self + .store + .load_registration_data() + .await + .map_err(|error| ProtocolError::InvalidState("sqlite", error.to_string()))? + .map(|data| data.registration_id) + .unwrap_or_default()) } // TODO: make this into an enum instead of a bool! @@ -235,19 +427,41 @@ impl IdentityKeyStore for SqliteProtocolStore { async fn save_identity( &mut self, address: &ProtocolAddress, - identity: &IdentityKey, + identity_key: &IdentityKey, ) -> Result { - todo!() + let previous = self.get_identity(address).await?; + let ret = previous.as_ref() == Some(identity_key); + + let address = address.name(); + let record_data = identity_key.serialize(); + query!( + "INSERT OR REPLACE INTO identities ( address, record, identity ) VALUES ( $1, $2, $3 )", + address, + record_data, + self.identity_type + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(ret) } + // TODO: take this out of the store trait! /// Return whether an identity is trusted for the role specified by `direction`. async fn is_trusted_identity( &self, address: &ProtocolAddress, - identity: &IdentityKey, + identity_key: &IdentityKey, direction: Direction, ) -> Result { - todo!() + if let Some(trusted_key) = self.get_identity(address).await? { + Ok(trusted_key == *identity_key) + } else { + // Trust on first use + // TODO: we should most likely expose this behaviour as a setting + Ok(true) + } } /// Return the public identity for the given `address`, if known. @@ -255,7 +469,17 @@ impl IdentityKeyStore for SqliteProtocolStore { &self, address: &ProtocolAddress, ) -> Result, ProtocolError> { - todo!() + let address_name = address.name(); + query!( + "SELECT record FROM identities WHERE address = $1 AND identity = $2", + address_name, + self.identity_type + ) + .fetch_optional(&self.store.db) + .await + .into_protocol_error()? + .map(|record| IdentityKey::decode(&record.record)) + .transpose() } } @@ -266,10 +490,24 @@ impl SenderKeyStore for SqliteProtocolStore { &mut self, sender: &ProtocolAddress, distribution_id: Uuid, - // TODO: pass this by value! record: &SenderKeyRecord, ) -> Result<(), ProtocolError> { - todo!() + let address = sender.name(); + let device_id: u32 = sender.device_id().into(); + let record_data = record.serialize()?; + query!( + "INSERT OR REPLACE INTO sender_keys (address, device, distribution_id, record, identity) VALUES ($1, $2, $3, $4, $5)", + address, + device_id, + distribution_id, + record_data, + self.identity_type + ) + .execute(&self.store.db) + .await + .into_protocol_error()?; + + Ok(()) } /// Look up the entry corresponding to `(sender, distribution_id)`. @@ -278,6 +516,18 @@ impl SenderKeyStore for SqliteProtocolStore { sender: &ProtocolAddress, distribution_id: Uuid, ) -> Result, ProtocolError> { - todo!() + let address = sender.name(); + let device_id: u32 = sender.device_id().into(); + query!( + "SELECT record FROM sender_keys WHERE address = $1 AND device = $2 AND distribution_id = $3 AND identity = $4", + address, + device_id, + distribution_id, + self.identity_type + ) + .fetch_optional(&self.store.db) .await + .into_protocol_error()? + .map(|record| SenderKeyRecord::deserialize(&record.record)) + .transpose() } } diff --git a/presage/src/manager/confirmation.rs b/presage/src/manager/confirmation.rs index 01b8a89fb..48a215611 100644 --- a/presage/src/manager/confirmation.rs +++ b/presage/src/manager/confirmation.rs @@ -132,7 +132,8 @@ impl Manager { signal_servers: self.state.signal_servers, device_name: None, phone_number, - service_ids: ServiceIds { aci, pni }, + aci, + pni, password, signaling_key, device_id: None, diff --git a/presage/src/manager/linking.rs b/presage/src/manager/linking.rs index a954a6b6f..1db9f5558 100644 --- a/presage/src/manager/linking.rs +++ b/presage/src/manager/linking.rs @@ -129,7 +129,8 @@ impl Manager { signal_servers, device_name: Some(device_name), phone_number, - service_ids, + aci: service_ids.aci, + pni: service_ids.pni, password, signaling_key, device_id: Some(device_id.into()), @@ -154,7 +155,7 @@ impl Manager { store.save_registration_data(®istration_data).await?; info!( "successfully registered device {}", - ®istration_data.service_ids + ®istration_data.service_ids() ); let mut manager = Manager { diff --git a/presage/src/manager/registered.rs b/presage/src/manager/registered.rs index 40e846b91..27d8c34b3 100644 --- a/presage/src/manager/registered.rs +++ b/presage/src/manager/registered.rs @@ -3,6 +3,7 @@ use std::pin::pin; use std::sync::{Arc, OnceLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use derivative::Derivative; use futures::{future, AsyncReadExt, Stream, StreamExt}; use libsignal_service::attachment_cipher::decrypt_in_place; use libsignal_service::configuration::{ServiceConfiguration, SignalServers, SignalingKey}; @@ -100,8 +101,8 @@ pub struct RegistrationData { pub signal_servers: SignalServers, pub device_name: Option, pub phone_number: PhoneNumber, - #[serde(flatten)] - pub service_ids: ServiceIds, + pub aci: Uuid, + pub pni: Uuid, pub(crate) password: String, #[serde(with = "serde_signaling_key")] pub(crate) signaling_key: SignalingKey, @@ -114,14 +115,21 @@ pub struct RegistrationData { } impl RegistrationData { + pub fn service_ids(&self) -> ServiceIds { + ServiceIds { + aci: self.aci, + pni: self.pni, + } + } + /// Account identity pub fn aci(&self) -> Uuid { - self.service_ids.aci + self.aci } /// Phone number identity pub fn pni(&self) -> Uuid { - self.service_ids.pni + self.pni } /// Our own profile key @@ -327,7 +335,7 @@ impl Manager { if self.state.data.pni_registration_id.is_none() { debug!("fetching PNI UUID and updating state"); let whoami = self.whoami().await?; - self.state.data.service_ids.pni = whoami.pni; + self.state.data.pni = whoami.pni; self.store.save_registration_data(&self.state.data).await?; } @@ -377,7 +385,7 @@ impl Manager { .as_millis() as u64; self.send_message( - ServiceAddress::from_aci(self.state.data.service_ids.aci), + ServiceAddress::from_aci(self.state.data.aci), sync_message, timestamp, ) @@ -445,7 +453,7 @@ impl Manager { /// Fetches the profile (name, about, status emoji) of the registered user. pub async fn retrieve_profile(&mut self) -> Result> { - self.retrieve_profile_by_uuid(self.state.data.service_ids.aci, self.state.data.profile_key) + self.retrieve_profile_by_uuid(self.state.data.aci, self.state.data.profile_key) .await } @@ -601,7 +609,7 @@ impl Manager { let groups_credentials_cache = InMemoryCredentialsCache::default(); let groups_manager = GroupsManager::new( - self.state.data.service_ids.clone(), + self.state.data.service_ids(), self.identified_push_service(), groups_credentials_cache, server_public_params, @@ -671,7 +679,7 @@ impl Manager { info!("saving contacts"); for contact in contacts.filter_map(Result::ok) { if let Err(error) = - state.store.save_contact(&contact.into()).await + state.store.save_contact(contact.into()).await { warn!(%error, "failed to save contacts"); break; @@ -882,7 +890,7 @@ impl Manager { // save the message let content = Content { metadata: Metadata { - sender: ServiceAddress::from_aci(self.state.data.service_ids.aci), + sender: ServiceAddress::from_aci(self.state.data.aci), sender_device: self.state.device_id(), destination: recipient, server_guid: None, @@ -947,7 +955,7 @@ impl Manager { for member in group .members .into_iter() - .filter(|m| m.uuid != self.state.data.service_ids.aci) + .filter(|m| m.uuid != self.state.data.aci) { let unidentified_access = self.store @@ -987,8 +995,8 @@ impl Manager { let content = Content { metadata: Metadata { - sender: ServiceAddress::from_aci(self.state.data.service_ids.aci), - destination: ServiceAddress::from_aci(self.state.data.service_ids.aci), + sender: ServiceAddress::from_aci(self.state.data.aci), + destination: ServiceAddress::from_aci(self.state.data.aci), sender_device: self.state.device_id(), server_guid: None, timestamp, @@ -1167,8 +1175,8 @@ impl Manager { fn credentials(&self) -> Option { Some(ServiceCredentials { - aci: Some(self.state.data.service_ids.aci), - pni: Some(self.state.data.service_ids.pni), + aci: Some(self.state.data.aci), + pni: Some(self.state.data.pni), phonenumber: self.state.data.phone_number.clone(), password: Some(self.state.data.password.clone()), signaling_key: Some(self.state.data.signaling_key), @@ -1196,8 +1204,8 @@ impl Manager { self.new_service_cipher_aci(), self.rng.clone(), aci_protocol_store, - ServiceAddress::from_aci(self.state.data.service_ids.aci), - ServiceAddress::from_pni(self.state.data.service_ids.pni), + ServiceAddress::from_aci(self.state.data.aci), + ServiceAddress::from_pni(self.state.data.pni), aci_identity_keypair, Some(pni_identity_keypair), self.state.device_id().into(), @@ -1211,7 +1219,7 @@ impl Manager { self.state .service_configuration() .unidentified_sender_trust_root, - self.state.data.service_ids.aci, + self.state.data.aci, self.state.device_id(), ) } @@ -1223,7 +1231,7 @@ impl Manager { self.state .service_configuration() .unidentified_sender_trust_root, - self.state.data.service_ids.pni, + self.state.data.pni, self.state.device_id(), ) } @@ -1517,7 +1525,7 @@ async fn save_message( }; info!(%sender_uuid, "saved contact on first sight"); - store.save_contact(&contact).await?; + store.save_contact(contact).await?; } store.upsert_profile_key(&sender_uuid, profile_key).await?; diff --git a/presage/src/model/groups.rs b/presage/src/model/groups.rs index 124d355e3..2fa99c650 100644 --- a/presage/src/model/groups.rs +++ b/presage/src/model/groups.rs @@ -1,7 +1,7 @@ use derivative::Derivative; use libsignal_service::{ groups_v2::Role, - prelude::{AccessControl, Member, ProfileKey, Timer, Uuid}, + prelude::{AccessControl, ProfileKey, Timer, Uuid}, }; use serde::{Deserialize, Serialize}; @@ -21,6 +21,16 @@ pub struct Group { pub description: Option, } +#[derive(Derivative, Clone, Deserialize, Serialize)] +#[derivative(Debug)] +pub struct Member { + pub uuid: Uuid, + pub role: Role, + #[derivative(Debug = "ignore")] + pub profile_key: ProfileKey, + pub joined_at_revision: u32, +} + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] pub struct PendingMember { // for backwards compatibility @@ -48,7 +58,7 @@ impl From for Group { disappearing_messages_timer: val.disappearing_messages_timer, access_control: val.access_control, revision: val.revision, - members: val.members, + members: val.members.into_iter().map(Into::into).collect(), pending_members: val.pending_members.into_iter().map(Into::into).collect(), requesting_members: val.requesting_members.into_iter().map(Into::into).collect(), invite_link_password: val.invite_link_password, @@ -57,6 +67,17 @@ impl From for Group { } } +impl From for Member { + fn from(val: libsignal_service::groups_v2::Member) -> Self { + Member { + uuid: val.uuid, + role: val.role, + profile_key: val.profile_key, + joined_at_revision: val.joined_at_revision, + } + } +} + impl From for PendingMember { fn from(val: libsignal_service::groups_v2::PendingMember) -> Self { PendingMember { diff --git a/presage/src/serde.rs b/presage/src/serde.rs index 4beed7371..1141978dc 100644 --- a/presage/src/serde.rs +++ b/presage/src/serde.rs @@ -3,11 +3,13 @@ pub(crate) mod serde_profile_key { use base64::{engine::general_purpose, Engine}; use libsignal_service::prelude::ProfileKey; use serde::{Deserialize, Deserializer, Serializer}; + use tracing::trace; pub(crate) fn serialize(profile_key: &ProfileKey, serializer: S) -> Result where S: Serializer, { + trace!("serializing profile key"); serializer.serialize_str(&general_purpose::STANDARD.encode(profile_key.bytes)) } @@ -15,6 +17,7 @@ pub(crate) mod serde_profile_key { where D: Deserializer<'de>, { + trace!("deserializing profile key"); let bytes: [u8; 32] = general_purpose::STANDARD .decode(String::deserialize(deserializer)?) .map_err(serde::de::Error::custom)? diff --git a/presage/src/store.rs b/presage/src/store.rs index 956e95068..629a10a1c 100644 --- a/presage/src/store.rs +++ b/presage/src/store.rs @@ -166,7 +166,7 @@ pub trait ContentsStore: Send + Sync { } contact.expire_timer_version = version; contact.expire_timer = timer; - self.save_contact(&contact).await?; + self.save_contact(contact).await?; } Ok(()) } @@ -190,7 +190,7 @@ pub trait ContentsStore: Send + Sync { /// Save a contact fn save_contact( &mut self, - contacts: &Contact, + contacts: Contact, ) -> impl Future>; /// Get an iterator on all stored (synchronized) contacts