Skip to content

Commit

Permalink
blksize and tsize support
Browse files Browse the repository at this point in the history
  • Loading branch information
pin committed Apr 5, 2016
1 parent 5d4f062 commit 6a2c3cd
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 44 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ func writeHanlder(filename string, w io.WriterTo) error {
fmt.Fprintf(os.Stderr, "%v\n", err)
return err
}
// In case client provides tsize option.
if t, ok := wt.(tftp.IncomingTransfer); ok {
if n, ok := t.Size(); ok {
fmt.Printf("Transfer size: %d\n", n)
}
}
n, err := w.WriteTo(file)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
Expand All @@ -25,12 +31,20 @@ func writeHanlder(filename string, w io.WriterTo) error {
fmt.Printf("%d bytes received\n", n)
return nil
}

func readHandler(filename string, r io.ReaderFrom) error {
file, err := os.Open(filename)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
return err
}
// Optional tsize support.
// Set transfer size before calling ReadFrom.
if t, ok := rf.(tftp.OutgoingTransfer); ok {
if fi, err := file.Stat(); err == nil {
t.SetSize(fi.Size())
}
}
n, err := r.ReadFrom(file)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
Expand Down Expand Up @@ -69,6 +83,10 @@ r, err := c.Send("foobar.txt", "octet")
if err != nil {
...
}
// Optional tsize.
if ot, ok := r.(tftp.OutgoingTransfer); ok {
ot.SetSize(length)
}
n, err := r.ReadFrom(file)
fmt.Printf("%d bytes sent\n", n)
```
Expand All @@ -88,6 +106,12 @@ file, err := os.Create(path)
if err != nil {
...
}
// Optional tsize.
if it, ok := readTransfer.(IncomingTransfer); ok {
if n, ok := it.Size(); ok {
fmt.Printf("Transfer size: %d\n", n)
}
}
n, err := w.WriteTo(file)
if err != nil {
...
Expand Down
24 changes: 19 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package tftp

import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"time"
)

Expand All @@ -28,6 +28,8 @@ type Client struct {
addr *net.UDPAddr
retry Retry
timeout time.Duration
blksize int
tsize bool
}

func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) {
Expand All @@ -44,14 +46,18 @@ func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) {
addr: c.addr,
mode: mode,
}
n := packRQ(s.send, opWRQ, filename, mode)
if c.blksize != 0 {
s.opts = make(options)
s.opts["blksize"] = strconv.Itoa(c.blksize)
}
n := packRQ(s.send, opWRQ, filename, mode, s.opts)
addr, err := s.sendWithRetry(n)
if err != nil {
return nil, err // wrap error
}
s.block++
s.addr = addr
binary.BigEndian.PutUint16(s.send[0:2], opDATA)
s.opts = nil
return s, nil
}

Expand All @@ -73,14 +79,22 @@ func (c Client) Receive(filename string, mode string) (io.WriterTo, error) {
autoTerm: true,
mode: mode,
}
n := packRQ(r.send, opRRQ, filename, mode)
if c.blksize != 0 || c.tsize {
r.opts = make(options)
}
if c.blksize != 0 {
r.opts["blksize"] = strconv.Itoa(c.blksize)
}
if c.tsize {
r.opts["tsize"] = "0"
}
n := packRQ(r.send, opRRQ, filename, mode, r.opts)
r.block++
l, addr, err := r.receiveWithRetry(n)
if err != nil {
return nil, err
}
r.l = l
r.addr = addr
binary.BigEndian.PutUint16(r.send[0:2], opACK)
return r, nil
}
87 changes: 68 additions & 19 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/binary"
"fmt"
"strings"
)

const (
Expand All @@ -21,6 +20,8 @@ const (
datagramLength = 516
)

type options map[string]string

// RRQ/WRQ packet
//
// 2 bytes string 1 byte string 1 byte
Expand All @@ -31,28 +32,71 @@ type pRRQ []byte
type pWRQ []byte

// packRQ returns length of the packet in b
func packRQ(b []byte, op uint16, filename, mode string) int {
binary.BigEndian.PutUint16(b, op)
n := copy(b[2:len(b)-10], filename)
b[2+n] = 0
m := copy(b[3+n:len(b)-2], mode)
b[2+n+1+m] = 0
return 2 + n + 1 + m + 1
func packRQ(p []byte, op uint16, filename, mode string, opts options) int {
binary.BigEndian.PutUint16(p, op)
n := 2
n += copy(p[2:len(p)-10], filename)
p[n] = 0
n++
n += copy(p[n:], mode)
p[n] = 0
n++
for name, value := range opts {
n += copy(p[n:], name)
p[n] = 0
n++
n += copy(p[n:], value)
p[n] = 0
n++
}
return n
}

func unpackRQ(p []byte) (filename, mode string, opts options, err error) {
bs := bytes.Split(p[2:], []byte{0})
if len(bs) < 2 {
return "", "", nil, fmt.Errorf("missing filename or mode")
}
filename = string(bs[0])
mode = string(bs[1])
if len(bs) < 4 {
return filename, mode, nil, nil
}
opts = make(options)
for i := 2; i+1 < len(bs); i += 2 {
opts[string(bs[i])] = string(bs[i+1])
}
return filename, mode, opts, nil
}

func unpackRQ(p []byte) (filename, mode string, err error) {
buffer := bytes.NewBuffer(p[2:])
s, err := buffer.ReadString(0x0)
if err != nil {
return s, "", err
// OACK packet
//
// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
// | Opcode | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 |
// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
type pOACK []byte

func packOACK(p []byte, opts options) int {
binary.BigEndian.PutUint16(p, opOACK)
n := 2
for name, value := range opts {
n += copy(p[n:], name)
p[n] = 0
n++
n += copy(p[n:], value)
p[n] = 0
n++
}
filename = strings.TrimSpace(strings.Trim(s, "\x00"))
s, err = buffer.ReadString(0x0)
if err != nil {
return filename, s, err
return n
}

func unpackOACK(p []byte) (opts options, err error) {
bs := bytes.Split(p[2:], []byte{0})
opts = make(options)
for i := 0; i+1 < len(bs); i += 2 {
opts[string(bs[i])] = string(bs[i+1])
}
mode = strings.TrimSpace(strings.Trim(s, "\x00"))
return filename, mode, nil
return opts, nil
}

// ERROR packet
Expand Down Expand Up @@ -135,6 +179,11 @@ func parsePacket(p []byte) (interface{}, error) {
return nil, fmt.Errorf("short ERROR packet: %d", l)
}
return pERROR(p), nil
case opOACK:
if l < 6 {
return nil, fmt.Errorf("short OACK packet: %d", l)
}
return pOACK(p), nil
default:
return nil, fmt.Errorf("unknown opcode: %d", opcode)
}
Expand Down
100 changes: 94 additions & 6 deletions receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,29 @@ import (
"fmt"
"io"
"net"
"strconv"
"time"

"github.com/pin/tftp/netascii"
)

type IncomingTransfer interface {
Size() (n int64, ok bool)
}

func (r *receiver) Size() (n int64, ok bool) {
if r.opts != nil {
if s, ok := r.opts["tsize"]; ok {
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, false
}
return n, true
}
}
return 0, false
}

type receiver struct {
send []byte
receive []byte
Expand All @@ -22,12 +40,21 @@ type receiver struct {
autoTerm bool
dally bool
mode string
opts options
}

func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
if r.mode == "netascii" {
w = netascii.FromWriter(w)
}
if r.opts != nil {
err := r.sendOptions()
if err != nil {
r.abort(err)
return 0, err
}
}
binary.BigEndian.PutUint16(r.send[0:2], opACK)
for {
if r.l > 0 {
l, err := w.Write(r.receive[4:r.l])
Expand All @@ -36,7 +63,7 @@ func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
r.abort(err)
return n, err
}
if r.l-4 < blockLength {
if r.l < len(r.receive) {
if r.autoTerm {
r.terminate()
}
Expand All @@ -54,12 +81,53 @@ func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
}
}

func (s *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) {
s.retry.Reset()
func (r *receiver) sendOptions() error {
for name, value := range r.opts {
if name == "blksize" {
err := r.setBlockSize(value)
if err != nil {
delete(r.opts, name)
continue
}
} else {
delete(r.opts, name)
}
}
if len(r.opts) > 0 {
m := packOACK(r.send, r.opts)
r.block = 1
ll, _, err := r.receiveWithRetry(m)
if err != nil {
r.abort(err)
return err
}
r.block = 1
r.l = ll
}
return nil
}

func (r *receiver) setBlockSize(blksize string) error {
n, err := strconv.Atoi(blksize)
if err != nil {
return err
}
if n < 512 {
return fmt.Errorf("blkzise too small: %d", n)
}
if n > 65464 {
return fmt.Errorf("blksize too large: %d", n)
}
r.receive = make([]byte, n+4)
return nil
}

func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) {
r.retry.Reset()
for {
n, addr, err := s.receiveDatagram(l)
if _, ok := err.(net.Error); ok && s.retry.Count() < 3 {
s.retry.Backoff()
n, addr, err := r.receiveDatagram(l)
if _, ok := err.(net.Error); ok && r.retry.Count() < 3 {
r.retry.Backoff()
continue
}
return n, addr, err
Expand Down Expand Up @@ -91,6 +159,26 @@ func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) {
if p.block() == r.block {
return c, addr, nil
}
case pOACK:
opts, err := unpackOACK(p)
if r.block != 1 {
continue
}
if err != nil {
r.abort(err)
return 0, addr, err
}
for name, value := range opts {
if name == "blksize" {
err := r.setBlockSize(value)
if err != nil {
continue
}
}
}
r.block = 0
r.opts = opts
return 0, addr, nil
case pERROR:
return 0, addr, fmt.Errorf("code: %d, message: %s",
p.code(), p.message())
Expand Down
Loading

0 comments on commit 6a2c3cd

Please sign in to comment.