Skip to content

Commit b2a9109

Browse files
authored
Replace buf.Read with io.ReadFull to prevent partial read failures (#63)
1 parent 96960d9 commit b2a9109

File tree

1 file changed

+10
-36
lines changed

1 file changed

+10
-36
lines changed

subtree.go

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func DeserializeNodesFromReader(reader io.Reader) (subtreeBytes []byte, err erro
135135
// third 8 bytes, number of leaves
136136
// total read at once = len(st.rootHash[:]) + 8 + 8 + 8
137137
byteBuffer := make([]byte, chainhash.HashSize+24)
138-
if _, err = ReadBytes(buf, byteBuffer); err != nil {
138+
if _, err = io.ReadFull(buf, byteBuffer); err != nil {
139139
return nil, fmt.Errorf("unable to read subtree root information: %w", err)
140140
}
141141

@@ -144,7 +144,7 @@ func DeserializeNodesFromReader(reader io.Reader) (subtreeBytes []byte, err erro
144144

145145
byteBuffer = byteBuffer[8:] // reduce read byteBuffer size by 8
146146
for i := uint64(0); i < numLeaves; i++ {
147-
if _, err = ReadBytes(buf, byteBuffer); err != nil {
147+
if _, err = io.ReadFull(buf, byteBuffer); err != nil {
148148
return nil, fmt.Errorf("unable to read subtree node information: %w", err)
149149
}
150150

@@ -669,29 +669,23 @@ func (st *Subtree) DeserializeFromReader(reader io.Reader) (err error) {
669669

670670
buf := bufio.NewReaderSize(reader, 32*1024) // 32KB buffer
671671

672-
var (
673-
n int
674-
bytes8 = make([]byte, 8)
675-
)
672+
bytes8 := make([]byte, 8)
676673

677674
// read root hash
678675
st.rootHash = new(chainhash.Hash)
679-
if n, err = buf.Read(st.rootHash[:]); err != nil || n != chainhash.HashSize {
680-
// if _, err = io.ReadFull(buf, st.rootHash[:]); err != nil {
676+
if _, err = io.ReadFull(buf, st.rootHash[:]); err != nil {
681677
return fmt.Errorf("unable to read root hash: %w", err)
682678
}
683679

684680
// read fees
685-
if n, err = buf.Read(bytes8); err != nil || n != 8 {
686-
// if _, err = io.ReadFull(buf, bytes8); err != nil {
681+
if _, err = io.ReadFull(buf, bytes8); err != nil {
687682
return fmt.Errorf("unable to read fees: %w", err)
688683
}
689684

690685
st.Fees = binary.LittleEndian.Uint64(bytes8)
691686

692687
// read sizeInBytes
693-
if n, err = buf.Read(bytes8); err != nil || n != 8 {
694-
// if _, err = io.ReadFull(buf, bytes8); err != nil {
688+
if _, err = io.ReadFull(buf, bytes8); err != nil {
695689
return fmt.Errorf("unable to read sizeInBytes: %w", err)
696690
}
697691

@@ -713,8 +707,7 @@ func (st *Subtree) deserializeNodes(buf *bufio.Reader) error {
713707
bytes8 := make([]byte, 8)
714708

715709
// read number of leaves
716-
if n, err := buf.Read(bytes8); err != nil || n != 8 {
717-
// if _, err = io.ReadFull(buf, bytes8); err != nil {
710+
if _, err := io.ReadFull(buf, bytes8); err != nil {
718711
return fmt.Errorf("unable to read number of leaves: %w", err)
719712
}
720713

@@ -730,8 +723,7 @@ func (st *Subtree) deserializeNodes(buf *bufio.Reader) error {
730723
bytes48 := make([]byte, 48)
731724
for i := uint64(0); i < numLeaves; i++ {
732725
// read all the node data in 1 go
733-
if n, err := ReadBytes(buf, bytes48); err != nil || n != 48 {
734-
// if _, err = io.ReadFull(buf, bytes48); err != nil {
726+
if _, err := io.ReadFull(buf, bytes48); err != nil {
735727
return fmt.Errorf("unable to read node: %w", err)
736728
}
737729

@@ -748,8 +740,7 @@ func (st *Subtree) deserializeConflictingNodes(buf *bufio.Reader) error {
748740
bytes8 := make([]byte, 8)
749741

750742
// read the number of conflicting nodes
751-
if n, err := buf.Read(bytes8); err != nil || n != 8 {
752-
// if _, err = io.ReadFull(buf, bytes8); err != nil {
743+
if _, err := io.ReadFull(buf, bytes8); err != nil {
753744
return fmt.Errorf("unable to read number of conflicting nodes: %w", err)
754745
}
755746

@@ -759,31 +750,14 @@ func (st *Subtree) deserializeConflictingNodes(buf *bufio.Reader) error {
759750
st.ConflictingNodes = make([]chainhash.Hash, numConflictingLeaves)
760751

761752
for i := uint64(0); i < numConflictingLeaves; i++ {
762-
if n, err := buf.Read(st.ConflictingNodes[i][:]); err != nil || n != 32 {
753+
if _, err := io.ReadFull(buf, st.ConflictingNodes[i][:]); err != nil {
763754
return fmt.Errorf("unable to read conflicting node %d: %w", i, err)
764755
}
765756
}
766757

767758
return nil
768759
}
769760

770-
// ReadBytes reads bytes from the buffered reader into the provided byte slice.
771-
func ReadBytes(buf *bufio.Reader, p []byte) (n int, err error) {
772-
minRead := len(p)
773-
for n < minRead && err == nil {
774-
p[n], err = buf.ReadByte()
775-
n++
776-
}
777-
778-
if n >= minRead {
779-
err = nil
780-
} else if n > 0 && err == io.EOF {
781-
err = io.ErrUnexpectedEOF
782-
}
783-
784-
return n, err
785-
}
786-
787761
// DeserializeSubtreeConflictingFromReader deserializes the conflicting nodes from the provided reader.
788762
func DeserializeSubtreeConflictingFromReader(reader io.Reader) (conflictingNodes []chainhash.Hash, err error) {
789763
defer func() {

0 commit comments

Comments
 (0)