Skip to content

Commit 4cd374b

Browse files
committed
Add Codec interface
1 parent ddba8b6 commit 4cd374b

File tree

4 files changed

+113
-75
lines changed

4 files changed

+113
-75
lines changed

codec.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package scs
2+
3+
import (
4+
"bytes"
5+
"encoding/gob"
6+
"time"
7+
)
8+
9+
// Codec is the interface for encoding/decoding session data to and from a byte
10+
// slice for use by the session store.
11+
type Codec interface {
12+
Encode(deadline time.Time, values map[string]interface{}) ([]byte, error)
13+
Decode([]byte) (deadline time.Time, values map[string]interface{}, err error)
14+
}
15+
16+
type gobCodec struct{}
17+
18+
func (gobCodec) Encode(deadline time.Time, values map[string]interface{}) ([]byte, error) {
19+
aux := &struct {
20+
Deadline time.Time
21+
Values map[string]interface{}
22+
}{
23+
Deadline: deadline,
24+
Values: values,
25+
}
26+
27+
var b bytes.Buffer
28+
err := gob.NewEncoder(&b).Encode(&aux)
29+
if err != nil {
30+
return nil, err
31+
}
32+
33+
return b.Bytes(), nil
34+
}
35+
36+
func (gobCodec) Decode(b []byte) (time.Time, map[string]interface{}, error) {
37+
aux := &struct {
38+
Deadline time.Time
39+
Values map[string]interface{}
40+
}{}
41+
42+
r := bytes.NewReader(b)
43+
err := gob.NewDecoder(r).Decode(&aux)
44+
if err != nil {
45+
return time.Time{}, nil, err
46+
}
47+
48+
return aux.Deadline, aux.Values, nil
49+
}

data.go

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package scs
22

33
import (
4-
"bytes"
54
"context"
65
"crypto/rand"
76
"encoding/base64"
8-
"encoding/gob"
97
"fmt"
108
"sort"
119
"sync"
@@ -30,18 +28,18 @@ const (
3028
)
3129

3230
type sessionData struct {
33-
Deadline time.Time // Exported for gob encoding.
31+
deadline time.Time
3432
status Status
3533
token string
36-
Values map[string]interface{} // Exported for gob encoding.
34+
values map[string]interface{}
3735
mu sync.Mutex
3836
}
3937

4038
func newSessionData(lifetime time.Duration) *sessionData {
4139
return &sessionData{
42-
Deadline: time.Now().Add(lifetime).UTC(),
40+
deadline: time.Now().Add(lifetime).UTC(),
4341
status: Unmodified,
44-
Values: make(map[string]interface{}),
42+
values: make(map[string]interface{}),
4543
}
4644
}
4745

@@ -71,7 +69,7 @@ func (s *SessionManager) Load(ctx context.Context, token string) (context.Contex
7169
status: Unmodified,
7270
token: token,
7371
}
74-
err = sd.decode(b)
72+
sd.deadline, sd.values, err = s.Codec.Decode(b)
7573
if err != nil {
7674
return nil, err
7775
}
@@ -104,12 +102,12 @@ func (s *SessionManager) Commit(ctx context.Context) (string, time.Time, error)
104102
}
105103
}
106104

107-
b, err := sd.encode()
105+
b, err := s.Codec.Encode(sd.deadline, sd.values)
108106
if err != nil {
109107
return "", time.Time{}, err
110108
}
111109

112-
expiry := sd.Deadline
110+
expiry := sd.deadline
113111
if s.IdleTimeout > 0 {
114112
ie := time.Now().Add(s.IdleTimeout)
115113
if ie.Before(expiry) {
@@ -143,9 +141,9 @@ func (s *SessionManager) Destroy(ctx context.Context) error {
143141

144142
// Reset everything else to defaults.
145143
sd.token = ""
146-
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
147-
for key := range sd.Values {
148-
delete(sd.Values, key)
144+
sd.deadline = time.Now().Add(s.Lifetime).UTC()
145+
for key := range sd.values {
146+
delete(sd.values, key)
149147
}
150148

151149
return nil
@@ -158,7 +156,7 @@ func (s *SessionManager) Put(ctx context.Context, key string, val interface{}) {
158156
sd := s.getSessionDataFromContext(ctx)
159157

160158
sd.mu.Lock()
161-
sd.Values[key] = val
159+
sd.values[key] = val
162160
sd.status = Modified
163161
sd.mu.Unlock()
164162
}
@@ -180,7 +178,7 @@ func (s *SessionManager) Get(ctx context.Context, key string) interface{} {
180178
sd.mu.Lock()
181179
defer sd.mu.Unlock()
182180

183-
return sd.Values[key]
181+
return sd.values[key]
184182
}
185183

186184
// Pop acts like a one-time Get. It returns the value for a given key from the
@@ -193,11 +191,11 @@ func (s *SessionManager) Pop(ctx context.Context, key string) interface{} {
193191
sd.mu.Lock()
194192
defer sd.mu.Unlock()
195193

196-
val, exists := sd.Values[key]
194+
val, exists := sd.values[key]
197195
if !exists {
198196
return nil
199197
}
200-
delete(sd.Values, key)
198+
delete(sd.values, key)
201199
sd.status = Modified
202200

203201
return val
@@ -212,12 +210,12 @@ func (s *SessionManager) Remove(ctx context.Context, key string) {
212210
sd.mu.Lock()
213211
defer sd.mu.Unlock()
214212

215-
_, exists := sd.Values[key]
213+
_, exists := sd.values[key]
216214
if !exists {
217215
return
218216
}
219217

220-
delete(sd.Values, key)
218+
delete(sd.values, key)
221219
sd.status = Modified
222220
}
223221

@@ -230,12 +228,12 @@ func (s *SessionManager) Clear(ctx context.Context) error {
230228
sd.mu.Lock()
231229
defer sd.mu.Unlock()
232230

233-
if len(sd.Values) == 0 {
231+
if len(sd.values) == 0 {
234232
return nil
235233
}
236234

237-
for key := range sd.Values {
238-
delete(sd.Values, key)
235+
for key := range sd.values {
236+
delete(sd.values, key)
239237
}
240238
sd.status = Modified
241239
return nil
@@ -246,7 +244,7 @@ func (s *SessionManager) Exists(ctx context.Context, key string) bool {
246244
sd := s.getSessionDataFromContext(ctx)
247245

248246
sd.mu.Lock()
249-
_, exists := sd.Values[key]
247+
_, exists := sd.values[key]
250248
sd.mu.Unlock()
251249

252250
return exists
@@ -259,9 +257,9 @@ func (s *SessionManager) Keys(ctx context.Context) []string {
259257
sd := s.getSessionDataFromContext(ctx)
260258

261259
sd.mu.Lock()
262-
keys := make([]string, len(sd.Values))
260+
keys := make([]string, len(sd.values))
263261
i := 0
264-
for key := range sd.Values {
262+
for key := range sd.values {
265263
keys[i] = key
266264
i++
267265
}
@@ -298,7 +296,7 @@ func (s *SessionManager) RenewToken(ctx context.Context) error {
298296
}
299297

300298
sd.token = newToken
301-
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
299+
sd.deadline = time.Now().Add(s.Lifetime).UTC()
302300
sd.status = Modified
303301

304302
return nil
@@ -477,21 +475,6 @@ func (s *SessionManager) getSessionDataFromContext(ctx context.Context) *session
477475
return c
478476
}
479477

480-
func (sd *sessionData) encode() ([]byte, error) {
481-
var b bytes.Buffer
482-
err := gob.NewEncoder(&b).Encode(sd)
483-
if err != nil {
484-
return nil, err
485-
}
486-
487-
return b.Bytes(), nil
488-
}
489-
490-
func (sd *sessionData) decode(b []byte) error {
491-
r := bytes.NewReader(b)
492-
return gob.NewDecoder(r).Decode(sd)
493-
}
494-
495478
func generateToken() (string, error) {
496479
b := make([]byte, 32)
497480
_, err := rand.Read(b)

0 commit comments

Comments
 (0)