Skip to content

Commit

Permalink
fix: add support for attrs in sshFxpOpenPacket for Server
Browse files Browse the repository at this point in the history
  • Loading branch information
mafredri committed Dec 12, 2023
1 parent 5bdc2b0 commit c771381
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 26 deletions.
17 changes: 13 additions & 4 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,13 @@ type sshFxpOpenPacket struct {
ID uint32
Path string
Pflags uint32
Flags uint32 // ignored
Flags uint32
Attrs interface{}
}

func (p *sshFxpOpenPacket) id() uint32 { return p.ID }

func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 + 4
Expand All @@ -698,7 +699,14 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
b = marshalUint32(b, p.Pflags)
b = marshalUint32(b, p.Flags)

return b, nil
payload := marshal(nil, p.Attrs)

return b, payload, nil
}

func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}

func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
Expand All @@ -709,9 +717,10 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}

Expand Down
18 changes: 18 additions & 0 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,24 @@ func TestSendPacket(t *testing.T) {
0x0, 0x0, 0x0, 0x0,
},
},
{
packet: &sshFxpOpenPacket{
ID: 3,
Path: "/foo",
Pflags: flags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC),
Flags: sshFileXferAttrPermissions,
Attrs: []uint8{0x0, 0x0, 0x1, 0xed}, // 0o755
},
want: []byte{
0x0, 0x0, 0x0, 0x19,
0x3,
0x0, 0x0, 0x0, 0x3,
0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o',
0x0, 0x0, 0x0, 0x1a,
0x0, 0x0, 0x0, 0x4,
0x0, 0x0, 0x1, 0xed,
},
},
{
packet: &sshFxpWritePacket{
ID: 124,
Expand Down
113 changes: 91 additions & 22 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,59 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
return statusFromError(p.ID, err)
}

// Both `sshFileXferAttrPermissions` and `sshFileXferAttrACmodTime` are set
// by e.g. `sftp`. Just in case, we handle all other cases as well.
if b, ok := p.Attrs.([]byte); ok {
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = f.Truncate(int64(size))
}
}
if err != nil {
_ = f.Close()
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
}
}
if err != nil {
_ = f.Close()
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
}
}
if err != nil {
_ = f.Close()
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, _, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
}
}
if err != nil {
_ = f.Close()
return statusFromError(p.ID, err)
}
}

handle := svr.nextHandle(f)
return &sshFxpHandlePacket{ID: p.ID, Handle: handle}
}
Expand Down Expand Up @@ -509,33 +562,41 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
err = os.Truncate(p.Path, int64(size))
}
}
if err != nil {
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
}
}
if err != nil {
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = os.Chmod(p.Path, os.FileMode(mode))
}
}
if err != nil {
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, _, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(p.Path, atimeT, mtimeT)
}
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
}
}

return statusFromError(p.ID, err)
}

Expand All @@ -556,33 +617,41 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
err = f.Truncate(int64(size))
}
}
if err != nil {
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
}
}
if err != nil {
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
}
}
if err != nil {
return statusFromError(p.ID, err)
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, _, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
}
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
}
}

return statusFromError(p.ID, err)
}

Expand Down
26 changes: 26 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,32 @@ func TestOpenStatRace(t *testing.T) {
checkServerAllocator(t, server)
}

func TestOpenWithPermissions(t *testing.T) {
client, server := clientServerPair(t)
defer client.Close()
defer server.Close()

tmppath := path.Join(os.TempDir(), "open_permissions")
pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
ch := make(chan result, 2)
id1 := client.nextID()
client.dispatchRequest(ch, &sshFxpOpenPacket{
ID: id1,
Path: tmppath,
Pflags: pflags,
Flags: sshFileXferAttrPermissions,
Attrs: []byte{0x0, 0x0, 0x1, 0xe5}, // 0o745 -- a slightly strange permission to test.
})
<-ch
stat, err := os.Stat(tmppath)
assert.NoError(t, err)
if !assert.Equal(t, os.FileMode(0o745), stat.Mode()&os.ModePerm) {
t.Logf("stat.Mode() = %v", stat.Mode())
}
os.Remove(tmppath)
checkServerAllocator(t, server)
}

// Ensure that proper error codes are returned for non existent files, such
// that they are mapped back to a 'not exists' error on the client side.
func TestStatNonExistent(t *testing.T) {
Expand Down

0 comments on commit c771381

Please sign in to comment.