Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

refactor middleware and cache in an attempt of clarifying its operation #305

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions api/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import (
log "github.com/sirupsen/logrus"
)

var datasetsKey = "datasets"
// requestContextKey holds a name for the request context storage key
// which is used to store and get the permissions after passing middleware
const requestContextKey = "requestContextKey"

// TokenMiddleware performs access token verification and validation
// JWTs are verified and validated by the app, opaque tokens are sent to AAI for verification
// Successful auth results in list of authorised datasets
// Successful auth results in list of authorised datasets.
// The datasets are stored into a session cache for subsequent requests, and also
// to the current request context for use in the endpoints.
func TokenMiddleware() gin.HandlerFunc {

return func(c *gin.Context) {
Expand All @@ -23,11 +27,11 @@ func TokenMiddleware() gin.HandlerFunc {
if err != nil {
log.Debugf("no session cookie received")
}
var datasetCache session.DatasetCache
var cache session.Cache
var exists bool
if sessionCookie != "" {
log.Debug("session cookie received")
datasetCache, exists = session.Get(sessionCookie)
cache, exists = session.Get(sessionCookie)
}

if !exists {
Expand Down Expand Up @@ -57,14 +61,11 @@ func TokenMiddleware() gin.HandlerFunc {
// 200 OK with [] empty dataset list, when listing datasets (use case for sda-filesystem download tool)
// 404 dataset not found, when listing files from a dataset
// 401 unauthorised, when downloading a file
datasets := auth.GetPermissions(*visas)
datasetCache = session.DatasetCache{
Datasets: datasets,
}
cache.Datasets = auth.GetPermissions(*visas)

// Start a new session and store datasets under the session key
key := session.NewSessionKey()
session.Set(key, datasetCache)
session.Set(key, cache)
c.SetCookie(config.Config.Session.Name, // name
key, // value
int(config.Config.Session.Expiration)/1e9, // max age
Expand All @@ -77,31 +78,25 @@ func TokenMiddleware() gin.HandlerFunc {
}

// Store dataset list to request context, for use in the endpoint handlers
c = storeDatasets(c, datasetCache)
log.Debugf("storing %v to request context", cache)
c.Set(requestContextKey, cache)

// Forward request to the next endpoint handler
c.Next()
}

}

// storeDatasets stores the dataset list to the request context
func storeDatasets(c *gin.Context, datasets session.DatasetCache) *gin.Context {
log.Debugf("storing %v datasets to request context", datasets)

c.Set(datasetsKey, datasets)

return c
}

// GetDatasets extracts the dataset list from the request context
var GetDatasets = func(c *gin.Context) session.DatasetCache {
var datasetCache session.DatasetCache
cached, exists := c.Get(datasetsKey)
// GetCacheFromContext is a helper function that endpoints can use to get data
// stored to the *current* request context (not the session storage).
// The request context was populated by the middleware, which in turn uses the session storage.
var GetCacheFromContext = func(c *gin.Context) session.Cache {
var cache session.Cache
cached, exists := c.Get(requestContextKey)
if exists {
datasetCache = cached.(session.DatasetCache)
cache = cached.(session.Cache)
}
log.Debugf("returning %v from request context", cached)

return datasetCache
return cache
}
24 changes: 12 additions & 12 deletions api/middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ func TestTokenMiddleware_Success_NoCache(t *testing.T) {
// Now that we are modifying the request context, we need to place the context test inside the handler
expectedDatasets := []string{"dataset1", "dataset2"}
testEndpointWithContextData := func(c *gin.Context) {
datasets, _ := c.Get(datasetsKey)
if !reflect.DeepEqual(datasets.(session.DatasetCache).Datasets, expectedDatasets) {
datasets, _ := c.Get(requestContextKey)
if !reflect.DeepEqual(datasets.(session.Cache).Datasets, expectedDatasets) {
t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %s expected %s", datasets, expectedDatasets)
}
}
Expand Down Expand Up @@ -224,9 +224,9 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
originalGetCache := session.Get

// Substitute mock functions
session.Get = func(key string) (session.DatasetCache, bool) {
session.Get = func(key string) (session.Cache, bool) {
log.Warningf("session.Get %v", key)
cached := session.DatasetCache{
cached := session.Cache{
Datasets: []string{"dataset1", "dataset2"},
}

Expand All @@ -248,8 +248,8 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
// Now that we are modifying the request context, we need to place the context test inside the handler
expectedDatasets := []string{"dataset1", "dataset2"}
testEndpointWithContextData := func(c *gin.Context) {
datasets, _ := c.Get(datasetsKey)
if !reflect.DeepEqual(datasets.(session.DatasetCache).Datasets, expectedDatasets) {
datasets, _ := c.Get(requestContextKey)
if !reflect.DeepEqual(datasets.(session.Cache).Datasets, expectedDatasets) {
t.Errorf("TestTokenMiddleware_Success_FromCache failed, got %s expected %s", datasets, expectedDatasets)
}
}
Expand Down Expand Up @@ -284,13 +284,13 @@ func TestStoreDatasets(t *testing.T) {
c, _ := gin.CreateTestContext(w)

// Store data to request context
datasets := session.DatasetCache{
datasets := session.Cache{
Datasets: []string{"dataset1", "dataset2"},
}
modifiedContext := storeDatasets(c, datasets)
c.Set(requestContextKey, datasets)

// Verify that context has new data
storedDatasets := modifiedContext.Value(datasetsKey).(session.DatasetCache)
storedDatasets := c.Value(requestContextKey).(session.Cache)
if !reflect.DeepEqual(datasets, storedDatasets) {
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
}
Expand All @@ -304,13 +304,13 @@ func TestGetDatasets(t *testing.T) {
c, _ := gin.CreateTestContext(w)

// Store data to request context
datasets := session.DatasetCache{
datasets := session.Cache{
Datasets: []string{"dataset1", "dataset2"},
}
modifiedContext := storeDatasets(c, datasets)
c.Set(requestContextKey, datasets)

// Verify that context has new data
storedDatasets := GetDatasets(modifiedContext)
storedDatasets := GetCacheFromContext(c)
if !reflect.DeepEqual(datasets, storedDatasets) {
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
}
Expand Down
12 changes: 6 additions & 6 deletions api/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func ListBuckets(c *gin.Context) {
}

buckets := []Bucket{}
datasetCache := middleware.GetDatasets(c)
for _, dataset := range datasetCache.Datasets {
cache := middleware.GetCacheFromContext(c)
for _, dataset := range cache.Datasets {
datasetInfo, err := database.GetDatasetInfo(dataset)
if err != nil {
log.Errorf("Failed to get dataset information: %v", err)
Expand Down Expand Up @@ -123,8 +123,8 @@ func ListObjects(c *gin.Context) {
dataset := c.Param("dataset")

allowed := false
datasetCache := middleware.GetDatasets(c)
for _, known := range datasetCache.Datasets {
cache := middleware.GetCacheFromContext(c)
for _, known := range cache.Datasets {
if dataset == known {
allowed = true

Expand Down Expand Up @@ -244,8 +244,8 @@ func parseParams(c *gin.Context) *gin.Context {
path = string(protocolPattern.ReplaceAll([]byte(path), []byte("$1/$2")))
}

datasetCache := middleware.GetDatasets(c)
for _, dataset := range datasetCache.Datasets {
cache := middleware.GetCacheFromContext(c)
for _, dataset := range cache.Datasets {
// check that the path starts with the dataset name, but also that the
// path is only the dataset, or that the following character is a slash.
// This prevents wrong matches in cases like when one dataset name is a
Expand Down
14 changes: 7 additions & 7 deletions api/sda/sda.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func Datasets(c *gin.Context) {

// Retrieve dataset list from request context
// generated by the authentication middleware
datasetCache := middleware.GetDatasets(c)
cache := middleware.GetCacheFromContext(c)

// Return response
c.JSON(http.StatusOK, datasetCache.Datasets)
c.JSON(http.StatusOK, cache.Datasets)
}

// find looks for a dataset name in a list of datasets
Expand All @@ -60,11 +60,11 @@ var getFiles = func(datasetID string, ctx *gin.Context) ([]*database.FileInfo, i

// Retrieve dataset list from request context
// generated by the authentication middleware
datasetCache := middleware.GetDatasets(ctx)
cache := middleware.GetCacheFromContext(ctx)

log.Debugf("request to process files for dataset %s", sanitizeString(datasetID))

if find(datasetID, datasetCache.Datasets) {
if find(datasetID, cache.Datasets) {
// Get file metadata
files, err := database.GetFiles(datasetID)
if err != nil {
Expand Down Expand Up @@ -133,12 +133,12 @@ func Download(c *gin.Context) {
}

// Get datasets from request context, parsed previously by token middleware
datasetCache := middleware.GetDatasets(c)
cache := middleware.GetCacheFromContext(c)

// Verify user has permission to datafile
permission := false
for d := range datasetCache.Datasets {
if datasetCache.Datasets[d] == dataset {
for d := range cache.Datasets {
if cache.Datasets[d] == dataset {
permission = true

break
Expand Down
Loading
Loading