diff --git a/examples/tree_generator/demo.hs b/examples/tree_generator/demo.hs new file mode 100644 index 00000000..76e7696f --- /dev/null +++ b/examples/tree_generator/demo.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Text.Parsec.String (parseFromFile) +import qualified Traq.Data.Context as Ctx +import qualified Traq.Data.Symbolic as Sym + +import qualified Traq.Analysis as A +import Traq.Prelude +import qualified Traq.ProtoLang as P + +import Traq.Analysis.CostModel.QueryCost (SimpleQueryCost (..)) +import Traq.Primitives (Primitive) +import Traq.Primitives.Amplify.CAmplify (CAmplify (..)) +import Traq.Primitives.Amplify.QAmplify (QAmplify (..)) +import qualified Traq.Utils.Printing as PP + +type Matrix = SizeT -> SizeT -> Bool + +listToFun :: [SizeT] -> [P.Value SizeT] -> [P.Value SizeT] +listToFun xs [P.FinV i] = [P.toValue $ xs !! i] +listToFun _ _ = error "invalid index" + +matrixToFun :: Matrix -> [P.Value SizeT] -> [P.Value SizeT] +matrixToFun matrix [P.FinV i, P.FinV j] = [P.toValue $ matrix i j] +matrixToFun _ _ = error "invalid indices" + +data Ctx = Ctx + { n :: Int + , num_iter :: Int + , capacity :: Int + , profits, weights :: [Int] + } + +substCtx :: Ctx -> Sym.Sym Int -> Int +substCtx Ctx{..} = + Sym.unSym + . Sym.subst "N" (Sym.con n) + . Sym.subst "K" (Sym.con num_iter) + . Sym.subst "W" (Sym.con 1000) + . Sym.subst "P" (Sym.con 1000) + +worstCaseCost + , expectedCost :: + forall primT primT'. + ( P.Parseable primT' + , A.AnnotateWithErrorBudgetU primT + , A.AnnotateWithErrorBudgetQ primT + , A.ExpCostQ (A.AnnFailProb primT) SizeT Double + , SizeType primT' ~ Sym.Sym Int + , P.MapSize primT' + , primT ~ P.MappedSize primT' Int + , primT' ~ P.MappedSize primT (Sym.Sym Int) + , PP.ToCodeString primT + ) => + Ctx -> + Double -> + IO Double +-- worst case cost (ignores data) +worstCaseCost ctx eps = do + -- load the program + loaded_program <- either (fail . show) pure =<< parseFromFile (P.programParser @primT') "examples/tree_generator/tree_generator_01_knapsack.qb" + let program = P.mapSize (substCtx ctx) loaded_program + program_annotated <- either fail pure $ A.annotateProgWithErrorBudget (A.failProb eps) program + + return $ getCost $ A.costQProg program_annotated +-- expected cost (depends on data) +expectedCost ctx@Ctx{..} eps = do + -- load the program + loaded_program <- either (fail . show) pure =<< parseFromFile (P.programParser @primT') "examples/tree_generator/tree_generator_01_knapsack.qb" + let program = P.mapSize (substCtx ctx) loaded_program + -- putStrLn $ replicate 80 '=' + -- putStrLn $ PP.toCodeString program + -- putStrLn $ replicate 80 '=' + program_annotated <- either fail pure $ A.annotateProgWithErrorBudget (A.failProb eps) program + -- putStrLn $ replicate 80 '=' + -- putStrLn $ PP.toCodeString program_annotated + -- putStrLn $ replicate 80 '=' + + -- the functionality of Matrix, provided as input data + let interp = + Ctx.fromList + [ ("Capacity", \_ -> [P.toValue capacity]) + , ("Profit", listToFun profits) + , ("Weight", listToFun weights) + ] + + return $ getCost $ A.expCostQProg program_annotated mempty interp + +main :: IO () +main = do + putStrLn "Demo: Matrix Search" + + let ctx = + Ctx + { n = 3 + , capacity = 10 + , profits = [1, 2, 3, 4, 5] + , weights = [2, 2, 1, 1, 1] + , num_iter = 1 + } + let eps = 0.005 + + putStrLn "Costs for sample 0-1 knapsack instance:" + + putStr " Quantum (worst-case): " + print =<< worstCaseCost @(Primitive (QAmplify _ _)) ctx eps + putStr " Classical (worst-case): " + print =<< worstCaseCost @(Primitive (CAmplify _ _)) ctx eps + putStr " Quantum (expected): " + print =<< expectedCost @(Primitive (QAmplify _ _)) ctx eps + putStr " Classical (expected): " + print =<< expectedCost @(Primitive (CAmplify _ _)) ctx eps diff --git a/package.yaml b/package.yaml index 97153390..fe054934 100644 --- a/package.yaml +++ b/package.yaml @@ -98,6 +98,10 @@ executables: main: demo.hs <<: *expt_opts source-dirs: examples/matrix_search + knapsackdemo: + main: demo.hs + <<: *expt_opts + source-dirs: examples/tree_generator timing: main: timing.hs <<: *expt_opts diff --git a/src/Traq/Analysis/Annotate/SplitBudget.hs b/src/Traq/Analysis/Annotate/SplitBudget.hs index 21985292..a21938cf 100644 --- a/src/Traq/Analysis/Annotate/SplitBudget.hs +++ b/src/Traq/Analysis/Annotate/SplitBudget.hs @@ -25,6 +25,7 @@ import Traq.Data.Default (default_) import Traq.Analysis.Annotate.Basic import Traq.Analysis.Annotate.Prelude import Traq.Analysis.Error.Prelude +import Traq.Analysis.Prelude (sizeToPrec) import Traq.Prelude import Traq.ProtoLang @@ -50,6 +51,8 @@ instance CanError (Expr ext) where canError FunCallE{fname} = do use (_funCtx . Ctx.at fname) >>= maybe (return False) canError canError PrimCallE{} = return True + canError LoopE{loop_body_fun} = do + use (_funCtx . Ctx.at loop_body_fun) >>= maybe (return False) canError canError _ = return False instance CanError (Stmt ext) where @@ -135,7 +138,12 @@ instance AnnotateWithErrorBudgetU1 Expr where annEpsU1 eps (NamedFunDef fname fn) pure FunCallE{..} annEpsU1 eps (PrimCallE ext') = PrimCallE <$> annEpsU eps ext' - annEpsU1 _ _ = error "UNSUPPORTED" + annEpsU1 eps LoopE{..} = do + fn@FunDef{param_types} <- use (_funCtx . Ctx.at loop_body_fun) >>= maybeWithError "cannot find loop body function" + let Fin n_iters = last param_types + let eps' = splitFailProb eps (sizeToPrec n_iters) + annEpsU1 eps' (NamedFunDef loop_body_fun fn) + pure LoopE{..} instance AnnotateWithErrorBudgetQ1 Expr where annEpsQ1 _ BasicExprE{..} = pure BasicExprE{..} @@ -145,7 +153,12 @@ instance AnnotateWithErrorBudgetQ1 Expr where annEpsQ1 eps (NamedFunDef fname fn) pure FunCallE{..} annEpsQ1 eps (PrimCallE ext) = PrimCallE <$> annEpsQ eps ext - annEpsQ1 _ _ = error "UNSUPPORTED" + annEpsQ1 eps LoopE{..} = do + fn@FunDef{param_types} <- use (_funCtx . Ctx.at loop_body_fun) >>= maybeWithError "cannot find loop body function" + let Fin n_iters = last param_types + let eps' = splitFailProb eps (sizeToPrec n_iters) + annEpsQ1 eps' (NamedFunDef loop_body_fun fn) + pure LoopE{..} instance AnnotateWithErrorBudgetU1 Stmt where annEpsU1 eps ExprS{rets, expr} = do diff --git a/src/Traq/Analysis/Cost/Prelude.hs b/src/Traq/Analysis/Cost/Prelude.hs index 97731ce1..a47fe2fa 100644 --- a/src/Traq/Analysis/Cost/Prelude.hs +++ b/src/Traq/Analysis/Cost/Prelude.hs @@ -13,6 +13,7 @@ type CostReqs size prec = ( Floating prec , Num size , Ord prec + , SizeToPrec size prec ) type CostModelReqs size prec cost = diff --git a/src/Traq/Analysis/Cost/Quantum.hs b/src/Traq/Analysis/Cost/Quantum.hs index b171f473..6188f40e 100644 --- a/src/Traq/Analysis/Cost/Quantum.hs +++ b/src/Traq/Analysis/Cost/Quantum.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} @@ -29,6 +30,7 @@ import qualified Traq.Data.Probability as Prob import Traq.Analysis.Cost.Prelude import Traq.Analysis.Cost.Unitary import Traq.Analysis.CostModel.Class +import Traq.Analysis.Prelude import Traq.ProtoLang.Eval import Traq.ProtoLang.Syntax @@ -66,7 +68,11 @@ instance (CostQ ext size prec) => CostQ (Expr ext) size prec where fn <- view $ _funCtx . Ctx.at fname . non' (error $ "unable to find function " ++ fname) costQ $ NamedFunDef fname fn costQ PrimCallE{prim} = costQ prim - costQ _ = error "unsupported" + costQ LoopE{loop_body_fun} = do + fn@FunDef{param_types} <- view $ _funCtx . Ctx.at loop_body_fun . non' (error $ "unable to find function " ++ loop_body_fun) + body_cost <- costQ $ NamedFunDef loop_body_fun fn + let Fin n_iters = last param_types + return $ (sizeToPrec n_iters :: prec) Alg..* body_cost instance (CostQ ext size prec) => CostQ (Stmt ext) size prec where costQ ExprS{expr} = costQ expr @@ -95,7 +101,22 @@ instance (ExpCostQ ext size prec) => ExpCostQ (Expr ext) size prec where let sigma_fn = Ctx.fromList $ zip [show i | i <- [0 :: Int ..]] arg_vals expCostQ (NamedFunDef fname fn) sigma_fn expCostQ PrimCallE{prim} sigma = expCostQ prim sigma - expCostQ _ _ = error "unsupported" + expCostQ LoopE{initial_args, loop_body_fun} sigma = do + fn@FunDef{param_types} <- view $ _funCtx . Ctx.at loop_body_fun . non' (error $ "unable to find function " ++ loop_body_fun) + let init_vals = [sigma ^?! Ctx.at x . non (error $ "could not find var " ++ x) | x <- initial_args] + let loop_domain = domain (last param_types) + + -- evaluate each iteration + env <- view _evaluationEnv + let run_loop_body i args = + evalFun (args ++ [i]) (NamedFunDef loop_body_fun fn) + & (runReaderT ?? env) + + (_, cs) <- forAccumM (pure init_vals) loop_domain $ \distr i -> do + let sigma_fn = fmap (\xs -> Ctx.fromList $ zip [show j | j <- [0 :: Int ..]] (xs ++ [i])) distr + iter_cost <- Prob.expectationA (expCostQ (NamedFunDef loop_body_fun fn)) sigma_fn + return (distr >>= run_loop_body i, iter_cost) + return $ Alg.sum cs -- | TODO unify this as a class instance, after unifying evaluation expCostQStmt :: diff --git a/src/Traq/Analysis/Cost/Unitary.hs b/src/Traq/Analysis/Cost/Unitary.hs index d3d50eea..51209fde 100644 --- a/src/Traq/Analysis/Cost/Unitary.hs +++ b/src/Traq/Analysis/Cost/Unitary.hs @@ -1,5 +1,6 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} module Traq.Analysis.Cost.Unitary ( @@ -19,6 +20,7 @@ import Traq.Data.Default (HasDefault (default_)) import Traq.Analysis.Cost.Prelude import Traq.Analysis.CostModel.Class +import Traq.Analysis.Prelude (sizeToPrec) import Traq.Prelude import Traq.ProtoLang.Syntax @@ -51,7 +53,11 @@ instance (CostU ext size prec) => CostU (Expr ext) size prec where fn <- view $ _funCtx . Ctx.at fname . non' (error $ "unable to find function " ++ fname) costU $ NamedFunDef fname fn costU PrimCallE{prim} = costU prim - costU _ = error "unsupported" + costU LoopE{loop_body_fun} = do + fn@FunDef{param_types} <- view $ _funCtx . Ctx.at loop_body_fun . non' (error $ "unable to find function " ++ loop_body_fun) + body_cost <- costU $ NamedFunDef loop_body_fun fn + let Fin n_iters = last param_types + return $ (sizeToPrec n_iters :: prec) Alg..* body_cost instance (CostU ext size prec) => CostU (Stmt ext) size prec where costU ExprS{expr} = costU expr diff --git a/src/Traq/Analysis/Error/Prelude.hs b/src/Traq/Analysis/Error/Prelude.hs index 23d27b8a..55891b09 100644 --- a/src/Traq/Analysis/Error/Prelude.hs +++ b/src/Traq/Analysis/Error/Prelude.hs @@ -21,6 +21,7 @@ module Traq.Analysis.Error.Prelude ( import Control.Monad.Reader (Reader) +import Traq.Analysis.Prelude (SizeToPrec) import Traq.ProtoLang.Syntax (FunCtx) -- ================================================================================ @@ -70,6 +71,7 @@ type ErrorReqs size prec = ( Floating prec , Num size , Ord prec + , SizeToPrec size prec ) type ErrorAnalysisMonad ext = Reader (FunCtx ext) diff --git a/src/Traq/Primitives/Amplify/CAmplify.hs b/src/Traq/Primitives/Amplify/CAmplify.hs index 1208b08a..4832dc39 100644 --- a/src/Traq/Primitives/Amplify/CAmplify.hs +++ b/src/Traq/Primitives/Amplify/CAmplify.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeFamilies #-} module Traq.Primitives.Amplify.CAmplify ( CAmplify (..), @@ -35,6 +36,10 @@ type instance PrimFnShape (CAmplify size prec) = SamplerFn instance Amplify sizeT precT :<: CAmplify sizeT precT +instance P.MapSize (CAmplify size prec) where + type MappedSize (CAmplify size prec) size' = CAmplify size' prec + mapSize f (CAmplify p) = CAmplify (P.mapSize f p) + -- Inherited instances instance (Show prec, Fractional prec) => SerializePrim (CAmplify size prec) where primNames = ["amplify"] diff --git a/src/Traq/Primitives/Amplify/QAmplify.hs b/src/Traq/Primitives/Amplify/QAmplify.hs index 301e6727..8554e656 100644 --- a/src/Traq/Primitives/Amplify/QAmplify.hs +++ b/src/Traq/Primitives/Amplify/QAmplify.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeFamilies #-} module Traq.Primitives.Amplify.QAmplify ( QAmplify (..), @@ -35,6 +36,10 @@ type instance PrimFnShape (QAmplify size prec) = SamplerFn instance Amplify sizeT precT :<: QAmplify sizeT precT +instance P.MapSize (QAmplify size prec) where + type MappedSize (QAmplify size prec) size' = QAmplify size' prec + mapSize f (QAmplify p) = QAmplify (P.mapSize f p) + -- Inherited instances instance (Show prec, Fractional prec) => SerializePrim (QAmplify size prec) where primNames = ["amplify"] diff --git a/src/Traq/ProtoLang/Eval.hs b/src/Traq/ProtoLang/Eval.hs index 9e05256a..fa2f7c84 100644 --- a/src/Traq/ProtoLang/Eval.hs +++ b/src/Traq/ProtoLang/Eval.hs @@ -127,7 +127,7 @@ valueToBool = fromValue domainSize :: (Integral sizeT) => VarType sizeT -> sizeT domainSize (Fin _N) = _N domainSize (Bitvec n) = 2 ^ n -domainSize (Arr n t) = n * domainSize t +domainSize (Arr n t) = domainSize t ^ n domainSize (Tup ts) = product $ map domainSize ts -- | Set of all values of a given type diff --git a/src/Traq/ProtoLang/Lenses.hs b/src/Traq/ProtoLang/Lenses.hs index 19121322..eccc2f1a 100644 --- a/src/Traq/ProtoLang/Lenses.hs +++ b/src/Traq/ProtoLang/Lenses.hs @@ -85,7 +85,7 @@ instance HasExts Expr where _exts _ RandomSampleE{distr_expr} = pure RandomSampleE{distr_expr} _exts _ FunCallE{fname, args} = pure FunCallE{fname, args} _exts focus (PrimCallE p) = PrimCallE <$> focus p - _exts _ _ = error "TODO" + _exts _ LoopE{initial_args, loop_body_fun} = pure LoopE{initial_args, loop_body_fun} instance HasExts Stmt where _exts focus ExprS{rets, expr} = do diff --git a/traq.cabal b/traq.cabal index 1cc2dfc2..7799761e 100644 --- a/traq.cabal +++ b/traq.cabal @@ -190,6 +190,41 @@ executable cqplcompile , traq default-language: Haskell2010 +executable knapsackdemo + main-is: demo.hs + hs-source-dirs: + examples/tree_generator + default-extensions: + LambdaCase + NamedFieldPuns + ScopedTypeVariables + ApplicativeDo + RankNTypes + FlexibleContexts + TypeFamilies + TypeOperators + MultiWayIf + EmptyCase + RecordWildCards + ghc-options: -Wall -fprint-typechecker-elaboration + build-depends: + algebra ==4.3.* + , base >=4.10 && <5 + , containers >=0.6 && <1 + , extra >=1.8 && <2 + , lens ==5.3.* + , microlens-ghc ==0.4.* + , microlens-mtl ==0.2.* + , mtl >=2.2.2 + , optparse-applicative ==0.18.* + , parsec >=3.1.17 && <3.2 + , random + , random-shuffle + , timeit ==2.0.* + , transformers >=0.5 + , traq + default-language: Haskell2010 + executable matrixsearch main-is: matrixsearch.hs hs-source-dirs: