Skip to content

Commit

Permalink
refactor: refactor entry and entry key handling
Browse files Browse the repository at this point in the history
The remaining reads and expiration are now handled at the entry key
level instead of the entry level.
The tests have been updated to reflect these changes. Additionally, some
unnecessary code has been removed and several minor improvements have
been made.
  • Loading branch information
Ajnasz committed May 23, 2024
1 parent dba23c0 commit d07cf51
Show file tree
Hide file tree
Showing 19 changed files with 227 additions and 316 deletions.
13 changes: 9 additions & 4 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,6 @@ func TestCreateEntryWithMaxReads(t *testing.T) {
NewSecretHandler(NewHandlerConfig(db)).ServeHTTP(w, req)

resp := w.Result()
model := &models.EntryModel{}

savedUUID := resp.Header.Get("x-entry-uuid")

Expand All @@ -652,7 +651,8 @@ func TestCreateEntryWithMaxReads(t *testing.T) {
if err != nil {
t.Fatal(err)
}
entry, err := model.ReadEntryMeta(ctx, tx, savedUUID)
model := &models.EntryKeyModel{}
entries, err := model.Get(ctx, tx, savedUUID)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand All @@ -665,8 +665,13 @@ func TestCreateEntryWithMaxReads(t *testing.T) {
t.Errorf("commit failed: %v", err)
}

if entry.RemainingReads != 2 {
t.Fatalf("expected max reads to be: %d, actual: %d", 2, entry.RemainingReads)
if len(entries) != 1 {
t.Fatalf("expected to get entry key %d, got %d", 1, len(entries))
}

remainingReads := entries[0].RemainingReads.Int16
if remainingReads != 2 {
t.Fatalf("expected max reads to be: %d, actual: %d", 2, remainingReads)
}
}

Expand Down
4 changes: 1 addition & 3 deletions internal/api/createentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error {

// Handle handles http request to create secret
func (c CreateHandler) Handle(w http.ResponseWriter, r *http.Request) {
err := c.handle(w, r)

if err != nil {
if err := c.handle(w, r); err != nil {
log.Println("create error", err)
c.view.RenderError(w, r, err)
}
Expand Down
8 changes: 6 additions & 2 deletions internal/api/generateentrykey.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package api

import (
"context"
"fmt"
"net/http"
"time"

"github.com/Ajnasz/sekret.link/internal/key"
"github.com/Ajnasz/sekret.link/internal/parsers"
Expand All @@ -16,7 +18,7 @@ type GenerateEntryKeyView interface {
}

type GenerateEntryKeyManager interface {
GenerateEntryKey(ctx context.Context, UUID string, k key.Key) (*services.EntryKeyData, error)
GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error)
}

type GenerateEntryKeyHandler struct {
Expand Down Expand Up @@ -44,10 +46,12 @@ func (g GenerateEntryKeyHandler) handle(w http.ResponseWriter, r *http.Request)
return err
}

fmt.Println(request)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key)
entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key, request.Expiration, request.MaxReads)
if err != nil {
return err
}
Expand Down
41 changes: 18 additions & 23 deletions internal/models/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,18 @@ var ErrInvalidKey = errors.New("invalid key")
var ErrCreateEntry = errors.New("failed to create entry")

type EntryMeta struct {
UUID string
RemainingReads int
DeleteKey string
Created time.Time
Accessed sql.NullTime
Expire time.Time
ContentType string
UUID string
DeleteKey string
Created time.Time
Accessed sql.NullTime
ContentType string
}

// uuid uuid PRIMARY KEY,
// data BYTEA,
// remaining_reads SMALLINT DEFAULT 1,
// delete_key CHAR(256) NOT NULL,
// created TIMESTAMPTZ,
// accessed TIMESTAMPTZ,
// expire TIMESTAMPTZ
type Entry struct {
EntryMeta
Data []byte
Expand All @@ -51,14 +47,14 @@ func (e *EntryModel) getDeleteKey() (string, error) {
}

// CreateEntry creates a new entry into the database
func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, contenType string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, error) {
func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, contenType string, data []byte) (*EntryMeta, error) {
deleteKey, err := e.getDeleteKey()
if err != nil {
return nil, errors.Join(err, ErrCreateEntry)
}

now := time.Now()
res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key, content_type) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING uuid, delete_key;`, uuid, data, now, now.Add(expire), remainingReads, deleteKey, contenType)
res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, delete_key, content_type) VALUES ($1, $2, $3, $4, $5) RETURNING uuid, delete_key;`, uuid, data, now, deleteKey, contenType)

if err != nil {
return nil, errors.Join(err, ErrCreateEntry)
Expand All @@ -74,25 +70,23 @@ func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, c
}

return &EntryMeta{
UUID: uuid,
RemainingReads: remainingReads,
DeleteKey: deleteKey,
Created: now,
Expire: now.Add(expire),
UUID: uuid,
DeleteKey: deleteKey,
Created: now,
}, err
}

func (e *EntryModel) Use(ctx context.Context, tx *sql.Tx, uuid string) error {
_, err := tx.ExecContext(ctx, "UPDATE entries SET accessed = NOW(), remaining_reads = remaining_reads - 1 WHERE uuid = $1 AND remaining_reads > 0", uuid)
_, err := tx.ExecContext(ctx, "UPDATE entries SET accessed = NOW() WHERE uuid = $1", uuid)
return err
}

// ReadEntry reads a entry from the database
// and updates the read count
func (e *EntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, uuid string) (*Entry, error) {
row := tx.QueryRow("SELECT uuid, data, remaining_reads, delete_key, created, accessed, expire, content_type FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid)
row := tx.QueryRow("SELECT uuid, data, delete_key, created, accessed, content_type FROM entries WHERE uuid=$1 LIMIT 1", uuid)
var s Entry
err := row.Scan(&s.UUID, &s.Data, &s.RemainingReads, &s.DeleteKey, &s.Created, &s.Accessed, &s.Expire, &s.ContentType)
err := row.Scan(&s.UUID, &s.Data, &s.DeleteKey, &s.Created, &s.Accessed, &s.ContentType)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrEntryNotFound
Expand All @@ -104,9 +98,9 @@ func (e *EntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, uuid string) (*E
}

func (e *EntryModel) ReadEntryMeta(ctx context.Context, tx *sql.Tx, uuid string) (*EntryMeta, error) {
row := tx.QueryRow("SELECT created, accessed, expire, remaining_reads, delete_key, content_type FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid)
row := tx.QueryRow("SELECT created, accessed, delete_key, content_type FROM entries WHERE uuid=$1 LIMIT 1", uuid)
var s EntryMeta
err := row.Scan(&s.Created, &s.Accessed, &s.Expire, &s.RemainingReads, &s.DeleteKey, &s.ContentType)
err := row.Scan(&s.Created, &s.Accessed, &s.DeleteKey, &s.ContentType)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrEntryNotFound
Expand Down Expand Up @@ -157,7 +151,8 @@ func (e *EntryModel) DeleteEntry(ctx context.Context, tx *sql.Tx, uuid string, d
}

func (e *EntryModel) DeleteExpired(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "DELETE FROM entries WHERE expire < NOW()")
// TODO join with entry_keys table and delete if no living entry found
// _, err := tx.ExecContext(ctx, "DELETE FROM entries WHERE expire < NOW()")

return err
return nil
}
21 changes: 2 additions & 19 deletions internal/models/entry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"testing"
"time"

"github.com/Ajnasz/sekret.link/internal/test/durable"
"github.com/google/uuid"
Expand All @@ -26,20 +25,14 @@ func Test_EntryModel_CreateEntry(t *testing.T) {

uid := uuid.New().String()
data := []byte("test data")
remainingReads := 2
expire := time.Hour * 24

model := &EntryModel{}

meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data, remainingReads, expire)
meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data)
if err != nil {
t.Fatal(err)
}

if meta.RemainingReads != 2 {
t.Errorf("expected %d got %d", remainingReads, meta.RemainingReads)
}

if meta.UUID != uid {
t.Errorf("expected %s got %s", uid, meta.UUID)
}
Expand All @@ -52,10 +45,6 @@ func Test_EntryModel_CreateEntry(t *testing.T) {
t.Errorf("expected created to be set")
}

if meta.Expire.IsZero() {
t.Errorf("expected expire to be set")
}

if meta.Accessed.Valid {
t.Errorf("expected accessed not to be set")
}
Expand All @@ -81,12 +70,10 @@ func Test_EntryModel_Use(t *testing.T) {

uid := uuid.New().String()
data := []byte("test data")
remainingReads := 2
expire := time.Hour * 24

model := &EntryModel{}

meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data, remainingReads, expire)
meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data)
if err != nil {
t.Fatal(err)
}
Expand All @@ -100,10 +87,6 @@ func Test_EntryModel_Use(t *testing.T) {
t.Fatal(errors.Join(err, errors.New("failed to read entry")))
}

if entry.RemainingReads != 1 {
t.Errorf("expected %d got %d", 0, entry.RemainingReads)
}

if !entry.Accessed.Valid {
t.Errorf("expected accessed to be set")
}
Expand Down
10 changes: 5 additions & 5 deletions internal/models/entrykey.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ type EntryKey struct {

type EntryKeyModel struct{}

func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*EntryKey, error) {
func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*EntryKey, error) {

now := time.Now()
res := tx.QueryRowContext(ctx, `
INSERT INTO entry_key (uuid, entry_uuid, encrypted_key, key_hash, created)
VALUES (gen_random_uuid(), $1, $2, $3, $4) RETURNING uuid, created;
`, entryUUID, encryptedKey, hash, now)
INSERT INTO entry_key (uuid, entry_uuid, encrypted_key, key_hash, created, remaining_reads, expire)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6) RETURNING uuid, created;
`, entryUUID, encryptedKey, hash, now, remainingReads, expire)

var uid string
var created time.Time
Expand All @@ -50,7 +50,7 @@ func (e *EntryKeyModel) Get(ctx context.Context, tx *sql.Tx, entryUUID string) (
SELECT uuid, entry_uuid, encrypted_key, key_hash, created, expire, remaining_reads
FROM entry_key
WHERE entry_uuid = $1
AND (expire IS NULL OR expire > NOW());
;
`, entryUUID)

if err != nil {
Expand Down
15 changes: 8 additions & 7 deletions internal/models/entrykey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error)

entryModel := &EntryModel{}

_, err := entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600)
_, err := entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"))

if err != nil {
return "", "", err
}

model := &EntryKeyModel{}

entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"))
entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"), time.Now().Add(time.Hour), 2)

if err != nil {
return "", "", err
Expand All @@ -66,14 +66,14 @@ func Test_EntryKeyModel_Create(t *testing.T) {
uid := uuid.New().String()

entryModel := &EntryModel{}
_, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600)
_, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"))
if err != nil {
t.Fatal(err)
}

model := &EntryKeyModel{}

entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"))
entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"), time.Now().Add(time.Hour), 2)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand Down Expand Up @@ -124,7 +124,7 @@ func Test_EntryKeyModel_Get(t *testing.T) {
uid := uuid.New().String()

entryModel := &EntryModel{}
_, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600)
_, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"))
if err != nil {
if err := tx.Rollback(); err != nil {
t.Error(err)
Expand All @@ -135,7 +135,7 @@ func Test_EntryKeyModel_Get(t *testing.T) {
model := &EntryKeyModel{}

for i := 0; i < 10; i++ {
_, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)))
_, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)), time.Now().Add(time.Hour), 2)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand All @@ -156,6 +156,7 @@ func Test_EntryKeyModel_Get(t *testing.T) {
}

entryKeys, err := model.Get(ctx, tx, uid)
fmt.Println("ENTRY KEYS", entryKeys)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand All @@ -169,7 +170,7 @@ func Test_EntryKeyModel_Get(t *testing.T) {
}

if len(entryKeys) != 10 {
t.Fatalf("expected 1 got %d", len(entryKeys))
t.Fatalf("expected 10 got %d", len(entryKeys))
}

if entryKeys[0].EntryUUID != uid {
Expand Down
18 changes: 18 additions & 0 deletions internal/models/migrate/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func (e *EntryMigration) Alter(ctx context.Context, tx *sql.Tx) error {
return err
}

if err := e.dropKeyFields(ctx, tx); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -123,3 +127,17 @@ func (e *EntryMigration) addContentType(ctx context.Context, tx *sql.Tx) error {

return nil
}

func (e *EntryMigration) dropKeyFields(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE entries DROP COLUMN IF EXISTS remaining_reads;")
if err != nil {
return fmt.Errorf("failed to drop remaining_reads column: %w", err)
}

_, err = tx.ExecContext(ctx, "ALTER TABLE entries DROP COLUMN IF EXISTS expire;")
if err != nil {
return fmt.Errorf("failed to drop expire column: %w", err)
}

return nil
}
6 changes: 2 additions & 4 deletions internal/models/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package models
import (
"context"
"database/sql"
"time"

"github.com/stretchr/testify/mock"
)
Expand All @@ -18,9 +17,8 @@ func (m *MockEntryModel) CreateEntry(
UUID string,
contentType string,
data []byte,
remainingReads int,
expire time.Duration) (*EntryMeta, error) {
args := m.Called(ctx, tx, UUID, data, remainingReads, expire)
) (*EntryMeta, error) {
args := m.Called(ctx, tx, UUID, data)
return args.Get(0).(*EntryMeta), args.Error(1)
}

Expand Down
Loading

0 comments on commit d07cf51

Please sign in to comment.