From be85d1264c50521f27f9f28bc84ed57c25d194e3 Mon Sep 17 00:00:00 2001 From: Subrose Date: Wed, 22 Nov 2023 16:04:06 +0000 Subject: [PATCH 01/13] stable --- api/core.go | 1 - docker-compose.ci.yml | 32 ++++++++++++++++++++------------ docker-compose.yml | 24 +++++++++++------------- simulator/wait.py | 1 + vault/sql.go | 39 +++++++++++++++++++++++++++++++++++---- 5 files changed, 67 insertions(+), 30 deletions(-) diff --git a/api/core.go b/api/core.go index 8691d77..8f6347b 100644 --- a/api/core.go +++ b/api/core.go @@ -114,7 +114,6 @@ func CreateCore(conf *CoreConfig) (*Core, error) { // conf.DB_DB, // ) db, err := _vault.NewSqlStore(_vault.FormatDsn(conf.DB_HOST, conf.DB_USER, conf.DB_PASSWORD, conf.DB_NAME, conf.DB_PORT)) - if err != nil { panic(err) } diff --git a/docker-compose.ci.yml b/docker-compose.ci.yml index e7e4f2f..c9fb3c8 100644 --- a/docker-compose.ci.yml +++ b/docker-compose.ci.yml @@ -1,21 +1,29 @@ version: "3.8" services: - keydb: - image: eqalpha/keydb:x86_64_v6.3.3 - container_name: keydb - command: keydb-server --server-threads 1 --protected-mode no --appendonly yes + # keydb: + # image: eqalpha/keydb:x86_64_v6.3.3 + # container_name: keydb + # command: keydb-server --server-threads 1 --protected-mode no --appendonly yes + # ports: + # - 6379:6379 + # restart: unless-stopped + # volumes: + # - ./keydb/redis.conf:/etc/keydb/redis.conf + postgres: + image: postgres:16.1-alpine ports: - - 6379:6379 - restart: unless-stopped - volumes: - - ./keydb/redis.conf:/etc/keydb/redis.conf + - 5432:5432 + environment: + - POSTGRES_PASSWORD=postgres + - POSTGRES_USER=postgres + - POSTGRES_DB=postgres api: profiles: ["simulations"] environment: - - VAULT_DB_HOST=keydb - - VAULT_DB_PORT=6379 - - VAULT_DB_USER=default - - VAULT_DB_PASSWORD= + - VAULT_DB_HOST=postgres + - VAULT_DB_PORT=5432 + - VAULT_DB_USER=postgres + - VAULT_DB_PASSWORD=postgres - VAULT_API_HOST=0.0.0.0 - VAULT_API_PORT=3001 - VAULT_ADMIN_USERNAME=admin diff --git a/docker-compose.yml b/docker-compose.yml index 3ca864a..fff46f6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,24 +19,22 @@ services: ports: - 3001:3001 depends_on: - - keydb + - postgres volumes: - "./:/app" - keydb: - image: eqalpha/keydb:arm64_v6.3.3 - container_name: keydb - command: keydb-server --server-threads 4 --protected-mode no --appendonly yes - ports: - - 6379:6379 - restart: unless-stopped - volumes: - - ./keydb/redis.conf:/etc/keydb/redis.conf + # keydb: + # image: eqalpha/keydb:arm64_v6.3.3 + # container_name: keydb + # command: keydb-server --server-threads 4 --protected-mode no --appendonly yes + # ports: + # - 6379:6379 + # restart: unless-stopped + # volumes: + # - ./keydb/redis.conf:/etc/keydb/redis.conf postgres: - image: postgres:14-alpine + image: postgres:16.1-alpine ports: - 5432:5432 - volumes: - - ~/apps/postgres:/var/lib/postgresql/data environment: - POSTGRES_PASSWORD=postgres - POSTGRES_USER=postgres diff --git a/simulator/wait.py b/simulator/wait.py index 5dcfb21..7ba1641 100644 --- a/simulator/wait.py +++ b/simulator/wait.py @@ -25,6 +25,7 @@ def wait_for_api(vault_url: str) -> None: try: response = requests.get(f"{vault_url}/health") response.raise_for_status() + print("API is up and ready") except requests.exceptions.ConnectionError: logger.info("Waiting for api to be ready ...") raise tenacity.TryAgain diff --git a/vault/sql.go b/vault/sql.go index 3230dae..726b052 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -1,5 +1,12 @@ package vault +// TODO: +// - Dynamic collection creation (no updates) +// - Error handling +// - Tidy DB Models +// - Ensure we never log sensitive data +// - Add indexes + import ( "context" "encoding/json" @@ -47,7 +54,9 @@ func FormatDsn(host string, user string, password string, dbName string, port in } func NewSqlStore(dsn string) (*SqlStore, error) { - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + TranslateError: true, + }) db.AutoMigrate(&GormCollection{}, &GormRecord{}, &GormPrincipal{}, &GormPolicy{}, &GormToken{}) return &SqlStore{db}, err @@ -84,7 +93,15 @@ func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, } gormCol := GormCollection{Name: c.Name, Collection: datatypes.JSON(b)} - return c.Name, st.db.Create(&gormCol).Error + result := st.db.Create(&gormCol) + if result.Error != nil { + switch result.Error { + case gorm.ErrDuplicatedKey: + return "", &ConflictError{c.Name} + } + } + + return c.Name, nil } func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { @@ -172,7 +189,14 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err return err } gPrincipal := GormPrincipal{Username: principal.Username, Principal: datatypes.JSON(p)} - return st.db.Create(&gPrincipal).Error + err = st.db.Create(&gPrincipal).Error + if err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return &ConflictError{principal.Username} + } + return err + } + return nil } func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { @@ -215,7 +239,14 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { return "", err } gormPol := GormPolicy{PolicyId: p.PolicyId, Policy: datatypes.JSON(pol)} - return p.PolicyId, st.db.Create(&gormPol).Error + err = st.db.Create(&gormPol).Error + if err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return "", &ConflictError{p.PolicyId} + } + return "", err + } + return p.PolicyId, nil } func (st SqlStore) DeletePolicy(ctx context.Context, policyId string) error { From fed2b7ce3b8685b74605df95a134aba480cb8e85 Mon Sep 17 00:00:00 2001 From: Subrose Date: Wed, 22 Nov 2023 16:42:45 +0000 Subject: [PATCH 02/13] flattened policies --- vault/go.mod | 1 + vault/go.sum | 2 ++ vault/sql.go | 64 +++++++++++++++++++++++++++++++++----------------- vault/vault.go | 48 ++++++++++++++++++------------------- 4 files changed, 69 insertions(+), 46 deletions(-) diff --git a/vault/go.mod b/vault/go.mod index b343f52..692449d 100644 --- a/vault/go.mod +++ b/vault/go.mod @@ -21,6 +21,7 @@ require ( github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/text v0.13.0 // indirect diff --git a/vault/go.sum b/vault/go.sum index 4347561..e86313f 100644 --- a/vault/go.sum +++ b/vault/go.sum @@ -28,6 +28,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/nyaruka/phonenumbers v1.1.6 h1:DcueYq7QrOArAprAYNoQfDgp0KetO4LqtnBtQC6Wyes= diff --git a/vault/sql.go b/vault/sql.go index 726b052..9ec8550 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -12,7 +12,9 @@ import ( "encoding/json" "errors" "fmt" + "time" + "github.com/lib/pq" "gorm.io/datatypes" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -38,9 +40,13 @@ type GormPrincipal struct { Principal datatypes.JSON } -type GormPolicy struct { - PolicyId string `gorm:"primaryKey"` - Policy datatypes.JSON +type DbPolicy struct { + ID string `gorm:"primaryKey"` + Effect string + Actions pq.StringArray `gorm:"type:text[]"` + Resources pq.StringArray `gorm:"type:text[]"` + CreatedAt time.Time + UpdatedAt time.Time } type GormToken struct { @@ -57,7 +63,7 @@ func NewSqlStore(dsn string) (*SqlStore, error) { db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ TranslateError: true, }) - db.AutoMigrate(&GormCollection{}, &GormRecord{}, &GormPrincipal{}, &GormPolicy{}, &GormToken{}) + db.AutoMigrate(&GormCollection{}, &GormRecord{}, &GormPrincipal{}, &DbPolicy{}, &GormToken{}) return &SqlStore{db}, err } @@ -204,8 +210,8 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { } func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { - var gp GormPolicy - err := st.db.First(&gp, "policy_id = ?", policyId).Error + var gp DbPolicy + err := st.db.First(&gp, "id = ?", policyId).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, &NotFoundError{"policy", policyId} @@ -213,33 +219,42 @@ func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, err return nil, err } var p Policy - err = json.Unmarshal(gp.Policy, &p) return &p, err } func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) { - var gPolicies []GormPolicy - err := st.db.Where("policy_id IN ?", policyIds).Find(&gPolicies).Error + var dBPolicies []DbPolicy + err := st.db.Find(&dBPolicies, policyIds).Error if err != nil { return nil, err } - var policyPtrs []*Policy - for _, policy := range gPolicies { - var p Policy - json.Unmarshal(policy.Policy, &p) - policyPtrs = append(policyPtrs, &p) + + policies := make([]*Policy, len(dBPolicies)) + for i, dbp := range dBPolicies { + actions := dbp.Actions + var policyActions []PolicyAction + for _, action := range actions { + policyActions = append(policyActions, PolicyAction(action)) + } + + policies[i] = &Policy{ + PolicyId: dbp.ID, + Effect: PolicyEffect(dbp.Effect), + Actions: policyActions, + Resources: dbp.Resources, + } } - return policyPtrs, err + return policies, nil } func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { - pol, err := json.Marshal(p) - if err != nil { - return "", err + actionStrings := make([]string, len(p.Actions)) + for i, action := range p.Actions { + actionStrings[i] = string(action) } - gormPol := GormPolicy{PolicyId: p.PolicyId, Policy: datatypes.JSON(pol)} - err = st.db.Create(&gormPol).Error + dbPolicy := DbPolicy{ID: p.PolicyId, Effect: string(p.Effect), Actions: actionStrings, Resources: p.Resources} + err := st.db.Create(&dbPolicy).Error if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return "", &ConflictError{p.PolicyId} @@ -250,7 +265,7 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { } func (st SqlStore) DeletePolicy(ctx context.Context, policyId string) error { - gp := GormPolicy{PolicyId: policyId} + gp := DbPolicy{ID: policyId} return st.db.Delete(gp).Error } @@ -270,5 +285,10 @@ func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, e } func (st SqlStore) Flush(ctx context.Context) error { - return st.db.Exec("DELETE FROM gorm_collections;DELETE FROM gorm_records; DELETE FROM gorm_principals; DELETE FROM gorm_policies;").Error + tables := []string{} + st.db.Raw("SELECT tablename FROM pg_tables WHERE schemaname='public'").Scan(&tables) + for _, table := range tables { + st.db.Exec("DELETE FROM " + table) + } + return nil } diff --git a/vault/vault.go b/vault/vault.go index 37a6b19..3df7bb9 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -16,7 +16,7 @@ type Field struct { } type Collection struct { - Name string `redis:"name" gorm:"primaryKey"` + Name string `redis:"name"` Fields map[string]Field `redis:"fields"` } @@ -30,29 +30,6 @@ type Principal struct { Policies []string `redis:"policies"` } -type VaultDB interface { - GetCollection(ctx context.Context, name string) (*Collection, error) - GetCollections(ctx context.Context) ([]string, error) - CreateCollection(ctx context.Context, c Collection) (string, error) - DeleteCollection(ctx context.Context, name string) error - CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) - GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) - GetRecordsFilter(ctx context.Context, collectionName string, fieldName string, value string) ([]string, error) - UpdateRecord(ctx context.Context, collectionName string, recordID string, record Record) error - DeleteRecord(ctx context.Context, collectionName string, recordID string) error - GetPrincipal(ctx context.Context, username string) (*Principal, error) - CreatePrincipal(ctx context.Context, principal Principal) error - DeletePrincipal(ctx context.Context, username string) error - GetPolicy(ctx context.Context, policyId string) (*Policy, error) - GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) - CreatePolicy(ctx context.Context, p Policy) (string, error) - DeletePolicy(ctx context.Context, policyId string) error - CreateToken(ctx context.Context, tokenId string, value string) error - DeleteToken(ctx context.Context, tokenId string) error - GetTokenValue(ctx context.Context, tokenId string) (string, error) - Flush(ctx context.Context) error -} - type Privatiser interface { Encrypt(string) (string, error) Decrypt(string) (string, error) @@ -114,6 +91,29 @@ const ( FIELDS_PPATH = "/fields" ) +type VaultDB interface { + GetCollection(ctx context.Context, name string) (*Collection, error) + GetCollections(ctx context.Context) ([]string, error) + CreateCollection(ctx context.Context, c Collection) (string, error) + DeleteCollection(ctx context.Context, name string) error + CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) + GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) + GetRecordsFilter(ctx context.Context, collectionName string, fieldName string, value string) ([]string, error) + UpdateRecord(ctx context.Context, collectionName string, recordID string, record Record) error + DeleteRecord(ctx context.Context, collectionName string, recordID string) error + GetPrincipal(ctx context.Context, username string) (*Principal, error) + CreatePrincipal(ctx context.Context, principal Principal) error + DeletePrincipal(ctx context.Context, username string) error + GetPolicy(ctx context.Context, policyId string) (*Policy, error) + GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) + CreatePolicy(ctx context.Context, p Policy) (string, error) + DeletePolicy(ctx context.Context, policyId string) error + CreateToken(ctx context.Context, tokenId string, value string) error + DeleteToken(ctx context.Context, tokenId string) error + GetTokenValue(ctx context.Context, tokenId string) (string, error) + Flush(ctx context.Context) error +} + func (vault Vault) GetCollection( ctx context.Context, principal Principal, From 8560b27f93b3e759ae1ef460647f5e19cf1aa80d Mon Sep 17 00:00:00 2001 From: Subrose Date: Wed, 22 Nov 2023 17:19:26 +0000 Subject: [PATCH 03/13] fix connection string --- api/core.go | 16 ++------- api/policies.go | 4 ++- api/testing_utils.go | 7 +--- conf/dev.conf.toml | 7 +--- conf/test.conf.toml | 7 +--- docker-compose.ci.yml | 5 +-- docker-compose.yml | 5 +-- vault/redis.go | 2 +- vault/sql.go | 82 +++++++++++++++++++++++-------------------- vault/vault.go | 24 ++++++------- vault/vault_test.go | 6 +--- 11 files changed, 68 insertions(+), 97 deletions(-) diff --git a/api/core.go b/api/core.go index 8f6347b..85ecdb5 100644 --- a/api/core.go +++ b/api/core.go @@ -14,12 +14,7 @@ import ( // CoreConfig is used to parameterize a core type CoreConfig struct { - DB_HOST string - DB_PORT int - DB_USER string - DB_PASSWORD string - DB_DB int - DB_NAME string + DATABASE_URL string VAULT_ENCRYPTION_KEY string VAULT_ENCRYPTION_SECRET string VAULT_SIGNING_KEY string @@ -65,12 +60,7 @@ func ReadConfigs(configPath string) (*CoreConfig, error) { } // Inject - conf.DB_HOST = Config.String("db_host") - conf.DB_PORT = Config.Int("db_port") - conf.DB_USER = Config.String("db_user") - conf.DB_PASSWORD = Config.String("db_password") - conf.DB_DB = Config.Int("db_db") - conf.DB_NAME = Config.String("db_name") + conf.DATABASE_URL = Config.String("database_url") conf.VAULT_ENCRYPTION_KEY = Config.String("encryption_key") conf.VAULT_ENCRYPTION_SECRET = Config.String("encryption_secret") conf.VAULT_ADMIN_USERNAME = Config.String("admin_access_key") @@ -113,7 +103,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { // conf.DB_PASSWORD, // conf.DB_DB, // ) - db, err := _vault.NewSqlStore(_vault.FormatDsn(conf.DB_HOST, conf.DB_USER, conf.DB_PASSWORD, conf.DB_NAME, conf.DB_PORT)) + db, err := _vault.NewSqlStore(conf.DATABASE_URL) if err != nil { panic(err) } diff --git a/api/policies.go b/api/policies.go index 77b2c3b..fdcfa5b 100644 --- a/api/policies.go +++ b/api/policies.go @@ -9,7 +9,7 @@ import ( func (core *Core) GetPolicies(c *fiber.Ctx) error { sessionPrincipal := GetSessionPrincipal(c) - policies, err := core.vault.GetPolicies(c.Context(), sessionPrincipal) + policies, err := core.vault.GetPrincipalPolicies(c.Context(), sessionPrincipal) if err != nil { return err } @@ -56,3 +56,5 @@ func (core *Core) DeletePolicy(c *fiber.Ctx) error { } return c.SendStatus(http.StatusNoContent) } + +// Note: You cannot update a policy diff --git a/api/testing_utils.go b/api/testing_utils.go index 82bd1ec..248281d 100644 --- a/api/testing_utils.go +++ b/api/testing_utils.go @@ -43,12 +43,7 @@ func InitTestingVault(t *testing.T) (*fiber.App, *Core) { // coreConfig.DB_PASSWORD, // coreConfig.DB_DB, // ) - db, err := _vault.NewSqlStore(_vault.FormatDsn( - coreConfig.DB_HOST, - coreConfig.DB_USER, - coreConfig.DB_PASSWORD, - coreConfig.DB_NAME, - coreConfig.DB_PORT)) + db, err := _vault.NewSqlStore(coreConfig.DATABASE_URL) if err != nil { t.Fatal("Failed to create db", err) diff --git a/conf/dev.conf.toml b/conf/dev.conf.toml index dcb83b8..72a8d9f 100644 --- a/conf/dev.conf.toml +++ b/conf/dev.conf.toml @@ -1,10 +1,5 @@ [db] -host = "localhost" -port = 5432 -name = "postgres" -user = "postgres" -password = "postgres" -db = 0 +url = "postgres://postgres:postgres@postgres:5432/postgres?sslmode=disable" [system] env_prefix = "VAULT_" diff --git a/conf/test.conf.toml b/conf/test.conf.toml index a4cf3ea..a058c05 100644 --- a/conf/test.conf.toml +++ b/conf/test.conf.toml @@ -1,10 +1,5 @@ [db] -host = "localhost" -port = 5432 -name = "postgres" -user = "postgres" -password = "postgres" -db = 0 +url = "postgres://postgres:postgres@postgres:5432/postgres?sslmode=disable" [system] env_prefix = "VAULT_" diff --git a/docker-compose.ci.yml b/docker-compose.ci.yml index c9fb3c8..8a75105 100644 --- a/docker-compose.ci.yml +++ b/docker-compose.ci.yml @@ -20,10 +20,7 @@ services: api: profiles: ["simulations"] environment: - - VAULT_DB_HOST=postgres - - VAULT_DB_PORT=5432 - - VAULT_DB_USER=postgres - - VAULT_DB_PASSWORD=postgres + - VAULT_DATABASE_URL=postgres://postgres:postgres@postgres:5432/postgres?sslmode=disable - VAULT_API_HOST=0.0.0.0 - VAULT_API_PORT=3001 - VAULT_ADMIN_USERNAME=admin diff --git a/docker-compose.yml b/docker-compose.yml index fff46f6..dfbd1bc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,10 +2,7 @@ version: "3.8" services: api: environment: - - VAULT_DB_HOST=postgres - - VAULT_DB_PORT=5432 - - VAULT_DB_USER=postgres - - VAULT_DB_PASSWORD=postgres + - VAULT_DATABASE_URL=postgres://postgres:postgres@postgres:5432/postgres?sslmode=disable - VAULT_API_HOST=0.0.0.0 - VAULT_API_PORT=3001 - VAULT_ADMIN_USERNAME=admin diff --git a/vault/redis.go b/vault/redis.go index 81f1f93..60c1ccd 100644 --- a/vault/redis.go +++ b/vault/redis.go @@ -454,7 +454,7 @@ func (rs RedisStore) GetPolicy(ctx context.Context, policyId string) (*Policy, e return rawPolicy.toPolicy(), nil } -func (rs RedisStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) { +func (rs RedisStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { policies := []*Policy{} pipeline := rs.Client.Pipeline() diff --git a/vault/sql.go b/vault/sql.go index 9ec8550..07fa459 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -11,7 +11,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "time" "github.com/lib/pq" @@ -24,20 +23,24 @@ type SqlStore struct { db *gorm.DB } -type GormRecord struct { +type DbRecord struct { Id string CollectionName string Record datatypes.JSON } -type GormCollection struct { +type DbCollection struct { Name string `gorm:"primaryKey"` Collection datatypes.JSON } -type GormPrincipal struct { - Username string `gorm:"primaryKey"` - Principal datatypes.JSON +type DbPrincipal struct { + Username string `gorm:"primaryKey"` + Password string + Description string + CreatedAt time.Time + UpdatedAt time.Time + PolicyIds pq.StringArray `gorm:"type:text[]"` } type DbPolicy struct { @@ -49,27 +52,22 @@ type DbPolicy struct { UpdatedAt time.Time } -type GormToken struct { - TokenId string `gorm:"primaryKey"` - Value string -} - -func FormatDsn(host string, user string, password string, dbName string, port int) string { - // TODO: Add sslmode - return fmt.Sprintf("host=%v user=%v password=%v dbname=%v port=%v sslmode=disable", host, user, password, dbName, port) +type DbToken struct { + ID string `gorm:"primaryKey"` + Value string } func NewSqlStore(dsn string) (*SqlStore, error) { db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ TranslateError: true, }) - db.AutoMigrate(&GormCollection{}, &GormRecord{}, &GormPrincipal{}, &DbPolicy{}, &GormToken{}) + db.AutoMigrate(&DbCollection{}, &DbRecord{}, &DbPrincipal{}, &DbPolicy{}, &DbToken{}) return &SqlStore{db}, err } func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) { - var gc GormCollection + var gc DbCollection err := st.db.First(&gc, "name = ?", name).Error if err != nil { return nil, err @@ -81,7 +79,7 @@ func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, } func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) { - var gcs []GormCollection + var gcs []DbCollection err := st.db.Find(&gcs).Error @@ -98,7 +96,7 @@ func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, return "", err } - gormCol := GormCollection{Name: c.Name, Collection: datatypes.JSON(b)} + gormCol := DbCollection{Name: c.Name, Collection: datatypes.JSON(b)} result := st.db.Create(&gormCol) if result.Error != nil { switch result.Error { @@ -111,14 +109,14 @@ func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, } func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { - gc := GormCollection{Name: name, Collection: nil} + gc := DbCollection{Name: name, Collection: nil} return st.db.Delete(&gc).Error } func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { recordIds := make([]string, len(records)) - gormRecords := make([]GormRecord, len(records)) + gormRecords := make([]DbRecord, len(records)) for i, record := range records { recordId := GenerateId() @@ -126,7 +124,7 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec if err != nil { return nil, err } - gormRecords[i] = GormRecord{Id: recordId, CollectionName: collectionName, Record: datatypes.JSON(jsonBytes)} + gormRecords[i] = DbRecord{Id: recordId, CollectionName: collectionName, Record: datatypes.JSON(jsonBytes)} recordIds[i] = recordId } err := st.db.CreateInBatches(&gormRecords, len(records)).Error @@ -137,7 +135,7 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec } func (st SqlStore) GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) { - var grs []GormRecord + var grs []DbRecord err := st.db.Where("id IN ?", recordIDs).Find(&grs).Error if err != nil { return nil, err @@ -165,18 +163,18 @@ func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, reco if err != nil { return err } - gr := GormRecord{Id: recordID, CollectionName: collectionName, Record: datatypes.JSON(r)} - return st.db.Model(&GormRecord{}).Where("id = ?", recordID).Updates(gr).Error + gr := DbRecord{Id: recordID, CollectionName: collectionName, Record: datatypes.JSON(r)} + return st.db.Model(&DbRecord{}).Where("id = ?", recordID).Updates(gr).Error } func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, recordID string) error { - gr := GormRecord{Id: recordID, CollectionName: collectionName} + gr := DbRecord{Id: recordID, CollectionName: collectionName} return st.db.Delete(&gr).Error } func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { - var gPrincipal GormPrincipal - err := st.db.First(&gPrincipal, "username = ?", username).Error + var dbPrincipal DbPrincipal + err := st.db.First(&dbPrincipal, "username = ?", username).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, &NotFoundError{"principal", username} @@ -184,18 +182,24 @@ func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principa return nil, err } - var principal Principal - err = json.Unmarshal(gPrincipal.Principal, &principal) + principal := Principal{ + Username: dbPrincipal.Username, + Password: dbPrincipal.Password, + Description: dbPrincipal.Description, + Policies: dbPrincipal.PolicyIds, + } + return &principal, err } func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) error { - p, err := json.Marshal(principal) - if err != nil { - return err + dbPrincipal := DbPrincipal{ + Username: principal.Username, + Password: principal.Password, + Description: principal.Description, + PolicyIds: principal.Policies, } - gPrincipal := GormPrincipal{Username: principal.Username, Principal: datatypes.JSON(p)} - err = st.db.Create(&gPrincipal).Error + err := st.db.Create(&dbPrincipal).Error if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return &ConflictError{principal.Username} @@ -206,7 +210,7 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err } func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { - return st.db.Delete(&GormPrincipal{}, "username = ?", username).Error + return st.db.Delete(&DbPrincipal{}, "username = ?", username).Error } func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { @@ -223,7 +227,7 @@ func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, err return &p, err } -func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) { +func (st SqlStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { var dBPolicies []DbPolicy err := st.db.Find(&dBPolicies, policyIds).Error if err != nil { @@ -270,16 +274,16 @@ func (st SqlStore) DeletePolicy(ctx context.Context, policyId string) error { } func (st SqlStore) CreateToken(ctx context.Context, tokenId string, value string) error { - gt := GormToken{TokenId: tokenId, Value: value} + gt := DbToken{ID: tokenId, Value: value} return st.db.Create(>).Error } func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error { - gt := GormToken{TokenId: tokenId} + gt := DbToken{ID: tokenId} return st.db.Delete(>).Error } func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { - var gt GormToken + var gt DbToken err := st.db.First(>, "token_id = ?", tokenId).Error return gt.Value, err } diff --git a/vault/vault.go b/vault/vault.go index 3df7bb9..8bb2e4c 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -22,14 +22,6 @@ type Collection struct { type Record map[string]string // field name -> value -type Principal struct { - Username string `redis:"username"` - Password string `redis:"password"` - Description string `redis:"description"` - CreatedAt string `redis:"created_at"` - Policies []string `redis:"policies"` -} - type Privatiser interface { Encrypt(string) (string, error) Decrypt(string) (string, error) @@ -69,6 +61,14 @@ type Policy struct { Resources []string `redis:"resources" json:"resources" validate:"required"` } +type Principal struct { + Username string `redis:"username"` + Password string `redis:"password"` + Description string `redis:"description"` + CreatedAt string `redis:"created_at"` + Policies []string `redis:"policies"` +} + type Request struct { Principal Principal Action PolicyAction @@ -105,7 +105,7 @@ type VaultDB interface { CreatePrincipal(ctx context.Context, principal Principal) error DeletePrincipal(ctx context.Context, username string) error GetPolicy(ctx context.Context, policyId string) (*Policy, error) - GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) + GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) CreatePolicy(ctx context.Context, p Policy) (string, error) DeletePolicy(ctx context.Context, policyId string) error CreateToken(ctx context.Context, tokenId string, value string) error @@ -549,7 +549,7 @@ func (vault Vault) DeletePolicy( return nil } -func (vault Vault) GetPolicies( +func (vault Vault) GetPrincipalPolicies( ctx context.Context, principal Principal, ) ([]*Policy, error) { @@ -562,7 +562,7 @@ func (vault Vault) GetPolicies( return nil, &ForbiddenError{request} } - policies, err := vault.Db.GetPolicies(ctx, principal.Policies) + policies, err := vault.Db.GetPoliciesById(ctx, principal.Policies) if err != nil { return nil, err } @@ -573,7 +573,7 @@ func (vault Vault) ValidateAction( ctx context.Context, request Request, ) (bool, error) { - policies, err := vault.Db.GetPolicies(ctx, request.Principal.Policies) + policies, err := vault.Db.GetPoliciesById(ctx, request.Principal.Policies) if err != nil { return false, err } diff --git a/vault/vault_test.go b/vault/vault_test.go index b7ba050..9664198 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -12,11 +12,7 @@ import ( func initVault(t *testing.T) (Vault, VaultDB, Privatiser) { ctx := context.Background() - db, err := NewRedisStore( - os.Getenv("KEYDB_CONN_STRING"), - "", - 0, - ) + db, err := NewSqlStore(os.Getenv("DATABASE_URL")) if err != nil { panic(err) } From 7648eacc278a20a3d34c27f2e402f25acba2a48b Mon Sep 17 00:00:00 2001 From: Subrose Date: Wed, 22 Nov 2023 17:54:52 +0000 Subject: [PATCH 04/13] tests passing --- docker-compose.ci.yml | 12 +- keydb/redis.conf | 1054 ----------------------------------------- makefile | 2 +- vault/redis.go | 551 --------------------- vault/redis_test.go | 272 ----------- vault/sql.go | 48 +- vault/vault_test.go | 100 ++-- 7 files changed, 98 insertions(+), 1941 deletions(-) delete mode 100644 keydb/redis.conf delete mode 100644 vault/redis.go delete mode 100644 vault/redis_test.go diff --git a/docker-compose.ci.yml b/docker-compose.ci.yml index 8a75105..2cf542d 100644 --- a/docker-compose.ci.yml +++ b/docker-compose.ci.yml @@ -34,7 +34,7 @@ services: ports: - 3001:3001 depends_on: - - keydb + - postgres volumes: - "./:/app" simulations: @@ -48,7 +48,7 @@ services: working_dir: /code depends_on: - api - - keydb + - postgres command: sh -c "cd simulator && ./simulate.sh" volumes: - "./:/code" @@ -59,14 +59,10 @@ services: dockerfile: Dockerfile target: build environment: - - KEYDB_CONN_STRING=keydb:6379 - - VAULT_DB_HOST=keydb - - VAULT_DB_PORT=6379 - - VAULT_DB_USER=default - - VAULT_DB_PASSWORD= + - VAULT_DATABASE_URL=postgres://postgres:postgres@postgres:5432/postgres?sslmode=disable working_dir: /code depends_on: - - keydb + - postgres command: sh -c "go test ./vault && go test ./api" volumes: - "./:/code" diff --git a/keydb/redis.conf b/keydb/redis.conf deleted file mode 100644 index 9bc2f69..0000000 --- a/keydb/redis.conf +++ /dev/null @@ -1,1054 +0,0 @@ -# Redis configuration file example. -# -# Note that in order to read the configuration file, Redis must be -# started with the file path as first argument: -# -# ./redis-server /path/to/redis.conf - -# Note on units: when memory size is needed, it is possible to specify -# it in the usual form of 1k 5GB 4M and so forth: -# -# 1k => 1000 bytes -# 1kb => 1024 bytes -# 1m => 1000000 bytes -# 1mb => 1024*1024 bytes -# 1g => 1000000000 bytes -# 1gb => 1024*1024*1024 bytes -# -# units are case insensitive so 1GB 1Gb 1gB are all the same. - -################################## INCLUDES ################################### - -# Include one or more other config files here. This is useful if you -# have a standard template that goes to all Redis servers but also need -# to customize a few per-server settings. Include files can include -# other files, so use this wisely. -# -# Notice option "include" won't be rewritten by command "CONFIG REWRITE" -# from admin or Redis Sentinel. Since Redis always uses the last processed -# line as value of a configuration directive, you'd better put includes -# at the beginning of this file to avoid overwriting config change at runtime. -# -# If instead you are interested in using includes to override configuration -# options, it is better to use include as the last line. -# -# include /path/to/local.conf -# include /path/to/other.conf - -################################## NETWORK ##################################### - -# By default, if no "bind" configuration directive is specified, Redis listens -# for connections from all the network interfaces available on the server. -# It is possible to listen to just one or multiple selected interfaces using -# the "bind" configuration directive, followed by one or more IP addresses. -# -# Examples: -# -# bind 192.168.1.100 10.0.0.1 -# bind 127.0.0.1 ::1 -# -# ~~~ WARNING ~~~ If the computer running Redis is directly exposed to the -# internet, binding to all the interfaces is dangerous and will expose the -# instance to everybody on the internet. So by default we uncomment the -# following bind directive, that will force Redis to listen only into -# the IPv4 lookback interface address (this means Redis will be able to -# accept connections only from clients running into the same computer it -# is running). -# -# IF YOU ARE SURE YOU WANT YOUR INSTANCE TO LISTEN TO ALL THE INTERFACES -# JUST COMMENT THE FOLLOWING LINE. -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# bind 127.0.0.1 - -# Protected mode is a layer of security protection, in order to avoid that -# Redis instances left open on the internet are accessed and exploited. -# -# When protected mode is on and if: -# -# 1) The server is not binding explicitly to a set of addresses using the -# "bind" directive. -# 2) No password is configured. -# -# The server only accepts connections from clients connecting from the -# IPv4 and IPv6 loopback addresses 127.0.0.1 and ::1, and from Unix domain -# sockets. -# -# By default protected mode is enabled. You should disable it only if -# you are sure you want clients from other hosts to connect to Redis -# even if no authentication is configured, nor a specific set of interfaces -# are explicitly listed using the "bind" directive. -protected-mode no - -# Accept connections on the specified port, default is 6379 (IANA #815344). -# If port 0 is specified Redis will not listen on a TCP socket. -port 6379 - -# TCP listen() backlog. -# -# In high requests-per-second environments you need an high backlog in order -# to avoid slow clients connections issues. Note that the Linux kernel -# will silently truncate it to the value of /proc/sys/net/core/somaxconn so -# make sure to raise both the value of somaxconn and tcp_max_syn_backlog -# in order to get the desired effect. -tcp-backlog 511 - -# Unix socket. -# -# Specify the path for the Unix socket that will be used to listen for -# incoming connections. There is no default, so Redis will not listen -# on a unix socket when not specified. -# -# unixsocket /tmp/redis.sock -# unixsocketperm 700 - -# Close the connection after a client is idle for N seconds (0 to disable) -timeout 0 - -# TCP keepalive. -# -# If non-zero, use SO_KEEPALIVE to send TCP ACKs to clients in absence -# of communication. This is useful for two reasons: -# -# 1) Detect dead peers. -# 2) Take the connection alive from the point of view of network -# equipment in the middle. -# -# On Linux, the specified value (in seconds) is the period used to send ACKs. -# Note that to close the connection the double of the time is needed. -# On other kernels the period depends on the kernel configuration. -# -# A reasonable value for this option is 300 seconds, which is the new -# Redis default starting with Redis 3.2.1. -tcp-keepalive 300 - -################################# GENERAL ##################################### - -# By default Redis does not run as a daemon. Use 'yes' if you need it. -# Note that Redis will write a pid file in /var/run/redis.pid when daemonized. -daemonize no - -# If you run Redis from upstart or systemd, Redis can interact with your -# supervision tree. Options: -# supervised no - no supervision interaction -# supervised upstart - signal upstart by putting Redis into SIGSTOP mode -# supervised systemd - signal systemd by writing READY=1 to $NOTIFY_SOCKET -# supervised auto - detect upstart or systemd method based on -# UPSTART_JOB or NOTIFY_SOCKET environment variables -# Note: these supervision methods only signal "process is ready." -# They do not enable continuous liveness pings back to your supervisor. -supervised no - -# If a pid file is specified, Redis writes it where specified at startup -# and removes it at exit. -# -# When the server runs non daemonized, no pid file is created if none is -# specified in the configuration. When the server is daemonized, the pid file -# is used even if not specified, defaulting to "/var/run/redis.pid". -# -# Creating a pid file is best effort: if Redis is not able to create it -# nothing bad happens, the server will start and run normally. -pidfile /var/run/redis_6379.pid - -# Specify the server verbosity level. -# This can be one of: -# debug (a lot of information, useful for development/testing) -# verbose (many rarely useful info, but not a mess like the debug level) -# notice (moderately verbose, what you want in production probably) -# warning (only very important / critical messages are logged) -loglevel notice - -# Specify the log file name. Also the empty string can be used to force -# Redis to log on the standard output. Note that if you use standard -# output for logging but daemonize, logs will be sent to /dev/null -logfile "" - -# To enable logging to the system logger, just set 'syslog-enabled' to yes, -# and optionally update the other syslog parameters to suit your needs. -# syslog-enabled no - -# Specify the syslog identity. -# syslog-ident redis - -# Specify the syslog facility. Must be USER or between LOCAL0-LOCAL7. -# syslog-facility local0 - -# Set the number of databases. The default database is DB 0, you can select -# a different one on a per-connection basis using SELECT where -# dbid is a number between 0 and 'databases'-1 -databases 16 - -################################ SNAPSHOTTING ################################ -# -# Save the DB on disk: -# -# save -# -# Will save the DB if both the given number of seconds and the given -# number of write operations against the DB occurred. -# -# In the example below the behaviour will be to save: -# after 900 sec (15 min) if at least 1 key changed -# after 300 sec (5 min) if at least 10 keys changed -# after 60 sec if at least 10000 keys changed -# -# Note: you can disable saving completely by commenting out all "save" lines. -# -# It is also possible to remove all the previously configured save -# points by adding a save directive with a single empty string argument -# like in the following example: -# -# save "" - -save 900 1 -save 300 10 -save 60 10000 - -# By default Redis will stop accepting writes if RDB snapshots are enabled -# (at least one save point) and the latest background save failed. -# This will make the user aware (in a hard way) that data is not persisting -# on disk properly, otherwise chances are that no one will notice and some -# disaster will happen. -# -# If the background saving process will start working again Redis will -# automatically allow writes again. -# -# However if you have setup your proper monitoring of the Redis server -# and persistence, you may want to disable this feature so that Redis will -# continue to work as usual even if there are problems with disk, -# permissions, and so forth. -stop-writes-on-bgsave-error yes - -# Compress string objects using LZF when dump .rdb databases? -# For default that's set to 'yes' as it's almost always a win. -# If you want to save some CPU in the saving child set it to 'no' but -# the dataset will likely be bigger if you have compressible values or keys. -rdbcompression yes - -# Since version 5 of RDB a CRC64 checksum is placed at the end of the file. -# This makes the format more resistant to corruption but there is a performance -# hit to pay (around 10%) when saving and loading RDB files, so you can disable it -# for maximum performances. -# -# RDB files created with checksum disabled have a checksum of zero that will -# tell the loading code to skip the check. -rdbchecksum yes - -# The filename where to dump the DB -dbfilename dump.rdb - -# The working directory. -# -# The DB will be written inside this directory, with the filename specified -# above using the 'dbfilename' configuration directive. -# -# The Append Only File will also be created inside this directory. -# -# Note that you must specify a directory here, not a file name. -dir ./ - -################################# REPLICATION ################################# - -# Master-Slave replication. Use slaveof to make a Redis instance a copy of -# another Redis server. A few things to understand ASAP about Redis replication. -# -# 1) Redis replication is asynchronous, but you can configure a master to -# stop accepting writes if it appears to be not connected with at least -# a given number of slaves. -# 2) Redis slaves are able to perform a partial resynchronization with the -# master if the replication link is lost for a relatively small amount of -# time. You may want to configure the replication backlog size (see the next -# sections of this file) with a sensible value depending on your needs. -# 3) Replication is automatic and does not need user intervention. After a -# network partition slaves automatically try to reconnect to masters -# and resynchronize with them. -# -# slaveof - -# If the master is password protected (using the "requirepass" configuration -# directive below) it is possible to tell the slave to authenticate before -# starting the replication synchronization process, otherwise the master will -# refuse the slave request. -# -# masterauth - -# When a slave loses its connection with the master, or when the replication -# is still in progress, the slave can act in two different ways: -# -# 1) if slave-serve-stale-data is set to 'yes' (the default) the slave will -# still reply to client requests, possibly with out of date data, or the -# data set may just be empty if this is the first synchronization. -# -# 2) if slave-serve-stale-data is set to 'no' the slave will reply with -# an error "SYNC with master in progress" to all the kind of commands -# but to INFO and SLAVEOF. -# -slave-serve-stale-data yes - -# You can configure a slave instance to accept writes or not. Writing against -# a slave instance may be useful to store some ephemeral data (because data -# written on a slave will be easily deleted after resync with the master) but -# may also cause problems if clients are writing to it because of a -# misconfiguration. -# -# Since Redis 2.6 by default slaves are read-only. -# -# Note: read only slaves are not designed to be exposed to untrusted clients -# on the internet. It's just a protection layer against misuse of the instance. -# Still a read only slave exports by default all the administrative commands -# such as CONFIG, DEBUG, and so forth. To a limited extent you can improve -# security of read only slaves using 'rename-command' to shadow all the -# administrative / dangerous commands. -slave-read-only yes - -# Replication SYNC strategy: disk or socket. -# -# ------------------------------------------------------- -# WARNING: DISKLESS REPLICATION IS EXPERIMENTAL CURRENTLY -# ------------------------------------------------------- -# -# New slaves and reconnecting slaves that are not able to continue the replication -# process just receiving differences, need to do what is called a "full -# synchronization". An RDB file is transmitted from the master to the slaves. -# The transmission can happen in two different ways: -# -# 1) Disk-backed: The Redis master creates a new process that writes the RDB -# file on disk. Later the file is transferred by the parent -# process to the slaves incrementally. -# 2) Diskless: The Redis master creates a new process that directly writes the -# RDB file to slave sockets, without touching the disk at all. -# -# With disk-backed replication, while the RDB file is generated, more slaves -# can be queued and served with the RDB file as soon as the current child producing -# the RDB file finishes its work. With diskless replication instead once -# the transfer starts, new slaves arriving will be queued and a new transfer -# will start when the current one terminates. -# -# When diskless replication is used, the master waits a configurable amount of -# time (in seconds) before starting the transfer in the hope that multiple slaves -# will arrive and the transfer can be parallelized. -# -# With slow disks and fast (large bandwidth) networks, diskless replication -# works better. -repl-diskless-sync no - -# When diskless replication is enabled, it is possible to configure the delay -# the server waits in order to spawn the child that transfers the RDB via socket -# to the slaves. -# -# This is important since once the transfer starts, it is not possible to serve -# new slaves arriving, that will be queued for the next RDB transfer, so the server -# waits a delay in order to let more slaves arrive. -# -# The delay is specified in seconds, and by default is 5 seconds. To disable -# it entirely just set it to 0 seconds and the transfer will start ASAP. -repl-diskless-sync-delay 5 - -# Slaves send PINGs to server in a predefined interval. It's possible to change -# this interval with the repl_ping_slave_period option. The default value is 10 -# seconds. -# -# repl-ping-slave-period 10 - -# The following option sets the replication timeout for: -# -# 1) Bulk transfer I/O during SYNC, from the point of view of slave. -# 2) Master timeout from the point of view of slaves (data, pings). -# 3) Slave timeout from the point of view of masters (REPLCONF ACK pings). -# -# It is important to make sure that this value is greater than the value -# specified for repl-ping-slave-period otherwise a timeout will be detected -# every time there is low traffic between the master and the slave. -# -# repl-timeout 60 - -# Disable TCP_NODELAY on the slave socket after SYNC? -# -# If you select "yes" Redis will use a smaller number of TCP packets and -# less bandwidth to send data to slaves. But this can add a delay for -# the data to appear on the slave side, up to 40 milliseconds with -# Linux kernels using a default configuration. -# -# If you select "no" the delay for data to appear on the slave side will -# be reduced but more bandwidth will be used for replication. -# -# By default we optimize for low latency, but in very high traffic conditions -# or when the master and slaves are many hops away, turning this to "yes" may -# be a good idea. -repl-disable-tcp-nodelay no - -# Set the replication backlog size. The backlog is a buffer that accumulates -# slave data when slaves are disconnected for some time, so that when a slave -# wants to reconnect again, often a full resync is not needed, but a partial -# resync is enough, just passing the portion of data the slave missed while -# disconnected. -# -# The bigger the replication backlog, the longer the time the slave can be -# disconnected and later be able to perform a partial resynchronization. -# -# The backlog is only allocated once there is at least a slave connected. -# -# repl-backlog-size 1mb - -# After a master has no longer connected slaves for some time, the backlog -# will be freed. The following option configures the amount of seconds that -# need to elapse, starting from the time the last slave disconnected, for -# the backlog buffer to be freed. -# -# A value of 0 means to never release the backlog. -# -# repl-backlog-ttl 3600 - -# The slave priority is an integer number published by Redis in the INFO output. -# It is used by Redis Sentinel in order to select a slave to promote into a -# master if the master is no longer working correctly. -# -# A slave with a low priority number is considered better for promotion, so -# for instance if there are three slaves with priority 10, 100, 25 Sentinel will -# pick the one with priority 10, that is the lowest. -# -# However a special priority of 0 marks the slave as not able to perform the -# role of master, so a slave with priority of 0 will never be selected by -# Redis Sentinel for promotion. -# -# By default the priority is 100. -slave-priority 100 - -# It is possible for a master to stop accepting writes if there are less than -# N slaves connected, having a lag less or equal than M seconds. -# -# The N slaves need to be in "online" state. -# -# The lag in seconds, that must be <= the specified value, is calculated from -# the last ping received from the slave, that is usually sent every second. -# -# This option does not GUARANTEE that N replicas will accept the write, but -# will limit the window of exposure for lost writes in case not enough slaves -# are available, to the specified number of seconds. -# -# For example to require at least 3 slaves with a lag <= 10 seconds use: -# -# min-slaves-to-write 3 -# min-slaves-max-lag 10 -# -# Setting one or the other to 0 disables the feature. -# -# By default min-slaves-to-write is set to 0 (feature disabled) and -# min-slaves-max-lag is set to 10. - -# A Redis master is able to list the address and port of the attached -# slaves in different ways. For example the "INFO replication" section -# offers this information, which is used, among other tools, by -# Redis Sentinel in order to discover slave instances. -# Another place where this info is available is in the output of the -# "ROLE" command of a masteer. -# -# The listed IP and address normally reported by a slave is obtained -# in the following way: -# -# IP: The address is auto detected by checking the peer address -# of the socket used by the slave to connect with the master. -# -# Port: The port is communicated by the slave during the replication -# handshake, and is normally the port that the slave is using to -# list for connections. -# -# However when port forwarding or Network Address Translation (NAT) is -# used, the slave may be actually reachable via different IP and port -# pairs. The following two options can be used by a slave in order to -# report to its master a specific set of IP and port, so that both INFO -# and ROLE will report those values. -# -# There is no need to use both the options if you need to override just -# the port or the IP address. -# -# slave-announce-ip 5.5.5.5 -# slave-announce-port 1234 - -################################## SECURITY ################################### - -# Require clients to issue AUTH before processing any other -# commands. This might be useful in environments in which you do not trust -# others with access to the host running redis-server. -# -# This should stay commented out for backward compatibility and because most -# people do not need auth (e.g. they run their own servers). -# -# Warning: since Redis is pretty fast an outside user can try up to -# 150k passwords per second against a good box. This means that you should -# use a very strong password otherwise it will be very easy to break. -# -# requirepass foobared - -# Command renaming. -# -# It is possible to change the name of dangerous commands in a shared -# environment. For instance the CONFIG command may be renamed into something -# hard to guess so that it will still be available for internal-use tools -# but not available for general clients. -# -# Example: -# -# rename-command CONFIG b840fc02d524045429941cc15f59e41cb7be6c52 -# -# It is also possible to completely kill a command by renaming it into -# an empty string: -# -# rename-command CONFIG "" -# -# Please note that changing the name of commands that are logged into the -# AOF file or transmitted to slaves may cause problems. - -################################### LIMITS #################################### - -# Set the max number of connected clients at the same time. By default -# this limit is set to 10000 clients, however if the Redis server is not -# able to configure the process file limit to allow for the specified limit -# the max number of allowed clients is set to the current file limit -# minus 32 (as Redis reserves a few file descriptors for internal uses). -# -# Once the limit is reached Redis will close all the new connections sending -# an error 'max number of clients reached'. -# -# maxclients 10000 - -# Don't use more memory than the specified amount of bytes. -# When the memory limit is reached Redis will try to remove keys -# according to the eviction policy selected (see maxmemory-policy). -# -# If Redis can't remove keys according to the policy, or if the policy is -# set to 'noeviction', Redis will start to reply with errors to commands -# that would use more memory, like SET, LPUSH, and so on, and will continue -# to reply to read-only commands like GET. -# -# This option is usually useful when using Redis as an LRU cache, or to set -# a hard memory limit for an instance (using the 'noeviction' policy). -# -# WARNING: If you have slaves attached to an instance with maxmemory on, -# the size of the output buffers needed to feed the slaves are subtracted -# from the used memory count, so that network problems / resyncs will -# not trigger a loop where keys are evicted, and in turn the output -# buffer of slaves is full with DELs of keys evicted triggering the deletion -# of more keys, and so forth until the database is completely emptied. -# -# In short... if you have slaves attached it is suggested that you set a lower -# limit for maxmemory so that there is some free RAM on the system for slave -# output buffers (but this is not needed if the policy is 'noeviction'). -# -# maxmemory -maxmemory 1gb - -# MAXMEMORY POLICY: how Redis will select what to remove when maxmemory -# is reached. You can select among five behaviors: -# -# volatile-lru -> remove the key with an expire set using an LRU algorithm -# allkeys-lru -> remove any key according to the LRU algorithm -# volatile-random -> remove a random key with an expire set -# allkeys-random -> remove a random key, any key -# volatile-ttl -> remove the key with the nearest expire time (minor TTL) -# noeviction -> don't expire at all, just return an error on write operations -# -# Note: with any of the above policies, Redis will return an error on write -# operations, when there are no suitable keys for eviction. -# -# At the date of writing these commands are: set setnx setex append -# incr decr rpush lpush rpushx lpushx linsert lset rpoplpush sadd -# sinter sinterstore sunion sunionstore sdiff sdiffstore zadd zincrby -# zunionstore zinterstore hset hsetnx hmset hincrby incrby decrby -# getset mset msetnx exec sort -# -# The default is: -# -# maxmemory-policy noeviction -maxmemory-policy allkeys-lru - -# LRU and minimal TTL algorithms are not precise algorithms but approximated -# algorithms (in order to save memory), so you can tune it for speed or -# accuracy. For default Redis will check five keys and pick the one that was -# used less recently, you can change the sample size using the following -# configuration directive. -# -# The default of 5 produces good enough results. 10 Approximates very closely -# true LRU but costs a bit more CPU. 3 is very fast but not very accurate. -# -# maxmemory-samples 5 - -############################## APPEND ONLY MODE ############################### - -# By default Redis asynchronously dumps the dataset on disk. This mode is -# good enough in many applications, but an issue with the Redis process or -# a power outage may result into a few minutes of writes lost (depending on -# the configured save points). -# -# The Append Only File is an alternative persistence mode that provides -# much better durability. For instance using the default data fsync policy -# (see later in the config file) Redis can lose just one second of writes in a -# dramatic event like a server power outage, or a single write if something -# wrong with the Redis process itself happens, but the operating system is -# still running correctly. -# -# AOF and RDB persistence can be enabled at the same time without problems. -# If the AOF is enabled on startup Redis will load the AOF, that is the file -# with the better durability guarantees. -# -# Please check http://redis.io/topics/persistence for more information. - -appendonly no - -# The name of the append only file (default: "appendonly.aof") - -appendfilename "appendonly.aof" - -# The fsync() call tells the Operating System to actually write data on disk -# instead of waiting for more data in the output buffer. Some OS will really flush -# data on disk, some other OS will just try to do it ASAP. -# -# Redis supports three different modes: -# -# no: don't fsync, just let the OS flush the data when it wants. Faster. -# always: fsync after every write to the append only log. Slow, Safest. -# everysec: fsync only one time every second. Compromise. -# -# The default is "everysec", as that's usually the right compromise between -# speed and data safety. It's up to you to understand if you can relax this to -# "no" that will let the operating system flush the output buffer when -# it wants, for better performances (but if you can live with the idea of -# some data loss consider the default persistence mode that's snapshotting), -# or on the contrary, use "always" that's very slow but a bit safer than -# everysec. -# -# More details please check the following article: -# http://antirez.com/post/redis-persistence-demystified.html -# -# If unsure, use "everysec". - -# appendfsync always -appendfsync everysec -# appendfsync no - -# When the AOF fsync policy is set to always or everysec, and a background -# saving process (a background save or AOF log background rewriting) is -# performing a lot of I/O against the disk, in some Linux configurations -# Redis may block too long on the fsync() call. Note that there is no fix for -# this currently, as even performing fsync in a different thread will block -# our synchronous write(2) call. -# -# In order to mitigate this problem it's possible to use the following option -# that will prevent fsync() from being called in the main process while a -# BGSAVE or BGREWRITEAOF is in progress. -# -# This means that while another child is saving, the durability of Redis is -# the same as "appendfsync none". In practical terms, this means that it is -# possible to lose up to 30 seconds of log in the worst scenario (with the -# default Linux settings). -# -# If you have latency problems turn this to "yes". Otherwise leave it as -# "no" that is the safest pick from the point of view of durability. - -no-appendfsync-on-rewrite no - -# Automatic rewrite of the append only file. -# Redis is able to automatically rewrite the log file implicitly calling -# BGREWRITEAOF when the AOF log size grows by the specified percentage. -# -# This is how it works: Redis remembers the size of the AOF file after the -# latest rewrite (if no rewrite has happened since the restart, the size of -# the AOF at startup is used). -# -# This base size is compared to the current size. If the current size is -# bigger than the specified percentage, the rewrite is triggered. Also -# you need to specify a minimal size for the AOF file to be rewritten, this -# is useful to avoid rewriting the AOF file even if the percentage increase -# is reached but it is still pretty small. -# -# Specify a percentage of zero in order to disable the automatic AOF -# rewrite feature. - -auto-aof-rewrite-percentage 100 -auto-aof-rewrite-min-size 64mb - -# An AOF file may be found to be truncated at the end during the Redis -# startup process, when the AOF data gets loaded back into memory. -# This may happen when the system where Redis is running -# crashes, especially when an ext4 filesystem is mounted without the -# data=ordered option (however this can't happen when Redis itself -# crashes or aborts but the operating system still works correctly). -# -# Redis can either exit with an error when this happens, or load as much -# data as possible (the default now) and start if the AOF file is found -# to be truncated at the end. The following option controls this behavior. -# -# If aof-load-truncated is set to yes, a truncated AOF file is loaded and -# the Redis server starts emitting a log to inform the user of the event. -# Otherwise if the option is set to no, the server aborts with an error -# and refuses to start. When the option is set to no, the user requires -# to fix the AOF file using the "redis-check-aof" utility before to restart -# the server. -# -# Note that if the AOF file will be found to be corrupted in the middle -# the server will still exit with an error. This option only applies when -# Redis will try to read more data from the AOF file but not enough bytes -# will be found. -aof-load-truncated yes - -################################ LUA SCRIPTING ############################### - -# Max execution time of a Lua script in milliseconds. -# -# If the maximum execution time is reached Redis will log that a script is -# still in execution after the maximum allowed time and will start to -# reply to queries with an error. -# -# When a long running script exceeds the maximum execution time only the -# SCRIPT KILL and SHUTDOWN NOSAVE commands are available. The first can be -# used to stop a script that did not yet called write commands. The second -# is the only way to shut down the server in the case a write command was -# already issued by the script but the user doesn't want to wait for the natural -# termination of the script. -# -# Set it to 0 or a negative value for unlimited execution without warnings. -lua-time-limit 5000 - -################################ REDIS CLUSTER ############################### -# -# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -# WARNING EXPERIMENTAL: Redis Cluster is considered to be stable code, however -# in order to mark it as "mature" we need to wait for a non trivial percentage -# of users to deploy it in production. -# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -# -# Normal Redis instances can't be part of a Redis Cluster; only nodes that are -# started as cluster nodes can. In order to start a Redis instance as a -# cluster node enable the cluster support uncommenting the following: -# -# cluster-enabled yes - -# Every cluster node has a cluster configuration file. This file is not -# intended to be edited by hand. It is created and updated by Redis nodes. -# Every Redis Cluster node requires a different cluster configuration file. -# Make sure that instances running in the same system do not have -# overlapping cluster configuration file names. -# -# cluster-config-file nodes-6379.conf - -# Cluster node timeout is the amount of milliseconds a node must be unreachable -# for it to be considered in failure state. -# Most other internal time limits are multiple of the node timeout. -# -# cluster-node-timeout 15000 - -# A slave of a failing master will avoid to start a failover if its data -# looks too old. -# -# There is no simple way for a slave to actually have a exact measure of -# its "data age", so the following two checks are performed: -# -# 1) If there are multiple slaves able to failover, they exchange messages -# in order to try to give an advantage to the slave with the best -# replication offset (more data from the master processed). -# Slaves will try to get their rank by offset, and apply to the start -# of the failover a delay proportional to their rank. -# -# 2) Every single slave computes the time of the last interaction with -# its master. This can be the last ping or command received (if the master -# is still in the "connected" state), or the time that elapsed since the -# disconnection with the master (if the replication link is currently down). -# If the last interaction is too old, the slave will not try to failover -# at all. -# -# The point "2" can be tuned by user. Specifically a slave will not perform -# the failover if, since the last interaction with the master, the time -# elapsed is greater than: -# -# (node-timeout * slave-validity-factor) + repl-ping-slave-period -# -# So for example if node-timeout is 30 seconds, and the slave-validity-factor -# is 10, and assuming a default repl-ping-slave-period of 10 seconds, the -# slave will not try to failover if it was not able to talk with the master -# for longer than 310 seconds. -# -# A large slave-validity-factor may allow slaves with too old data to failover -# a master, while a too small value may prevent the cluster from being able to -# elect a slave at all. -# -# For maximum availability, it is possible to set the slave-validity-factor -# to a value of 0, which means, that slaves will always try to failover the -# master regardless of the last time they interacted with the master. -# (However they'll always try to apply a delay proportional to their -# offset rank). -# -# Zero is the only value able to guarantee that when all the partitions heal -# the cluster will always be able to continue. -# -# cluster-slave-validity-factor 10 - -# Cluster slaves are able to migrate to orphaned masters, that are masters -# that are left without working slaves. This improves the cluster ability -# to resist to failures as otherwise an orphaned master can't be failed over -# in case of failure if it has no working slaves. -# -# Slaves migrate to orphaned masters only if there are still at least a -# given number of other working slaves for their old master. This number -# is the "migration barrier". A migration barrier of 1 means that a slave -# will migrate only if there is at least 1 other working slave for its master -# and so forth. It usually reflects the number of slaves you want for every -# master in your cluster. -# -# Default is 1 (slaves migrate only if their masters remain with at least -# one slave). To disable migration just set it to a very large value. -# A value of 0 can be set but is useful only for debugging and dangerous -# in production. -# -# cluster-migration-barrier 1 - -# By default Redis Cluster nodes stop accepting queries if they detect there -# is at least an hash slot uncovered (no available node is serving it). -# This way if the cluster is partially down (for example a range of hash slots -# are no longer covered) all the cluster becomes, eventually, unavailable. -# It automatically returns available as soon as all the slots are covered again. -# -# However sometimes you want the subset of the cluster which is working, -# to continue to accept queries for the part of the key space that is still -# covered. In order to do so, just set the cluster-require-full-coverage -# option to no. -# -# cluster-require-full-coverage yes - -# In order to setup your cluster make sure to read the documentation -# available at http://redis.io web site. - -################################## SLOW LOG ################################### - -# The Redis Slow Log is a system to log queries that exceeded a specified -# execution time. The execution time does not include the I/O operations -# like talking with the client, sending the reply and so forth, -# but just the time needed to actually execute the command (this is the only -# stage of command execution where the thread is blocked and can not serve -# other requests in the meantime). -# -# You can configure the slow log with two parameters: one tells Redis -# what is the execution time, in microseconds, to exceed in order for the -# command to get logged, and the other parameter is the length of the -# slow log. When a new command is logged the oldest one is removed from the -# queue of logged commands. - -# The following time is expressed in microseconds, so 1000000 is equivalent -# to one second. Note that a negative number disables the slow log, while -# a value of zero forces the logging of every command. -slowlog-log-slower-than 10000 - -# There is no limit to this length. Just be aware that it will consume memory. -# You can reclaim memory used by the slow log with SLOWLOG RESET. -slowlog-max-len 128 - -################################ LATENCY MONITOR ############################## - -# The Redis latency monitoring subsystem samples different operations -# at runtime in order to collect data related to possible sources of -# latency of a Redis instance. -# -# Via the LATENCY command this information is available to the user that can -# print graphs and obtain reports. -# -# The system only logs operations that were performed in a time equal or -# greater than the amount of milliseconds specified via the -# latency-monitor-threshold configuration directive. When its value is set -# to zero, the latency monitor is turned off. -# -# By default latency monitoring is disabled since it is mostly not needed -# if you don't have latency issues, and collecting data has a performance -# impact, that while very small, can be measured under big load. Latency -# monitoring can easily be enabled at runtime using the command -# "CONFIG SET latency-monitor-threshold " if needed. -latency-monitor-threshold 0 - -############################# EVENT NOTIFICATION ############################## - -# Redis can notify Pub/Sub clients about events happening in the key space. -# This feature is documented at http://redis.io/topics/notifications -# -# For instance if keyspace events notification is enabled, and a client -# performs a DEL operation on key "foo" stored in the Database 0, two -# messages will be published via Pub/Sub: -# -# PUBLISH __keyspace@0__:foo del -# PUBLISH __keyevent@0__:del foo -# -# It is possible to select the events that Redis will notify among a set -# of classes. Every class is identified by a single character: -# -# K Keyspace events, published with __keyspace@__ prefix. -# E Keyevent events, published with __keyevent@__ prefix. -# g Generic commands (non-type specific) like DEL, EXPIRE, RENAME, ... -# $ String commands -# l List commands -# s Set commands -# h Hash commands -# z Sorted set commands -# x Expired events (events generated every time a key expires) -# e Evicted events (events generated when a key is evicted for maxmemory) -# A Alias for g$lshzxe, so that the "AKE" string means all the events. -# -# The "notify-keyspace-events" takes as argument a string that is composed -# of zero or multiple characters. The empty string means that notifications -# are disabled. -# -# Example: to enable list and generic events, from the point of view of the -# event name, use: -# -# notify-keyspace-events Elg -# -# Example 2: to get the stream of the expired keys subscribing to channel -# name __keyevent@0__:expired use: -# -# notify-keyspace-events Ex -# -# By default all notifications are disabled because most users don't need -# this feature and the feature has some overhead. Note that if you don't -# specify at least one of K or E, no events will be delivered. -notify-keyspace-events "" - -############################### ADVANCED CONFIG ############################### - -# Hashes are encoded using a memory efficient data structure when they have a -# small number of entries, and the biggest entry does not exceed a given -# threshold. These thresholds can be configured using the following directives. -hash-max-ziplist-entries 512 -hash-max-ziplist-value 64 - -# Lists are also encoded in a special way to save a lot of space. -# The number of entries allowed per internal list node can be specified -# as a fixed maximum size or a maximum number of elements. -# For a fixed maximum size, use -5 through -1, meaning: -# -5: max size: 64 Kb <-- not recommended for normal workloads -# -4: max size: 32 Kb <-- not recommended -# -3: max size: 16 Kb <-- probably not recommended -# -2: max size: 8 Kb <-- good -# -1: max size: 4 Kb <-- good -# Positive numbers mean store up to _exactly_ that number of elements -# per list node. -# The highest performing option is usually -2 (8 Kb size) or -1 (4 Kb size), -# but if your use case is unique, adjust the settings as necessary. -list-max-ziplist-size -2 - -# Lists may also be compressed. -# Compress depth is the number of quicklist ziplist nodes from *each* side of -# the list to *exclude* from compression. The head and tail of the list -# are always uncompressed for fast push/pop operations. Settings are: -# 0: disable all list compression -# 1: depth 1 means "don't start compressing until after 1 node into the list, -# going from either the head or tail" -# So: [head]->node->node->...->node->[tail] -# [head], [tail] will always be uncompressed; inner nodes will compress. -# 2: [head]->[next]->node->node->...->node->[prev]->[tail] -# 2 here means: don't compress head or head->next or tail->prev or tail, -# but compress all nodes between them. -# 3: [head]->[next]->[next]->node->node->...->node->[prev]->[prev]->[tail] -# etc. -list-compress-depth 0 - -# Sets have a special encoding in just one case: when a set is composed -# of just strings that happen to be integers in radix 10 in the range -# of 64 bit signed integers. -# The following configuration setting sets the limit in the size of the -# set in order to use this special memory saving encoding. -set-max-intset-entries 512 - -# Similarly to hashes and lists, sorted sets are also specially encoded in -# order to save a lot of space. This encoding is only used when the length and -# elements of a sorted set are below the following limits: -zset-max-ziplist-entries 128 -zset-max-ziplist-value 64 - -# HyperLogLog sparse representation bytes limit. The limit includes the -# 16 bytes header. When an HyperLogLog using the sparse representation crosses -# this limit, it is converted into the dense representation. -# -# A value greater than 16000 is totally useless, since at that point the -# dense representation is more memory efficient. -# -# The suggested value is ~ 3000 in order to have the benefits of -# the space efficient encoding without slowing down too much PFADD, -# which is O(N) with the sparse encoding. The value can be raised to -# ~ 10000 when CPU is not a concern, but space is, and the data set is -# composed of many HyperLogLogs with cardinality in the 0 - 15000 range. -hll-sparse-max-bytes 3000 - -# Active rehashing uses 1 millisecond every 100 milliseconds of CPU time in -# order to help rehashing the main Redis hash table (the one mapping top-level -# keys to values). The hash table implementation Redis uses (see dict.c) -# performs a lazy rehashing: the more operation you run into a hash table -# that is rehashing, the more rehashing "steps" are performed, so if the -# server is idle the rehashing is never complete and some more memory is used -# by the hash table. -# -# The default is to use this millisecond 10 times every second in order to -# actively rehash the main dictionaries, freeing memory when possible. -# -# If unsure: -# use "activerehashing no" if you have hard latency requirements and it is -# not a good thing in your environment that Redis can reply from time to time -# to queries with 2 milliseconds delay. -# -# use "activerehashing yes" if you don't have such hard requirements but -# want to free memory asap when possible. -activerehashing yes - -# The client output buffer limits can be used to force disconnection of clients -# that are not reading data from the server fast enough for some reason (a -# common reason is that a Pub/Sub client can't consume messages as fast as the -# publisher can produce them). -# -# The limit can be set differently for the three different classes of clients: -# -# normal -> normal clients including MONITOR clients -# slave -> slave clients -# pubsub -> clients subscribed to at least one pubsub channel or pattern -# -# The syntax of every client-output-buffer-limit directive is the following: -# -# client-output-buffer-limit -# -# A client is immediately disconnected once the hard limit is reached, or if -# the soft limit is reached and remains reached for the specified number of -# seconds (continuously). -# So for instance if the hard limit is 32 megabytes and the soft limit is -# 16 megabytes / 10 seconds, the client will get disconnected immediately -# if the size of the output buffers reach 32 megabytes, but will also get -# disconnected if the client reaches 16 megabytes and continuously overcomes -# the limit for 10 seconds. -# -# By default normal clients are not limited because they don't receive data -# without asking (in a push way), but just after a request, so only -# asynchronous clients may create a scenario where data is requested faster -# than it can read. -# -# Instead there is a default limit for pubsub and slave clients, since -# subscribers and slaves receive data in a push fashion. -# -# Both the hard or the soft limit can be disabled by setting them to zero. -client-output-buffer-limit normal 0 0 0 -client-output-buffer-limit slave 256mb 64mb 60 -client-output-buffer-limit pubsub 32mb 8mb 60 - -# Redis calls an internal function to perform many background tasks, like -# closing connections of clients in timeout, purging expired keys that are -# never requested, and so forth. -# -# Not all tasks are performed with the same frequency, but Redis checks for -# tasks to perform according to the specified "hz" value. -# -# By default "hz" is set to 10. Raising the value will use more CPU when -# Redis is idle, but at the same time will make Redis more responsive when -# there are many keys expiring at the same time, and timeouts may be -# handled with more precision. -# -# The range is between 1 and 500, however a value over 100 is usually not -# a good idea. Most users should use the default of 10 and raise this up to -# 100 only in environments where very low latency is required. -hz 10 - -# When a child rewrites the AOF file, if the following option is enabled -# the file will be fsync-ed every 32 MB of data generated. This is useful -# in order to commit the file to the disk more incrementally and avoid -# big latency spikes. -aof-rewrite-incremental-fsync yes diff --git a/makefile b/makefile index a61cb36..8468486 100644 --- a/makefile +++ b/makefile @@ -8,7 +8,7 @@ tests: go test ./api run-gosec: - gosec ./... + gosec ./vault ./api ./logger check-formatting: if [ -n "$(gofmt -l .)" ]; then echo "Go code is not properly formatted:"; gofmt -d .; exit 1; fi diff --git a/vault/redis.go b/vault/redis.go deleted file mode 100644 index 60c1ccd..0000000 --- a/vault/redis.go +++ /dev/null @@ -1,551 +0,0 @@ -package vault - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/redis/go-redis/v9" -) - -const COLLECTIONS_PREFIX = "collection:" -const RECORDS_PREFIX = "record:" -const FIELDS_PREFIX = "field:" -const PRINCIPAL_PREFIX = "principal:" -const POLICY_PREFIX = "policy:" -const INDEX_PREFIX = "idx:" -const TOKEN_PREFIX = "token:" - -var ( - Prefix = map[string]string{ - "collection": COLLECTIONS_PREFIX, - "record": RECORDS_PREFIX, - } -) - -type RedisStore struct { - Client *redis.Client -} - -func NewRedisStore(addr, password string, db int) (*RedisStore, error) { - client := redis.NewClient(&redis.Options{ - Addr: addr, - Password: password, - DB: db, - }) - - _, err := client.Ping(context.Background()).Result() - if err != nil { - return nil, fmt.Errorf("unable to connect to Redis: %w", err) - } - - return &RedisStore{Client: client}, nil -} - -func (rs RedisStore) Flush(ctx context.Context) error { - _, err := rs.Client.FlushDB(ctx).Result() - return err -} - -func (rs RedisStore) GetCollections(ctx context.Context) ([]string, error) { - members, err := rs.Client.SMembers(ctx, COLLECTIONS_PREFIX).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return []string{}, nil - } - return []string{}, fmt.Errorf("failed to get collections: %w", err) - } - for i, member := range members { - members[i] = member[len(COLLECTIONS_PREFIX):] - } - return members, nil -} - -func (rs RedisStore) GetCollection(ctx context.Context, name string) (*Collection, error) { - colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, name) - dbCol := Collection{} - col, err := rs.Client.HGetAll(ctx, colId).Result() - if err != nil { - return nil, fmt.Errorf("failed to get data from Redis with key %s: %w", colId, err) - } - if len(col) == 0 { - return nil, &NotFoundError{"collection", name} - } - pipe := rs.Client.Pipeline() - for _, v := range col { - pipe.HGetAll(ctx, v) - } - fields, err := pipe.Exec(ctx) - if err != nil { - return nil, fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - dbCol.Fields = make(map[string]Field, len(fields)) - for _, field := range fields { - dbCol.Fields[field.(*redis.MapStringStringCmd).Val()["name"]] = Field{ - Name: field.(*redis.MapStringStringCmd).Val()["name"], - Type: field.(*redis.MapStringStringCmd).Val()["type"], - IsIndexed: field.(*redis.MapStringStringCmd).Val()["is_indexed"] == "1", - } - } - dbCol.Name = name - return &dbCol, nil -} - -func (rs RedisStore) CreateCollection(ctx context.Context, c Collection) (string, error) { - colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, c.Name) - - exists, err := rs.Client.Exists(ctx, colId).Result() - if err != nil { - return c.Name, fmt.Errorf("failed to check collection existence: %w", err) - } - - if exists == 1 { - return c.Name, &ConflictError{colId} - } - - pipe := rs.Client.Pipeline() - pipe.SAdd(ctx, COLLECTIONS_PREFIX, colId) - for fieldName, fieldValue := range c.Fields { - fieldId := fmt.Sprintf("%s:%s%s", colId, FIELDS_PREFIX, fieldName) - pipe.HSet( - ctx, - fieldId, - "name", fieldName, - "type", fieldValue.Type, - "is_indexed", fieldValue.IsIndexed, - ) - pipe.HSet( - ctx, - colId, - fieldName, - fieldId, - ) - } - - _, err = pipe.Exec(ctx) - if err != nil { - return c.Name, fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - - return c.Name, nil -} - -func (rs RedisStore) DeleteCollection(ctx context.Context, name string) error { - dbCollection, err := rs.GetCollection(ctx, name) - if err != nil { - return err - } - - colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, name) - recordIds, err := rs.Client.SMembers(ctx, fmt.Sprintf("%s:r", colId)).Result() - if err != nil { - return err - } - - pipe := rs.Client.Pipeline() - // Delete the collection - pipe.Del(ctx, colId) - pipe.SRem(ctx, COLLECTIONS_PREFIX, colId) - // Delete all records and indexes for the collection - for _, recordId := range recordIds { - dbRecord, err := rs.GetRecords(ctx, name, []string{recordId}) - if err != nil { - return err - } - - redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) - pipe.Del(ctx, redisKey) - pipe.SRem(ctx, fmt.Sprintf("%s:r", colId), recordId) - for fieldName, fieldValue := range *dbRecord[recordId] { - if dbCollection.Fields[fieldName].IsIndexed { - pipe.SRem(ctx, formatIndex(fieldName, fieldValue), recordId) - } - } - } - - _, err = pipe.Exec(ctx) - if err != nil { - return fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - - return nil -} - -func formatIndex(fieldName string, value string) string { - // Given that the value is encrypted for now, this might not be needed. - return fmt.Sprintf("%s%s_%s", INDEX_PREFIX, fieldName, strings.ToLower(value)) -} - -func (rs RedisStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { - colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, collectionName) - dbCol, err := rs.GetCollection(ctx, collectionName) - if err != nil { - return []string{}, err - } - - recordIds := []string{} - - pipe := rs.Client.Pipeline() - for _, record := range records { - recordId := GenerateId() - redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) - pipe.SAdd(ctx, fmt.Sprintf("%s:r", colId), recordId) - for rFieldName, rFieldValue := range record { - // TODO: Validate types and schema here - field, ok := dbCol.Fields[rFieldName] - if !ok { - return []string{}, &ValueError{fmt.Sprintf("field %s does not exist in collection %s", rFieldName, collectionName)} - } - pipe.HSet( - ctx, - redisKey, - rFieldName, - rFieldValue, - ) - if field.IsIndexed { - pipe.SAdd(ctx, formatIndex(rFieldName, rFieldValue), recordId) - } - // TODO: Add unique constraint here, removed for simplicity - } - - recordIds = append(recordIds, recordId) - } - _, err = pipe.Exec(ctx) - if err != nil { - return nil, fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - return recordIds, nil -} - -func (rs RedisStore) GetRecords(ctx context.Context, collectionName string, recordIds []string) (map[string]*Record, error) { - _, err := rs.GetCollection(ctx, collectionName) - if err != nil { - return nil, err - } - records := map[string]*Record{} - - for _, recordId := range recordIds { - redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) - recordMap, err := rs.Client.HGetAll(ctx, redisKey).Result() - if len(recordMap) == 0 { - return nil, &NotFoundError{"record", recordId} - } - if err != nil { - return nil, fmt.Errorf("failed to get record with ID %s: %w", recordId, err) - } - - record := &Record{} - for key, value := range recordMap { - (*record)[key] = value - } - - records[recordId] = record - } - return records, nil -} - -func (rs RedisStore) GetRecordsFilter(ctx context.Context, collectionName string, fieldName string, value string) ([]string, error) { - dbCol, err := rs.GetCollection(ctx, collectionName) - if err != nil { - return []string{}, err - } - - if !dbCol.Fields[fieldName].IsIndexed { - return []string{}, ErrIndexError - } - - data, err := rs.Client.SMembers(ctx, formatIndex(fieldName, value)).Result() - if err != nil { - return []string{}, err - } - return data, nil -} - -func (rs RedisStore) UpdateRecord(ctx context.Context, collectionName string, recordId string, record Record) error { - dbCol, err := rs.GetCollection(ctx, collectionName) - if err != nil { - return err - } - - redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) - pipe := rs.Client.Pipeline() - pipe.Del(ctx, redisKey) - for rFieldName, rFieldValue := range record { - field, ok := dbCol.Fields[rFieldName] - if !ok { - return &ValueError{fmt.Sprintf("field %s does not exist in collection %s", rFieldName, collectionName)} - } - pipe.HSet( - ctx, - redisKey, - rFieldName, - rFieldValue, - ) - if field.IsIndexed { - pipe.SAdd(ctx, formatIndex(rFieldName, rFieldValue), recordId) - } - } - - _, err = pipe.Exec(ctx) - if err != nil { - return fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - return nil -} - -func (rs RedisStore) DeleteRecord(ctx context.Context, collectionName string, recordId string) error { - dbCol, err := rs.GetCollection(ctx, collectionName) - if err != nil { - return err - } - - dbRecord, err := rs.GetRecords(ctx, collectionName, []string{recordId}) - if err != nil { - return err - } - - pipe := rs.Client.Pipeline() - redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) - pipe.Del(ctx, redisKey) - colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, collectionName) - pipe.SRem(ctx, fmt.Sprintf("%s:r", colId), recordId) - for fieldName, fieldValue := range *dbRecord[recordId] { - if dbCol.Fields[fieldName].IsIndexed { - pipe.SRem(ctx, formatIndex(fieldName, fieldValue), recordId) - } - } - - _, err = pipe.Exec(ctx) - if err != nil { - return fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - - return nil -} - -func (rs RedisStore) CreatePrincipal(ctx context.Context, principal Principal) error { - principalId := fmt.Sprintf("%s%s", PRINCIPAL_PREFIX, principal.Username) - - exists, err := rs.Client.Exists(ctx, principalId).Result() - if err != nil { - return fmt.Errorf("failed to check principal existence: %w", err) - } - - if exists == 1 { - return &ConflictError{principalId} - } - - pipe := rs.Client.Pipeline() - pipe.SAdd(ctx, PRINCIPAL_PREFIX, principalId) - pipe.HSet( - context.Background(), - principalId, - "username", principal.Username, - "password", principal.Password, - "created_at", principal.CreatedAt, - "description", principal.Description, - ) - - for _, policy := range principal.Policies { - // TODO: Is this a bad idea? The sets can get out of sync - pipe.SAdd(ctx, fmt.Sprintf("%s:policies", principalId), policy) - pipe.SAdd(ctx, fmt.Sprintf("%smembers:%s", POLICY_PREFIX, policy), principal.Username) - } - - _, err = pipe.Exec(ctx) - if err != nil { - return fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - return nil -} - -func (rs RedisStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { - principalId := fmt.Sprintf("%s%s", PRINCIPAL_PREFIX, username) - var dbPrincipal Principal - - pipe := rs.Client.Pipeline() - pipe.HGetAll(ctx, principalId) - pipe.SMembers(ctx, fmt.Sprintf("%s:policies", principalId)) - pipeRes, err := pipe.Exec(ctx) - if err != nil { - if err == redis.Nil { - return nil, &NotFoundError{"principal", principalId} - } - return nil, err - } - err = pipeRes[0].(*redis.MapStringStringCmd).Scan(&dbPrincipal) - if err != nil { - return nil, err - } - dbPrincipal.Policies = pipeRes[1].(*redis.StringSliceCmd).Val() - if dbPrincipal.Username == "" || dbPrincipal.Password == "" { - return nil, &NotFoundError{"principal", principalId} - } - return &dbPrincipal, nil -} - -func (rs RedisStore) DeletePrincipal(ctx context.Context, username string) error { - principalId := fmt.Sprintf("%s%s", PRINCIPAL_PREFIX, username) - - exists, err := rs.Client.Exists(ctx, principalId).Result() - if err != nil { - return fmt.Errorf("failed to check principal existence: %w", err) - } - - if exists != 1 { - return &NotFoundError{"principal", principalId} - } - - pipe := rs.Client.Pipeline() - pipe.Del(ctx, principalId) - pipe.SRem(ctx, PRINCIPAL_PREFIX, principalId) - - _, err = pipe.Exec(ctx) - if err != nil { - return fmt.Errorf("failed to execute Redis pipeline: %w", err) - } - return nil -} - -type RawPolicy struct { - PolicyId string `redis:"policy_id"` - Effect PolicyEffect `redis:"effect"` - Actions string `redis:"actions"` - Resources string `redis:"resources"` -} - -func (rawPolicy RawPolicy) toPolicy() *Policy { - var actions []PolicyAction - for _, action := range strings.Split(rawPolicy.Actions, ",") { - actions = append(actions, PolicyAction(action)) - } - policy := Policy{ - PolicyId: rawPolicy.PolicyId, - Effect: rawPolicy.Effect, - Actions: actions, - Resources: strings.Split(rawPolicy.Resources, ","), - } - - return &policy -} - -func (rs RedisStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { - polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, policyId) - cmd := rs.Client.HGetAll(ctx, polRedisId) - if err := cmd.Err(); err != nil { - return nil, err - } - - result, err := cmd.Result() - if err != nil { - return nil, err - } - - if len(result) == 0 { - return nil, &NotFoundError{"policy", polRedisId} - } - - var rawPolicy RawPolicy - if err := cmd.Scan(&rawPolicy); err != nil { - return nil, err - } - - return rawPolicy.toPolicy(), nil -} - -func (rs RedisStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { - policies := []*Policy{} - pipeline := rs.Client.Pipeline() - - // Prepare the commands - cmds := make([]*redis.MapStringStringCmd, len(policyIds)) - for i, polId := range policyIds { - polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, polId) - cmds[i] = pipeline.HGetAll(ctx, polRedisId) - } - - // Execute the pipeline - _, err := pipeline.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // Process the results - for _, cmd := range cmds { - if err := cmd.Err(); err != nil { - if err != redis.Nil { - return nil, err - } - // Skip if not found - continue - } - var rawPolicy RawPolicy - if err := cmd.Scan(&rawPolicy); err != nil { - return nil, err - } - policies = append(policies, rawPolicy.toPolicy()) - } - return policies, nil -} - -func (rs RedisStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { - polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, p.PolicyId) - var actions []string - for _, action := range p.Actions { - actions = append(actions, string(action)) - } - - _, err := rs.Client.HSet( - ctx, - polRedisId, - "policy_id", p.PolicyId, - "effect", string(p.Effect), - "actions", strings.Join(actions, ","), - "resources", strings.Join(p.Resources, ","), - ).Result() - - if err != nil { - return "", fmt.Errorf("failed to create policy: %w", err) - } - - return p.PolicyId, nil -} - -func (rs RedisStore) DeletePolicy(ctx context.Context, policyId string) error { - _, err := rs.GetPolicy(ctx, policyId) - if err != nil { - return err - } - - polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, policyId) - _, err = rs.Client.Del(ctx, polRedisId).Result() - if err != nil { - return fmt.Errorf("failed to delete policy: %w", err) - } - return nil -} - -// TODO: -// - Set expiration? -func (rs RedisStore) CreateToken(ctx context.Context, tokenId string, value string) error { - return rs.Client.Set(ctx, fmt.Sprintf("%s%s", TOKEN_PREFIX, tokenId), value, 0).Err() -} -func (rs RedisStore) DeleteToken(ctx context.Context, tokenId string) error { - err := rs.Client.Del(ctx, fmt.Sprintf("%s%s", TOKEN_PREFIX, tokenId)).Err() - if err == redis.Nil { - return &NotFoundError{"token", tokenId} - } - - return err -} -func (rs RedisStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { - res, err := rs.Client.Get(ctx, fmt.Sprintf("%s%s", TOKEN_PREFIX, tokenId)).Result() - if err != nil { - if err == redis.Nil { - return "", &NotFoundError{"token", tokenId} - } - return "", err - } - return res, nil -} diff --git a/vault/redis_test.go b/vault/redis_test.go deleted file mode 100644 index 6eb0211..0000000 --- a/vault/redis_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package vault - -import ( - "context" - "os" - "sort" - "testing" -) - -func initDB() (VaultDB, error) { - db, err := NewRedisStore( - os.Getenv("KEYDB_CONN_STRING"), - "", - 0, - ) - if err != nil { - return nil, err - } - db.Flush(context.Background()) - - return db, nil -} - -// TODO: These need to be separate tests latr on and probably mocked... - -func TestRedisStore(t *testing.T) { - t.Run("can crud collections and records", func(t *testing.T) { - ctx := context.Background() - db, err := initDB() - - if err != nil { - t.Fatal(err) - } - - dbCols, err := db.GetCollections(ctx) - if err != nil { - t.Fatal(err) - } - - if len(dbCols) != 0 { - t.Fatal("Expected 0 collections, got", len(dbCols)) - } - - col := Collection{Name: "customers", Fields: map[string]Field{ - "name": { - Type: "string", - IsIndexed: false, - }, - "age": { - Name: "age", - Type: "integer", - IsIndexed: false, - }, - "country": { - Name: "country", - Type: "string", - IsIndexed: true, - }, - }} - - // Can create collection - colID, err := db.CreateCollection(ctx, col) - if err != nil || colID == "" { - t.Fatal(err) - } - - dbCols, err = db.GetCollections(ctx) - if err != nil { - t.Fatal(err) - } - - if len(dbCols) != 1 { - t.Fatal("Expected 1 collection, got", len(dbCols)) - } - - if dbCols[0] != col.Name { - t.Fatal("Expected collection name to be 'test', got", dbCols[0]) - } - - newCol, _ := db.GetCollection(ctx, dbCols[0]) - fields := newCol.Fields - if fields["country"].Name != "country" || fields["country"].Type != "string" || !fields["country"].IsIndexed { - t.Fatal("Field props not matching.") - } - - // Can add records - records := []Record{ - {"name": "Simon", "age": "10", "country": "Ahibia"}, - {"name": "Ali", "age": "11", "country": "Bolonesia"}, - {"name": "Jim", "age": "22", "country": "Sarumania"}, - {"name": "Jeff", "age": "22", "country": "Ahibia"}, - } - - recordIds, err := db.CreateRecords(ctx, col.Name, records) - if err != nil { - t.Fatal(err) - } - - // Can get records - dbRecords, err := db.GetRecords(ctx, col.Name, recordIds) - if err != nil { - t.Fatal(err) - } - if len(dbRecords) != len(records) { - t.Fatalf("Expected %d records, got %d", len(records), len(dbRecords)) - } - - // Can update records - updateRecord := Record{"name": "UpdatedName", "age": "99", "country": "UpdatedCountry"} - err = db.UpdateRecord(ctx, col.Name, recordIds[0], updateRecord) - if err != nil { - t.Fatal(err) - } - - // Verify update of the record - updatedRecord, err := db.GetRecords(ctx, col.Name, []string{recordIds[0]}) - if err != nil { - t.Fatal(err) - } - if (*updatedRecord[recordIds[0]])["name"] != "UpdatedName" || - (*updatedRecord[recordIds[0]])["age"] != "99" || - (*updatedRecord[recordIds[0]])["country"] != "UpdatedCountry" { - t.Fatal("Record not updated correctly.") - } - - // Can delete records - err = db.DeleteRecord(ctx, col.Name, recordIds[0]) - if err != nil { - t.Fatal(err) - } - - // Verify deletion of the record - deleteRecord, err := db.GetRecords(ctx, col.Name, []string{recordIds[0]}) - if err == nil { - t.Fatal(err) - } - if len(deleteRecord) != 0 { - t.Fatal("Record not deleted.") - } - - }) - - t.Run("can delete collections", func(t *testing.T) { - ctx := context.Background() - db, err := initDB() - - if err != nil { - t.Fatal(err) - } - - col := Collection{Name: "customers", Fields: map[string]Field{ - "name": { - Type: "string", - IsIndexed: false, - }, - "age": { - Name: "age", - Type: "integer", - IsIndexed: false, - }, - "country": { - Name: "country", - Type: "string", - IsIndexed: true, - }, - }} - - // Can create collection - colID, err := db.CreateCollection(ctx, col) - if err != nil || colID == "" { - t.Fatal(err) - } - - // Can delete collection - err = db.DeleteCollection(ctx, colID) - if err != nil { - t.Fatal(err) - } - - // Collection should not exist after deletion - _, err = db.GetCollection(ctx, colID) - if err == nil { - t.Fatal("Expected error when getting deleted collection, got nil") - } - }) - - t.Run("can create and get principals", func(t *testing.T) { - ctx := context.Background() - db, err := initDB() - - if err != nil { - t.Fatal(err) - } - - // Note: password is not encrypted when storing this way, but it's just for testing purposes. - // The principal object should be created at the vault level. - principal := Principal{ - Username: "test", - Password: "test", - Description: "test", - CreatedAt: "0", - Policies: []string{"read-customers", "write-credit-cards"}, - } - - // Can create principal - err = db.CreatePrincipal(ctx, principal) - if err != nil { - t.Fatal(err) - } - - // Can get principal - dbPrincipalRead, err := db.GetPrincipal(ctx, principal.Username) - if err != nil { - t.Fatal(err) - } - - if dbPrincipalRead.Username != principal.Username || dbPrincipalRead.Description != principal.Description { - t.Fatal("Principal props not matching.") - } - - // Returned policies match - if len(principal.Policies) != len(dbPrincipalRead.Policies) { - t.Fatalf("Principal policies not matching. Expected %d, got %d", len(principal.Policies), len(dbPrincipalRead.Policies)) - } - - sort.Strings(principal.Policies) - sort.Strings(dbPrincipalRead.Policies) - for i, role := range principal.Policies { - if role != dbPrincipalRead.Policies[i] { - t.Fatalf("Principal policies not matching at index %d. Expected %s, got %s", i, role, dbPrincipalRead.Policies[i]) - } - } - }) - - t.Run("can delete principals", func(t *testing.T) { - ctx := context.Background() - db, err := initDB() - - if err != nil { - t.Fatal(err) - } - - // Note: password is not encrypted when storing this way, but it's just for testing purposes. - // The principal object should be created at the vault level. - principal := Principal{ - Username: "test", - Password: "test", - Description: "test", - CreatedAt: "0", - Policies: []string{"read-customers", "write-credit-cards"}, - } - - // Can create principal - err = db.CreatePrincipal(ctx, principal) - if err != nil { - t.Fatal(err) - } - - // Can delete principal - err = db.DeletePrincipal(ctx, principal.Username) - if err != nil { - t.Fatal(err) - } - - // Principal should not exist after deletion - _, err = db.GetPrincipal(ctx, principal.Username) - if err == nil { - t.Fatal("Expected error when getting deleted principal, got nil") - } - }) - -} diff --git a/vault/sql.go b/vault/sql.go index 07fa459..efdc0d2 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -3,7 +3,6 @@ package vault // TODO: // - Dynamic collection creation (no updates) // - Error handling -// - Tidy DB Models // - Ensure we never log sensitive data // - Add indexes @@ -11,12 +10,15 @@ import ( "context" "encoding/json" "errors" + "log" + "os" "time" "github.com/lib/pq" "gorm.io/datatypes" "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" ) type SqlStore struct { @@ -58,10 +60,30 @@ type DbToken struct { } func NewSqlStore(dsn string) (*SqlStore, error) { + // Todo: Make sure we never log in production + dbLogger := logger.New( + log.New(os.Stdout, "\r\n", log.LstdFlags), + logger.Config{ + SlowThreshold: time.Second, // Slow SQL threshold + LogLevel: logger.Silent, // Log level + IgnoreRecordNotFoundError: true, + ParameterizedQueries: true, + Colorful: false, + }, + ) + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ TranslateError: true, + Logger: dbLogger, }) - db.AutoMigrate(&DbCollection{}, &DbRecord{}, &DbPrincipal{}, &DbPolicy{}, &DbToken{}) + + if err != nil { + return nil, err + } + err = db.AutoMigrate(&DbCollection{}, &DbRecord{}, &DbPrincipal{}, &DbPolicy{}, &DbToken{}) + if err != nil { + return nil, err + } return &SqlStore{db}, err } @@ -70,11 +92,17 @@ func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, var gc DbCollection err := st.db.First(&gc, "name = ?", name).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, &NotFoundError{"collection", name} + } return nil, err } var col Collection - json.Unmarshal(gc.Collection, &col) + err = json.Unmarshal(gc.Collection, &col) + if err != nil { + return nil, err + } return &col, err } @@ -222,7 +250,17 @@ func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, err } return nil, err } - var p Policy + + var policyActions []PolicyAction + for _, action := range gp.Actions { + policyActions = append(policyActions, PolicyAction(action)) + } + p := Policy{ + PolicyId: gp.ID, + Effect: PolicyEffect(gp.Effect), + Actions: policyActions, + Resources: gp.Resources, + } return &p, err } @@ -284,7 +322,7 @@ func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error { } func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { var gt DbToken - err := st.db.First(>, "token_id = ?", tokenId).Error + err := st.db.First(>, "id = ?", tokenId).Error return gt.Value, err } diff --git a/vault/vault_test.go b/vault/vault_test.go index 9664198..964dbcc 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -12,7 +12,7 @@ import ( func initVault(t *testing.T) (Vault, VaultDB, Privatiser) { ctx := context.Background() - db, err := NewSqlStore(os.Getenv("DATABASE_URL")) + db, err := NewSqlStore(os.Getenv("VAULT_DATABASE_URL")) if err != nil { panic(err) } @@ -383,55 +383,55 @@ func TestVault(t *testing.T) { t.Fatal(err) } }) - t.Run("get records by field value", func(t *testing.T) { - vault, _, _ := initVault(t) - col := Collection{Name: "customers", Fields: map[string]Field{ - "first_name": { - Name: "first_name", - Type: "string", - IsIndexed: true, - }, - }} - - // Can create collection - _, _ = vault.CreateCollection(ctx, testPrincipal, col) - _, _ = vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ - {"first_name": "John"}, - {"first_name": "Jane"}, - {"first_name": "Bob"}, - }) - res, err := vault.GetRecordsFilter(ctx, testPrincipal, "customers", "first_name", "Bob", map[string]string{ - "first_name": "plain", - }) - assert.Equal(t, err, nil) - assert.Equal( - t, - len(res), - 1, - ) - }) - t.Run("get records by field fails when field not indexed", func(t *testing.T) { - vault, _, _ := initVault(t) - col := Collection{Name: "customers", Fields: map[string]Field{ - "first_name": { - Name: "first_name", - Type: "string", - IsIndexed: false, - }, - }} - - // Can create collection - _, _ = vault.CreateCollection(ctx, testPrincipal, col) - _, _ = vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ - {"first_name": "John"}, - {"first_name": "Jane"}, - {"first_name": "Bob"}, - }) - _, err := vault.GetRecordsFilter(ctx, testPrincipal, "customers", "first_name", "Bob", map[string]string{ - "first_name": "plain", - }) - assert.Equal(t, err, ErrIndexError) - }) + // t.Run("get records by field value", func(t *testing.T) { + // vault, _, _ := initVault(t) + // col := Collection{Name: "customers", Fields: map[string]Field{ + // "first_name": { + // Name: "first_name", + // Type: "string", + // IsIndexed: true, + // }, + // }} + + // // Can create collection + // _, _ = vault.CreateCollection(ctx, testPrincipal, col) + // _, _ = vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ + // {"first_name": "John"}, + // {"first_name": "Jane"}, + // {"first_name": "Bob"}, + // }) + // res, err := vault.GetRecordsFilter(ctx, testPrincipal, "customers", "first_name", "Bob", map[string]string{ + // "first_name": "plain", + // }) + // assert.Equal(t, err, nil) + // assert.Equal( + // t, + // len(res), + // 1, + // ) + // }) + // t.Run("get records by field fails when field not indexed", func(t *testing.T) { + // vault, _, _ := initVault(t) + // col := Collection{Name: "customers", Fields: map[string]Field{ + // "first_name": { + // Name: "first_name", + // Type: "string", + // IsIndexed: false, + // }, + // }} + + // // Can create collection + // _, _ = vault.CreateCollection(ctx, testPrincipal, col) + // _, _ = vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ + // {"first_name": "John"}, + // {"first_name": "Jane"}, + // {"first_name": "Bob"}, + // }) + // _, err := vault.GetRecordsFilter(ctx, testPrincipal, "customers", "first_name", "Bob", map[string]string{ + // "first_name": "plain", + // }) + // assert.Equal(t, err, ErrIndexError) + // }) } func TestVaultLogin(t *testing.T) { From 0ab4d1ce337a053ddbe400f32405caa3ce64a569 Mon Sep 17 00:00:00 2001 From: Subrose Date: Wed, 22 Nov 2023 18:09:32 +0000 Subject: [PATCH 05/13] stable --- vault/sql.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vault/sql.go b/vault/sql.go index efdc0d2..69f402b 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -2,8 +2,6 @@ package vault // TODO: // - Dynamic collection creation (no updates) -// - Error handling -// - Ensure we never log sensitive data // - Add indexes import ( @@ -320,6 +318,7 @@ func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error { gt := DbToken{ID: tokenId} return st.db.Delete(>).Error } + func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { var gt DbToken err := st.db.First(>, "id = ?", tokenId).Error From 76c5c4d469cf57f53da80f47017d22446f076d5b Mon Sep 17 00:00:00 2001 From: Subrose Date: Wed, 22 Nov 2023 21:57:00 +0000 Subject: [PATCH 06/13] new script and status fixes --- api/principals.go | 2 +- api/principals_test.go | 2 +- archive/redis.go | 551 +++++++++++++++++++++++++++++++++++++++++ archive/redis_test.go | 272 ++++++++++++++++++++ simulator/client.py | 22 ++ simulator/ops.py | 100 ++++++++ vault/sql.go | 90 ++++++- 7 files changed, 1027 insertions(+), 12 deletions(-) create mode 100644 archive/redis.go create mode 100644 archive/redis_test.go create mode 100644 simulator/ops.py diff --git a/api/principals.go b/api/principals.go index 5895da9..d59c4df 100644 --- a/api/principals.go +++ b/api/principals.go @@ -69,5 +69,5 @@ func (core *Core) DeletePrincipal(c *fiber.Ctx) error { if err != nil { return err } - return c.SendStatus(http.StatusOK) + return c.SendStatus(http.StatusNoContent) } diff --git a/api/principals_test.go b/api/principals_test.go index 15c9c3b..7a9bf77 100644 --- a/api/principals_test.go +++ b/api/principals_test.go @@ -70,7 +70,7 @@ func TestPrincipals(t *testing.T) { response := performRequest(t, app, request) - checkResponse(t, response, http.StatusOK, nil) + checkResponse(t, response, http.StatusNoContent, nil) // Check that the principal has been deleted request = newRequest(t, http.MethodGet, fmt.Sprintf("/principals/%s", newPrincipal.Username), map[string]string{ diff --git a/archive/redis.go b/archive/redis.go new file mode 100644 index 0000000..60c1ccd --- /dev/null +++ b/archive/redis.go @@ -0,0 +1,551 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/redis/go-redis/v9" +) + +const COLLECTIONS_PREFIX = "collection:" +const RECORDS_PREFIX = "record:" +const FIELDS_PREFIX = "field:" +const PRINCIPAL_PREFIX = "principal:" +const POLICY_PREFIX = "policy:" +const INDEX_PREFIX = "idx:" +const TOKEN_PREFIX = "token:" + +var ( + Prefix = map[string]string{ + "collection": COLLECTIONS_PREFIX, + "record": RECORDS_PREFIX, + } +) + +type RedisStore struct { + Client *redis.Client +} + +func NewRedisStore(addr, password string, db int) (*RedisStore, error) { + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + DB: db, + }) + + _, err := client.Ping(context.Background()).Result() + if err != nil { + return nil, fmt.Errorf("unable to connect to Redis: %w", err) + } + + return &RedisStore{Client: client}, nil +} + +func (rs RedisStore) Flush(ctx context.Context) error { + _, err := rs.Client.FlushDB(ctx).Result() + return err +} + +func (rs RedisStore) GetCollections(ctx context.Context) ([]string, error) { + members, err := rs.Client.SMembers(ctx, COLLECTIONS_PREFIX).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return []string{}, nil + } + return []string{}, fmt.Errorf("failed to get collections: %w", err) + } + for i, member := range members { + members[i] = member[len(COLLECTIONS_PREFIX):] + } + return members, nil +} + +func (rs RedisStore) GetCollection(ctx context.Context, name string) (*Collection, error) { + colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, name) + dbCol := Collection{} + col, err := rs.Client.HGetAll(ctx, colId).Result() + if err != nil { + return nil, fmt.Errorf("failed to get data from Redis with key %s: %w", colId, err) + } + if len(col) == 0 { + return nil, &NotFoundError{"collection", name} + } + pipe := rs.Client.Pipeline() + for _, v := range col { + pipe.HGetAll(ctx, v) + } + fields, err := pipe.Exec(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + dbCol.Fields = make(map[string]Field, len(fields)) + for _, field := range fields { + dbCol.Fields[field.(*redis.MapStringStringCmd).Val()["name"]] = Field{ + Name: field.(*redis.MapStringStringCmd).Val()["name"], + Type: field.(*redis.MapStringStringCmd).Val()["type"], + IsIndexed: field.(*redis.MapStringStringCmd).Val()["is_indexed"] == "1", + } + } + dbCol.Name = name + return &dbCol, nil +} + +func (rs RedisStore) CreateCollection(ctx context.Context, c Collection) (string, error) { + colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, c.Name) + + exists, err := rs.Client.Exists(ctx, colId).Result() + if err != nil { + return c.Name, fmt.Errorf("failed to check collection existence: %w", err) + } + + if exists == 1 { + return c.Name, &ConflictError{colId} + } + + pipe := rs.Client.Pipeline() + pipe.SAdd(ctx, COLLECTIONS_PREFIX, colId) + for fieldName, fieldValue := range c.Fields { + fieldId := fmt.Sprintf("%s:%s%s", colId, FIELDS_PREFIX, fieldName) + pipe.HSet( + ctx, + fieldId, + "name", fieldName, + "type", fieldValue.Type, + "is_indexed", fieldValue.IsIndexed, + ) + pipe.HSet( + ctx, + colId, + fieldName, + fieldId, + ) + } + + _, err = pipe.Exec(ctx) + if err != nil { + return c.Name, fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + + return c.Name, nil +} + +func (rs RedisStore) DeleteCollection(ctx context.Context, name string) error { + dbCollection, err := rs.GetCollection(ctx, name) + if err != nil { + return err + } + + colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, name) + recordIds, err := rs.Client.SMembers(ctx, fmt.Sprintf("%s:r", colId)).Result() + if err != nil { + return err + } + + pipe := rs.Client.Pipeline() + // Delete the collection + pipe.Del(ctx, colId) + pipe.SRem(ctx, COLLECTIONS_PREFIX, colId) + // Delete all records and indexes for the collection + for _, recordId := range recordIds { + dbRecord, err := rs.GetRecords(ctx, name, []string{recordId}) + if err != nil { + return err + } + + redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) + pipe.Del(ctx, redisKey) + pipe.SRem(ctx, fmt.Sprintf("%s:r", colId), recordId) + for fieldName, fieldValue := range *dbRecord[recordId] { + if dbCollection.Fields[fieldName].IsIndexed { + pipe.SRem(ctx, formatIndex(fieldName, fieldValue), recordId) + } + } + } + + _, err = pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + + return nil +} + +func formatIndex(fieldName string, value string) string { + // Given that the value is encrypted for now, this might not be needed. + return fmt.Sprintf("%s%s_%s", INDEX_PREFIX, fieldName, strings.ToLower(value)) +} + +func (rs RedisStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { + colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, collectionName) + dbCol, err := rs.GetCollection(ctx, collectionName) + if err != nil { + return []string{}, err + } + + recordIds := []string{} + + pipe := rs.Client.Pipeline() + for _, record := range records { + recordId := GenerateId() + redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) + pipe.SAdd(ctx, fmt.Sprintf("%s:r", colId), recordId) + for rFieldName, rFieldValue := range record { + // TODO: Validate types and schema here + field, ok := dbCol.Fields[rFieldName] + if !ok { + return []string{}, &ValueError{fmt.Sprintf("field %s does not exist in collection %s", rFieldName, collectionName)} + } + pipe.HSet( + ctx, + redisKey, + rFieldName, + rFieldValue, + ) + if field.IsIndexed { + pipe.SAdd(ctx, formatIndex(rFieldName, rFieldValue), recordId) + } + // TODO: Add unique constraint here, removed for simplicity + } + + recordIds = append(recordIds, recordId) + } + _, err = pipe.Exec(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + return recordIds, nil +} + +func (rs RedisStore) GetRecords(ctx context.Context, collectionName string, recordIds []string) (map[string]*Record, error) { + _, err := rs.GetCollection(ctx, collectionName) + if err != nil { + return nil, err + } + records := map[string]*Record{} + + for _, recordId := range recordIds { + redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) + recordMap, err := rs.Client.HGetAll(ctx, redisKey).Result() + if len(recordMap) == 0 { + return nil, &NotFoundError{"record", recordId} + } + if err != nil { + return nil, fmt.Errorf("failed to get record with ID %s: %w", recordId, err) + } + + record := &Record{} + for key, value := range recordMap { + (*record)[key] = value + } + + records[recordId] = record + } + return records, nil +} + +func (rs RedisStore) GetRecordsFilter(ctx context.Context, collectionName string, fieldName string, value string) ([]string, error) { + dbCol, err := rs.GetCollection(ctx, collectionName) + if err != nil { + return []string{}, err + } + + if !dbCol.Fields[fieldName].IsIndexed { + return []string{}, ErrIndexError + } + + data, err := rs.Client.SMembers(ctx, formatIndex(fieldName, value)).Result() + if err != nil { + return []string{}, err + } + return data, nil +} + +func (rs RedisStore) UpdateRecord(ctx context.Context, collectionName string, recordId string, record Record) error { + dbCol, err := rs.GetCollection(ctx, collectionName) + if err != nil { + return err + } + + redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) + pipe := rs.Client.Pipeline() + pipe.Del(ctx, redisKey) + for rFieldName, rFieldValue := range record { + field, ok := dbCol.Fields[rFieldName] + if !ok { + return &ValueError{fmt.Sprintf("field %s does not exist in collection %s", rFieldName, collectionName)} + } + pipe.HSet( + ctx, + redisKey, + rFieldName, + rFieldValue, + ) + if field.IsIndexed { + pipe.SAdd(ctx, formatIndex(rFieldName, rFieldValue), recordId) + } + } + + _, err = pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + return nil +} + +func (rs RedisStore) DeleteRecord(ctx context.Context, collectionName string, recordId string) error { + dbCol, err := rs.GetCollection(ctx, collectionName) + if err != nil { + return err + } + + dbRecord, err := rs.GetRecords(ctx, collectionName, []string{recordId}) + if err != nil { + return err + } + + pipe := rs.Client.Pipeline() + redisKey := fmt.Sprintf("%s%s", RECORDS_PREFIX, recordId) + pipe.Del(ctx, redisKey) + colId := fmt.Sprintf("%s%s", COLLECTIONS_PREFIX, collectionName) + pipe.SRem(ctx, fmt.Sprintf("%s:r", colId), recordId) + for fieldName, fieldValue := range *dbRecord[recordId] { + if dbCol.Fields[fieldName].IsIndexed { + pipe.SRem(ctx, formatIndex(fieldName, fieldValue), recordId) + } + } + + _, err = pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + + return nil +} + +func (rs RedisStore) CreatePrincipal(ctx context.Context, principal Principal) error { + principalId := fmt.Sprintf("%s%s", PRINCIPAL_PREFIX, principal.Username) + + exists, err := rs.Client.Exists(ctx, principalId).Result() + if err != nil { + return fmt.Errorf("failed to check principal existence: %w", err) + } + + if exists == 1 { + return &ConflictError{principalId} + } + + pipe := rs.Client.Pipeline() + pipe.SAdd(ctx, PRINCIPAL_PREFIX, principalId) + pipe.HSet( + context.Background(), + principalId, + "username", principal.Username, + "password", principal.Password, + "created_at", principal.CreatedAt, + "description", principal.Description, + ) + + for _, policy := range principal.Policies { + // TODO: Is this a bad idea? The sets can get out of sync + pipe.SAdd(ctx, fmt.Sprintf("%s:policies", principalId), policy) + pipe.SAdd(ctx, fmt.Sprintf("%smembers:%s", POLICY_PREFIX, policy), principal.Username) + } + + _, err = pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + return nil +} + +func (rs RedisStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { + principalId := fmt.Sprintf("%s%s", PRINCIPAL_PREFIX, username) + var dbPrincipal Principal + + pipe := rs.Client.Pipeline() + pipe.HGetAll(ctx, principalId) + pipe.SMembers(ctx, fmt.Sprintf("%s:policies", principalId)) + pipeRes, err := pipe.Exec(ctx) + if err != nil { + if err == redis.Nil { + return nil, &NotFoundError{"principal", principalId} + } + return nil, err + } + err = pipeRes[0].(*redis.MapStringStringCmd).Scan(&dbPrincipal) + if err != nil { + return nil, err + } + dbPrincipal.Policies = pipeRes[1].(*redis.StringSliceCmd).Val() + if dbPrincipal.Username == "" || dbPrincipal.Password == "" { + return nil, &NotFoundError{"principal", principalId} + } + return &dbPrincipal, nil +} + +func (rs RedisStore) DeletePrincipal(ctx context.Context, username string) error { + principalId := fmt.Sprintf("%s%s", PRINCIPAL_PREFIX, username) + + exists, err := rs.Client.Exists(ctx, principalId).Result() + if err != nil { + return fmt.Errorf("failed to check principal existence: %w", err) + } + + if exists != 1 { + return &NotFoundError{"principal", principalId} + } + + pipe := rs.Client.Pipeline() + pipe.Del(ctx, principalId) + pipe.SRem(ctx, PRINCIPAL_PREFIX, principalId) + + _, err = pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute Redis pipeline: %w", err) + } + return nil +} + +type RawPolicy struct { + PolicyId string `redis:"policy_id"` + Effect PolicyEffect `redis:"effect"` + Actions string `redis:"actions"` + Resources string `redis:"resources"` +} + +func (rawPolicy RawPolicy) toPolicy() *Policy { + var actions []PolicyAction + for _, action := range strings.Split(rawPolicy.Actions, ",") { + actions = append(actions, PolicyAction(action)) + } + policy := Policy{ + PolicyId: rawPolicy.PolicyId, + Effect: rawPolicy.Effect, + Actions: actions, + Resources: strings.Split(rawPolicy.Resources, ","), + } + + return &policy +} + +func (rs RedisStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { + polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, policyId) + cmd := rs.Client.HGetAll(ctx, polRedisId) + if err := cmd.Err(); err != nil { + return nil, err + } + + result, err := cmd.Result() + if err != nil { + return nil, err + } + + if len(result) == 0 { + return nil, &NotFoundError{"policy", polRedisId} + } + + var rawPolicy RawPolicy + if err := cmd.Scan(&rawPolicy); err != nil { + return nil, err + } + + return rawPolicy.toPolicy(), nil +} + +func (rs RedisStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { + policies := []*Policy{} + pipeline := rs.Client.Pipeline() + + // Prepare the commands + cmds := make([]*redis.MapStringStringCmd, len(policyIds)) + for i, polId := range policyIds { + polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, polId) + cmds[i] = pipeline.HGetAll(ctx, polRedisId) + } + + // Execute the pipeline + _, err := pipeline.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + // Process the results + for _, cmd := range cmds { + if err := cmd.Err(); err != nil { + if err != redis.Nil { + return nil, err + } + // Skip if not found + continue + } + var rawPolicy RawPolicy + if err := cmd.Scan(&rawPolicy); err != nil { + return nil, err + } + policies = append(policies, rawPolicy.toPolicy()) + } + return policies, nil +} + +func (rs RedisStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { + polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, p.PolicyId) + var actions []string + for _, action := range p.Actions { + actions = append(actions, string(action)) + } + + _, err := rs.Client.HSet( + ctx, + polRedisId, + "policy_id", p.PolicyId, + "effect", string(p.Effect), + "actions", strings.Join(actions, ","), + "resources", strings.Join(p.Resources, ","), + ).Result() + + if err != nil { + return "", fmt.Errorf("failed to create policy: %w", err) + } + + return p.PolicyId, nil +} + +func (rs RedisStore) DeletePolicy(ctx context.Context, policyId string) error { + _, err := rs.GetPolicy(ctx, policyId) + if err != nil { + return err + } + + polRedisId := fmt.Sprintf("%s%s", POLICY_PREFIX, policyId) + _, err = rs.Client.Del(ctx, polRedisId).Result() + if err != nil { + return fmt.Errorf("failed to delete policy: %w", err) + } + return nil +} + +// TODO: +// - Set expiration? +func (rs RedisStore) CreateToken(ctx context.Context, tokenId string, value string) error { + return rs.Client.Set(ctx, fmt.Sprintf("%s%s", TOKEN_PREFIX, tokenId), value, 0).Err() +} +func (rs RedisStore) DeleteToken(ctx context.Context, tokenId string) error { + err := rs.Client.Del(ctx, fmt.Sprintf("%s%s", TOKEN_PREFIX, tokenId)).Err() + if err == redis.Nil { + return &NotFoundError{"token", tokenId} + } + + return err +} +func (rs RedisStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { + res, err := rs.Client.Get(ctx, fmt.Sprintf("%s%s", TOKEN_PREFIX, tokenId)).Result() + if err != nil { + if err == redis.Nil { + return "", &NotFoundError{"token", tokenId} + } + return "", err + } + return res, nil +} diff --git a/archive/redis_test.go b/archive/redis_test.go new file mode 100644 index 0000000..6eb0211 --- /dev/null +++ b/archive/redis_test.go @@ -0,0 +1,272 @@ +package vault + +import ( + "context" + "os" + "sort" + "testing" +) + +func initDB() (VaultDB, error) { + db, err := NewRedisStore( + os.Getenv("KEYDB_CONN_STRING"), + "", + 0, + ) + if err != nil { + return nil, err + } + db.Flush(context.Background()) + + return db, nil +} + +// TODO: These need to be separate tests latr on and probably mocked... + +func TestRedisStore(t *testing.T) { + t.Run("can crud collections and records", func(t *testing.T) { + ctx := context.Background() + db, err := initDB() + + if err != nil { + t.Fatal(err) + } + + dbCols, err := db.GetCollections(ctx) + if err != nil { + t.Fatal(err) + } + + if len(dbCols) != 0 { + t.Fatal("Expected 0 collections, got", len(dbCols)) + } + + col := Collection{Name: "customers", Fields: map[string]Field{ + "name": { + Type: "string", + IsIndexed: false, + }, + "age": { + Name: "age", + Type: "integer", + IsIndexed: false, + }, + "country": { + Name: "country", + Type: "string", + IsIndexed: true, + }, + }} + + // Can create collection + colID, err := db.CreateCollection(ctx, col) + if err != nil || colID == "" { + t.Fatal(err) + } + + dbCols, err = db.GetCollections(ctx) + if err != nil { + t.Fatal(err) + } + + if len(dbCols) != 1 { + t.Fatal("Expected 1 collection, got", len(dbCols)) + } + + if dbCols[0] != col.Name { + t.Fatal("Expected collection name to be 'test', got", dbCols[0]) + } + + newCol, _ := db.GetCollection(ctx, dbCols[0]) + fields := newCol.Fields + if fields["country"].Name != "country" || fields["country"].Type != "string" || !fields["country"].IsIndexed { + t.Fatal("Field props not matching.") + } + + // Can add records + records := []Record{ + {"name": "Simon", "age": "10", "country": "Ahibia"}, + {"name": "Ali", "age": "11", "country": "Bolonesia"}, + {"name": "Jim", "age": "22", "country": "Sarumania"}, + {"name": "Jeff", "age": "22", "country": "Ahibia"}, + } + + recordIds, err := db.CreateRecords(ctx, col.Name, records) + if err != nil { + t.Fatal(err) + } + + // Can get records + dbRecords, err := db.GetRecords(ctx, col.Name, recordIds) + if err != nil { + t.Fatal(err) + } + if len(dbRecords) != len(records) { + t.Fatalf("Expected %d records, got %d", len(records), len(dbRecords)) + } + + // Can update records + updateRecord := Record{"name": "UpdatedName", "age": "99", "country": "UpdatedCountry"} + err = db.UpdateRecord(ctx, col.Name, recordIds[0], updateRecord) + if err != nil { + t.Fatal(err) + } + + // Verify update of the record + updatedRecord, err := db.GetRecords(ctx, col.Name, []string{recordIds[0]}) + if err != nil { + t.Fatal(err) + } + if (*updatedRecord[recordIds[0]])["name"] != "UpdatedName" || + (*updatedRecord[recordIds[0]])["age"] != "99" || + (*updatedRecord[recordIds[0]])["country"] != "UpdatedCountry" { + t.Fatal("Record not updated correctly.") + } + + // Can delete records + err = db.DeleteRecord(ctx, col.Name, recordIds[0]) + if err != nil { + t.Fatal(err) + } + + // Verify deletion of the record + deleteRecord, err := db.GetRecords(ctx, col.Name, []string{recordIds[0]}) + if err == nil { + t.Fatal(err) + } + if len(deleteRecord) != 0 { + t.Fatal("Record not deleted.") + } + + }) + + t.Run("can delete collections", func(t *testing.T) { + ctx := context.Background() + db, err := initDB() + + if err != nil { + t.Fatal(err) + } + + col := Collection{Name: "customers", Fields: map[string]Field{ + "name": { + Type: "string", + IsIndexed: false, + }, + "age": { + Name: "age", + Type: "integer", + IsIndexed: false, + }, + "country": { + Name: "country", + Type: "string", + IsIndexed: true, + }, + }} + + // Can create collection + colID, err := db.CreateCollection(ctx, col) + if err != nil || colID == "" { + t.Fatal(err) + } + + // Can delete collection + err = db.DeleteCollection(ctx, colID) + if err != nil { + t.Fatal(err) + } + + // Collection should not exist after deletion + _, err = db.GetCollection(ctx, colID) + if err == nil { + t.Fatal("Expected error when getting deleted collection, got nil") + } + }) + + t.Run("can create and get principals", func(t *testing.T) { + ctx := context.Background() + db, err := initDB() + + if err != nil { + t.Fatal(err) + } + + // Note: password is not encrypted when storing this way, but it's just for testing purposes. + // The principal object should be created at the vault level. + principal := Principal{ + Username: "test", + Password: "test", + Description: "test", + CreatedAt: "0", + Policies: []string{"read-customers", "write-credit-cards"}, + } + + // Can create principal + err = db.CreatePrincipal(ctx, principal) + if err != nil { + t.Fatal(err) + } + + // Can get principal + dbPrincipalRead, err := db.GetPrincipal(ctx, principal.Username) + if err != nil { + t.Fatal(err) + } + + if dbPrincipalRead.Username != principal.Username || dbPrincipalRead.Description != principal.Description { + t.Fatal("Principal props not matching.") + } + + // Returned policies match + if len(principal.Policies) != len(dbPrincipalRead.Policies) { + t.Fatalf("Principal policies not matching. Expected %d, got %d", len(principal.Policies), len(dbPrincipalRead.Policies)) + } + + sort.Strings(principal.Policies) + sort.Strings(dbPrincipalRead.Policies) + for i, role := range principal.Policies { + if role != dbPrincipalRead.Policies[i] { + t.Fatalf("Principal policies not matching at index %d. Expected %s, got %s", i, role, dbPrincipalRead.Policies[i]) + } + } + }) + + t.Run("can delete principals", func(t *testing.T) { + ctx := context.Background() + db, err := initDB() + + if err != nil { + t.Fatal(err) + } + + // Note: password is not encrypted when storing this way, but it's just for testing purposes. + // The principal object should be created at the vault level. + principal := Principal{ + Username: "test", + Password: "test", + Description: "test", + CreatedAt: "0", + Policies: []string{"read-customers", "write-credit-cards"}, + } + + // Can create principal + err = db.CreatePrincipal(ctx, principal) + if err != nil { + t.Fatal(err) + } + + // Can delete principal + err = db.DeletePrincipal(ctx, principal.Username) + if err != nil { + t.Fatal(err) + } + + // Principal should not exist after deletion + _, err = db.GetPrincipal(ctx, principal.Username) + if err == nil { + t.Fatal("Expected error when getting deleted principal, got nil") + } + }) + +} diff --git a/simulator/client.py b/simulator/client.py index 0e0ac47..5599438 100644 --- a/simulator/client.py +++ b/simulator/client.py @@ -56,6 +56,18 @@ def create_principal( check_expected_status(response, expected_statuses) return response.json() + def delete_principal( + self, + username: str, + expected_statuses: Optional[list[int]] = None, + ) -> None: + response = requests.delete( + f"{self.vault_url}/principals/{username}", + auth=(self.username, self.password), + ) + check_expected_status(response, expected_statuses) + return + def create_collection( self, schema: dict[str, Any], expected_statuses: Optional[list[int]] = None ) -> None: @@ -114,3 +126,13 @@ def get_policy( ) check_expected_status(response, expected_statuses) return response.json() + + def delete_policy( + self, policy_id: str, expected_statuses: Optional[list[int]] = None + ) -> None: + response = requests.delete( + f"{self.vault_url}/policies/{policy_id}", + auth=(self.username, self.password), + ) + check_expected_status(response, expected_statuses) + return diff --git a/simulator/ops.py b/simulator/ops.py new file mode 100644 index 0000000..bbd8caf --- /dev/null +++ b/simulator/ops.py @@ -0,0 +1,100 @@ +from client import Actor, Policy +from wait import wait_for_api +import os + +# VAULT_URL from your client.py +VAULT_URL = os.environ.get("VAULT_URL", "http://localhost:3001") +wait_for_api(VAULT_URL) + +# Step 0: Initialize your actors + +ADMIN_USERNAME = "admin" +ADMIN_PASSWORD = "admin" + +SOMEBODY_USERNAME = "somebody" +SOMEBODY_PASSWORD = "somebody-password" + +admin = Actor(VAULT_URL, username=ADMIN_USERNAME, password=ADMIN_PASSWORD) + +# Create collection and some records +admin.create_collection( + schema={ + "name": "secrets", + "fields": { + "name": {"type": "string", "indexed": False}, + "value": {"type": "string", "indexed": False}, + }, + }, + expected_statuses=[201, 409], +) + +# Admin adds some records +records = admin.create_records( + "secrets", + [{"name": "admin-password", "value": "admin-password-value"}], + expected_statuses=[201], +) + + +# Create a temporary policy for somebody +admin.create_policy( + policy=Policy( + policy_id="secret-access", + effect="allow", + actions=["read", "write"], + resources=["/collections/secrets/*"], + ), + expected_statuses=[201, 409], +) + +# Admin recreates somebody +admin.delete_principal( + username=SOMEBODY_USERNAME, + expected_statuses=[204, 404], +) + +admin.create_principal( + username=SOMEBODY_USERNAME, + password=SOMEBODY_PASSWORD, + description="somebody", + policies=["secret-access"], + expected_statuses=[201, 409], +) + +somebody = Actor(VAULT_URL, SOMEBODY_USERNAME, SOMEBODY_PASSWORD) + +# Somebody reads the records +record = somebody.get_record( + collection="secrets", + record_id=records[0], + return_formats="name.plain,value.plain", + expected_statuses=[200], +) + +# Policy is removed +admin.delete_policy( + policy_id="secret-access", + expected_statuses=[204], +) + +# Policy is removed twice for good measure +admin.delete_policy( + policy_id="secret-access", + expected_statuses=[404], +) + +# Somebody can't read the records anymore +somebody.get_record( + collection="secrets", + record_id=records[0], + return_formats="name.plain,value.plain", + expected_statuses=[403], +) + +# Admin deletes somebody +admin.delete_principal( + username=SOMEBODY_USERNAME, + expected_statuses=[204], +) + +print("Ops use case completed successfully!") diff --git a/vault/sql.go b/vault/sql.go index 69f402b..d5ea3fc 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -40,7 +40,7 @@ type DbPrincipal struct { Description string CreatedAt time.Time UpdatedAt time.Time - PolicyIds pq.StringArray `gorm:"type:text[]"` + Policies []DbPolicy `gorm:"many2many:principal_policies;"` } type DbPolicy struct { @@ -200,7 +200,7 @@ func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, reco func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { var dbPrincipal DbPrincipal - err := st.db.First(&dbPrincipal, "username = ?", username).Error + err := st.db.Preload("Policies").First(&dbPrincipal, "username = ?", username).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, &NotFoundError{"principal", username} @@ -208,35 +208,75 @@ func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principa return nil, err } + policyIds := make([]string, len(dbPrincipal.Policies)) + for i, policy := range dbPrincipal.Policies { + policyIds[i] = policy.ID + } + principal := Principal{ Username: dbPrincipal.Username, Password: dbPrincipal.Password, Description: dbPrincipal.Description, - Policies: dbPrincipal.PolicyIds, + Policies: policyIds, } return &principal, err } func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) error { + dbPrincipal := DbPrincipal{ Username: principal.Username, Password: principal.Password, Description: principal.Description, - PolicyIds: principal.Policies, } - err := st.db.Create(&dbPrincipal).Error + tx := st.db.Begin() + if err := tx.Error; err != nil { + return err + } + err := tx.Create(&dbPrincipal).Error if err != nil { + tx.Rollback() if errors.Is(err, gorm.ErrDuplicatedKey) { return &ConflictError{principal.Username} } return err } + var dbPolicies []DbPolicy + if err := tx.Where("id IN ?", principal.Policies).Find(&dbPolicies).Error; err != nil { + tx.Rollback() + return err + } + if err := tx.Model(&dbPrincipal).Association("Policies").Append(&dbPolicies); err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil } func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { - return st.db.Delete(&DbPrincipal{}, "username = ?", username).Error + // Start a transaction + tx := st.db.Begin() + if err := tx.Error; err != nil { + return err + } + + // First, delete associations in the many-to-many join table + if err := tx.Table("principal_policies").Where("db_principal_username = ?", username).Delete(nil).Error; err != nil { + tx.Rollback() + return err + } + + // Now, delete the principal itself + if err := tx.Where("username = ?", username).Delete(&DbPrincipal{}).Error; err != nil { + tx.Rollback() + return err + } + + // Commit the transaction + return tx.Commit().Error } func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { @@ -265,7 +305,7 @@ func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, err func (st SqlStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { var dBPolicies []DbPolicy - err := st.db.Find(&dBPolicies, policyIds).Error + err := st.db.Where("id IN ?", policyIds).Find(&dBPolicies).Error if err != nil { return nil, err } @@ -304,9 +344,36 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { return p.PolicyId, nil } -func (st SqlStore) DeletePolicy(ctx context.Context, policyId string) error { - gp := DbPolicy{ID: policyId} - return st.db.Delete(gp).Error +func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error { + // Start a transaction + tx := st.db.Begin() + if err := tx.Error; err != nil { + return err + } + + // Check if the policy exists + var policy DbPolicy + if err := tx.First(&policy, "id = ?", policyID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &NotFoundError{"policy", policyID} + } + return err + } + + // Directly delete associations in the many-to-many join table + if err := tx.Table("principal_policies").Where("db_policy_id = ?", policyID).Delete(nil).Error; err != nil { + tx.Rollback() + return err + } + + // Delete the policy itself + if err := tx.Where("id = ?", policyID).Delete(&DbPolicy{}).Error; err != nil { + tx.Rollback() + return err + } + + // Commit the transaction + return tx.Commit().Error } func (st SqlStore) CreateToken(ctx context.Context, tokenId string, value string) error { @@ -326,6 +393,9 @@ func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, e } func (st SqlStore) Flush(ctx context.Context) error { + // Delete fk constraints first + st.db.Exec("delete from principal_policies") + // Delete all records tables := []string{} st.db.Raw("SELECT tablename FROM pg_tables WHERE schemaname='public'").Scan(&tables) for _, table := range tables { From 39ace49579ed7693d93e963d718105515f56d24f Mon Sep 17 00:00:00 2001 From: Subrose Date: Thu, 23 Nov 2023 14:56:17 +0000 Subject: [PATCH 07/13] raw working --- api/core.go | 11 +- simulator/password_manager.py | 28 +- vault/go.mod | 1 + vault/go.sum | 5 + vault/sql.go | 699 ++++++++++++++++++++++++---------- vault/vault.go | 28 +- vault/vault_test.go | 2 +- 7 files changed, 547 insertions(+), 227 deletions(-) diff --git a/api/core.go b/api/core.go index 85ecdb5..dde26eb 100644 --- a/api/core.go +++ b/api/core.go @@ -132,18 +132,23 @@ func (core *Core) Init() error { if core.conf.DEV_MODE { _ = core.vault.Db.Flush(ctx) } - _, _ = core.vault.Db.CreatePolicy(ctx, _vault.Policy{ + _, err := core.vault.Db.CreatePolicy(ctx, _vault.Policy{ PolicyId: "root", Effect: _vault.EffectAllow, Actions: []_vault.PolicyAction{_vault.PolicyActionWrite, _vault.PolicyActionRead}, Resources: []string{"*"}, }) + if err != nil { + panic(err) + } adminPrincipal := _vault.Principal{ Username: core.conf.VAULT_ADMIN_USERNAME, Password: core.conf.VAULT_ADMIN_PASSWORD, Description: "admin", Policies: []string{"root"}} - err := core.vault.CreatePrincipal(ctx, adminPrincipal, adminPrincipal.Username, adminPrincipal.Password, adminPrincipal.Description, adminPrincipal.Policies) - + err = core.vault.CreatePrincipal(ctx, adminPrincipal, adminPrincipal.Username, adminPrincipal.Password, adminPrincipal.Description, adminPrincipal.Policies) + if err != nil { + panic(err) + } return err } diff --git a/simulator/password_manager.py b/simulator/password_manager.py index 2e314a0..62f263e 100644 --- a/simulator/password_manager.py +++ b/simulator/password_manager.py @@ -23,7 +23,7 @@ # Step 2: Create collection admin.create_collection( schema={ - "name": "alice-passwords", + "name": "alice_passwords", "fields": { "service": {"type": "string", "indexed": False}, "password": {"type": "string", "indexed": False}, @@ -34,7 +34,7 @@ admin.create_collection( schema={ - "name": "bob-passwords", + "name": "bob_passwords", "fields": { "service": {"type": "string", "indexed": False}, "password": {"type": "string", "indexed": False}, @@ -47,10 +47,10 @@ admin.create_policy( policy=Policy( - policy_id="alice-access-own-passwords", + policy_id="alice-access-own_passwords", effect="allow", actions=["read", "write"], - resources=["/collections/alice-passwords/*"], + resources=["/collections/alice_passwords/*"], ), expected_statuses=[201, 409], ) @@ -58,10 +58,10 @@ admin.create_policy( policy=Policy( - policy_id="bob-access-own-passwords", + policy_id="bob-access-own_passwords", effect="allow", actions=["read", "write"], - resources=["/collections/bob-passwords/*"], + resources=["/collections/bob_passwords/*"], ), expected_statuses=[201, 409], ) @@ -70,7 +70,7 @@ username=ALICE_USERNAME, password=ALICE_PASSWORD, description="alice", - policies=["alice-access-own-passwords"], + policies=["alice-access-own_passwords"], expected_statuses=[201, 409], ) @@ -81,7 +81,7 @@ username=BOB_USERNAME, password=BOB_PASSWORD, description="bob", - policies=["bob-access-own-passwords"], + policies=["bob-access-own_passwords"], expected_statuses=[201, 409], ) @@ -90,7 +90,7 @@ # 2) Alice adds a password alice_password = "alicerocks" alice_password_res = alice.create_records( - "alice-passwords", + "alice_passwords", [{"service": "email", "password": alice_password}], expected_statuses=[201], ) @@ -98,7 +98,7 @@ # 4) Bob adds a password bob_password = "bobisthebest" bob_password_res = bob.create_records( - "bob-passwords", + "bob_passwords", [{"service": "email", "password": bob_password}], expected_statuses=[201], ) @@ -108,7 +108,7 @@ # 5) Alice views her passwords alice_retrieved_password = alice.get_record( - collection="alice-passwords", + collection="alice_passwords", record_id=alice_password_id, return_formats="service.plain,password.plain", expected_statuses=[200], @@ -118,7 +118,7 @@ # 6) Bob views his passwords bob_retrieved_password = bob.get_record( - collection="bob-passwords", + collection="bob_passwords", record_id=bob_password_id, return_formats="service.plain,password.plain", expected_statuses=[200], @@ -127,7 +127,7 @@ # 7) Alice can't CRUD Bob's passwords alice.get_record( - collection="bob-passwords", + collection="bob_passwords", record_id=bob_password_id, return_formats="service.plain,password.plain", expected_statuses=[403], @@ -135,7 +135,7 @@ # 8) Bob can't CRUD Alice's passwords bob.get_record( - collection="alice-passwords", + collection="alice_passwords", record_id=alice_password_id, return_formats="service.plain,password.plain", expected_statuses=[403], diff --git a/vault/go.mod b/vault/go.mod index 692449d..4892e64 100644 --- a/vault/go.mod +++ b/vault/go.mod @@ -21,6 +21,7 @@ require ( github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/jmoiron/sqlx v1.3.5 // indirect github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/vault/go.sum b/vault/go.sum index e86313f..d890e8a 100644 --- a/vault/go.sum +++ b/vault/go.sum @@ -9,6 +9,7 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -28,8 +29,12 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/nyaruka/phonenumbers v1.1.6 h1:DcueYq7QrOArAprAYNoQfDgp0KetO4LqtnBtQC6Wyes= diff --git a/vault/sql.go b/vault/sql.go index d5ea3fc..be41044 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -1,182 +1,369 @@ package vault -// TODO: -// - Dynamic collection creation (no updates) -// - Add indexes - import ( "context" "encoding/json" "errors" - "log" - "os" - "time" + "fmt" + "reflect" + "strings" + + "database/sql" + _ "github.com/jackc/pgx/v5" + "github.com/jmoiron/sqlx" "github.com/lib/pq" - "gorm.io/datatypes" - "gorm.io/driver/postgres" - "gorm.io/gorm" - "gorm.io/gorm/logger" + _ "github.com/lib/pq" ) type SqlStore struct { - db *gorm.DB + db *sqlx.DB } type DbRecord struct { Id string CollectionName string - Record datatypes.JSON + Record json.RawMessage } type DbCollection struct { - Name string `gorm:"primaryKey"` - Collection datatypes.JSON -} - -type DbPrincipal struct { - Username string `gorm:"primaryKey"` - Password string - Description string - CreatedAt time.Time - UpdatedAt time.Time - Policies []DbPolicy `gorm:"many2many:principal_policies;"` -} - -type DbPolicy struct { - ID string `gorm:"primaryKey"` - Effect string - Actions pq.StringArray `gorm:"type:text[]"` - Resources pq.StringArray `gorm:"type:text[]"` - CreatedAt time.Time - UpdatedAt time.Time + Name string `db:"name"` + Collection json.RawMessage } type DbToken struct { - ID string `gorm:"primaryKey"` + ID string `db:"id"` Value string } -func NewSqlStore(dsn string) (*SqlStore, error) { - // Todo: Make sure we never log in production - dbLogger := logger.New( - log.New(os.Stdout, "\r\n", log.LstdFlags), - logger.Config{ - SlowThreshold: time.Second, // Slow SQL threshold - LogLevel: logger.Silent, // Log level - IgnoreRecordNotFoundError: true, - ParameterizedQueries: true, - Colorful: false, - }, - ) - - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ - TranslateError: true, - Logger: dbLogger, - }) +type CollectionMetadata struct { + Name string + FieldSchema json.RawMessage `db:"field_schema"` +} +func NewSqlStore(dsn string) (*SqlStore, error) { + db, err := sqlx.Connect("postgres", dsn) if err != nil { return nil, err } - err = db.AutoMigrate(&DbCollection{}, &DbRecord{}, &DbPrincipal{}, &DbPolicy{}, &DbToken{}) + + store := &SqlStore{db} + + err = store.CreateSchemas() if err != nil { return nil, err } - return &SqlStore{db}, err + return store, nil +} + +func (st *SqlStore) CreateSchemas() error { + tables := map[string]string{ + "principals": "CREATE TABLE IF NOT EXISTS principals (username TEXT PRIMARY KEY, password TEXT, description TEXT)", + "policies": "CREATE TABLE IF NOT EXISTS policies (id TEXT, effect TEXT, actions TEXT[], resources TEXT[])", + "tokens": "CREATE TABLE IF NOT EXISTS tokens (id TEXT, value TEXT)", + "collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (name TEXT, field_schema JSON)", + "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT)", + } + + for _, query := range tables { + _, err := st.db.Exec(query) + if err != nil { + return err + } + } + + return nil +} + +func (st SqlStore) createCollectionTable(ctx context.Context, c Collection) error { + // Define a dynamic struct based on the Fields of the collection + var dynamicStructFields []reflect.StructField + + // Add an ID field to the struct + idField := reflect.StructField{ + Name: "ID", + Type: reflect.TypeOf(""), + Tag: reflect.StructTag(`db:"id"`), + } + dynamicStructFields = append(dynamicStructFields, idField) + + for fieldName := range c.Fields { + exportedFieldName := strings.Title(fieldName) + structField := reflect.StructField{ + Name: exportedFieldName, + Type: reflect.TypeOf(""), // Assuming all fields are strings for simplicity + Tag: reflect.StructTag(fmt.Sprintf(`db:"%s"`, fieldName)), + } + dynamicStructFields = append(dynamicStructFields, structField) + } + + // dynamicStruct := reflect.StructOf(dynamicStructFields) + // dynamicStructPtr := reflect.New(dynamicStruct).Interface() // Create a pointer to a new instance of the dynamic struct + + tableName := "collection_" + c.Name // Create a unique table name + + // Create the table using SQLX's MustExec with a pointer to the dynamic struct + // st.db.MustExecContext(ctx, "CREATE TABLE IF NOT EXISTS "+tableName+" (?)", dynamicStructPtr) + // Instead of using the dynamic struct directly, we will generate the SQL query manually + var queryBuilder strings.Builder + queryBuilder.WriteString("CREATE TABLE IF NOT EXISTS " + tableName + " (id TEXT") + for fieldName := range c.Fields { + queryBuilder.WriteString(", " + fieldName + " TEXT") + } + queryBuilder.WriteString(")") + st.db.MustExecContext(ctx, queryBuilder.String()) + + return nil +} + +func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { + // Convert the Fields map to JSON for storing in the collection_metadata table + fieldSchema, err := json.Marshal(c.Fields) + if err != nil { + return "", err + } + + // Create a new CollectionMetadata instance + collectionMetadata := CollectionMetadata{ + Name: c.Name, + FieldSchema: fieldSchema, + } + + // Save the collection metadata + _, err = st.db.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", &collectionMetadata) + if err != nil { + return "", err + } + + // Dynamically create a table for the collection + if err := st.createCollectionTable(ctx, c); err != nil { + return "", err + } + + return collectionMetadata.Name, nil } func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) { - var gc DbCollection - err := st.db.First(&gc, "name = ?", name).Error + var collectionMetadata CollectionMetadata + err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", name) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"collection", name} } return nil, err } - var col Collection - err = json.Unmarshal(gc.Collection, &col) + var fields map[string]Field + err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) if err != nil { return nil, err } - return &col, err + return &Collection{Name: collectionMetadata.Name, Fields: fields}, nil } func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) { - var gcs []DbCollection + var collectionMetadatas []CollectionMetadata - err := st.db.Find(&gcs).Error + err := st.db.SelectContext(ctx, &collectionMetadatas, "SELECT * FROM collection_metadata") - colNames := make([]string, len(gcs)) - for i, col := range gcs { - colNames[i] = col.Name + collectionNames := make([]string, len(collectionMetadatas)) + for i, collectionMetadata := range collectionMetadatas { + collectionNames[i] = collectionMetadata.Name } - return colNames, err + return collectionNames, err } -func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { - b, err := json.Marshal(c) +func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { + // Start a transaction + tx, err := st.db.BeginTxx(ctx, nil) if err != nil { - return "", err + return err } - gormCol := DbCollection{Name: c.Name, Collection: datatypes.JSON(b)} - result := st.db.Create(&gormCol) - if result.Error != nil { - switch result.Error { - case gorm.ErrDuplicatedKey: - return "", &ConflictError{c.Name} + // Delete the collection metadata + _, err = tx.ExecContext(ctx, "DELETE FROM collection_metadata WHERE name = $1", name) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + tx.Rollback() + return &NotFoundError{"collection", name} } + tx.Rollback() + return err } - return c.Name, nil -} + // Delete the dynamic table + tableName := "collection_" + name + _, err = tx.ExecContext(ctx, "DROP TABLE IF EXISTS "+tableName) + if err != nil { + tx.Rollback() + return err + } -func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { - gc := DbCollection{Name: name, Collection: nil} + // Commit the transaction + err = tx.Commit() + if err != nil { + return err + } - return st.db.Delete(&gc).Error + return nil } - func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { - recordIds := make([]string, len(records)) - gormRecords := make([]DbRecord, len(records)) + var collectionMetadata CollectionMetadata + err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, &NotFoundError{"collection", collectionName} + } + return nil, err + } + var fields map[string]Field + err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) + if err != nil { + return nil, err + } + + // Create slice for field names and initialize placeholders + fieldNames := make([]string, 0, len(fields)) + placeholders := make([]string, 0, len(fields)) + idx := 2 // Start from 2 because $1 is reserved for recordId + + for fieldName := range fields { + fieldNames = append(fieldNames, fieldName) + placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) + idx++ + } + + // Prepare SQL statement + query := fmt.Sprintf("INSERT INTO collection_%s (id, %s) VALUES ($1, %s)", collectionName, strings.Join(fieldNames, ", "), strings.Join(placeholders, ", ")) + stmt, err := st.db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + + // Processing records + recordIds := make([]string, len(records)) for i, record := range records { + // Validate record fields + if len(record) != len(fields) { + return nil, errors.New("record does not match schema") + } + recordId := GenerateId() - jsonBytes, err := json.Marshal(record) + recordIds[i] = recordId + + // Prepare values for insertion + values := make([]interface{}, len(fields)+1) + values[0] = recordId + for j, fieldName := range fieldNames { + if value, ok := record[fieldName]; ok { + values[j+1] = value + } else { + return nil, fmt.Errorf("missing field: %s", fieldName) + } + } + + // Execute the prepared statement + _, err = stmt.ExecContext(ctx, values...) if err != nil { return nil, err } - gormRecords[i] = DbRecord{Id: recordId, CollectionName: collectionName, Record: datatypes.JSON(jsonBytes)} - recordIds[i] = recordId - } - err := st.db.CreateInBatches(&gormRecords, len(records)).Error - if err != nil { - return nil, err } return recordIds, nil } +// func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { +// var collectionMetadata CollectionMetadata +// err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) +// if err != nil { +// if errors.Is(err, sql.ErrNoRows) { +// return nil, &NotFoundError{"collection", collectionName} +// } +// return nil, err +// } + +// var fields map[string]Field +// err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) +// if err != nil { +// return nil, err +// } + +// recordIds := make([]string, len(records)) +// for i, record := range records { +// recordId := GenerateId() +// recordIds[i] = recordId +// // Create a slice to hold the field names and values +// var fieldNames []string +// var fieldValues []interface{} +// for fieldName, fieldValue := range record { +// if _, ok := fields[fieldName]; ok { +// fieldNames = append(fieldNames, fieldName) +// fieldValues = append(fieldValues, fieldValue) +// } +// } + +// // Generate placeholders for each field value +// placeholders := make([]string, len(fieldValues)) +// for i := range placeholders { +// placeholders[i] = "$" + strconv.Itoa(i+2) // Start from $2 as $1 is reserved for recordId +// } + +// // Construct the query +// query := fmt.Sprintf( +// "INSERT INTO collection_%s (id, %s) VALUES ($1, %s)", +// collectionName, +// strings.Join(fieldNames, ", "), +// strings.Join(placeholders, ", "), +// ) + +// // Append the recordId and fieldValues to form the final values for the query +// values := make([]interface{}, len(fieldValues)+1) +// values[0] = recordId +// copy(values[1:], fieldValues) + +// _, err = st.db.ExecContext(ctx, query, values...) +// if err != nil { +// return nil, err +// } +// } +// return recordIds, nil +// } + func (st SqlStore) GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) { - var grs []DbRecord - err := st.db.Where("id IN ?", recordIDs).Find(&grs).Error + var collectionMetadata CollectionMetadata + err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, &NotFoundError{"collection", collectionName} + } return nil, err } - var records = make(map[string]*Record) - for _, gr := range grs { - var record Record - err := json.Unmarshal(gr.Record, &record) - if err != nil { - return nil, err + + var fields map[string]Field + err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) + if err != nil { + return nil, err + } + + records := make(map[string]*Record) + for _, recordId := range recordIDs { + record := make(Record) + for fieldName := range fields { + var fieldValue string + err := st.db.GetContext(ctx, &fieldValue, "SELECT "+fieldName+" FROM collection_"+collectionName+" WHERE id = $1", recordId) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, &NotFoundError{"record", recordId} + } + return nil, err + } + record[fieldName] = fieldValue } - records[gr.Id] = &record + records[recordId] = &record } return records, nil - } func (st SqlStore) GetRecordsFilter(ctx context.Context, collectionName string, fieldName string, value string) ([]string, error) { @@ -185,221 +372,343 @@ func (st SqlStore) GetRecordsFilter(ctx context.Context, collectionName string, } func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, recordID string, record Record) error { - r, err := json.Marshal(record) + var collectionMetadata CollectionMetadata + err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return &NotFoundError{"collection", collectionName} + } return err } - gr := DbRecord{Id: recordID, CollectionName: collectionName, Record: datatypes.JSON(r)} - return st.db.Model(&DbRecord{}).Where("id = ?", recordID).Updates(gr).Error + + var fields map[string]Field + err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) + if err != nil { + return err + } + + var fieldNames []string + var fieldValues []interface{} + for fieldName, fieldValue := range record { + if _, ok := fields[fieldName]; ok { + fieldNames = append(fieldNames, fieldName) + fieldValues = append(fieldValues, fieldValue) + } + } + + setClause := make([]string, len(fieldNames)) + for i, fieldName := range fieldNames { + setClause[i] = fmt.Sprintf("%s = $%d", fieldName, i+1) + } + + query := fmt.Sprintf("UPDATE collection_%s SET %s WHERE id = $%d", collectionName, strings.Join(setClause, ", "), len(fieldNames)+1) + _, err = st.db.ExecContext(ctx, query, append(fieldValues, recordID)...) + return err } func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, recordID string) error { - gr := DbRecord{Id: recordID, CollectionName: collectionName} - return st.db.Delete(&gr).Error + _, err := st.db.ExecContext(ctx, "DELETE FROM collection_"+collectionName+" WHERE id = $1", recordID) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return &NotFoundError{"record", recordID} + } + return err + } + return nil } func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { - var dbPrincipal DbPrincipal - err := st.db.Preload("Policies").First(&dbPrincipal, "username = ?", username).Error + var principal Principal + err := st.db.GetContext(ctx, &principal, "SELECT * FROM principals WHERE username = $1", username) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"principal", username} } return nil, err } - policyIds := make([]string, len(dbPrincipal.Policies)) - for i, policy := range dbPrincipal.Policies { - policyIds[i] = policy.ID + rows, err := st.db.QueryxContext(ctx, "SELECT policy_id FROM principal_policies WHERE username = $1", username) + if err != nil { + return nil, err } + defer rows.Close() - principal := Principal{ - Username: dbPrincipal.Username, - Password: dbPrincipal.Password, - Description: dbPrincipal.Description, - Policies: policyIds, + var policyIds []string + for rows.Next() { + var policyId string + if err := rows.Scan(&policyId); err != nil { + return nil, err + } + policyIds = append(policyIds, policyId) } - return &principal, err + if err := rows.Err(); err != nil { + return nil, err + } + + principal.Policies = policyIds + + return &principal, nil } func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) error { - - dbPrincipal := DbPrincipal{ - Username: principal.Username, - Password: principal.Password, - Description: principal.Description, - } - tx := st.db.Begin() - if err := tx.Error; err != nil { + tx, err := st.db.BeginTxx(ctx, nil) + if err != nil { return err } - err := tx.Create(&dbPrincipal).Error + _, err = tx.NamedExecContext(ctx, "INSERT INTO principals (username, password, description) VALUES (:username, :password, :description)", &principal) if err != nil { tx.Rollback() - if errors.Is(err, gorm.ErrDuplicatedKey) { - return &ConflictError{principal.Username} + if pqErr, ok := err.(*pq.Error); ok { + if pqErr.Code == "23505" { + return &ConflictError{principal.Username} + } } return err } - var dbPolicies []DbPolicy - if err := tx.Where("id IN ?", principal.Policies).Find(&dbPolicies).Error; err != nil { - tx.Rollback() - return err + + for _, policyId := range principal.Policies { + _, err = tx.ExecContext(ctx, "INSERT INTO principal_policies (username, policy_id) VALUES ($1, $2)", principal.Username, policyId) + if err != nil { + tx.Rollback() + return err + } } - if err := tx.Model(&dbPrincipal).Association("Policies").Append(&dbPolicies); err != nil { - tx.Rollback() + err = tx.Commit() + if err != nil { return err } - tx.Commit() return nil } func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { // Start a transaction - tx := st.db.Begin() - if err := tx.Error; err != nil { + tx, err := st.db.BeginTxx(ctx, nil) + if err != nil { return err } // First, delete associations in the many-to-many join table - if err := tx.Table("principal_policies").Where("db_principal_username = ?", username).Delete(nil).Error; err != nil { + _, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE username = $1", username) + if err != nil { tx.Rollback() return err } // Now, delete the principal itself - if err := tx.Where("username = ?", username).Delete(&DbPrincipal{}).Error; err != nil { + _, err = tx.ExecContext(ctx, "DELETE FROM principals WHERE username = $1", username) + if err != nil { tx.Rollback() return err } // Commit the transaction - return tx.Commit().Error + err = tx.Commit() + if err != nil { + return err + } + + return nil } func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { - var gp DbPolicy - err := st.db.First(&gp, "id = ?", policyId).Error + var p Policy + err := st.db.GetContext(ctx, &p, "SELECT * FROM policies WHERE id = $1", policyId) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"policy", policyId} } return nil, err } - var policyActions []PolicyAction - for _, action := range gp.Actions { - policyActions = append(policyActions, PolicyAction(action)) - } - p := Policy{ - PolicyId: gp.ID, - Effect: PolicyEffect(gp.Effect), - Actions: policyActions, - Resources: gp.Resources, - } - return &p, err } func (st SqlStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { - var dBPolicies []DbPolicy - err := st.db.Where("id IN ?", policyIds).Find(&dBPolicies).Error + if len(policyIds) == 0 { + return []*Policy{}, nil + } + query, args, err := sqlx.In("SELECT id, effect, actions, resources FROM policies WHERE id IN (?)", policyIds) if err != nil { return nil, err } + query = st.db.Rebind(query) + rows, err := st.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() - policies := make([]*Policy, len(dBPolicies)) - for i, dbp := range dBPolicies { - actions := dbp.Actions - var policyActions []PolicyAction - for _, action := range actions { - policyActions = append(policyActions, PolicyAction(action)) + policies := make([]*Policy, 0) + for rows.Next() { + var id, effect string + var actions []string + var resources []string + + err = rows.Scan(&id, &effect, pq.Array(&actions), pq.Array(&resources)) + if err != nil { + return nil, err } - policies[i] = &Policy{ - PolicyId: dbp.ID, - Effect: PolicyEffect(dbp.Effect), - Actions: policyActions, - Resources: dbp.Resources, + actionList := make([]PolicyAction, len(actions)) + for i, action := range actions { + actionList[i] = PolicyAction(action) } + + policies = append(policies, &Policy{ + PolicyId: id, + Effect: PolicyEffect(effect), + Actions: actionList, + Resources: resources, + }) + } + + if err = rows.Err(); err != nil { + return nil, err } + return policies, nil } func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { - actionStrings := make([]string, len(p.Actions)) + + // Start a transaction + tx, err := st.db.BeginTxx(ctx, nil) + if err != nil { + return "", err + } + + query := "INSERT INTO policies (id, effect, actions, resources) VALUES (:id, :effect, :actions, :resources)" + actions := make(pq.StringArray, len(p.Actions)) for i, action := range p.Actions { - actionStrings[i] = string(action) + actions[i] = string(action) } - dbPolicy := DbPolicy{ID: p.PolicyId, Effect: string(p.Effect), Actions: actionStrings, Resources: p.Resources} - err := st.db.Create(&dbPolicy).Error + resources := make(pq.StringArray, len(p.Resources)) + for i, resource := range p.Resources { + resources[i] = resource + } + query, args, err := sqlx.Named(query, map[string]interface{}{ + "id": p.PolicyId, + "effect": string(p.Effect), + "actions": actions, + "resources": resources, + }) + if err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { + tx.Rollback() + return "", err + } + query = tx.Rebind(query) + _, err = tx.ExecContext(ctx, query, args...) + if err != nil { + tx.Rollback() + if errors.Is(err, sql.ErrNoRows) { return "", &ConflictError{p.PolicyId} } return "", err } + + // Commit the transaction + err = tx.Commit() + if err != nil { + return "", err + } + return p.PolicyId, nil } func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error { // Start a transaction - tx := st.db.Begin() - if err := tx.Error; err != nil { + tx, err := st.db.BeginTxx(ctx, nil) + if err != nil { return err } - - // Check if the policy exists - var policy DbPolicy - if err := tx.First(&policy, "id = ?", policyID).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return &NotFoundError{"policy", policyID} - } + // Delete the policy itself + result, err := tx.ExecContext(ctx, "DELETE FROM policies WHERE id = $1", policyID) + if err != nil { + tx.Rollback() return err } - - // Directly delete associations in the many-to-many join table - if err := tx.Table("principal_policies").Where("db_policy_id = ?", policyID).Delete(nil).Error; err != nil { + rowsAffected, err := result.RowsAffected() + if err != nil { tx.Rollback() return err } + if rowsAffected == 0 { + tx.Rollback() + return &NotFoundError{"policy", policyID} + } - // Delete the policy itself - if err := tx.Where("id = ?", policyID).Delete(&DbPolicy{}).Error; err != nil { + // Directly delete associations in the many-to-many join table + _, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE policy_id = $1", policyID) + if err != nil { tx.Rollback() return err } // Commit the transaction - return tx.Commit().Error + err = tx.Commit() + if err != nil { + return err + } + + return nil } func (st SqlStore) CreateToken(ctx context.Context, tokenId string, value string) error { gt := DbToken{ID: tokenId, Value: value} - return st.db.Create(>).Error + _, err := st.db.NamedExecContext(ctx, "INSERT INTO tokens (id, value) VALUES (:id, :value)", >) + return err } func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error { - gt := DbToken{ID: tokenId} - return st.db.Delete(>).Error + _, err := st.db.ExecContext(ctx, "DELETE FROM tokens WHERE id = $1", tokenId) + return err } func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { var gt DbToken - err := st.db.First(>, "id = ?", tokenId).Error + err := st.db.GetContext(ctx, >, "SELECT * FROM tokens WHERE id = $1", tokenId) return gt.Value, err } func (st SqlStore) Flush(ctx context.Context) error { - // Delete fk constraints first - st.db.Exec("delete from principal_policies") - // Delete all records + // Drop all tables tables := []string{} - st.db.Raw("SELECT tablename FROM pg_tables WHERE schemaname='public'").Scan(&tables) + err := st.db.SelectContext(ctx, &tables, "SELECT tablename FROM pg_tables WHERE schemaname='public'") + if err != nil { + return err + } for _, table := range tables { - st.db.Exec("DELETE FROM " + table) + _, err = st.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+table) + if err != nil { + return err + } + } + // Recreate schemas + err = st.CreateSchemas() + if err != nil { + return err } return nil } + +// func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { +// b, err := json.Marshal(c) +// if err != nil { +// return "", err +// } + +// gormCol := DbCollection{Name: c.Name, Collection: datatypes.JSON(b)} +// result := st.db.Create(&gormCol) +// if result.Error != nil { +// switch result.Error { +// case gorm.ErrDuplicatedKey: +// return "", &ConflictError{c.Name} +// } +// } + +// return c.Name, nil +// } diff --git a/vault/vault.go b/vault/vault.go index 8bb2e4c..00850f0 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -10,14 +10,14 @@ import ( ) type Field struct { - Name string `redis:"name"` - Type string `redis:"type"` - IsIndexed bool `redis:"is_indexed"` + Name string + Type string + IsIndexed bool } type Collection struct { - Name string `redis:"name"` - Fields map[string]Field `redis:"fields"` + Name string + Fields map[string]Field } type Record map[string]string // field name -> value @@ -55,18 +55,18 @@ const ( ) type Policy struct { - PolicyId string `redis:"policy_id" json:"policy_id" validate:"required"` - Effect PolicyEffect `redis:"effect" json:"effect" validate:"required"` - Actions []PolicyAction `redis:"actions" json:"actions" validate:"required"` - Resources []string `redis:"resources" json:"resources" validate:"required"` + PolicyId string `json:"policy_id" validate:"required"` + Effect PolicyEffect `json:"effect" validate:"required"` + Actions []PolicyAction `json:"actions" validate:"required"` + Resources []string `json:"resources" validate:"required"` } type Principal struct { - Username string `redis:"username"` - Password string `redis:"password"` - Description string `redis:"description"` - CreatedAt string `redis:"created_at"` - Policies []string `redis:"policies"` + Username string + Password string + Description string + CreatedAt string + Policies []string } type Request struct { diff --git a/vault/vault_test.go b/vault/vault_test.go index 964dbcc..f4086cc 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -181,7 +181,7 @@ func TestVault(t *testing.T) { t.Run("can update records", func(t *testing.T) { vault, _, _ := initVault(t) - col := Collection{Name: "test_collection", Fields: map[string]Field{ + col := Collection{Name: "testing", Fields: map[string]Field{ "test_field": { Name: "test_field", Type: "string", From f09a4f2b387bc9ccaaf7ad701b1796f6ac374aea Mon Sep 17 00:00:00 2001 From: Subrose Date: Thu, 23 Nov 2023 15:38:37 +0000 Subject: [PATCH 08/13] sql working --- api/collections_test.go | 4 +- api/go.mod | 9 +-- api/go.sum | 29 +++----- archive/redis.go | 2 +- go.work.sum | 2 + vault/sql.go | 145 ++++++++++++++++------------------------ vault/vault.go | 7 +- vault/vault_test.go | 2 +- 8 files changed, 80 insertions(+), 120 deletions(-) diff --git a/api/collections_test.go b/api/collections_test.go index 449e9e5..24f564a 100644 --- a/api/collections_test.go +++ b/api/collections_test.go @@ -62,7 +62,7 @@ func TestCollections(t *testing.T) { t.Run("can delete a collection", func(t *testing.T) { // Create a dummy collection collectionToDelete := CollectionModel{ - Name: "delete-me", + Name: "delete_me", Fields: map[string]CollectionFieldModel{ "name": {Type: "name", IsIndexed: true}, }, @@ -74,7 +74,7 @@ func TestCollections(t *testing.T) { response := performRequest(t, app, request) checkResponse(t, response, http.StatusCreated, nil) // Delete it - request = newRequest(t, http.MethodDelete, "/collections/delete-me", map[string]string{ + request = newRequest(t, http.MethodDelete, "/collections/delete_me", map[string]string{ "Authorization": createBasicAuthHeader(core.conf.VAULT_ADMIN_USERNAME, core.conf.VAULT_ADMIN_PASSWORD), }, nil) diff --git a/api/go.mod b/api/go.mod index 25bce10..6f1cfe7 100644 --- a/api/go.mod +++ b/api/go.mod @@ -31,21 +31,18 @@ require ( ) require ( - github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.4.3 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect + github.com/jmoiron/sqlx v1.3.5 // indirect github.com/leodido/go-urn v1.2.1 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.17 // indirect github.com/nyaruka/phonenumbers v1.1.6 // indirect - github.com/redis/go-redis/v9 v9.0.2 // indirect github.com/segmentio/ksuid v1.0.4 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/crypto v0.14.0 // indirect @@ -53,8 +50,6 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/protobuf v1.28.0 // indirect - gorm.io/driver/postgres v1.5.4 // indirect - gorm.io/gorm v1.25.5 // indirect ) replace github.com/subrose/vault v0.0.0 => ../vault diff --git a/api/go.sum b/api/go.sum index f455b76..7dff707 100644 --- a/api/go.sum +++ b/api/go.sum @@ -29,14 +29,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bsm/ginkgo/v2 v2.5.0 h1:aOAnND1T40wEdAtkGSkvSICWeQ8L3UASX7YVCqQx+eQ= -github.com/bsm/ginkgo/v2 v2.5.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= -github.com/bsm/gomega v1.20.0 h1:JhAwLmtRzXFTx2AkALSLa8ijZafntmhSoU63Ok18Uq8= -github.com/bsm/gomega v1.20.0/go.mod h1:JifAceMQ4crZIWYUKrlGcmbN3bqHogVTADMD2ATsbwk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -46,8 +40,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -76,6 +68,9 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-test/deep v1.0.2-0.20181118220953-042da051cf31/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -161,12 +156,10 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= @@ -194,6 +187,9 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -209,6 +205,9 @@ github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPn github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= @@ -269,8 +268,6 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/redis/go-redis/v9 v9.0.2 h1:BA426Zqe/7r56kCcvxYLWe1mkaz71LKF77GwgFzSxfE= -github.com/redis/go-redis/v9 v9.0.2/go.mod h1:/xDTe9EF1LM61hek62Poq2nzQSGj0xSrEtEHbBQevps= github.com/rhnvrm/simples3 v0.6.1/go.mod h1:Y+3vYm2V7Y4VijFoJHHTrja6OgPrJ2cBti8dPGkC3sA= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= @@ -503,10 +500,6 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= -gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= -gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= -gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= diff --git a/archive/redis.go b/archive/redis.go index 60c1ccd..81f1f93 100644 --- a/archive/redis.go +++ b/archive/redis.go @@ -454,7 +454,7 @@ func (rs RedisStore) GetPolicy(ctx context.Context, policyId string) (*Policy, e return rawPolicy.toPolicy(), nil } -func (rs RedisStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { +func (rs RedisStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) { policies := []*Policy{} pipeline := rs.Client.Pipeline() diff --git a/go.work.sum b/go.work.sum index 03a9f54..6e82773 100644 --- a/go.work.sum +++ b/go.work.sum @@ -17,6 +17,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.7.2 h1:ol2Y5DWqnJeKqNd8th7JWzBtqu63x github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= +github.com/bsm/ginkgo/v2 v2.5.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= +github.com/bsm/gomega v1.20.0/go.mod h1:JifAceMQ4crZIWYUKrlGcmbN3bqHogVTADMD2ATsbwk= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= diff --git a/vault/sql.go b/vault/sql.go index be41044..eefa343 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -20,22 +20,11 @@ type SqlStore struct { db *sqlx.DB } -type DbRecord struct { - Id string - CollectionName string - Record json.RawMessage -} - type DbCollection struct { Name string `db:"name"` Collection json.RawMessage } -type DbToken struct { - ID string `db:"id"` - Value string -} - type CollectionMetadata struct { Name string FieldSchema json.RawMessage `db:"field_schema"` @@ -274,63 +263,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec return recordIds, nil } -// func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { -// var collectionMetadata CollectionMetadata -// err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) -// if err != nil { -// if errors.Is(err, sql.ErrNoRows) { -// return nil, &NotFoundError{"collection", collectionName} -// } -// return nil, err -// } - -// var fields map[string]Field -// err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) -// if err != nil { -// return nil, err -// } - -// recordIds := make([]string, len(records)) -// for i, record := range records { -// recordId := GenerateId() -// recordIds[i] = recordId -// // Create a slice to hold the field names and values -// var fieldNames []string -// var fieldValues []interface{} -// for fieldName, fieldValue := range record { -// if _, ok := fields[fieldName]; ok { -// fieldNames = append(fieldNames, fieldName) -// fieldValues = append(fieldValues, fieldValue) -// } -// } - -// // Generate placeholders for each field value -// placeholders := make([]string, len(fieldValues)) -// for i := range placeholders { -// placeholders[i] = "$" + strconv.Itoa(i+2) // Start from $2 as $1 is reserved for recordId -// } - -// // Construct the query -// query := fmt.Sprintf( -// "INSERT INTO collection_%s (id, %s) VALUES ($1, %s)", -// collectionName, -// strings.Join(fieldNames, ", "), -// strings.Join(placeholders, ", "), -// ) - -// // Append the recordId and fieldValues to form the final values for the query -// values := make([]interface{}, len(fieldValues)+1) -// values[0] = recordId -// copy(values[1:], fieldValues) - -// _, err = st.db.ExecContext(ctx, query, values...) -// if err != nil { -// return nil, err -// } -// } -// return recordIds, nil -// } - func (st SqlStore) GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) { var collectionMetadata CollectionMetadata err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) @@ -347,22 +279,45 @@ func (st SqlStore) GetRecords(ctx context.Context, collectionName string, record return nil, err } + // Prepare the query + query, args, err := sqlx.In("SELECT * FROM collection_"+collectionName+" WHERE id IN (?)", recordIDs) + if err != nil { + return nil, err + } + query = st.db.Rebind(query) + + // Execute the query + rows, err := st.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + // Process the results records := make(map[string]*Record) - for _, recordId := range recordIDs { + for rows.Next() { + recordMap := make(map[string]interface{}) + err = rows.MapScan(recordMap) + if err != nil { + return nil, err + } + recordID := recordMap["id"].(string) record := make(Record) - for fieldName := range fields { - var fieldValue string - err := st.db.GetContext(ctx, &fieldValue, "SELECT "+fieldName+" FROM collection_"+collectionName+" WHERE id = $1", recordId) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, &NotFoundError{"record", recordId} - } - return nil, err + for k, v := range recordMap { + if str, ok := v.(string); ok { + record[k] = str + } else { + // We're assuming all record fields are strings as they are encrypted in the db, this might change + return nil, fmt.Errorf("unexpected type for field %s", k) } - record[fieldName] = fieldValue } - records[recordId] = &record + records[recordID] = &record + } + + if err = rows.Err(); err != nil { + return nil, err } + return records, nil } @@ -514,8 +469,11 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { } func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { - var p Policy - err := st.db.GetContext(ctx, &p, "SELECT * FROM policies WHERE id = $1", policyId) + var id, effect string + var actions []string + var resources []string + + err := st.db.QueryRowxContext(ctx, "SELECT id, effect, actions, resources FROM policies WHERE id = $1", policyId).Scan(&id, &effect, pq.Array(&actions), pq.Array(&resources)) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"policy", policyId} @@ -523,10 +481,22 @@ func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, err return nil, err } - return &p, err + actionList := make([]PolicyAction, len(actions)) + for i, action := range actions { + actionList[i] = PolicyAction(action) + } + + p := Policy{ + PolicyId: id, + Effect: PolicyEffect(effect), + Actions: actionList, + Resources: resources, + } + + return &p, nil } -func (st SqlStore) GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) { +func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) { if len(policyIds) == 0 { return []*Policy{}, nil } @@ -658,8 +628,7 @@ func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error { } func (st SqlStore) CreateToken(ctx context.Context, tokenId string, value string) error { - gt := DbToken{ID: tokenId, Value: value} - _, err := st.db.NamedExecContext(ctx, "INSERT INTO tokens (id, value) VALUES (:id, :value)", >) + _, err := st.db.ExecContext(ctx, "INSERT INTO tokens (id, value) VALUES ($1, $2)", tokenId, value) return err } @@ -669,9 +638,9 @@ func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error { } func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { - var gt DbToken - err := st.db.GetContext(ctx, >, "SELECT * FROM tokens WHERE id = $1", tokenId) - return gt.Value, err + var value string + err := st.db.GetContext(ctx, &value, "SELECT value FROM tokens WHERE id = $1", tokenId) + return value, err } func (st SqlStore) Flush(ctx context.Context) error { diff --git a/vault/vault.go b/vault/vault.go index 00850f0..20b3ad0 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -105,7 +105,7 @@ type VaultDB interface { CreatePrincipal(ctx context.Context, principal Principal) error DeletePrincipal(ctx context.Context, username string) error GetPolicy(ctx context.Context, policyId string) (*Policy, error) - GetPoliciesById(ctx context.Context, policyIds []string) ([]*Policy, error) + GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) CreatePolicy(ctx context.Context, p Policy) (string, error) DeletePolicy(ctx context.Context, policyId string) error CreateToken(ctx context.Context, tokenId string, value string) error @@ -562,7 +562,7 @@ func (vault Vault) GetPrincipalPolicies( return nil, &ForbiddenError{request} } - policies, err := vault.Db.GetPoliciesById(ctx, principal.Policies) + policies, err := vault.Db.GetPolicies(ctx, principal.Policies) if err != nil { return nil, err } @@ -573,7 +573,7 @@ func (vault Vault) ValidateAction( ctx context.Context, request Request, ) (bool, error) { - policies, err := vault.Db.GetPoliciesById(ctx, request.Principal.Policies) + policies, err := vault.Db.GetPolicies(ctx, request.Principal.Policies) if err != nil { return false, err } @@ -605,6 +605,7 @@ func (vault Vault) CreateToken(ctx context.Context, principal Principal, collect // I don't think this is needed since it's already handled in the GetRecords error return. return "", &NotFoundError{"record", recordId} } + func (vault Vault) DeleteToken(ctx context.Context, tokenId string) error { return vault.Db.DeleteToken(ctx, tokenId) } diff --git a/vault/vault_test.go b/vault/vault_test.go index f4086cc..b7deda7 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -544,7 +544,7 @@ func TestTokens(t *testing.T) { } }) t.Run("getting token value fails without access to underlying record", func(t *testing.T) { - rId := customerRecords[0] + rId := employeeRecords[0] tokenId, err := vault.CreateToken(ctx, rootPrincipal, "employees", rId, "name", "plain") assert.NoError(t, err) assert.NotEqual(t, 0, len(tokenId), "tokenId was empty") From 488705fb10e0e4b8ea657934e7520eb23db52100 Mon Sep 17 00:00:00 2001 From: Subrose Date: Thu, 23 Nov 2023 16:32:39 +0000 Subject: [PATCH 09/13] fixes --- vault/sql.go | 179 ++++++++++++++++++++++----------------------------- 1 file changed, 76 insertions(+), 103 deletions(-) diff --git a/vault/sql.go b/vault/sql.go index eefa343..a9e0526 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "reflect" "strings" "database/sql" @@ -50,9 +49,9 @@ func (st *SqlStore) CreateSchemas() error { tables := map[string]string{ "principals": "CREATE TABLE IF NOT EXISTS principals (username TEXT PRIMARY KEY, password TEXT, description TEXT)", "policies": "CREATE TABLE IF NOT EXISTS policies (id TEXT, effect TEXT, actions TEXT[], resources TEXT[])", + "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT)", "tokens": "CREATE TABLE IF NOT EXISTS tokens (id TEXT, value TEXT)", "collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (name TEXT, field_schema JSON)", - "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT)", } for _, query := range tables { @@ -65,68 +64,49 @@ func (st *SqlStore) CreateSchemas() error { return nil } -func (st SqlStore) createCollectionTable(ctx context.Context, c Collection) error { - // Define a dynamic struct based on the Fields of the collection - var dynamicStructFields []reflect.StructField - - // Add an ID field to the struct - idField := reflect.StructField{ - Name: "ID", - Type: reflect.TypeOf(""), - Tag: reflect.StructTag(`db:"id"`), +func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { + tx, err := st.db.BeginTxx(ctx, nil) + if err != nil { + return "", err } - dynamicStructFields = append(dynamicStructFields, idField) - for fieldName := range c.Fields { - exportedFieldName := strings.Title(fieldName) - structField := reflect.StructField{ - Name: exportedFieldName, - Type: reflect.TypeOf(""), // Assuming all fields are strings for simplicity - Tag: reflect.StructTag(fmt.Sprintf(`db:"%s"`, fieldName)), + defer func() { + if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err) + } } - dynamicStructFields = append(dynamicStructFields, structField) - } - - // dynamicStruct := reflect.StructOf(dynamicStructFields) - // dynamicStructPtr := reflect.New(dynamicStruct).Interface() // Create a pointer to a new instance of the dynamic struct - - tableName := "collection_" + c.Name // Create a unique table name + }() - // Create the table using SQLX's MustExec with a pointer to the dynamic struct - // st.db.MustExecContext(ctx, "CREATE TABLE IF NOT EXISTS "+tableName+" (?)", dynamicStructPtr) - // Instead of using the dynamic struct directly, we will generate the SQL query manually - var queryBuilder strings.Builder - queryBuilder.WriteString("CREATE TABLE IF NOT EXISTS " + tableName + " (id TEXT") - for fieldName := range c.Fields { - queryBuilder.WriteString(", " + fieldName + " TEXT") - } - queryBuilder.WriteString(")") - st.db.MustExecContext(ctx, queryBuilder.String()) - - return nil -} - -func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { - // Convert the Fields map to JSON for storing in the collection_metadata table fieldSchema, err := json.Marshal(c.Fields) if err != nil { return "", err } - // Create a new CollectionMetadata instance collectionMetadata := CollectionMetadata{ Name: c.Name, FieldSchema: fieldSchema, } - // Save the collection metadata - _, err = st.db.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", &collectionMetadata) + _, err = tx.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", &collectionMetadata) if err != nil { return "", err } - // Dynamically create a table for the collection - if err := st.createCollectionTable(ctx, c); err != nil { + tableName := "collection_" + c.Name + var queryBuilder strings.Builder + queryBuilder.WriteString("CREATE TABLE IF NOT EXISTS " + tableName + " (id TEXT PRIMARY KEY") + for fieldName := range c.Fields { + queryBuilder.WriteString(", " + fieldName + " TEXT") + } + queryBuilder.WriteString(")") + _, err = tx.ExecContext(ctx, queryBuilder.String()) + if err != nil { + return "", err + } + + err = tx.Commit() + if err != nil { return "", err } @@ -164,32 +144,34 @@ func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) { } func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { - // Start a transaction tx, err := st.db.BeginTxx(ctx, nil) if err != nil { return err } - // Delete the collection metadata + defer func() { + if err != nil { + rbErr := tx.Rollback() + if rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err) + } + } + }() + _, err = tx.ExecContext(ctx, "DELETE FROM collection_metadata WHERE name = $1", name) if err != nil { if errors.Is(err, sql.ErrNoRows) { - tx.Rollback() return &NotFoundError{"collection", name} } - tx.Rollback() return err } - // Delete the dynamic table tableName := "collection_" + name _, err = tx.ExecContext(ctx, "DROP TABLE IF EXISTS "+tableName) if err != nil { - tx.Rollback() return err } - // Commit the transaction err = tx.Commit() if err != nil { return err @@ -197,6 +179,7 @@ func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { return nil } + func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { var collectionMetadata CollectionMetadata err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) @@ -213,7 +196,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec return nil, err } - // Create slice for field names and initialize placeholders fieldNames := make([]string, 0, len(fields)) placeholders := make([]string, 0, len(fields)) idx := 2 // Start from 2 because $1 is reserved for recordId @@ -224,7 +206,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec idx++ } - // Prepare SQL statement query := fmt.Sprintf("INSERT INTO collection_%s (id, %s) VALUES ($1, %s)", collectionName, strings.Join(fieldNames, ", "), strings.Join(placeholders, ", ")) stmt, err := st.db.PrepareContext(ctx, query) if err != nil { @@ -232,7 +213,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec } defer stmt.Close() - // Processing records recordIds := make([]string, len(records)) for i, record := range records { // Validate record fields @@ -243,7 +223,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec recordId := GenerateId() recordIds[i] = recordId - // Prepare values for insertion values := make([]interface{}, len(fields)+1) values[0] = recordId for j, fieldName := range fieldNames { @@ -254,7 +233,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec } } - // Execute the prepared statement _, err = stmt.ExecContext(ctx, values...) if err != nil { return nil, err @@ -279,21 +257,18 @@ func (st SqlStore) GetRecords(ctx context.Context, collectionName string, record return nil, err } - // Prepare the query query, args, err := sqlx.In("SELECT * FROM collection_"+collectionName+" WHERE id IN (?)", recordIDs) if err != nil { return nil, err } query = st.db.Rebind(query) - // Execute the query rows, err := st.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - // Process the results records := make(map[string]*Record) for rows.Next() { recordMap := make(map[string]interface{}) @@ -412,9 +387,18 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err if err != nil { return err } + + defer func() { + if err != nil { + rbErr := tx.Rollback() + if rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err) + } + } + }() + _, err = tx.NamedExecContext(ctx, "INSERT INTO principals (username, password, description) VALUES (:username, :password, :description)", &principal) if err != nil { - tx.Rollback() if pqErr, ok := err.(*pq.Error); ok { if pqErr.Code == "23505" { return &ConflictError{principal.Username} @@ -426,10 +410,10 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err for _, policyId := range principal.Policies { _, err = tx.ExecContext(ctx, "INSERT INTO principal_policies (username, policy_id) VALUES ($1, $2)", principal.Username, policyId) if err != nil { - tx.Rollback() return err } } + err = tx.Commit() if err != nil { return err @@ -445,21 +429,28 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { return err } - // First, delete associations in the many-to-many join table + defer func() { + if p := recover(); p != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after panic: %v", rbErr, p) + } + } else if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err) + } + } + }() + _, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE username = $1", username) if err != nil { - tx.Rollback() return err } - // Now, delete the principal itself _, err = tx.ExecContext(ctx, "DELETE FROM principals WHERE username = $1", username) if err != nil { - tx.Rollback() return err } - // Commit the transaction err = tx.Commit() if err != nil { return err @@ -543,22 +534,26 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli } func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { - - // Start a transaction tx, err := st.db.BeginTxx(ctx, nil) if err != nil { return "", err } + defer func() { + if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err) + } + } + }() + query := "INSERT INTO policies (id, effect, actions, resources) VALUES (:id, :effect, :actions, :resources)" actions := make(pq.StringArray, len(p.Actions)) for i, action := range p.Actions { actions[i] = string(action) } resources := make(pq.StringArray, len(p.Resources)) - for i, resource := range p.Resources { - resources[i] = resource - } + copy(resources, p.Resources) query, args, err := sqlx.Named(query, map[string]interface{}{ "id": p.PolicyId, "effect": string(p.Effect), @@ -567,20 +562,17 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { }) if err != nil { - tx.Rollback() return "", err } query = tx.Rebind(query) _, err = tx.ExecContext(ctx, query, args...) if err != nil { - tx.Rollback() if errors.Is(err, sql.ErrNoRows) { return "", &ConflictError{p.PolicyId} } return "", err } - // Commit the transaction err = tx.Commit() if err != nil { return "", err @@ -590,35 +582,35 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { } func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error { - // Start a transaction tx, err := st.db.BeginTxx(ctx, nil) if err != nil { return err } - // Delete the policy itself + defer func() { + if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err) + } + } + }() + result, err := tx.ExecContext(ctx, "DELETE FROM policies WHERE id = $1", policyID) if err != nil { - tx.Rollback() return err } rowsAffected, err := result.RowsAffected() if err != nil { - tx.Rollback() return err } if rowsAffected == 0 { - tx.Rollback() return &NotFoundError{"policy", policyID} } - // Directly delete associations in the many-to-many join table _, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE policy_id = $1", policyID) if err != nil { - tx.Rollback() return err } - // Commit the transaction err = tx.Commit() if err != nil { return err @@ -656,28 +648,9 @@ func (st SqlStore) Flush(ctx context.Context) error { return err } } - // Recreate schemas err = st.CreateSchemas() if err != nil { return err } return nil } - -// func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { -// b, err := json.Marshal(c) -// if err != nil { -// return "", err -// } - -// gormCol := DbCollection{Name: c.Name, Collection: datatypes.JSON(b)} -// result := st.db.Create(&gormCol) -// if result.Error != nil { -// switch result.Error { -// case gorm.ErrDuplicatedKey: -// return "", &ConflictError{c.Name} -// } -// } - -// return c.Name, nil -// } From 8b7305bdc9e49b6d1c6230a4b2cd1dec4a2c2ef0 Mon Sep 17 00:00:00 2001 From: Subrose Date: Thu, 23 Nov 2023 16:41:31 +0000 Subject: [PATCH 10/13] more fixes --- vault/sql.go | 73 ++++++++++++++-------------------------------------- 1 file changed, 20 insertions(+), 53 deletions(-) diff --git a/vault/sql.go b/vault/sql.go index a9e0526..7023e03 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -19,16 +19,6 @@ type SqlStore struct { db *sqlx.DB } -type DbCollection struct { - Name string `db:"name"` - Collection json.RawMessage -} - -type CollectionMetadata struct { - Name string - FieldSchema json.RawMessage `db:"field_schema"` -} - func NewSqlStore(dsn string) (*SqlStore, error) { db, err := sqlx.Connect("postgres", dsn) if err != nil { @@ -83,12 +73,10 @@ func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, return "", err } - collectionMetadata := CollectionMetadata{ - Name: c.Name, - FieldSchema: fieldSchema, - } - - _, err = tx.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", &collectionMetadata) + _, err = tx.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", map[string]interface{}{ + "name": c.Name, + "field_schema": fieldSchema, + }) if err != nil { return "", err } @@ -110,12 +98,12 @@ func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, return "", err } - return collectionMetadata.Name, nil + return c.Name, nil } func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) { - var collectionMetadata CollectionMetadata - err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", name) + var fieldSchema json.RawMessage + err := st.db.GetContext(ctx, &fieldSchema, "SELECT field_schema FROM collection_metadata WHERE name = $1", name) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"collection", name} @@ -124,22 +112,18 @@ func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, } var fields map[string]Field - err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) + err = json.Unmarshal(fieldSchema, &fields) if err != nil { return nil, err } - return &Collection{Name: collectionMetadata.Name, Fields: fields}, nil + return &Collection{Name: name, Fields: fields}, nil } func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) { - var collectionMetadatas []CollectionMetadata + var collectionNames []string - err := st.db.SelectContext(ctx, &collectionMetadatas, "SELECT * FROM collection_metadata") + err := st.db.SelectContext(ctx, &collectionNames, "SELECT name FROM collection_metadata") - collectionNames := make([]string, len(collectionMetadatas)) - for i, collectionMetadata := range collectionMetadatas { - collectionNames[i] = collectionMetadata.Name - } return collectionNames, err } @@ -181,8 +165,8 @@ func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { } func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) { - var collectionMetadata CollectionMetadata - err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) + var fieldSchema json.RawMessage + err := st.db.GetContext(ctx, &fieldSchema, "SELECT field_schema FROM collection_metadata WHERE name = $1", collectionName) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"collection", collectionName} @@ -191,7 +175,7 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec } var fields map[string]Field - err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) + err = json.Unmarshal(fieldSchema, &fields) if err != nil { return nil, err } @@ -242,8 +226,8 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec } func (st SqlStore) GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) { - var collectionMetadata CollectionMetadata - err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) + var fieldSchema json.RawMessage + err := st.db.GetContext(ctx, &fieldSchema, "SELECT field_schema FROM collection_metadata WHERE name = $1", collectionName) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"collection", collectionName} @@ -252,7 +236,7 @@ func (st SqlStore) GetRecords(ctx context.Context, collectionName string, record } var fields map[string]Field - err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) + err = json.Unmarshal(fieldSchema, &fields) if err != nil { return nil, err } @@ -302,28 +286,11 @@ func (st SqlStore) GetRecordsFilter(ctx context.Context, collectionName string, } func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, recordID string, record Record) error { - var collectionMetadata CollectionMetadata - err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return &NotFoundError{"collection", collectionName} - } - return err - } - - var fields map[string]Field - err = json.Unmarshal(collectionMetadata.FieldSchema, &fields) - if err != nil { - return err - } - var fieldNames []string var fieldValues []interface{} for fieldName, fieldValue := range record { - if _, ok := fields[fieldName]; ok { - fieldNames = append(fieldNames, fieldName) - fieldValues = append(fieldValues, fieldValue) - } + fieldNames = append(fieldNames, fieldName) + fieldValues = append(fieldValues, fieldValue) } setClause := make([]string, len(fieldNames)) @@ -332,7 +299,7 @@ func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, reco } query := fmt.Sprintf("UPDATE collection_%s SET %s WHERE id = $%d", collectionName, strings.Join(setClause, ", "), len(fieldNames)+1) - _, err = st.db.ExecContext(ctx, query, append(fieldValues, recordID)...) + _, err := st.db.ExecContext(ctx, query, append(fieldValues, recordID)...) return err } From ac585999595716621c1a84fd73b029bf03c1380b Mon Sep 17 00:00:00 2001 From: Subrose Date: Thu, 23 Nov 2023 17:02:30 +0000 Subject: [PATCH 11/13] tests complete --- api/errors.go | 7 +++---- simulator/client.py | 13 +++++++++++++ simulator/simulate.sh | 6 +----- vault/sql.go | 8 ++++---- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/api/errors.go b/api/errors.go index 303dcf9..35fb8d6 100644 --- a/api/errors.go +++ b/api/errors.go @@ -20,12 +20,11 @@ type ErrorResponse struct { } func ValidateResourceName(fl validator.FieldLevel) bool { - // Validation for vault internal resource names - reg := "^[a-zA-Z0-9._-]{1,249}$" + reg := "^[a-zA-Z0-9._]{1,249}$" match, _ := regexp.MatchString(reg, fl.Field().String()) - // Check for prohibited values: single period and double underscore - if fl.Field().String() == "." || fl.Field().String() == "__" { + // Check for prohibited values: single period, double underscore, and hyphen + if fl.Field().String() == "." || fl.Field().String() == "__" || fl.Field().String() == "-" { return false } diff --git a/simulator/client.py b/simulator/client.py index 5599438..b86b229 100644 --- a/simulator/client.py +++ b/simulator/client.py @@ -78,6 +78,19 @@ def create_collection( ) check_expected_status(response, expected_statuses) + def update_collection( + self, + collection: str, + schema: dict[str, Any], + expected_statuses: Optional[list[int]] = None, + ) -> None: + response = requests.put( + f"{self.vault_url}/collections/{collection}", + json=schema, + auth=(self.username, self.password), + ) + check_expected_status(response, expected_statuses) + def create_records( self, collection: str, diff --git a/simulator/simulate.sh b/simulator/simulate.sh index 9d358bc..aede4c1 100755 --- a/simulator/simulate.sh +++ b/simulator/simulate.sh @@ -1,10 +1,6 @@ #!/bin/bash -# redis-cli -h keydb FLUSHALL || exit 1 python ecommerce.py || exit 1 - -# redis-cli -h keydb FLUSHALL || exit 1 python pci.py || exit 1 - -# redis-cli -h keydb FLUSHALL || exit 1 python password_manager.py || exit 1 +python ops.py || exit 1 diff --git a/vault/sql.go b/vault/sql.go index 7023e03..b816034 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -38,10 +38,10 @@ func NewSqlStore(dsn string) (*SqlStore, error) { func (st *SqlStore) CreateSchemas() error { tables := map[string]string{ "principals": "CREATE TABLE IF NOT EXISTS principals (username TEXT PRIMARY KEY, password TEXT, description TEXT)", - "policies": "CREATE TABLE IF NOT EXISTS policies (id TEXT, effect TEXT, actions TEXT[], resources TEXT[])", - "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT)", - "tokens": "CREATE TABLE IF NOT EXISTS tokens (id TEXT, value TEXT)", - "collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (name TEXT, field_schema JSON)", + "policies": "CREATE TABLE IF NOT EXISTS policies (id TEXT PRIMARY KEY, effect TEXT, actions TEXT[], resources TEXT[])", + "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT, UNIQUE(username, policy_id))", + "tokens": "CREATE TABLE IF NOT EXISTS tokens (id TEXT PRIMARY KEY, value TEXT)", + "collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (name TEXT PRIMARY KEY, field_schema JSON)", } for _, query := range tables { From 7b0cd9ceaa86343a4469390fb56b27181aaf0f7e Mon Sep 17 00:00:00 2001 From: Subrose Date: Thu, 23 Nov 2023 18:14:13 +0000 Subject: [PATCH 12/13] ci fix --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b450d61..5916618 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: - name: Run Gosec Security Scanner uses: securego/gosec@master with: - args: ./... + args: ./api ./vault ./logger - name: Check code formatting run: | From 6f7ab3cc0297fd65d60521482c316cd96282330b Mon Sep 17 00:00:00 2001 From: Subrose Date: Fri, 24 Nov 2023 10:00:21 +0000 Subject: [PATCH 13/13] fi xi --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5916618..27f87f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,7 +66,7 @@ jobs: up-flags: "--remove-orphans --abort-on-container-exit" down-flags: "--volumes --remove-orphans" services: | - keydb + postgres tests - name: Run simulations @@ -76,6 +76,6 @@ jobs: up-flags: "--remove-orphans --abort-on-container-exit" down-flags: "--volumes --remove-orphans" services: | - keydb + postgres api simulations