Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tropical semiring #151

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/tropical.gr
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
-- examples for tropical semirings

isAssocPlus1 : forall {a b : Type, n m k : Tropical} . (a [(n + m) + k] -> b) -> (a [n + (m + k)] -> b)
isAssocPlus1 f = f

isAssocPlus2 : forall {a b : Type, n m k : Tropical} . (a [n + (m + k)] -> b) -> (a [(n + m) + k] -> b)
isAssocPlus2 f = f

isAssocMult1 : forall {a b : Type, n m k : Tropical} . (a [(n * m) * k] -> b) -> (a [n * (m * k)] -> b)
isAssocMult1 f = f

isAssocMult2 : forall {a b : Type, n m k : Tropical} . (a [n * (m * k)] -> b) -> (a [(n * m) * k] -> b)
isAssocMult2 f = f

isDistrib1 : forall {a b : Type, n m k : Tropical} . (a [(n + m) * k] -> b) -> (a [(n * k) + (m * k)] -> b)
isDistrib1 f = f

isDistrib2 : forall {a b : Type, n m k : Tropical} . (a [(n * k) + (m * k)] -> b) -> (a [(n + m) * k] -> b)
isDistrib2 f = f

isCommutePlus : forall {a b : Type, n m : Tropical} . (a [n + m] -> b) -> (a [m + n] -> b)
isCommutePlus f = f

isCommuteMult : forall {a b : Type, n m : Tropical} . (a [n * m] -> b) -> (a [m * n] -> b)
isCommuteMult f = f

zeroPlusIdentity : forall {a b : Type, n : Tropical} . (a [n + (Inf : Tropical)] -> b) -> (a [n] -> b)
zeroPlusIdentity f = f

oneTimesIdentity : forall {a b : Type, n : Tropical} . (a [n * (1 : Tropical)] -> b) -> (a [n] -> b)
oneTimesIdentity f = f

zeroAbsorbs : forall {a b : Type, n : Tropical} . (a [n * (Inf : Tropical)] -> b) -> (a [Inf : Tropical] -> b)
zeroAbsorbs f = f
2 changes: 2 additions & 0 deletions frontend/package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ library:
- Language.Granule.Checker.Checker
- Language.Granule.Checker.Coeffects
- Language.Granule.Checker.Constraints
- Language.Granule.Checker.Constraints.Semiring
- Language.Granule.Checker.Constraints.SNatX
- Language.Granule.Checker.Constraints.SExt
- Language.Granule.Checker.Flatten
- Language.Granule.Checker.KindsHelpers
- Language.Granule.Checker.LaTeX
Expand Down
12 changes: 12 additions & 0 deletions frontend/src/Language/Granule/Checker/Constraints.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Language.Granule.Checker.Predicates
import Language.Granule.Context (Ctxt)

import Language.Granule.Checker.Constraints.SymbolicGrades
import Language.Granule.Checker.Constraints.SExt
import qualified Language.Granule.Checker.Constraints.SNatX as SNatX

import Language.Granule.Syntax.Helpers
Expand Down Expand Up @@ -170,6 +171,8 @@ freshCVarScoped quant name (TyCon conName) q k =
.|| solverVar .== literal unusedRepresentation
, SLevel solverVar)
"OOZ" -> k (solverVar .== 0 .|| solverVar .== 1, SOOZ (ite (solverVar .== 0) sFalse sTrue))
"Tropical" -> k ( solverVar .>= 0
, STropical (fromSInteger solverVar))
k -> solverError $ "I don't know how to make a fresh solver variable of type " <> show conName)

freshCVarScoped quant name t q k | t == extendedNat = do
Expand Down Expand Up @@ -306,13 +309,18 @@ compileCoeffect (CInfinity (Just (TyVar _))) _ _ = return (zeroToInfinity, sTrue
compileCoeffect (CInfinity Nothing) _ _ = return (zeroToInfinity, sTrue)
compileCoeffect (CInfinity _) t _| t == extendedNat =
return (SExtNat SNatX.inf, sTrue)
compileCoeffect (CInfinity _) t _| t == tropical =
return (STropical top, sTrue)

compileCoeffect (CNat n) k _ | k == nat =
return (SNat . fromInteger . toInteger $ n, sTrue)

compileCoeffect (CNat n) k _ | k == extendedNat =
return (SExtNat . fromInteger . toInteger $ n, sTrue)

compileCoeffect (CNat n) k _ | k == tropical =
return (STropical . fromSInteger . fromInteger . toInteger $ n, sTrue)

compileCoeffect (CFloat r) (TyCon k) _ | internalName k == "Q" =
return (SFloat . fromRational $ r, sTrue)

Expand Down Expand Up @@ -361,6 +369,7 @@ compileCoeffect (CZero k') k vars =
"Q" -> return (SFloat (fromRational 0), sTrue)
"Set" -> return (SSet (S.fromList []), sTrue)
"OOZ" -> return (SOOZ sFalse, sTrue)
"Tropical" -> return (STropical zero, sTrue)
_ -> solverError $ "I don't know how to compile a 0 for " <> pretty k'
(otherK', otherK) | (otherK' == extendedNat || otherK == extendedNat) ->
return (SExtNat 0, sTrue)
Expand All @@ -387,6 +396,7 @@ compileCoeffect (COne k') k vars =
"Q" -> return (SFloat (fromRational 1), sTrue)
"Set" -> return (SSet (S.fromList []), sTrue)
"OOZ" -> return (SOOZ sTrue, sTrue)
"Tropical" -> return (STropical one, sTrue)
_ -> solverError $ "I don't know how to compile a 1 for " <> pretty k'

(otherK', otherK) | (otherK' == extendedNat || otherK == extendedNat) ->
Expand Down Expand Up @@ -436,6 +446,7 @@ eqConstraint (SFloat n) (SFloat m) = return $ n .== m
eqConstraint (SLevel l) (SLevel k) = return $ l .== k
eqConstraint u@(SUnknown{}) u'@(SUnknown{}) = symGradeEq u u'
eqConstraint (SExtNat x) (SExtNat y) = return $ x .== y
eqConstraint (STropical x) (STropical y) = return $ x .== y
eqConstraint SPoint SPoint = return sTrue

eqConstraint (SInterval lb1 ub1) (SInterval lb2 ub2) =
Expand All @@ -454,6 +465,7 @@ approximatedByOrEqualConstraint (SFloat n) (SFloat m) = return $ n .<= m
approximatedByOrEqualConstraint SPoint SPoint = return $ sTrue
approximatedByOrEqualConstraint (SExtNat x) (SExtNat y) = return $ x .== y
approximatedByOrEqualConstraint (SOOZ s) (SOOZ r) = pure $ s .== r
approximatedByOrEqualConstraint (STropical x) (STropical y) = return $ x .== y
approximatedByOrEqualConstraint (SSet s) (SSet t) =
return $ if s == t then sTrue else sFalse

Expand Down
147 changes: 147 additions & 0 deletions frontend/src/Language/Granule/Checker/Constraints/SExt.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeSynonymInstances #-}

-- | Represents extended coeffects.
module Language.Granule.Checker.Constraints.SExt
(
-- * Extended coeffects.
SExt
, pattern SExtUn
, pattern SExtX
, isExt
, extBase

-- * Top-completed coeffects.
, TopCompleted
, isTop
, notTop
, HasTop(..)

-- * Semirings
, Semiring(..)
, Tropical

-- * Helpers
, FromSInteger(..)
) where


import Data.Either (isRight)
import Data.SBV


import Language.Granule.Checker.Constraints.Semiring


-- | Base @b@, extension type @e@.
newtype SExt b e = SExt { unExt :: Either b e }
deriving (EqSymbolic, Mergeable)


pattern SExtUn :: b -> SExt b e
pattern SExtUn b = SExt (Left b)

pattern SExtX :: e -> SExt b e
pattern SExtX e = SExt (Right e)


instance (Show b, Show e) => Show (SExt b e) where
show = either show show . unExt


isExt :: SExt b e -> Bool
isExt = isRight . unExt


extBase :: b -> SExt b e
extBase = SExtUn


extWith :: e -> SExt b e
extWith = SExtX


exts :: (b -> e -> c) ->
(e -> b -> c) ->
(b -> b -> c) ->
(e -> e -> c) ->
SExt b e -> SExt b e -> c
exts lr rl ll rr x y = either (\b -> either (ll b) (lr b) (unExt y))
(\e -> either (rl e) (rr e) (unExt y)) (unExt x)


const2 :: c -> a -> b -> c
const2 x _ _ = x


newtype Top = Top ()
deriving (Mergeable)


instance EqSymbolic Top where
(Top ()) .== (Top ()) = sTrue


instance Show Top where
show _ = "∞"


type TopCompleted a = SExt a Top


instance HasTop (TopCompleted b) where
top = extWith (Top ())


notTop :: b -> TopCompleted b
notTop = extBase


isTop :: TopCompleted b -> Bool
isTop = isExt


isZero :: (EqSymbolic a, Semiring a) => a -> SBool
isZero = ((.==) zero)


instance (OrdSymbolic b) => OrdSymbolic (TopCompleted b) where
x .< y = exts (const2 sTrue) (const2 sFalse) (\a b -> a .< b) (const2 sFalse) x y


-----------------------
-- Tropical Semiring --
-----------------------


newtype Tropical = Tropical { unTropical :: TopCompleted SInteger }
deriving (Show, EqSymbolic, HasTop, Mergeable, OrdSymbolic)


instance Semiring Tropical where
zero = Tropical top
one = Tropical (notTop 0)
plus x y =
Tropical $ exts (\b _ -> extBase b) (\_ b -> extBase b) (\a b -> extBase (smin a b)) (const2 top) (unTropical x) (unTropical y)

times x y = ite (isZero x .|| isZero y) zero
(Tropical $ (exts (const2 top) (const2 top) (\a b -> extBase (a + b)) (const2 top) (unTropical x) (unTropical y)))


instance FromSInteger Tropical where
fromSInteger = Tropical . extBase


-------------------
----- Helpers -----
-------------------


class FromSInteger a where
fromSInteger :: SInteger -> a


class HasTop a where
top :: a
13 changes: 13 additions & 0 deletions frontend/src/Language/Granule/Checker/Constraints/Semiring.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- | Represents extended coeffects.
module Language.Granule.Checker.Constraints.Semiring
(
-- * Semirings
Semiring(..)
) where


class Semiring a where
zero :: a
one :: a
plus :: a -> a -> a
times :: a -> a -> a
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Language.Granule.Checker.Constraints.SymbolicGrades where

import Language.Granule.Syntax.Identifiers
import Language.Granule.Syntax.Type
import Language.Granule.Checker.Constraints.SExt
import Language.Granule.Checker.Constraints.SNatX

import Data.Functor.Identity
Expand All @@ -31,6 +32,7 @@ data SGrade =
| SLevel SInteger
| SSet (S.Set (Id, Type))
| SExtNat { sExtNat :: SNatX }
| STropical { sTropical :: Tropical }
| SInterval { sLowerBound :: SGrade, sUpperBound :: SGrade }
-- Single point coeffect (not exposed at the moment)
| SPoint
Expand Down Expand Up @@ -111,6 +113,7 @@ match (SFloat _) (SFloat _) = True
match (SLevel _) (SLevel _) = True
match (SSet _) (SSet _) = True
match (SExtNat _) (SExtNat _) = True
match (STropical _) (STropical _) = True
match (SInterval s1 s2) (SInterval t1 t2) = match s1 t1 && match t1 t2
match SPoint SPoint = True
match (SProduct s1 s2) (SProduct t1 t2) = match s1 t1 && match s2 t2
Expand Down Expand Up @@ -154,6 +157,7 @@ applyToProducts _ _ _ a b =
natLike :: SGrade -> Bool
natLike (SNat _) = True
natLike (SExtNat _) = True
natLike (STropical _) = True
natLike _ = False

instance Mergeable SGrade where
Expand All @@ -162,6 +166,7 @@ instance Mergeable SGrade where
symbolicMerge s sb (SLevel n) (SLevel n') = SLevel (symbolicMerge s sb n n')
symbolicMerge s sb (SSet n) (SSet n') = error "Can't symbolic merge sets yet"
symbolicMerge s sb (SExtNat n) (SExtNat n') = SExtNat (symbolicMerge s sb n n')
symbolicMerge s sb (STropical n) (STropical n') = STropical (symbolicMerge s sb n n')
symbolicMerge s sb (SInterval lb1 ub1) (SInterval lb2 ub2) =
SInterval (symbolicMerge s sb lb1 lb2) (symbolicMerge s sb ub1 ub2)
symbolicMerge s sb SPoint SPoint = SPoint
Expand Down Expand Up @@ -218,6 +223,7 @@ symGradeEq (SFloat n) (SFloat n') = return $ n .== n'
symGradeEq (SLevel n) (SLevel n') = return $ n .== n'
symGradeEq (SSet n) (SSet n') = solverError "Can't compare symbolic sets yet"
symGradeEq (SExtNat n) (SExtNat n') = return $ n .== n'
symGradeEq (STropical n) (STropical n') = return $ n .== n'
symGradeEq SPoint SPoint = return $ sTrue
symGradeEq (SOOZ s) (SOOZ r) = pure $ s .== r
symGradeEq s t | isSProduct s || isSProduct t =
Expand Down Expand Up @@ -271,6 +277,7 @@ symGradePlus (SSet s) (SSet t) = return $ SSet $ S.union s t
symGradePlus (SLevel lev1) (SLevel lev2) = return $ SLevel $ lev1 `smax` lev2
symGradePlus (SFloat n1) (SFloat n2) = return $ SFloat $ n1 + n2
symGradePlus (SExtNat x) (SExtNat y) = return $ SExtNat (x + y)
symGradePlus (STropical x) (STropical y) = pure $ STropical (plus x y)
symGradePlus (SInterval lb1 ub1) (SInterval lb2 ub2) =
liftM2 SInterval (lb1 `symGradePlus` lb2) (ub1 `symGradePlus` ub2)
symGradePlus SPoint SPoint = return $ SPoint
Expand Down Expand Up @@ -310,6 +317,7 @@ symGradeTimes (SLevel lev1) (SLevel lev2) = return $
symGradeTimes (SFloat n1) (SFloat n2) = return $ SFloat $ n1 * n2
symGradeTimes (SExtNat x) (SExtNat y) = return $ SExtNat (x * y)
symGradeTimes (SOOZ s) (SOOZ r) = pure . SOOZ $ s .&& r
symGradeTimes (STropical x) (STropical y) = pure $ STropical (times x y)

symGradeTimes (SInterval lb1 ub1) (SInterval lb2 ub2) =
liftM2 SInterval (comb symGradeMeet) (comb symGradeJoin)
Expand Down
1 change: 1 addition & 0 deletions frontend/src/Language/Granule/Checker/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ typeConstructors =
, (mkId "Public", (KPromote (TyCon $ mkId "Level"), [], False))
, (mkId "Unused", (KPromote (TyCon $ mkId "Level"), [], False))
, (mkId "OOZ", (KCoeffect, [], False)) -- 1 + 1 = 0
, (mkId "Tropical", (KCoeffect, [], False)) -- Tropical semiring
, (mkId "Interval", (KFun KCoeffect KCoeffect, [], False))
, (mkId "Set", (KFun (KVar $ mkId "k") (KFun (kConstr $ mkId "k") KCoeffect), [], False))
-- Channels and protocol types
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/Language/Granule/Syntax/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,10 @@ publicRepresentation = 2
unusedRepresentation :: Integer
unusedRepresentation = 0

nat, extendedNat :: Type
nat, extendedNat, tropical :: Type
nat = TyCon $ mkId "Nat"
extendedNat = TyApp (TyCon $ mkId "Ext") (TyCon $ mkId "Nat")
tropical = TyCon $ mkId "Tropical"

infinity :: Coeffect
infinity = CInfinity (Just extendedNat)
Expand Down