From b03ec76abcb4b2028ad52683e29c1ec4dd932a5b Mon Sep 17 00:00:00 2001 From: Ben Moon Date: Mon, 13 Jul 2020 16:18:32 +0100 Subject: [PATCH] Add tropical semiring --- examples/tropical.gr | 34 ++++ frontend/package.yaml | 2 + .../Language/Granule/Checker/Constraints.hs | 12 ++ .../Granule/Checker/Constraints/SExt.hs | 147 ++++++++++++++++++ .../Granule/Checker/Constraints/Semiring.hs | 13 ++ .../Checker/Constraints/SymbolicGrades.hs | 8 + .../Language/Granule/Checker/Primitives.hs | 1 + frontend/src/Language/Granule/Syntax/Type.hs | 3 +- 8 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 examples/tropical.gr create mode 100644 frontend/src/Language/Granule/Checker/Constraints/SExt.hs create mode 100644 frontend/src/Language/Granule/Checker/Constraints/Semiring.hs diff --git a/examples/tropical.gr b/examples/tropical.gr new file mode 100644 index 000000000..0e4965657 --- /dev/null +++ b/examples/tropical.gr @@ -0,0 +1,34 @@ +-- examples for tropical semirings + +isAssocPlus1 : forall {a b : Type, n m k : Tropical} . (a [(n + m) + k] -> b) -> (a [n + (m + k)] -> b) +isAssocPlus1 f = f + +isAssocPlus2 : forall {a b : Type, n m k : Tropical} . (a [n + (m + k)] -> b) -> (a [(n + m) + k] -> b) +isAssocPlus2 f = f + +isAssocMult1 : forall {a b : Type, n m k : Tropical} . (a [(n * m) * k] -> b) -> (a [n * (m * k)] -> b) +isAssocMult1 f = f + +isAssocMult2 : forall {a b : Type, n m k : Tropical} . (a [n * (m * k)] -> b) -> (a [(n * m) * k] -> b) +isAssocMult2 f = f + +isDistrib1 : forall {a b : Type, n m k : Tropical} . (a [(n + m) * k] -> b) -> (a [(n * k) + (m * k)] -> b) +isDistrib1 f = f + +isDistrib2 : forall {a b : Type, n m k : Tropical} . (a [(n * k) + (m * k)] -> b) -> (a [(n + m) * k] -> b) +isDistrib2 f = f + +isCommutePlus : forall {a b : Type, n m : Tropical} . (a [n + m] -> b) -> (a [m + n] -> b) +isCommutePlus f = f + +isCommuteMult : forall {a b : Type, n m : Tropical} . (a [n * m] -> b) -> (a [m * n] -> b) +isCommuteMult f = f + +zeroPlusIdentity : forall {a b : Type, n : Tropical} . (a [n + (Inf : Tropical)] -> b) -> (a [n] -> b) +zeroPlusIdentity f = f + +oneTimesIdentity : forall {a b : Type, n : Tropical} . (a [n * (1 : Tropical)] -> b) -> (a [n] -> b) +oneTimesIdentity f = f + +zeroAbsorbs : forall {a b : Type, n : Tropical} . (a [n * (Inf : Tropical)] -> b) -> (a [Inf : Tropical] -> b) +zeroAbsorbs f = f diff --git a/frontend/package.yaml b/frontend/package.yaml index 5eb644ea4..e4bcbf95c 100644 --- a/frontend/package.yaml +++ b/frontend/package.yaml @@ -43,7 +43,9 @@ library: - Language.Granule.Checker.Checker - Language.Granule.Checker.Coeffects - Language.Granule.Checker.Constraints + - Language.Granule.Checker.Constraints.Semiring - Language.Granule.Checker.Constraints.SNatX + - Language.Granule.Checker.Constraints.SExt - Language.Granule.Checker.Flatten - Language.Granule.Checker.KindsHelpers - Language.Granule.Checker.LaTeX diff --git a/frontend/src/Language/Granule/Checker/Constraints.hs b/frontend/src/Language/Granule/Checker/Constraints.hs index 8b0d6ee7c..ab5bed68c 100644 --- a/frontend/src/Language/Granule/Checker/Constraints.hs +++ b/frontend/src/Language/Granule/Checker/Constraints.hs @@ -21,6 +21,7 @@ import Language.Granule.Checker.Predicates import Language.Granule.Context (Ctxt) import Language.Granule.Checker.Constraints.SymbolicGrades +import Language.Granule.Checker.Constraints.SExt import qualified Language.Granule.Checker.Constraints.SNatX as SNatX import Language.Granule.Syntax.Helpers @@ -170,6 +171,8 @@ freshCVarScoped quant name (TyCon conName) q k = .|| solverVar .== literal unusedRepresentation , SLevel solverVar) "OOZ" -> k (solverVar .== 0 .|| solverVar .== 1, SOOZ (ite (solverVar .== 0) sFalse sTrue)) + "Tropical" -> k ( solverVar .>= 0 + , STropical (fromSInteger solverVar)) k -> solverError $ "I don't know how to make a fresh solver variable of type " <> show conName) freshCVarScoped quant name t q k | t == extendedNat = do @@ -306,6 +309,8 @@ compileCoeffect (CInfinity (Just (TyVar _))) _ _ = return (zeroToInfinity, sTrue compileCoeffect (CInfinity Nothing) _ _ = return (zeroToInfinity, sTrue) compileCoeffect (CInfinity _) t _| t == extendedNat = return (SExtNat SNatX.inf, sTrue) +compileCoeffect (CInfinity _) t _| t == tropical = + return (STropical top, sTrue) compileCoeffect (CNat n) k _ | k == nat = return (SNat . fromInteger . toInteger $ n, sTrue) @@ -313,6 +318,9 @@ compileCoeffect (CNat n) k _ | k == nat = compileCoeffect (CNat n) k _ | k == extendedNat = return (SExtNat . fromInteger . toInteger $ n, sTrue) +compileCoeffect (CNat n) k _ | k == tropical = + return (STropical . fromSInteger . fromInteger . toInteger $ n, sTrue) + compileCoeffect (CFloat r) (TyCon k) _ | internalName k == "Q" = return (SFloat . fromRational $ r, sTrue) @@ -361,6 +369,7 @@ compileCoeffect (CZero k') k vars = "Q" -> return (SFloat (fromRational 0), sTrue) "Set" -> return (SSet (S.fromList []), sTrue) "OOZ" -> return (SOOZ sFalse, sTrue) + "Tropical" -> return (STropical zero, sTrue) _ -> solverError $ "I don't know how to compile a 0 for " <> pretty k' (otherK', otherK) | (otherK' == extendedNat || otherK == extendedNat) -> return (SExtNat 0, sTrue) @@ -387,6 +396,7 @@ compileCoeffect (COne k') k vars = "Q" -> return (SFloat (fromRational 1), sTrue) "Set" -> return (SSet (S.fromList []), sTrue) "OOZ" -> return (SOOZ sTrue, sTrue) + "Tropical" -> return (STropical one, sTrue) _ -> solverError $ "I don't know how to compile a 1 for " <> pretty k' (otherK', otherK) | (otherK' == extendedNat || otherK == extendedNat) -> @@ -436,6 +446,7 @@ eqConstraint (SFloat n) (SFloat m) = return $ n .== m eqConstraint (SLevel l) (SLevel k) = return $ l .== k eqConstraint u@(SUnknown{}) u'@(SUnknown{}) = symGradeEq u u' eqConstraint (SExtNat x) (SExtNat y) = return $ x .== y +eqConstraint (STropical x) (STropical y) = return $ x .== y eqConstraint SPoint SPoint = return sTrue eqConstraint (SInterval lb1 ub1) (SInterval lb2 ub2) = @@ -454,6 +465,7 @@ approximatedByOrEqualConstraint (SFloat n) (SFloat m) = return $ n .<= m approximatedByOrEqualConstraint SPoint SPoint = return $ sTrue approximatedByOrEqualConstraint (SExtNat x) (SExtNat y) = return $ x .== y approximatedByOrEqualConstraint (SOOZ s) (SOOZ r) = pure $ s .== r +approximatedByOrEqualConstraint (STropical x) (STropical y) = return $ x .== y approximatedByOrEqualConstraint (SSet s) (SSet t) = return $ if s == t then sTrue else sFalse diff --git a/frontend/src/Language/Granule/Checker/Constraints/SExt.hs b/frontend/src/Language/Granule/Checker/Constraints/SExt.hs new file mode 100644 index 000000000..0be7f9409 --- /dev/null +++ b/frontend/src/Language/Granule/Checker/Constraints/SExt.hs @@ -0,0 +1,147 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeSynonymInstances #-} + +-- | Represents extended coeffects. +module Language.Granule.Checker.Constraints.SExt + ( + -- * Extended coeffects. + SExt + , pattern SExtUn + , pattern SExtX + , isExt + , extBase + + -- * Top-completed coeffects. + , TopCompleted + , isTop + , notTop + , HasTop(..) + + -- * Semirings + , Semiring(..) + , Tropical + + -- * Helpers + , FromSInteger(..) + ) where + + +import Data.Either (isRight) +import Data.SBV + + +import Language.Granule.Checker.Constraints.Semiring + + +-- | Base @b@, extension type @e@. +newtype SExt b e = SExt { unExt :: Either b e } + deriving (EqSymbolic, Mergeable) + + +pattern SExtUn :: b -> SExt b e +pattern SExtUn b = SExt (Left b) + +pattern SExtX :: e -> SExt b e +pattern SExtX e = SExt (Right e) + + +instance (Show b, Show e) => Show (SExt b e) where + show = either show show . unExt + + +isExt :: SExt b e -> Bool +isExt = isRight . unExt + + +extBase :: b -> SExt b e +extBase = SExtUn + + +extWith :: e -> SExt b e +extWith = SExtX + + +exts :: (b -> e -> c) -> + (e -> b -> c) -> + (b -> b -> c) -> + (e -> e -> c) -> + SExt b e -> SExt b e -> c +exts lr rl ll rr x y = either (\b -> either (ll b) (lr b) (unExt y)) + (\e -> either (rl e) (rr e) (unExt y)) (unExt x) + + +const2 :: c -> a -> b -> c +const2 x _ _ = x + + +newtype Top = Top () + deriving (Mergeable) + + +instance EqSymbolic Top where + (Top ()) .== (Top ()) = sTrue + + +instance Show Top where + show _ = "∞" + + +type TopCompleted a = SExt a Top + + +instance HasTop (TopCompleted b) where + top = extWith (Top ()) + + +notTop :: b -> TopCompleted b +notTop = extBase + + +isTop :: TopCompleted b -> Bool +isTop = isExt + + +isZero :: (EqSymbolic a, Semiring a) => a -> SBool +isZero = ((.==) zero) + + +instance (OrdSymbolic b) => OrdSymbolic (TopCompleted b) where + x .< y = exts (const2 sTrue) (const2 sFalse) (\a b -> a .< b) (const2 sFalse) x y + + +----------------------- +-- Tropical Semiring -- +----------------------- + + +newtype Tropical = Tropical { unTropical :: TopCompleted SInteger } + deriving (Show, EqSymbolic, HasTop, Mergeable, OrdSymbolic) + + +instance Semiring Tropical where + zero = Tropical top + one = Tropical (notTop 0) + plus x y = + Tropical $ exts (\b _ -> extBase b) (\_ b -> extBase b) (\a b -> extBase (smin a b)) (const2 top) (unTropical x) (unTropical y) + + times x y = ite (isZero x .|| isZero y) zero + (Tropical $ (exts (const2 top) (const2 top) (\a b -> extBase (a + b)) (const2 top) (unTropical x) (unTropical y))) + + +instance FromSInteger Tropical where + fromSInteger = Tropical . extBase + + +------------------- +----- Helpers ----- +------------------- + + +class FromSInteger a where + fromSInteger :: SInteger -> a + + +class HasTop a where + top :: a diff --git a/frontend/src/Language/Granule/Checker/Constraints/Semiring.hs b/frontend/src/Language/Granule/Checker/Constraints/Semiring.hs new file mode 100644 index 000000000..f2d449b3d --- /dev/null +++ b/frontend/src/Language/Granule/Checker/Constraints/Semiring.hs @@ -0,0 +1,13 @@ +-- | Represents extended coeffects. +module Language.Granule.Checker.Constraints.Semiring + ( + -- * Semirings + Semiring(..) + ) where + + +class Semiring a where + zero :: a + one :: a + plus :: a -> a -> a + times :: a -> a -> a diff --git a/frontend/src/Language/Granule/Checker/Constraints/SymbolicGrades.hs b/frontend/src/Language/Granule/Checker/Constraints/SymbolicGrades.hs index 95651e5da..d8395df14 100644 --- a/frontend/src/Language/Granule/Checker/Constraints/SymbolicGrades.hs +++ b/frontend/src/Language/Granule/Checker/Constraints/SymbolicGrades.hs @@ -8,6 +8,7 @@ module Language.Granule.Checker.Constraints.SymbolicGrades where import Language.Granule.Syntax.Identifiers import Language.Granule.Syntax.Type +import Language.Granule.Checker.Constraints.SExt import Language.Granule.Checker.Constraints.SNatX import Data.Functor.Identity @@ -31,6 +32,7 @@ data SGrade = | SLevel SInteger | SSet (S.Set (Id, Type)) | SExtNat { sExtNat :: SNatX } + | STropical { sTropical :: Tropical } | SInterval { sLowerBound :: SGrade, sUpperBound :: SGrade } -- Single point coeffect (not exposed at the moment) | SPoint @@ -111,6 +113,7 @@ match (SFloat _) (SFloat _) = True match (SLevel _) (SLevel _) = True match (SSet _) (SSet _) = True match (SExtNat _) (SExtNat _) = True +match (STropical _) (STropical _) = True match (SInterval s1 s2) (SInterval t1 t2) = match s1 t1 && match t1 t2 match SPoint SPoint = True match (SProduct s1 s2) (SProduct t1 t2) = match s1 t1 && match s2 t2 @@ -154,6 +157,7 @@ applyToProducts _ _ _ a b = natLike :: SGrade -> Bool natLike (SNat _) = True natLike (SExtNat _) = True +natLike (STropical _) = True natLike _ = False instance Mergeable SGrade where @@ -162,6 +166,7 @@ instance Mergeable SGrade where symbolicMerge s sb (SLevel n) (SLevel n') = SLevel (symbolicMerge s sb n n') symbolicMerge s sb (SSet n) (SSet n') = error "Can't symbolic merge sets yet" symbolicMerge s sb (SExtNat n) (SExtNat n') = SExtNat (symbolicMerge s sb n n') + symbolicMerge s sb (STropical n) (STropical n') = STropical (symbolicMerge s sb n n') symbolicMerge s sb (SInterval lb1 ub1) (SInterval lb2 ub2) = SInterval (symbolicMerge s sb lb1 lb2) (symbolicMerge s sb ub1 ub2) symbolicMerge s sb SPoint SPoint = SPoint @@ -218,6 +223,7 @@ symGradeEq (SFloat n) (SFloat n') = return $ n .== n' symGradeEq (SLevel n) (SLevel n') = return $ n .== n' symGradeEq (SSet n) (SSet n') = solverError "Can't compare symbolic sets yet" symGradeEq (SExtNat n) (SExtNat n') = return $ n .== n' +symGradeEq (STropical n) (STropical n') = return $ n .== n' symGradeEq SPoint SPoint = return $ sTrue symGradeEq (SOOZ s) (SOOZ r) = pure $ s .== r symGradeEq s t | isSProduct s || isSProduct t = @@ -271,6 +277,7 @@ symGradePlus (SSet s) (SSet t) = return $ SSet $ S.union s t symGradePlus (SLevel lev1) (SLevel lev2) = return $ SLevel $ lev1 `smax` lev2 symGradePlus (SFloat n1) (SFloat n2) = return $ SFloat $ n1 + n2 symGradePlus (SExtNat x) (SExtNat y) = return $ SExtNat (x + y) +symGradePlus (STropical x) (STropical y) = pure $ STropical (plus x y) symGradePlus (SInterval lb1 ub1) (SInterval lb2 ub2) = liftM2 SInterval (lb1 `symGradePlus` lb2) (ub1 `symGradePlus` ub2) symGradePlus SPoint SPoint = return $ SPoint @@ -310,6 +317,7 @@ symGradeTimes (SLevel lev1) (SLevel lev2) = return $ symGradeTimes (SFloat n1) (SFloat n2) = return $ SFloat $ n1 * n2 symGradeTimes (SExtNat x) (SExtNat y) = return $ SExtNat (x * y) symGradeTimes (SOOZ s) (SOOZ r) = pure . SOOZ $ s .&& r +symGradeTimes (STropical x) (STropical y) = pure $ STropical (times x y) symGradeTimes (SInterval lb1 ub1) (SInterval lb2 ub2) = liftM2 SInterval (comb symGradeMeet) (comb symGradeJoin) diff --git a/frontend/src/Language/Granule/Checker/Primitives.hs b/frontend/src/Language/Granule/Checker/Primitives.hs index a1094207c..164fd548d 100644 --- a/frontend/src/Language/Granule/Checker/Primitives.hs +++ b/frontend/src/Language/Granule/Checker/Primitives.hs @@ -41,6 +41,7 @@ typeConstructors = , (mkId "Public", (KPromote (TyCon $ mkId "Level"), [], False)) , (mkId "Unused", (KPromote (TyCon $ mkId "Level"), [], False)) , (mkId "OOZ", (KCoeffect, [], False)) -- 1 + 1 = 0 + , (mkId "Tropical", (KCoeffect, [], False)) -- Tropical semiring , (mkId "Interval", (KFun KCoeffect KCoeffect, [], False)) , (mkId "Set", (KFun (KVar $ mkId "k") (KFun (kConstr $ mkId "k") KCoeffect), [], False)) -- Channels and protocol types diff --git a/frontend/src/Language/Granule/Syntax/Type.hs b/frontend/src/Language/Granule/Syntax/Type.hs index 9f5241e51..cb980a146 100644 --- a/frontend/src/Language/Granule/Syntax/Type.hs +++ b/frontend/src/Language/Granule/Syntax/Type.hs @@ -253,9 +253,10 @@ publicRepresentation = 2 unusedRepresentation :: Integer unusedRepresentation = 0 -nat, extendedNat :: Type +nat, extendedNat, tropical :: Type nat = TyCon $ mkId "Nat" extendedNat = TyApp (TyCon $ mkId "Ext") (TyCon $ mkId "Nat") +tropical = TyCon $ mkId "Tropical" infinity :: Coeffect infinity = CInfinity (Just extendedNat)