|
| 1 | +{-# LANGUAGE TemplateHaskellQuotes #-} |
| 2 | +module PlutusTx.Eq.TH (Eq (..), deriveEq) where |
| 3 | + |
| 4 | +import PlutusTx.Bool ((&&), Bool (True)) |
| 5 | +import Prelude hiding (Eq, (==), (&&), Bool (True)) |
| 6 | +import Data.Foldable |
| 7 | +import Data.Traversable |
| 8 | +import Language.Haskell.TH as TH |
| 9 | +import Language.Haskell.TH.Datatype as TH |
| 10 | +import Data.Deriving.Internal (varTToName) |
| 11 | + |
| 12 | +infix 4 == |
| 13 | + |
| 14 | +-- Copied from the GHC definition |
| 15 | + |
| 16 | +-- | The 'Eq' class defines equality ('=='). |
| 17 | +class Eq a where |
| 18 | + (==) :: a -> a -> Bool |
| 19 | + |
| 20 | +-- (/=) deliberately omitted, to make this a one-method class which has a |
| 21 | +-- simpler representation |
| 22 | + |
| 23 | +deriveEq :: TH.Name -> TH.Q [TH.Dec] |
| 24 | +deriveEq name = do |
| 25 | + TH.DatatypeInfo |
| 26 | + { TH.datatypeName = tyConName |
| 27 | + , TH.datatypeInstTypes = tyVars0 |
| 28 | + , TH.datatypeCons = cons |
| 29 | + } <- |
| 30 | + TH.reifyDatatype name |
| 31 | + let |
| 32 | + -- The purpose of the `TH.VarT . varTToName` roundtrip is to remove the kind |
| 33 | + -- signatures attached to the type variables in `tyVars0`. Otherwise, the |
| 34 | + -- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`. |
| 35 | + tyVars = TH.VarT . varTToName <$> tyVars0 |
| 36 | + instanceCxt :: TH.Cxt |
| 37 | + instanceCxt = TH.AppT (TH.ConT ''Eq) <$> tyVars |
| 38 | + instanceType :: TH.Type |
| 39 | + instanceType = TH.AppT (TH.ConT ''Eq) $ foldl' TH.AppT (TH.ConT tyConName) tyVars |
| 40 | + |
| 41 | + pure <$> instanceD (pure instanceCxt) (pure instanceType) |
| 42 | + [funD '(==) (fmap deriveEqCons cons <> [pure eqDefaultClause]) |
| 43 | + , TH.pragInlD '(==) TH.Inlinable TH.FunLike TH.AllPhases |
| 44 | + ] |
| 45 | + |
| 46 | + |
| 47 | +-- Clause: Cons1 l1 l2 l3 .. ln == Cons1 r1 r2 r3 .. rn |
| 48 | +deriveEqCons :: ConstructorInfo -> Q Clause |
| 49 | +deriveEqCons (ConstructorInfo {constructorName = name, constructorFields = fields }) |
| 50 | + = do |
| 51 | + argsL <- for [1 .. length fields] $ \i -> TH.newName ("l" <> show i) |
| 52 | + argsR <- for [1 .. length fields] $ \i -> TH.newName ("r" <> show i) |
| 53 | + pure (TH.Clause [ConP name [] (fmap VarP argsL), ConP name [] (fmap VarP argsR)] |
| 54 | + (NormalB $ |
| 55 | + foldr |
| 56 | + (\ (argL,argR) acc -> |
| 57 | + TH.InfixE(pure $ TH.InfixE (pure $ TH.VarE argL) (TH.VarE '(==)) (pure $ TH.VarE argR)) (TH.VarE '(&&)) (pure acc)) |
| 58 | + (TH.ConE 'True) |
| 59 | + (zip argsL argsR) |
| 60 | + ) |
| 61 | + [] |
| 62 | + ) |
| 63 | + |
| 64 | +-- Clause: _ == _ = False |
| 65 | +eqDefaultClause :: Clause |
| 66 | +eqDefaultClause = TH.Clause [WildP, WildP] (TH.NormalB (ConE 'False)) [] |
0 commit comments