From 6a2c3cde8f105824e623b13ac583c9310e5b7caf Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Tue, 5 Apr 2016 14:07:21 -0700 Subject: [PATCH] blksize and tsize support --- README.md | 24 +++++++++++++ client.go | 24 ++++++++++--- packet.go | 87 ++++++++++++++++++++++++++++++++++---------- receiver.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++---- sender.go | 92 +++++++++++++++++++++++++++++++++++++++++++++-- server.go | 13 ++++--- tftp_test.go | 70 +++++++++++++++++++++++++++++++++--- 7 files changed, 366 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index a977b2a..6700452 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,12 @@ func writeHanlder(filename string, w io.WriterTo) error { fmt.Fprintf(os.Stderr, "%v\n", err) return err } + // In case client provides tsize option. + if t, ok := wt.(tftp.IncomingTransfer); ok { + if n, ok := t.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + } + } n, err := w.WriteTo(file) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) @@ -25,12 +31,20 @@ func writeHanlder(filename string, w io.WriterTo) error { fmt.Printf("%d bytes received\n", n) return nil } + func readHandler(filename string, r io.ReaderFrom) error { file, err := os.Open(filename) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return err } + // Optional tsize support. + // Set transfer size before calling ReadFrom. + if t, ok := rf.(tftp.OutgoingTransfer); ok { + if fi, err := file.Stat(); err == nil { + t.SetSize(fi.Size()) + } + } n, err := r.ReadFrom(file) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) @@ -69,6 +83,10 @@ r, err := c.Send("foobar.txt", "octet") if err != nil { ... } +// Optional tsize. +if ot, ok := r.(tftp.OutgoingTransfer); ok { + ot.SetSize(length) +} n, err := r.ReadFrom(file) fmt.Printf("%d bytes sent\n", n) ``` @@ -88,6 +106,12 @@ file, err := os.Create(path) if err != nil { ... } +// Optional tsize. +if it, ok := readTransfer.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + } +} n, err := w.WriteTo(file) if err != nil { ... diff --git a/client.go b/client.go index c5f13a9..c2056b2 100644 --- a/client.go +++ b/client.go @@ -1,10 +1,10 @@ package tftp import ( - "encoding/binary" "fmt" "io" "net" + "strconv" "time" ) @@ -28,6 +28,8 @@ type Client struct { addr *net.UDPAddr retry Retry timeout time.Duration + blksize int + tsize bool } func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) { @@ -44,14 +46,18 @@ func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) { addr: c.addr, mode: mode, } - n := packRQ(s.send, opWRQ, filename, mode) + if c.blksize != 0 { + s.opts = make(options) + s.opts["blksize"] = strconv.Itoa(c.blksize) + } + n := packRQ(s.send, opWRQ, filename, mode, s.opts) addr, err := s.sendWithRetry(n) if err != nil { return nil, err // wrap error } s.block++ s.addr = addr - binary.BigEndian.PutUint16(s.send[0:2], opDATA) + s.opts = nil return s, nil } @@ -73,7 +79,16 @@ func (c Client) Receive(filename string, mode string) (io.WriterTo, error) { autoTerm: true, mode: mode, } - n := packRQ(r.send, opRRQ, filename, mode) + if c.blksize != 0 || c.tsize { + r.opts = make(options) + } + if c.blksize != 0 { + r.opts["blksize"] = strconv.Itoa(c.blksize) + } + if c.tsize { + r.opts["tsize"] = "0" + } + n := packRQ(r.send, opRRQ, filename, mode, r.opts) r.block++ l, addr, err := r.receiveWithRetry(n) if err != nil { @@ -81,6 +96,5 @@ func (c Client) Receive(filename string, mode string) (io.WriterTo, error) { } r.l = l r.addr = addr - binary.BigEndian.PutUint16(r.send[0:2], opACK) return r, nil } diff --git a/packet.go b/packet.go index 85d9120..1ac7742 100644 --- a/packet.go +++ b/packet.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "fmt" - "strings" ) const ( @@ -21,6 +20,8 @@ const ( datagramLength = 516 ) +type options map[string]string + // RRQ/WRQ packet // // 2 bytes string 1 byte string 1 byte @@ -31,28 +32,71 @@ type pRRQ []byte type pWRQ []byte // packRQ returns length of the packet in b -func packRQ(b []byte, op uint16, filename, mode string) int { - binary.BigEndian.PutUint16(b, op) - n := copy(b[2:len(b)-10], filename) - b[2+n] = 0 - m := copy(b[3+n:len(b)-2], mode) - b[2+n+1+m] = 0 - return 2 + n + 1 + m + 1 +func packRQ(p []byte, op uint16, filename, mode string, opts options) int { + binary.BigEndian.PutUint16(p, op) + n := 2 + n += copy(p[2:len(p)-10], filename) + p[n] = 0 + n++ + n += copy(p[n:], mode) + p[n] = 0 + n++ + for name, value := range opts { + n += copy(p[n:], name) + p[n] = 0 + n++ + n += copy(p[n:], value) + p[n] = 0 + n++ + } + return n +} + +func unpackRQ(p []byte) (filename, mode string, opts options, err error) { + bs := bytes.Split(p[2:], []byte{0}) + if len(bs) < 2 { + return "", "", nil, fmt.Errorf("missing filename or mode") + } + filename = string(bs[0]) + mode = string(bs[1]) + if len(bs) < 4 { + return filename, mode, nil, nil + } + opts = make(options) + for i := 2; i+1 < len(bs); i += 2 { + opts[string(bs[i])] = string(bs[i+1]) + } + return filename, mode, opts, nil } -func unpackRQ(p []byte) (filename, mode string, err error) { - buffer := bytes.NewBuffer(p[2:]) - s, err := buffer.ReadString(0x0) - if err != nil { - return s, "", err +// OACK packet +// +// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ +// | Opcode | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 | +// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ +type pOACK []byte + +func packOACK(p []byte, opts options) int { + binary.BigEndian.PutUint16(p, opOACK) + n := 2 + for name, value := range opts { + n += copy(p[n:], name) + p[n] = 0 + n++ + n += copy(p[n:], value) + p[n] = 0 + n++ } - filename = strings.TrimSpace(strings.Trim(s, "\x00")) - s, err = buffer.ReadString(0x0) - if err != nil { - return filename, s, err + return n +} + +func unpackOACK(p []byte) (opts options, err error) { + bs := bytes.Split(p[2:], []byte{0}) + opts = make(options) + for i := 0; i+1 < len(bs); i += 2 { + opts[string(bs[i])] = string(bs[i+1]) } - mode = strings.TrimSpace(strings.Trim(s, "\x00")) - return filename, mode, nil + return opts, nil } // ERROR packet @@ -135,6 +179,11 @@ func parsePacket(p []byte) (interface{}, error) { return nil, fmt.Errorf("short ERROR packet: %d", l) } return pERROR(p), nil + case opOACK: + if l < 6 { + return nil, fmt.Errorf("short OACK packet: %d", l) + } + return pOACK(p), nil default: return nil, fmt.Errorf("unknown opcode: %d", opcode) } diff --git a/receiver.go b/receiver.go index e923ad4..67fbd17 100644 --- a/receiver.go +++ b/receiver.go @@ -5,11 +5,29 @@ import ( "fmt" "io" "net" + "strconv" "time" "github.com/pin/tftp/netascii" ) +type IncomingTransfer interface { + Size() (n int64, ok bool) +} + +func (r *receiver) Size() (n int64, ok bool) { + if r.opts != nil { + if s, ok := r.opts["tsize"]; ok { + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, false + } + return n, true + } + } + return 0, false +} + type receiver struct { send []byte receive []byte @@ -22,12 +40,21 @@ type receiver struct { autoTerm bool dally bool mode string + opts options } func (r *receiver) WriteTo(w io.Writer) (n int64, err error) { if r.mode == "netascii" { w = netascii.FromWriter(w) } + if r.opts != nil { + err := r.sendOptions() + if err != nil { + r.abort(err) + return 0, err + } + } + binary.BigEndian.PutUint16(r.send[0:2], opACK) for { if r.l > 0 { l, err := w.Write(r.receive[4:r.l]) @@ -36,7 +63,7 @@ func (r *receiver) WriteTo(w io.Writer) (n int64, err error) { r.abort(err) return n, err } - if r.l-4 < blockLength { + if r.l < len(r.receive) { if r.autoTerm { r.terminate() } @@ -54,12 +81,53 @@ func (r *receiver) WriteTo(w io.Writer) (n int64, err error) { } } -func (s *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) { - s.retry.Reset() +func (r *receiver) sendOptions() error { + for name, value := range r.opts { + if name == "blksize" { + err := r.setBlockSize(value) + if err != nil { + delete(r.opts, name) + continue + } + } else { + delete(r.opts, name) + } + } + if len(r.opts) > 0 { + m := packOACK(r.send, r.opts) + r.block = 1 + ll, _, err := r.receiveWithRetry(m) + if err != nil { + r.abort(err) + return err + } + r.block = 1 + r.l = ll + } + return nil +} + +func (r *receiver) setBlockSize(blksize string) error { + n, err := strconv.Atoi(blksize) + if err != nil { + return err + } + if n < 512 { + return fmt.Errorf("blkzise too small: %d", n) + } + if n > 65464 { + return fmt.Errorf("blksize too large: %d", n) + } + r.receive = make([]byte, n+4) + return nil +} + +func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) { + r.retry.Reset() for { - n, addr, err := s.receiveDatagram(l) - if _, ok := err.(net.Error); ok && s.retry.Count() < 3 { - s.retry.Backoff() + n, addr, err := r.receiveDatagram(l) + if _, ok := err.(net.Error); ok && r.retry.Count() < 3 { + r.retry.Backoff() continue } return n, addr, err @@ -91,6 +159,26 @@ func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) { if p.block() == r.block { return c, addr, nil } + case pOACK: + opts, err := unpackOACK(p) + if r.block != 1 { + continue + } + if err != nil { + r.abort(err) + return 0, addr, err + } + for name, value := range opts { + if name == "blksize" { + err := r.setBlockSize(value) + if err != nil { + continue + } + } + } + r.block = 0 + r.opts = opts + return 0, addr, nil case pERROR: return 0, addr, fmt.Errorf("code: %d, message: %s", p.code(), p.message()) diff --git a/sender.go b/sender.go index 54900bf..229b908 100644 --- a/sender.go +++ b/sender.go @@ -5,11 +5,16 @@ import ( "fmt" "io" "net" + "strconv" "time" "github.com/pin/tftp/netascii" ) +type OutgoingTransfer interface { + SetSize(n int64) +} + type sender struct { conn *net.UDPConn addr *net.UDPAddr @@ -19,12 +24,29 @@ type sender struct { timeout time.Duration block uint16 mode string + opts options +} + +func (s *sender) SetSize(n int64) { + if s.opts != nil { + if _, ok := s.opts["tsize"]; ok { + s.opts["tsize"] = strconv.FormatInt(n, 10) + } + } } func (s *sender) ReadFrom(r io.Reader) (n int64, err error) { if s.mode == "netascii" { r = netascii.ToReader(r) } + if s.opts != nil { + err = s.sendOptions() + if err != nil { + s.abort(err) + return 0, err + } + } + binary.BigEndian.PutUint16(s.send[0:2], opDATA) for { l, err := io.ReadFull(r, s.send[4:]) n += int64(l) @@ -47,13 +69,59 @@ func (s *sender) ReadFrom(r io.Reader) (n int64, err error) { s.abort(err) return n, err } - if l < blockLength { + if l < len(s.send)-4 { return n, nil } s.block++ } } +func (s *sender) sendOptions() error { + for name, value := range s.opts { + if name == "blksize" { + err := s.setBlockSize(value) + if err != nil { + delete(s.opts, name) + continue + } + } else if name == "tsize" { + if value != "0" { + s.opts["tsize"] = value + } else { + delete(s.opts, name) + continue + } + } else { + delete(s.opts, name) + } + } + if len(s.opts) > 0 { + m := packOACK(s.send, s.opts) + s.block = 0 + _, err := s.sendWithRetry(m) + if err != nil { + return err + } + s.block = 1 + } + return nil +} + +func (s *sender) setBlockSize(blksize string) error { + n, err := strconv.Atoi(blksize) + if err != nil { + return err + } + if n < 512 { + return fmt.Errorf("blkzise too small: %d", n) + } + if n > 65464 { + return fmt.Errorf("blksize too large: %d", n) + } + s.send = make([]byte, n+4) + return nil +} + func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) { s.retry.Reset() for { @@ -86,10 +154,28 @@ func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) { } switch p := p.(type) { case pACK: - block := p.block() - if s.block == block { + if p.block() == s.block { return addr, nil } + case pOACK: + opts, err := unpackOACK(p) + if false && s.block != 1 { + continue + } + if err != nil { + s.abort(err) + return addr, err + } + for name, value := range opts { + if name == "blksize" { + err := s.setBlockSize(value) + if err != nil { + continue + } + } + } + s.block = 0 + return addr, nil case pERROR: return nil, fmt.Errorf("sending block %d: code=%d, error: %s", s.block, p.code(), p.message()) diff --git a/server.go b/server.go index 7eff74a..b695990 100644 --- a/server.go +++ b/server.go @@ -1,7 +1,6 @@ package tftp import ( - "encoding/binary" "fmt" "io" "net" @@ -82,11 +81,11 @@ func (s *Server) processRequest(conn *net.UDPConn) error { } switch p := p.(type) { case pWRQ: - filename, mode, err := unpackRQ(p) + filename, mode, opts, err := unpackRQ(p) if err != nil { return fmt.Errorf("unpack WRQ: %v", err) } - //fmt.Printf("got WRQ (filename=%s, mode=%s)\n", filename, mode) + //fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts) transmissionConn, err := transmissionConn() if err != nil { return fmt.Errorf("open transmission: %v", err) @@ -99,8 +98,8 @@ func (s *Server) processRequest(conn *net.UDPConn) error { timeout: s.timeout, addr: remoteAddr, mode: mode, + opts: opts, } - binary.BigEndian.PutUint16(wt.send[0:2], opACK) s.wg.Add(1) go func() { err := s.writeHandler(filename, wt) @@ -112,11 +111,11 @@ func (s *Server) processRequest(conn *net.UDPConn) error { s.wg.Done() }() case pRRQ: - filename, mode, err := unpackRQ(p) + filename, mode, opts, err := unpackRQ(p) if err != nil { return fmt.Errorf("unpack RRQ: %v", err) } - //fmt.Printf("got RRQ (filename=%s, mode=%s)\n", filename, mode) + //fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts) transmissionConn, err := transmissionConn() if err != nil { return fmt.Errorf("open transmission: %v", err) @@ -130,8 +129,8 @@ func (s *Server) processRequest(conn *net.UDPConn) error { addr: remoteAddr, block: 1, mode: mode, + opts: opts, } - binary.BigEndian.PutUint16(rf.send[0:2], opDATA) s.wg.Add(1) go func() { err := s.readHandler(filename, rf) diff --git a/tftp_test.go b/tftp_test.go index 11e467f..308c6c0 100644 --- a/tftp_test.go +++ b/tftp_test.go @@ -16,18 +16,27 @@ import ( func TestPackUnpack(t *testing.T) { v := []string{"test-filename/with-subdir"} + testOptsList := []options{ + nil, + options{ + "tsize": "1234", + "blksize": "22", + }, + } for _, filename := range v { for _, mode := range []string{"octet", "netascii"} { - packUnpack(t, filename, mode) + for _, opts := range testOptsList { + packUnpack(t, filename, mode, opts) + } } } } -func packUnpack(t *testing.T, filename, mode string) { +func packUnpack(t *testing.T, filename, mode string, opts options) { b := make([]byte, datagramLength) for _, op := range []uint16{opRRQ, opWRQ} { - n := packRQ(b, op, filename, mode) - f, m, err := unpackRQ(b[:n]) + n := packRQ(b, op, filename, mode, opts) + f, m, o, err := unpackRQ(b[:n]) if err != nil { t.Errorf("%s pack/unpack: %v", filename, err) } @@ -39,6 +48,17 @@ func packUnpack(t *testing.T, filename, mode string) { t.Errorf("mode mismatch (%s): '%x' vs '%x'", mode, m, mode) } + if opts != nil { + for name, value := range opts { + v, ok := o[name] + if !ok { + t.Errorf("missing %s option", name) + } + if v != value { + t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) + } + } + } } } @@ -48,6 +68,29 @@ func TestZeroLength(t *testing.T) { testSendReceive(t, c, 0) } +func Test900(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + for i := 600; i < 4000; i += 10 { + c.blksize = i + testSendReceive(t, c, 9000+int64(i)) + } +} + +func Test1810(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + c.blksize = 1810 + testSendReceive(t, c, 9000+1810) +} + +func TestTSize(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + c.tsize = true + testSendReceive(t, c, 640) +} + func TestNearBlockLength(t *testing.T) { s, c := makeTestServer() defer s.Shutdown() @@ -150,6 +193,9 @@ func testSendReceive(t *testing.T, client *Client, length int64) { if err != nil { t.Fatalf("requesting write %s: %v", filename, err) } + if ot, ok := writeTransfer.(OutgoingTransfer); ok { + ot.SetSize(length) + } r := io.LimitReader(newRandReader(rand.NewSource(42)), length) n, err := writeTransfer.ReadFrom(r) if err != nil { @@ -162,6 +208,14 @@ func testSendReceive(t *testing.T, client *Client, length int64) { if err != nil { t.Fatalf("requesting read %s: %v", filename, err) } + if it, ok := readTransfer.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + if n != length { + t.Errorf("tsize mismatch: %d vs %d", n, length) + } + } + } buf := &bytes.Buffer{} n, err = readTransfer.WriteTo(buf) if err != nil { @@ -217,6 +271,11 @@ func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) return fmt.Errorf("file already exists") } + if t, ok := wt.(IncomingTransfer); ok { + if n, ok := t.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + } + } buf := &bytes.Buffer{} n, err := wt.WriteTo(buf) if err != nil { @@ -236,6 +295,9 @@ func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { fmt.Fprintf(os.Stderr, "File %s not found\n", filename) return fmt.Errorf("file not found") } + if t, ok := rf.(OutgoingTransfer); ok { + t.SetSize(int64(len(bs))) + } n, err := rf.ReadFrom(bytes.NewBuffer(bs)) if err != nil { fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err)