Skip to content

Commit

Permalink
zmq4: add r/w test for greeting
Browse files Browse the repository at this point in the history
Updates go-zeromq#56.
  • Loading branch information
sbinet committed Jan 20, 2020
1 parent 78ce94b commit c154975
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 16 deletions.
30 changes: 14 additions & 16 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const (
hasMoreBitFlag = 0x1
isLongBitFlag = 0x2
isCommandBitFlag = 0x4

zmtpMsgLen = 64
)

var (
Expand Down Expand Up @@ -90,46 +92,42 @@ type greeting struct {
}

func (g *greeting) read(r io.Reader) error {
var data [64]byte
var data [zmtpMsgLen]byte
_, err := io.ReadFull(r, data[:])
if err != nil {
return err
return xerrors.Errorf("could not read ZMTP greeting: %w", err)
}

err = g.unmarshal(data[:])
if err != nil {
return err
}
g.unmarshal(data[:])

if g.Sig.Header != sigHeader {
return errGreeting
return xerrors.Errorf("invalid ZMTP signature header: %w", errGreeting)
}

if g.Sig.Footer != sigFooter {
return errGreeting
return xerrors.Errorf("invalid ZMTP signature footer: %w", errGreeting)
}

// FIXME(sbinet): handle version negotiations as per
// https://rfc.zeromq.org/spec:23/ZMTP/#version-negotiation
if g.Version != defaultVersion {
return errGreeting
return xerrors.Errorf(
"invalid ZMTP version (got=%v, want=%v): %w",
g.Version, defaultVersion, errGreeting,
)
}

return nil
}

func (g *greeting) unmarshal(data []byte) error {
if len(data) < 64 {
return io.ErrShortBuffer
}
_ = data[:64]
func (g *greeting) unmarshal(data []byte) {
_ = data[:zmtpMsgLen]
g.Sig.Header = data[0]
g.Sig.Footer = data[9]
g.Version[0] = data[10]
g.Version[1] = data[11]
copy(g.Mechanism[:], data[12:32])
g.Server = data[32]
return nil
}

func (g *greeting) write(w io.Writer) error {
Expand All @@ -138,7 +136,7 @@ func (g *greeting) write(w io.Writer) error {
}

func (g *greeting) marshal() []byte {
var buf [64]byte
var buf [zmtpMsgLen]byte
buf[0] = g.Sig.Header
// padding 1 ignored
buf[9] = g.Sig.Footer
Expand Down
115 changes: 115 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zmq4

import (
"bytes"
"io"
"testing"

"golang.org/x/xerrors"
)

func TestGreeting(t *testing.T) {
for _, tc := range []struct {
name string
data []byte
want error
}{
{
name: "valid",
data: func() []byte {
w := new(bytes.Buffer)
g := greeting{
Version: defaultVersion,
}
g.Sig.Header = sigHeader
g.Sig.Footer = sigFooter
g.write(w)
return w.Bytes()
}(),
},
{
name: "empty-buffer",
data: nil,
want: xerrors.Errorf("could not read ZMTP greeting: %w", io.EOF),
},
{
name: "unexpected-EOF",
data: make([]byte, 1),
want: xerrors.Errorf("could not read ZMTP greeting: %w", io.ErrUnexpectedEOF),
},
{
name: "invalid-header",
data: func() []byte {
w := new(bytes.Buffer)
g := greeting{
Version: defaultVersion,
}
g.Sig.Header = sigFooter // err
g.Sig.Footer = sigFooter
g.write(w)
return w.Bytes()
}(),
want: xerrors.Errorf("invalid ZMTP signature header: %w", errGreeting),
},
{
name: "invalid-footer",
data: func() []byte {
w := new(bytes.Buffer)
g := greeting{
Version: defaultVersion,
}
g.Sig.Header = sigHeader
g.Sig.Footer = sigHeader // err
g.write(w)
return w.Bytes()
}(),
want: xerrors.Errorf("invalid ZMTP signature footer: %w", errGreeting),
},
{
name: "invalid-version", // FIXME(sbinet): adapt for when/if we support multiple ZMTP versions
data: func() []byte {
w := new(bytes.Buffer)
g := greeting{
Version: [2]uint8{1, 1},
}
g.Sig.Header = sigHeader
g.Sig.Footer = sigFooter
g.write(w)
return w.Bytes()
}(),
want: xerrors.Errorf("invalid ZMTP version (got=%v, want=%v): %w",
[2]uint{1, 1},
defaultVersion,
errGreeting,
),
},
} {
t.Run(tc.name, func(t *testing.T) {
var (
g greeting
r = bytes.NewReader(tc.data)
)

err := g.read(r)
switch {
case err == nil && tc.want == nil:
// ok
case err == nil && tc.want != nil:
t.Fatalf("expected an error (%s)", tc.want)
case err != nil && tc.want == nil:
t.Fatalf("could not read ZMTP greeting: %+v", err)
case err != nil && tc.want != nil:
if got, want := err.Error(), tc.want.Error(); got != want {
t.Fatalf("invalid ZMTP greeting error:\ngot= %+v\nwant=%+v\n",
got, want,
)
}
}

})
}
}

0 comments on commit c154975

Please sign in to comment.