Skip to content

Commit

Permalink
record update translation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Aug 22, 2024
1 parent 07f4337 commit 7298271
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
45 changes: 43 additions & 2 deletions src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fromInternal Internal.InternalTypedResult {..} = do
goModule :: Bool -> Internal.InfoTable -> Internal.Module -> Theory
goModule onlyTypes infoTable Internal.Module {..} =
Theory
{ _theoryName = over nameText toIsabelleName $ over namePretty toIsabelleName _moduleName,
{ _theoryName = overNameText toIsabelleName _moduleName,
_theoryImports = map (^. Internal.importModuleName) (_moduleBody ^. Internal.moduleImports),
_theoryStatements = case _modulePragmas ^. pragmasIsabelleIgnore of
Just (PragmaIsabelleIgnore True) -> []
Expand Down Expand Up @@ -498,6 +498,16 @@ goModule onlyTypes infoTable Internal.Module {..} =
| Just (name, fields) <- getRecordCreation app = do
fields' <- mapM (secondM goExpression) fields
return $ ExprRecord (Record name fields')
| Just (indName, names, record, fields) <- getRecordUpdate app = do
record' <- goExpression record
let names' = map (qualifyRecordProjection indName) names
nset <- ask @NameSet
nmap <- ask @NameMap
let nset' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset names'
exprs = map (\n -> ExprApp (Application (ExprIden n) record')) names'
nmap' = foldl' (flip (over nameMap . uncurry HashMap.insert)) nmap (zipExact names exprs)
fields' <- mapM (secondM (withLocalNames nset' nmap' . goExpression)) fields
return $ ExprRecordUpdate (RecordUpdate record' (Record indName fields'))
| Just (fn, args) <- getIdentApp app = do
fn' <- goExpression fn
args' <- mapM goExpression args
Expand Down Expand Up @@ -666,6 +676,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
where
(fn, args) = Internal.unfoldApplication app

getRecordUpdate :: Internal.Application -> Maybe (Name, [Name], Internal.Expression, [(Name, Internal.Expression)])
getRecordUpdate Internal.Application {..} = case _appLeft of
Internal.ExpressionLambda Internal.Lambda {..} -> case _lambdaClauses of
Internal.LambdaClause {..} :| [] -> case fmap (^. Internal.patternArgPattern) _lambdaPatterns of
Internal.PatternConstructorApp Internal.ConstructorApp {..} :| []
| all isPatternArgVar _constrAppParameters ->
case _lambdaBody of
Internal.ExpressionApplication app -> case fn of
Internal.ExpressionIden (Internal.IdenConstructor name')
| name' == _constrAppConstructor ->
case HashMap.lookup name' (infoTable ^. Internal.infoConstructors) of
Just ctrInfo
| ctrInfo ^. Internal.constructorInfoRecord ->
let names = map (fromJust . getPatternArgName) _constrAppParameters
fields = goRecordFields (getArgtys ctrInfo) (toList args)
fields' = zipWithExact (\n (n', e) -> (setNameText (n' ^. namePretty) n, e)) names fields
fields'' = filter (\(n, e) -> e /= Internal.ExpressionIden (Internal.IdenVar n)) fields'
in Just (ctrInfo ^. Internal.constructorInfoInductive, map fst fields', _appRight, fields'')
_ -> Nothing
_ -> Nothing
where
(fn, args) = Internal.unfoldApplication app
_ -> Nothing
_ -> Nothing
_ -> Nothing
_ -> Nothing

getIdentApp :: Internal.Application -> Maybe (Internal.Expression, [Internal.Expression])
getIdentApp app = case mty of
Just (ty, paramsNum) -> Just (fn, args')
Expand Down Expand Up @@ -1000,7 +1037,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
binders <- gets (^. nameSet)
let adjustName :: Name -> Expression
adjustName name =
let name' = overNameText (\n -> indName ^. nameText <> "." <> n) name
let name' = qualifyRecordProjection indName name
in ExprApp (Application (ExprIden name') (ExprIden vname))
vname = defaultName (disambiguate binders "v")
fieldsVars = map (second (fromJust . getPatternArgName)) $ map (first adjustName) $ filter (isPatternArgVar . snd) fields
Expand Down Expand Up @@ -1112,6 +1149,10 @@ goModule onlyTypes infoTable Internal.Module {..} =
_nameIdModuleId = defaultModuleId
}

qualifyRecordProjection :: Name -> Name -> Name
qualifyRecordProjection indName name =
setNameText (indName ^. namePretty <> "." <> name ^. namePretty) name

setNameText :: Text -> Name -> Name
setNameText txt name =
set namePretty txt
Expand Down
3 changes: 3 additions & 0 deletions tests/positive/Isabelle/Program.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ funR3 (er : Either' R R) : R :=
| Right' mkR@{r1; r2 := zero} := mkR@{r1 := 7; r2 := r1}
| Right' r@(mkR@{r1; r2}) := r@R{r1 := r2 + 2; r2 := r1 + 3};

funR4 : R -> R
| r@mkR@{r1} := r@R{r2 := r1};

-- Standard library

bf (b1 b2 : Bool) : Bool := not (b1 && b2);
Expand Down
33 changes: 13 additions & 20 deletions tests/positive/Isabelle/isabelle/Program.thy
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,14 @@ definition v :: nat where
fun funR :: "R \<Rightarrow> R" where
"funR r' =
(case r' of
v' \<Rightarrow>
(\<lambda> x0 . case x0 of
v'0 \<Rightarrow> (| R.r1 = R.r1 v'0 + R.r2 v'0, R.r2 = R.r2 v'0 |)) r')"
v' \<Rightarrow> r' (| R.r1 := R.r1 r' + R.r2 r' |))"

fun funRR :: "R \<Rightarrow> R" where
"funRR r'0 =
(case (r'0) of
(r') \<Rightarrow>
(case (r') of
(v') \<Rightarrow>
(\<lambda> x0 . case x0 of
v'0 \<Rightarrow>
(| R.r1 = R.r1 v'0 + R.r2 v'0, R.r2 = R.r2 v'0 |)) r'))"
(v') \<Rightarrow> r' (| R.r1 := R.r1 r' + R.r2 r' |)))"

fun funR' :: "R \<Rightarrow> R" where
"funR' (| R.r1 = rr1, R.r2 = rr2 |) =
Expand Down Expand Up @@ -206,17 +201,13 @@ fun funR3 :: "(R, R) Either' \<Rightarrow> R" where
(Right' r') \<Rightarrow>
(case (r') of
(v'3) \<Rightarrow>
(\<lambda> x0 . case x0 of
v'5 \<Rightarrow>
(| R.r1 = R.r2 v'5 + 2, R.r2 = R.r1 v'5 + 3 |)) r')))) |
r' (| R.r1 := R.r2 r' + 2, R.r2 := R.r1 r' + 3 |))))) |
v'4 \<Rightarrow>
(case v'4 of
(Right' r') \<Rightarrow>
(case (r') of
(v'3) \<Rightarrow>
(\<lambda> x0 . case x0 of
v'5 \<Rightarrow>
(| R.r1 = R.r2 v'5 + 2, R.r2 = R.r1 v'5 + 3 |)) r'))))) |
r' (| R.r1 := R.r2 r' + 2, R.r2 := R.r1 r' + 3 |)))))) |
v'2 \<Rightarrow>
(case v'2 of
(Right' v'1) \<Rightarrow>
Expand All @@ -233,17 +224,19 @@ fun funR3 :: "(R, R) Either' \<Rightarrow> R" where
(Right' r') \<Rightarrow>
(case (r') of
(v'3) \<Rightarrow>
(\<lambda> x0 . case x0 of
v'5 \<Rightarrow>
(| R.r1 = R.r2 v'5 + 2, R.r2 = R.r1 v'5 + 3 |)) r')))) |
r' (| R.r1 := R.r2 r' + 2, R.r2 := R.r1 r' + 3 |))))) |
v'4 \<Rightarrow>
(case v'4 of
(Right' r') \<Rightarrow>
(case (r') of
(v'3) \<Rightarrow>
(\<lambda> x0 . case x0 of
v'5 \<Rightarrow>
(| R.r1 = R.r2 v'5 + 2, R.r2 = R.r1 v'5 + 3 |)) r')))))"
(v'3) \<Rightarrow> r' (| R.r1 := R.r2 r' + 2, R.r2 := R.r1 r' + 3 |))))))"

fun funR4 :: "R \<Rightarrow> R" where
"funR4 r'0 =
(case (r'0) of
(r') \<Rightarrow>
(case (r') of
(v') \<Rightarrow> r' (| R.r2 := R.r1 r' |)))"

fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
"bf b1 b2 = (\<not> (b1 \<and> b2))"
Expand Down

0 comments on commit 7298271

Please sign in to comment.