Skip to content

Commit

Permalink
feat: move core and mux specific function in specific package
Browse files Browse the repository at this point in the history
  • Loading branch information
davidebianchi committed Jun 29, 2023
1 parent 4c0d6fd commit 618e5f1
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 183 deletions.
56 changes: 0 additions & 56 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
package core

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/types"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -115,56 +112,3 @@ func CreateRegoQueryInput(
type RondInput interface {
Input(user types.User, responseBody any) (Input, error)
}

type requestInfo struct {
*http.Request
clientTypeHeaderKey string
pathParams map[string]string
}

func (req requestInfo) Input(user types.User, responseBody any) (Input, error) {
shouldParseJSONBody := utils.HasApplicationJSONContentType(req.Header) &&
req.ContentLength > 0 &&
(req.Method == http.MethodPatch || req.Method == http.MethodPost || req.Method == http.MethodPut || req.Method == http.MethodDelete)

var requestBody any
if shouldParseJSONBody {
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return Input{}, fmt.Errorf("failed request body parse: %s", err.Error())
}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
return Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error())
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

return Input{
ClientType: req.Header.Get(req.clientTypeHeaderKey),
Request: InputRequest{
Method: req.Method,
Path: req.URL.Path,
Headers: req.Header,
Query: req.URL.Query(),
PathParams: req.pathParams,
Body: requestBody,
},
Response: InputResponse{
Body: responseBody,
},
User: InputUser{
Properties: user.Properties,
Groups: user.UserGroups,
Bindings: user.UserBindings,
Roles: user.UserRoles,
},
}, nil
}

func NewRondInput(req *http.Request, clientTypeHeaderKey string, pathParams map[string]string) RondInput {
return requestInfo{
Request: req,
clientTypeHeaderKey: clientTypeHeaderKey,
pathParams: pathParams,
}
}
103 changes: 27 additions & 76 deletions core/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@
package core

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/types"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -247,78 +242,34 @@ func BenchmarkBuildOptimizedResourcePermissionsMap(b *testing.B) {
}
}

func TestRondInput(t *testing.T) {
user := types.User{}
clientTypeHeaderKey := "clienttypeheader"
pathParams := map[string]string{}

t.Run("request body integration", func(t *testing.T) {
expectedRequestBody := map[string]interface{}{
"Key": float64(42),
}
reqBody := struct{ Key int }{
Key: 42,
}
reqBodyBytes, err := json.Marshal(reqBody)
require.Nil(t, err, "Unexpected error")

t.Run("ignored on method GET", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(reqBodyBytes))

rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
input, err := rondRequest.Input(user, nil)
require.NoError(t, err, "Unexpected error")
require.Nil(t, input.Request.Body)
})

t.Run("ignore nil body on method POST", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(utils.ContentTypeHeaderKey, "application/json")

rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
input, err := rondRequest.Input(user, nil)
require.NoError(t, err, "Unexpected error")
require.Nil(t, input.Request.Body)
})

t.Run("added on accepted methods", func(t *testing.T) {
acceptedMethods := []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete}

for _, method := range acceptedMethods {
req := httptest.NewRequest(method, "/", bytes.NewReader(reqBodyBytes))
req.Header.Set(utils.ContentTypeHeaderKey, "application/json")
rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
input, err := rondRequest.Input(user, nil)
require.NoError(t, err, "Unexpected error")
require.Equal(t, expectedRequestBody, input.Request.Body)
}
})

t.Run("added with content-type specifying charset", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBodyBytes))
req.Header.Set(utils.ContentTypeHeaderKey, "application/json;charset=UTF-8")
rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
input, err := rondRequest.Input(user, nil)
require.NoError(t, err, "Unexpected error")
require.Equal(t, expectedRequestBody, input.Request.Body)
})
type FakeInput struct {
request InputRequest
clientType string
}

t.Run("reject on method POST but with invalid body", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("{notajson}")))
req.Header.Set(utils.ContentTypeHeaderKey, "application/json")
rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
_, err := rondRequest.Input(user, nil)
require.ErrorContains(t, err, "failed request body deserialization:")
})
func (i FakeInput) Input(user types.User, responseBody any) (Input, error) {
return Input{
User: InputUser{
Properties: user.Properties,
Groups: user.UserGroups,
Bindings: user.UserBindings,
Roles: user.UserRoles,
},
Request: i.request,
Response: InputResponse{
Body: responseBody,
},
ClientType: i.clientType,
}, nil
}

t.Run("ignore body on method POST but with another content type", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("{notajson}")))
req.Header.Set(utils.ContentTypeHeaderKey, "multipart/form-data")
func getFakeInput(t require.TestingT, request InputRequest, clientType string) RondInput {
if h, ok := t.(tHelper); ok {
h.Helper()
}

rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
input, err := rondRequest.Input(user, nil)
require.NoError(t, err, "Unexpected error")
require.Nil(t, input.Request.Body)
})
})
return FakeInput{
request: request,
clientType: clientType,
}
}
42 changes: 29 additions & 13 deletions core/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ func TestSDK(t *testing.T) {
func TestEvaluateRequestPolicy(t *testing.T) {
logger := logrus.NewEntry(logrus.New())

clientTypeHeaderKey := "client-header-key"

t.Run("throws without RondInput", func(t *testing.T) {
sdk := getSdk(t, nil)
evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/users/")
Expand Down Expand Up @@ -387,13 +385,17 @@ func TestEvaluateRequestPolicy(t *testing.T) {
evaluate, err := sdk.FindEvaluator(logger, testCase.method, testCase.path)
require.NoError(t, err)

req := httptest.NewRequest(testCase.method, testCase.path, nil)
headers := http.Header{}
if testCase.reqHeaders != nil {
for k, v := range testCase.reqHeaders {
req.Header.Set(k, v)
headers.Set(k, v)
}
}
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)
rondInput := getFakeInput(t, InputRequest{
Headers: headers,
Path: testCase.path,
Method: testCase.method,
}, "")

actual, err := evaluate.EvaluateRequestPolicy(context.Background(), rondInput, testCase.user)
if testCase.expectedErr {
Expand Down Expand Up @@ -469,8 +471,6 @@ func assertCorrectMetrics(t *testing.T, registry *prometheus.Registry, expected
func TestEvaluateResponsePolicy(t *testing.T) {
logger := logrus.NewEntry(logrus.New())

clientTypeHeaderKey := "client-header-key"

t.Run("throws without RondInput", func(t *testing.T) {
sdk := getSdk(t, nil)
evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/users/")
Expand Down Expand Up @@ -587,7 +587,17 @@ func TestEvaluateResponsePolicy(t *testing.T) {
req.Header.Set(k, v)
}
}
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)
headers := http.Header{}
if testCase.reqHeaders != nil {
for k, v := range testCase.reqHeaders {
headers.Set(k, v)
}
}
rondInput := getFakeInput(t, InputRequest{
Headers: headers,
Path: testCase.path,
Method: testCase.method,
}, "")

actual, err := evaluate.EvaluateResponsePolicy(context.Background(), rondInput, testCase.user, testCase.decodedBody)
if testCase.expectedErr {
Expand Down Expand Up @@ -677,12 +687,18 @@ func BenchmarkEvaluateRequest(b *testing.B) {

for n := 0; n < b.N; n++ {
b.StopTimer()
req := httptest.NewRequest(http.MethodGet, "/projects/project123", nil)
req.Header.Set("my-header", "value")
headers := http.Header{}
headers.Set("my-header", "value")
recorder := httptest.NewRecorder()
rondInput := NewRondInput(req, "", map[string]string{
"projectId": "project123",
})

rondInput := getFakeInput(b, InputRequest{
Path: "/projects/project123",
Headers: headers,
Method: http.MethodGet,
PathParams: map[string]string{
"projectId": "project123",
},
}, "")
b.StartTimer()
evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/projects/project123")
require.NoError(b, err)
Expand Down
80 changes: 80 additions & 0 deletions routers/mux/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2023 Mia srl
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rondmux

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/rond-authz/rond/core"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/types"
)

type requestInfo struct {
*http.Request
clientTypeHeaderKey string
pathParams map[string]string
}

func (req requestInfo) Input(user types.User, responseBody any) (core.Input, error) {
shouldParseJSONBody := utils.HasApplicationJSONContentType(req.Header) &&
req.ContentLength > 0 &&
(req.Method == http.MethodPatch || req.Method == http.MethodPost || req.Method == http.MethodPut || req.Method == http.MethodDelete)

var requestBody any
if shouldParseJSONBody {
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return core.Input{}, fmt.Errorf("failed request body parse: %s", err.Error())
}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
return core.Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error())
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

return core.Input{
ClientType: req.Header.Get(req.clientTypeHeaderKey),
Request: core.InputRequest{
Method: req.Method,
Path: req.URL.Path,
Headers: req.Header,
Query: req.URL.Query(),
PathParams: req.pathParams,
Body: requestBody,
},
Response: core.InputResponse{
Body: responseBody,
},
User: core.InputUser{
Properties: user.Properties,
Groups: user.UserGroups,
Bindings: user.UserBindings,
Roles: user.UserRoles,
},
}, nil
}

func NewRondInput(req *http.Request, clientTypeHeaderKey string, pathParams map[string]string) core.RondInput {
return requestInfo{
Request: req,
clientTypeHeaderKey: clientTypeHeaderKey,
pathParams: pathParams,
}
}
Loading

0 comments on commit 618e5f1

Please sign in to comment.