From 27a2c7ea43a0c28265868af90515db4572cbb23b Mon Sep 17 00:00:00 2001
From: ChinYikMing <yikming2222@gmail.com>
Date: Fri, 27 Dec 2024 04:33:15 +0800
Subject: [PATCH] Handle signals properly

Page faults trigger a trap, which is handled by do_page_fault(). This
function calls lock_mm_and_find_vma() to locate and validate the virtual
memory area (VMA), returning the VMA if valid, or NULL otherwise.
Typically, attempts to read or write to a NULL VMA result in a NULL
return. If the VMA is invalid, bad_area_nosemaphore() is invoked, which
checks whether the fault originated in kernel or user space.

For user-space faults, a SIGSEGV signal is sent to the user process via
do_trap(), which determines if the signal should be ignored or blocked,
and if not, adds it to the task's pending signal list. Kernel-space
faults cause the kernel to crash via die_kernel_fault().

Before returning to user space (via the resume_userspace label), pending
work (indicated by the _TIF_WORK_MASK mask) is processed by
do_work_pending(). Signals are handled by do_signal(), which in turn
calls handle_signal(). handle_signal() creates a signal handler frame
that will be jumped to upon returning to user space. This frame creation
process might modifies the Control and Status Register (CSR) SEPC.
If there are a signal pending, the SEPC CSR overwritten the original
trap/fault PC. This caused an assertion failure in get_ppn_and_offset()
when running the vi program, reported in [1].

To address this, a variable last_csr_sepc was introduced to store the
original SEPC CSR value before entering the trap path. After returning
to user space, last_csr_sepc is compared with the current SEPC CSR
value. If they differ, the fault ld/st instruction returns early and
jumps to the signal handler frame.

This commit prevents emulator crashes when the guest OS accesses invalid
memory. Consequently, reads or writes to a NULL value now correctly
result in a segmentation fault. In addition, two user-space programs:
mem_null_read and mem_null_write are bundled into the rootfs for
verification.

Original behaviour
1. $ make system ENABLE_SYSTEM=1 -j$(nproc)
2. $ mem_null_read                            # Emulator crashes
3. $ mem_null_write                           # Emulator crashes
4. $ vi                                       # Emulator crashes

Patch Reproduce / Testing procedure:
1. $ make system ENABLE_SYSTEM=1 -j$(nproc)
2. $ mem_null_read       # NULL read causes SIGSEGV without crashing
3. $ mem_null_write      # NULL write causes SIGSEGV without crashing
4. $ vi                  # w/o filename causes SIGSEGV without crashing

[1] #508
---
 src/emulate.c       | 12 ++++++++++--
 src/riscv_private.h |  6 ++++++
 src/system.c        | 37 +++++++++++++++++++++++++++++++------
 3 files changed, 47 insertions(+), 8 deletions(-)

diff --git a/src/emulate.c b/src/emulate.c
index 293ce031e..8279eb749 100644
--- a/src/emulate.c
+++ b/src/emulate.c
@@ -46,6 +46,7 @@ static bool need_clear_block_map = false;
 static uint32_t reloc_enable_mmu_jalr_addr;
 static bool reloc_enable_mmu = false;
 bool need_retranslate = false;
+bool need_handle_signal = false;
 #endif
 
 static void rv_trap_default_handler(riscv_t *rv)
@@ -379,8 +380,12 @@ static uint32_t peripheral_update_ctr = 64;
     {                                                                 \
         IIF(RV32_HAS(SYSTEM))(ctr++;, ) cycle++;                      \
         code;                                                         \
-    nextop:                                                           \
-        PC += __rv_insn_##inst##_len;                                 \
+        IIF(RV32_HAS(SYSTEM))                                         \
+        (                                                             \
+            if (need_handle_signal) {                                 \
+                need_handle_signal = false;                           \
+                return true;                                          \
+            }, ) nextop : PC += __rv_insn_##inst##_len;               \
         IIF(RV32_HAS(SYSTEM))                                         \
         (IIF(RV32_HAS(JIT))(                                          \
              , if (unlikely(need_clear_block_map)) {                  \
@@ -1179,6 +1184,9 @@ static void _trap_handler(riscv_t *rv)
         mode = rv->csr_stvec & 0x3;
         cause = rv->csr_scause;
         rv->csr_sepc = rv->PC;
+#if RV32_HAS(SYSTEM)
+        rv->last_csr_sepc = rv->csr_sepc;
+#endif
     } else { /* machine */
         const uint32_t mstatus_mie =
             (rv->csr_mstatus & MSTATUS_MIE) >> MSTATUS_MIE_SHIFT;
diff --git a/src/riscv_private.h b/src/riscv_private.h
index 0ae6f2797..4e052760c 100644
--- a/src/riscv_private.h
+++ b/src/riscv_private.h
@@ -201,6 +201,12 @@ struct riscv_internal {
 #if RV32_HAS(SYSTEM)
     /* The flag is used to indicate the current emulation is in a trap */
     bool is_trapped;
+
+    /*
+     * The flag that stores the SEPC CSR at the trap point for corectly
+     * executing signal handler.
+     */
+    uint32_t last_csr_sepc;
 #endif
 };
 
diff --git a/src/system.c b/src/system.c
index 62ecbc11d..60c56b5f1 100644
--- a/src/system.c
+++ b/src/system.c
@@ -25,6 +25,18 @@ void emu_update_uart_interrupts(riscv_t *rv)
     plic_update_interrupts(attr->plic);
 }
 
+#define CLEAR_PENDING_SIGNAL_R()                                \
+    if (rv->csr_sepc != rv->last_csr_sepc) {                    \
+        need_handle_signal = true;                              \
+        return 0; /* early return and jump to signal handler */ \
+    }
+
+#define CLEAR_PENDING_SIGNAL_W()                              \
+    if (rv->csr_sepc != rv->last_csr_sepc) {                  \
+        need_handle_signal = true;                            \
+        return; /* early return and jump to signal handler */ \
+    }
+
 #define MMIO_R 1
 #define MMIO_W 0
 
@@ -269,6 +281,7 @@ MMU_FAULT_CHECK_IMPL(write, pagefault_store)
  * - mmu_write_s
  * - mmu_write_b
  */
+extern bool need_handle_signal;
 extern bool need_retranslate;
 static uint32_t mmu_ifetch(riscv_t *rv, const uint32_t addr)
 {
@@ -297,8 +310,10 @@ static uint32_t mmu_read_w(riscv_t *rv, const uint32_t addr)
     uint32_t level;
     pte_t *pte = mmu_walk(rv, addr, &level);
     bool ok = MMU_FAULT_CHECK(read, rv, pte, addr, PTE_R);
-    if (unlikely(!ok))
+    if (unlikely(!ok)) {
+        CLEAR_PENDING_SIGNAL_R();
         pte = mmu_walk(rv, addr, &level);
+    }
 
     {
         get_ppn_and_offset();
@@ -323,8 +338,10 @@ static uint16_t mmu_read_s(riscv_t *rv, const uint32_t addr)
     uint32_t level;
     pte_t *pte = mmu_walk(rv, addr, &level);
     bool ok = MMU_FAULT_CHECK(read, rv, pte, addr, PTE_R);
-    if (unlikely(!ok))
+    if (unlikely(!ok)) {
+        CLEAR_PENDING_SIGNAL_R();
         pte = mmu_walk(rv, addr, &level);
+    }
 
     get_ppn_and_offset();
     return memory_read_s(ppn | offset);
@@ -338,8 +355,10 @@ static uint8_t mmu_read_b(riscv_t *rv, const uint32_t addr)
     uint32_t level;
     pte_t *pte = mmu_walk(rv, addr, &level);
     bool ok = MMU_FAULT_CHECK(read, rv, pte, addr, PTE_R);
-    if (unlikely(!ok))
+    if (unlikely(!ok)) {
+        CLEAR_PENDING_SIGNAL_R();
         pte = mmu_walk(rv, addr, &level);
+    }
 
     {
         get_ppn_and_offset();
@@ -364,8 +383,10 @@ static void mmu_write_w(riscv_t *rv, const uint32_t addr, const uint32_t val)
     uint32_t level;
     pte_t *pte = mmu_walk(rv, addr, &level);
     bool ok = MMU_FAULT_CHECK(write, rv, pte, addr, PTE_W);
-    if (unlikely(!ok))
+    if (unlikely(!ok)) {
+        CLEAR_PENDING_SIGNAL_W();
         pte = mmu_walk(rv, addr, &level);
+    }
 
     {
         get_ppn_and_offset();
@@ -390,8 +411,10 @@ static void mmu_write_s(riscv_t *rv, const uint32_t addr, const uint16_t val)
     uint32_t level;
     pte_t *pte = mmu_walk(rv, addr, &level);
     bool ok = MMU_FAULT_CHECK(write, rv, pte, addr, PTE_W);
-    if (unlikely(!ok))
+    if (unlikely(!ok)) {
+        CLEAR_PENDING_SIGNAL_W();
         pte = mmu_walk(rv, addr, &level);
+    }
 
     get_ppn_and_offset();
     memory_write_s(ppn | offset, (uint8_t *) &val);
@@ -405,8 +428,10 @@ static void mmu_write_b(riscv_t *rv, const uint32_t addr, const uint8_t val)
     uint32_t level;
     pte_t *pte = mmu_walk(rv, addr, &level);
     bool ok = MMU_FAULT_CHECK(write, rv, pte, addr, PTE_W);
-    if (unlikely(!ok))
+    if (unlikely(!ok)) {
+        CLEAR_PENDING_SIGNAL_W();
         pte = mmu_walk(rv, addr, &level);
+    }
 
     {
         get_ppn_and_offset();