Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/matrix_search/matrix_search.qpl
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ uproc IsRowAllOnes_U(i : IN Fin<20>, okr : OUT Fin<2>, hasZero : AUX Fin<2>, has
}

// Grover[...]
uproc Grover[k](i : IN Fin<20>, x : IN Fin<10>, hasZero : OUT Fin<2>, aux_4 : AUX Fin<2>, aux_5 : AUX Fin<2>, aux_6 : AUX Fin<2>) {
uproc Grover[k](i : Fin<20>, x : IN Fin<10>, hasZero : OUT Fin<2>, aux_4 : AUX Fin<2>, aux_5 : AUX Fin<2>, aux_6 : AUX Fin<2>) {
hasZero *= X;
hasZero *= H;
x *= Distr[uniform : Fin<10>];
Expand Down
17 changes: 3 additions & 14 deletions src/Traq/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ module Traq.Primitives (
import GHC.Generics

import qualified Traq.Analysis as A
import qualified Traq.Compiler as Compiler
import Traq.Prelude
import Traq.Primitives.Class
import Traq.Primitives.Search.DetSearch
Expand Down Expand Up @@ -89,24 +88,14 @@ instance
instance
(P.TypingReqs size, Integral size, RealFloat prec, Show prec) =>
UnitaryCompilePrim (DefaultPrimCollection size prec) size prec
instance
(size ~ SizeT, P.TypingReqs size, Integral size, RealFloat prec, Show prec) =>
QuantumCompilePrim (DefaultPrimCollection size prec) size prec

type DefaultPrims sizeT precT = Primitive (DefaultPrimCollection sizeT precT)

type DefaultPrims' = DefaultPrims SizeT Double

instance
( Integral sizeT
, Floating precT
, RealFloat precT
, P.TypingReqs sizeT
, Show precT
, sizeT ~ SizeT
) =>
Compiler.CompileQ (A.AnnFailProb (DefaultPrims sizeT precT))
where
compileQ (A.AnnFailProb eps (Primitive fs (QAny q))) = Compiler.compileQ (A.AnnFailProb eps (Primitive fs q))
compileQ _ = error "TODO: lowerPrimitive"

-- ================================================================================
-- Worst-cost prim collection
-- ================================================================================
Expand Down
93 changes: 90 additions & 3 deletions src/Traq/Primitives/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Traq.Primitives.Class (
module Traq.Primitives.Class.Eval,
module Traq.Primitives.Class.UnitaryCost,
module Traq.Primitives.Class.QuantumCost,
module Traq.Primitives.Class.UnitaryCompile,
module Traq.Primitives.Class.Compile,
) where

import Control.Applicative (Alternative ((<|>)), many)
Expand All @@ -38,12 +38,12 @@ import qualified Traq.Analysis as A
import qualified Traq.CQPL as CQPL
import qualified Traq.Compiler as Compiler
import Traq.Prelude
import Traq.Primitives.Class.Compile
import Traq.Primitives.Class.Eval
import Traq.Primitives.Class.Prelude
import Traq.Primitives.Class.QuantumCost
import Traq.Primitives.Class.Serialize
import Traq.Primitives.Class.TypeCheck
import Traq.Primitives.Class.UnitaryCompile
import Traq.Primitives.Class.UnitaryCost
import qualified Traq.ProtoLang as P
import qualified Traq.Utils.Printing as PP
Expand Down Expand Up @@ -316,7 +316,14 @@ instance
>>= maybeWithError (printf "could not find uproc `%s` for fun `%s`" uproc_name pfun_name)
return $ Compiler.aux_tys sign

let builder = UnitaryCompilePrimBuilder{mk_ucall, uproc_aux_types, ret_vars = rets}
let builder =
PrimCompileEnv
{ mk_ucall
, mk_call = reshapeUnsafe $ replicate (length par_funs) (error "cannot call proc from UPrim")
, mk_meas = reshapeUnsafe $ replicate (length par_funs) (error "cannot meas uproc from UPrim")
, uproc_aux_types
, ret_vars = rets
}
let arg_bounder = prependBoundArgs (map Compiler.mkUProcName pfun_names) bound_args
prim_proc_raw <-
runReaderT (compileUPrim prim eps) builder
Expand Down Expand Up @@ -496,3 +503,83 @@ instance
when (n_query_u > 0) $ void $ A.annEpsU1 eps_fn_u named_fn

pure $ A.AnnFailProb eps_alg $ Primitive par_funs prim

-- --------------------------------------------------------------------------------
-- Compilation
-- --------------------------------------------------------------------------------

instance
{-# OVERLAPPABLE #-}
( TypeCheckPrim prim (SizeType prim)
, P.TypingReqs (SizeType prim)
, UnitaryCompilePrim prim (SizeType prim) (PrecType prim)
, QuantumCompilePrim prim (SizeType prim) (PrecType prim)
) =>
Compiler.CompileQ (A.AnnFailProb (Primitive prim))
where
compileQ (A.AnnFailProb eps (Primitive par_funs prim)) rets = do
let pfun_names = map pfun_name par_funs
let bound_args_names = concatMap (catMaybes . pfun_args) par_funs
bound_args_tys <- forM bound_args_names $ \x -> use $ P._typingCtx . Ctx.at x . non' (error $ "invalid arg " ++ x)
let bound_args = zip bound_args_names bound_args_tys

mk_ucall <-
reshape $
par_funs <&> \PartialFun{pfun_name, pfun_args} xs ->
CQPL.UCallS
{ uproc_id = Compiler.mkUProcName pfun_name
, dagger = False
, qargs = placeArgsWithExcess pfun_args xs
}

uproc_aux_types <-
reshape =<< do
forM par_funs $ \PartialFun{pfun_name} -> do
let uproc_name = Compiler.mkUProcName pfun_name
sign <-
use (Compiler._procSignatures . at uproc_name)
>>= maybeWithError (printf "could not find uproc `%s` for fun `%s`" uproc_name pfun_name)
return $ Compiler.aux_tys sign

mk_call <-
reshape $
par_funs <&> \PartialFun{pfun_name, pfun_args} xs ->
CQPL.CallS
{ fun = CQPL.FunctionCall $ Compiler.mkQProcName pfun_name
, meta_params = []
, args = placeArgsWithExcess pfun_args xs
}

mk_meas <-
reshape $
par_funs <&> \PartialFun{pfun_name, pfun_args} xs ->
CQPL.CallS
{ fun = CQPL.UProcAndMeas $ Compiler.mkUProcName pfun_name
, meta_params = []
, args = placeArgsWithExcess pfun_args xs
}

let builder =
PrimCompileEnv
{ mk_ucall
, mk_call
, mk_meas
, uproc_aux_types
, ret_vars = rets
}
let arg_bounder =
prependBoundArgs
(map Compiler.mkUProcName pfun_names ++ map Compiler.mkQProcName pfun_names)
bound_args
prim_proc_raw <-
runReaderT (compileQPrim prim eps) builder
& censor (Compiler._loweredProcs . each %~ arg_bounder)
let prim_proc = arg_bounder prim_proc_raw
Compiler.addProc prim_proc

return $
CQPL.CallS
{ fun = CQPL.FunctionCall $ CQPL.proc_name prim_proc
, meta_params = []
, args = map fst bound_args ++ rets
}
191 changes: 191 additions & 0 deletions src/Traq/Primitives/Class/Compile.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE UndecidableInstances #-}

module Traq.Primitives.Class.Compile (
PrimCompileEnv (..),
UnitaryCompilePrim (..),
QuantumCompilePrim (..),
) where

import Control.Monad.Reader (ReaderT (..))
import Control.Monad.Trans (lift)
import GHC.Generics

import Lens.Micro.Mtl

import qualified Traq.Analysis as A
import qualified Traq.CQPL as CQPL
import Traq.Compiler
import Traq.Prelude
import Traq.Primitives.Class.Prelude
import qualified Traq.ProtoLang as P

-- --------------------------------------------------------------------------------
-- Environment and enclosing monad for compiling primitives.
-- --------------------------------------------------------------------------------

type UCallBuilder size = [Ident] -> CQPL.UStmt size
type CallBuilder size = [Ident] -> CQPL.Stmt size

-- | Helpers to compile a primitive.
data PrimCompileEnv shape size = PrimCompileEnv
{ mk_ucall :: shape (UCallBuilder size)
-- ^ helper to generate a call to a unitary function argument.
, mk_call :: shape (CallBuilder size)
-- ^ helper to generate a call to a classical function argument.
, mk_meas :: shape (CallBuilder size)
-- ^ helper to generate a call-and-meas to a unitary proc arg.
, uproc_aux_types :: shape [P.VarType size]
-- ^ auxiliary variables for each unitary function argument.
, ret_vars :: [Ident]
-- ^ return variables to store the result in.
}

reshapeBuilder ::
(ValidPrimShape shape, ValidPrimShape shape') =>
PrimCompileEnv shape size ->
Either String (PrimCompileEnv shape' size)
reshapeBuilder PrimCompileEnv{..} = do
mk_ucall <- reshape mk_ucall

Check warning on line 52 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘mk_ucall’ shadows the existing binding

Check warning on line 52 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘mk_ucall’ shadows the existing binding

Check warning on line 52 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘mk_ucall’ shadows the existing binding

Check warning on line 52 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘mk_ucall’ shadows the existing binding
mk_call <- reshape mk_call

Check warning on line 53 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘mk_call’ shadows the existing binding

Check warning on line 53 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘mk_call’ shadows the existing binding

Check warning on line 53 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘mk_call’ shadows the existing binding

Check warning on line 53 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘mk_call’ shadows the existing binding
mk_meas <- reshape mk_meas

Check warning on line 54 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘mk_meas’ shadows the existing binding

Check warning on line 54 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘mk_meas’ shadows the existing binding

Check warning on line 54 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘mk_meas’ shadows the existing binding

Check warning on line 54 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘mk_meas’ shadows the existing binding
uproc_aux_types <- reshape uproc_aux_types

Check warning on line 55 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘uproc_aux_types’ shadows the existing binding

Check warning on line 55 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.6.7)

This binding for ‘uproc_aux_types’ shadows the existing binding

Check warning on line 55 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘uproc_aux_types’ shadows the existing binding

Check warning on line 55 in src/Traq/Primitives/Class/Compile.hs

View workflow job for this annotation

GitHub Actions / Build and test (cabal) (9.4.8)

This binding for ‘uproc_aux_types’ shadows the existing binding
return PrimCompileEnv{..}

type PrimCompileMonad ext prim =
ReaderT
(PrimCompileEnv (PrimFnShape prim) (SizeType prim))
(CompilerT ext)

-- --------------------------------------------------------------------------------
-- Unitary Compilation
-- --------------------------------------------------------------------------------

-- | Compile a primitive to a uproc
class
( size ~ SizeType prim
, prec ~ PrecType prim
, ValidPrimShape (PrimFnShape prim)
) =>
UnitaryCompilePrim prim size prec
| prim -> size prec
where
compileUPrim ::
forall ext' m shape.
( m ~ PrimCompileMonad ext' prim
, size ~ SizeType ext'
, prec ~ PrecType ext'
, shape ~ PrimFnShape prim
) =>
prim ->
A.FailProb prec ->
m (CQPL.ProcDef size)
default compileUPrim ::
forall ext' m shape.
( Generic prim
, GUnitaryCompilePrim (Rep prim) size prec
, m ~ PrimCompileMonad ext' prim
, size ~ SizeType ext'
, prec ~ PrecType ext'
, shape ~ PrimFnShape prim
) =>
prim ->
A.FailProb prec ->
m (CQPL.ProcDef size)
compileUPrim prim eps = do
builder <- view id
lift $ do
builder' <- lift $ reshapeBuilder builder
gcompileUPrim (from prim) eps builder'

class GUnitaryCompilePrim f size prec | f -> size prec where
gcompileUPrim ::
forall ext' m p.
( m ~ CompilerT ext'
, size ~ SizeType ext'
, prec ~ PrecType ext'
) =>
f p ->
A.FailProb prec ->
PrimCompileEnv [] size ->
m (CQPL.ProcDef size)

instance (GUnitaryCompilePrim a size prec, GUnitaryCompilePrim b size prec) => GUnitaryCompilePrim (a :+: b) size prec where
gcompileUPrim (L1 x) = gcompileUPrim x
gcompileUPrim (R1 x) = gcompileUPrim x

instance (GUnitaryCompilePrim f size prec) => GUnitaryCompilePrim (M1 i c f) size prec where
gcompileUPrim (M1 x) = gcompileUPrim x

instance (UnitaryCompilePrim a size prec) => GUnitaryCompilePrim (K1 i a) size prec where
gcompileUPrim (K1 x) eps builder = do
builder' <- lift $ reshapeBuilder builder
runReaderT (compileUPrim x eps) builder'

-- --------------------------------------------------------------------------------
-- Quantum Compilation
-- --------------------------------------------------------------------------------

-- | Compile a primitive to a cq-proc
class
( size ~ SizeType prim
, prec ~ PrecType prim
, ValidPrimShape (PrimFnShape prim)
) =>
QuantumCompilePrim prim size prec
| prim -> size prec
where
compileQPrim ::
forall ext' m shape.
( m ~ PrimCompileMonad ext' prim
, size ~ SizeType ext'
, prec ~ PrecType ext'
, shape ~ PrimFnShape prim
) =>
prim ->
A.FailProb prec ->
m (CQPL.ProcDef size)
default compileQPrim ::
forall ext' m shape.
( Generic prim
, GQuantumCompilePrim (Rep prim) size prec
, m ~ PrimCompileMonad ext' prim
, size ~ SizeType ext'
, prec ~ PrecType ext'
, shape ~ PrimFnShape prim
) =>
prim ->
A.FailProb prec ->
m (CQPL.ProcDef size)
compileQPrim prim eps = do
builder <- view id
lift $ do
builder' <- lift $ reshapeBuilder builder
gcompileQPrim (from prim) eps builder'

class GQuantumCompilePrim f size prec | f -> size prec where
gcompileQPrim ::
forall ext' m p.
( m ~ CompilerT ext'
, size ~ SizeType ext'
, prec ~ PrecType ext'
) =>
f p ->
A.FailProb prec ->
PrimCompileEnv [] size ->
m (CQPL.ProcDef size)

instance (GQuantumCompilePrim a size prec, GQuantumCompilePrim b size prec) => GQuantumCompilePrim (a :+: b) size prec where
gcompileQPrim (L1 x) = gcompileQPrim x
gcompileQPrim (R1 x) = gcompileQPrim x

instance (GQuantumCompilePrim f size prec) => GQuantumCompilePrim (M1 i c f) size prec where
gcompileQPrim (M1 x) = gcompileQPrim x

instance (QuantumCompilePrim a size prec) => GQuantumCompilePrim (K1 i a) size prec where
gcompileQPrim (K1 x) eps builder = do
builder' <- lift $ reshapeBuilder builder
runReaderT (compileQPrim x eps) builder'
Loading
Loading