Skip to content

Commit

Permalink
Fourmulize code
Browse files Browse the repository at this point in the history
  • Loading branch information
dnadales committed Jul 6, 2024
1 parent d388094 commit 4f70e7a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 82 deletions.
45 changes: 22 additions & 23 deletions leios-sim/src/Leios/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ import Control.Monad.Class.MonadAsync (Async, Concurrently (Concurrently, runCon
import Control.Monad.Class.MonadSTM (MonadSTM, STM, atomically, retry)
import Control.Monad.Class.MonadTimer (MonadDelay, MonadTimer, threadDelay)
import Control.Monad.Extra (whenM)
import Control.Monad.State (State, get, put, runState)
import Control.Tracer (Tracer, traceWith)
import qualified Data.Aeson as Aeson
import Data.Foldable (for_)
import Data.List (partition)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.PQueue.Max (MaxQueue)
import qualified Data.PQueue.Max as PQueue
import Data.Word (Word64)
import GHC.Generics (Generic)
import Leios.Trace (mkTracer)
import System.Random (RandomGen, mkStdGen, randomR, split)
import Text.Pretty.Simple (pPrint)
import Data.List (partition)
import System.Random (randomR, RandomGen, mkStdGen, split)
import Control.Monad.State (State, get, put, runState)
import Data.Word (Word64)

--------------------------------------------------------------------------------
-- FIXME: constants that should be configurable
Expand Down Expand Up @@ -118,7 +118,7 @@ newtype NodeStakePercent = NodeStakePercent Double
deriving newtype (Show)

-- Frequency of IB slots per Praos slots.
newtype IBFrequency = IBFrequency {getIBFrequency :: Double}
newtype IBFrequency = IBFrequency {getIBFrequency :: Double}
deriving stock (Generic)
deriving newtype (Show, Eq, Ord, Aeson.ToJSON, Aeson.FromJSON)

Expand Down Expand Up @@ -207,20 +207,20 @@ run tracer paramsTVar continueTVar = do
gen = mkStdGen (initialSeed params)
nodesGens = splitIn totalNodes gen
-- TODO: for now we don't allow this to be configurable
stakePercent = NodeStakePercent (1/ fromIntegral totalNodes)
stakePercent = NodeStakePercent (1 / fromIntegral totalNodes)
raceAll
[ do
register (NodeId i) world
let nodeGen = nodesGens !! fromIntegral i
node (NodeId i) stakePercent nodeGen tracer world continueTVar
register (NodeId i) world
let nodeGen = nodesGens !! fromIntegral i
node (NodeId i) stakePercent nodeGen tracer world continueTVar
| i <- [0 .. totalNodes - 1]
]
where
splitIn 0 _ = []
splitIn 1 gen = [gen]
splitIn n gen = gen0 : splitIn (n-1) gen1
where
(gen0, gen1) = split gen
where
splitIn 0 _ = []
splitIn 1 gen = [gen]
splitIn n gen = gen0 : splitIn (n - 1) gen1
where
(gen0, gen1) = split gen

-- | Determine if the node with the given stake leads.
--
Expand All @@ -232,16 +232,15 @@ run tracer paramsTVar continueTVar = do
-- where @α@ is the given node stake, @f@ is the frequency of slots and
--
-- > asc = f / ceiling f
--
leads :: RandomGen g => NodeStakePercent -> Double -> State g Bool
leads (NodeStakePercent α) f_I = do
generator <- get
let (n, nextGenerator) = randomR (0, 1) generator
put nextGenerator
pure $! n <= asc_I * α
where
asc_I = f_I / ceiling' f_I
ceiling' = fromIntegral . ceiling
where
asc_I = f_I / ceiling' f_I
ceiling' = fromIntegral . ceiling

-- | Given a number of IB slots, determine if the node with the given
-- stake percent leads on those slots.
Expand Down Expand Up @@ -282,7 +281,7 @@ node nodeId nodeStakePercent initialGenerator tracer world continueTVar = do
let loop generator = do
slot <- nextSlot clock
traceWith tracer (NextSlot nodeId slot)
Parameters {f_I, f_E} <- getParams world
Parameters{f_I, f_E} <- getParams world
-- Generate IB blocks
let (numberOfIBsInThisSlot, generator1) =
slotsLed generator (getIBFrequency f_I)
Expand All @@ -308,7 +307,7 @@ node nodeId nodeStakePercent initialGenerator tracer world continueTVar = do
let q = ceiling $ f -- For practical reasons we want this to be a minimal value.
(nodeLeads, nextGenerator) =
leadsMultiple generator q nodeStakePercent f
in (length $ filter id nodeLeads, nextGenerator)
in (length $ filter id nodeLeads, nextGenerator)

produceIB slot = do
let newIB = IB{ib_slot = slot, ib_producer = nodeId, ib_size = gIBSize}
Expand Down Expand Up @@ -407,7 +406,8 @@ runStandalone = do
-- TODO: we might want to add some mechanism to cancel the async tick
-- thread when the thread that has a reference to the returned clock
-- is canceled.
runClock :: (Monad m, MonadSTM m, MonadAsync m, MonadDelay m) =>
runClock ::
(Monad m, MonadSTM m, MonadAsync m, MonadDelay m) =>
TVar m ShouldContinue ->
m (Clock m)
runClock continueTVar = do
Expand Down Expand Up @@ -539,7 +539,6 @@ storeIB nodeId ib OutsideWorld{storedIBsTVar} =

-- | Retrieve the downloaded IBs by the given node, which correspond
-- to the given slice. Once retrieved the IBs are deleted.
--
storedIBsBy :: MonadSTM m => NodeId -> OutsideWorld m -> Slice -> m [IB]
storedIBsBy nodeId OutsideWorld{storedIBsTVar} slice = atomically $ do
storedIBs <- readTVar storedIBsTVar
Expand Down
126 changes: 67 additions & 59 deletions leios-sim/src/Leios/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,28 @@

module Leios.Server where

import Control.Concurrent.Class.MonadSTM (MonadSTM (modifyTVar), TVar, atomically, newBroadcastTChanIO, writeTChan, newTVarIO, stateTVar, readTVarIO, writeTVar)
import Control.Concurrent (threadDelay)
import Control.Concurrent.Class.MonadSTM (MonadSTM (modifyTVar), TVar, atomically, newBroadcastTChanIO, newTVarIO, readTVarIO, stateTVar, writeTChan, writeTVar)
import Control.Concurrent.Class.MonadSTM.TChan (TChan, dupTChan, readTChan)
import Control.Concurrent.Class.MonadSTM.TQueue
import Control.Exception (SomeException, handle, throw)
import Control.Monad (forever)
import Control.Monad.Class.MonadAsync (race_)
import Control.Monad.IO.Class (liftIO)
import Data.Aeson (Value, encode)
import Data.Text.Lazy.Encoding (decodeUtf8)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Text.Lazy as Text
import Control.Concurrent (threadDelay)
import Control.Exception (handle, SomeException, throw)
import Data.Text.Lazy.Encoding (decodeUtf8)
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.Warp as Warp
import qualified Network.Wai.Handler.WebSockets as WS
import Network.Wai.Middleware.RequestLogger (logStdoutDev)
import qualified Network.WebSockets as WS
import qualified Web.Scotty as Sc
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

-- FIXME: import explicitly
import Leios.Model (Parameters (..), BitsPerSecond (..), NumberOfSlots (..), NumberOfSlices (..), BitsPerSecond (..), NumberOfBits (..), IBFrequency (..), EBFrequency (..), ShouldContinue(..))
import Leios.Model (BitsPerSecond (..), EBFrequency (..), IBFrequency (..), NumberOfBits (..), NumberOfSlices (..), NumberOfSlots (..), Parameters (..), ShouldContinue (..))
import qualified Leios.Model as Model
import Leios.Trace (mkTracer)
import Network.HTTP.Types.Status (badRequest400)
Expand All @@ -38,20 +39,23 @@ import Network.HTTP.Types.Status (badRequest400)
type SessionId = Int

nextId :: MonadSTM m => ServerState m -> m SessionId
nextId ServerState { nextIdTVar } =
atomically $ stateTVar nextIdTVar (\i -> (i, i+1))
nextId ServerState{nextIdTVar} =
atomically $ stateTVar nextIdTVar (\i -> (i, i + 1))

-- | Create a new session with the given parameters. Return the id of the new session.
newSession :: MonadSTM m =>
ServerState m
-> TVar m Parameters
-> TVar m ShouldContinue
-> m SessionId
newSession ::
MonadSTM m =>
ServerState m ->
TVar m Parameters ->
TVar m ShouldContinue ->
m SessionId
newSession state paramsTVar continueTVar = do
sid <- nextId state
let clientState = ClientState {
paramsTVar = paramsTVar, continueTVar = continueTVar
}
let clientState =
ClientState
{ paramsTVar = paramsTVar
, continueTVar = continueTVar
}
atomically $ modifyTVar (sessionsTVar state) (Map.insert sid clientState) -- We could assert the session does not exist in the map.
pure sid

Expand All @@ -65,30 +69,33 @@ lookupParamsTVar sid serverState =
fmap (fmap paramsTVar) $ lookupClientState sid serverState

lookupClientState :: MonadSTM m => SessionId -> ServerState m -> m (Maybe (ClientState m))
lookupClientState sid ServerState { sessionsTVar } = do
lookupClientState sid ServerState{sessionsTVar} = do
sessions <- readTVarIO sessionsTVar
pure $ Map.lookup sid sessions

lookupContinueTVar :: MonadSTM m => SessionId -> ServerState m
-> m (Maybe (TVar m ShouldContinue))
lookupContinueTVar sid serverState = do
lookupContinueTVar ::
MonadSTM m =>
SessionId ->
ServerState m ->
m (Maybe (TVar m ShouldContinue))
lookupContinueTVar sid serverState = do
fmap (fmap continueTVar) $ lookupClientState sid serverState

data ServerState m = ServerState {
sessionsTVar :: TVar m (Map SessionId (ClientState m)),
nextIdTVar :: TVar m Int
data ServerState m = ServerState
{ sessionsTVar :: TVar m (Map SessionId (ClientState m))
, nextIdTVar :: TVar m Int
}

data ClientState m = ClientState {
paramsTVar :: TVar m Parameters,
continueTVar :: TVar m ShouldContinue
data ClientState m = ClientState
{ paramsTVar :: TVar m Parameters
, continueTVar :: TVar m ShouldContinue
}

mkServerState :: (Monad m, MonadSTM m) => m (ServerState m)
mkServerState = do
sessionsTVar <- newTVarIO mempty
nextIdTVar <- newTVarIO 0
pure $ ServerState { sessionsTVar = sessionsTVar, nextIdTVar = nextIdTVar }
pure $ ServerState{sessionsTVar = sessionsTVar, nextIdTVar = nextIdTVar}

--------------------------------------------------------------------------------
-- Server
Expand All @@ -102,10 +109,10 @@ runServer = do
sapp <- scottyApp serverState
Warp.runSettings
settings
(WS.websocketsOr
WS.defaultConnectionOptions
(wsapp serverState)
sapp
( WS.websocketsOr
WS.defaultConnectionOptions
(wsapp serverState)
sapp
)

feedClient :: MonadSTM m => TQueue m Value -> TChan m Value -> m ()
Expand Down Expand Up @@ -165,9 +172,9 @@ scottyApp serverState =
(id :: SessionId) <- Sc.queryParam "sessionId"
liftIO $ print id
pure ()
-- liftIO $
-- atomically $
-- modifyTVar params (\p -> p{nodeBandwidth = BitsPerSecond bps})
-- liftIO $
-- atomically $
-- modifyTVar params (\p -> p{nodeBandwidth = BitsPerSecond bps})

Sc.post "/api/lambda" $ do
λ <- Sc.jsonData
Expand All @@ -187,9 +194,10 @@ wsapp serverState pending = do
continueTVar <- newTVarIO Stop
sid <- newSession serverState paramsTVar continueTVar
-- For now we send the session ID it this way. We can make this more robust if needed.
WS.sendTextData conn $ "{ \"tag\": \"SessionId\", \"sessionId\": "
<> Text.pack (show sid)
<> " }"
WS.sendTextData conn $
"{ \"tag\": \"SessionId\", \"sessionId\": "
<> Text.pack (show sid)
<> " }"

WS.withPingThread conn 30 (pure ()) $ do
eventQueue <- newTQueueIO
Expand All @@ -198,24 +206,24 @@ wsapp serverState pending = do

-- raceAll could be moved to some `Utils` package if we want to use it here.
handle cleanup $
Model.raceAll
[ feedClient eventQueue clientChannel
, Model.run (mkTracer eventQueue) paramsTVar continueTVar
, forever $ do
msg <- atomically $ readTChan clientQueue
WS.sendTextData conn $ decodeUtf8 $ encode msg
]
where
cleanup :: SomeException -> IO ()
cleanup e = putStrLn "TODO: perform cleanup." >> throw e

defaultParams =
Parameters
{ _L = NumberOfSlots 4
, λ = NumberOfSlices 3
, nodeBandwidth = BitsPerSecond 1000
, ibSize = NumberOfBits 300
, f_I = IBFrequency 5
, f_E = EBFrequency 1
, initialSeed = 22595838
}
Model.raceAll
[ feedClient eventQueue clientChannel
, Model.run (mkTracer eventQueue) paramsTVar continueTVar
, forever $ do
msg <- atomically $ readTChan clientQueue
WS.sendTextData conn $ decodeUtf8 $ encode msg
]
where
cleanup :: SomeException -> IO ()
cleanup e = putStrLn "TODO: perform cleanup." >> throw e

defaultParams =
Parameters
{ _L = NumberOfSlots 4
, λ = NumberOfSlices 3
, nodeBandwidth = BitsPerSecond 1000
, ibSize = NumberOfBits 300
, f_I = IBFrequency 5
, f_E = EBFrequency 1
, initialSeed = 22595838
}

0 comments on commit 4f70e7a

Please sign in to comment.