From c4dcdfd275d0f3bdabfceab61dcb64055e864e8a Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Tue, 27 Jan 2026 21:24:55 +0100 Subject: [PATCH 1/4] env --- src/Traq/Primitives/Class.hs | 8 +++++- src/Traq/Primitives/Class/UnitaryCompile.hs | 29 ++++++++++----------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/Traq/Primitives/Class.hs b/src/Traq/Primitives/Class.hs index f11154d..92db4fe 100644 --- a/src/Traq/Primitives/Class.hs +++ b/src/Traq/Primitives/Class.hs @@ -316,7 +316,13 @@ instance >>= maybeWithError (printf "could not find uproc `%s` for fun `%s`" uproc_name pfun_name) return $ Compiler.aux_tys sign - let builder = UnitaryCompilePrimBuilder{mk_ucall, uproc_aux_types, ret_vars = rets} + let builder = + PrimCompileEnv + { mk_ucall + , mk_call = reshapeUnsafe $ replicate (length par_funs) (error "cannot call proc from UPrim") + , uproc_aux_types + , ret_vars = rets + } let arg_bounder = prependBoundArgs (map Compiler.mkUProcName pfun_names) bound_args prim_proc_raw <- runReaderT (compileUPrim prim eps) builder diff --git a/src/Traq/Primitives/Class/UnitaryCompile.hs b/src/Traq/Primitives/Class/UnitaryCompile.hs index 32351f3..f23af2b 100644 --- a/src/Traq/Primitives/Class/UnitaryCompile.hs +++ b/src/Traq/Primitives/Class/UnitaryCompile.hs @@ -6,7 +6,7 @@ module Traq.Primitives.Class.UnitaryCompile ( UnitaryCompilePrim (..), - UnitaryCompilePrimBuilder (..), + PrimCompileEnv (..), ) where import Control.Monad.Reader (ReaderT (..)) @@ -27,12 +27,15 @@ import qualified Traq.ProtoLang as P -- -------------------------------------------------------------------------------- type UCallBuilder size = [Ident] -> CQPL.UStmt size +type CallBuilder size = [Ident] -> CQPL.Stmt size -- type UProcBuilder size = [(Ident, P.VarType size)] -> CQPL.UStmt size -> CQPL.ProcDef size -data UnitaryCompilePrimBuilder shape size = UnitaryCompilePrimBuilder +data PrimCompileEnv shape size = PrimCompileEnv { mk_ucall :: shape (UCallBuilder size) -- ^ helper to generate a call to a unitary function argument. + , mk_call :: shape (CallBuilder size) + -- ^ helper to generate a call to a classical function argument. , uproc_aux_types :: shape [P.VarType size] -- ^ auxiliary variables for each unitary function argument. , ret_vars :: [Ident] @@ -41,21 +44,17 @@ data UnitaryCompilePrimBuilder shape size = UnitaryCompilePrimBuilder reshapeBuilder :: (ValidPrimShape shape, ValidPrimShape shape') => - UnitaryCompilePrimBuilder shape size -> - Either String (UnitaryCompilePrimBuilder shape' size) -reshapeBuilder UnitaryCompilePrimBuilder{..} = do - mk_ucall' <- reshape mk_ucall - uproc_aux_types' <- reshape uproc_aux_types - return - UnitaryCompilePrimBuilder - { mk_ucall = mk_ucall' - , uproc_aux_types = uproc_aux_types' - , .. - } + PrimCompileEnv shape size -> + Either String (PrimCompileEnv shape' size) +reshapeBuilder PrimCompileEnv{..} = do + mk_ucall <- reshape mk_ucall + mk_call <- reshape mk_call + uproc_aux_types <- reshape uproc_aux_types + return PrimCompileEnv{..} type UnitaryCompilePrimMonad ext prim = ReaderT - (UnitaryCompilePrimBuilder (PrimFnShape prim) (SizeType prim)) + (PrimCompileEnv (PrimFnShape prim) (SizeType prim)) (CompilerT ext) -- | Compile a primitive to a unitary statement. @@ -104,7 +103,7 @@ class GUnitaryCompilePrim f size prec | f -> size prec where ) => f p -> A.FailProb prec -> - UnitaryCompilePrimBuilder [] size -> + PrimCompileEnv [] size -> m (CQPL.ProcDef size) instance (GUnitaryCompilePrim a size prec, GUnitaryCompilePrim b size prec) => GUnitaryCompilePrim (a :+: b) size prec where From ef3bfb37285ba4e5eaaaff6ed987fd56c479431d Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Tue, 27 Jan 2026 21:25:53 +0100 Subject: [PATCH 2/4] move --- src/Traq/Primitives/Class.hs | 4 ++-- src/Traq/Primitives/Class/{UnitaryCompile.hs => Compile.hs} | 2 +- traq.cabal | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/Traq/Primitives/Class/{UnitaryCompile.hs => Compile.hs} (98%) diff --git a/src/Traq/Primitives/Class.hs b/src/Traq/Primitives/Class.hs index 92db4fe..ef2839f 100644 --- a/src/Traq/Primitives/Class.hs +++ b/src/Traq/Primitives/Class.hs @@ -12,7 +12,7 @@ module Traq.Primitives.Class ( module Traq.Primitives.Class.Eval, module Traq.Primitives.Class.UnitaryCost, module Traq.Primitives.Class.QuantumCost, - module Traq.Primitives.Class.UnitaryCompile, + module Traq.Primitives.Class.Compile, ) where import Control.Applicative (Alternative ((<|>)), many) @@ -38,12 +38,12 @@ import qualified Traq.Analysis as A import qualified Traq.CQPL as CQPL import qualified Traq.Compiler as Compiler import Traq.Prelude +import Traq.Primitives.Class.Compile import Traq.Primitives.Class.Eval import Traq.Primitives.Class.Prelude import Traq.Primitives.Class.QuantumCost import Traq.Primitives.Class.Serialize import Traq.Primitives.Class.TypeCheck -import Traq.Primitives.Class.UnitaryCompile import Traq.Primitives.Class.UnitaryCost import qualified Traq.ProtoLang as P import qualified Traq.Utils.Printing as PP diff --git a/src/Traq/Primitives/Class/UnitaryCompile.hs b/src/Traq/Primitives/Class/Compile.hs similarity index 98% rename from src/Traq/Primitives/Class/UnitaryCompile.hs rename to src/Traq/Primitives/Class/Compile.hs index f23af2b..103deb7 100644 --- a/src/Traq/Primitives/Class/UnitaryCompile.hs +++ b/src/Traq/Primitives/Class/Compile.hs @@ -4,7 +4,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE UndecidableInstances #-} -module Traq.Primitives.Class.UnitaryCompile ( +module Traq.Primitives.Class.Compile ( UnitaryCompilePrim (..), PrimCompileEnv (..), ) where diff --git a/traq.cabal b/traq.cabal index ac063fe..b26747b 100644 --- a/traq.cabal +++ b/traq.cabal @@ -67,12 +67,12 @@ library Traq.Primitives.Amplify.Prelude Traq.Primitives.Amplify.QAmplify Traq.Primitives.Class + Traq.Primitives.Class.Compile Traq.Primitives.Class.Eval Traq.Primitives.Class.Prelude Traq.Primitives.Class.QuantumCost Traq.Primitives.Class.Serialize Traq.Primitives.Class.TypeCheck - Traq.Primitives.Class.UnitaryCompile Traq.Primitives.Class.UnitaryCost Traq.Primitives.Count.QCount Traq.Primitives.Max.QMax From 5a6b5535683395f43f0a7a226d8b5c2b7d6fd04d Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Tue, 27 Jan 2026 21:48:16 +0100 Subject: [PATCH 3/4] class --- src/Traq/Primitives.hs | 16 +--- src/Traq/Primitives/Class.hs | 70 ++++++++++++++++++ src/Traq/Primitives/Class/Compile.hs | 85 ++++++++++++++++++++-- src/Traq/Primitives/Search/DetSearch.hs | 6 +- src/Traq/Primitives/Search/QSearchCFNW.hs | 9 +++ src/Traq/Primitives/Search/RandomSearch.hs | 6 +- 6 files changed, 169 insertions(+), 23 deletions(-) diff --git a/src/Traq/Primitives.hs b/src/Traq/Primitives.hs index 8af72e8..e7ef4ea 100644 --- a/src/Traq/Primitives.hs +++ b/src/Traq/Primitives.hs @@ -89,24 +89,14 @@ instance instance (P.TypingReqs size, Integral size, RealFloat prec, Show prec) => UnitaryCompilePrim (DefaultPrimCollection size prec) size prec +instance + (P.TypingReqs size, Integral size, RealFloat prec, Show prec) => + QuantumCompilePrim (DefaultPrimCollection size prec) size prec type DefaultPrims sizeT precT = Primitive (DefaultPrimCollection sizeT precT) type DefaultPrims' = DefaultPrims SizeT Double -instance - ( Integral sizeT - , Floating precT - , RealFloat precT - , P.TypingReqs sizeT - , Show precT - , sizeT ~ SizeT - ) => - Compiler.CompileQ (A.AnnFailProb (DefaultPrims sizeT precT)) - where - compileQ (A.AnnFailProb eps (Primitive fs (QAny q))) = Compiler.compileQ (A.AnnFailProb eps (Primitive fs q)) - compileQ _ = error "TODO: lowerPrimitive" - -- ================================================================================ -- Worst-cost prim collection -- ================================================================================ diff --git a/src/Traq/Primitives/Class.hs b/src/Traq/Primitives/Class.hs index ef2839f..211428b 100644 --- a/src/Traq/Primitives/Class.hs +++ b/src/Traq/Primitives/Class.hs @@ -502,3 +502,73 @@ instance when (n_query_u > 0) $ void $ A.annEpsU1 eps_fn_u named_fn pure $ A.AnnFailProb eps_alg $ Primitive par_funs prim + +-- -------------------------------------------------------------------------------- +-- Compilation +-- -------------------------------------------------------------------------------- + +instance + {-# OVERLAPPABLE #-} + ( TypeCheckPrim prim (SizeType prim) + , P.TypingReqs (SizeType prim) + , UnitaryCompilePrim prim (SizeType prim) (PrecType prim) + , QuantumCompilePrim prim (SizeType prim) (PrecType prim) + ) => + Compiler.CompileQ (A.AnnFailProb (Primitive prim)) + where + compileQ (A.AnnFailProb eps (Primitive par_funs prim)) rets = do + let pfun_names = map pfun_name par_funs + let bound_args_names = concatMap (catMaybes . pfun_args) par_funs + bound_args_tys <- forM bound_args_names $ \x -> use $ P._typingCtx . Ctx.at x . non' (error $ "invalid arg " ++ x) + let bound_args = zip bound_args_names bound_args_tys + + mk_ucall <- + reshape $ + par_funs <&> \PartialFun{pfun_name, pfun_args} xs -> + CQPL.UCallS + { uproc_id = Compiler.mkUProcName pfun_name + , dagger = False + , qargs = placeArgsWithExcess pfun_args xs + } + + uproc_aux_types <- + reshape =<< do + forM par_funs $ \PartialFun{pfun_name} -> do + let uproc_name = Compiler.mkUProcName pfun_name + sign <- + use (Compiler._procSignatures . at uproc_name) + >>= maybeWithError (printf "could not find uproc `%s` for fun `%s`" uproc_name pfun_name) + return $ Compiler.aux_tys sign + + mk_call <- + reshape $ + par_funs <&> \PartialFun{pfun_name, pfun_args} xs -> + CQPL.CallS + { fun = CQPL.FunctionCall $ Compiler.mkQProcName pfun_name + , meta_params = [] + , args = placeArgsWithExcess pfun_args xs + } + + let builder = + PrimCompileEnv + { mk_ucall + , mk_call + , uproc_aux_types + , ret_vars = rets + } + let arg_bounder = + prependBoundArgs + (map Compiler.mkUProcName pfun_names ++ map Compiler.mkQProcName pfun_names) + bound_args + prim_proc_raw <- + runReaderT (compileQPrim prim eps) builder + & censor (Compiler._loweredProcs . each %~ arg_bounder) + let prim_proc = arg_bounder prim_proc_raw + Compiler.addProc prim_proc + + return $ + CQPL.CallS + { fun = CQPL.FunctionCall $ CQPL.proc_name prim_proc + , meta_params = [] + , args = map fst bound_args ++ rets + } diff --git a/src/Traq/Primitives/Class/Compile.hs b/src/Traq/Primitives/Class/Compile.hs index 103deb7..d0631f6 100644 --- a/src/Traq/Primitives/Class/Compile.hs +++ b/src/Traq/Primitives/Class/Compile.hs @@ -5,8 +5,9 @@ {-# LANGUAGE UndecidableInstances #-} module Traq.Primitives.Class.Compile ( - UnitaryCompilePrim (..), PrimCompileEnv (..), + UnitaryCompilePrim (..), + QuantumCompilePrim (..), ) where import Control.Monad.Reader (ReaderT (..)) @@ -23,14 +24,13 @@ import Traq.Primitives.Class.Prelude import qualified Traq.ProtoLang as P -- -------------------------------------------------------------------------------- --- Unitary Compilation +-- Environment and enclosing monad for compiling primitives. -- -------------------------------------------------------------------------------- type UCallBuilder size = [Ident] -> CQPL.UStmt size type CallBuilder size = [Ident] -> CQPL.Stmt size --- type UProcBuilder size = [(Ident, P.VarType size)] -> CQPL.UStmt size -> CQPL.ProcDef size - +-- | Helpers to compile a primitive. data PrimCompileEnv shape size = PrimCompileEnv { mk_ucall :: shape (UCallBuilder size) -- ^ helper to generate a call to a unitary function argument. @@ -52,12 +52,16 @@ reshapeBuilder PrimCompileEnv{..} = do uproc_aux_types <- reshape uproc_aux_types return PrimCompileEnv{..} -type UnitaryCompilePrimMonad ext prim = +type PrimCompileMonad ext prim = ReaderT (PrimCompileEnv (PrimFnShape prim) (SizeType prim)) (CompilerT ext) --- | Compile a primitive to a unitary statement. +-- -------------------------------------------------------------------------------- +-- Unitary Compilation +-- -------------------------------------------------------------------------------- + +-- | Compile a primitive to a uproc class ( size ~ SizeType prim , prec ~ PrecType prim @@ -68,7 +72,7 @@ class where compileUPrim :: forall ext' m shape. - ( m ~ UnitaryCompilePrimMonad ext' prim + ( m ~ PrimCompileMonad ext' prim , size ~ SizeType ext' , prec ~ PrecType ext' , shape ~ PrimFnShape prim @@ -80,7 +84,7 @@ class forall ext' m shape. ( Generic prim , GUnitaryCompilePrim (Rep prim) size prec - , m ~ UnitaryCompilePrimMonad ext' prim + , m ~ PrimCompileMonad ext' prim , size ~ SizeType ext' , prec ~ PrecType ext' , shape ~ PrimFnShape prim @@ -117,3 +121,68 @@ instance (UnitaryCompilePrim a size prec) => GUnitaryCompilePrim (K1 i a) size p gcompileUPrim (K1 x) eps builder = do builder' <- lift $ reshapeBuilder builder runReaderT (compileUPrim x eps) builder' + +-- -------------------------------------------------------------------------------- +-- Quantum Compilation +-- -------------------------------------------------------------------------------- + +-- | Compile a primitive to a cq-proc +class + ( size ~ SizeType prim + , prec ~ PrecType prim + , ValidPrimShape (PrimFnShape prim) + ) => + QuantumCompilePrim prim size prec + | prim -> size prec + where + compileQPrim :: + forall ext' m shape. + ( m ~ PrimCompileMonad ext' prim + , size ~ SizeType ext' + , prec ~ PrecType ext' + , shape ~ PrimFnShape prim + ) => + prim -> + A.FailProb prec -> + m (CQPL.ProcDef size) + default compileQPrim :: + forall ext' m shape. + ( Generic prim + , GQuantumCompilePrim (Rep prim) size prec + , m ~ PrimCompileMonad ext' prim + , size ~ SizeType ext' + , prec ~ PrecType ext' + , shape ~ PrimFnShape prim + ) => + prim -> + A.FailProb prec -> + m (CQPL.ProcDef size) + compileQPrim prim eps = do + builder <- view id + lift $ do + builder' <- lift $ reshapeBuilder builder + gcompileQPrim (from prim) eps builder' + +class GQuantumCompilePrim f size prec | f -> size prec where + gcompileQPrim :: + forall ext' m p. + ( m ~ CompilerT ext' + , size ~ SizeType ext' + , prec ~ PrecType ext' + ) => + f p -> + A.FailProb prec -> + PrimCompileEnv [] size -> + m (CQPL.ProcDef size) + +instance (GQuantumCompilePrim a size prec, GQuantumCompilePrim b size prec) => GQuantumCompilePrim (a :+: b) size prec where + gcompileQPrim (L1 x) = gcompileQPrim x + gcompileQPrim (R1 x) = gcompileQPrim x + +instance (GQuantumCompilePrim f size prec) => GQuantumCompilePrim (M1 i c f) size prec where + gcompileQPrim (M1 x) = gcompileQPrim x + +instance (QuantumCompilePrim a size prec) => GQuantumCompilePrim (K1 i a) size prec where + gcompileQPrim (K1 x) eps builder = do + builder' <- lift $ reshapeBuilder builder + runReaderT (compileQPrim x eps) builder' diff --git a/src/Traq/Primitives/Search/DetSearch.hs b/src/Traq/Primitives/Search/DetSearch.hs index b955b7d..28c7862 100644 --- a/src/Traq/Primitives/Search/DetSearch.hs +++ b/src/Traq/Primitives/Search/DetSearch.hs @@ -100,4 +100,8 @@ instance instance UnitaryCompilePrim (DetSearch size prec) size prec where compileUPrim (DetSearch PrimSearch{search_kind, search_ty}) eps = do - error "TODO: CompileU andomSearch" + error "TODO: CompileU DetSearch" + +instance QuantumCompilePrim (DetSearch size prec) size prec where + compileQPrim (DetSearch PrimSearch{search_kind, search_ty}) eps = do + error "TODO: CompileQ DetSearch" diff --git a/src/Traq/Primitives/Search/QSearchCFNW.hs b/src/Traq/Primitives/Search/QSearchCFNW.hs index 2b74559..8de7c0e 100644 --- a/src/Traq/Primitives/Search/QSearchCFNW.hs +++ b/src/Traq/Primitives/Search/QSearchCFNW.hs @@ -580,6 +580,15 @@ algoQSearch ty n_samples eps grover_k_caller pred_caller ok = do nxt m = min (lambda * m) sqrt_n instance + (P.TypingReqs size, Integral size, RealFloat prec, Show prec) => + QuantumCompilePrim (QSearchCFNW size prec) size prec + where + compileQPrim (QSearchCFNW PrimSearch{search_kind = AnyK, search_ty}) eps = do + error "TODO" + compileQPrim _ _ = error "unsupported" + +instance + {-# OVERLAPPABLE #-} ( Integral sizeT , RealFloat precT , sizeT ~ SizeT diff --git a/src/Traq/Primitives/Search/RandomSearch.hs b/src/Traq/Primitives/Search/RandomSearch.hs index 8c7f402..02ec303 100644 --- a/src/Traq/Primitives/Search/RandomSearch.hs +++ b/src/Traq/Primitives/Search/RandomSearch.hs @@ -121,4 +121,8 @@ instance instance UnitaryCompilePrim (RandomSearch size prec) size prec where compileUPrim (RandomSearch PrimSearch{search_kind, search_ty}) eps = do - error "TODO: CompileU andomSearch" + error "TODO: CompileU RandomSearch" + +instance QuantumCompilePrim (RandomSearch size prec) size prec where + compileQPrim (RandomSearch PrimSearch{search_kind, search_ty}) eps = do + error "TODO: CompileQ RandomSearch" From 4af45ef84ae8791570e0ffea839dd07b5bce48be Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Tue, 27 Jan 2026 22:10:13 +0100 Subject: [PATCH 4/4] upgrade --- examples/matrix_search/matrix_search.qpl | 2 +- src/Traq/Primitives.hs | 3 +- src/Traq/Primitives/Class.hs | 11 ++++ src/Traq/Primitives/Class/Compile.hs | 3 + src/Traq/Primitives/Search/QSearchCFNW.hs | 76 +++++++---------------- 5 files changed, 39 insertions(+), 56 deletions(-) diff --git a/examples/matrix_search/matrix_search.qpl b/examples/matrix_search/matrix_search.qpl index 65672aa..f3e6f61 100644 --- a/examples/matrix_search/matrix_search.qpl +++ b/examples/matrix_search/matrix_search.qpl @@ -1504,7 +1504,7 @@ uproc IsRowAllOnes_U(i : IN Fin<20>, okr : OUT Fin<2>, hasZero : AUX Fin<2>, has } // Grover[...] -uproc Grover[k](i : IN Fin<20>, x : IN Fin<10>, hasZero : OUT Fin<2>, aux_4 : AUX Fin<2>, aux_5 : AUX Fin<2>, aux_6 : AUX Fin<2>) { +uproc Grover[k](i : Fin<20>, x : IN Fin<10>, hasZero : OUT Fin<2>, aux_4 : AUX Fin<2>, aux_5 : AUX Fin<2>, aux_6 : AUX Fin<2>) { hasZero *= X; hasZero *= H; x *= Distr[uniform : Fin<10>]; diff --git a/src/Traq/Primitives.hs b/src/Traq/Primitives.hs index e7ef4ea..d622ccf 100644 --- a/src/Traq/Primitives.hs +++ b/src/Traq/Primitives.hs @@ -23,7 +23,6 @@ module Traq.Primitives ( import GHC.Generics import qualified Traq.Analysis as A -import qualified Traq.Compiler as Compiler import Traq.Prelude import Traq.Primitives.Class import Traq.Primitives.Search.DetSearch @@ -90,7 +89,7 @@ instance (P.TypingReqs size, Integral size, RealFloat prec, Show prec) => UnitaryCompilePrim (DefaultPrimCollection size prec) size prec instance - (P.TypingReqs size, Integral size, RealFloat prec, Show prec) => + (size ~ SizeT, P.TypingReqs size, Integral size, RealFloat prec, Show prec) => QuantumCompilePrim (DefaultPrimCollection size prec) size prec type DefaultPrims sizeT precT = Primitive (DefaultPrimCollection sizeT precT) diff --git a/src/Traq/Primitives/Class.hs b/src/Traq/Primitives/Class.hs index 211428b..78118b6 100644 --- a/src/Traq/Primitives/Class.hs +++ b/src/Traq/Primitives/Class.hs @@ -320,6 +320,7 @@ instance PrimCompileEnv { mk_ucall , mk_call = reshapeUnsafe $ replicate (length par_funs) (error "cannot call proc from UPrim") + , mk_meas = reshapeUnsafe $ replicate (length par_funs) (error "cannot meas uproc from UPrim") , uproc_aux_types , ret_vars = rets } @@ -549,10 +550,20 @@ instance , args = placeArgsWithExcess pfun_args xs } + mk_meas <- + reshape $ + par_funs <&> \PartialFun{pfun_name, pfun_args} xs -> + CQPL.CallS + { fun = CQPL.UProcAndMeas $ Compiler.mkUProcName pfun_name + , meta_params = [] + , args = placeArgsWithExcess pfun_args xs + } + let builder = PrimCompileEnv { mk_ucall , mk_call + , mk_meas , uproc_aux_types , ret_vars = rets } diff --git a/src/Traq/Primitives/Class/Compile.hs b/src/Traq/Primitives/Class/Compile.hs index d0631f6..b13b695 100644 --- a/src/Traq/Primitives/Class/Compile.hs +++ b/src/Traq/Primitives/Class/Compile.hs @@ -36,6 +36,8 @@ data PrimCompileEnv shape size = PrimCompileEnv -- ^ helper to generate a call to a unitary function argument. , mk_call :: shape (CallBuilder size) -- ^ helper to generate a call to a classical function argument. + , mk_meas :: shape (CallBuilder size) + -- ^ helper to generate a call-and-meas to a unitary proc arg. , uproc_aux_types :: shape [P.VarType size] -- ^ auxiliary variables for each unitary function argument. , ret_vars :: [Ident] @@ -49,6 +51,7 @@ reshapeBuilder :: reshapeBuilder PrimCompileEnv{..} = do mk_ucall <- reshape mk_ucall mk_call <- reshape mk_call + mk_meas <- reshape mk_meas uproc_aux_types <- reshape uproc_aux_types return PrimCompileEnv{..} diff --git a/src/Traq/Primitives/Search/QSearchCFNW.hs b/src/Traq/Primitives/Search/QSearchCFNW.hs index 8de7c0e..85757a7 100644 --- a/src/Traq/Primitives/Search/QSearchCFNW.hs +++ b/src/Traq/Primitives/Search/QSearchCFNW.hs @@ -35,7 +35,7 @@ import Control.Monad.Except (throwError) import Control.Monad.RWS (RWST, evalRWST) import Control.Monad.Trans (lift) import Control.Monad.Writer (WriterT (..), censor, execWriterT, listen) -import Data.Maybe (catMaybes, fromJust) +import Data.Maybe (fromJust) import Data.String (fromString) import GHC.Generics (Generic) import Text.Printf (printf) @@ -474,7 +474,6 @@ algoQSearch :: ( Integral sizeT , RealFloat precT , sizeT ~ SizeT - , Compiler.CompileQ ext , Show sizeT , Show precT , P.TypingReqs sizeT @@ -580,32 +579,17 @@ algoQSearch ty n_samples eps grover_k_caller pred_caller ok = do nxt m = min (lambda * m) sqrt_n instance - (P.TypingReqs size, Integral size, RealFloat prec, Show prec) => - QuantumCompilePrim (QSearchCFNW size prec) size prec + (RealFloat prec, Show prec) => + QuantumCompilePrim (QSearchCFNW SizeT prec) SizeT prec where compileQPrim (QSearchCFNW PrimSearch{search_kind = AnyK, search_ty}) eps = do - error "TODO" - compileQPrim _ _ = error "unsupported" - -instance - {-# OVERLAPPABLE #-} - ( Integral sizeT - , RealFloat precT - , sizeT ~ SizeT - , Show sizeT - , Show precT - , P.TypingReqs sizeT - ) => - Compiler.CompileQ (A.AnnFailProb (Primitive (QSearchCFNW sizeT precT))) - where - compileQ (A.AnnFailProb eps (Primitive [PartialFun{pfun_name, pfun_args}] (QSearchCFNW (PrimSearch _ s_ty)))) (ret : rets) = do -- lowered unitary predicate - let upred_proc_name = Compiler.mkUProcName pfun_name - Compiler.ProcSignature - { Compiler.in_tys = pred_inp_tys - , Compiler.aux_tys = pred_aux_tys - } <- - use (Compiler._procSignatures . at upred_proc_name) >>= maybeWithError "missing uproc" + (BooleanPredicate call_upred) <- view $ to mk_ucall + (BooleanPredicate pred_aux_tys) <- view $ to uproc_aux_types + ret <- + view (to ret_vars) >>= \case + [b] -> pure b + _ -> throwError "bool predicate must return single bool" -- make the Grover_k uproc -- TODO this should ideally be done by algoQSearch, but requires a lot of aux information. @@ -616,17 +600,11 @@ instance let uproc_grover_k_body = groverK meta_k - (grover_arg_name, s_ty) + (grover_arg_name, search_ty) ret - ( \x b -> - CQPL.UCallS - { CQPL.uproc_id = upred_proc_name - , CQPL.dagger = False - , CQPL.qargs = catMaybes pfun_args ++ [x, b] ++ upred_aux_vars - } - ) + (\x b -> call_upred ([x, b] ++ upred_aux_vars)) let uproc_grover_k_params = - Compiler.withTag CQPL.ParamInp (zip (catMaybes pfun_args ++ [grover_arg_name]) pred_inp_tys) + Compiler.withTag CQPL.ParamInp [(grover_arg_name, search_ty)] ++ Compiler.withTag CQPL.ParamOut [(ret, P.tbool)] ++ Compiler.withTag CQPL.ParamAux (zip upred_aux_vars pred_aux_tys) let uproc_grover_k = @@ -649,24 +627,21 @@ instance CQPL.CallS { CQPL.fun = CQPL.UProcAndMeas uproc_grover_k_name , CQPL.meta_params = [k] - , CQPL.args = catMaybes pfun_args ++ [x, b] + , CQPL.args = [x, b] } -- emit the QSearch algorithm - qsearch_params <- forM (catMaybes pfun_args ++ [ret]) $ \x -> do - ty <- use $ P._typingCtx . Ctx.at x . singular _Just - return (x, ty) + let qsearch_params = [(ret, P.tbool)] - let pred_caller x b = - CQPL.CallS - { CQPL.fun = CQPL.UProcAndMeas upred_proc_name - , CQPL.meta_params = [] - , CQPL.args = catMaybes pfun_args ++ [x, b] - } + (BooleanPredicate meas_upred) <- view $ to mk_meas + let pred_caller x b = meas_upred [x, b] - (qsearch_body, qsearch_local_vars) <- execWriterT $ algoQSearch s_ty 0 eps grover_k_caller pred_caller ret + (qsearch_body, qsearch_local_vars) <- + lift $ + execWriterT $ + algoQSearch search_ty 0 eps grover_k_caller pred_caller ret qsearch_proc_name <- Compiler.newIdent "QAny" - Compiler.addProc $ + return CQPL.ProcDef { CQPL.info_comment = printf "QAny[%s]" (show $ A.getFailProb eps) , CQPL.proc_name = qsearch_proc_name @@ -681,10 +656,5 @@ instance } } - return - CQPL.CallS - { CQPL.fun = CQPL.FunctionCall qsearch_proc_name - , CQPL.args = catMaybes pfun_args ++ [ret] ++ rets - , CQPL.meta_params = [] - } - compileQ _ _ = error "Unsupported" + -- TODO variants + compileQPrim _ _ = error "unsupported"