Skip to content

Commit 5d931a0

Browse files
committed
Add deriveEq for Plinth similar to deriving stock Eq
Add some derived Eq instances
1 parent 62c6d6c commit 5d931a0

File tree

10 files changed

+158
-57
lines changed

10 files changed

+158
-57
lines changed

plutus-tx-plugin/test/IsData/9.6/MyMonoData.golden.th

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ instance PlutusTx.IsData.Class.FromData Plugin.Data.Spec.MyMonoData
88
where {{-# INLINABLE PlutusTx.IsData.Class.fromBuiltinData #-};
99
PlutusTx.IsData.Class.fromBuiltinData d_4 = let constrFun_5 (!index_6) (!args_7) = case (index_6,
1010
args_7) of
11-
{(((PlutusTx.Eq.==) (0 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
11+
{(((PlutusTx.Eq.TH.==) (0 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
1212
(PlutusTx.Builtins.uncons -> GHC.Maybe.Just ((PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_8),
1313
(PlutusTx.Builtins.headMaybe -> GHC.Maybe.Just (PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_9))))) -> GHC.Maybe.Just (Plugin.Data.Spec.Mono1 arg_8 arg_9);
14-
(((PlutusTx.Eq.==) (1 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
14+
(((PlutusTx.Eq.TH.==) (1 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
1515
(PlutusTx.Builtins.headMaybe -> GHC.Maybe.Just (PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_10))) -> GHC.Maybe.Just (Plugin.Data.Spec.Mono2 arg_10);
16-
(((PlutusTx.Eq.==) (2 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
16+
(((PlutusTx.Eq.TH.==) (2 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
1717
(PlutusTx.Builtins.headMaybe -> GHC.Maybe.Just (PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_11))) -> GHC.Maybe.Just (Plugin.Data.Spec.Mono3 arg_11);
1818
_ -> GHC.Maybe.Nothing}
1919
in PlutusTx.Builtins.matchData' d_4 constrFun_5 (GHC.Base.const GHC.Maybe.Nothing) (GHC.Base.const GHC.Maybe.Nothing) (GHC.Base.const GHC.Maybe.Nothing) (GHC.Base.const GHC.Maybe.Nothing)}

plutus-tx-plugin/test/IsData/9.6/MyMonoRecord.golden.th

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ instance PlutusTx.IsData.Class.FromData Plugin.Data.Spec.MyMonoRecord
66
where {{-# INLINABLE PlutusTx.IsData.Class.fromBuiltinData #-};
77
PlutusTx.IsData.Class.fromBuiltinData d_2 = let constrFun_3 (!index_4) (!args_5) = case (index_4,
88
args_5) of
9-
{(((PlutusTx.Eq.==) (0 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
9+
{(((PlutusTx.Eq.TH.==) (0 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
1010
(PlutusTx.Builtins.uncons -> GHC.Maybe.Just ((PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_6),
1111
(PlutusTx.Builtins.headMaybe -> GHC.Maybe.Just (PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_7))))) -> GHC.Maybe.Just (Plugin.Data.Spec.MyMonoRecord arg_6 arg_7);
1212
_ -> GHC.Maybe.Nothing}

plutus-tx-plugin/test/IsData/9.6/MyPolyData.golden.th

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ instance (PlutusTx.IsData.Class.FromData a_0,
1111
where {{-# INLINABLE PlutusTx.IsData.Class.fromBuiltinData #-};
1212
PlutusTx.IsData.Class.fromBuiltinData d_5 = let constrFun_6 (!index_7) (!args_8) = case (index_7,
1313
args_8) of
14-
{(((PlutusTx.Eq.==) (0 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
14+
{(((PlutusTx.Eq.TH.==) (0 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
1515
(PlutusTx.Builtins.uncons -> GHC.Maybe.Just ((PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_9),
1616
(PlutusTx.Builtins.headMaybe -> GHC.Maybe.Just (PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_10))))) -> GHC.Maybe.Just (Plugin.Data.Spec.Poly1 arg_9 arg_10);
17-
(((PlutusTx.Eq.==) (1 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
17+
(((PlutusTx.Eq.TH.==) (1 :: GHC.Num.Integer.Integer) -> GHC.Types.True),
1818
(PlutusTx.Builtins.headMaybe -> GHC.Maybe.Just (PlutusTx.IsData.Class.fromBuiltinData -> GHC.Maybe.Just arg_11))) -> GHC.Maybe.Just (Plugin.Data.Spec.Poly2 arg_11);
1919
_ -> GHC.Maybe.Nothing}
2020
in PlutusTx.Builtins.matchData' d_5 constrFun_6 (GHC.Base.const GHC.Maybe.Nothing) (GHC.Base.const GHC.Maybe.Nothing) (GHC.Base.const GHC.Maybe.Nothing) (GHC.Base.const GHC.Maybe.Nothing)}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
### Added
2+
3+
- A `deriveEq` command to derive PlutusTx.Eq instances for datatypes/newtypes, similar to Haskell's
4+
`deriving stock Eq`

plutus-tx/plutus-tx.cabal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ library
109109
PlutusTx.Semigroup
110110
PlutusTx.Show
111111
PlutusTx.Show.TH
112+
PlutusTx.Eq.TH
112113
PlutusTx.Sqrt
113114
PlutusTx.TH
114115
PlutusTx.These
@@ -209,6 +210,7 @@ test-suite plutus-tx-test
209210
Blueprint.Spec
210211
List.Spec
211212
Bool.Spec
213+
Eq.Spec
212214
Rational.Laws
213215
Rational.Laws.Additive
214216
Rational.Laws.Construction

plutus-tx/src/PlutusTx/Eq.hs

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
1+
{-# LANGUAGE TemplateHaskell #-}
12
{-# OPTIONS_GHC -fno-omit-interface-pragmas #-}
3+
{-# OPTIONS_GHC -Wno-orphans #-}
24

3-
module PlutusTx.Eq (Eq (..), (/=)) where
5+
module PlutusTx.Eq (Eq (..), (/=), deriveEq) where
46

57
import PlutusTx.Bool
68
import PlutusTx.Builtins qualified as Builtins
79
import PlutusTx.Either (Either (..))
10+
import PlutusTx.Eq.TH
811
import Prelude (Maybe (..))
912

1013
{- HLINT ignore -}
1114

12-
infix 4 ==, /=
13-
14-
-- Copied from the GHC definition
15-
16-
-- | The 'Eq' class defines equality ('==').
17-
class Eq a where
18-
(==) :: a -> a -> Bool
19-
20-
-- (/=) deliberately omitted, to make this a one-method class which has a
21-
-- simpler representation
22-
15+
infix 4 /=
2316
(/=) :: (Eq a) => a -> a -> Bool
2417
x /= y = not (x == y)
2518
{-# INLINEABLE (/=) #-}
@@ -48,34 +41,34 @@ instance Eq Builtins.BuiltinBLS12_381_G2_Element where
4841
{-# INLINEABLE (==) #-}
4942
(==) = Builtins.bls12_381_G2_equals
5043

51-
instance (Eq a) => Eq [a] where
52-
{-# INLINEABLE (==) #-}
53-
[] == [] = True
54-
(x : xs) == (y : ys) = x == y && xs == ys
55-
_ == _ = False
56-
57-
instance Eq Bool where
58-
{-# INLINEABLE (==) #-}
59-
True == True = True
60-
False == False = True
61-
_ == _ = False
62-
63-
instance (Eq a) => Eq (Maybe a) where
64-
{-# INLINEABLE (==) #-}
65-
(Just a1) == (Just a2) = a1 == a2
66-
Nothing == Nothing = True
67-
_ == _ = False
68-
69-
instance (Eq a, Eq b) => Eq (Either a b) where
70-
{-# INLINEABLE (==) #-}
71-
(Left a1) == (Left a2) = a1 == a2
72-
(Right b1) == (Right b2) = b1 == b2
73-
_ == _ = False
74-
75-
instance Eq () where
76-
{-# INLINEABLE (==) #-}
77-
_ == _ = True
78-
79-
instance (Eq a, Eq b) => Eq (a, b) where
80-
{-# INLINEABLE (==) #-}
81-
(a, b) == (a', b') = a == a' && b == b'
44+
deriveEq ''[]
45+
deriveEq ''Bool
46+
deriveEq ''Maybe
47+
deriveEq ''Either
48+
deriveEq ''()
49+
deriveEq ''(,)
50+
deriveEq ''(,,)
51+
deriveEq ''(,,,)
52+
deriveEq ''(,,,,)
53+
deriveEq ''(,,,,,)
54+
deriveEq ''(,,,,,,)
55+
deriveEq ''(,,,,,,,)
56+
deriveEq ''(,,,,,,,,)
57+
deriveEq ''(,,,,,,,,,)
58+
deriveEq ''(,,,,,,,,,,)
59+
deriveEq ''(,,,,,,,,,,,)
60+
deriveEq ''(,,,,,,,,,,,,)
61+
deriveEq ''(,,,,,,,,,,,,,)
62+
deriveEq ''(,,,,,,,,,,,,,,)
63+
deriveEq ''(,,,,,,,,,,,,,,,)
64+
deriveEq ''(,,,,,,,,,,,,,,,,)
65+
deriveEq ''(,,,,,,,,,,,,,,,,,)
66+
deriveEq ''(,,,,,,,,,,,,,,,,,,)
67+
deriveEq ''(,,,,,,,,,,,,,,,,,,,)
68+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,)
69+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,,)
70+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,,,)
71+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,,,,)
72+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,,,,,)
73+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,,,,,,)
74+
deriveEq ''(,,,,,,,,,,,,,,,,,,,,,,,,,,)

plutus-tx/src/PlutusTx/Eq/TH.hs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
{-# LANGUAGE TemplateHaskellQuotes #-}
2+
module PlutusTx.Eq.TH (Eq (..), deriveEq) where
3+
4+
import PlutusTx.Bool ((&&), Bool (True))
5+
import Prelude hiding (Eq, (==), (&&), Bool (True))
6+
import Data.Foldable
7+
import Data.Traversable
8+
import Language.Haskell.TH as TH
9+
import Language.Haskell.TH.Datatype as TH
10+
import Data.Deriving.Internal (varTToName)
11+
12+
infix 4 ==
13+
14+
-- Copied from the GHC definition
15+
16+
-- | The 'Eq' class defines equality ('==').
17+
class Eq a where
18+
(==) :: a -> a -> Bool
19+
20+
-- (/=) deliberately omitted, to make this a one-method class which has a
21+
-- simpler representation
22+
23+
deriveEq :: TH.Name -> TH.Q [TH.Dec]
24+
deriveEq name = do
25+
TH.DatatypeInfo
26+
{ TH.datatypeName = tyConName
27+
, TH.datatypeInstTypes = tyVars0
28+
, TH.datatypeCons = cons
29+
} <-
30+
TH.reifyDatatype name
31+
let
32+
-- The purpose of the `TH.VarT . varTToName` roundtrip is to remove the kind
33+
-- signatures attached to the type variables in `tyVars0`. Otherwise, the
34+
-- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`.
35+
tyVars = TH.VarT . varTToName <$> tyVars0
36+
instanceCxt :: TH.Cxt
37+
instanceCxt = TH.AppT (TH.ConT ''Eq) <$> tyVars
38+
instanceType :: TH.Type
39+
instanceType = TH.AppT (TH.ConT ''Eq) $ foldl' TH.AppT (TH.ConT tyConName) tyVars
40+
41+
pure <$> instanceD (pure instanceCxt) (pure instanceType)
42+
[funD '(==) (fmap deriveEqCons cons <> [pure eqDefaultClause])
43+
, TH.pragInlD '(==) TH.Inlinable TH.FunLike TH.AllPhases
44+
]
45+
46+
47+
-- Clause: Cons1 l1 l2 l3 .. ln == Cons1 r1 r2 r3 .. rn
48+
deriveEqCons :: ConstructorInfo -> Q Clause
49+
deriveEqCons (ConstructorInfo {constructorName = name, constructorFields = fields })
50+
= do
51+
argsL <- for [1 .. length fields] $ \i -> TH.newName ("l" <> show i)
52+
argsR <- for [1 .. length fields] $ \i -> TH.newName ("r" <> show i)
53+
pure (TH.Clause [ConP name [] (fmap VarP argsL), ConP name [] (fmap VarP argsR)]
54+
(NormalB $
55+
foldr
56+
(\ (argL,argR) acc ->
57+
TH.InfixE(pure $ TH.InfixE (pure $ TH.VarE argL) (TH.VarE '(==)) (pure $ TH.VarE argR)) (TH.VarE '(&&)) (pure acc))
58+
(TH.ConE 'True)
59+
(zip argsL argsR)
60+
)
61+
[]
62+
)
63+
64+
-- Clause: _ == _ = False
65+
eqDefaultClause :: Clause
66+
eqDefaultClause = TH.Clause [WildP, WildP] (TH.NormalB (ConE 'False)) []

plutus-tx/src/PlutusTx/These.hs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ module PlutusTx.These (
2121
import GHC.Generics (Generic)
2222
import PlutusTx.Blueprint.Definition (HasBlueprintDefinition, definitionRef)
2323
import PlutusTx.Blueprint.TH (makeIsDataSchemaIndexed)
24-
import PlutusTx.Bool
2524
import PlutusTx.Eq
2625
import PlutusTx.Lift
2726
import PlutusTx.Ord
@@ -35,6 +34,7 @@ data These a b = This a | That b | These a b
3534
deriving stock (Generic, Haskell.Eq, Haskell.Show)
3635
deriving anyclass (HasBlueprintDefinition)
3736

37+
deriveEq ''These
3838
deriveShow ''These
3939
makeLift ''These
4040
makeIsDataSchemaIndexed ''These [('This, 0), ('That, 1), ('These, 2)]
@@ -68,10 +68,3 @@ instance (Ord a, Ord b) => Ord (These a b) where
6868
compare (That _) (These _ _) = LT
6969
compare (These _ _) (This _) = GT
7070
compare (These _ _) (That _) = GT
71-
72-
instance (Eq a, Eq b) => Eq (These a b) where
73-
{-# INLINEABLE (==) #-}
74-
(This a) == (This a') = a == a'
75-
(That b) == (That b') = b == b'
76-
(These a b) == (These a' b') = a == a' && b == b'
77-
_ == _ = False

plutus-tx/test/Eq/Spec.hs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{-# LANGUAGE TemplateHaskell #-}
2+
{-# LANGUAGE DerivingStrategies #-}
3+
{-# LANGUAGE OverloadedStrings #-}
4+
{-# LANGUAGE TypeApplications #-}
5+
module Eq.Spec (eqTests) where
6+
7+
import PlutusTx.Builtins as Tx
8+
import PlutusTx.Bool qualified as Tx
9+
import PlutusTx.Eq as Tx
10+
import Control.Exception
11+
12+
import Data.Either
13+
14+
import Prelude hiding (Eq (..), error)
15+
import Prelude qualified as HS (Eq (..),)
16+
import Test.Tasty
17+
import Test.Tasty.HUnit
18+
19+
data SomeLargeADT a b c d e =
20+
SomeLargeADT1 Integer a Tx.Bool b c d
21+
| SomeLargeADT2
22+
| SomeLargeADT3 { f1 :: e, f2 :: e, _f3 :: e, _f4 :: e, _f5 :: e }
23+
deriving stock HS.Eq
24+
deriveEq ''SomeLargeADT
25+
26+
eqTests :: TestTree
27+
eqTests =
28+
let v1 :: SomeLargeADT () BuiltinString () () () = SomeLargeADT1 1 () Tx.True "foobar" () ()
29+
v2 :: SomeLargeADT () () () () () = SomeLargeADT2
30+
v3 :: SomeLargeADT () () () () Integer = SomeLargeADT3 1 2 3 4 5
31+
v3Error1 = v3 { f1 = 0, f2 = error () } -- mismatch comes first, error comes later
32+
v3Error2 = v3 { f1 = error (), f2 = 0 } -- error comes first, mismatch later
33+
34+
in testGroup
35+
"PlutusTx.Eq tests"
36+
[testCase "reflexive1" $ (v1 Tx.== v1) @?= (v1 HS.== v1)
37+
, testCase "reflexive2" $ (v2 Tx.== v2) @?= (v2 HS.== v2)
38+
, testCase "reflexive3" $ (v3 Tx.== v3) @?= (v3 HS.== v3)
39+
, testCase "shortcircuit" $ (v3 Tx.== v3Error1) @?= (v3 Tx.== v3Error1) -- should not throw an error
40+
, testCase "throws" $ try @SomeException (evaluate $ v3 Tx.== v3Error2) >>= assertBool "did not throw error" . isLeft -- should throw erro
41+
]

plutus-tx/test/Spec.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import Hedgehog.Gen qualified as Gen
2020
import Hedgehog.Range qualified as Range
2121
import List.Spec (listTests)
2222
import Bool.Spec (boolTests)
23+
import Eq.Spec (eqTests)
2324
import PlutusCore.Data (Data (B, Constr, I, List, Map))
2425
import PlutusTx.Enum (Enum (..))
2526
import PlutusTx.Numeric (negate)
@@ -47,6 +48,7 @@ tests =
4748
, enumTests
4849
, listTests
4950
, boolTests
51+
, eqTests
5052
, lawsTests
5153
, Show.Spec.propertyTests
5254
, Show.Spec.goldenTests

0 commit comments

Comments
 (0)