diff --git a/go.mod b/go.mod index 65359d7..8d082e2 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/ivarprudnikov/secretshare go 1.22 require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.2.0 github.com/gorilla/sessions v1.2.2 @@ -10,7 +11,6 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect diff --git a/internal/storage/aztablestore/messages.go b/internal/storage/aztablestore/messages.go new file mode 100644 index 0000000..738dbc8 --- /dev/null +++ b/internal/storage/aztablestore/messages.go @@ -0,0 +1,228 @@ +package aztablestore + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "github.com/Azure/azure-sdk-for-go/sdk/data/aztables" + "github.com/ivarprudnikov/secretshare/internal/storage" +) + +type azMessageStore struct { + accountName string + tableName string + salt string +} + +func NewAzMessageStore(accountName, tableName, salt string) storage.MessageStore { + return &azMessageStore{accountName: accountName, tableName: tableName, salt: salt} +} + +func (s *azMessageStore) getClient() (*aztables.Client, error) { + return getTableClient(s.accountName, s.tableName) +} + +func (s *azMessageStore) CountMessages() (int64, error) { + var count int64 = 0 + client, err := s.getClient() + if err != nil { + return count, err + } + keySelector := "$select=PartitionKey" + metadataFormat := aztables.MetadataFormatNone + listPager := client.NewListEntitiesPager(&aztables.ListEntitiesOptions{ + Select: &keySelector, + Format: &metadataFormat, + }) + for listPager.More() { + response, err := listPager.NextPage(context.TODO()) + if err != nil { + return count, err + } + count += int64(len(response.Entities)) + } + return count, nil +} + +// TODO move to storage +func (s *azMessageStore) Encrypt(text string, pass string) (string, error) { + // derive a key from the pass + key, err := storage.StrongKey(pass, s.salt) + if err != nil { + return "", err + } + ciphertext, err := storage.EncryptAES(key, text) + if err != nil { + return "", err + } + return ciphertext, nil +} + +// Decrypt cipher text with a given PIN which will be used to derive a key +func (s *azMessageStore) Decrypt(ciphertext string, pass string) (string, error) { + // derive a key from the pass + key, err := storage.StrongKey(pass, s.salt) + if err != nil { + return "", err + } + plaintext, err := storage.DecryptAES(key, ciphertext) + if err != nil { + return "", err + } + return plaintext, nil +} + +func (s *azMessageStore) ListMessages(username string) ([]*storage.Message, error) { + var msgs []*storage.Message + client, err := s.getClient() + if err != nil { + return msgs, err + } + userFilter := fmt.Sprintf("RowKey eq '%s'", username) + listPager := client.NewListEntitiesPager(&aztables.ListEntitiesOptions{ + Filter: &userFilter, + }) + for listPager.More() { + response, err := listPager.NextPage(context.TODO()) + if err != nil { + return msgs, err + } + for _, v := range response.Entities { + var msg *storage.Message + err = json.Unmarshal(v, &msg) + if err != nil { + return msgs, err + } + msgs = append(msgs, msg) + } + } + return msgs, nil +} + +// TODO: allow to reset the pin for the owner +func (s *azMessageStore) AddMessage(text string, username string) (*storage.Message, error) { + // an easy to enter pin + pin, err := storage.MakePin() + if err != nil { + return nil, err + } + ciphertext, err := s.Encrypt(text, pin) + if err != nil { + return nil, err + } + msg, err := storage.NewMessage(username, ciphertext, pin) + if err != nil { + return nil, err + } + err = s.saveMessage(&msg) + msg.Pin = pin + return &msg, nil +} + +func (s *azMessageStore) GetMessage(id string) (*storage.Message, error) { + msg, err := s.getMessage(id) + if err != nil { + return nil, err + } + // clear the pin to let the view know it needs decryption + msg.Pin = "" + return msg, nil +} + +func (s *azMessageStore) GetFullMessage(id string, pin string) (*storage.Message, error) { + msg, err := s.getMessage(id) + if err != nil { + return nil, err + } + + if err := storage.CompareHashToPass(msg.Pin, pin); err == nil { + text, err := s.Decrypt(msg.Content, pin) + if err != nil { + return nil, err + } + msg.Content = text + err = s.deleteMessage(msg) + if err != nil { + slog.LogAttrs(context.TODO(), slog.LevelError, "failed to delete message", slog.String("id", msg.PartitionKey), slog.String("username", msg.RowKey)) + } + return msg, nil + } + + msg.AttemptsRemaining -= 1 + // If the pin was wrong then track attempts + if msg.AttemptsRemaining <= 0 { + err = s.deleteMessage(msg) + if err != nil { + slog.LogAttrs(context.TODO(), slog.LevelError, "failed to delete message", slog.String("id", msg.PartitionKey), slog.String("username", msg.RowKey)) + } + } else { + err = s.saveMessage(msg) + if err != nil { + slog.LogAttrs(context.TODO(), slog.LevelError, "failed to update message", slog.String("id", msg.PartitionKey), slog.String("username", msg.RowKey)) + } + } + + return nil, nil +} + +func (s *azMessageStore) getMessage(id string) (*storage.Message, error) { + client, err := s.getClient() + if err != nil { + return nil, err + } + var msgs []*storage.Message + idFilter := fmt.Sprintf("PartitionKey eq '%s'", id) + listPager := client.NewListEntitiesPager(&aztables.ListEntitiesOptions{ + Filter: &idFilter, + }) + for listPager.More() { + response, err := listPager.NextPage(context.TODO()) + if err != nil { + return nil, err + } + for _, v := range response.Entities { + var msg *storage.Message + err = json.Unmarshal(v, &msg) + if err != nil { + return nil, err + } + msgs = append(msgs, msg) + } + } + if len(msgs) > 1 { + slog.LogAttrs(context.TODO(), slog.LevelError, "more than one message with the same id", slog.String("id", id), slog.Int("total", len(msgs))) + } + return msgs[0], nil +} + +func (s *azMessageStore) saveMessage(msg *storage.Message) error { + marshalled, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + client, err := s.getClient() + if err != nil { + return fmt.Errorf("failed to get aztable client: %w", err) + } + _, err = client.UpsertEntity(context.TODO(), marshalled, &aztables.UpsertEntityOptions{ + UpdateMode: aztables.UpdateModeReplace, + }) + if err != nil { + return fmt.Errorf("failed to save: %w", err) + } + return nil +} + +func (s *azMessageStore) deleteMessage(msg *storage.Message) error { + client, err := s.getClient() + if err != nil { + return fmt.Errorf("failed to get aztable client: %w", err) + } + _, err = client.DeleteEntity(context.TODO(), msg.PartitionKey, msg.RowKey, nil) + if err != nil { + return err + } + return nil +} diff --git a/server.go b/server.go index 2219a6a..ae2f8f0 100644 --- a/server.go +++ b/server.go @@ -54,23 +54,30 @@ func getPort() string { // Production environment needs to work with Azure Table Storage which is not // available locally. Locally an in-memory implementation of storage is used. func getStorageImplementation(config *configuration.ConfigReader) (storage.MessageStore, storage.UserStore) { - var messages storage.MessageStore = memstore.NewMemMessageStore(config.GetSalt()) - var users storage.UserStore = aztablestore.NewAzUserStore(config.GetStorageAccountName(), config.GetUsersTableName(), config.GetSalt()) - if !config.IsProd() { + var messages storage.MessageStore + var users storage.UserStore + + if config.IsProd() { + messages = aztablestore.NewAzMessageStore(config.GetStorageAccountName(), config.GetMessagesTableName(), config.GetSalt()) + users = aztablestore.NewAzUserStore(config.GetStorageAccountName(), config.GetUsersTableName(), config.GetSalt()) + } else { messages = memstore.NewMemMessageStore(config.GetSalt()) users = memstore.NewMemUserStore(config.GetSalt()) + bootstrapTestData(messages, users) + } + return messages, users +} - // add test users - users.AddUser("joe", "joe", []string{}) - users.AddUser("alice", "alice", []string{}) - users.AddUser("admin", "admin", []string{storage.PERMISSION_READ_STATS}) +func bootstrapTestData(messages storage.MessageStore, users storage.UserStore) { + // add test users + users.AddUser("joe", "joe", []string{}) + users.AddUser("alice", "alice", []string{}) + users.AddUser("admin", "admin", []string{storage.PERMISSION_READ_STATS}) - // add a test message - msg, err := messages.AddMessage("foobar", "joe") - if err != nil { - panic("Unexpected error") - } - log.Printf("Generated PIN for test message %s", msg.Pin) + // add a test message + msg, err := messages.AddMessage("foobar", "joe") + if err != nil { + panic("Unexpected error") } - return messages, users + log.Printf("Generated PIN for test message %s", msg.Pin) }