diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2ea6ba7..aab7e69 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,5 +15,4 @@ jobs: uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - - run: go test -v ./... - - run: go vet ./... + - run: make test diff --git a/Makefile b/Makefile index ba027da..7c05efb 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,16 @@ # gets last tag VERSION := $(shell git describe --abbrev=0 --tags) -test: - go vet ./... - go test ./... .PHONY: test +test: unit cover + +.PHONY: unit +unit: + go test -v -timeout 30s -coverprofile=coverage.out ./pkg/... + +.PHONY: cover +cover: + go tool cover -func=coverage.out build-cli: go build -o dracula-cli cmd/cli/main.go diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go new file mode 100644 index 0000000..c3fe972 --- /dev/null +++ b/pkg/protocol/protocol.go @@ -0,0 +1,177 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + + "github.com/OneOfOne/xxhash" +) + +const ( + PacketSize int = 1500 + + CommandOffset int = 0 + MessageIDOffset int = 1 + NamespaceOffset int = 5 + DataValueOffset int = 69 + HashOffset int = 1492 + + CmdCount byte = 'C' + CmdPut byte = 'P' + CmdPutReplicate byte = 'R' + CmdCountNamespace byte = 'N' + CmdCountServer byte = 'S' + + CmdTCPOnlyKeys byte = 'K' + CmdTCPOnlyValues byte = 'V' + CmdTCPOnlyStore byte = 'T' + CmdTCPOnlyRetrieve byte = 'I' + CmdTCPOnlyNamespaces byte = 'L' +) + +var ( + ErrInvalidPacketSizeTooSmall = errors.New("bad packet: too small, size must be 1500 bytes") + ErrInvalidPacketSizeTooLarge = errors.New("bad packet: too large, size must be 1500 bytes") + ErrInvalidCommandByte = errors.New("bad packet: invalid command byte") + ErrProtocolSpace1 = errors.New("bad packet: expected space 1") + ErrProtocolSpace2 = errors.New("bad packet: expected space 1") + ErrProtocolSpace3 = errors.New("bad packet: expected space 3") + ErrBadHash = errors.New("auth failed: packet hash invalid") + ErrBadOutputSize = errors.New("wrong data size during packet construction") + ErrBadReadCall = errors.New("packet reader error") + ErrBadReadSize = errors.New("invalid packet read size") + ErrBadWriteCall = errors.New("packet writer error") + ErrBadWriteSize = errors.New("invalid packet write size") +) + +// Packet zero copy structure for 1500 byte messages +type Packet struct { + buffer []byte +} + +// NewPacket creates a packet from an existing slice of 1500 bytes or initializes. If +// buffer is nil, a new buffer is allocated. +func NewPacket(buffer []byte) *Packet { + packet := &Packet{ + buffer: buffer, + } + if packet.buffer == nil { + packet.buffer = make([]byte, PacketSize) + } + return packet +} + +// GetCommand gets command byte +func (p *Packet) GetCommand() byte { + return p.buffer[CommandOffset] +} + +// SetCommand sets command byte +func (p *Packet) SetCommand(command byte) { + p.buffer[CommandOffset] = command +} + +// GetMessageID gets message ID +func (p *Packet) GetMessageID() uint32 { + return binary.LittleEndian.Uint32(p.buffer[MessageIDOffset:NamespaceOffset]) +} + +// SetMessageID sets message ID +func (p *Packet) SetMessageID(messageID uint32) { + binary.LittleEndian.PutUint32(p.buffer[MessageIDOffset:NamespaceOffset], messageID) +} + +// GetNamespace gets namespace +func (p *Packet) GetNamespace() string { + return p.getString(NamespaceOffset, 64) +} + +// SetNamespace sets namespace +func (p *Packet) SetNamespace(namespace string) { + p.setString(NamespaceOffset, 64, namespace) +} + +// GetDataValue gets data value +func (p *Packet) GetDataValue() string { + return p.getString(DataValueOffset, 1423) +} + +// SetDataValue sets data value +func (p *Packet) SetDataValue(dataValue string) { + p.setString(DataValueOffset, 1423, dataValue) +} + +// GetHash gets hash +func (p *Packet) GetHash() uint64 { + return binary.LittleEndian.Uint64(p.buffer[HashOffset:PacketSize]) +} + +// GetHash sets hash +func (p *Packet) SetHash(hash uint64) { + binary.LittleEndian.PutUint64(p.buffer[HashOffset:PacketSize], hash) +} + +// Sign hashes the packet with a given key and stores it in the packet hash value +func (p *Packet) Sign(key []byte) { + hash := p.hash(key) + p.SetHash(hash) +} + +// Verify checks the stored hash from the packet using a given key +func (p *Packet) Verify(key []byte) error { + hash := p.hash(key) + if hash != p.GetHash() { + return ErrBadHash + } + return nil +} + +// getString decodes string up to maxSize length starting from offset +func (p *Packet) getString(offset, maxSize int) string { + stringSize := bytes.IndexByte(p.buffer[offset:offset+maxSize], 0) + if stringSize == -1 { + stringSize = maxSize + } + buffer := p.buffer[offset : offset+stringSize] + return string(buffer) +} + +// setString encodes string up to maxSize length starting at offset +func (p *Packet) setString(offset, maxSize int, value string) { + for index := offset; index < offset+maxSize; index++ { + p.buffer[index] = 0 + } + _ = copy(p.buffer[offset:], []byte(value)) +} + +func (p *Packet) hash(key []byte) uint64 { + hasher := xxhash.New64() + _, _ = hasher.Write(key) + _, _ = hasher.Write(p.buffer[0:HashOffset]) + return hasher.Sum64() +} + +// ReadPacket extracts a packet from a reader. Returns error if the reader fails or less than 1500 bytes are read. +func ReadPacket(reader io.Reader) (*Packet, error) { + buffer := make([]byte, PacketSize) + size, err := reader.Read(buffer) + if err != nil { + return nil, ErrBadReadCall + } else if size != PacketSize { + return nil, ErrBadReadSize + } + return NewPacket(buffer), nil +} + +// ReadPacket insterts a packet into a writer. Returns error if the writer fails or less than 1500 bytes are written. +func WritePacket(writer io.Writer, packet *Packet) error { + size, err := writer.Write(packet.buffer) + if err != nil { + return ErrBadWriteCall + } else if size != PacketSize { + return ErrBadWriteSize + } + return nil +} diff --git a/pkg/protocol/protocol_test.go b/pkg/protocol/protocol_test.go new file mode 100644 index 0000000..e71292c --- /dev/null +++ b/pkg/protocol/protocol_test.go @@ -0,0 +1,129 @@ +package protocol + +import ( + "bytes" + "errors" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + success int = 0 + truncate int = 1 + failure int = 2 + + expectedCommand byte = CmdCount + expectedMessageID uint32 = 1 + expectedHash uint64 = 12345 + expectedDataValue string = "datavalue" +) + +var ( + // test the edge case of having a string of max size + expectedNamespace string = generateAlphaNumericString(64) +) + +// Packet buffer manager with configurable error states +type MockPacketStorage struct { + state int + buffer []byte +} + +// NewMockPacketStorage creates a new MockPacketStorage with an empty buffer and no error state +func NewMockPacketStorage() *MockPacketStorage { + return &MockPacketStorage{ + state: success, + buffer: make([]byte, PacketSize), + } +} + +// Read either copies from the read buffer or simulates some read error +func (m MockPacketStorage) Read(buffer []byte) (int, error) { + if m.state == truncate { + return len(buffer) / 2, nil + } else if m.state == failure { + return 0, errors.New("mock error") + } + return copy(buffer, m.buffer), nil +} + +// Write either copies to the write buffer or simulates some write error +func (m MockPacketStorage) Write(buffer []byte) (int, error) { + if m.state == truncate { + return len(buffer) / 2, nil + } else if m.state == failure { + return 0, errors.New("mock error") + } + return copy(m.buffer, buffer), nil +} + +// TestWriteRead verifies that we can write and read a packet from the same block of bytes +func TestWriteRead(t *testing.T) { + storage := NewMockPacketStorage() + writePacket := NewPacket(nil) + writePacket.SetCommand(expectedCommand) + writePacket.SetMessageID(expectedMessageID) + writePacket.SetNamespace(expectedNamespace) + writePacket.SetDataValue(expectedDataValue) + writePacket.SetHash(expectedHash) + + storage.state = truncate + err := WritePacket(storage, writePacket) + assert.ErrorIs(t, err, ErrBadWriteSize) + + storage.state = failure + err = WritePacket(storage, writePacket) + assert.ErrorIs(t, err, ErrBadWriteCall) + + storage.state = success + err = WritePacket(storage, writePacket) + assert.NoError(t, err) + + storage.state = truncate + readPacket, err := ReadPacket(storage) + assert.Nil(t, readPacket) + assert.ErrorIs(t, err, ErrBadReadSize) + + storage.state = failure + readPacket, err = ReadPacket(storage) + assert.Nil(t, readPacket) + assert.ErrorIs(t, err, ErrBadReadCall) + + storage.state = success + readPacket, err = ReadPacket(storage) + assert.NotNil(t, readPacket) + assert.NoError(t, err) + + assert.Equal(t, expectedCommand, readPacket.GetCommand()) + assert.Equal(t, expectedMessageID, readPacket.GetMessageID()) + assert.Equal(t, expectedNamespace, readPacket.GetNamespace()) + assert.Equal(t, expectedDataValue, readPacket.GetDataValue()) + assert.Equal(t, expectedHash, readPacket.GetHash()) +} + +// TestSignVerify checks the Sign/Verify functions using various keys and the hashes +func TestSignVerify(t *testing.T) { + goodKey := []byte("this is the good key") + badKey := []byte("this is the bad key") + packet := NewPacket(nil) + + packet.Sign(goodKey) + + err := packet.Verify(badKey) + assert.Error(t, err) + + err = packet.Verify(goodKey) + assert.NoError(t, err) +} + +func generateAlphaNumericString(size int) string { + chars := []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") + buffer := bytes.NewBuffer(nil) + buffer.Grow(size) + for buffer.Len() < size { + buffer.WriteByte(chars[rand.Intn(len(chars))]) + } + return buffer.String() +} diff --git a/pkg/store/store.go b/pkg/store/store.go new file mode 100644 index 0000000..620aef0 --- /dev/null +++ b/pkg/store/store.go @@ -0,0 +1,8 @@ +package store + +type Store struct { +} + +func Put(namespace, key, value string) error { + return nil +}