diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 4e98188e..4945c4c1 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -35,6 +35,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI return queryRes.Error } res.DeviceListsOTKCount = queryRes.Count.KeyCount + res.DeviceListsUnusedFallbackAlgorithms = queryRes.UnusedFallbackAlgorithms return nil } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index de20a608..be0fde5c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -350,13 +350,14 @@ type ToDeviceResponse struct { // Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync type Response struct { - NextBatch StreamingToken `json:"next_batch"` - AccountData *ClientEvents `json:"account_data,omitempty"` - Presence *ClientEvents `json:"presence,omitempty"` - Rooms *RoomsResponse `json:"rooms,omitempty"` - ToDevice *ToDeviceResponse `json:"to_device,omitempty"` - DeviceLists *DeviceLists `json:"device_lists,omitempty"` - DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` + NextBatch StreamingToken `json:"next_batch"` + AccountData *ClientEvents `json:"account_data,omitempty"` + Presence *ClientEvents `json:"presence,omitempty"` + Rooms *RoomsResponse `json:"rooms,omitempty"` + ToDevice *ToDeviceResponse `json:"to_device,omitempty"` + DeviceLists *DeviceLists `json:"device_lists,omitempty"` + DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` + DeviceListsUnusedFallbackAlgorithms []string `json:"device_unused_fallback_key_types"` } func (r Response) MarshalJSON() ([]byte, error) { @@ -419,6 +420,7 @@ func NewResponse() *Response { res.DeviceLists = &DeviceLists{} res.ToDevice = &ToDeviceResponse{} res.DeviceListsOTKCount = map[string]int{} + res.DeviceListsUnusedFallbackAlgorithms = []string{} return &res } diff --git a/sytest-whitelist b/sytest-whitelist index 35d700d0..540ae99b 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -793,4 +793,6 @@ remote user can join room with version 11 User can invite remote user to room with version 11 Remote user can backfill in a room with version 11 Can reject invites over federation for rooms with version 11 -Can receive redactions from regular users over federation in room version 11 \ No newline at end of file +Can receive redactions from regular users over federation in room version 11 +Can upload self-signing keys +uploading signed devices gets propagated over federation diff --git a/userapi/api/api.go b/userapi/api/api.go index 6da12fc9..26482129 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -788,12 +788,30 @@ type OneTimeKeysCount struct { KeyCount map[string]int } +// FallbackKeys represents a set of fallback keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type FallbackKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + +// Split a key in KeyJSON into algorithm and key ID +func (k *FallbackKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + // PerformUploadKeysRequest is the request to PerformUploadKeys type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + FallbackKeys []FallbackKeys // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // the display name for their respective device, and NOT to modify the keys. The key // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. @@ -810,8 +828,9 @@ type PerformUploadKeysResponse struct { // A fatal error when processing e.g database failures Error *KeyError // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount + FallbackKeysUnusedAlgorithms []string } // PerformDeleteKeysRequest asks the keyserver to forget about certain @@ -917,8 +936,9 @@ type QueryOneTimeKeysRequest struct { type QueryOneTimeKeysResponse struct { // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 - Count OneTimeKeysCount - Error *KeyError + Count OneTimeKeysCount + UnusedFallbackAlgorithms []string + Error *KeyError } type QueryDeviceMessagesRequest struct { diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 09ead2c5..4ae1f09f 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor if len(req.DeviceKeys) > 0 { a.uploadLocalDeviceKeys(ctx, req, res) } - if len(req.OneTimeKeys) > 0 { - a.uploadOneTimeKeys(ctx, req, res) + if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 { + a.uploadOneTimeAndFallbackKeys(ctx, req, res) } otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } + algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err), + } + return nil + } res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} + res.FallbackKeysUnusedAlgorithms = algos return nil } @@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn } return nil } + algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err), + } + return nil + } res.Count = *count + res.UnusedFallbackAlgorithms = algos return nil } @@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( for userID := range userIDsForAllDevices { err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID) if err != nil { + if errors.Is(err, context.Canceled) { + return + } logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, "user_id": userID, @@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( // user so the fact that we're populating all devices here isn't a problem so long as we have devices. err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) if err != nil { + if errors.Is(err, context.Canceled) { + return + } logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, "user_id": userID, @@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe } } -func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { if req.UserID == "" { res.Error = &api.KeyError{ Err: "user ID missing", @@ -768,7 +790,30 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor // collect counts res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) } - + if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()), + }) + return + } + for _, key := range req.FallbackKeys { + // grab existing keys based on (user/device/algorithm/key ID) + keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) + i := 0 + for keyIDWithAlgo := range key.KeyJSON { + keyIDsWithAlgorithms[i] = keyIDWithAlgo + i++ + } + unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()), + }) + continue + } + // collect counts + res.FallbackKeysUnusedAlgorithms = unused + } } func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 7767f6cd..2a46a7fd 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -167,6 +167,15 @@ type KeyDatabase interface { // OneTimeKeysCount returns a count of all OTKs for this device. OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + // StoreFallbackKeys persists the given fallback keys. + StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error) + + // UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device. + UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) + + // DeleteFallbackKeys deletes all fallback keys for the user. + DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error diff --git a/userapi/storage/postgres/fallback_keys_table.go b/userapi/storage/postgres/fallback_keys_table.go new file mode 100644 index 00000000..acae7ed6 --- /dev/null +++ b/userapi/storage/postgres/fallback_keys_table.go @@ -0,0 +1,134 @@ +// Copyright 2024 New Vector Ltd. +// Copyright 2017 Vector Creations Ltd +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/element-hq/dendrite/internal" + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" +) + +var fallbackKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_fallback_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + used BOOLEAN NOT NULL, + -- Clobber based on tuple of user/device/algorithm. + CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id); +` + +const upsertFallbackKeysSQL = "" + + "INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" + + " VALUES ($1, $2, $3, $4, $5, $6, false)" + + " ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" + + " DO UPDATE SET key_id = $3, key_json = $6, used = false" + +const selectFallbackUnusedAlgorithmsSQL = "" + + "SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false" + +const selectFallbackKeysByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1" + +const deleteFallbackKeysSQL = "" + + "DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2" + +const updateFallbackKeyUsedSQL = "" + + "UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4" + +type fallbackKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectUnusedAlgorithmsStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteFallbackKeysStmt *sql.Stmt + updateFallbackKeyUsedStmt *sql.Stmt +} + +func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) { + s := &fallbackKeysStatements{ + db: db, + } + _, err := db.Exec(fallbackKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertFallbackKeysSQL}, + {&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL}, + {&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL}, + {&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL}, + {&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL}, + }.Prepare(db) +} + +func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) { + rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + algos := []string{} + for rows.Next() { + var algorithm string + if err = rows.Scan(&algorithm); err != nil { + return nil, err + } + algos = append(algos, algorithm) + } + return algos, rows.Err() +} + +func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) { + now := time.Now().Unix() + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID) +} + +func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 696e1aa6..c7fb9d29 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp if err != nil { return nil, err } + fk, err := NewPostgresFallbackKeysTable(db) + if err != nil { + return nil, err + } dk, err := NewPostgresDeviceKeysTable(db) if err != nil { return nil, err @@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp return &shared.KeyDatabase{ OneTimeKeysTable: otk, + FallbackKeysTable: fk, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 2b1885cd..44ace733 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -57,6 +57,7 @@ type Database struct { type KeyDatabase struct { OneTimeKeysTable tables.OneTimeKeys + FallbackKeysTable tables.FallbackKeys DeviceKeysTable tables.DeviceKeys KeyChangesTable tables.KeyChanges StaleDeviceListsTable tables.StaleDeviceLists @@ -937,6 +938,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } +func (d *KeyDatabase) StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) (unused []string, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + unused, err = d.FallbackKeysTable.InsertFallbackKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error { + return d.FallbackKeysTable.DeleteFallbackKeys(ctx, nil, userID, deviceID) +} + +func (d *KeyDatabase) UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) { + return d.FallbackKeysTable.SelectUnusedFallbackKeyAlgorithms(ctx, userID, deviceID) +} + func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } @@ -999,6 +1016,12 @@ func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map if err != nil { return err } + if len(keyJSON) == 0 { + keyJSON, err = d.FallbackKeysTable.SelectAndUpdateFallbackKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + } if keyJSON != nil { result = append(result, api.OneTimeKeys{ UserID: userID, diff --git a/userapi/storage/sqlite3/fallback_keys_table.go b/userapi/storage/sqlite3/fallback_keys_table.go new file mode 100644 index 00000000..2eb99813 --- /dev/null +++ b/userapi/storage/sqlite3/fallback_keys_table.go @@ -0,0 +1,132 @@ +// Copyright 2024 New Vector Ltd. +// Copyright 2017 Vector Creations Ltd +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/element-hq/dendrite/internal" + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" +) + +var fallbackKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_fallback_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + used BOOLEAN NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS keyserver_fallback_keys_unique_idx ON keyserver_fallback_keys(user_id, device_id, algorithm); +CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id); +` + +const upsertFallbackKeysSQL = "" + + "INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" + + " VALUES ($1, $2, $3, $4, $5, $6, false)" + + " ON CONFLICT (user_id, device_id, algorithm)" + + " DO UPDATE SET key_id = $3, key_json = $6, used = false" + +const selectFallbackUnusedAlgorithmsSQL = "" + + "SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false" + +const selectFallbackKeysByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1" + +const deleteFallbackKeysSQL = "" + + "DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2" + +const updateFallbackKeyUsedSQL = "" + + "UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4" + +type fallbackKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectUnusedAlgorithmsStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteFallbackKeysStmt *sql.Stmt + updateFallbackKeyUsedStmt *sql.Stmt +} + +func NewSqliteFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) { + s := &fallbackKeysStatements{ + db: db, + } + _, err := db.Exec(fallbackKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertFallbackKeysSQL}, + {&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL}, + {&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL}, + {&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL}, + {&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL}, + }.Prepare(db) +} + +func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) { + rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + algos := []string{} + for rows.Next() { + var algorithm string + if err = rows.Scan(&algorithm); err != nil { + return nil, err + } + algos = append(algos, algorithm) + } + return algos, rows.Err() +} + +func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) { + now := time.Now().Unix() + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID) +} + +func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index c57cc153..6d906191 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -138,6 +138,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp if err != nil { return nil, err } + fk, err := NewSqliteFallbackKeysTable(db) + if err != nil { + return nil, err + } dk, err := NewSqliteDeviceKeysTable(db) if err != nil { return nil, err @@ -161,6 +165,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp return &shared.KeyDatabase{ OneTimeKeysTable: otk, + FallbackKeysTable: fk, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 68198b37..189c1dd8 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -809,3 +809,42 @@ func TestOneTimeKeys(t *testing.T) { } }) } + +func TestFallbackKeys(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + userID := "@alice:localhost" + deviceID := "alice_device" + fk := api.FallbackKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)}, + } + + _, err := db.StoreFallbackKeys(ctx, fk) + MustNotError(t, err) + + unused, err := db.UnusedFallbackKeyAlgorithms(ctx, userID, deviceID) + MustNotError(t, err) + if c := len(unused); c != 1 { + t.Fatalf("Expected 1 unused key algorithm, got %d", c) + } + if unused[0] != "curve25519" { + t.Fatalf("Expected unused key algorithm to be 'curve25519', got '%s'", unused[0]) + } + + // No other one-time keys have been uploaded so we expect to get the fallback key instead. + claimed, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}}) + MustNotError(t, err) + + switch { + case claimed[0].UserID != fk.UserID: + t.Fatalf("Claimed user ID ID doesn't match, got %q, want %q", claimed[0].UserID, fk.DeviceID) + case claimed[0].DeviceID != fk.DeviceID: + t.Fatalf("Claimed device ID doesn't match, got %q, want %q", claimed[0].DeviceID, fk.DeviceID) + case claimed[0].KeyJSON["curve25519:KEY1"] == nil: + t.Fatalf("Claimed key JSON for curve25519:KEY1 not found") + } + }) +} diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 7d4cfbae..44f31a5c 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -170,6 +170,13 @@ type DeviceKeys interface { DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } +type FallbackKeys interface { + SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) + InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) + DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + SelectAndUpdateFallbackKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) +} + type KeyChanges interface { InsertKeyChange(ctx context.Context, userID string) (int64, error) // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.