Skip to content
This repository was archived by the owner on Jun 20, 2024. It is now read-only.

Commit 94ab3e7

Browse files
authored
Merge pull request #305 from neicnordic/refactor/middleware-and-cache
refactor middleware and cache in an attempt of clarifying its operation
2 parents 7babab0 + 208114a commit 94ab3e7

File tree

9 files changed

+110
-111
lines changed

9 files changed

+110
-111
lines changed

api/middleware/middleware.go

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ import (
1010
log "github.com/sirupsen/logrus"
1111
)
1212

13-
var datasetsKey = "datasets"
13+
// requestContextKey holds a name for the request context storage key
14+
// which is used to store and get the permissions after passing middleware
15+
const requestContextKey = "requestContextKey"
1416

1517
// TokenMiddleware performs access token verification and validation
1618
// JWTs are verified and validated by the app, opaque tokens are sent to AAI for verification
17-
// Successful auth results in list of authorised datasets
19+
// Successful auth results in list of authorised datasets.
20+
// The datasets are stored into a session cache for subsequent requests, and also
21+
// to the current request context for use in the endpoints.
1822
func TokenMiddleware() gin.HandlerFunc {
1923

2024
return func(c *gin.Context) {
@@ -23,11 +27,11 @@ func TokenMiddleware() gin.HandlerFunc {
2327
if err != nil {
2428
log.Debugf("no session cookie received")
2529
}
26-
var datasetCache session.DatasetCache
30+
var cache session.Cache
2731
var exists bool
2832
if sessionCookie != "" {
2933
log.Debug("session cookie received")
30-
datasetCache, exists = session.Get(sessionCookie)
34+
cache, exists = session.Get(sessionCookie)
3135
}
3236

3337
if !exists {
@@ -57,14 +61,11 @@ func TokenMiddleware() gin.HandlerFunc {
5761
// 200 OK with [] empty dataset list, when listing datasets (use case for sda-filesystem download tool)
5862
// 404 dataset not found, when listing files from a dataset
5963
// 401 unauthorised, when downloading a file
60-
datasets := auth.GetPermissions(*visas)
61-
datasetCache = session.DatasetCache{
62-
Datasets: datasets,
63-
}
64+
cache.Datasets = auth.GetPermissions(*visas)
6465

6566
// Start a new session and store datasets under the session key
6667
key := session.NewSessionKey()
67-
session.Set(key, datasetCache)
68+
session.Set(key, cache)
6869
c.SetCookie(config.Config.Session.Name, // name
6970
key, // value
7071
int(config.Config.Session.Expiration)/1e9, // max age
@@ -77,31 +78,25 @@ func TokenMiddleware() gin.HandlerFunc {
7778
}
7879

7980
// Store dataset list to request context, for use in the endpoint handlers
80-
c = storeDatasets(c, datasetCache)
81+
log.Debugf("storing %v to request context", cache)
82+
c.Set(requestContextKey, cache)
8183

8284
// Forward request to the next endpoint handler
8385
c.Next()
8486
}
8587

8688
}
8789

88-
// storeDatasets stores the dataset list to the request context
89-
func storeDatasets(c *gin.Context, datasets session.DatasetCache) *gin.Context {
90-
log.Debugf("storing %v datasets to request context", datasets)
91-
92-
c.Set(datasetsKey, datasets)
93-
94-
return c
95-
}
96-
97-
// GetDatasets extracts the dataset list from the request context
98-
var GetDatasets = func(c *gin.Context) session.DatasetCache {
99-
var datasetCache session.DatasetCache
100-
cached, exists := c.Get(datasetsKey)
90+
// GetCacheFromContext is a helper function that endpoints can use to get data
91+
// stored to the *current* request context (not the session storage).
92+
// The request context was populated by the middleware, which in turn uses the session storage.
93+
var GetCacheFromContext = func(c *gin.Context) session.Cache {
94+
var cache session.Cache
95+
cached, exists := c.Get(requestContextKey)
10196
if exists {
102-
datasetCache = cached.(session.DatasetCache)
97+
cache = cached.(session.Cache)
10398
}
10499
log.Debugf("returning %v from request context", cached)
105100

106-
return datasetCache
101+
return cache
107102
}

api/middleware/middleware_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ func TestTokenMiddleware_Success_NoCache(t *testing.T) {
182182
// Now that we are modifying the request context, we need to place the context test inside the handler
183183
expectedDatasets := []string{"dataset1", "dataset2"}
184184
testEndpointWithContextData := func(c *gin.Context) {
185-
datasets, _ := c.Get(datasetsKey)
186-
if !reflect.DeepEqual(datasets.(session.DatasetCache).Datasets, expectedDatasets) {
185+
datasets, _ := c.Get(requestContextKey)
186+
if !reflect.DeepEqual(datasets.(session.Cache).Datasets, expectedDatasets) {
187187
t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %s expected %s", datasets, expectedDatasets)
188188
}
189189
}
@@ -224,9 +224,9 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
224224
originalGetCache := session.Get
225225

226226
// Substitute mock functions
227-
session.Get = func(key string) (session.DatasetCache, bool) {
227+
session.Get = func(key string) (session.Cache, bool) {
228228
log.Warningf("session.Get %v", key)
229-
cached := session.DatasetCache{
229+
cached := session.Cache{
230230
Datasets: []string{"dataset1", "dataset2"},
231231
}
232232

@@ -248,8 +248,8 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
248248
// Now that we are modifying the request context, we need to place the context test inside the handler
249249
expectedDatasets := []string{"dataset1", "dataset2"}
250250
testEndpointWithContextData := func(c *gin.Context) {
251-
datasets, _ := c.Get(datasetsKey)
252-
if !reflect.DeepEqual(datasets.(session.DatasetCache).Datasets, expectedDatasets) {
251+
datasets, _ := c.Get(requestContextKey)
252+
if !reflect.DeepEqual(datasets.(session.Cache).Datasets, expectedDatasets) {
253253
t.Errorf("TestTokenMiddleware_Success_FromCache failed, got %s expected %s", datasets, expectedDatasets)
254254
}
255255
}
@@ -284,13 +284,13 @@ func TestStoreDatasets(t *testing.T) {
284284
c, _ := gin.CreateTestContext(w)
285285

286286
// Store data to request context
287-
datasets := session.DatasetCache{
287+
datasets := session.Cache{
288288
Datasets: []string{"dataset1", "dataset2"},
289289
}
290-
modifiedContext := storeDatasets(c, datasets)
290+
c.Set(requestContextKey, datasets)
291291

292292
// Verify that context has new data
293-
storedDatasets := modifiedContext.Value(datasetsKey).(session.DatasetCache)
293+
storedDatasets := c.Value(requestContextKey).(session.Cache)
294294
if !reflect.DeepEqual(datasets, storedDatasets) {
295295
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
296296
}
@@ -304,13 +304,13 @@ func TestGetDatasets(t *testing.T) {
304304
c, _ := gin.CreateTestContext(w)
305305

306306
// Store data to request context
307-
datasets := session.DatasetCache{
307+
datasets := session.Cache{
308308
Datasets: []string{"dataset1", "dataset2"},
309309
}
310-
modifiedContext := storeDatasets(c, datasets)
310+
c.Set(requestContextKey, datasets)
311311

312312
// Verify that context has new data
313-
storedDatasets := GetDatasets(modifiedContext)
313+
storedDatasets := GetCacheFromContext(c)
314314
if !reflect.DeepEqual(datasets, storedDatasets) {
315315
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
316316
}

api/s3/s3.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ func ListBuckets(c *gin.Context) {
9494
}
9595

9696
buckets := []Bucket{}
97-
datasetCache := middleware.GetDatasets(c)
98-
for _, dataset := range datasetCache.Datasets {
97+
cache := middleware.GetCacheFromContext(c)
98+
for _, dataset := range cache.Datasets {
9999
datasetInfo, err := database.GetDatasetInfo(dataset)
100100
if err != nil {
101101
log.Errorf("Failed to get dataset information: %v", err)
@@ -123,8 +123,8 @@ func ListObjects(c *gin.Context) {
123123
dataset := c.Param("dataset")
124124

125125
allowed := false
126-
datasetCache := middleware.GetDatasets(c)
127-
for _, known := range datasetCache.Datasets {
126+
cache := middleware.GetCacheFromContext(c)
127+
for _, known := range cache.Datasets {
128128
if dataset == known {
129129
allowed = true
130130

@@ -244,8 +244,8 @@ func parseParams(c *gin.Context) *gin.Context {
244244
path = string(protocolPattern.ReplaceAll([]byte(path), []byte("$1/$2")))
245245
}
246246

247-
datasetCache := middleware.GetDatasets(c)
248-
for _, dataset := range datasetCache.Datasets {
247+
cache := middleware.GetCacheFromContext(c)
248+
for _, dataset := range cache.Datasets {
249249
// check that the path starts with the dataset name, but also that the
250250
// path is only the dataset, or that the following character is a slash.
251251
// This prevents wrong matches in cases like when one dataset name is a

api/sda/sda.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ func Datasets(c *gin.Context) {
3535

3636
// Retrieve dataset list from request context
3737
// generated by the authentication middleware
38-
datasetCache := middleware.GetDatasets(c)
38+
cache := middleware.GetCacheFromContext(c)
3939

4040
// Return response
41-
c.JSON(http.StatusOK, datasetCache.Datasets)
41+
c.JSON(http.StatusOK, cache.Datasets)
4242
}
4343

4444
// find looks for a dataset name in a list of datasets
@@ -60,11 +60,11 @@ var getFiles = func(datasetID string, ctx *gin.Context) ([]*database.FileInfo, i
6060

6161
// Retrieve dataset list from request context
6262
// generated by the authentication middleware
63-
datasetCache := middleware.GetDatasets(ctx)
63+
cache := middleware.GetCacheFromContext(ctx)
6464

6565
log.Debugf("request to process files for dataset %s", sanitizeString(datasetID))
6666

67-
if find(datasetID, datasetCache.Datasets) {
67+
if find(datasetID, cache.Datasets) {
6868
// Get file metadata
6969
files, err := database.GetFiles(datasetID)
7070
if err != nil {
@@ -133,12 +133,12 @@ func Download(c *gin.Context) {
133133
}
134134

135135
// Get datasets from request context, parsed previously by token middleware
136-
datasetCache := middleware.GetDatasets(c)
136+
cache := middleware.GetCacheFromContext(c)
137137

138138
// Verify user has permission to datafile
139139
permission := false
140-
for d := range datasetCache.Datasets {
141-
if datasetCache.Datasets[d] == dataset {
140+
for d := range cache.Datasets {
141+
if cache.Datasets[d] == dataset {
142142
permission = true
143143

144144
break

0 commit comments

Comments
 (0)