Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve todos in the codebase #90

Merged
merged 9 commits into from
Oct 31, 2024
18 changes: 1 addition & 17 deletions rvgo/fast/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ func (m *Memory) pageLookup(pageIndex uint64) (*CachedPage, bool) {
return p, ok
}

// TODO: we never do unaligned writes, this should be simplified
func (m *Memory) SetUnaligned(addr uint64, dat []byte) {
func (m *Memory) SetAligned(addr uint64, dat []byte) {
if len(dat) > 32 {
panic("cannot set more than 32 bytes")
}
Expand All @@ -200,21 +199,6 @@ func (m *Memory) SetUnaligned(addr uint64, dat []byte) {
if d == len(dat) {
return // if all the data fitted in the page, we're done
}
mininny marked this conversation as resolved.
Show resolved Hide resolved

// continue to remaining part
mininny marked this conversation as resolved.
Show resolved Hide resolved
addr += uint64(d)
pageIndex = addr >> PageAddrSize
pageAddr = addr & PageAddrMask
p, ok = m.pageLookup(pageIndex)
if !ok {
// allocate the page if we have not already.
// Go may mmap relatively large ranges, but we only allocate the pages just in time.
p = m.AllocPage(pageIndex)
} else {
m.Invalidate(addr) // invalidate this branch of memory, now that the value changed
}

copy(p.Data[pageAddr:], dat)
}

func (m *Memory) GetUnaligned(addr uint64, dest []byte) {
Expand Down
44 changes: 22 additions & 22 deletions rvgo/fast/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func TestMemoryMerkleProof(t *testing.T) {
t.Run("nearly empty tree", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0x10000, []byte{0xaa, 0xbb, 0xcc, 0xdd})
m.SetAligned(0x10000, []byte{0xaa, 0xbb, 0xcc, 0xdd})
proof := m.MerkleProof(0x10000)
require.Equal(t, uint32(0xaabbccdd), binary.BigEndian.Uint32(proof[:4]))
for i := 0; i < 32-5; i++ {
Expand All @@ -24,9 +24,9 @@ func TestMemoryMerkleProof(t *testing.T) {
})
t.Run("fuller tree", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0x10000, []byte{0xaa, 0xbb, 0xcc, 0xdd})
m.SetUnaligned(0x80004, []byte{42})
m.SetUnaligned(0x13370000, []byte{123})
m.SetAligned(0x10000, []byte{0xaa, 0xbb, 0xcc, 0xdd})
m.SetAligned(0x80004, []byte{42})
m.SetAligned(0x13370000, []byte{123})
root := m.MerkleRoot()
proof := m.MerkleProof(0x80004)
require.Equal(t, uint32(42<<24), binary.BigEndian.Uint32(proof[4:8]))
Expand All @@ -53,35 +53,35 @@ func TestMemoryMerkleRoot(t *testing.T) {
})
t.Run("empty page", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0xF000, []byte{0})
m.SetAligned(0xF000, []byte{0})
root := m.MerkleRoot()
require.Equal(t, zeroHashes[64-5], root, "fully zeroed memory should have expected zero hash")
})
t.Run("single page", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0xF000, []byte{1})
m.SetAligned(0xF000, []byte{1})
root := m.MerkleRoot()
require.NotEqual(t, zeroHashes[64-5], root, "non-zero memory")
})
t.Run("repeat zero", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0xF000, []byte{0})
m.SetUnaligned(0xF004, []byte{0})
m.SetAligned(0xF000, []byte{0})
m.SetAligned(0xF004, []byte{0})
root := m.MerkleRoot()
require.Equal(t, zeroHashes[64-5], root, "zero still")
})
t.Run("two empty pages", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(PageSize*3, []byte{0})
m.SetUnaligned(PageSize*10, []byte{0})
m.SetAligned(PageSize*3, []byte{0})
m.SetAligned(PageSize*10, []byte{0})
root := m.MerkleRoot()
require.Equal(t, zeroHashes[64-5], root, "zero still")
})
t.Run("random few pages", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(PageSize*3, []byte{1})
m.SetUnaligned(PageSize*5, []byte{42})
m.SetUnaligned(PageSize*6, []byte{123})
m.SetAligned(PageSize*3, []byte{1})
m.SetAligned(PageSize*5, []byte{42})
m.SetAligned(PageSize*6, []byte{123})
p3 := m.MerkleizeSubtree((1 << PageKeySize) | 3)
p5 := m.MerkleizeSubtree((1 << PageKeySize) | 5)
p6 := m.MerkleizeSubtree((1 << PageKeySize) | 6)
Expand All @@ -101,11 +101,11 @@ func TestMemoryMerkleRoot(t *testing.T) {
})
t.Run("invalidate page", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0xF000, []byte{0})
m.SetAligned(0xF000, []byte{0})
require.Equal(t, zeroHashes[64-5], m.MerkleRoot(), "zero at first")
m.SetUnaligned(0xF004, []byte{1})
m.SetAligned(0xF004, []byte{1})
require.NotEqual(t, zeroHashes[64-5], m.MerkleRoot(), "non-zero")
m.SetUnaligned(0xF004, []byte{0})
m.SetAligned(0xF004, []byte{0})
require.Equal(t, zeroHashes[64-5], m.MerkleRoot(), "zero again")
})
}
Expand Down Expand Up @@ -141,30 +141,30 @@ func TestMemoryReadWrite(t *testing.T) {

t.Run("read-write", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(12, []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE})
m.SetAligned(12, []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE})
var tmp [5]byte
m.GetUnaligned(12, tmp[:])
require.Equal(t, [5]byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE}, tmp)
m.SetUnaligned(12, []byte{0xAA, 0xBB, 0x1C, 0xDD, 0xEE})
m.SetAligned(12, []byte{0xAA, 0xBB, 0x1C, 0xDD, 0xEE})
m.GetUnaligned(12, tmp[:])
require.Equal(t, [5]byte{0xAA, 0xBB, 0x1C, 0xDD, 0xEE}, tmp)
})

t.Run("read-write-unaligned", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(13, []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE})
m.SetAligned(13, []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE})
var tmp [5]byte
m.GetUnaligned(13, tmp[:])
require.Equal(t, [5]byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE}, tmp)
m.SetUnaligned(13, []byte{0xAA, 0xBB, 0x1C, 0xDD, 0xEE})
m.SetAligned(13, []byte{0xAA, 0xBB, 0x1C, 0xDD, 0xEE})
m.GetUnaligned(13, tmp[:])
require.Equal(t, [5]byte{0xAA, 0xBB, 0x1C, 0xDD, 0xEE}, tmp)
})
}

func TestMemoryJSON(t *testing.T) {
m := NewMemory()
m.SetUnaligned(8, []byte{123})
m.SetAligned(8, []byte{123})
dat, err := json.Marshal(m)
require.NoError(t, err)
var res Memory
Expand All @@ -176,7 +176,7 @@ func TestMemoryJSON(t *testing.T) {

func TestMemoryBinary(t *testing.T) {
m := NewMemory()
m.SetUnaligned(8, []byte{123})
m.SetAligned(8, []byte{123})
ser := new(bytes.Buffer)
err := m.Serialize(ser)
require.NoError(t, err, "must serialize state")
Expand Down
4 changes: 0 additions & 4 deletions rvgo/fast/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,3 @@ func parseRs2(instr U64) U64 {
func parseFunct7(instr U64) U64 {
return shr64(toU64(25), instr)
}

func parseCSSR(instr U64) U64 {
return shr64(toU64(20), instr)
}
95 changes: 39 additions & 56 deletions rvgo/fast/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
panic(fmt.Errorf("addr %d not aligned with 32 bytes", addr))
}
inst.verifyMemChange(addr, proofIndex)
s.Memory.SetUnaligned(addr, v[:])
s.Memory.SetAligned(addr, v[:])
}

// load unaligned, optionally signed, little-endian, integer of 1 ... 8 bytes from memory
Expand Down Expand Up @@ -220,7 +220,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
}
inst.verifyMemChange(leftAddr, proofIndexL)
if (addr+size-1)&^31 == addr&^31 { // if aligned
s.Memory.SetUnaligned(addr, bytez[:size])
s.Memory.SetAligned(addr, bytez[:size])
return
}
if proofIndexR == 0xff {
Expand All @@ -229,12 +229,12 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
// if not aligned
rightAddr := leftAddr + 32
leftSize := rightAddr - addr
s.Memory.SetUnaligned(addr, bytez[:leftSize])
s.Memory.SetAligned(addr, bytez[:leftSize])
if verifyR {
inst.trackMemAccess(rightAddr, proofIndexR)
}
inst.verifyMemChange(rightAddr, proofIndexR)
s.Memory.SetUnaligned(rightAddr, bytez[leftSize:size])
s.Memory.SetAligned(rightAddr, bytez[leftSize:size])
}

storeMem := func(addr U64, size U64, value U64, proofIndexL uint8, proofIndexR uint8, verifyL bool, verifyR bool) {
Expand All @@ -249,7 +249,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
}
inst.verifyMemChange(leftAddr, proofIndexL)
if (addr+size-1)&^31 == addr&^31 { // if aligned
s.Memory.SetUnaligned(addr, bytez[:size])
s.Memory.SetAligned(addr, bytez[:size])
return
}
// if not aligned
Expand All @@ -258,40 +258,17 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
}
rightAddr := leftAddr + 32
leftSize := rightAddr - addr
s.Memory.SetUnaligned(addr, bytez[:leftSize])
s.Memory.SetAligned(addr, bytez[:leftSize])
if verifyR {
inst.trackMemAccess(rightAddr, proofIndexR)
}
inst.verifyMemChange(rightAddr, proofIndexR)
s.Memory.SetUnaligned(rightAddr, bytez[leftSize:size])
s.Memory.SetAligned(rightAddr, bytez[leftSize:size])
}

//
// CSR (control and status registers) functions
//
mininny marked this conversation as resolved.
Show resolved Hide resolved
readCSR := func(num U64) U64 {
// TODO: do we need CSR?
return toU64(0)
}

writeCSR := func(num U64, v U64) {
// TODO: do we need CSR?
}

updateCSR := func(num U64, v U64, mode U64) (out U64) {
out = readCSR(num)
switch mode {
case 1: // ?01 = CSRRW(I)
case 2: // ?10 = CSRRS(I)
v = or64(out, v)
case 3: // ?11 = CSRRC(I)
v = and64(out, not64(v))
default:
revertWithCode(riscv.ErrUnknownCSRMode, fmt.Errorf("unknown CSR mode: %d", mode))
}
writeCSR(num, v)
return
}

//
// Preimage oracle interactions
Expand Down Expand Up @@ -390,28 +367,39 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
// A1 = n (length)
length := getRegister(toU64(11))
// A2 = prot (memory protection type, can ignore)
// A3 = flags (shared with other process and or written back to file, can ignore) // TODO maybe assert the MAP_ANONYMOUS flag is set
// A3 = flags (shared with other process and or written back to file)
flags := getRegister(toU64(13))
// A4 = fd (file descriptor, can ignore because we support anon memory only)
fd := getRegister(toU64(14))
// A5 = offset (offset in file, we don't support any non-anon memory, so we can ignore this)

// ignore: prot, flags, fd, offset
switch addr {
case 0:
// No hint, allocate it ourselves, by as much as the requested length.
// Increase the length to align it with desired page size if necessary.
align := and64(length, shortToU64(4095))
if align != 0 {
length = add64(length, sub64(shortToU64(4096), align))
errCode := toU64(0)

// ensure MAP_ANONYMOUS is set and fd == -1
if (flags&0x20) == 0 || fd != u64Mask() {
addr = u64Mask()
errCode = toU64(0x4d) // no error
mininny marked this conversation as resolved.
Show resolved Hide resolved
} else {
// ignore: prot, flags, fd, offset
switch addr {
case 0:
// No hint, allocate it ourselves, by as much as the requested length.
// Increase the length to align it with desired page size if necessary.
align := and64(length, shortToU64(4095))
if align != 0 {
length = add64(length, sub64(shortToU64(4096), align))
}
prevHeap := getHeap()
addr = prevHeap
setHeap(add64(prevHeap, length)) // increment heap with length
//fmt.Printf("mmap: 0x%016x (+ 0x%x increase)\n", s.Heap, length)
default:
// allow hinted memory address (leave it in A0 as return argument)
//fmt.Printf("mmap: 0x%016x (0x%x allowed)\n", addr, length)
}
prevHeap := getHeap()
setRegister(toU64(10), prevHeap)
setHeap(add64(prevHeap, length)) // increment heap with length
//fmt.Printf("mmap: 0x%016x (+ 0x%x increase)\n", s.Heap, length)
default:
// allow hinted memory address (leave it in A0 as return argument)
//fmt.Printf("mmap: 0x%016x (0x%x allowed)\n", addr, length)
}
setRegister(toU64(11), toU64(0)) // no error
setRegister(toU64(10), addr)
setRegister(toU64(11), errCode)
case riscv.SysRead: // read
fd := getRegister(toU64(10)) // A0 = fd
addr := getRegister(toU64(11)) // A1 = *buf addr
Expand Down Expand Up @@ -867,14 +855,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
setPC(add64(pc, toU64(4))) // ignore breakpoint
}
default: // CSR instructions
imm := parseCSSR(instr)
value := rs1
if iszero64(and64(funct3, toU64(4))) {
value = getRegister(rs1)
}
mode := and64(funct3, toU64(3))
rdValue := updateCSR(imm, value, mode)
setRegister(rd, rdValue)
setRegister(rd, 0) // ignore CSR instructions
setPC(add64(pc, toU64(4)))
}
case 0x2F: // 010_1111: RV32A and RV32A atomic operations extension
Expand All @@ -894,7 +875,9 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size))
}
addr := getRegister(rs1)
// TODO check if addr is aligned
if addr&3 != 0 { // quick addr alignment check
mininny marked this conversation as resolved.
Show resolved Hide resolved
revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("addr %d not aligned with 4 bytes", addr))
}

op := shr64(toU64(2), funct7)
switch op {
Expand Down
1 change: 0 additions & 1 deletion rvgo/riscv/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ const (
ErrUnexpectedRProofLoad = uint64(0xbad22220)
ErrUnexpectedRProofStoreUnaligned = uint64(0xbad22221)
ErrUnexpectedRProofStore = uint64(0xbad2222f)
ErrUnknownCSRMode = uint64(0xbadc0de0)
ErrBadAMOSize = uint64(0xbada70)
ErrFailToReadPreimage = uint64(0xbadf00d0)
ErrBadMemoryProof = uint64(0xbadf00d1)
Expand Down
4 changes: 2 additions & 2 deletions rvgo/scripts/go-ffi/differential-testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func DiffTestUtils() {
checkErr(err, "Error decoding insn")
instBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(instBytes, uint32(insn))
mem.SetUnaligned(uint64(pc), instBytes)
mem.SetAligned(uint64(pc), instBytes)

// proof size: 64-5+1=60 (a 64-bit mem-address branch to 32 byte leaf, incl leaf itself), all 32 bytes
// 60 * 32 = 1920
Expand All @@ -57,7 +57,7 @@ func DiffTestUtils() {
checkErr(err, "Error decoding memAddr")
memValue, err := hex.DecodeString(strings.TrimPrefix(args[4], "0x"))
checkErr(err, "Error decoding memValue")
mem.SetUnaligned(uint64(memAddr), memValue)
mem.SetAligned(uint64(memAddr), memValue)
memProof = mem.MerkleProof(uint64(memAddr))
}
insnProof = mem.MerkleProof(uint64(pc))
Expand Down
4 changes: 0 additions & 4 deletions rvgo/slow/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,3 @@ func parseRs2(instr U64) U64 {
func parseFunct7(instr U64) U64 {
return shr64(toU64(25), instr)
}

func parseCSSR(instr U64) U64 {
return shr64(toU64(20), instr)
}
Loading
Loading