Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add control signals and use general rsp #120

Open
wants to merge 18 commits into
base: v4
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions akita.code-workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"folders": [
{
"path": "."
},
{
"path": "../../rewrite_mem"
}
],
"settings": {}
}
14 changes: 14 additions & 0 deletions mem/idealmemcontroller/111.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package idealmemcontroller

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

//go:generate mockgen -destination "mock_sim_test.go" -package $GOPACKAGE -write_package_comment=false github.com/sarchlab/akita/v4/sim Port,Connection,Engine
func TestIdealmemcontroller(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Idealmemcontroller Suite")
}
11 changes: 8 additions & 3 deletions mem/idealmemcontroller/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (b Builder) Build(
c := &Comp{
Latency: b.latency,
width: b.width,
state: "enable",
}

c.TickingComponent = sim.NewTickingComponent(name, b.engine, b.freq, c)
Expand All @@ -102,11 +103,15 @@ func (b Builder) Build(
c.Storage = b.storage
}

ctrlMiddleware := &ctrlMiddleware{Comp: c}
c.AddMiddleware(ctrlMiddleware)
funcMiddleware := &funcMiddleware{Comp: c}
c.AddMiddleware(funcMiddleware)

c.topPort = sim.NewLimitNumMsgPort(c, b.topBufSize, name+".TopPort")
c.AddPort("Top", c.topPort)

middleware := &middleware{Comp: c}
c.AddMiddleware(middleware)
c.ctrlPort = sim.NewLimitNumMsgPort(c, b.topBufSize, name+".CtrlPort")
c.AddPort("Control", c.ctrlPort)

return c
}
159 changes: 7 additions & 152 deletions mem/idealmemcontroller/comp.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
package idealmemcontroller

import (
"log"
"reflect"

"github.com/sarchlab/akita/v4/mem/mem"

"github.com/sarchlab/akita/v4/sim"

"github.com/sarchlab/akita/v4/tracing"
)

type readRespondEvent struct {
Expand Down Expand Up @@ -41,159 +36,19 @@ type Comp struct {
sim.MiddlewareHolder

topPort sim.Port
ctrlPort sim.Port
Storage *mem.Storage
Latency int
addressConverter mem.AddressConverter

width int
}

func (c *Comp) Tick() bool {
return c.MiddlewareHolder.Tick()
}

// Handle defines how the Comp handles event
func (c *Comp) Handle(e sim.Event) error {
switch e := e.(type) {
case *readRespondEvent:
return c.handleReadRespondEvent(e)
case *writeRespondEvent:
return c.handleWriteRespondEvent(e)
case sim.TickEvent:
return c.TickingComponent.Handle(e)
default:
log.Panicf("cannot handle event of %s", reflect.TypeOf(e))
}

return nil
}

type middleware struct {
*Comp
}

// Tick updates ideal memory controller state.
func (m *middleware) Tick() bool {
msg := m.topPort.RetrieveIncoming()
if msg == nil {
return false
}

tracing.TraceReqReceive(msg, m.Comp)

switch msg := msg.(type) {
case *mem.ReadReq:
m.handleReadReq(msg)
return true
case *mem.WriteReq:
m.handleWriteReq(msg)
return true
default:
log.Panicf("cannot handle request of type %s", reflect.TypeOf(msg))
}

return false
}

func (m *middleware) handleReadReq(req *mem.ReadReq) {
now := m.CurrentTime()
timeToSchedule := m.Freq.NCyclesLater(m.Latency, now)
respondEvent := newReadRespondEvent(timeToSchedule, m.Comp, req)
m.Engine.Schedule(respondEvent)
}

func (m *middleware) handleWriteReq(req *mem.WriteReq) {
now := m.CurrentTime()
timeToSchedule := m.Freq.NCyclesLater(m.Latency, now)
respondEvent := newWriteRespondEvent(timeToSchedule, m.Comp, req)
m.Engine.Schedule(respondEvent)
}

func (c *Comp) handleReadRespondEvent(e *readRespondEvent) error {
now := e.Time()
req := e.req

addr := req.Address
if c.addressConverter != nil {
addr = c.addressConverter.ConvertExternalToInternal(addr)
}

data, err := c.Storage.Read(addr, req.AccessByteSize)
if err != nil {
log.Panic(err)
}
respondReq *mem.ControlMsg
width int

rsp := mem.DataReadyRspBuilder{}.
WithSrc(c.topPort).
WithDst(req.Src).
WithRspTo(req.ID).
WithData(data).
Build()
state string

networkErr := c.topPort.Send(rsp)

if networkErr != nil {
retry := newReadRespondEvent(c.Freq.NextTick(now), c, req)
c.Engine.Schedule(retry)
return nil
}

tracing.TraceReqComplete(req, c)
c.TickLater()

return nil
inflightbuffer []sim.Msg
}

func (c *Comp) handleWriteRespondEvent(e *writeRespondEvent) error {
now := e.Time()
req := e.req

rsp := mem.WriteDoneRspBuilder{}.
WithSrc(c.topPort).
WithDst(req.Src).
WithRspTo(req.ID).
Build()

networkErr := c.topPort.Send(rsp)
if networkErr != nil {
retry := newWriteRespondEvent(c.Freq.NextTick(now), c, req)
c.Engine.Schedule(retry)
return nil
}

addr := req.Address

if c.addressConverter != nil {
addr = c.addressConverter.ConvertExternalToInternal(addr)
}

if req.DirtyMask == nil {
err := c.Storage.Write(addr, req.Data)
if err != nil {
log.Panic(err)
}
} else {
data, err := c.Storage.Read(addr, uint64(len(req.Data)))
if err != nil {
panic(err)
}
for i := 0; i < len(req.Data); i++ {
if req.DirtyMask[i] {
data[i] = req.Data[i]
}
}
err = c.Storage.Write(addr, data)
if err != nil {
panic(err)
}
}

tracing.TraceReqComplete(req, c)
c.TickLater()

return nil
}

func (c *Comp) CurrentTime() sim.VTimeInSec {
return c.Engine.CurrentTime()
func (c *Comp) Tick() bool {
return c.MiddlewareHolder.Tick()
}
119 changes: 119 additions & 0 deletions mem/idealmemcontroller/ctrlMiddleware_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package idealmemcontroller

import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sarchlab/akita/v4/mem/mem"
"github.com/sarchlab/akita/v4/sim"
)

var _ = FDescribe("CtrlMiddleware", func() {
var (
mockCtrl *gomock.Controller
comp *Comp
engine *MockEngine
remoteCtrlPort *MockPort
ctrlPort *MockPort
ctrlMW *ctrlMiddleware
)

BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())

engine = NewMockEngine(mockCtrl)
ctrlPort = NewMockPort(mockCtrl)
remoteCtrlPort = NewMockPort(mockCtrl)

comp = MakeBuilder().
WithEngine(engine).
WithNewStorage(1 * mem.MB).
Build("MemCtrl")
comp.Freq = 1000 * sim.MHz
comp.Latency = 10
comp.ctrlPort = ctrlPort

ctrlMW = &ctrlMiddleware{
Comp: comp,
}
})

AfterEach(func() {
mockCtrl.Finish()
})

It("should do nothing if no ctrl message", func() {
ctrlPort.EXPECT().RetrieveIncoming().Return(nil)

madeProgress := ctrlMW.Tick()

Expect(madeProgress).To(BeFalse())
})

It("should handle enable message", func() {
comp.state = "paused"

ctrlMsg := mem.ControlMsgBuilder{}.
WithSrc(remoteCtrlPort).
WithDst(ctrlPort).
WithCtrlInfo(true, false, false, false).
Build()
ctrlPort.EXPECT().RetrieveIncoming().Return(ctrlMsg)
ctrlPort.EXPECT().
Send(gomock.Any()).
Do(func(msg *sim.GeneralRsp) {
Expect(msg.Src).To(Equal(ctrlPort))
Expect(msg.Dst).To(Equal(remoteCtrlPort))
Expect(msg.OriginalReq).To(Equal(ctrlMsg))
}).
Return(nil).
AnyTimes()

madeProgress := ctrlMW.Tick()

Expect(madeProgress).To(BeTrue())
Expect(comp.state).To(Equal("enable"))
})

It("should handle pause message", func() {
comp.state = "enable"

ctrlMsg := mem.ControlMsgBuilder{}.
WithSrc(remoteCtrlPort).
WithDst(ctrlPort).
WithCtrlInfo(false, false, false, false).
Build()
ctrlPort.EXPECT().RetrieveIncoming().Return(ctrlMsg)
ctrlPort.EXPECT().
Send(gomock.Any()).
Do(func(msg *sim.GeneralRsp) {
Expect(msg.Src).To(Equal(ctrlPort))
Expect(msg.Dst).To(Equal(remoteCtrlPort))
Expect(msg.OriginalReq).To(Equal(ctrlMsg))
}).
Return(nil).
AnyTimes()

madeProgress := ctrlMW.Tick()

Expect(madeProgress).To(BeTrue())
Expect(comp.state).To(Equal("pause"))
})

It("should handle drain message", func() {
comp.state = "enable"

ctrlMsg := mem.ControlMsgBuilder{}.
WithSrc(remoteCtrlPort).
WithDst(ctrlPort).
WithCtrlInfo(false, true, false, false).
Build()
ctrlPort.EXPECT().RetrieveIncoming().Return(ctrlMsg)
madeProgress := ctrlMW.Tick()

Expect(madeProgress).To(BeTrue())
Expect(comp.state).To(Equal("drain"))
Expect(comp.respondReq).To(Equal(ctrlMsg))
})

})
Loading
Loading