Skip to content

Commit

Permalink
New models (#61)
Browse files Browse the repository at this point in the history
* WIP: Models refactor

* More models

* New Models Fixes (#60)

* tests passing

* tidy mods

* validations on username

* more validations

* change to actor

---------

Co-authored-by: Paco Nelos <nelospaco@gmail.com>
  • Loading branch information
subroseio and PacoNelos authored Dec 4, 2023
1 parent 1fdda98 commit 1b01303
Show file tree
Hide file tree
Showing 26 changed files with 349 additions and 353 deletions.
45 changes: 5 additions & 40 deletions api/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,15 @@ import (
_vault "github.com/subrose/vault"
)

type CollectionFieldModel struct {
Type string `json:"type" validate:"required"`
IsIndexed bool `json:"indexed" validate:"required, boolean"`
}

type CollectionModel struct {
Name string `json:"name" validate:"required,vaultResourceNames"`
Fields map[string]CollectionFieldModel `json:"fields" validate:"required"`
}

func (core *Core) GetCollection(c *fiber.Ctx) error {
collectionName := c.Params("name")
principal := GetSessionPrincipal(c)
dbCollection, err := core.vault.GetCollection(c.Context(), principal, collectionName)
collection, err := core.vault.GetCollection(c.Context(), principal, collectionName)

if err != nil {
return err
}

collection := CollectionModel{
Name: dbCollection.Name,
Fields: make(map[string]CollectionFieldModel, len(dbCollection.Fields)),
}
for _, field := range dbCollection.Fields {
collection.Fields[field.Name] = CollectionFieldModel{
Type: field.Type,
IsIndexed: field.IsIndexed,
}
}
return c.Status(http.StatusOK).JSON(collection)
}

Expand All @@ -51,31 +31,16 @@ func (core *Core) GetCollections(c *fiber.Ctx) error {

func (core *Core) CreateCollection(c *fiber.Ctx) error {
principal := GetSessionPrincipal(c)
inputCollection := new(CollectionModel)
if err := core.ParseJsonBody(c.Body(), inputCollection); err != nil {
collection := new(_vault.Collection)
if err := core.ParseJsonBody(c.Body(), collection); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil})
}
if err := core.Validate(inputCollection); err != nil {
return c.Status(http.StatusBadRequest).JSON(err)
}

newCollection := _vault.Collection{
Name: inputCollection.Name,
Fields: make(map[string]_vault.Field, len(inputCollection.Fields)),
}
for fieldName, field := range inputCollection.Fields {
newCollection.Fields[fieldName] = _vault.Field{
Name: fieldName,
Type: field.Type,
IsIndexed: field.IsIndexed,
}
}

_, err := core.vault.CreateCollection(c.Context(), principal, newCollection)
err := core.vault.CreateCollection(c.Context(), principal, collection)
if err != nil {
return err
}
return c.Status(http.StatusCreated).SendString("Collection created")
return c.Status(http.StatusCreated).JSON(collection)
}

func (core *Core) DeleteCollection(c *fiber.Ctx) error {
Expand Down
10 changes: 5 additions & 5 deletions api/collections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
func TestCollections(t *testing.T) {
app, core := InitTestingVault(t)

customerCollection := CollectionModel{
customerCollection := &_vault.Collection{
Name: "customers",
Fields: map[string]CollectionFieldModel{
Fields: map[string]_vault.Field{
"name": {Type: "name", IsIndexed: true},
"phone_number": {Type: "phone_number", IsIndexed: true},
"dob": {Type: "date", IsIndexed: false},
Expand All @@ -36,7 +36,7 @@ func TestCollections(t *testing.T) {
}, nil)

response := performRequest(t, app, request)
var returnedCollection CollectionModel
var returnedCollection _vault.Collection
checkResponse(t, response, http.StatusOK, &returnedCollection)

if returnedCollection.Name != "customers" {
Expand All @@ -60,9 +60,9 @@ func TestCollections(t *testing.T) {

t.Run("can delete a collection", func(t *testing.T) {
// Create a dummy collection
collectionToDelete := CollectionModel{
collectionToDelete := _vault.Collection{
Name: "delete_me",
Fields: map[string]CollectionFieldModel{
Fields: map[string]_vault.Field{
"name": {Type: "name", IsIndexed: true},
},
}
Expand Down
50 changes: 15 additions & 35 deletions api/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"encoding/json"
"errors"

"github.com/go-playground/validator/v10"

"github.com/knadh/koanf"
"github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/providers/env"
Expand Down Expand Up @@ -35,10 +33,9 @@ type CoreConfig struct {
// interface for API handlers and is responsible for managing the logical and physical
// backends, router, security barrier, and audit trails.
type Core struct {
vault _vault.Vault
logger _logger.Logger
conf *CoreConfig
validator *validator.Validate
vault _vault.Vault
logger _logger.Logger
conf *CoreConfig
}

func ReadConfigs() (*CoreConfig, error) {
Expand Down Expand Up @@ -132,14 +129,14 @@ func CreateCore(conf *CoreConfig) (*Core, error) {

vaultLogger, err := _logger.NewLogger("VAULT", conf.LOG_SINK, conf.LOG_HANDLER, conf.LOG_LEVEL, conf.DEV_MODE)
vault := _vault.Vault{
Db: db,
Priv: priv,
Logger: vaultLogger,
Signer: signer,
Db: db,
Priv: priv,
Logger: vaultLogger,
Signer: signer,
Validator: _vault.NewValidator(),
}

c.vault = vault
c.validator = newValidator()

return c, err
}
Expand All @@ -155,8 +152,11 @@ func (core *Core) Init() error {
panic(err)
}
}
_, err := core.vault.Db.CreatePolicy(ctx, _vault.Policy{
PolicyId: "root",
// TODO: Move this to a bootstrap function
rootPolicyId := _vault.GenerateId("pol")
err := core.vault.Db.CreatePolicy(ctx, &_vault.Policy{
Id: rootPolicyId,
Name: "root",
Effect: _vault.EffectAllow,
Actions: []_vault.PolicyAction{_vault.PolicyActionWrite, _vault.PolicyActionRead},
Resources: []string{"*"},
Expand All @@ -174,8 +174,8 @@ func (core *Core) Init() error {
Username: core.conf.ADMIN_USERNAME,
Password: core.conf.ADMIN_PASSWORD,
Description: "admin",
Policies: []string{"root"}}
err = core.vault.CreatePrincipal(ctx, adminPrincipal, adminPrincipal.Username, adminPrincipal.Password, adminPrincipal.Description, adminPrincipal.Policies)
Policies: []string{rootPolicyId}}
err = core.vault.CreatePrincipal(ctx, adminPrincipal, &adminPrincipal) // The admin bootstraps himself

var co *_vault.ConflictError
if err != nil {
Expand All @@ -189,26 +189,6 @@ func (core *Core) Init() error {
return nil
}

func (core *Core) Validate(payload interface{}) []*ValidationError {
var errors []*ValidationError

err := core.validator.Struct(payload)
if err != nil {
// Check if the error is a validator.ValidationErrors type
if _, ok := err.(*validator.InvalidValidationError); ok {
return errors
}
for _, err := range err.(validator.ValidationErrors) {
var element ValidationError
element.FailedField = err.Field()
element.Tag = err.Tag()
element.Value = err.Param()
errors = append(errors, &element)
}
}
return errors
}

func (core *Core) ParseJsonBody(data []byte, payload interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(data))
decoder.DisallowUnknownFields()
Expand Down
42 changes: 0 additions & 42 deletions api/errors.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
package main

import (
"reflect"
"regexp"
"strings"

"github.com/go-playground/validator/v10"
)

type AuthError struct{ Msg string }

func (e *AuthError) Error() string {
Expand All @@ -18,37 +10,3 @@ type ErrorResponse struct {
Message string `json:"message"`
Errors []*interface{} `json:"errors"`
}

func ValidateResourceName(fl validator.FieldLevel) bool {
reg := "^[a-zA-Z0-9._]{1,249}$"
match, _ := regexp.MatchString(reg, fl.Field().String())

// Check for prohibited values: single period, double underscore, and hyphen
if fl.Field().String() == "." || fl.Field().String() == "__" || fl.Field().String() == "-" {
return false
}

return match
}

type ValidationError struct {
FailedField string `json:"failed_field"`
Tag string `json:"tag"`
Value string `json:"value"`
}

func newValidator() *validator.Validate {
v := validator.New(validator.WithRequiredStructEnabled())
err := v.RegisterValidation("vaultResourceNames", ValidateResourceName)
if err != nil {
panic(err)
}
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
return v
}
2 changes: 1 addition & 1 deletion api/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.21

require (
github.com/go-playground/assert/v2 v2.2.0
github.com/go-playground/validator/v10 v10.16.0
github.com/gofiber/fiber/v2 v2.51.0
github.com/joho/godotenv v1.3.0
github.com/knadh/koanf v1.5.0
Expand All @@ -29,6 +28,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.16.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
Expand Down
3 changes: 3 additions & 0 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error {
var ne *_vault.NotFoundError
var ae *AuthError
var co *_vault.ConflictError
var va *_vault.ValidationErrors
switch {
case errors.As(err, &ve):
return ctx.Status(http.StatusBadRequest).JSON(ErrorResponse{ve.Error(), nil})
Expand All @@ -106,6 +107,8 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error {
return ctx.Status(http.StatusNotImplemented).JSON(ErrorResponse{err.Error(), nil})
case errors.As(err, &co):
return ctx.Status(http.StatusConflict).JSON(ErrorResponse{co.Error(), nil})
case errors.As(err, &va):
return ctx.Status(http.StatusConflict).JSON(ErrorResponse{va.Error(), nil})
default:
// Handle other types of errors by returning a generic 500 - this should remain obscure as it can leak information
core.logger.Error(fmt.Sprintf("Unhandled error: %s", err.Error()))
Expand Down
6 changes: 1 addition & 5 deletions api/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ func (core *Core) CreatePolicy(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil})
}

if validationErrs := core.Validate(policy); validationErrs != nil {
return c.Status(fiber.StatusBadRequest).JSON(validationErrs)
}

_, err := core.vault.CreatePolicy(c.Context(), sessionPrincipal, policy)
err := core.vault.CreatePolicy(c.Context(), sessionPrincipal, &policy)
if err != nil {
core.logger.Error(fmt.Sprintf("Failed to create policy %v", err))
return err
Expand Down
18 changes: 10 additions & 8 deletions api/policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ import (
func TestPolicies(t *testing.T) {
app, core := InitTestingVault(t)

testPolicyId := "test-policy"
testPolicyId := ""
testResource := "/resources/stuff"

t.Run("can create policy", func(t *testing.T) {
testPolicy := _vault.Policy{
PolicyId: testPolicyId,
Name: "test-policy",
Effect: _vault.EffectAllow,
Actions: []_vault.PolicyAction{_vault.PolicyActionRead},
Resources: []string{fmt.Sprintf("/policies/%s", testPolicyId)},
Resources: []string{testResource},
}

request := newRequest(t, http.MethodPost, "/policies", map[string]string{
Expand All @@ -29,11 +30,12 @@ func TestPolicies(t *testing.T) {
response := performRequest(t, app, request)
var createdPolicy _vault.Policy
checkResponse(t, response, http.StatusCreated, &createdPolicy)
testPolicyId = createdPolicy.Id

// Assertions
assert.Equal(t, _vault.EffectAllow, createdPolicy.Effect)
assert.Equal(t, []_vault.PolicyAction{_vault.PolicyActionRead}, createdPolicy.Actions)
assert.Equal(t, []string{fmt.Sprintf("/policies/%s", testPolicyId)}, createdPolicy.Resources)
assert.Equal(t, []string{testResource}, createdPolicy.Resources)
})

t.Run("can get policy", func(t *testing.T) {
Expand All @@ -48,13 +50,13 @@ func TestPolicies(t *testing.T) {
// Assertions
assert.Equal(t, _vault.EffectAllow, returnedPolicy.Effect)
assert.Equal(t, []_vault.PolicyAction{_vault.PolicyActionRead}, returnedPolicy.Actions)
assert.Equal(t, []string{fmt.Sprintf("/policies/%s", testPolicyId)}, returnedPolicy.Resources)
assert.Equal(t, []string{testResource}, returnedPolicy.Resources)
})

t.Run("can delete policy", func(t *testing.T) {
// Add a dummy policy first before deleting
dummyPolicy := _vault.Policy{
PolicyId: "dummy-policy",
Name: "dummy-policy",
Effect: _vault.EffectAllow,
Actions: []_vault.PolicyAction{_vault.PolicyActionRead},
Resources: []string{"/policies/dummy-policy"},
Expand All @@ -69,15 +71,15 @@ func TestPolicies(t *testing.T) {
checkResponse(t, response, http.StatusCreated, &returnedPolicy)

// Delete it
request = newRequest(t, http.MethodDelete, fmt.Sprintf("/policies/%s", dummyPolicy.PolicyId), map[string]string{
request = newRequest(t, http.MethodDelete, fmt.Sprintf("/policies/%s", returnedPolicy.Id), map[string]string{
"Authorization": createBasicAuthHeader(core.conf.ADMIN_USERNAME, core.conf.ADMIN_PASSWORD),
}, nil)

response = performRequest(t, app, request)
checkResponse(t, response, http.StatusNoContent, nil)

// Check it's gone
request = newRequest(t, http.MethodGet, fmt.Sprintf("/policies/%s", dummyPolicy.PolicyId), map[string]string{
request = newRequest(t, http.MethodGet, fmt.Sprintf("/policies/%s", returnedPolicy.Id), map[string]string{
"Authorization": createBasicAuthHeader(core.conf.ADMIN_USERNAME, core.conf.ADMIN_PASSWORD),
}, nil)

Expand Down
Loading

0 comments on commit 1b01303

Please sign in to comment.