Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New models #61

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 5 additions & 40 deletions api/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,15 @@ import (
_vault "github.com/subrose/vault"
)

type CollectionFieldModel struct {
Type string `json:"type" validate:"required"`
IsIndexed bool `json:"indexed" validate:"required, boolean"`
}

type CollectionModel struct {
Name string `json:"name" validate:"required,vaultResourceNames"`
Fields map[string]CollectionFieldModel `json:"fields" validate:"required"`
}

func (core *Core) GetCollection(c *fiber.Ctx) error {
collectionName := c.Params("name")
principal := GetSessionPrincipal(c)
dbCollection, err := core.vault.GetCollection(c.Context(), principal, collectionName)
collection, err := core.vault.GetCollection(c.Context(), principal, collectionName)

if err != nil {
return err
}

collection := CollectionModel{
Name: dbCollection.Name,
Fields: make(map[string]CollectionFieldModel, len(dbCollection.Fields)),
}
for _, field := range dbCollection.Fields {
collection.Fields[field.Name] = CollectionFieldModel{
Type: field.Type,
IsIndexed: field.IsIndexed,
}
}
return c.Status(http.StatusOK).JSON(collection)
}

Expand All @@ -51,31 +31,16 @@ func (core *Core) GetCollections(c *fiber.Ctx) error {

func (core *Core) CreateCollection(c *fiber.Ctx) error {
principal := GetSessionPrincipal(c)
inputCollection := new(CollectionModel)
if err := core.ParseJsonBody(c.Body(), inputCollection); err != nil {
collection := new(_vault.Collection)
if err := core.ParseJsonBody(c.Body(), collection); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil})
}
if err := core.Validate(inputCollection); err != nil {
return c.Status(http.StatusBadRequest).JSON(err)
}

newCollection := _vault.Collection{
Name: inputCollection.Name,
Fields: make(map[string]_vault.Field, len(inputCollection.Fields)),
}
for fieldName, field := range inputCollection.Fields {
newCollection.Fields[fieldName] = _vault.Field{
Name: fieldName,
Type: field.Type,
IsIndexed: field.IsIndexed,
}
}

_, err := core.vault.CreateCollection(c.Context(), principal, newCollection)
err := core.vault.CreateCollection(c.Context(), principal, collection)
if err != nil {
return err
}
return c.Status(http.StatusCreated).SendString("Collection created")
return c.Status(http.StatusCreated).JSON(collection)
}

func (core *Core) DeleteCollection(c *fiber.Ctx) error {
Expand Down
10 changes: 5 additions & 5 deletions api/collections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
func TestCollections(t *testing.T) {
app, core := InitTestingVault(t)

customerCollection := CollectionModel{
customerCollection := &_vault.Collection{
Name: "customers",
Fields: map[string]CollectionFieldModel{
Fields: map[string]_vault.Field{
"name": {Type: "name", IsIndexed: true},
"phone_number": {Type: "phone_number", IsIndexed: true},
"dob": {Type: "date", IsIndexed: false},
Expand All @@ -36,7 +36,7 @@ func TestCollections(t *testing.T) {
}, nil)

response := performRequest(t, app, request)
var returnedCollection CollectionModel
var returnedCollection _vault.Collection
checkResponse(t, response, http.StatusOK, &returnedCollection)

if returnedCollection.Name != "customers" {
Expand All @@ -60,9 +60,9 @@ func TestCollections(t *testing.T) {

t.Run("can delete a collection", func(t *testing.T) {
// Create a dummy collection
collectionToDelete := CollectionModel{
collectionToDelete := _vault.Collection{
Name: "delete_me",
Fields: map[string]CollectionFieldModel{
Fields: map[string]_vault.Field{
"name": {Type: "name", IsIndexed: true},
},
}
Expand Down
50 changes: 15 additions & 35 deletions api/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"encoding/json"
"errors"

"github.com/go-playground/validator/v10"

"github.com/knadh/koanf"
"github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/providers/env"
Expand Down Expand Up @@ -35,10 +33,9 @@ type CoreConfig struct {
// interface for API handlers and is responsible for managing the logical and physical
// backends, router, security barrier, and audit trails.
type Core struct {
vault _vault.Vault
logger _logger.Logger
conf *CoreConfig
validator *validator.Validate
vault _vault.Vault
logger _logger.Logger
conf *CoreConfig
}

func ReadConfigs() (*CoreConfig, error) {
Expand Down Expand Up @@ -132,14 +129,14 @@ func CreateCore(conf *CoreConfig) (*Core, error) {

vaultLogger, err := _logger.NewLogger("VAULT", conf.LOG_SINK, conf.LOG_HANDLER, conf.LOG_LEVEL, conf.DEV_MODE)
vault := _vault.Vault{
Db: db,
Priv: priv,
Logger: vaultLogger,
Signer: signer,
Db: db,
Priv: priv,
Logger: vaultLogger,
Signer: signer,
Validator: _vault.NewValidator(),
}

c.vault = vault
c.validator = newValidator()

return c, err
}
Expand All @@ -155,8 +152,11 @@ func (core *Core) Init() error {
panic(err)
}
}
_, err := core.vault.Db.CreatePolicy(ctx, _vault.Policy{
PolicyId: "root",
// TODO: Move this to a bootstrap function
rootPolicyId := _vault.GenerateId("pol")
err := core.vault.Db.CreatePolicy(ctx, &_vault.Policy{
Id: rootPolicyId,
Name: "root",
Effect: _vault.EffectAllow,
Actions: []_vault.PolicyAction{_vault.PolicyActionWrite, _vault.PolicyActionRead},
Resources: []string{"*"},
Expand All @@ -174,8 +174,8 @@ func (core *Core) Init() error {
Username: core.conf.ADMIN_USERNAME,
Password: core.conf.ADMIN_PASSWORD,
Description: "admin",
Policies: []string{"root"}}
err = core.vault.CreatePrincipal(ctx, adminPrincipal, adminPrincipal.Username, adminPrincipal.Password, adminPrincipal.Description, adminPrincipal.Policies)
Policies: []string{rootPolicyId}}
err = core.vault.CreatePrincipal(ctx, adminPrincipal, &adminPrincipal) // The admin bootstraps himself

var co *_vault.ConflictError
if err != nil {
Expand All @@ -189,26 +189,6 @@ func (core *Core) Init() error {
return nil
}

func (core *Core) Validate(payload interface{}) []*ValidationError {
var errors []*ValidationError

err := core.validator.Struct(payload)
if err != nil {
// Check if the error is a validator.ValidationErrors type
if _, ok := err.(*validator.InvalidValidationError); ok {
return errors
}
for _, err := range err.(validator.ValidationErrors) {
var element ValidationError
element.FailedField = err.Field()
element.Tag = err.Tag()
element.Value = err.Param()
errors = append(errors, &element)
}
}
return errors
}

func (core *Core) ParseJsonBody(data []byte, payload interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(data))
decoder.DisallowUnknownFields()
Expand Down
42 changes: 0 additions & 42 deletions api/errors.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
package main

import (
"reflect"
"regexp"
"strings"

"github.com/go-playground/validator/v10"
)

type AuthError struct{ Msg string }

func (e *AuthError) Error() string {
Expand All @@ -18,37 +10,3 @@ type ErrorResponse struct {
Message string `json:"message"`
Errors []*interface{} `json:"errors"`
}

func ValidateResourceName(fl validator.FieldLevel) bool {
reg := "^[a-zA-Z0-9._]{1,249}$"
match, _ := regexp.MatchString(reg, fl.Field().String())

// Check for prohibited values: single period, double underscore, and hyphen
if fl.Field().String() == "." || fl.Field().String() == "__" || fl.Field().String() == "-" {
return false
}

return match
}

type ValidationError struct {
FailedField string `json:"failed_field"`
Tag string `json:"tag"`
Value string `json:"value"`
}

func newValidator() *validator.Validate {
v := validator.New(validator.WithRequiredStructEnabled())
err := v.RegisterValidation("vaultResourceNames", ValidateResourceName)
if err != nil {
panic(err)
}
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
return v
}
2 changes: 1 addition & 1 deletion api/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.21

require (
github.com/go-playground/assert/v2 v2.2.0
github.com/go-playground/validator/v10 v10.16.0
github.com/gofiber/fiber/v2 v2.51.0
github.com/joho/godotenv v1.3.0
github.com/knadh/koanf v1.5.0
Expand All @@ -29,6 +28,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.16.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
Expand Down
3 changes: 3 additions & 0 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error {
var ne *_vault.NotFoundError
var ae *AuthError
var co *_vault.ConflictError
var va *_vault.ValidationErrors
switch {
case errors.As(err, &ve):
return ctx.Status(http.StatusBadRequest).JSON(ErrorResponse{ve.Error(), nil})
Expand All @@ -106,6 +107,8 @@ func (core *Core) customErrorHandler(ctx *fiber.Ctx, err error) error {
return ctx.Status(http.StatusNotImplemented).JSON(ErrorResponse{err.Error(), nil})
case errors.As(err, &co):
return ctx.Status(http.StatusConflict).JSON(ErrorResponse{co.Error(), nil})
case errors.As(err, &va):
return ctx.Status(http.StatusConflict).JSON(ErrorResponse{va.Error(), nil})
default:
// Handle other types of errors by returning a generic 500 - this should remain obscure as it can leak information
core.logger.Error(fmt.Sprintf("Unhandled error: %s", err.Error()))
Expand Down
6 changes: 1 addition & 5 deletions api/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ func (core *Core) CreatePolicy(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(ErrorResponse{"Invalid body", nil})
}

if validationErrs := core.Validate(policy); validationErrs != nil {
return c.Status(fiber.StatusBadRequest).JSON(validationErrs)
}

_, err := core.vault.CreatePolicy(c.Context(), sessionPrincipal, policy)
err := core.vault.CreatePolicy(c.Context(), sessionPrincipal, &policy)
if err != nil {
core.logger.Error(fmt.Sprintf("Failed to create policy %v", err))
return err
Expand Down
18 changes: 10 additions & 8 deletions api/policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ import (
func TestPolicies(t *testing.T) {
app, core := InitTestingVault(t)

testPolicyId := "test-policy"
testPolicyId := ""
testResource := "/resources/stuff"

t.Run("can create policy", func(t *testing.T) {
testPolicy := _vault.Policy{
PolicyId: testPolicyId,
Name: "test-policy",
Effect: _vault.EffectAllow,
Actions: []_vault.PolicyAction{_vault.PolicyActionRead},
Resources: []string{fmt.Sprintf("/policies/%s", testPolicyId)},
Resources: []string{testResource},
}

request := newRequest(t, http.MethodPost, "/policies", map[string]string{
Expand All @@ -29,11 +30,12 @@ func TestPolicies(t *testing.T) {
response := performRequest(t, app, request)
var createdPolicy _vault.Policy
checkResponse(t, response, http.StatusCreated, &createdPolicy)
testPolicyId = createdPolicy.Id

// Assertions
assert.Equal(t, _vault.EffectAllow, createdPolicy.Effect)
assert.Equal(t, []_vault.PolicyAction{_vault.PolicyActionRead}, createdPolicy.Actions)
assert.Equal(t, []string{fmt.Sprintf("/policies/%s", testPolicyId)}, createdPolicy.Resources)
assert.Equal(t, []string{testResource}, createdPolicy.Resources)
})

t.Run("can get policy", func(t *testing.T) {
Expand All @@ -48,13 +50,13 @@ func TestPolicies(t *testing.T) {
// Assertions
assert.Equal(t, _vault.EffectAllow, returnedPolicy.Effect)
assert.Equal(t, []_vault.PolicyAction{_vault.PolicyActionRead}, returnedPolicy.Actions)
assert.Equal(t, []string{fmt.Sprintf("/policies/%s", testPolicyId)}, returnedPolicy.Resources)
assert.Equal(t, []string{testResource}, returnedPolicy.Resources)
})

t.Run("can delete policy", func(t *testing.T) {
// Add a dummy policy first before deleting
dummyPolicy := _vault.Policy{
PolicyId: "dummy-policy",
Name: "dummy-policy",
Effect: _vault.EffectAllow,
Actions: []_vault.PolicyAction{_vault.PolicyActionRead},
Resources: []string{"/policies/dummy-policy"},
Expand All @@ -69,15 +71,15 @@ func TestPolicies(t *testing.T) {
checkResponse(t, response, http.StatusCreated, &returnedPolicy)

// Delete it
request = newRequest(t, http.MethodDelete, fmt.Sprintf("/policies/%s", dummyPolicy.PolicyId), map[string]string{
request = newRequest(t, http.MethodDelete, fmt.Sprintf("/policies/%s", returnedPolicy.Id), map[string]string{
"Authorization": createBasicAuthHeader(core.conf.ADMIN_USERNAME, core.conf.ADMIN_PASSWORD),
}, nil)

response = performRequest(t, app, request)
checkResponse(t, response, http.StatusNoContent, nil)

// Check it's gone
request = newRequest(t, http.MethodGet, fmt.Sprintf("/policies/%s", dummyPolicy.PolicyId), map[string]string{
request = newRequest(t, http.MethodGet, fmt.Sprintf("/policies/%s", returnedPolicy.Id), map[string]string{
"Authorization": createBasicAuthHeader(core.conf.ADMIN_USERNAME, core.conf.ADMIN_PASSWORD),
}, nil)

Expand Down
Loading
Loading