Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge stack and temporary variable groups in JuvixReg #2579

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions runtime/src/juvix/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,10 @@
IO_INTERPRET; \
io_print_toplevel(juvix_result);

// Temporary vars
// Temporary / local vars
#define DECL_TMP(k) UNUSED word_t juvix_tmp_##k
#define TMP(k) juvix_tmp_##k

// Value stack temporary vars
#define DECL_STMP(k) word_t juvix_stmp_##k
#define STMP(k) juvix_stmp_##k

// Begin a function definition. `max_stack` is the maximum stack allocation in
// the function.
#define JUVIX_FUNCTION(label, max_stack) \
Expand Down
8 changes: 3 additions & 5 deletions src/Juvix/Compiler/Backend/C/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ fromReg lims tab =
fromRegFunction :: (Member CBuilder r) => Reg.ExtraInfo -> Reg.FunctionInfo -> Sem r [Statement]
fromRegFunction info funInfo = do
body <- fromRegCode bNoStack info (funInfo ^. Reg.functionCode)
let stmpDecls = mkDecls "DECL_STMP" (funInfo ^. Reg.functionStackVarsNum)
tmpDecls = mkDecls "DECL_TMP" (funInfo ^. Reg.functionTempVarsNum)
let tmpDecls = mkDecls "DECL_TMP" (funInfo ^. Reg.functionLocalVarsNum)
return
[closureDecl, functionDecl, StatementCompound (stmpDecls ++ tmpDecls ++ body)]
[closureDecl, functionDecl, StatementCompound (tmpDecls ++ body)]
where
mkDecls :: Text -> Int -> [Statement]
mkDecls decl n = map (\i -> StatementExpr (macroCall decl [integer i])) [0 .. n - 1]
Expand Down Expand Up @@ -290,8 +289,7 @@ fromRegInstr bNoStack info = \case
g =
case _varRefGroup of
Reg.VarGroupArgs -> "ARG"
Reg.VarGroupStack -> "STMP"
Reg.VarGroupTemp -> "TMP"
Reg.VarGroupLocal -> "TMP"

fromValue :: Reg.Value -> Expression
fromValue = \case
Expand Down
3 changes: 1 addition & 2 deletions src/Juvix/Compiler/Reg/Data/InfoTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ data FunctionInfo = FunctionInfo
_functionLocation :: Maybe Location,
_functionSymbol :: Symbol,
_functionArgsNum :: Int,
_functionStackVarsNum :: Int,
_functionTempVarsNum :: Int,
_functionLocalVarsNum :: Int,
_functionCode :: Code
}

Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Reg/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data ConstrField = ConstrField
_constrFieldIndex :: Index
}

data VarGroup = VarGroupArgs | VarGroupStack | VarGroupTemp
data VarGroup = VarGroupArgs | VarGroupLocal

data VarRef = VarRef
{ _varRefGroup :: VarGroup,
Expand Down
76 changes: 41 additions & 35 deletions src/Juvix/Compiler/Reg/Translation/FromAsm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ fromAsm tab =
_functionLocation = fi ^. Asm.functionLocation,
_functionSymbol = fi ^. Asm.functionSymbol,
_functionArgsNum = fi ^. Asm.functionArgsNum,
_functionStackVarsNum = fi ^. Asm.functionMaxValueStackHeight,
_functionTempVarsNum = fi ^. Asm.functionMaxTempStackHeight,
_functionLocalVarsNum = fi ^. Asm.functionMaxTempStackHeight + fi ^. Asm.functionMaxValueStackHeight,
_functionCode = fromAsmFun tab fi
}

Expand Down Expand Up @@ -64,9 +63,9 @@ fromAsmFun tab fi =
Asm.RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = fromAsmInstr fi tab,
_recurseBranch = fromAsmBranch,
_recurseCase = fromAsmCase tab,
_recurseSave = fromAsmSave
_recurseBranch = fromAsmBranch fi,
_recurseCase = fromAsmCase fi tab,
_recurseSave = fromAsmSave fi
}

fromAsmInstr ::
Expand All @@ -78,14 +77,14 @@ fromAsmInstr ::
fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
case _cmdInstrInstruction of
Asm.Binop op -> return $ mkBinop (mkOpcode op)
Asm.ValShow -> return $ mkShow (VarRef VarGroupStack n) (VRef $ VarRef VarGroupStack n)
Asm.StrToInt -> return $ mkStrToInt (VarRef VarGroupStack n) (VRef $ VarRef VarGroupStack n)
Asm.Push val -> return $ mkAssign (VarRef VarGroupStack (n + 1)) (mkValue val)
Asm.ValShow -> return $ mkShow (VarRef VarGroupLocal (ntmps + n)) (VRef $ VarRef VarGroupLocal (ntmps + n))
Asm.StrToInt -> return $ mkStrToInt (VarRef VarGroupLocal (ntmps + n)) (VRef $ VarRef VarGroupLocal (ntmps + n))
Asm.Push val -> return $ mkAssign (VarRef VarGroupLocal (ntmps + n + 1)) (mkValue val)
Asm.Pop -> return Nop
Asm.Trace -> return $ Trace $ InstrTrace (VRef $ VarRef VarGroupStack n)
Asm.Trace -> return $ Trace $ InstrTrace (VRef $ VarRef VarGroupLocal (ntmps + n))
Asm.Dump -> return Dump
Asm.Failure -> return $ Failure $ InstrFailure (VRef $ VarRef VarGroupStack n)
Asm.ArgsNum -> return $ mkArgsNum (VarRef VarGroupStack n) (VRef $ VarRef VarGroupStack n)
Asm.Failure -> return $ Failure $ InstrFailure (VRef $ VarRef VarGroupLocal (ntmps + n))
Asm.ArgsNum -> return $ mkArgsNum (VarRef VarGroupLocal (ntmps + n)) (VRef $ VarRef VarGroupLocal (ntmps + n))
Asm.Prealloc x -> return $ mkPrealloc x
Asm.AllocConstr tag -> return $ mkAlloc tag
Asm.AllocClosure x -> return $ mkAllocClosure x
Expand All @@ -95,33 +94,37 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
Asm.CallClosures x -> return $ mkCallClosures False x
Asm.TailCallClosures x -> return $ mkCallClosures True x
Asm.Return ->
return $ Return InstrReturn {_instrReturnValue = VRef $ VarRef VarGroupStack 0}
return $ Return InstrReturn {_instrReturnValue = VRef $ VarRef VarGroupLocal ntmps}
where
-- `n` is the index of the top of the value stack *before* executing the
-- instruction
n :: Int
n = si ^. Asm.stackInfoValueStackHeight - 1

-- `ntmps` is the number of temporary variables (= max temporary stack height)
ntmps :: Int
ntmps = funInfo ^. Asm.functionMaxTempStackHeight

-- Live variables *after* executing the instruction. `k` is the number of
-- value stack cells that will be popped by the instruction. TODO: proper
-- liveness analysis in JuvixAsm.
liveVars :: Int -> [VarRef]
liveVars k =
map (VarRef VarGroupStack) [0 .. n - k]
++ map (VarRef VarGroupTemp) [0 .. si ^. Asm.stackInfoTempStackHeight - 1]
map (VarRef VarGroupLocal) [0 .. si ^. Asm.stackInfoTempStackHeight - 1]
++ map (VarRef VarGroupLocal) [ntmps .. ntmps + n - k]
++ map (VarRef VarGroupArgs) [0 .. funInfo ^. Asm.functionArgsNum - 1]

getArgs :: Int -> Int -> [Value]
getArgs s k = map (\i -> VRef $ VarRef VarGroupStack (n - i)) [s .. (s + k - 1)]
getArgs s k = map (\i -> VRef $ VarRef VarGroupLocal (ntmps + n - i)) [s .. (s + k - 1)]

mkBinop :: Opcode -> Instruction
mkBinop op =
Binop
( BinaryOp
{ _binaryOpCode = op,
_binaryOpResult = VarRef VarGroupStack (n - 1),
_binaryOpArg1 = VRef $ VarRef VarGroupStack n,
_binaryOpArg2 = VRef $ VarRef VarGroupStack (n - 1)
_binaryOpResult = VarRef VarGroupLocal (ntmps + n - 1),
_binaryOpArg1 = VRef $ VarRef VarGroupLocal (ntmps + n),
_binaryOpArg2 = VRef $ VarRef VarGroupLocal (ntmps + n - 1)
}
)

Expand Down Expand Up @@ -171,9 +174,9 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =

mkVar :: Asm.DirectRef -> VarRef
mkVar = \case
Asm.StackRef -> VarRef VarGroupStack n
Asm.StackRef -> VarRef VarGroupLocal (ntmps + n)
Asm.ArgRef Asm.OffsetRef {..} -> VarRef VarGroupArgs _offsetRefOffset
Asm.TempRef Asm.RefTemp {..} -> VarRef VarGroupTemp (_refTempOffsetRef ^. Asm.offsetRefOffset)
Asm.TempRef Asm.RefTemp {..} -> VarRef VarGroupLocal (_refTempOffsetRef ^. Asm.offsetRefOffset)

mkPrealloc :: Asm.InstrPrealloc -> Instruction
mkPrealloc Asm.InstrPrealloc {..} =
Expand All @@ -188,7 +191,7 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
Alloc $
InstrAlloc
{ _instrAllocTag = tag,
_instrAllocResult = VarRef VarGroupStack m,
_instrAllocResult = VarRef VarGroupLocal (ntmps + m),
_instrAllocArgs = getArgs 0 (ci ^. Asm.constructorArgsNum),
_instrAllocMemRep = ci ^. Asm.constructorRepresentation
}
Expand All @@ -201,7 +204,7 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
AllocClosure $
InstrAllocClosure
{ _instrAllocClosureSymbol = fi ^. Asm.functionSymbol,
_instrAllocClosureResult = VarRef VarGroupStack m,
_instrAllocClosureResult = VarRef VarGroupLocal (ntmps + m),
_instrAllocClosureExpectedArgsNum = fi ^. Asm.functionArgsNum,
_instrAllocClosureArgs = getArgs 0 _allocClosureArgsNum
}
Expand All @@ -213,8 +216,8 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
mkExtendClosure Asm.InstrExtendClosure {..} =
ExtendClosure $
InstrExtendClosure
{ _instrExtendClosureResult = VarRef VarGroupStack m,
_instrExtendClosureValue = VarRef VarGroupStack n,
{ _instrExtendClosureResult = VarRef VarGroupLocal (ntmps + m),
_instrExtendClosureValue = VarRef VarGroupLocal (ntmps + n),
_instrExtendClosureArgs = getArgs 1 _extendClosureArgsNum
}
where
Expand All @@ -224,7 +227,7 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
mkCall isTail Asm.InstrCall {..} =
Call $
InstrCall
{ _instrCallResult = VarRef VarGroupStack m,
{ _instrCallResult = VarRef VarGroupLocal (ntmps + m),
_instrCallType = ct,
_instrCallIsTail = isTail,
_instrCallArgs = getArgs s _callArgsNum,
Expand All @@ -234,7 +237,7 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
m = n - _callArgsNum - s + 1
ct = case _callType of
Asm.CallFun f -> CallFun f
Asm.CallClosure -> CallClosure (VarRef VarGroupStack n)
Asm.CallClosure -> CallClosure (VarRef VarGroupLocal (ntmps + n))
s = case _callType of
Asm.CallFun {} -> 0
Asm.CallClosure -> 1
Expand All @@ -243,8 +246,8 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
mkCallClosures isTail Asm.InstrCallClosures {..} =
CallClosures $
InstrCallClosures
{ _instrCallClosuresResult = VarRef VarGroupStack m,
_instrCallClosuresValue = VarRef VarGroupStack n,
{ _instrCallClosuresResult = VarRef VarGroupLocal (ntmps + m),
_instrCallClosuresValue = VarRef VarGroupLocal (ntmps + n),
_instrCallClosuresIsTail = isTail,
_instrCallClosuresArgs = getArgs 1 _callClosuresArgsNum,
_instrCallClosuresLiveVars = liveVars (_callClosuresArgsNum + 1)
Expand All @@ -254,32 +257,34 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
m = n - _callClosuresArgsNum

fromAsmBranch ::
Asm.FunctionInfo ->
Asm.StackInfo ->
Asm.CmdBranch ->
Code ->
Code ->
Sem r Instruction
fromAsmBranch si Asm.CmdBranch {} codeTrue codeFalse =
fromAsmBranch fi si Asm.CmdBranch {} codeTrue codeFalse =
return $
Branch $
InstrBranch
{ _instrBranchValue = VRef $ VarRef VarGroupStack (si ^. Asm.stackInfoValueStackHeight - 1),
{ _instrBranchValue = VRef $ VarRef VarGroupLocal (fi ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1),
_instrBranchTrue = codeTrue,
_instrBranchFalse = codeFalse
}

fromAsmCase ::
Asm.FunctionInfo ->
Asm.InfoTable ->
Asm.StackInfo ->
Asm.CmdCase ->
[Code] ->
Maybe Code ->
Sem r Instruction
fromAsmCase tab si Asm.CmdCase {..} brs def =
fromAsmCase fi tab si Asm.CmdCase {..} brs def =
return $
Case $
InstrCase
{ _instrCaseValue = VRef $ VarRef VarGroupStack (si ^. Asm.stackInfoValueStackHeight - 1),
{ _instrCaseValue = VRef $ VarRef VarGroupLocal (fi ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1),
_instrCaseInductive = _cmdCaseInductive,
_instrCaseIndRep = ii ^. Asm.inductiveRepresentation,
_instrCaseBranches =
Expand All @@ -306,19 +311,20 @@ fromAsmCase tab si Asm.CmdCase {..} brs def =
HashMap.lookup _cmdCaseInductive (tab ^. Asm.infoInductives)

fromAsmSave ::
Asm.FunctionInfo ->
Asm.StackInfo ->
Asm.CmdSave ->
Code ->
Sem r Instruction
fromAsmSave si Asm.CmdSave {} block =
fromAsmSave fi si Asm.CmdSave {} block =
return $
Block $
InstrBlock
{ _instrBlockCode =
Assign
( InstrAssign
(VarRef VarGroupTemp (si ^. Asm.stackInfoTempStackHeight))
(VRef $ VarRef VarGroupStack (si ^. Asm.stackInfoValueStackHeight - 1))
(VarRef VarGroupLocal (si ^. Asm.stackInfoTempStackHeight))
(VRef $ VarRef VarGroupLocal (fi ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1))
)
: block
}
8 changes: 4 additions & 4 deletions tests/runtime/positive/test009.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ int main() {

JUVIX_FUNCTION_NS(juvix_function_calculate);
{
DECL_STMP(0);
JUVIX_INT_MUL(STMP(0), ARG(2), ARG(1));
JUVIX_INT_ADD(STMP(0), STMP(0), ARG(0));
juvix_result = STMP(0);
DECL_TMP(0);
JUVIX_INT_MUL(TMP(0), ARG(2), ARG(1));
JUVIX_INT_ADD(TMP(0), TMP(0), ARG(0));
juvix_result = TMP(0);
RETURN_NS;
}

Expand Down
21 changes: 10 additions & 11 deletions tests/runtime/positive/test010.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,21 @@ int main() {

JUVIX_FUNCTION_NS(juvix_function_calculate);
{
DECL_STMP(0);
JUVIX_INT_MUL(STMP(0), ARG(2), ARG(1));
JUVIX_INT_ADD(STMP(0), STMP(0), ARG(0));
juvix_result = STMP(0);
DECL_TMP(0);
JUVIX_INT_MUL(TMP(0), ARG(2), ARG(1));
JUVIX_INT_ADD(TMP(0), TMP(0), ARG(0));
juvix_result = TMP(0);
RETURN_NS;
}

JUVIX_FUNCTION(juvix_function_main, 1);
{
DECL_STMP(0);
ALLOC_CLOSURE(STMP(0), 1, LABEL_ADDR(juvix_closure_calculate), 2, 1);
CLOSURE_ARG(STMP(0), 0) = make_smallint(5);
CLOSURE_ARG(STMP(0), 1) = make_smallint(3);
ASSIGN_CARGS(STMP(0),
{ CARG(juvix_closure_nargs) = make_smallint(2); });
CALL_CLOSURE(STMP(0), juvix_label_1);
DECL_TMP(0);
ALLOC_CLOSURE(TMP(0), 1, LABEL_ADDR(juvix_closure_calculate), 2, 1);
CLOSURE_ARG(TMP(0), 0) = make_smallint(5);
CLOSURE_ARG(TMP(0), 1) = make_smallint(3);
ASSIGN_CARGS(TMP(0), { CARG(juvix_closure_nargs) = make_smallint(2); });
CALL_CLOSURE(TMP(0), juvix_label_1);
RETURN;
}

Expand Down
13 changes: 6 additions & 7 deletions tests/runtime/positive/test011.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ int main() {

JUVIX_FUNCTION(juvix_function_main, 0);
{
DECL_STMP(0);
ALLOC_CLOSURE(STMP(0), 1, LABEL_ADDR(juvix_closure_calculate), 2, 1);
CLOSURE_ARG(STMP(0), 0) = make_smallint(5);
CLOSURE_ARG(STMP(0), 1) = make_smallint(3);
ASSIGN_CARGS(STMP(0),
{ CARG(juvix_closure_nargs) = make_smallint(2); });
TAIL_CALL_CLOSURE(STMP(0));
DECL_TMP(0);
ALLOC_CLOSURE(TMP(0), 1, LABEL_ADDR(juvix_closure_calculate), 2, 1);
CLOSURE_ARG(TMP(0), 0) = make_smallint(5);
CLOSURE_ARG(TMP(0), 1) = make_smallint(3);
ASSIGN_CARGS(TMP(0), { CARG(juvix_closure_nargs) = make_smallint(2); });
TAIL_CALL_CLOSURE(TMP(0));
}

JUVIX_EPILOGUE;
Expand Down