Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/tree_generator/demo.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}

module Main where

import Text.Parsec.String (parseFromFile)
import qualified Traq.Data.Context as Ctx
import qualified Traq.Data.Symbolic as Sym

import qualified Traq.Analysis as A
import Traq.Prelude
import qualified Traq.ProtoLang as P

import Traq.Analysis.CostModel.QueryCost (SimpleQueryCost (..))
import Traq.Primitives (Primitive)
import Traq.Primitives.Amplify.CAmplify (CAmplify (..))
import Traq.Primitives.Amplify.QAmplify (QAmplify (..))
import qualified Traq.Utils.Printing as PP

type Matrix = SizeT -> SizeT -> Bool

listToFun :: [SizeT] -> [P.Value SizeT] -> [P.Value SizeT]
listToFun xs [P.FinV i] = [P.toValue $ xs !! i]
listToFun _ _ = error "invalid index"

matrixToFun :: Matrix -> [P.Value SizeT] -> [P.Value SizeT]
matrixToFun matrix [P.FinV i, P.FinV j] = [P.toValue $ matrix i j]
matrixToFun _ _ = error "invalid indices"

data Ctx = Ctx
{ n :: Int
, num_iter :: Int
, capacity :: Int
, profits, weights :: [Int]
}

substCtx :: Ctx -> Sym.Sym Int -> Int
substCtx Ctx{..} =
Sym.unSym
. Sym.subst "N" (Sym.con n)
. Sym.subst "K" (Sym.con num_iter)
. Sym.subst "W" (Sym.con 1000)
. Sym.subst "P" (Sym.con 1000)

worstCaseCost
, expectedCost ::
forall primT primT'.
( P.Parseable primT'
, A.AnnotateWithErrorBudgetU primT
, A.AnnotateWithErrorBudgetQ primT
, A.ExpCostQ (A.AnnFailProb primT) SizeT Double
, SizeType primT' ~ Sym.Sym Int
, P.MapSize primT'
, primT ~ P.MappedSize primT' Int
, primT' ~ P.MappedSize primT (Sym.Sym Int)
, PP.ToCodeString primT
) =>
Ctx ->
Double ->
IO Double
-- worst case cost (ignores data)
worstCaseCost ctx eps = do
-- load the program
loaded_program <- either (fail . show) pure =<< parseFromFile (P.programParser @primT') "examples/tree_generator/tree_generator_01_knapsack.qb"
let program = P.mapSize (substCtx ctx) loaded_program
program_annotated <- either fail pure $ A.annotateProgWithErrorBudget (A.failProb eps) program

return $ getCost $ A.costQProg program_annotated
-- expected cost (depends on data)
expectedCost ctx@Ctx{..} eps = do
-- load the program
loaded_program <- either (fail . show) pure =<< parseFromFile (P.programParser @primT') "examples/tree_generator/tree_generator_01_knapsack.qb"
let program = P.mapSize (substCtx ctx) loaded_program
-- putStrLn $ replicate 80 '='
-- putStrLn $ PP.toCodeString program
-- putStrLn $ replicate 80 '='
program_annotated <- either fail pure $ A.annotateProgWithErrorBudget (A.failProb eps) program
-- putStrLn $ replicate 80 '='
-- putStrLn $ PP.toCodeString program_annotated
-- putStrLn $ replicate 80 '='

-- the functionality of Matrix, provided as input data
let interp =
Ctx.fromList
[ ("Capacity", \_ -> [P.toValue capacity])
, ("Profit", listToFun profits)
, ("Weight", listToFun weights)
]

return $ getCost $ A.expCostQProg program_annotated mempty interp

main :: IO ()
main = do
putStrLn "Demo: Matrix Search"

let ctx =
Ctx
{ n = 3
, capacity = 10
, profits = [1, 2, 3, 4, 5]
, weights = [2, 2, 1, 1, 1]
, num_iter = 1
}
let eps = 0.005

putStrLn "Costs for sample 0-1 knapsack instance:"

putStr " Quantum (worst-case): "
print =<< worstCaseCost @(Primitive (QAmplify _ _)) ctx eps
putStr " Classical (worst-case): "
print =<< worstCaseCost @(Primitive (CAmplify _ _)) ctx eps
putStr " Quantum (expected): "
print =<< expectedCost @(Primitive (QAmplify _ _)) ctx eps
putStr " Classical (expected): "
print =<< expectedCost @(Primitive (CAmplify _ _)) ctx eps
4 changes: 4 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ executables:
main: demo.hs
<<: *expt_opts
source-dirs: examples/matrix_search
knapsackdemo:
main: demo.hs
<<: *expt_opts
source-dirs: examples/tree_generator
timing:
main: timing.hs
<<: *expt_opts
Expand Down
17 changes: 15 additions & 2 deletions src/Traq/Analysis/Annotate/SplitBudget.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import Traq.Analysis.Annotate.Basic
import Traq.Analysis.Annotate.Prelude
import Traq.Analysis.Error.Prelude
import Traq.Analysis.Prelude (sizeToPrec)
import Traq.Prelude
import Traq.ProtoLang

Expand All @@ -50,6 +51,8 @@
canError FunCallE{fname} = do
use (_funCtx . Ctx.at fname) >>= maybe (return False) canError
canError PrimCallE{} = return True
canError LoopE{loop_body_fun} = do
use (_funCtx . Ctx.at loop_body_fun) >>= maybe (return False) canError
canError _ = return False

instance CanError (Stmt ext) where
Expand Down Expand Up @@ -135,7 +138,12 @@
annEpsU1 eps (NamedFunDef fname fn)
pure FunCallE{..}
annEpsU1 eps (PrimCallE ext') = PrimCallE <$> annEpsU eps ext'
annEpsU1 _ _ = error "UNSUPPORTED"
annEpsU1 eps LoopE{..} = do
fn@FunDef{param_types} <- use (_funCtx . Ctx.at loop_body_fun) >>= maybeWithError "cannot find loop body function"
let Fin n_iters = last param_types

Check warning on line 143 in src/Traq/Analysis/Annotate/SplitBudget.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

Pattern match(es) are non-exhaustive

Check warning on line 143 in src/Traq/Analysis/Annotate/SplitBudget.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

Pattern match(es) are non-exhaustive
let eps' = splitFailProb eps (sizeToPrec n_iters)
annEpsU1 eps' (NamedFunDef loop_body_fun fn)
pure LoopE{..}

instance AnnotateWithErrorBudgetQ1 Expr where
annEpsQ1 _ BasicExprE{..} = pure BasicExprE{..}
Expand All @@ -145,7 +153,12 @@
annEpsQ1 eps (NamedFunDef fname fn)
pure FunCallE{..}
annEpsQ1 eps (PrimCallE ext) = PrimCallE <$> annEpsQ eps ext
annEpsQ1 _ _ = error "UNSUPPORTED"
annEpsQ1 eps LoopE{..} = do
fn@FunDef{param_types} <- use (_funCtx . Ctx.at loop_body_fun) >>= maybeWithError "cannot find loop body function"
let Fin n_iters = last param_types

Check warning on line 158 in src/Traq/Analysis/Annotate/SplitBudget.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

Pattern match(es) are non-exhaustive

Check warning on line 158 in src/Traq/Analysis/Annotate/SplitBudget.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

Pattern match(es) are non-exhaustive
let eps' = splitFailProb eps (sizeToPrec n_iters)
annEpsQ1 eps' (NamedFunDef loop_body_fun fn)
pure LoopE{..}

instance AnnotateWithErrorBudgetU1 Stmt where
annEpsU1 eps ExprS{rets, expr} = do
Expand Down
1 change: 1 addition & 0 deletions src/Traq/Analysis/Cost/Prelude.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type CostReqs size prec =
( Floating prec
, Num size
, Ord prec
, SizeToPrec size prec
)

type CostModelReqs size prec cost =
Expand Down
25 changes: 23 additions & 2 deletions src/Traq/Analysis/Cost/Quantum.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}

Expand Down Expand Up @@ -29,6 +30,7 @@
import Traq.Analysis.Cost.Prelude
import Traq.Analysis.Cost.Unitary
import Traq.Analysis.CostModel.Class
import Traq.Analysis.Prelude
import Traq.ProtoLang.Eval
import Traq.ProtoLang.Syntax

Expand Down Expand Up @@ -66,7 +68,11 @@
fn <- view $ _funCtx . Ctx.at fname . non' (error $ "unable to find function " ++ fname)
costQ $ NamedFunDef fname fn
costQ PrimCallE{prim} = costQ prim
costQ _ = error "unsupported"
costQ LoopE{loop_body_fun} = do
fn@FunDef{param_types} <- view $ _funCtx . Ctx.at loop_body_fun . non' (error $ "unable to find function " ++ loop_body_fun)
body_cost <- costQ $ NamedFunDef loop_body_fun fn
let Fin n_iters = last param_types

Check warning on line 74 in src/Traq/Analysis/Cost/Quantum.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

Pattern match(es) are non-exhaustive

Check warning on line 74 in src/Traq/Analysis/Cost/Quantum.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

Pattern match(es) are non-exhaustive
return $ (sizeToPrec n_iters :: prec) Alg..* body_cost

instance (CostQ ext size prec) => CostQ (Stmt ext) size prec where
costQ ExprS{expr} = costQ expr
Expand Down Expand Up @@ -95,7 +101,22 @@
let sigma_fn = Ctx.fromList $ zip [show i | i <- [0 :: Int ..]] arg_vals
expCostQ (NamedFunDef fname fn) sigma_fn
expCostQ PrimCallE{prim} sigma = expCostQ prim sigma
expCostQ _ _ = error "unsupported"
expCostQ LoopE{initial_args, loop_body_fun} sigma = do
fn@FunDef{param_types} <- view $ _funCtx . Ctx.at loop_body_fun . non' (error $ "unable to find function " ++ loop_body_fun)
let init_vals = [sigma ^?! Ctx.at x . non (error $ "could not find var " ++ x) | x <- initial_args]
let loop_domain = domain (last param_types)

-- evaluate each iteration
env <- view _evaluationEnv
let run_loop_body i args =
evalFun (args ++ [i]) (NamedFunDef loop_body_fun fn)
& (runReaderT ?? env)

(_, cs) <- forAccumM (pure init_vals) loop_domain $ \distr i -> do
let sigma_fn = fmap (\xs -> Ctx.fromList $ zip [show j | j <- [0 :: Int ..]] (xs ++ [i])) distr
iter_cost <- Prob.expectationA (expCostQ (NamedFunDef loop_body_fun fn)) sigma_fn
return (distr >>= run_loop_body i, iter_cost)
return $ Alg.sum cs

-- | TODO unify this as a class instance, after unifying evaluation
expCostQStmt ::
Expand Down
8 changes: 7 additions & 1 deletion src/Traq/Analysis/Cost/Unitary.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}

module Traq.Analysis.Cost.Unitary (
Expand All @@ -19,6 +20,7 @@

import Traq.Analysis.Cost.Prelude
import Traq.Analysis.CostModel.Class
import Traq.Analysis.Prelude (sizeToPrec)
import Traq.Prelude
import Traq.ProtoLang.Syntax

Expand Down Expand Up @@ -51,7 +53,11 @@
fn <- view $ _funCtx . Ctx.at fname . non' (error $ "unable to find function " ++ fname)
costU $ NamedFunDef fname fn
costU PrimCallE{prim} = costU prim
costU _ = error "unsupported"
costU LoopE{loop_body_fun} = do
fn@FunDef{param_types} <- view $ _funCtx . Ctx.at loop_body_fun . non' (error $ "unable to find function " ++ loop_body_fun)
body_cost <- costU $ NamedFunDef loop_body_fun fn
let Fin n_iters = last param_types

Check warning on line 59 in src/Traq/Analysis/Cost/Unitary.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

Pattern match(es) are non-exhaustive

Check warning on line 59 in src/Traq/Analysis/Cost/Unitary.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

Pattern match(es) are non-exhaustive
return $ (sizeToPrec n_iters :: prec) Alg..* body_cost

instance (CostU ext size prec) => CostU (Stmt ext) size prec where
costU ExprS{expr} = costU expr
Expand Down
2 changes: 2 additions & 0 deletions src/Traq/Analysis/Error/Prelude.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module Traq.Analysis.Error.Prelude (

import Control.Monad.Reader (Reader)

import Traq.Analysis.Prelude (SizeToPrec)
import Traq.ProtoLang.Syntax (FunCtx)

-- ================================================================================
Expand Down Expand Up @@ -70,6 +71,7 @@ type ErrorReqs size prec =
( Floating prec
, Num size
, Ord prec
, SizeToPrec size prec
)

type ErrorAnalysisMonad ext = Reader (FunCtx ext)
5 changes: 5 additions & 0 deletions src/Traq/Primitives/Amplify/CAmplify.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Traq.Primitives.Amplify.CAmplify (
CAmplify (..),
Expand Down Expand Up @@ -35,6 +36,10 @@ type instance PrimFnShape (CAmplify size prec) = SamplerFn

instance Amplify sizeT precT :<: CAmplify sizeT precT

instance P.MapSize (CAmplify size prec) where
type MappedSize (CAmplify size prec) size' = CAmplify size' prec
mapSize f (CAmplify p) = CAmplify (P.mapSize f p)

-- Inherited instances
instance (Show prec, Fractional prec) => SerializePrim (CAmplify size prec) where
primNames = ["amplify"]
Expand Down
5 changes: 5 additions & 0 deletions src/Traq/Primitives/Amplify/QAmplify.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Traq.Primitives.Amplify.QAmplify (
QAmplify (..),
Expand Down Expand Up @@ -35,6 +36,10 @@ type instance PrimFnShape (QAmplify size prec) = SamplerFn

instance Amplify sizeT precT :<: QAmplify sizeT precT

instance P.MapSize (QAmplify size prec) where
type MappedSize (QAmplify size prec) size' = QAmplify size' prec
mapSize f (QAmplify p) = QAmplify (P.mapSize f p)

-- Inherited instances
instance (Show prec, Fractional prec) => SerializePrim (QAmplify size prec) where
primNames = ["amplify"]
Expand Down
2 changes: 1 addition & 1 deletion src/Traq/ProtoLang/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ valueToBool = fromValue
domainSize :: (Integral sizeT) => VarType sizeT -> sizeT
domainSize (Fin _N) = _N
domainSize (Bitvec n) = 2 ^ n
domainSize (Arr n t) = n * domainSize t
domainSize (Arr n t) = domainSize t ^ n
domainSize (Tup ts) = product $ map domainSize ts

-- | Set of all values of a given type
Expand Down
2 changes: 1 addition & 1 deletion src/Traq/ProtoLang/Lenses.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ instance HasExts Expr where
_exts _ RandomSampleE{distr_expr} = pure RandomSampleE{distr_expr}
_exts _ FunCallE{fname, args} = pure FunCallE{fname, args}
_exts focus (PrimCallE p) = PrimCallE <$> focus p
_exts _ _ = error "TODO"
_exts _ LoopE{initial_args, loop_body_fun} = pure LoopE{initial_args, loop_body_fun}

instance HasExts Stmt where
_exts focus ExprS{rets, expr} = do
Expand Down
35 changes: 35 additions & 0 deletions traq.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,41 @@ executable cqplcompile
, traq
default-language: Haskell2010

executable knapsackdemo
main-is: demo.hs
hs-source-dirs:
examples/tree_generator
default-extensions:
LambdaCase
NamedFieldPuns
ScopedTypeVariables
ApplicativeDo
RankNTypes
FlexibleContexts
TypeFamilies
TypeOperators
MultiWayIf
EmptyCase
RecordWildCards
ghc-options: -Wall -fprint-typechecker-elaboration
build-depends:
algebra ==4.3.*
, base >=4.10 && <5
, containers >=0.6 && <1
, extra >=1.8 && <2
, lens ==5.3.*
, microlens-ghc ==0.4.*
, microlens-mtl ==0.2.*
, mtl >=2.2.2
, optparse-applicative ==0.18.*
, parsec >=3.1.17 && <3.2
, random
, random-shuffle
, timeit ==2.0.*
, transformers >=0.5
, traq
default-language: Haskell2010

executable matrixsearch
main-is: matrixsearch.hs
hs-source-dirs:
Expand Down
Loading