From 1b01303f94d3cac8312951878dba2ab370aa4d39 Mon Sep 17 00:00:00 2001 From: subroseio <136270129+subroseio@users.noreply.github.com> Date: Mon, 4 Dec 2023 17:33:52 +0000 Subject: [PATCH] New models (#61) * WIP: Models refactor * More models * New Models Fixes (#60) * tests passing * tidy mods * validations on username * more validations * change to actor --------- Co-authored-by: Paco Nelos --- api/collections.go | 45 ++--------- api/collections_test.go | 10 +-- api/core.go | 50 ++++-------- api/errors.go | 42 ---------- api/go.mod | 2 +- api/main.go | 3 + api/policies.go | 6 +- api/policies_test.go | 18 +++-- api/principals.go | 39 ++++----- api/principals_test.go | 6 +- api/testing_utils.go | 23 ++++-- api/tokens.go | 7 +- api/tokens_test.go | 8 +- go.work.sum | 1 + simulator/client.py | 5 +- simulator/ecommerce.py | 15 ++-- simulator/ops.py | 9 +-- simulator/password_manager.py | 10 +-- simulator/pci.py | 15 ++-- api/test.env => test.env | 0 vault/errors.go | 64 +++++++++++++-- vault/policy_test.go | 27 ++++--- vault/sql.go | 69 ++++++++-------- vault/test.env | 1 + vault/vault.go | 147 +++++++++++++++++++++------------- vault/vault_test.go | 80 +++++++++--------- 26 files changed, 349 insertions(+), 353 deletions(-) rename api/test.env => test.env (100%) create mode 100644 vault/test.env diff --git a/api/collections.go b/api/collections.go index d7d4ebd..46d9127 100644 --- a/api/collections.go +++ b/api/collections.go @@ -8,35 +8,15 @@ import ( _vault "github.com/subrose/vault" ) -type CollectionFieldModel struct { - Type string `json:"type" validate:"required"` - IsIndexed bool `json:"indexed" validate:"required, boolean"` -} - -type CollectionModel struct { - Name string `json:"name" validate:"required,vaultResourceNames"` - Fields map[string]CollectionFieldModel `json:"fields" validate:"required"` -} - func (core *Core) GetCollection(c *fiber.Ctx) error { collectionName := c.Params("name") principal := GetSessionPrincipal(c) - dbCollection, err := core.vault.GetCollection(c.Context(), principal, collectionName) + collection, err := core.vault.GetCollection(c.Context(), principal, collectionName) if err != nil { return err } - collection := CollectionModel{ - Name: dbCollection.Name, - Fields: make(map[string]CollectionFieldModel, len(dbCollection.Fields)), - } - for _, field := range dbCollection.Fields { - collection.Fields[field.Name] = CollectionFieldModel{ - Type: field.Type, - IsIndexed: field.IsIndexed, - } - } return c.Status(http.StatusOK).JSON(collection) } @@ -51,31 +31,16 @@ func (core *Core) GetCollections(c *fiber.Ctx) error { func (core *Core) CreateCollection(c *fiber.Ctx) error { principal := GetSessionPrincipal(c) - inputCollection := new(CollectionModel) - if err := core.ParseJsonBody(c.Body(), inputCollection); err != nil { + collection := new(_vault.Collection) + if err := core.ParseJsonBody(c.Body(), collection); err != nil { return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil}) } - if err := core.Validate(inputCollection); err != nil { - return c.Status(http.StatusBadRequest).JSON(err) - } - - newCollection := _vault.Collection{ - Name: inputCollection.Name, - Fields: make(map[string]_vault.Field, len(inputCollection.Fields)), - } - for fieldName, field := range inputCollection.Fields { - newCollection.Fields[fieldName] = _vault.Field{ - Name: fieldName, - Type: field.Type, - IsIndexed: field.IsIndexed, - } - } - _, err := core.vault.CreateCollection(c.Context(), principal, newCollection) + err := core.vault.CreateCollection(c.Context(), principal, collection) if err != nil { return err } - return c.Status(http.StatusCreated).SendString("Collection created") + return c.Status(http.StatusCreated).JSON(collection) } func (core *Core) DeleteCollection(c *fiber.Ctx) error { diff --git a/api/collections_test.go b/api/collections_test.go index 5ae3e9d..7e35db9 100644 --- a/api/collections_test.go +++ b/api/collections_test.go @@ -11,9 +11,9 @@ import ( func TestCollections(t *testing.T) { app, core := InitTestingVault(t) - customerCollection := CollectionModel{ + customerCollection := &_vault.Collection{ Name: "customers", - Fields: map[string]CollectionFieldModel{ + Fields: map[string]_vault.Field{ "name": {Type: "name", IsIndexed: true}, "phone_number": {Type: "phone_number", IsIndexed: true}, "dob": {Type: "date", IsIndexed: false}, @@ -36,7 +36,7 @@ func TestCollections(t *testing.T) { }, nil) response := performRequest(t, app, request) - var returnedCollection CollectionModel + var returnedCollection _vault.Collection checkResponse(t, response, http.StatusOK, &returnedCollection) if returnedCollection.Name != "customers" { @@ -60,9 +60,9 @@ func TestCollections(t *testing.T) { t.Run("can delete a collection", func(t *testing.T) { // Create a dummy collection - collectionToDelete := CollectionModel{ + collectionToDelete := _vault.Collection{ Name: "delete_me", - Fields: map[string]CollectionFieldModel{ + Fields: map[string]_vault.Field{ "name": {Type: "name", IsIndexed: true}, }, } diff --git a/api/core.go b/api/core.go index 4463122..60dce34 100644 --- a/api/core.go +++ b/api/core.go @@ -6,8 +6,6 @@ import ( "encoding/json" "errors" - "github.com/go-playground/validator/v10" - "github.com/knadh/koanf" "github.com/knadh/koanf/providers/confmap" "github.com/knadh/koanf/providers/env" @@ -35,10 +33,9 @@ type CoreConfig struct { // interface for API handlers and is responsible for managing the logical and physical // backends, router, security barrier, and audit trails. type Core struct { - vault _vault.Vault - logger _logger.Logger - conf *CoreConfig - validator *validator.Validate + vault _vault.Vault + logger _logger.Logger + conf *CoreConfig } func ReadConfigs() (*CoreConfig, error) { @@ -132,14 +129,14 @@ func CreateCore(conf *CoreConfig) (*Core, error) { vaultLogger, err := _logger.NewLogger("VAULT", conf.LOG_SINK, conf.LOG_HANDLER, conf.LOG_LEVEL, conf.DEV_MODE) vault := _vault.Vault{ - Db: db, - Priv: priv, - Logger: vaultLogger, - Signer: signer, + Db: db, + Priv: priv, + Logger: vaultLogger, + Signer: signer, + Validator: _vault.NewValidator(), } c.vault = vault - c.validator = newValidator() return c, err } @@ -155,8 +152,11 @@ func (core *Core) Init() error { panic(err) } } - _, err := core.vault.Db.CreatePolicy(ctx, _vault.Policy{ - PolicyId: "root", + // TODO: Move this to a bootstrap function + rootPolicyId := _vault.GenerateId("pol") + err := core.vault.Db.CreatePolicy(ctx, &_vault.Policy{ + Id: rootPolicyId, + Name: "root", Effect: _vault.EffectAllow, Actions: []_vault.PolicyAction{_vault.PolicyActionWrite, _vault.PolicyActionRead}, Resources: []string{"*"}, @@ -174,8 +174,8 @@ func (core *Core) Init() error { Username: core.conf.ADMIN_USERNAME, Password: core.conf.ADMIN_PASSWORD, Description: "admin", - Policies: []string{"root"}} - err = core.vault.CreatePrincipal(ctx, adminPrincipal, adminPrincipal.Username, adminPrincipal.Password, adminPrincipal.Description, adminPrincipal.Policies) + Policies: []string{rootPolicyId}} + err = core.vault.CreatePrincipal(ctx, adminPrincipal, &adminPrincipal) // The admin bootstraps himself var co *_vault.ConflictError if err != nil { @@ -189,26 +189,6 @@ func (core *Core) Init() error { return nil } -func (core *Core) Validate(payload interface{}) []*ValidationError { - var errors []*ValidationError - - err := core.validator.Struct(payload) - if err != nil { - // Check if the error is a validator.ValidationErrors type - if _, ok := err.(*validator.InvalidValidationError); ok { - return errors - } - for _, err := range err.(validator.ValidationErrors) { - var element ValidationError - element.FailedField = err.Field() - element.Tag = err.Tag() - element.Value = err.Param() - errors = append(errors, &element) - } - } - return errors -} - func (core *Core) ParseJsonBody(data []byte, payload interface{}) error { decoder := json.NewDecoder(bytes.NewReader(data)) decoder.DisallowUnknownFields() diff --git a/api/errors.go b/api/errors.go index dff076c..6476933 100644 --- a/api/errors.go +++ b/api/errors.go @@ -1,13 +1,5 @@ package main -import ( - "reflect" - "regexp" - "strings" - - "github.com/go-playground/validator/v10" -) - type AuthError struct{ Msg string } func (e *AuthError) Error() string { @@ -18,37 +10,3 @@ type ErrorResponse struct { Message string `json:"message"` Errors []*interface{} `json:"errors"` } - -func ValidateResourceName(fl validator.FieldLevel) bool { - reg := "^[a-zA-Z0-9._]{1,249}$" - match, _ := regexp.MatchString(reg, fl.Field().String()) - - // Check for prohibited values: single period, double underscore, and hyphen - if fl.Field().String() == "." || fl.Field().String() == "__" || fl.Field().String() == "-" { - return false - } - - return match -} - -type ValidationError struct { - FailedField string `json:"failed_field"` - Tag string `json:"tag"` - Value string `json:"value"` -} - -func newValidator() *validator.Validate { - v := validator.New(validator.WithRequiredStructEnabled()) - err := v.RegisterValidation("vaultResourceNames", ValidateResourceName) - if err != nil { - panic(err) - } - v.RegisterTagNameFunc(func(fld reflect.StructField) string { - name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] - if name == "-" { - return "" - } - return name - }) - return v -} diff --git a/api/go.mod b/api/go.mod index caeb15f..0aa8576 100644 --- a/api/go.mod +++ b/api/go.mod @@ -4,7 +4,6 @@ go 1.21 require ( github.com/go-playground/assert/v2 v2.2.0 - github.com/go-playground/validator/v10 v10.16.0 github.com/gofiber/fiber/v2 v2.51.0 github.com/joho/godotenv v1.3.0 github.com/knadh/koanf v1.5.0 @@ -29,6 +28,7 @@ require ( github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.16.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect diff --git a/api/main.go b/api/main.go index 1ef55fd..73cfb57 100644 --- a/api/main.go +++ b/api/main.go @@ -93,6 +93,7 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error { var ne *_vault.NotFoundError var ae *AuthError var co *_vault.ConflictError + var va *_vault.ValidationErrors switch { case errors.As(err, &ve): return ctx.Status(http.StatusBadRequest).JSON(ErrorResponse{ve.Error(), nil}) @@ -106,6 +107,8 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error { return ctx.Status(http.StatusNotImplemented).JSON(ErrorResponse{err.Error(), nil}) case errors.As(err, &co): return ctx.Status(http.StatusConflict).JSON(ErrorResponse{co.Error(), nil}) + case errors.As(err, &va): + return ctx.Status(http.StatusConflict).JSON(ErrorResponse{va.Error(), nil}) default: // Handle other types of errors by returning a generic 500 - this should remain obscure as it can leak information core.logger.Error(fmt.Sprintf("Unhandled error: %s", err.Error())) diff --git a/api/policies.go b/api/policies.go index 735410c..227aba6 100644 --- a/api/policies.go +++ b/api/policies.go @@ -38,11 +38,7 @@ func (core *Core) CreatePolicy(c *fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil}) } - if validationErrs := core.Validate(policy); validationErrs != nil { - return c.Status(fiber.StatusBadRequest).JSON(validationErrs) - } - - _, err := core.vault.CreatePolicy(c.Context(), sessionPrincipal, policy) + err := core.vault.CreatePolicy(c.Context(), sessionPrincipal, &policy) if err != nil { core.logger.Error(fmt.Sprintf("Failed to create policy %v", err)) return err diff --git a/api/policies_test.go b/api/policies_test.go index 9ec359c..22be8b2 100644 --- a/api/policies_test.go +++ b/api/policies_test.go @@ -12,14 +12,15 @@ import ( func TestPolicies(t *testing.T) { app, core := InitTestingVault(t) - testPolicyId := "test-policy" + testPolicyId := "" + testResource := "/resources/stuff" t.Run("can create policy", func(t *testing.T) { testPolicy := _vault.Policy{ - PolicyId: testPolicyId, + Name: "test-policy", Effect: _vault.EffectAllow, Actions: []_vault.PolicyAction{_vault.PolicyActionRead}, - Resources: []string{fmt.Sprintf("/policies/%s", testPolicyId)}, + Resources: []string{testResource}, } request := newRequest(t, http.MethodPost, "/policies", map[string]string{ @@ -29,11 +30,12 @@ func TestPolicies(t *testing.T) { response := performRequest(t, app, request) var createdPolicy _vault.Policy checkResponse(t, response, http.StatusCreated, &createdPolicy) + testPolicyId = createdPolicy.Id // Assertions assert.Equal(t, _vault.EffectAllow, createdPolicy.Effect) assert.Equal(t, []_vault.PolicyAction{_vault.PolicyActionRead}, createdPolicy.Actions) - assert.Equal(t, []string{fmt.Sprintf("/policies/%s", testPolicyId)}, createdPolicy.Resources) + assert.Equal(t, []string{testResource}, createdPolicy.Resources) }) t.Run("can get policy", func(t *testing.T) { @@ -48,13 +50,13 @@ func TestPolicies(t *testing.T) { // Assertions assert.Equal(t, _vault.EffectAllow, returnedPolicy.Effect) assert.Equal(t, []_vault.PolicyAction{_vault.PolicyActionRead}, returnedPolicy.Actions) - assert.Equal(t, []string{fmt.Sprintf("/policies/%s", testPolicyId)}, returnedPolicy.Resources) + assert.Equal(t, []string{testResource}, returnedPolicy.Resources) }) t.Run("can delete policy", func(t *testing.T) { // Add a dummy policy first before deleting dummyPolicy := _vault.Policy{ - PolicyId: "dummy-policy", + Name: "dummy-policy", Effect: _vault.EffectAllow, Actions: []_vault.PolicyAction{_vault.PolicyActionRead}, Resources: []string{"/policies/dummy-policy"}, @@ -69,7 +71,7 @@ func TestPolicies(t *testing.T) { checkResponse(t, response, http.StatusCreated, &returnedPolicy) // Delete it - request = newRequest(t, http.MethodDelete, fmt.Sprintf("/policies/%s", dummyPolicy.PolicyId), map[string]string{ + request = newRequest(t, http.MethodDelete, fmt.Sprintf("/policies/%s", returnedPolicy.Id), map[string]string{ "Authorization": createBasicAuthHeader(core.conf.ADMIN_USERNAME, core.conf.ADMIN_PASSWORD), }, nil) @@ -77,7 +79,7 @@ func TestPolicies(t *testing.T) { checkResponse(t, response, http.StatusNoContent, nil) // Check it's gone - request = newRequest(t, http.MethodGet, fmt.Sprintf("/policies/%s", dummyPolicy.PolicyId), map[string]string{ + request = newRequest(t, http.MethodGet, fmt.Sprintf("/policies/%s", returnedPolicy.Id), map[string]string{ "Authorization": createBasicAuthHeader(core.conf.ADMIN_USERNAME, core.conf.ADMIN_PASSWORD), }, nil) diff --git a/api/principals.go b/api/principals.go index 04d33eb..1387efd 100644 --- a/api/principals.go +++ b/api/principals.go @@ -4,46 +4,37 @@ import ( "net/http" "github.com/gofiber/fiber/v2" + _vault "github.com/subrose/vault" ) -type NewPrincipal struct { - Username string `json:"username" validate:"required,min=1,max=32"` - Password string `json:"password" validate:"required,min=4,max=32"` // This is to limit the size of the password hash. - Description string `json:"description"` - Policies []string `json:"policies"` -} +// type NewPrincipal struct { +// Username string `json:"username" validate:"required,min=1,max=32"` +// Password string `json:"password" validate:"required,min=4,max=32"` // This is to limit the size of the password hash. +// Description string `json:"description"` +// Policies []string `json:"policies"` +// } type PrincipalResponse struct { - Username string `json:"username"` + Id string `json:"id"` + Username string `json:"username" validate:"required,min=3,max=32"` Description string `json:"description"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` Policies []string `json:"policies"` } func (core *Core) CreatePrincipal(c *fiber.Ctx) error { - var newPrincipal NewPrincipal - if err := core.ParseJsonBody(c.Body(), &newPrincipal); err != nil { + var principal _vault.Principal + if err := core.ParseJsonBody(c.Body(), &principal); err != nil { return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil}) } - if validationErrs := core.Validate(newPrincipal); validationErrs != nil { - return c.Status(fiber.StatusBadRequest).JSON(validationErrs) - } - sessionPrincipal := GetSessionPrincipal(c) - err := core.vault.CreatePrincipal(c.Context(), sessionPrincipal, - newPrincipal.Username, - newPrincipal.Password, - newPrincipal.Description, - newPrincipal.Policies, - ) + err := core.vault.CreatePrincipal(c.Context(), sessionPrincipal, &principal) if err != nil { return err } - return c.Status(http.StatusCreated).JSON(PrincipalResponse{ - Username: newPrincipal.Username, - Description: newPrincipal.Description, - Policies: newPrincipal.Policies, - }) + return c.Status(http.StatusCreated).JSON(PrincipalResponse{Id: principal.Id, Username: principal.Username, Description: principal.Description, Policies: principal.Policies, CreatedAt: principal.CreatedAt, UpdatedAt: principal.UpdatedAt}) } func (core *Core) GetPrincipal(c *fiber.Ctx) error { diff --git a/api/principals_test.go b/api/principals_test.go index a3402fe..69d3aff 100644 --- a/api/principals_test.go +++ b/api/principals_test.go @@ -4,12 +4,14 @@ import ( "fmt" "net/http" "testing" + + _vault "github.com/subrose/vault" ) func TestPrincipals(t *testing.T) { app, core := InitTestingVault(t) - newPrincipal := NewPrincipal{ + newPrincipal := _vault.Principal{ Username: "newprincipal", Password: "password", Description: "A new principal", @@ -50,7 +52,7 @@ func TestPrincipals(t *testing.T) { }) t.Run("can't create principals without assigned roles", func(t *testing.T) { - request := newRequest(t, http.MethodPost, "/principals", nil, NewPrincipal{ + request := newRequest(t, http.MethodPost, "/principals", nil, _vault.Principal{ Username: "newprincipal", Password: "password", Description: "A new principal", diff --git a/api/testing_utils.go b/api/testing_utils.go index b716b29..6328e7e 100644 --- a/api/testing_utils.go +++ b/api/testing_utils.go @@ -16,11 +16,12 @@ import ( _vault "github.com/subrose/vault" ) -var adminPrincipal = _vault.Principal{Username: "admin", Password: "admin", Policies: []string{"root"}} +var rootPolicyId = _vault.GenerateId("pol") +var adminPrincipal = _vault.Principal{Username: "admin", Password: "admin", Policies: []string{rootPolicyId}} func InitTestingVault(t *testing.T) (*fiber.App, *Core) { // Read environment variables if a test.env file exists, error is ignored on purpose - _ = godotenv.Load("test.env") + _ = godotenv.Load("../test.env") coreConfig, err := ReadConfigs() if err != nil { @@ -50,14 +51,16 @@ func InitTestingVault(t *testing.T) (*fiber.App, *Core) { if err != nil { t.Fatal("Failed to create logger", err) } - vault := _vault.Vault{Db: db, Priv: priv, Logger: vaultLogger, Signer: signer} + vault := _vault.Vault{Db: db, Priv: priv, Logger: vaultLogger, Signer: signer, Validator: _vault.NewValidator()} bootstrapContext := context.Background() err = vault.Db.Flush(bootstrapContext) if err != nil { t.Fatal("Failed to flush db", err) } - _, err = db.CreatePolicy(bootstrapContext, _vault.Policy{ - PolicyId: "root", + + err = db.CreatePolicy(bootstrapContext, &_vault.Policy{ + Id: rootPolicyId, + Name: "root", Effect: _vault.EffectAllow, Actions: []_vault.PolicyAction{_vault.PolicyActionRead, _vault.PolicyActionWrite}, Resources: []string{"*"}, @@ -67,7 +70,13 @@ func InitTestingVault(t *testing.T) (*fiber.App, *Core) { t.Fatal("Failed to create root policy", err) } - err = vault.CreatePrincipal(bootstrapContext, adminPrincipal, coreConfig.ADMIN_USERNAME, coreConfig.ADMIN_PASSWORD, "admin principal", []string{"root"}) + err = vault.CreatePrincipal(bootstrapContext, adminPrincipal, &_vault.Principal{ + Id: _vault.GenerateId("prin"), + Username: coreConfig.ADMIN_USERNAME, + Password: coreConfig.ADMIN_PASSWORD, + Description: "admin principal", + Policies: []string{rootPolicyId}, + }) if err != nil { t.Fatal("Failed to create admin principal", err) } @@ -134,7 +143,7 @@ func checkResponse(t *testing.T, response *http.Response, expectedStatusCode int // Check if target is a struct if _, ok := target.(struct{}); ok { - validate := newValidator() + validate := _vault.NewValidator() if err := validate.Struct(target); err != nil { t.Fatalf("Error validating response: %v", err) } diff --git a/api/tokens.go b/api/tokens.go index 4b62dd9..5022230 100644 --- a/api/tokens.go +++ b/api/tokens.go @@ -20,9 +20,10 @@ func (core *Core) CreateToken(c *fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil}) } - if validationErrs := core.Validate(tokenRequest); validationErrs != nil { - return c.Status(fiber.StatusBadRequest).JSON(validationErrs) - } + // TODO: Move validation to vault + // if validationErrs := core.Validate(tokenRequest); validationErrs != nil { + // return c.Status(fiber.StatusBadRequest).JSON(validationErrs) + // } sessionPrincipal := GetSessionPrincipal(c) tokenId, err := core.vault.CreateToken(c.Context(), sessionPrincipal, tokenRequest.Collection, tokenRequest.RecordId, tokenRequest.Field, tokenRequest.Format) diff --git a/api/tokens_test.go b/api/tokens_test.go index d020469..c1825a6 100644 --- a/api/tokens_test.go +++ b/api/tokens_test.go @@ -14,12 +14,12 @@ func TestTokens(t *testing.T) { customerCollection := vault.Collection{ Name: "test", Fields: map[string]vault.Field{ - "name": {Name: "name", Type: "name", IsIndexed: true}, - "phone_number": {Name: "phone_number", Type: "phone_number", IsIndexed: true}, - "dob": {Name: "dob", Type: "date", IsIndexed: false}, + "name": {Type: "name", IsIndexed: true}, + "phone_number": {Type: "phone_number", IsIndexed: true}, + "dob": {Type: "date", IsIndexed: false}, }, } - _, err := core.vault.CreateCollection(context.Background(), adminPrincipal, customerCollection) + err := core.vault.CreateCollection(context.Background(), adminPrincipal, &customerCollection) if err != nil { t.Error(err) t.FailNow() diff --git a/go.work.sum b/go.work.sum index 754d376..bb9f5a1 100644 --- a/go.work.sum +++ b/go.work.sum @@ -139,6 +139,7 @@ golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/simulator/client.py b/simulator/client.py index a1ecbc5..af45b95 100644 --- a/simulator/client.py +++ b/simulator/client.py @@ -9,7 +9,7 @@ class Policy(BaseModel): - policy_id: str + name: Optional[str] = None effect: str actions: List[str] resources: List[str] @@ -134,13 +134,14 @@ def get_record( def create_policy( self, policy: Policy, expected_statuses: Optional[list[int]] = None - ) -> None: + ) -> dict[str, str]: response = requests.post( f"{self.vault_url}/policies", auth=(self.username, self.password), json=policy.model_dump(), ) check_expected_status(response, expected_statuses) + return response.json() def get_policy( self, policy_id: str, expected_statuses: Optional[list[int]] = None diff --git a/simulator/ecommerce.py b/simulator/ecommerce.py index dd5db69..2ca1d68 100644 --- a/simulator/ecommerce.py +++ b/simulator/ecommerce.py @@ -36,9 +36,8 @@ # Create policies # Backend can write customer details -admin.create_policy( +backend_policy = admin.create_policy( policy=Policy( - policy_id="backend", effect="allow", actions=["write"], resources=["/collections/customers/*"], @@ -47,9 +46,8 @@ ) # Marketing can read masked records -admin.create_policy( +marketing_policy = admin.create_policy( policy=Policy( - policy_id="marketing", effect="allow", actions=["read"], resources=[ @@ -60,9 +58,8 @@ ) # Customer service team can read all customer details in plain -admin.create_policy( +cs_policy = admin.create_policy( policy=Policy( - policy_id="customer-service", effect="allow", actions=["read"], resources=[ @@ -83,7 +80,7 @@ username=backend.username, password=backend.password, description="backend", - policies=["backend"], + policies=[backend_policy["id"]], expected_statuses=[201, 409], ) @@ -91,7 +88,7 @@ username=marketing.username, password=marketing.password, description="marketing", - policies=["marketing"], + policies=[marketing_policy["id"]], expected_statuses=[201, 409], ) @@ -99,7 +96,7 @@ username=customer_service.username, password=customer_service.password, description="customer-service", - policies=["customer-service"], + policies=[cs_policy["id"]], expected_statuses=[201, 409], ) diff --git a/simulator/ops.py b/simulator/ops.py index f89edd6..b985867 100644 --- a/simulator/ops.py +++ b/simulator/ops.py @@ -32,9 +32,8 @@ # Create a temporary policy for somebody -admin.create_policy( +temp_policy = admin.create_policy( policy=Policy( - policy_id="secret-access", effect="allow", actions=["read", "write"], resources=["/collections/secrets/*"], @@ -52,7 +51,7 @@ username=SOMEBODY_USERNAME, password=SOMEBODY_PASSWORD, description="somebody", - policies=["secret-access"], + policies=[temp_policy["id"]], expected_statuses=[201, 409], ) @@ -68,13 +67,13 @@ # Policy is removed admin.delete_policy( - policy_id="secret-access", + policy_id=temp_policy["id"], expected_statuses=[204], ) # Policy is removed twice for good measure admin.delete_policy( - policy_id="secret-access", + policy_id=temp_policy["id"], expected_statuses=[404], ) diff --git a/simulator/password_manager.py b/simulator/password_manager.py index a968596..01a1863 100644 --- a/simulator/password_manager.py +++ b/simulator/password_manager.py @@ -40,9 +40,8 @@ # Step 3: Create policies using admin role -admin.create_policy( +alice_policy = admin.create_policy( policy=Policy( - policy_id="alice-access-own_passwords", effect="allow", actions=["read", "write"], resources=["/collections/alice_passwords/*"], @@ -51,9 +50,8 @@ ) -admin.create_policy( +bob_policy = admin.create_policy( policy=Policy( - policy_id="bob-access-own_passwords", effect="allow", actions=["read", "write"], resources=["/collections/bob_passwords/*"], @@ -65,7 +63,7 @@ username=ALICE_USERNAME, password=ALICE_PASSWORD, description="alice", - policies=["alice-access-own_passwords"], + policies=[alice_policy["id"]], expected_statuses=[201, 409], ) @@ -76,7 +74,7 @@ username=BOB_USERNAME, password=BOB_PASSWORD, description="bob", - policies=["bob-access-own_passwords"], + policies=[bob_policy["id"]], expected_statuses=[201, 409], ) diff --git a/simulator/pci.py b/simulator/pci.py index 0277d2d..9c51602 100644 --- a/simulator/pci.py +++ b/simulator/pci.py @@ -26,9 +26,8 @@ # Create policies # Backend can write customer details -admin.create_policy( +backend_policy = admin.create_policy( policy=Policy( - policy_id="backend-ccs", effect="allow", actions=["write"], resources=["/collections/credit_cards/*"], @@ -37,9 +36,8 @@ ) # CS can read masked records -admin.create_policy( +cs_policy = admin.create_policy( policy=Policy( - policy_id="cs-ccs", effect="allow", actions=["read"], resources=[ @@ -50,9 +48,8 @@ ) # Proxy service can read plain for forwarding to payment gateway -admin.create_policy( +proxy_policy = admin.create_policy( policy=Policy( - policy_id="proxy-ccs", effect="allow", actions=["read"], resources=[ @@ -71,7 +68,7 @@ username=backend.username, password=backend.password, description="backend", - policies=["backend-ccs"], + policies=[backend_policy["id"]], expected_statuses=[201, 409], ) @@ -79,7 +76,7 @@ username=cs.username, password=cs.password, description="cs", - policies=["cs-ccs"], + policies=[cs_policy["id"]], expected_statuses=[201, 409], ) @@ -87,7 +84,7 @@ username=proxy.username, password=proxy.password, description="proxy", - policies=["proxy-ccs"], + policies=[proxy_policy["id"]], expected_statuses=[201, 409], ) diff --git a/api/test.env b/test.env similarity index 100% rename from api/test.env rename to test.env diff --git a/vault/errors.go b/vault/errors.go index 870b783..5e69811 100644 --- a/vault/errors.go +++ b/vault/errors.go @@ -3,6 +3,11 @@ package vault import ( "errors" "fmt" + "reflect" + "regexp" + "strings" + + "github.com/go-playground/validator/v10" ) var ( @@ -10,16 +15,10 @@ var ( ErrNotSupported = errors.New("notsupported") ) -type AuthError struct{ Msg string } - -func (e *AuthError) Error() string { - return e.Msg -} - type ForbiddenError struct{ request Request } func (e *ForbiddenError) Error() string { - return fmt.Sprintf("forbidden: principal %s doing %s on %s", e.request.Principal.Username, e.request.Action, e.request.Resource) + return fmt.Sprintf("forbidden: principal %s doing %s on %s", e.request.Actor.Username, e.request.Action, e.request.Resource) } type NotFoundError struct { @@ -42,3 +41,54 @@ type ValueError struct{ Msg string } func (e *ValueError) Error() string { return e.Msg } + +type ErrorResponse struct { + Message string `json:"message"` + Errors []*interface{} `json:"errors"` +} + +func ValidateResourceName(fl validator.FieldLevel) bool { + reg := "^[a-zA-Z0-9._]{1,249}$" + match, _ := regexp.MatchString(reg, fl.Field().String()) + + // Check for prohibited values: single period, double underscore, and hyphen + if fl.Field().String() == "." || fl.Field().String() == "__" || fl.Field().String() == "-" { + return false + } + + return match +} + +type ValidationError struct { + FailedField string `json:"failed_field"` + Tag string `json:"tag"` + Value string `json:"value"` +} + +type ValidationErrors struct { + Errors []*ValidationError `json:"errors"` +} + +func (e ValidationErrors) Error() string { + errors := make([]string, len(e.Errors)) + for i, err := range e.Errors { + errors[i] = fmt.Sprintf("%s: %s", err.FailedField, err.Tag) + } + return strings.Join(errors, ", ") +} + +func NewValidator() *validator.Validate { + v := validator.New(validator.WithRequiredStructEnabled()) + err := v.RegisterValidation("vaultResourceNames", ValidateResourceName) + if err != nil { + panic(err) + } + v.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) + return v +} diff --git a/vault/policy_test.go b/vault/policy_test.go index df86017..aa20430 100644 --- a/vault/policy_test.go +++ b/vault/policy_test.go @@ -26,9 +26,9 @@ func (pm DummyPolicyManager) GetPolicies(ctx context.Context, policyIds []string } func (pm DummyPolicyManager) CreatePolicy(ctx context.Context, p Policy) (string, error) { - pm.policies[p.PolicyId] = p + pm.policies[p.Id] = p - return p.PolicyId, nil + return p.Id, nil } func (pm DummyPolicyManager) DeletePolicy(ctx context.Context, policyId string) error { @@ -39,16 +39,18 @@ func (pm DummyPolicyManager) DeletePolicy(ctx context.Context, policyId string) func getDummyPolicy(principal string) []Policy { return []Policy{ { - fmt.Sprintf("%s-allow", principal), - EffectAllow, - []PolicyAction{PolicyActionRead}, - []string{"allowed-resource/*", "restricted-resource"}, + Id: "pol_allow", + Name: fmt.Sprintf("%s-allow", principal), + Effect: EffectAllow, + Actions: []PolicyAction{PolicyActionRead}, + Resources: []string{"allowed-resource/*", "restricted-resource"}, }, { - fmt.Sprintf("%s-deny", principal), - EffectDeny, - []PolicyAction{PolicyActionRead}, - []string{"restricted-resource"}, + Id: "pol_deny", + Name: fmt.Sprintf("%s-deny", principal), + Effect: EffectDeny, + Actions: []PolicyAction{PolicyActionRead}, + Resources: []string{"restricted-resource"}, }, } } @@ -69,8 +71,8 @@ func makePrincipal() Principal { return Principal{ Username: "test", Policies: []string{ - "test-allow", - "test-deny", + "pol_allow", + "pol_deny", }, } } @@ -113,6 +115,7 @@ func TestPolicies(t *testing.T) { PolicyActionRead, "aallowed-resource", } + allowed := EvaluateRequest(request, policies) if allowed { t.Fail() diff --git a/vault/sql.go b/vault/sql.go index f876ca2..e7c292b 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -37,11 +37,11 @@ func NewSqlStore(dsn string) (*SqlStore, error) { func (st *SqlStore) CreateSchemas() error { tables := map[string]string{ - "principals": "CREATE TABLE IF NOT EXISTS principals (username TEXT PRIMARY KEY, password TEXT, description TEXT)", + "principals": "CREATE TABLE IF NOT EXISTS principals (id TEXT PRIMARY KEY, username TEXT UNIQUE, password TEXT, description TEXT)", "policies": "CREATE TABLE IF NOT EXISTS policies (id TEXT PRIMARY KEY, effect TEXT, actions TEXT[], resources TEXT[])", - "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT, UNIQUE(username, policy_id))", + "principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (principal_id TEXT, policy_id TEXT, UNIQUE(principal_id, policy_id))", "tokens": "CREATE TABLE IF NOT EXISTS tokens (id TEXT PRIMARY KEY, value TEXT)", - "collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (name TEXT PRIMARY KEY, field_schema JSON)", + "collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (id TEXT PRIMARY KEY, name TEXT UNIQUE, field_schema JSON)", } for _, query := range tables { @@ -54,10 +54,10 @@ func (st *SqlStore) CreateSchemas() error { return nil } -func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) { +func (st *SqlStore) CreateCollection(ctx context.Context, c *Collection) error { tx, err := st.db.BeginTxx(ctx, nil) if err != nil { - return "", err + return err } defer func() { @@ -70,20 +70,21 @@ func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, fieldSchema, err := json.Marshal(c.Fields) if err != nil { - return "", err + return err } - _, err = tx.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", map[string]interface{}{ + _, err = tx.NamedExecContext(ctx, "INSERT INTO collection_metadata (id, name, field_schema) VALUES (:id, :name, :field_schema)", map[string]interface{}{ + "id": c.Id, "name": c.Name, "field_schema": fieldSchema, }) if err != nil { if pqErr, ok := err.(*pq.Error); ok { if pqErr.Code == "23505" { // unique_violation - return "", &ConflictError{c.Name} + return &ConflictError{c.Name} } } - return "", err + return err } tableName := "collection_" + c.Name @@ -95,15 +96,15 @@ func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, queryBuilder.WriteString(")") _, err = tx.ExecContext(ctx, queryBuilder.String()) if err != nil { - return "", err + return err } err = tx.Commit() if err != nil { - return "", err + return err } - return c.Name, nil + return nil } func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) { @@ -331,8 +332,8 @@ func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, reco } func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { - var principal Principal - err := st.db.GetContext(ctx, &principal, "SELECT * FROM principals WHERE username = $1", username) + var dbPrincipal Principal + err := st.db.GetContext(ctx, &dbPrincipal, "SELECT * FROM principals WHERE username = $1", username) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"principal", username} @@ -340,7 +341,7 @@ func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principa return nil, err } - rows, err := st.db.QueryxContext(ctx, "SELECT policy_id FROM principal_policies WHERE username = $1", username) + rows, err := st.db.QueryxContext(ctx, "SELECT policy_id FROM principal_policies WHERE principal_id = $1", dbPrincipal.Id) if err != nil { return nil, err } @@ -359,12 +360,12 @@ func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principa return nil, err } - principal.Policies = policyIds + dbPrincipal.Policies = policyIds - return &principal, nil + return &dbPrincipal, nil } -func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) error { +func (st SqlStore) CreatePrincipal(ctx context.Context, principal *Principal) error { tx, err := st.db.BeginTxx(ctx, nil) if err != nil { return err @@ -379,7 +380,7 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err } }() - _, err = tx.NamedExecContext(ctx, "INSERT INTO principals (username, password, description) VALUES (:username, :password, :description)", &principal) + _, err = tx.NamedExecContext(ctx, "INSERT INTO principals (id, username, password, description) VALUES (:id, :username, :password, :description)", &principal) if err != nil { if pqErr, ok := err.(*pq.Error); ok { if pqErr.Code == "23505" { @@ -390,7 +391,7 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err } for _, policyId := range principal.Policies { - _, err = tx.ExecContext(ctx, "INSERT INTO principal_policies (username, policy_id) VALUES ($1, $2)", principal.Username, policyId) + _, err = tx.ExecContext(ctx, "INSERT INTO principal_policies (principal_id, policy_id) VALUES ($1, $2)", principal.Id, policyId) if err != nil { return err } @@ -404,7 +405,7 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err return nil } -func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { +func (st SqlStore) DeletePrincipal(ctx context.Context, id string) error { // Start a transaction tx, err := st.db.BeginTxx(ctx, nil) if err != nil { @@ -423,12 +424,12 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { } }() - _, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE username = $1", username) + _, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE principal_id = $1", id) if err != nil { return err } - result, err := tx.ExecContext(ctx, "DELETE FROM principals WHERE username = $1", username) + result, err := tx.ExecContext(ctx, "DELETE FROM principals WHERE username = $1", id) if err != nil { return err } @@ -437,7 +438,7 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error { return err } if rowsAffected == 0 { - return &NotFoundError{"principal", username} + return &NotFoundError{"principal", id} } err = tx.Commit() if err != nil { @@ -466,7 +467,7 @@ func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, err } p := Policy{ - PolicyId: id, + Name: id, Effect: PolicyEffect(effect), Actions: actionList, Resources: resources, @@ -507,7 +508,7 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli } policies = append(policies, &Policy{ - PolicyId: id, + Name: id, Effect: PolicyEffect(effect), Actions: actionList, Resources: resources, @@ -521,10 +522,10 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli return policies, nil } -func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { +func (st SqlStore) CreatePolicy(ctx context.Context, p *Policy) error { tx, err := st.db.BeginTxx(ctx, nil) if err != nil { - return "", err + return err } defer func() { @@ -543,32 +544,32 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) { resources := make(pq.StringArray, len(p.Resources)) copy(resources, p.Resources) query, args, err := sqlx.Named(query, map[string]interface{}{ - "id": p.PolicyId, + "id": p.Id, "effect": string(p.Effect), "actions": actions, "resources": resources, }) if err != nil { - return "", err + return err } query = tx.Rebind(query) _, err = tx.ExecContext(ctx, query, args...) if err != nil { if pqErr, ok := err.(*pq.Error); ok { if pqErr.Code == "23505" { // unique_violation - return "", &ConflictError{p.PolicyId} + return &ConflictError{p.Id} } } - return "", err + return err } err = tx.Commit() if err != nil { - return "", err + return err } - return p.PolicyId, nil + return nil } func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error { diff --git a/vault/test.env b/vault/test.env new file mode 100644 index 0000000..79bd1a4 --- /dev/null +++ b/vault/test.env @@ -0,0 +1 @@ +THORN_DATABASE_URL=postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable diff --git a/vault/vault.go b/vault/vault.go index e2b59e7..9b4a703 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -6,18 +6,22 @@ import ( "strings" "time" + "github.com/go-playground/validator/v10" "golang.org/x/crypto/bcrypt" ) type Field struct { - Name string - Type string - IsIndexed bool + Type string `json:"type" validate:"required"` + IsIndexed bool `json:"indexed" validate:"required,boolean"` } type Collection struct { - Name string - Fields map[string]Field + Id string `json:"id"` + Name string `json:"name" validate:"required,min=3,max=32"` + Fields map[string]Field `json:"fields" validate:"required"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Description string `json:"description"` } type Record map[string]string // field name -> value @@ -55,31 +59,38 @@ const ( ) type Policy struct { - PolicyId string `json:"policy_id" validate:"required"` - Effect PolicyEffect `json:"effect" validate:"required"` - Actions []PolicyAction `json:"actions" validate:"required"` - Resources []string `json:"resources" validate:"required"` + Id string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Effect PolicyEffect `json:"effect" validate:"required"` + Actions []PolicyAction `json:"actions" validate:"required"` + Resources []string `json:"resources" validate:"required"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type Principal struct { - Username string - Password string - Description string - CreatedAt string - Policies []string + Id string `json:"id"` + Username string `json:"username" validate:"required,min=3,max=32"` + Password string `json:"password" validate:"required,min=3"` // This is to limit the size of the password hash. + Description string `json:"description"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Policies []string `json:"policies"` } type Request struct { - Principal Principal - Action PolicyAction - Resource string + Actor Principal + Action PolicyAction + Resource string } type Vault struct { - Db VaultDB - Priv Privatiser - Logger Logger - Signer Signer + Db VaultDB + Priv Privatiser + Logger Logger + Signer Signer + Validator *validator.Validate } const ( @@ -92,7 +103,7 @@ const ( type VaultDB interface { GetCollection(ctx context.Context, name string) (*Collection, error) GetCollections(ctx context.Context) ([]string, error) - CreateCollection(ctx context.Context, c Collection) (string, error) + CreateCollection(ctx context.Context, col *Collection) error DeleteCollection(ctx context.Context, name string) error CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) GetRecords(ctx context.Context, collectionName string, recordIDs []string) (map[string]*Record, error) @@ -100,11 +111,11 @@ type VaultDB interface { UpdateRecord(ctx context.Context, collectionName string, recordID string, record Record) error DeleteRecord(ctx context.Context, collectionName string, recordID string) error GetPrincipal(ctx context.Context, username string) (*Principal, error) - CreatePrincipal(ctx context.Context, principal Principal) error + CreatePrincipal(ctx context.Context, principal *Principal) error DeletePrincipal(ctx context.Context, username string) error GetPolicy(ctx context.Context, policyId string) (*Policy, error) GetPolicies(ctx context.Context, policyIds []string) ([]*Policy, error) - CreatePolicy(ctx context.Context, p Policy) (string, error) + CreatePolicy(ctx context.Context, p *Policy) error DeletePolicy(ctx context.Context, policyId string) error CreateToken(ctx context.Context, tokenId string, value string) error DeleteToken(ctx context.Context, tokenId string) error @@ -157,25 +168,26 @@ func (vault Vault) GetCollections( func (vault Vault) CreateCollection( ctx context.Context, principal Principal, - col Collection, -) (string, error) { + col *Collection, +) error { request := Request{principal, PolicyActionWrite, COLLECTIONS_PPATH} allowed, err := vault.ValidateAction(ctx, request) if err != nil { - return "", err + return err } if !allowed { - return "", &ForbiddenError{request} + return &ForbiddenError{request} } - - if len(col.Name) < 3 { - return "", &ValueError{Msg: "collection name must be at least 3 characters"} + if err := vault.Validate(col); err != nil { + return err } - collectionId, err := vault.Db.CreateCollection(ctx, col) + col.Id = GenerateId("col") + + err = vault.Db.CreateCollection(ctx, col) if err != nil { - return "", err + return err } - return collectionId, nil + return nil } func (vault Vault) DeleteCollection( @@ -407,13 +419,10 @@ func (vault Vault) GetPrincipal( func (vault Vault) CreatePrincipal( ctx context.Context, - principal Principal, - username, - password, - description string, - policies []string, + actor Principal, + principal *Principal, ) error { - request := Request{principal, PolicyActionWrite, PRINCIPALS_PPATH} + request := Request{actor, PolicyActionWrite, PRINCIPALS_PPATH} allowed, err := vault.ValidateAction(ctx, request) if err != nil { return err @@ -422,16 +431,16 @@ func (vault Vault) CreatePrincipal( return &ForbiddenError{request} } - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - dbPrincipal := Principal{ - Username: username, - Password: string(hashedPassword), - CreatedAt: time.Now().Format(time.RFC3339), - Description: description, - Policies: policies, + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(principal.Password), bcrypt.DefaultCost) + principal.Password = string(hashedPassword) + principal.Id = GenerateId("prin") + principal.CreatedAt = time.Now().Format(time.RFC3339) + + if err := vault.Validate(principal); err != nil { + return err } - err = vault.Db.CreatePrincipal(ctx, dbPrincipal) + err = vault.Db.CreatePrincipal(ctx, principal) if err != nil { return err } @@ -489,22 +498,27 @@ func (vault Vault) Login( func (vault Vault) CreatePolicy( ctx context.Context, principal Principal, - p Policy, -) (string, error) { + p *Policy, +) error { request := Request{principal, PolicyActionWrite, POLICIES_PPATH} // Ensure resource starts with a slash for _, resource := range p.Resources { if !strings.HasPrefix(resource, "/") { - return "", &ValueError{Msg: fmt.Sprintf("resources must start with a slash - '%s' is not a valid resource", resource)} + return &ValueError{Msg: fmt.Sprintf("resources must start with a slash - '%s' is not a valid resource", resource)} } } allowed, err := vault.ValidateAction(ctx, request) if err != nil { - return "", err + return err } if !allowed { - return "", &ForbiddenError{request} + return &ForbiddenError{request} + } + err = vault.Validate(p) + if err != nil { + return err } + p.Id = GenerateId("pol") return vault.Db.CreatePolicy(ctx, p) } @@ -571,7 +585,7 @@ func (vault Vault) ValidateAction( ctx context.Context, request Request, ) (bool, error) { - policies, err := vault.Db.GetPolicies(ctx, request.Principal.Policies) + policies, err := vault.Db.GetPolicies(ctx, request.Actor.Policies) if err != nil { return false, err } @@ -630,3 +644,28 @@ func (vault Vault) GetTokenValue(ctx context.Context, principal Principal, token } return nil, &NotFoundError{"record", recordId} } + +func (vault *Vault) Validate(payload interface{}) error { + if vault.Validator == nil { + panic("Validator not set") + } + var errors []*ValidationError + + err := vault.Validator.Struct(payload) + if err != nil { + if _, ok := err.(*validator.InvalidValidationError); ok { + return &ValidationErrors{errors} + } + for _, err := range err.(validator.ValidationErrors) { + var element ValidationError + element.FailedField = err.Field() + element.Tag = err.Tag() + element.Value = err.Param() + errors = append(errors, &element) + } + } + if len(errors) == 0 { + return nil + } + return ValidationErrors{errors} +} diff --git a/vault/vault_test.go b/vault/vault_test.go index 31aeb3f..a9f926f 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -5,12 +5,15 @@ import ( "errors" "os" "testing" + "time" + "github.com/joho/godotenv" "github.com/stretchr/testify/assert" _logger "github.com/subrose/logger" ) func initVault(t *testing.T) (Vault, VaultDB, Privatiser) { + _ = godotenv.Load("../test.env") ctx := context.Background() db, err := NewSqlStore(os.Getenv("THORN_DATABASE_URL")) if err != nil { @@ -19,20 +22,28 @@ func initVault(t *testing.T) (Vault, VaultDB, Privatiser) { db.Flush(ctx) priv := NewAESPrivatiser([]byte{35, 46, 57, 24, 85, 35, 24, 74, 87, 35, 88, 98, 66, 32, 14, 05}, "abc&1*~#^2^#s0^=)^^7%b34") signer, _ := NewHMACSigner([]byte("testkey")) - _, _ = db.CreatePolicy(ctx, Policy{ - "root", - EffectAllow, - []PolicyAction{PolicyActionRead, PolicyActionWrite}, - []string{"*"}, + _ = db.CreatePolicy(ctx, &Policy{ + Id: "root", + Name: "root", + Description: "", + Effect: EffectAllow, + Actions: []PolicyAction{PolicyActionRead, PolicyActionWrite}, + Resources: []string{"*"}, + CreatedAt: time.Now().String(), + UpdatedAt: time.Now().String(), }) - _, _ = db.CreatePolicy(ctx, Policy{ - "read-all-customers", - EffectAllow, - []PolicyAction{PolicyActionRead}, - []string{"/collections/customers*"}, + _ = db.CreatePolicy(ctx, &Policy{ + Id: "read-all-customers", + Name: "read-all-customers", + Description: "", + Effect: EffectAllow, + Actions: []PolicyAction{PolicyActionRead}, + Resources: []string{"/collections/customers*"}, + CreatedAt: time.Now().String(), + UpdatedAt: time.Now().String(), }) vaultLogger, _ := _logger.NewLogger("TEST_VAULT", "none", "text", "debug", true) - vault := Vault{Db: db, Priv: priv, Logger: vaultLogger, Signer: signer} + vault := Vault{Db: db, Priv: priv, Logger: vaultLogger, Signer: signer, Validator: NewValidator()} return vault, db, priv } @@ -48,30 +59,26 @@ func TestVault(t *testing.T) { vault, _, _ := initVault(t) col := Collection{Name: "customers", Fields: map[string]Field{ "first_name": { - Name: "first_name", Type: "string", IsIndexed: false, }, "last_name": { - Name: "last_name", Type: "string", IsIndexed: false, }, "email": { - Name: "email", Type: "string", IsIndexed: true, }, "phone_number": { - Name: "phone_number", Type: "string", IsIndexed: true, }, }} // Can create collection - colID, err := vault.CreateCollection(ctx, testPrincipal, col) - if err != nil || colID == "" { + err := vault.CreateCollection(ctx, testPrincipal, &col) + if err != nil || col.Id == "" { t.Fatal(err) } @@ -144,12 +151,11 @@ func TestVault(t *testing.T) { vault, _, _ := initVault(t) col := Collection{Name: "customers", Fields: map[string]Field{ "first_name": { - Name: "first_name", Type: "string", IsIndexed: false, }, }} - _, _ = vault.CreateCollection(ctx, testPrincipal, col) + _ = vault.CreateCollection(ctx, testPrincipal, &col) // Create a dummy record recordID, err := vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ {"first_name": "dummy"}, @@ -183,14 +189,13 @@ func TestVault(t *testing.T) { vault, _, _ := initVault(t) col := Collection{Name: "testing", Fields: map[string]Field{ "test_field": { - Name: "test_field", Type: "string", IsIndexed: false, }, }} // Create collection - _, _ = vault.CreateCollection(ctx, testPrincipal, col) + _ = vault.CreateCollection(ctx, testPrincipal, &col) // Create a dummy record recordID, err := vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ @@ -224,12 +229,11 @@ func TestVault(t *testing.T) { vault, _, _ := initVault(t) col := Collection{Name: "test_collection", Fields: map[string]Field{ "first_name": { - Name: "first_name", Type: "string", IsIndexed: false, }, }} - _, _ = vault.CreateCollection(ctx, testPrincipal, col) + _ = vault.CreateCollection(ctx, testPrincipal, &col) inputRecords := []Record{{"invalid_field": "John"}} _, err := vault.CreateRecords(ctx, testPrincipal, col.Name, inputRecords) var ve *ValueError @@ -242,14 +246,13 @@ func TestVault(t *testing.T) { vault, _, _ := initVault(t) col := Collection{Name: "test_collection", Fields: map[string]Field{ "test_field": { - Name: "test_field", Type: "string", IsIndexed: false, }, }} // Create collection - _, _ = vault.CreateCollection(ctx, testPrincipal, col) + _ = vault.CreateCollection(ctx, testPrincipal, &col) // Create a dummy record recordID, err := vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ @@ -288,7 +291,7 @@ func TestVault(t *testing.T) { t.Error("Should throw a not found error!", err) } // Can create a principal - err = vault.CreatePrincipal(ctx, testPrincipal, testPrincipal.Username, testPrincipal.Password, "a test principal, again", []string{"read-all-customers"}) + err = vault.CreatePrincipal(ctx, testPrincipal, &Principal{Username: testPrincipal.Username, Password: testPrincipal.Password, Description: "a test principal, again", Policies: []string{"read-all-customers"}}) if err != nil { t.Fatal(err) } @@ -298,7 +301,7 @@ func TestVault(t *testing.T) { t.Run("can delete a principal", func(t *testing.T) { vault, _, _ := initVault(t) // Create a principal - err := vault.CreatePrincipal(ctx, testPrincipal, "test_user", "test_password", "test principal", []string{"root"}) + err := vault.CreatePrincipal(ctx, testPrincipal, &Principal{Username: "test_user", Password: "test_password", Description: "test principal", Policies: []string{"root"}}) if err != nil { t.Fatal(err) } @@ -319,12 +322,12 @@ func TestVault(t *testing.T) { t.Run("cant create the same principal twice", func(t *testing.T) { vault, _, _ := initVault(t) - err := vault.CreatePrincipal(ctx, testPrincipal, testPrincipal.Username, testPrincipal.Password, "a test principal", []string{"read-all-customers"}) + err := vault.CreatePrincipal(ctx, testPrincipal, &Principal{Username: testPrincipal.Username, Password: testPrincipal.Password, Description: "a test principal", Policies: []string{"read-all-customers"}}) if err != nil { t.Fatal(err) } - err2 := vault.CreatePrincipal(ctx, testPrincipal, testPrincipal.Username, testPrincipal.Password, "a test principal", []string{"read-all-customers"}) + err2 := vault.CreatePrincipal(ctx, testPrincipal, &Principal{Username: testPrincipal.Username, Password: testPrincipal.Password, Description: "a test principal", Policies: []string{"read-all-customers"}}) switch err2.(type) { case *ConflictError: // success @@ -344,14 +347,13 @@ func TestVault(t *testing.T) { // TODO: Smelly test, make this DRY col := Collection{Name: "customers", Fields: map[string]Field{ "first_name": { - Name: "first_name", Type: "string", IsIndexed: false, }, }} // Can create collection - _, _ = vault.CreateCollection(ctx, testPrincipal, col) + _ = vault.CreateCollection(ctx, testPrincipal, &col) record_ids, _ := vault.CreateRecords(ctx, testPrincipal, col.Name, []Record{ {"first_name": "John"}, {"first_name": "Jane"}, @@ -441,11 +443,11 @@ func TestVaultLogin(t *testing.T) { testPrincipal := Principal{ Username: "test_user", Password: "test_password", - Policies: []string{"root"}, Description: "test principal", + Policies: []string{"root"}, } - err := vault.CreatePrincipal(ctx, testPrincipal, testPrincipal.Username, testPrincipal.Password, testPrincipal.Description, testPrincipal.Policies) + err := vault.CreatePrincipal(ctx, testPrincipal, &Principal{Username: testPrincipal.Username, Password: testPrincipal.Password, Description: testPrincipal.Description, Policies: testPrincipal.Policies}) if err != nil { t.Fatal(err) } @@ -490,21 +492,21 @@ func TestTokens(t *testing.T) { Policies: []string{"read-all-customers"}, Description: "test principal", } - err := vault.CreatePrincipal(ctx, rootPrincipal, rootPrincipal.Username, rootPrincipal.Password, rootPrincipal.Description, rootPrincipal.Policies) + err := vault.CreatePrincipal(ctx, rootPrincipal, &Principal{Username: rootPrincipal.Username, Password: rootPrincipal.Password, Description: rootPrincipal.Description, Policies: rootPrincipal.Policies}) assert.NoError(t, err, "failed to create root principal") - err = vault.CreatePrincipal(ctx, rootPrincipal, testPrincipal.Username, testPrincipal.Password, testPrincipal.Description, testPrincipal.Policies) + err = vault.CreatePrincipal(ctx, rootPrincipal, &Principal{Username: testPrincipal.Username, Password: testPrincipal.Password, Description: testPrincipal.Description, Policies: testPrincipal.Policies}) assert.NoError(t, err, "failed to create test principal") // create collections - _, err = vault.CreateCollection(ctx, rootPrincipal, Collection{ + err = vault.CreateCollection(ctx, rootPrincipal, &Collection{ Name: "customers", - Fields: map[string]Field{"name": {"name", "string", false}, "foo": {"foo", "string", false}}, + Fields: map[string]Field{"name": {"string", false}, "foo": {"string", false}}, }) assert.NoError(t, err, "failed to create customer collection") - _, err = vault.CreateCollection(ctx, rootPrincipal, Collection{ + err = vault.CreateCollection(ctx, rootPrincipal, &Collection{ Name: "employees", - Fields: map[string]Field{"name": {"name", "string", false}, "foo": {"foo", "string", false}}, + Fields: map[string]Field{"name": {"string", false}, "foo": {"string", false}}, }) assert.NoError(t, err, "failed to create employees collection")