diff --git a/doc/spec.md b/doc/spec.md
index c98ee173..eb8e1610 100644
--- a/doc/spec.md
+++ b/doc/spec.md
@@ -155,9 +155,14 @@ reproducibility is paramount, such as build tools.
* [list·remove](#list·remove)
* [set·add](#set·add)
* [set·clear](#set·clear)
+ * [set·difference](#set·difference)
* [set·discard](#set·discard)
+ * [set·intersection](#set·intersection)
+ * [set·issubset](#set·issubset)
+ * [set·issuperset](#set·issuperset)
* [set·pop](#set·pop)
* [set·remove](#set·remove)
+ * [set·symmetric_difference](#set·symmetric_difference)
* [set·union](#set·union)
* [string·capitalize](#string·capitalize)
* [string·codepoint_ords](#string·codepoint_ords)
@@ -975,9 +980,14 @@ A set has these methods:
* [`add`](#set·add)
* [`clear`](#set·clear)
+* [`difference`](#set·difference)
* [`discard`](#set·discard)
+* [`intersection`](#set·intersection)
+* [`issubset`](#set·issubset)
+* [`issuperset`](#set·issuperset)
* [`pop`](#set·pop)
* [`remove`](#set·remove)
+* [`symmetric_difference`](#set·symmetric_difference)
* [`union`](#set·union)
@@ -1995,6 +2005,11 @@ which breaks several mathematical identities. For example, if `x` is
a `NaN` value, the comparisons `x < y`, `x == y`, and `x > y` all
yield false for all values of `y`.
+When used to compare two `set` objects, the `<=`, and `>=` operators will report
+whether one set is a subset or superset of another. Similarly, using `<` or `>` will
+report whether a set is a proper subset or superset of another, thus `x > y` is
+equivalent to `x >= y and x != y`.
+
Applications may define additional types that support ordered
comparison.
@@ -2045,6 +2060,8 @@ Sets
int & int # bitwise intersection (AND)
set & set # set intersection
set ^ set # set symmetric difference
+ set - set # set difference
+
Dict
dict | dict # ordered union
@@ -2115,6 +2132,7 @@ Implementations may impose a limit on the second operand of a left shift.
set([1, 2]) & set([2, 3]) # set([2])
set([1, 2]) | set([2, 3]) # set([1, 2, 3])
set([1, 2]) ^ set([2, 3]) # set([1, 3])
+set([1, 2]) - set([2, 3]) # set([1])
```
Implementation note:
@@ -3782,6 +3800,18 @@ x.clear(2) # None
x # set([])
```
+
+### set·difference
+
+`S.difference(y)` returns a new set into which have been inserted all the elements of set S which are not in y.
+
+y can be any type of iterable (e.g. set, list, tuple).
+
+```python
+x = set([1, 2, 3])
+x.difference([3, 4, 5]) # set([1, 2])
+```
+
### set·discard
@@ -3798,6 +3828,44 @@ x.discard(2) # None
x # set([1, 3])
```
+
+### set·intersection
+
+`S.intersection(y)` returns a new set into which have been inserted all the elements of set S which are also in y.
+
+y can be any type of iterable (e.g. set, list, tuple).
+
+```python
+x = set([1, 2, 3])
+x.intersection([3, 4, 5]) # set([3])
+```
+
+
+### set·issubset
+
+`S.issubset(y)` returns True if all items in S are also in y, otherwise it returns False.
+
+y can be any type of iterable (e.g. set, list, tuple).
+
+```python
+x = set([1, 2])
+x.issubset([1, 2, 3]) # True
+x.issubset([1, 3, 4]) # False
+```
+
+
+### set·issuperset
+
+`S.issuperset(y)` returns True if all items in y are also in S, otherwise it returns False.
+
+y can be any type of iterable (e.g. set, list, tuple).
+
+```python
+x = set([1, 2, 3])
+x.issuperset([1, 2]) # True
+x.issuperset([1, 3, 4]) # False
+```
+
### set·pop
@@ -3826,6 +3894,18 @@ x # set([1, 3])
x.remove(2) # error: element not found
```
+
+### set·symmetric_difference
+
+`S.symmetric_difference(y)` creates a new set into which is inserted all of the items which are in S but not y, followed by all of the items which are in y but not S.
+
+y can be any type of iterable (e.g. set, list, tuple).
+
+```python
+x = set([1, 2, 3])
+x.symmetric_difference([3, 4, 5]) # set([1, 2, 4, 5])
+```
+
### set·union
diff --git a/starlark/eval.go b/starlark/eval.go
index 706c6249..3ab08e47 100644
--- a/starlark/eval.go
+++ b/starlark/eval.go
@@ -826,6 +826,12 @@ func Binary(op syntax.Token, x, y Value) (Value, error) {
}
return x - yf, nil
}
+ case *Set: // difference
+ if y, ok := y.(*Set); ok {
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.Difference(iter)
+ }
}
case syntax.STAR:
@@ -1097,17 +1103,9 @@ func Binary(op syntax.Token, x, y Value) (Value, error) {
}
case *Set: // intersection
if y, ok := y.(*Set); ok {
- set := new(Set)
- if x.Len() > y.Len() {
- x, y = y, x // opt: range over smaller set
- }
- for xe := x.ht.head; xe != nil; xe = xe.next {
- // Has, Insert cannot fail here.
- if found, _ := y.Has(xe.key); found {
- set.Insert(xe.key)
- }
- }
- return set, nil
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.Intersection(iter)
}
}
@@ -1119,18 +1117,9 @@ func Binary(op syntax.Token, x, y Value) (Value, error) {
}
case *Set: // symmetric difference
if y, ok := y.(*Set); ok {
- set := new(Set)
- for xe := x.ht.head; xe != nil; xe = xe.next {
- if found, _ := y.Has(xe.key); !found {
- set.Insert(xe.key)
- }
- }
- for ye := y.ht.head; ye != nil; ye = ye.next {
- if found, _ := x.Has(ye.key); !found {
- set.Insert(ye.key)
- }
- }
- return set, nil
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.SymmetricDifference(iter)
}
}
diff --git a/starlark/library.go b/starlark/library.go
index 0e3c7bf7..ef032ee2 100644
--- a/starlark/library.go
+++ b/starlark/library.go
@@ -140,12 +140,17 @@ var (
}
setMethods = map[string]*Builtin{
- "add": NewBuiltin("add", set_add),
- "clear": NewBuiltin("clear", set_clear),
- "discard": NewBuiltin("discard", set_discard),
- "pop": NewBuiltin("pop", set_pop),
- "remove": NewBuiltin("remove", set_remove),
- "union": NewBuiltin("union", set_union),
+ "add": NewBuiltin("add", set_add),
+ "clear": NewBuiltin("clear", set_clear),
+ "difference": NewBuiltin("difference", set_difference),
+ "discard": NewBuiltin("discard", set_discard),
+ "intersection": NewBuiltin("intersection", set_intersection),
+ "issubset": NewBuiltin("issubset", set_issubset),
+ "issuperset": NewBuiltin("issuperset", set_issuperset),
+ "pop": NewBuiltin("pop", set_pop),
+ "remove": NewBuiltin("remove", set_remove),
+ "symmetric_difference": NewBuiltin("symmetric_difference", set_symmetric_difference),
+ "union": NewBuiltin("union", set_union),
}
)
@@ -2204,6 +2209,68 @@ func set_clear(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error)
return None, nil
}
+// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·difference.
+func set_difference(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
+ // TODO: support multiple others: s.difference(*others)
+ var other Iterable
+ if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil {
+ return nil, err
+ }
+ iter := other.Iterate()
+ defer iter.Done()
+ diff, err := b.Receiver().(*Set).Difference(iter)
+ if err != nil {
+ return nil, nameErr(b, err)
+ }
+ return diff, nil
+}
+
+// https://github.com/google/starlark-go/blob/master/doc/spec.md#set_intersection.
+func set_intersection(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
+ // TODO: support multiple others: s.difference(*others)
+ var other Iterable
+ if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil {
+ return nil, err
+ }
+ iter := other.Iterate()
+ defer iter.Done()
+ diff, err := b.Receiver().(*Set).Intersection(iter)
+ if err != nil {
+ return nil, nameErr(b, err)
+ }
+ return diff, nil
+}
+
+// https://github.com/google/starlark-go/blob/master/doc/spec.md#set_issubset.
+func set_issubset(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
+ var other Iterable
+ if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil {
+ return nil, err
+ }
+ iter := other.Iterate()
+ defer iter.Done()
+ diff, err := b.Receiver().(*Set).IsSubset(iter)
+ if err != nil {
+ return nil, nameErr(b, err)
+ }
+ return Bool(diff), nil
+}
+
+// https://github.com/google/starlark-go/blob/master/doc/spec.md#set_issuperset.
+func set_issuperset(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
+ var other Iterable
+ if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil {
+ return nil, err
+ }
+ iter := other.Iterate()
+ defer iter.Done()
+ diff, err := b.Receiver().(*Set).IsSuperset(iter)
+ if err != nil {
+ return nil, nameErr(b, err)
+ }
+ return Bool(diff), nil
+}
+
// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·discard.
func set_discard(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
var k Value
@@ -2252,6 +2319,21 @@ func set_remove(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error
return nil, nameErr(b, "missing key")
}
+// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·symmetric_difference.
+func set_symmetric_difference(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
+ var other Iterable
+ if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil {
+ return nil, err
+ }
+ iter := other.Iterate()
+ defer iter.Done()
+ diff, err := b.Receiver().(*Set).SymmetricDifference(iter)
+ if err != nil {
+ return nil, nameErr(b, err)
+ }
+ return diff, nil
+}
+
// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·union.
func set_union(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
var iterable Iterable
diff --git a/starlark/testdata/builtins.star b/starlark/testdata/builtins.star
index c350c967..c7188fa3 100644
--- a/starlark/testdata/builtins.star
+++ b/starlark/testdata/builtins.star
@@ -196,7 +196,7 @@ assert.eq(getattr(hf, "x"), 2)
assert.eq(hf.x, 2)
# built-in types can have attributes (methods) too.
myset = set([])
-assert.eq(dir(myset), ["add", "clear", "discard", "pop", "remove", "union"])
+assert.eq(dir(myset), ["add", "clear", "difference", "discard", "intersection", "issubset", "issuperset", "pop", "remove", "symmetric_difference", "union"])
assert.true(hasattr(myset, "union"))
assert.true(not hasattr(myset, "onion"))
assert.eq(str(getattr(myset, "union")), "")
diff --git a/starlark/testdata/int.star b/starlark/testdata/int.star
index 46c0ad0d..f0e2cde3 100644
--- a/starlark/testdata/int.star
+++ b/starlark/testdata/int.star
@@ -74,7 +74,6 @@ def compound():
x %= 3
assert.eq(x, 2)
- # use resolve.AllowBitwise to enable the ops:
x = 2
x &= 1
assert.eq(x, 0)
@@ -197,7 +196,6 @@ assert.fails(lambda: int("0x-4", 16), "invalid literal with base 16: 0x-4")
# bitwise union (int|int), intersection (int&int), XOR (int^int), unary not (~int),
# left shift (int<>int).
-# use resolve.AllowBitwise to enable the ops.
# TODO(adonovan): this is not yet in the Starlark spec,
# but there is consensus that it should be.
assert.eq(1 | 2, 3)
diff --git a/starlark/testdata/set.star b/starlark/testdata/set.star
index 3dcde3c1..303b4472 100644
--- a/starlark/testdata/set.star
+++ b/starlark/testdata/set.star
@@ -1,5 +1,5 @@
# Tests of Starlark 'set'
-# option:set
+# option:set option:globalreassign
# Sets are not a standard part of Starlark, so the features
# tested in this file must be enabled in the application by setting
@@ -9,9 +9,7 @@
# TODO(adonovan): support set mutation:
# - del set[k]
-# - set.remove
# - set.update
-# - set.clear
# - set += iterable, perhaps?
# Test iterator invalidation.
@@ -48,7 +46,7 @@ y = set([3, 4, 5])
# set + any is not defined
assert.fails(lambda : x + y, "unknown.*: set \\+ set")
-# set | set (use resolve.AllowBitwise to enable it)
+# set | set
assert.eq(list(set("a".elems()) | set("b".elems())), ["a", "b"])
assert.eq(list(set("ab".elems()) | set("bc".elems())), ["a", "b", "c"])
assert.fails(lambda : set() | [], "unknown binary op: set | list")
@@ -67,12 +65,16 @@ assert.eq(list(x.union([5, 1])), [1, 2, 3, 5])
assert.eq(list(x.union((6, 5, 4))), [1, 2, 3, 6, 5, 4])
assert.fails(lambda : x.union([1, 2, {}]), "unhashable type: dict")
-# intersection, set & set (use resolve.AllowBitwise to enable it)
+# intersection, set & set or set.intersection(iterable)
assert.eq(list(set("a".elems()) & set("b".elems())), [])
assert.eq(list(set("ab".elems()) & set("bc".elems())), ["b"])
+assert.eq(list(set("a".elems()).intersection("b".elems())), [])
+assert.eq(list(set("ab".elems()).intersection("bc".elems())), ["b"])
-# symmetric difference, set ^ set (use resolve.AllowBitwise to enable it)
+# symmetric difference, set ^ set or set.symmetric_difference(iterable)
assert.eq(set([1, 2, 3]) ^ set([4, 5, 3]), set([1, 2, 4, 5]))
+assert.eq(set([1,2,3,4]).symmetric_difference([3,4,5,6]), set([1,2,5,6]))
+assert.eq(set([1,2,3,4]).symmetric_difference(set([])), set([1,2,3,4]))
def test_set_augmented_assign():
x = set([1, 2, 3])
@@ -100,7 +102,6 @@ assert.eq(x, x)
assert.eq(y, y)
assert.true(x != y)
assert.eq(set([1, 2, 3]), set([3, 2, 1]))
-assert.fails(lambda : x < y, "set < set not implemented")
# iteration
assert.true(type([elem for elem in x]), "list")
@@ -154,7 +155,6 @@ pop_set.add(2)
freeze(pop_set)
assert.fails(lambda: pop_set.pop(), "pop: cannot delete from frozen hash table")
-
# clear
clear_set = set([1,2,3])
clear_set.clear()
@@ -165,3 +165,34 @@ assert.eq(clear_set.clear(), None)
other_clear_set = set([1,2,3])
freeze(other_clear_set)
assert.fails(lambda: other_clear_set.clear(), "clear: cannot clear frozen hash table")
+
+# difference: set - set or set.difference(iterable)
+assert.eq(set([1,2,3,4]).difference([1,2,3,4]), set([]))
+assert.eq(set([1,2,3,4]).difference([1,2]), set([3,4]))
+assert.eq(set([1,2,3,4]).difference([]), set([1,2,3,4]))
+assert.eq(set([1,2,3,4]).difference(set([1,2,3])), set([4]))
+
+assert.eq(set([1,2,3,4]) - set([1,2,3,4]), set())
+assert.eq(set([1,2,3,4]) - set([1,2]), set([3,4]))
+
+# issuperset: set >= set or set.issuperset(iterable)
+assert.true(set([1,2,3]).issuperset([1,2]))
+assert.true(not set([1,2,3]).issuperset(set([1,2,4])))
+assert.true(set([1,2,3]) >= set([1,2,3]))
+assert.true(set([1,2,3]) >= set([1,2]))
+assert.true(not set([1,2,3]) >= set([1,2,4]))
+
+# proper superset: set > set
+assert.true(set([1, 2, 3]) > set([1, 2]))
+assert.true(not set([1,2, 3]) > set([1, 2, 3]))
+
+# issubset: set <= set or set.issubset(iterable)
+assert.true(set([1,2]).issubset([1,2,3]))
+assert.true(not set([1,2,3]).issubset(set([1,2,4])))
+assert.true(set([1,2,3]) <= set([1,2,3]))
+assert.true(set([1,2]) <= set([1,2,3]))
+assert.true(not set([1,2,3]) <= set([1,2,4]))
+
+# proper subset: set < set
+assert.true(set([1,2]) < set([1,2,3]))
+assert.true(not set([1,2,3]) < set([1,2,3]))
diff --git a/starlark/value.go b/starlark/value.go
index 3ceacc60..db9ba113 100644
--- a/starlark/value.go
+++ b/starlark/value.go
@@ -1134,6 +1134,34 @@ func (x *Set) CompareSameType(op syntax.Token, y_ Value, depth int) (bool, error
case syntax.NEQ:
ok, err := setsEqual(x, y, depth)
return !ok, err
+ case syntax.GE: // superset
+ if x.Len() < y.Len() {
+ return false, nil
+ }
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.IsSuperset(iter)
+ case syntax.LE: // subset
+ if x.Len() > y.Len() {
+ return false, nil
+ }
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.IsSubset(iter)
+ case syntax.GT: // proper superset
+ if x.Len() <= y.Len() {
+ return false, nil
+ }
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.IsSuperset(iter)
+ case syntax.LT: // proper subset
+ if x.Len() >= y.Len() {
+ return false, nil
+ }
+ iter := y.Iterate()
+ defer iter.Done()
+ return x.IsSubset(iter)
default:
return false, fmt.Errorf("%s %s %s not implemented", x.Type(), op, y.Type())
}
@@ -1151,11 +1179,28 @@ func setsEqual(x, y *Set, depth int) (bool, error) {
return true, nil
}
-func (s *Set) Union(iter Iterator) (Value, error) {
+func setFromIterator(iter Iterator) (*Set, error) {
+ var x Value
+ set := new(Set)
+ for iter.Next(&x) {
+ err := set.Insert(x)
+ if err != nil {
+ return set, err
+ }
+ }
+ return set, nil
+}
+
+func (s *Set) clone() *Set {
set := new(Set)
for e := s.ht.head; e != nil; e = e.next {
set.Insert(e.key) // can't fail
}
+ return set
+}
+
+func (s *Set) Union(iter Iterator) (Value, error) {
+ set := s.clone()
var x Value
for iter.Next(&x) {
if err := set.Insert(x); err != nil {
@@ -1165,6 +1210,74 @@ func (s *Set) Union(iter Iterator) (Value, error) {
return set, nil
}
+func (s *Set) Difference(other Iterator) (Value, error) {
+ diff := s.clone()
+ var x Value
+ for other.Next(&x) {
+ if _, err := diff.Delete(x); err != nil {
+ return nil, err
+ }
+ }
+ return diff, nil
+}
+
+func (s *Set) IsSuperset(other Iterator) (bool, error) {
+ var x Value
+ for other.Next(&x) {
+ found, err := s.Has(x)
+ if err != nil {
+ return false, err
+ }
+ if !found {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+func (s *Set) IsSubset(other Iterator) (bool, error) {
+ otherset, err := setFromIterator(other)
+ if err != nil {
+ return false, err
+ }
+ iter := s.Iterate()
+ defer iter.Done()
+ return otherset.IsSuperset(iter)
+}
+
+func (s *Set) Intersection(other Iterator) (Value, error) {
+ intersect := new(Set)
+ var x Value
+ for other.Next(&x) {
+ found, err := s.Has(x)
+ if err != nil {
+ return nil, err
+ }
+ if found {
+ err = intersect.Insert(x)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+ return intersect, nil
+}
+
+func (s *Set) SymmetricDifference(other Iterator) (Value, error) {
+ diff := s.clone()
+ var x Value
+ for other.Next(&x) {
+ found, err := diff.Delete(x)
+ if err != nil {
+ return nil, err
+ }
+ if !found {
+ diff.Insert(x)
+ }
+ }
+ return diff, nil
+}
+
// toString returns the string form of value v.
// It may be more efficient than v.String() for larger values.
func toString(v Value) string {