Skip to content

Commit

Permalink
adds aztable message store to production env
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarprudnikov committed Apr 27, 2024
1 parent a1257f7 commit a9027d3
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 15 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ module github.com/ivarprudnikov/secretshare
go 1.22

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.2.0
github.com/gorilla/sessions v1.2.2
golang.org/x/crypto v0.22.0
)

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
Expand Down
228 changes: 228 additions & 0 deletions internal/storage/aztablestore/messages.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
package aztablestore

import (
"context"
"encoding/json"
"fmt"
"log/slog"

"github.com/Azure/azure-sdk-for-go/sdk/data/aztables"
"github.com/ivarprudnikov/secretshare/internal/storage"
)

type azMessageStore struct {
accountName string
tableName string
salt string
}

func NewAzMessageStore(accountName, tableName, salt string) storage.MessageStore {
return &azMessageStore{accountName: accountName, tableName: tableName, salt: salt}
}

func (s *azMessageStore) getClient() (*aztables.Client, error) {
return getTableClient(s.accountName, s.tableName)
}

func (s *azMessageStore) CountMessages() (int64, error) {
var count int64 = 0
client, err := s.getClient()
if err != nil {
return count, err
}
keySelector := "$select=PartitionKey"
metadataFormat := aztables.MetadataFormatNone
listPager := client.NewListEntitiesPager(&aztables.ListEntitiesOptions{
Select: &keySelector,
Format: &metadataFormat,
})
for listPager.More() {
response, err := listPager.NextPage(context.TODO())
if err != nil {
return count, err
}
count += int64(len(response.Entities))
}
return count, nil
}

// TODO move to storage
func (s *azMessageStore) Encrypt(text string, pass string) (string, error) {
// derive a key from the pass
key, err := storage.StrongKey(pass, s.salt)
if err != nil {
return "", err
}
ciphertext, err := storage.EncryptAES(key, text)
if err != nil {
return "", err
}
return ciphertext, nil
}

// Decrypt cipher text with a given PIN which will be used to derive a key
func (s *azMessageStore) Decrypt(ciphertext string, pass string) (string, error) {
// derive a key from the pass
key, err := storage.StrongKey(pass, s.salt)
if err != nil {
return "", err
}
plaintext, err := storage.DecryptAES(key, ciphertext)
if err != nil {
return "", err
}
return plaintext, nil
}

func (s *azMessageStore) ListMessages(username string) ([]*storage.Message, error) {
var msgs []*storage.Message
client, err := s.getClient()
if err != nil {
return msgs, err
}
userFilter := fmt.Sprintf("RowKey eq '%s'", username)
listPager := client.NewListEntitiesPager(&aztables.ListEntitiesOptions{
Filter: &userFilter,
})
for listPager.More() {
response, err := listPager.NextPage(context.TODO())
if err != nil {
return msgs, err
}
for _, v := range response.Entities {
var msg *storage.Message
err = json.Unmarshal(v, &msg)
if err != nil {
return msgs, err
}
msgs = append(msgs, msg)
}
}
return msgs, nil
}

// TODO: allow to reset the pin for the owner
func (s *azMessageStore) AddMessage(text string, username string) (*storage.Message, error) {
// an easy to enter pin
pin, err := storage.MakePin()
if err != nil {
return nil, err
}
ciphertext, err := s.Encrypt(text, pin)
if err != nil {
return nil, err
}
msg, err := storage.NewMessage(username, ciphertext, pin)
if err != nil {
return nil, err
}
err = s.saveMessage(&msg)
msg.Pin = pin
return &msg, nil
}

func (s *azMessageStore) GetMessage(id string) (*storage.Message, error) {
msg, err := s.getMessage(id)
if err != nil {
return nil, err
}
// clear the pin to let the view know it needs decryption
msg.Pin = ""
return msg, nil
}

func (s *azMessageStore) GetFullMessage(id string, pin string) (*storage.Message, error) {
msg, err := s.getMessage(id)
if err != nil {
return nil, err
}

if err := storage.CompareHashToPass(msg.Pin, pin); err == nil {
text, err := s.Decrypt(msg.Content, pin)
if err != nil {
return nil, err
}
msg.Content = text
err = s.deleteMessage(msg)
if err != nil {
slog.LogAttrs(context.TODO(), slog.LevelError, "failed to delete message", slog.String("id", msg.PartitionKey), slog.String("username", msg.RowKey))
}
return msg, nil
}

msg.AttemptsRemaining -= 1
// If the pin was wrong then track attempts
if msg.AttemptsRemaining <= 0 {
err = s.deleteMessage(msg)
if err != nil {
slog.LogAttrs(context.TODO(), slog.LevelError, "failed to delete message", slog.String("id", msg.PartitionKey), slog.String("username", msg.RowKey))
}
} else {
err = s.saveMessage(msg)
if err != nil {
slog.LogAttrs(context.TODO(), slog.LevelError, "failed to update message", slog.String("id", msg.PartitionKey), slog.String("username", msg.RowKey))
}
}

return nil, nil
}

func (s *azMessageStore) getMessage(id string) (*storage.Message, error) {
client, err := s.getClient()
if err != nil {
return nil, err
}
var msgs []*storage.Message
idFilter := fmt.Sprintf("PartitionKey eq '%s'", id)
listPager := client.NewListEntitiesPager(&aztables.ListEntitiesOptions{
Filter: &idFilter,
})
for listPager.More() {
response, err := listPager.NextPage(context.TODO())
if err != nil {
return nil, err
}
for _, v := range response.Entities {
var msg *storage.Message
err = json.Unmarshal(v, &msg)
if err != nil {
return nil, err
}
msgs = append(msgs, msg)
}
}
if len(msgs) > 1 {
slog.LogAttrs(context.TODO(), slog.LevelError, "more than one message with the same id", slog.String("id", id), slog.Int("total", len(msgs)))
}
return msgs[0], nil
}

func (s *azMessageStore) saveMessage(msg *storage.Message) error {
marshalled, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal: %w", err)
}
client, err := s.getClient()
if err != nil {
return fmt.Errorf("failed to get aztable client: %w", err)
}
_, err = client.UpsertEntity(context.TODO(), marshalled, &aztables.UpsertEntityOptions{
UpdateMode: aztables.UpdateModeReplace,
})
if err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}

func (s *azMessageStore) deleteMessage(msg *storage.Message) error {
client, err := s.getClient()
if err != nil {
return fmt.Errorf("failed to get aztable client: %w", err)
}
_, err = client.DeleteEntity(context.TODO(), msg.PartitionKey, msg.RowKey, nil)
if err != nil {
return err
}
return nil
}
35 changes: 21 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,30 @@ func getPort() string {
// Production environment needs to work with Azure Table Storage which is not
// available locally. Locally an in-memory implementation of storage is used.
func getStorageImplementation(config *configuration.ConfigReader) (storage.MessageStore, storage.UserStore) {
var messages storage.MessageStore = memstore.NewMemMessageStore(config.GetSalt())
var users storage.UserStore = aztablestore.NewAzUserStore(config.GetStorageAccountName(), config.GetUsersTableName(), config.GetSalt())
if !config.IsProd() {
var messages storage.MessageStore
var users storage.UserStore

if config.IsProd() {
messages = aztablestore.NewAzMessageStore(config.GetStorageAccountName(), config.GetMessagesTableName(), config.GetSalt())
users = aztablestore.NewAzUserStore(config.GetStorageAccountName(), config.GetUsersTableName(), config.GetSalt())
} else {
messages = memstore.NewMemMessageStore(config.GetSalt())
users = memstore.NewMemUserStore(config.GetSalt())
bootstrapTestData(messages, users)
}
return messages, users
}

// add test users
users.AddUser("joe", "joe", []string{})
users.AddUser("alice", "alice", []string{})
users.AddUser("admin", "admin", []string{storage.PERMISSION_READ_STATS})
func bootstrapTestData(messages storage.MessageStore, users storage.UserStore) {
// add test users
users.AddUser("joe", "joe", []string{})
users.AddUser("alice", "alice", []string{})
users.AddUser("admin", "admin", []string{storage.PERMISSION_READ_STATS})

// add a test message
msg, err := messages.AddMessage("foobar", "joe")
if err != nil {
panic("Unexpected error")
}
log.Printf("Generated PIN for test message %s", msg.Pin)
// add a test message
msg, err := messages.AddMessage("foobar", "joe")
if err != nil {
panic("Unexpected error")
}
return messages, users
log.Printf("Generated PIN for test message %s", msg.Pin)
}

0 comments on commit a9027d3

Please sign in to comment.