diff --git a/api/collections.go b/api/collections.go index 29e652b..0db060d 100644 --- a/api/collections.go +++ b/api/collections.go @@ -9,6 +9,8 @@ import ( _vault "github.com/subrose/vault" ) +// TODO: Extract and share the collection name validation logic + func (core *Core) GetCollection(c *fiber.Ctx) error { collectionName := c.Params("name") principal := GetSessionPrincipal(c) @@ -37,6 +39,10 @@ func (core *Core) CreateCollection(c *fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil}) } + if collection.Type == "" { + collection.Type = _vault.CollectionTypeSubject + } + err := core.vault.CreateCollection(c.Context(), principal, collection) if err != nil { return err @@ -137,9 +143,7 @@ func (core *Core) GetRecords(c *fiber.Ctx) error { if err != nil { return err } - return c.Status(http.StatusOK).JSON(records) - } func (core *Core) GetRecord(c *fiber.Ctx) error { diff --git a/api/go.mod b/api/go.mod index 0aa8576..080deec 100644 --- a/api/go.mod +++ b/api/go.mod @@ -33,7 +33,8 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.4.3 // indirect - github.com/jmoiron/sqlx v1.3.5 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -47,6 +48,8 @@ require ( golang.org/x/sys v0.14.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/protobuf v1.28.0 // indirect + gorm.io/driver/postgres v1.5.4 // indirect + gorm.io/gorm v1.25.5 // indirect ) replace github.com/subrose/vault v0.0.0 => ../vault diff --git a/api/go.sum b/api/go.sum index e4cdb16..31d9f1f 100644 --- a/api/go.sum +++ b/api/go.sum @@ -66,9 +66,6 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= -github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-test/deep v1.0.2-0.20181118220953-042da051cf31/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -154,10 +151,12 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= -github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= @@ -181,7 +180,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= @@ -199,9 +197,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= -github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= @@ -456,6 +451,10 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= diff --git a/api/main.go b/api/main.go index c0c3de6..9227c44 100644 --- a/api/main.go +++ b/api/main.go @@ -29,7 +29,7 @@ func ApiLogger(core *Core) fiber.Handler { ctx.Get("User-Agent"), ctx.Get("X-Trace-Id"), dt, - ctx.Response().StatusCode(), + ctx.Context().Response.StatusCode(), // This will now log the final status code ) return err } @@ -87,6 +87,8 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error { return ctx.Status(code).SendString(err.Error()) } + core.logger.Error(err.Error()) + // Handle custom errors from the vault package var ve *_vault.ValueError var fe *_vault.ForbiddenError @@ -133,8 +135,8 @@ func SetupApi(core *Core) *fiber.App { ErrorHandler: core.customErrorHandler, }) app.Use(helmet.New()) - app.Use(ApiLogger(core)) app.Use(recover.New()) + app.Use(ApiLogger(core)) app.Get("/health", func(c *fiber.Ctx) error { return c.Status(http.StatusOK).SendString("OK") @@ -143,7 +145,7 @@ func SetupApi(core *Core) *fiber.App { principalGroup := app.Group("/principals") principalGroup.Use(authGuard(core)) principalGroup.Get(":username", core.GetPrincipal) - principalGroup.Post("", core.CreatePrincipal) + principalGroup.Post("", JSONOnlyMiddleware, core.CreatePrincipal) principalGroup.Delete(":username", core.DeletePrincipal) collectionsGroup := app.Group("/collections") @@ -151,8 +153,8 @@ func SetupApi(core *Core) *fiber.App { collectionsGroup.Get("", core.GetCollections) collectionsGroup.Get("/:name", core.GetCollection) collectionsGroup.Delete("/:name", core.DeleteCollection) - collectionsGroup.Post("", core.CreateCollection) - collectionsGroup.Post("/:name/records", core.CreateRecord) + collectionsGroup.Post("", JSONOnlyMiddleware, core.CreateCollection) + collectionsGroup.Post("/:name/records", JSONOnlyMiddleware, core.CreateRecord) collectionsGroup.Get("/:name/records", core.GetRecords) collectionsGroup.Get("/:name/records/:id", core.GetRecord) collectionsGroup.Put("/:name/records/:id", core.UpdateRecord) @@ -168,7 +170,7 @@ func SetupApi(core *Core) *fiber.App { tokensGroup := app.Group("/tokens") tokensGroup.Use(authGuard(core)) tokensGroup.Get(":tokenId", core.GetTokenById) - tokensGroup.Post("", core.CreateToken) + tokensGroup.Post("", JSONOnlyMiddleware, core.CreateToken) app.Use(func(c *fiber.Ctx) error { return c.SendStatus(404) diff --git a/api/principals.go b/api/principals.go index 1387efd..959800e 100644 --- a/api/principals.go +++ b/api/principals.go @@ -7,13 +7,6 @@ import ( _vault "github.com/subrose/vault" ) -// type NewPrincipal struct { -// Username string `json:"username" validate:"required,min=1,max=32"` -// Password string `json:"password" validate:"required,min=4,max=32"` // This is to limit the size of the password hash. -// Description string `json:"description"` -// Policies []string `json:"policies"` -// } - type PrincipalResponse struct { Id string `json:"id"` Username string `json:"username" validate:"required,min=3,max=32"` diff --git a/archive/sql/go.mod b/archive/sql/go.mod new file mode 100644 index 0000000..b522457 --- /dev/null +++ b/archive/sql/go.mod @@ -0,0 +1,3 @@ +module sql/m/v2 + +go 1.21.0 diff --git a/archive/sql/main.go b/archive/sql/main.go new file mode 100644 index 0000000..5e2f302 --- /dev/null +++ b/archive/sql/main.go @@ -0,0 +1,197 @@ +package main + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "log" + "time" + + "github.com/lib/pq" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +// Define your models +type dbPolicy struct { + Id string `gorm:"primaryKey"` + Name string + Description string + Effect string + Actions pq.StringArray `gorm:"type:text[]"` + Resources pq.StringArray `gorm:"type:text[]"` + CreatedAt time.Time + UpdatedAt time.Time +} + +func (dbPolicy) TableName() string { + return "policies" +} + +type dbPrincipal struct { + Id string `gorm:"primaryKey"` + Username string + Password string + Description string + CreatedAt time.Time + UpdatedAt time.Time + Policies []dbPolicy `gorm:"many2many:principal_policies;"` +} + +func (dbPrincipal) TableName() string { + return "principals" +} + +type dbToken struct { + Id string `gorm:"primaryKey"` + value string + CreatedAt time.Time + UpdatedAt time.Time +} + +func (dbToken) TableName() string { + return "tokens" +} + +type PrincipalPolicy struct { + PrincipalId string `gorm:"primaryKey;autoIncrement:false"` + PolicyId string `gorm:"primaryKey;autoIncrement:false"` +} + +func (PrincipalPolicy) TableName() string { + return "principal_policies" +} + +type Field struct { + Type string `json:"type"` + IsIndexed bool `json:"indexed"` +} + +type FieldSchemaMap map[string]Field + +func (f *FieldSchemaMap) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New("failed to unmarshal JSONB value") + } + + result := FieldSchemaMap{} + if err := json.Unmarshal(bytes, &result); err != nil { + return err + } + + *f = result + return nil +} + +func (f FieldSchemaMap) Value() (driver.Value, error) { + if len(f) == 0 { + return nil, nil + } + + return json.Marshal(f) +} + +type dbCollectionMetadata struct { + Id string `gorm:"primaryKey"` + Name string `gorm:"unique"` + FieldSchema FieldSchemaMap `gorm:"type:json"` // Ensures JSON storage + CreatedAt time.Time + UpdatedAt time.Time +} + +func (dbCollectionMetadata) TableName() string { + return "collections_metadata" +} + +func main() { + // Setup database connection + time.Local = time.UTC + dsn := "host=localhost user=postgres dbname=postgres sslmode=disable password=postgres" + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{TranslateError: true}) + if err != nil { + log.Fatalf("Failed to connect to database: %v", err) + } + + // Drop tables + db.Exec("DROP TABLE IF EXISTS principal_policies") + db.Exec("DROP TABLE IF EXISTS principals") + db.Exec("DROP TABLE IF EXISTS policies") + db.Exec("DROP TABLE IF EXISTS collections_metadata") + + // AutoMigrate models + db.AutoMigrate(&dbPrincipal{}, &dbPolicy{}, &PrincipalPolicy{}, &dbCollectionMetadata{}) + + // Create a Policy + + policy := dbPolicy{ + Id: "policy1", + Name: "Test Policy", + Description: "A test policy", + Effect: "allow", + Actions: []string{"read", "write"}, + Resources: []string{"resource1", "resource2"}, + } + + db.Create(&policy) + + // Create a Principal with the created Policy + principal := dbPrincipal{ + Id: "principal1", + Username: "john_doe", + Password: "secret", + Description: "A test principal", + Policies: []dbPolicy{policy}, + } + db.Create(&principal) + + // Retrieve a Principal with their Policies - this can be looped over + var retrievedPrincipal dbPrincipal + db.Preload("Policies").First(&retrievedPrincipal, "id = ?", "principal1") + + // Pattern for deleting a Principal and associated records + // Start a new transaction + tx := db.Begin() + + // Defer a function that will commit the transaction if no errors, or rollback if there were any + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() + + // Delete a Principal and cascade delete associated records in PrincipalPolicy + if err := tx.Where("principal_id = ?", principal.Id).Delete(&PrincipalPolicy{}).Error; err != nil { + log.Fatal(err) + } + + // Then, delete the principal + if err := tx.Delete(principal).Error; err != nil { + log.Fatal(err) + } + + // Create a CollectionMetadata + collectionMetadata := dbCollectionMetadata{ + Id: "collection1", + Name: "customers", + FieldSchema: map[string]Field{"name": {Type: "phone", IsIndexed: true}}, + } + + db.Create(&collectionMetadata) + result := db.Create(&collectionMetadata) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrDuplicatedKey) { + log.Print("Collection already exists") // we'd return a 409 here + } else { + log.Fatal(result.Error) + } + } + + // Now let's retrieve the CollectionMetadata + var retrievedCollectionMetadata dbCollectionMetadata + db.First(&retrievedCollectionMetadata, "id = ?", "collection1") + fmt.Println(retrievedCollectionMetadata.FieldSchema["name"].Type) +} diff --git a/simulator/client.py b/simulator/client.py index 1e62c88..a48e7e6 100644 --- a/simulator/client.py +++ b/simulator/client.py @@ -193,3 +193,28 @@ def detokenise( ) check_expected_status(response, expected_statuses) return response.json() + + def create_subject(self, eid: str, expected_statuses: Optional[list[int]] = None): + response = requests.post( + f"{self.vault_url}/subjects", + auth=(self.username, self.password), + json={"eid": eid}, + ) + check_expected_status(response, expected_statuses) + return response.json() + + def get_subject(self, sid: str, expected_statuses: Optional[list[int]] = None): + response = requests.get( + f"{self.vault_url}/subjects/{sid}", + auth=(self.username, self.password), + ) + check_expected_status(response, expected_statuses) + return response.json() + + def delete_subject(self, sid: str, expected_statuses: Optional[list[int]] = None): + response = requests.delete( + f"{self.vault_url}/subjects/{sid}", + auth=(self.username, self.password), + ) + check_expected_status(response, expected_statuses) + return diff --git a/simulator/company-house.py b/simulator/company-house.py new file mode 100644 index 0000000..524793b --- /dev/null +++ b/simulator/company-house.py @@ -0,0 +1,91 @@ +# Simulating the UK company house usecase where we need to register: +# - Company details: name, address, phone, email +# - Company directors: name, address, phone, email, company +# Access control is out of scope for this usecase, just experimenting with Subjects +# If a company is deleted, all associated records should be deleted as well + +from client import Actor, init_client +from faker import Faker +from faker_e164.providers import E164Provider + +vault_url = init_client() + +admin = Actor(vault_url, username="admin", password="admin") + +admin.create_collection( + schema={ + "name": "companies", + "fields": { + "registration_number": {"type": "string", "indexed": True}, + "name": {"type": "name", "indexed": False}, + "email": {"type": "email", "indexed": True}, + "phone": {"type": "phone_number", "indexed": False}, + "address": {"type": "address", "indexed": False}, + }, + }, + expected_statuses=[201, 409], +) + +admin.create_collection( + schema={ + "name": "directors", + "fields": { + "name": {"type": "name", "indexed": False}, + "email": {"type": "email", "indexed": True}, + "phone": {"type": "phone_number", "indexed": False}, + "address": {"type": "address", "indexed": False}, + "shares": {"type": "integer", "indexed": False}, + "company": {"type": "string", "indexed": False}, + }, + }, + expected_statuses=[201, 409], +) + + +fake = Faker() +fake.add_provider(E164Provider) + +# We need to create records one by one and build a map with the returned id: +company_records_map = {} + +for i in range(3): + company = { + "registration_number": "company_" + fake.ean(length=13), + "name": fake.name(), + "email": fake.email(), + "phone": fake.e164(), + "address": fake.address(), + } + # Create a subject + sub = admin.create_subject( + eid=company["registration_number"], + ) + + # Create company record + admin.create_record( + collection="companies", + record={**company, "sid": sub["id"]}, + expected_statuses=[201, 409], + ) + + # Create company director records + for j in range(3): + director = { + "name": fake.name(), + "email": fake.email(), + "phone": fake.e164(), + "address": fake.address(), + "shares": str(fake.random_int(min=1, max=100)), # a hack, until: SUB-32 + "company": company["registration_number"], + } + sub = admin.create_subject( + eid="email_" + director["email"], + ) + admin.create_record( + collection="directors", + record={**director, "sid": sub["id"]}, + expected_statuses=[201, 409], + ) + +# ??? How to delete a company and all associated records? +# Does mixing subjects in one table make sense? diff --git a/vault/sql.go b/vault/sql.go index a42b331..1646db0 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -113,8 +113,10 @@ func (f FieldSchemaMap) Value() (driver.Value, error) { } type dbCollectionMetadata struct { - Id string `gorm:"primaryKey"` - Name string `gorm:"unique"` + Id string `gorm:"primaryKey"` + Name string `gorm:"unique"` + Description string + Type string FieldSchema FieldSchemaMap `gorm:"type:json"` // Ensures JSON storage CreatedAt time.Time UpdatedAt time.Time @@ -148,6 +150,7 @@ func (st *SqlStore) CreateCollection(ctx context.Context, c *Collection) error { collectionMetadata := dbCollectionMetadata{ Id: c.Id, Name: c.Name, + Type: string(c.Type), FieldSchema: c.Fields, } @@ -172,13 +175,14 @@ func (st *SqlStore) CreateCollection(ctx context.Context, c *Collection) error { } query += `, ` + fieldName + ` TEXT` } - query += `)` - + if c.Type == CollectionTypeData { + query += `, subject_id TEXT` + } + query += `, created_at TIMESTAMP, updated_at TIMESTAMP)` result = tx.Exec(query) if result.Error != nil { return result.Error } - tx.Commit() if tx.Error != nil { return tx.Error @@ -201,11 +205,27 @@ func getCollectionFields(ctx context.Context, db *gorm.DB, collectionName string } func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) { - fields, err := getCollectionFields(ctx, st.db, name) - if err != nil { - return nil, err + if !validateInput(name) { + return nil, &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", name)} } - return &Collection{Name: name, Fields: fields}, nil + + dbCollectionMetadata := dbCollectionMetadata{} + result := st.db.Where("name = ?", name).First(&dbCollectionMetadata) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, &NotFoundError{"collection", name} + } + return nil, result.Error + } + return &Collection{ + Id: dbCollectionMetadata.Id, + Name: dbCollectionMetadata.Name, + Description: dbCollectionMetadata.Description, + Type: CollectionType(dbCollectionMetadata.Type), + Fields: dbCollectionMetadata.FieldSchema, + CreatedAt: dbCollectionMetadata.CreatedAt, + UpdatedAt: dbCollectionMetadata.UpdatedAt, + }, nil } func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) { @@ -258,38 +278,22 @@ func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { return nil } -func (st SqlStore) CreateRecord(ctx context.Context, collectionName string, record Record) (string, error) { +func (st SqlStore) CreateRecord(ctx context.Context, collectionName string, record Record) error { if !validateInput(collectionName) { - return "", &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)} - } - - fields, err := getCollectionFields(ctx, st.db, collectionName) - if err != nil { - return "", err + return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)} } - recordId := GenerateId("rec") newRecord := make(map[string]interface{}) - newRecord["id"] = recordId - for fieldName := range fields { - if fieldValue, ok := record[fieldName]; !ok { - return "", &ValueError{Msg: fmt.Sprintf("Field %s is missing from the record", fieldName)} - } else { - newRecord[fieldName] = fieldValue - } - } for fieldName := range record { - if _, ok := fields[fieldName]; !ok { - return "", &ValueError{Msg: fmt.Sprintf("Field %s is not existent in the schema", fieldName)} - } + newRecord[fieldName] = record[fieldName] } result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Create(&newRecord) if result.Error != nil { - return "", result.Error + return result.Error } - return recordId, nil + return nil } func (st SqlStore) GetRecords(ctx context.Context, collectionName string) ([]string, error) { diff --git a/vault/vault.go b/vault/vault.go index 5c15b5d..627014e 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -15,13 +15,30 @@ type Field struct { IsIndexed bool `json:"indexed" validate:"required,boolean"` } +type CollectionType string + +const ( + CollectionTypeSubject CollectionType = "subject" + CollectionTypeData CollectionType = "data" +) + type Collection struct { Id string `json:"id"` Name string `json:"name" validate:"required,min=3,max=32"` - Fields map[string]Field `json:"fields" validate:"required"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` Description string `json:"description"` + Type CollectionType `json:"type" validate:"oneof=subject data"` // Type can be one of "subject" or "record", default is "record" + Parent string `json:"parent"` + Fields map[string]Field `json:"fields" validate:"required"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type Subject struct { + Id string `json:"id"` + Eid string `json:"eid" validate:"required"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Metadata map[string]string `json:"metadata"` } type Record map[string]string // field name -> value @@ -62,8 +79,8 @@ type Policy struct { Id string `json:"id"` Name string `json:"name"` Description string `json:"description"` - Effect PolicyEffect `json:"effect" validate:"required"` - Actions []PolicyAction `json:"actions" validate:"required"` + Effect PolicyEffect `json:"effect" validate:"required,oneof=allow deny"` + Actions []PolicyAction `json:"actions" validate:"dive,required,oneof=read write"` Resources []string `json:"resources" validate:"required"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` @@ -98,6 +115,7 @@ const ( PRINCIPALS_PPATH = "/principals" RECORDS_PPATH = "/records" POLICIES_PPATH = "/policies" + SUBJECTS_PPATH = "/subjects" ) type VaultDB interface { @@ -105,7 +123,7 @@ type VaultDB interface { GetCollections(ctx context.Context) ([]string, error) CreateCollection(ctx context.Context, col *Collection) error DeleteCollection(ctx context.Context, name string) error - CreateRecord(ctx context.Context, collectionName string, record Record) (string, error) + CreateRecord(ctx context.Context, collectionName string, record Record) error GetRecords(ctx context.Context, collectionName string) ([]string, error) GetRecord(ctx context.Context, collectionName string, recordId string) (Record, error) GetRecordsFilter(ctx context.Context, collectionName string, fieldName string, value string) ([]string, error) @@ -182,6 +200,13 @@ func (vault Vault) CreateCollection( if err := vault.Validate(col); err != nil { return err } + if col.Type == CollectionTypeData && col.Parent == "" { + return &ValueError{Msg: "parent collection must be provided for data collections"} + } + if col.Type == CollectionTypeSubject && col.Parent != "" { + return &ValueError{Msg: "parent collection must not be provided for subject collections"} + } + col.Id = GenerateId("col") err = vault.Db.CreateCollection(ctx, col) @@ -232,10 +257,36 @@ func (vault Vault) CreateRecord( return "", err } + // Ensure all fields are present + for fieldName := range collection.Fields { + if _, ok := record[fieldName]; !ok { + return "", &ValueError{Msg: fmt.Sprintf("Field %s is missing from the record", fieldName)} + } + } + + _, sidProvided := record["sid"] + var prefix string // Is this a good idea? + if collection.Type == CollectionTypeSubject { + if sidProvided { + return "", &ValueError{Msg: "sid cannot be provided for subject collections"} + } + prefix = "sub" + } else if collection.Type == CollectionTypeData { + if !sidProvided || record["sid"] == "" { + return "", &ValueError{Msg: "sid must be provided for data collections"} + } + prefix = "dat" + } + encryptedRecord := make(Record) for fieldName, fieldValue := range record { // Ensure field exists on collection - if _, ok := collection.Fields[fieldName]; !ok { + if fieldName == "" || fieldName == "id" || fieldName == "created_at" || fieldName == "updated_at" { + return "", &ValueError{Msg: fmt.Sprintf("Invalid field name: %s", fieldName)} + } + + // Ensure field exists on collection + if _, ok := collection.Fields[fieldName]; !ok && fieldName != "sid" { return "", &ValueError{fmt.Sprintf("Field %s not found on collection %s", fieldName, collectionName)} } @@ -244,14 +295,24 @@ func (vault Vault) CreateRecord( if err != nil { return "", err } - // Encrypt field value + encryptedValue, err := vault.Priv.Encrypt(fieldValue) if err != nil { return "", err } encryptedRecord[fieldName] = encryptedValue } - return vault.Db.CreateRecord(ctx, collectionName, encryptedRecord) + + fmt.Println(prefix, collection.Type) + encryptedRecord["id"] = GenerateId(prefix) + encryptedRecord["created_at"] = time.Now().Format(time.RFC3339) + encryptedRecord["updated_at"] = time.Now().Format(time.RFC3339) + + err = vault.Db.CreateRecord(ctx, collectionName, encryptedRecord) + if err != nil { + return "", err + } + return encryptedRecord["id"], nil } func (vault Vault) GetRecords( @@ -637,7 +698,6 @@ func (vault Vault) GetTokenValue(ctx context.Context, principal Principal, token } return record, nil - } func (vault *Vault) Validate(payload interface{}) error { diff --git a/vault/vault_test.go b/vault/vault_test.go index 5d452bb..e65a6b7 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -115,7 +115,6 @@ func TestVault(t *testing.T) { } // Check if input and output records match - for k, v := range inputRecord { val := v if val != vaultRecord[k] {