Skip to content

Commit 1bdeaab

Browse files
committed
Add deriveEq for Plinth similar to deriving stock Eq
1 parent 7549152 commit 1bdeaab

File tree

5 files changed

+115
-12
lines changed

5 files changed

+115
-12
lines changed

plutus-tx/plutus-tx.cabal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ library
109109
PlutusTx.Semigroup
110110
PlutusTx.Show
111111
PlutusTx.Show.TH
112+
PlutusTx.Eq.TH
112113
PlutusTx.Sqrt
113114
PlutusTx.TH
114115
PlutusTx.These
@@ -209,6 +210,7 @@ test-suite plutus-tx-test
209210
Blueprint.Spec
210211
List.Spec
211212
Bool.Spec
213+
Eq.Spec
212214
Rational.Laws
213215
Rational.Laws.Additive
214216
Rational.Laws.Construction

plutus-tx/src/PlutusTx/Eq.hs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
{-# OPTIONS_GHC -fno-omit-interface-pragmas #-}
2+
{-# OPTIONS_GHC -Wno-orphans #-}
23

3-
module PlutusTx.Eq (Eq (..), (/=)) where
4+
module PlutusTx.Eq (Eq (..), (/=), deriveEq) where
45

6+
import PlutusTx.Eq.TH
57
import PlutusTx.Bool
68
import PlutusTx.Builtins qualified as Builtins
79
import PlutusTx.Either (Either (..))
@@ -10,17 +12,7 @@ import Prelude (Maybe (..))
1012

1113
{- HLINT ignore -}
1214

13-
infix 4 ==, /=
14-
15-
-- Copied from the GHC definition
16-
17-
-- | The 'Eq' class defines equality ('==').
18-
class Eq a where
19-
(==) :: a -> a -> Bool
20-
21-
-- (/=) deliberately omitted, to make this a one-method class which has a
22-
-- simpler representation
23-
15+
infix 4 /=
2416
(/=) :: (Eq a) => a -> a -> Bool
2517
x /= y = not (x == y)
2618
{-# INLINEABLE (/=) #-}

plutus-tx/src/PlutusTx/Eq/TH.hs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
{-# LANGUAGE TemplateHaskellQuotes #-}
2+
module PlutusTx.Eq.TH (Eq (..), deriveEq) where
3+
4+
import PlutusTx.Bool ((&&), Bool (True))
5+
import Prelude hiding (Eq, (==), (&&), Bool (True))
6+
import Data.Foldable
7+
import Data.Traversable
8+
import Language.Haskell.TH as TH
9+
import Language.Haskell.TH.Datatype as TH
10+
import Data.Deriving.Internal (varTToName)
11+
12+
infix 4 ==
13+
14+
-- Copied from the GHC definition
15+
16+
-- | The 'Eq' class defines equality ('==').
17+
class Eq a where
18+
(==) :: a -> a -> Bool
19+
20+
-- (/=) deliberately omitted, to make this a one-method class which has a
21+
-- simpler representation
22+
23+
deriveEq :: TH.Name -> TH.Q [TH.Dec]
24+
deriveEq name = do
25+
TH.DatatypeInfo
26+
{ TH.datatypeName = tyConName
27+
, TH.datatypeInstTypes = tyVars0
28+
, TH.datatypeCons = cons
29+
} <-
30+
TH.reifyDatatype name
31+
let
32+
-- The purpose of the `TH.VarT . varTToName` roundtrip is to remove the kind
33+
-- signatures attached to the type variables in `tyVars0`. Otherwise, the
34+
-- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`.
35+
tyVars = TH.VarT . varTToName <$> tyVars0
36+
instanceCxt :: TH.Cxt
37+
instanceCxt = TH.AppT (TH.ConT ''Eq) <$> tyVars
38+
instanceType :: TH.Type
39+
instanceType = TH.AppT (TH.ConT ''Eq) $ foldl' TH.AppT (TH.ConT tyConName) tyVars
40+
41+
pure <$> instanceD (pure instanceCxt) (pure instanceType)
42+
[funD '(==) (fmap deriveEqCons cons <> [pure eqDefaultClause])
43+
, TH.pragInlD '(==) TH.Inlinable TH.FunLike TH.AllPhases
44+
]
45+
46+
47+
-- Clause: Cons1 l1 l2 l3 .. ln == Cons1 r1 r2 r3 .. rn
48+
deriveEqCons :: ConstructorInfo -> Q Clause
49+
deriveEqCons (ConstructorInfo {constructorName = name, constructorFields = fields })
50+
= do
51+
argsL <- for [1 .. length fields] $ \i -> TH.newName ("l" <> show i)
52+
argsR <- for [1 .. length fields] $ \i -> TH.newName ("r" <> show i)
53+
pure (TH.Clause [ConP name [] (fmap VarP argsL), ConP name [] (fmap VarP argsR)]
54+
(NormalB $
55+
foldr
56+
(\ (argL,argR) acc ->
57+
TH.InfixE(pure $ TH.InfixE (pure $ TH.VarE argL) (TH.VarE '(==)) (pure $ TH.VarE argR)) (TH.VarE '(&&)) (pure acc))
58+
(TH.ConE 'True)
59+
(zip argsL argsR)
60+
)
61+
[]
62+
)
63+
64+
-- Clause: _ == _ = False
65+
eqDefaultClause :: Clause
66+
eqDefaultClause = TH.Clause [WildP, WildP] (TH.NormalB (ConE 'False)) []

plutus-tx/test/Eq/Spec.hs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{-# LANGUAGE TemplateHaskell #-}
2+
{-# LANGUAGE DerivingStrategies #-}
3+
{-# LANGUAGE OverloadedStrings #-}
4+
{-# LANGUAGE TypeApplications #-}
5+
module Eq.Spec (eqTests) where
6+
7+
import PlutusTx.Builtins as Tx
8+
import PlutusTx.Bool qualified as Tx
9+
import PlutusTx.Eq as Tx
10+
import Control.Exception
11+
12+
import Data.Either
13+
14+
import Prelude hiding (Eq (..), error)
15+
import Prelude qualified as HS (Eq (..),)
16+
import Test.Tasty
17+
import Test.Tasty.HUnit
18+
19+
data SomeLargeADT a b c d e =
20+
SomeLargeADT1 Integer a Tx.Bool b c d
21+
| SomeLargeADT2
22+
| SomeLargeADT3 { f1 :: e, f2 :: e, _f3 :: e, _f4 :: e, _f5 :: e }
23+
deriving stock HS.Eq
24+
deriveEq ''SomeLargeADT
25+
26+
eqTests :: TestTree
27+
eqTests =
28+
let v1 :: SomeLargeADT () BuiltinString () () () = SomeLargeADT1 1 () Tx.True "foobar" () ()
29+
v2 :: SomeLargeADT () () () () () = SomeLargeADT2
30+
v3 :: SomeLargeADT () () () () Integer = SomeLargeADT3 1 2 3 4 5
31+
v3Error1 = v3 { f1 = 0, f2 = error () } -- mismatch comes first, error comes later
32+
v3Error2 = v3 { f1 = error (), f2 = 0 } -- error comes first, mismatch later
33+
34+
in testGroup
35+
"PlutusTx.Eq tests"
36+
[testCase "reflexive1" $ (v1 Tx.== v1) @?= (v1 HS.== v1)
37+
, testCase "reflexive2" $ (v2 Tx.== v2) @?= (v2 HS.== v2)
38+
, testCase "reflexive3" $ (v3 Tx.== v3) @?= (v3 HS.== v3)
39+
, testCase "shortcircuit" $ (v3 Tx.== v3Error1) @?= (v3 Tx.== v3Error1) -- should not throw an error
40+
, testCase "throws" $ try @SomeException (evaluate $ v3 Tx.== v3Error2) >>= assertBool "did not throw error" . isLeft -- should throw erro
41+
]

plutus-tx/test/Spec.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import Hedgehog.Gen qualified as Gen
2020
import Hedgehog.Range qualified as Range
2121
import List.Spec (listTests)
2222
import Bool.Spec (boolTests)
23+
import Eq.Spec (eqTests)
2324
import PlutusCore.Data (Data (B, Constr, I, List, Map))
2425
import PlutusTx.Enum (Enum (..))
2526
import PlutusTx.Numeric (negate)
@@ -47,6 +48,7 @@ tests =
4748
, enumTests
4849
, listTests
4950
, boolTests
51+
, eqTests
5052
, lawsTests
5153
, Show.Spec.propertyTests
5254
, Show.Spec.goldenTests

0 commit comments

Comments
 (0)