diff --git a/containers-tests/tests/intmap-properties.hs b/containers-tests/tests/intmap-properties.hs index 1ead2c581..72771a7ec 100644 --- a/containers-tests/tests/intmap-properties.hs +++ b/containers-tests/tests/intmap-properties.hs @@ -32,7 +32,7 @@ import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.QuickCheck import Test.QuickCheck.Function (apply) -import Test.QuickCheck.Poly (A, B, C) +import Test.QuickCheck.Poly (A, B, C, OrdA) default (Int) @@ -217,6 +217,7 @@ main = defaultMain $ testGroup "intmap-properties" , testProperty "traverseMaybeWithKey->traverseWithKey" prop_traverseMaybeWithKey_degrade_to_traverseWithKey , testProperty "isProperSubmapOfBy" prop_isProperSubmapOfBy , testProperty "isSubmapOfBy" prop_isSubmapOfBy + , testProperty "compare" prop_compare ] {-------------------------------------------------------------------- @@ -1721,3 +1722,6 @@ prop_isSubmapOfBy f m1 m2 = xs = List.intersectBy (\(k1,x1) (k2,x2) -> k1 == k2 && applyFun2 f x1 x2) (assocs m1) (assocs m2) + +prop_compare :: IntMap OrdA -> IntMap OrdA -> Property +prop_compare m1 m2 = compare m1 m2 === compare (toList m1) (toList m2) diff --git a/containers/src/Data/IntMap/Internal.hs b/containers/src/Data/IntMap/Internal.hs index f2c62e9bc..67699ce68 100644 --- a/containers/src/Data/IntMap/Internal.hs +++ b/containers/src/Data/IntMap/Internal.hs @@ -310,6 +310,7 @@ import Data.IntSet.Internal.IntTreeCommons , TreeTreeBranch(..) , treeTreeBranch , i2w + , Order(..) ) import Utils.Containers.Internal.BitUtil (shiftLL, shiftRL, iShiftRL) import Utils.Containers.Internal.StrictPair @@ -3473,12 +3474,97 @@ instance Eq1 IntMap where --------------------------------------------------------------------} instance Ord a => Ord (IntMap a) where - compare m1 m2 = compare (toList m1) (toList m2) + compare m1 m2 = liftCmp compare m1 m2 + {-# INLINABLE compare #-} -- | @since 0.5.9 instance Ord1 IntMap where - liftCompare cmp m n = - liftCompare (liftCompare cmp) (toList m) (toList n) + liftCompare = liftCmp + +liftCmp :: (a -> b -> Ordering) -> IntMap a -> IntMap b -> Ordering +liftCmp cmp = go0 + where + go0 t1@(Bin p1 l1 r1) t2@(Bin p2 l2 r2) = case treeTreeBranch p1 p2 of + ABL | signBranch p1 -> LT + | otherwise -> case go l1 t2 of + Less -> LT + _ -> GT + ABR | signBranch p1 -> case go r1 t2 of + Less -> LT + _ -> GT + | otherwise -> LT + BAL | signBranch p2 -> GT + | otherwise -> case go t1 l2 of + Greater -> GT + _ -> LT + BAR | signBranch p2 -> case go t1 r2 of + Greater -> GT + _ -> LT + | otherwise -> GT + EQL -> + let !(l1', r1', l2', r2') = if signBranch p1 + then (r1, l1, r2, l2) + else (l1, r1, l2, r2) + in case go l1' l2' of + Less -> LT + Prefix' -> GT + Equals -> case go r1' r2' of + Less -> LT + Prefix' -> LT + Equals -> EQ + FlipPrefix -> GT + Greater -> GT + FlipPrefix -> LT + Greater -> GT + NOM -> compare (unPrefix p1) (unPrefix p2) + go0 (Bin p1 l1 r1) (Tip k2 x2) = + case lookupMinSure (if signBranch p1 then r1 else l1) of + KeyValue k1 x1 -> case compare k1 k2 <> cmp x1 x2 of + EQ -> GT + o -> o + go0 (Tip k1 x1) (Bin p2 l2 r2) = + case lookupMinSure (if signBranch p2 then r2 else l2) of + KeyValue k2 x2 -> case compare k1 k2 <> cmp x1 x2 of + EQ -> LT + o -> o + go0 (Tip k1 x1) (Tip k2 x2) = compare k1 k2 <> cmp x1 x2 + go0 Nil Nil = EQ + go0 Nil _ = LT + go0 _ Nil = GT + + go t1@(Bin p1 l1 r1) t2@(Bin p2 l2 r2) = case treeTreeBranch p1 p2 of + ABL -> case go l1 t2 of + Prefix' -> Greater + Equals -> FlipPrefix + o -> o + ABR -> Less + BAL -> case go t1 l2 of + Equals -> Prefix' + FlipPrefix -> Less + o -> o + BAR -> Greater + EQL -> case go l1 l2 of + Prefix' -> Greater + Equals -> go r1 r2 + FlipPrefix -> Less + o -> o + NOM -> if unPrefix p1 < unPrefix p2 then Less else Greater + go (Bin _ l1 _) (Tip k2 x2) = case lookupMinSure l1 of + KeyValue k1 x1 -> case compare k1 k2 <> cmp x1 x2 of + LT -> Less + EQ -> FlipPrefix + GT -> Greater + go (Tip k1 x1) (Bin _ l2 _) = case lookupMinSure l2 of + KeyValue k2 x2 -> case compare k1 k2 <> cmp x1 x2 of + LT -> Less + EQ -> Prefix' + GT -> Greater + go (Tip k1 x1) (Tip k2 x2) = case compare k1 k2 <> cmp x1 x2 of + LT -> Less + EQ -> Equals + GT -> Greater + go _ _ = error "liftCmp.go: Nil" +{-# INLINE liftCmp #-} {-------------------------------------------------------------------- Functor diff --git a/containers/src/Data/IntSet/Internal.hs b/containers/src/Data/IntSet/Internal.hs index 66a43b076..3f0671adc 100644 --- a/containers/src/Data/IntSet/Internal.hs +++ b/containers/src/Data/IntSet/Internal.hs @@ -214,6 +214,7 @@ import Data.IntSet.Internal.IntTreeCommons , TreeTreeBranch(..) , treeTreeBranch , i2w + , Order(..) ) #if __GLASGOW_HASKELL__ @@ -1479,8 +1480,112 @@ equal _ _ = False --------------------------------------------------------------------} instance Ord IntSet where - compare s1 s2 = compare (toAscList s1) (toAscList s2) - -- tentative implementation. See if more efficient exists. + compare = compareIntSets + +compareIntSets :: IntSet -> IntSet -> Ordering +compareIntSets = go0 + where + go0 t1@(Bin p1 l1 r1) t2@(Bin p2 l2 r2) = case treeTreeBranch p1 p2 of + ABL | signBranch p1 -> LT + | otherwise -> case go l1 t2 of + Less -> LT + _ -> GT + ABR | signBranch p1 -> case go r1 t2 of + Less -> LT + _ -> GT + | otherwise -> LT + BAL | signBranch p2 -> GT + | otherwise -> case go t1 l2 of + Greater -> GT + _ -> LT + BAR | signBranch p2 -> case go t1 r2 of + Greater -> GT + _ -> LT + | otherwise -> GT + EQL -> + let !(l1', r1', l2', r2') = if signBranch p1 + then (r1, l1, r2, l2) + else (l1, r1, l2, r2) + in case go l1' l2' of + Less -> LT + Prefix' -> GT + Equals -> case go r1' r2' of + Less -> LT + Prefix' -> LT + Equals -> EQ + FlipPrefix -> GT + Greater -> GT + FlipPrefix -> LT + Greater -> GT + NOM -> compare (unPrefix p1) (unPrefix p2) + go0 (Bin p1 l1 r1) (Tip k2 bm2) = + case leftmostTipSure (if signBranch p1 then r1 else l1) of + k1 :*: bm1 -> case orderTips k1 bm1 k2 bm2 of + Less -> LT + _ -> GT + go0 (Tip k1 bm1) (Bin p2 l2 r2) = + case leftmostTipSure (if signBranch p2 then r2 else l2) of + k2 :*: bm2 -> case orderTips k1 bm1 k2 bm2 of + Greater -> GT + _ -> LT + go0 (Tip k1 bm1) (Tip k2 bm2) = case orderTips k1 bm1 k2 bm2 of + Less -> LT + Prefix' -> LT + Equals -> EQ + FlipPrefix -> GT + Greater -> GT + go0 Nil Nil = EQ + go0 Nil _ = LT + go0 _ Nil = GT + + go t1@(Bin p1 l1 r1) t2@(Bin p2 l2 r2) = case treeTreeBranch p1 p2 of + ABL -> case go l1 t2 of + Prefix' -> Greater + Equals -> FlipPrefix + o -> o + ABR -> Less + BAL -> case go t1 l2 of + Equals -> Prefix' + FlipPrefix -> Less + o -> o + BAR -> Greater + EQL -> case go l1 l2 of + Prefix' -> Greater + Equals -> go r1 r2 + FlipPrefix -> Less + o -> o + NOM -> if unPrefix p1 < unPrefix p2 then Less else Greater + go (Bin _ l1 _) (Tip k2 bm2) = case leftmostTipSure l1 of + k1 :*: bm1 -> case orderTips k1 bm1 k2 bm2 of + Prefix' -> Greater + Equals -> FlipPrefix + o -> o + go (Tip k1 bm1) (Bin _ l2 _) = case leftmostTipSure l2 of + k2 :*: bm2 -> case orderTips k1 bm1 k2 bm2 of + Equals -> Prefix' + FlipPrefix -> Less + o -> o + go (Tip k1 bm1) (Tip k2 bm2) = orderTips k1 bm1 k2 bm2 + go _ _ = error "compareIntSets.go: Nil" + +leftmostTipSure :: IntSet -> StrictPair Int BitMap +leftmostTipSure (Bin _ l _) = leftmostTipSure l +leftmostTipSure (Tip k bm) = k :*: bm +leftmostTipSure Nil = error "leftmostTipSure: Nil" + +orderTips :: Int -> BitMap -> Int -> BitMap -> Order +orderTips k1 bm1 k2 bm2 = case compare k1 k2 of + LT -> Less + EQ | bm1 == bm2 -> Equals + | otherwise -> + let diff = bm1 `xor` bm2 + lowestDiff = diff .&. negate diff + highMask = lowestDiff `xor` negate lowestDiff + in if bm1 .&. lowestDiff == 0 + then (if bm1 .&. highMask == 0 then Prefix' else Greater) + else (if bm2 .&. highMask == 0 then FlipPrefix else Less) + GT -> Greater +{-# INLINE orderTips #-} {-------------------------------------------------------------------- Show diff --git a/containers/src/Data/IntSet/Internal/IntTreeCommons.hs b/containers/src/Data/IntSet/Internal/IntTreeCommons.hs index ba3cbb166..d7cbf8c60 100644 --- a/containers/src/Data/IntSet/Internal/IntTreeCommons.hs +++ b/containers/src/Data/IntSet/Internal/IntTreeCommons.hs @@ -36,6 +36,7 @@ module Data.IntSet.Internal.IntTreeCommons , mask , branchMask , i2w + , Order(..) ) where import Data.Bits (Bits(..), countLeadingZeros) @@ -161,6 +162,14 @@ i2w :: Int -> Word i2w = fromIntegral {-# INLINE i2w #-} +-- Used to compare IntSets and IntMaps +data Order + = Less -- holds for [0,3,4] [0,3,5,1] + | Prefix' -- holds for [0,3,4] [0,3,4,5] + | Equals -- holds for [0,3,4] [0,3,4] + | FlipPrefix -- holds for [0,3,4] [0,3] + | Greater -- holds for [0,3,4] [0,2,5] + {-------------------------------------------------------------------- Notes --------------------------------------------------------------------}