diff --git a/doc/spec.md b/doc/spec.md index c98ee173..eb8e1610 100644 --- a/doc/spec.md +++ b/doc/spec.md @@ -155,9 +155,14 @@ reproducibility is paramount, such as build tools. * [list·remove](#list·remove) * [set·add](#set·add) * [set·clear](#set·clear) + * [set·difference](#set·difference) * [set·discard](#set·discard) + * [set·intersection](#set·intersection) + * [set·issubset](#set·issubset) + * [set·issuperset](#set·issuperset) * [set·pop](#set·pop) * [set·remove](#set·remove) + * [set·symmetric_difference](#set·symmetric_difference) * [set·union](#set·union) * [string·capitalize](#string·capitalize) * [string·codepoint_ords](#string·codepoint_ords) @@ -975,9 +980,14 @@ A set has these methods: * [`add`](#set·add) * [`clear`](#set·clear) +* [`difference`](#set·difference) * [`discard`](#set·discard) +* [`intersection`](#set·intersection) +* [`issubset`](#set·issubset) +* [`issuperset`](#set·issuperset) * [`pop`](#set·pop) * [`remove`](#set·remove) +* [`symmetric_difference`](#set·symmetric_difference) * [`union`](#set·union) @@ -1995,6 +2005,11 @@ which breaks several mathematical identities. For example, if `x` is a `NaN` value, the comparisons `x < y`, `x == y`, and `x > y` all yield false for all values of `y`. +When used to compare two `set` objects, the `<=`, and `>=` operators will report +whether one set is a subset or superset of another. Similarly, using `<` or `>` will +report whether a set is a proper subset or superset of another, thus `x > y` is +equivalent to `x >= y and x != y`. + Applications may define additional types that support ordered comparison. @@ -2045,6 +2060,8 @@ Sets int & int # bitwise intersection (AND) set & set # set intersection set ^ set # set symmetric difference + set - set # set difference + Dict dict | dict # ordered union @@ -2115,6 +2132,7 @@ Implementations may impose a limit on the second operand of a left shift. set([1, 2]) & set([2, 3]) # set([2]) set([1, 2]) | set([2, 3]) # set([1, 2, 3]) set([1, 2]) ^ set([2, 3]) # set([1, 3]) +set([1, 2]) - set([2, 3]) # set([1]) ``` Implementation note: @@ -3782,6 +3800,18 @@ x.clear(2) # None x # set([]) ``` + +### set·difference + +`S.difference(y)` returns a new set into which have been inserted all the elements of set S which are not in y. + +y can be any type of iterable (e.g. set, list, tuple). + +```python +x = set([1, 2, 3]) +x.difference([3, 4, 5]) # set([1, 2]) +``` + ### set·discard @@ -3798,6 +3828,44 @@ x.discard(2) # None x # set([1, 3]) ``` + +### set·intersection + +`S.intersection(y)` returns a new set into which have been inserted all the elements of set S which are also in y. + +y can be any type of iterable (e.g. set, list, tuple). + +```python +x = set([1, 2, 3]) +x.intersection([3, 4, 5]) # set([3]) +``` + + +### set·issubset + +`S.issubset(y)` returns True if all items in S are also in y, otherwise it returns False. + +y can be any type of iterable (e.g. set, list, tuple). + +```python +x = set([1, 2]) +x.issubset([1, 2, 3]) # True +x.issubset([1, 3, 4]) # False +``` + + +### set·issuperset + +`S.issuperset(y)` returns True if all items in y are also in S, otherwise it returns False. + +y can be any type of iterable (e.g. set, list, tuple). + +```python +x = set([1, 2, 3]) +x.issuperset([1, 2]) # True +x.issuperset([1, 3, 4]) # False +``` + ### set·pop @@ -3826,6 +3894,18 @@ x # set([1, 3]) x.remove(2) # error: element not found ``` + +### set·symmetric_difference + +`S.symmetric_difference(y)` creates a new set into which is inserted all of the items which are in S but not y, followed by all of the items which are in y but not S. + +y can be any type of iterable (e.g. set, list, tuple). + +```python +x = set([1, 2, 3]) +x.symmetric_difference([3, 4, 5]) # set([1, 2, 4, 5]) +``` + ### set·union diff --git a/starlark/eval.go b/starlark/eval.go index 706c6249..3ab08e47 100644 --- a/starlark/eval.go +++ b/starlark/eval.go @@ -826,6 +826,12 @@ func Binary(op syntax.Token, x, y Value) (Value, error) { } return x - yf, nil } + case *Set: // difference + if y, ok := y.(*Set); ok { + iter := y.Iterate() + defer iter.Done() + return x.Difference(iter) + } } case syntax.STAR: @@ -1097,17 +1103,9 @@ func Binary(op syntax.Token, x, y Value) (Value, error) { } case *Set: // intersection if y, ok := y.(*Set); ok { - set := new(Set) - if x.Len() > y.Len() { - x, y = y, x // opt: range over smaller set - } - for xe := x.ht.head; xe != nil; xe = xe.next { - // Has, Insert cannot fail here. - if found, _ := y.Has(xe.key); found { - set.Insert(xe.key) - } - } - return set, nil + iter := y.Iterate() + defer iter.Done() + return x.Intersection(iter) } } @@ -1119,18 +1117,9 @@ func Binary(op syntax.Token, x, y Value) (Value, error) { } case *Set: // symmetric difference if y, ok := y.(*Set); ok { - set := new(Set) - for xe := x.ht.head; xe != nil; xe = xe.next { - if found, _ := y.Has(xe.key); !found { - set.Insert(xe.key) - } - } - for ye := y.ht.head; ye != nil; ye = ye.next { - if found, _ := x.Has(ye.key); !found { - set.Insert(ye.key) - } - } - return set, nil + iter := y.Iterate() + defer iter.Done() + return x.SymmetricDifference(iter) } } diff --git a/starlark/library.go b/starlark/library.go index 0e3c7bf7..ef032ee2 100644 --- a/starlark/library.go +++ b/starlark/library.go @@ -140,12 +140,17 @@ var ( } setMethods = map[string]*Builtin{ - "add": NewBuiltin("add", set_add), - "clear": NewBuiltin("clear", set_clear), - "discard": NewBuiltin("discard", set_discard), - "pop": NewBuiltin("pop", set_pop), - "remove": NewBuiltin("remove", set_remove), - "union": NewBuiltin("union", set_union), + "add": NewBuiltin("add", set_add), + "clear": NewBuiltin("clear", set_clear), + "difference": NewBuiltin("difference", set_difference), + "discard": NewBuiltin("discard", set_discard), + "intersection": NewBuiltin("intersection", set_intersection), + "issubset": NewBuiltin("issubset", set_issubset), + "issuperset": NewBuiltin("issuperset", set_issuperset), + "pop": NewBuiltin("pop", set_pop), + "remove": NewBuiltin("remove", set_remove), + "symmetric_difference": NewBuiltin("symmetric_difference", set_symmetric_difference), + "union": NewBuiltin("union", set_union), } ) @@ -2204,6 +2209,68 @@ func set_clear(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) return None, nil } +// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·difference. +func set_difference(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { + // TODO: support multiple others: s.difference(*others) + var other Iterable + if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil { + return nil, err + } + iter := other.Iterate() + defer iter.Done() + diff, err := b.Receiver().(*Set).Difference(iter) + if err != nil { + return nil, nameErr(b, err) + } + return diff, nil +} + +// https://github.com/google/starlark-go/blob/master/doc/spec.md#set_intersection. +func set_intersection(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { + // TODO: support multiple others: s.difference(*others) + var other Iterable + if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil { + return nil, err + } + iter := other.Iterate() + defer iter.Done() + diff, err := b.Receiver().(*Set).Intersection(iter) + if err != nil { + return nil, nameErr(b, err) + } + return diff, nil +} + +// https://github.com/google/starlark-go/blob/master/doc/spec.md#set_issubset. +func set_issubset(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { + var other Iterable + if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil { + return nil, err + } + iter := other.Iterate() + defer iter.Done() + diff, err := b.Receiver().(*Set).IsSubset(iter) + if err != nil { + return nil, nameErr(b, err) + } + return Bool(diff), nil +} + +// https://github.com/google/starlark-go/blob/master/doc/spec.md#set_issuperset. +func set_issuperset(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { + var other Iterable + if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil { + return nil, err + } + iter := other.Iterate() + defer iter.Done() + diff, err := b.Receiver().(*Set).IsSuperset(iter) + if err != nil { + return nil, nameErr(b, err) + } + return Bool(diff), nil +} + // https://github.com/google/starlark-go/blob/master/doc/spec.md#set·discard. func set_discard(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { var k Value @@ -2252,6 +2319,21 @@ func set_remove(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error return nil, nameErr(b, "missing key") } +// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·symmetric_difference. +func set_symmetric_difference(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { + var other Iterable + if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &other); err != nil { + return nil, err + } + iter := other.Iterate() + defer iter.Done() + diff, err := b.Receiver().(*Set).SymmetricDifference(iter) + if err != nil { + return nil, nameErr(b, err) + } + return diff, nil +} + // https://github.com/google/starlark-go/blob/master/doc/spec.md#set·union. func set_union(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { var iterable Iterable diff --git a/starlark/testdata/builtins.star b/starlark/testdata/builtins.star index c350c967..c7188fa3 100644 --- a/starlark/testdata/builtins.star +++ b/starlark/testdata/builtins.star @@ -196,7 +196,7 @@ assert.eq(getattr(hf, "x"), 2) assert.eq(hf.x, 2) # built-in types can have attributes (methods) too. myset = set([]) -assert.eq(dir(myset), ["add", "clear", "discard", "pop", "remove", "union"]) +assert.eq(dir(myset), ["add", "clear", "difference", "discard", "intersection", "issubset", "issuperset", "pop", "remove", "symmetric_difference", "union"]) assert.true(hasattr(myset, "union")) assert.true(not hasattr(myset, "onion")) assert.eq(str(getattr(myset, "union")), "") diff --git a/starlark/testdata/int.star b/starlark/testdata/int.star index 46c0ad0d..f0e2cde3 100644 --- a/starlark/testdata/int.star +++ b/starlark/testdata/int.star @@ -74,7 +74,6 @@ def compound(): x %= 3 assert.eq(x, 2) - # use resolve.AllowBitwise to enable the ops: x = 2 x &= 1 assert.eq(x, 0) @@ -197,7 +196,6 @@ assert.fails(lambda: int("0x-4", 16), "invalid literal with base 16: 0x-4") # bitwise union (int|int), intersection (int&int), XOR (int^int), unary not (~int), # left shift (int<>int). -# use resolve.AllowBitwise to enable the ops. # TODO(adonovan): this is not yet in the Starlark spec, # but there is consensus that it should be. assert.eq(1 | 2, 3) diff --git a/starlark/testdata/set.star b/starlark/testdata/set.star index 3dcde3c1..303b4472 100644 --- a/starlark/testdata/set.star +++ b/starlark/testdata/set.star @@ -1,5 +1,5 @@ # Tests of Starlark 'set' -# option:set +# option:set option:globalreassign # Sets are not a standard part of Starlark, so the features # tested in this file must be enabled in the application by setting @@ -9,9 +9,7 @@ # TODO(adonovan): support set mutation: # - del set[k] -# - set.remove # - set.update -# - set.clear # - set += iterable, perhaps? # Test iterator invalidation. @@ -48,7 +46,7 @@ y = set([3, 4, 5]) # set + any is not defined assert.fails(lambda : x + y, "unknown.*: set \\+ set") -# set | set (use resolve.AllowBitwise to enable it) +# set | set assert.eq(list(set("a".elems()) | set("b".elems())), ["a", "b"]) assert.eq(list(set("ab".elems()) | set("bc".elems())), ["a", "b", "c"]) assert.fails(lambda : set() | [], "unknown binary op: set | list") @@ -67,12 +65,16 @@ assert.eq(list(x.union([5, 1])), [1, 2, 3, 5]) assert.eq(list(x.union((6, 5, 4))), [1, 2, 3, 6, 5, 4]) assert.fails(lambda : x.union([1, 2, {}]), "unhashable type: dict") -# intersection, set & set (use resolve.AllowBitwise to enable it) +# intersection, set & set or set.intersection(iterable) assert.eq(list(set("a".elems()) & set("b".elems())), []) assert.eq(list(set("ab".elems()) & set("bc".elems())), ["b"]) +assert.eq(list(set("a".elems()).intersection("b".elems())), []) +assert.eq(list(set("ab".elems()).intersection("bc".elems())), ["b"]) -# symmetric difference, set ^ set (use resolve.AllowBitwise to enable it) +# symmetric difference, set ^ set or set.symmetric_difference(iterable) assert.eq(set([1, 2, 3]) ^ set([4, 5, 3]), set([1, 2, 4, 5])) +assert.eq(set([1,2,3,4]).symmetric_difference([3,4,5,6]), set([1,2,5,6])) +assert.eq(set([1,2,3,4]).symmetric_difference(set([])), set([1,2,3,4])) def test_set_augmented_assign(): x = set([1, 2, 3]) @@ -100,7 +102,6 @@ assert.eq(x, x) assert.eq(y, y) assert.true(x != y) assert.eq(set([1, 2, 3]), set([3, 2, 1])) -assert.fails(lambda : x < y, "set < set not implemented") # iteration assert.true(type([elem for elem in x]), "list") @@ -154,7 +155,6 @@ pop_set.add(2) freeze(pop_set) assert.fails(lambda: pop_set.pop(), "pop: cannot delete from frozen hash table") - # clear clear_set = set([1,2,3]) clear_set.clear() @@ -165,3 +165,34 @@ assert.eq(clear_set.clear(), None) other_clear_set = set([1,2,3]) freeze(other_clear_set) assert.fails(lambda: other_clear_set.clear(), "clear: cannot clear frozen hash table") + +# difference: set - set or set.difference(iterable) +assert.eq(set([1,2,3,4]).difference([1,2,3,4]), set([])) +assert.eq(set([1,2,3,4]).difference([1,2]), set([3,4])) +assert.eq(set([1,2,3,4]).difference([]), set([1,2,3,4])) +assert.eq(set([1,2,3,4]).difference(set([1,2,3])), set([4])) + +assert.eq(set([1,2,3,4]) - set([1,2,3,4]), set()) +assert.eq(set([1,2,3,4]) - set([1,2]), set([3,4])) + +# issuperset: set >= set or set.issuperset(iterable) +assert.true(set([1,2,3]).issuperset([1,2])) +assert.true(not set([1,2,3]).issuperset(set([1,2,4]))) +assert.true(set([1,2,3]) >= set([1,2,3])) +assert.true(set([1,2,3]) >= set([1,2])) +assert.true(not set([1,2,3]) >= set([1,2,4])) + +# proper superset: set > set +assert.true(set([1, 2, 3]) > set([1, 2])) +assert.true(not set([1,2, 3]) > set([1, 2, 3])) + +# issubset: set <= set or set.issubset(iterable) +assert.true(set([1,2]).issubset([1,2,3])) +assert.true(not set([1,2,3]).issubset(set([1,2,4]))) +assert.true(set([1,2,3]) <= set([1,2,3])) +assert.true(set([1,2]) <= set([1,2,3])) +assert.true(not set([1,2,3]) <= set([1,2,4])) + +# proper subset: set < set +assert.true(set([1,2]) < set([1,2,3])) +assert.true(not set([1,2,3]) < set([1,2,3])) diff --git a/starlark/value.go b/starlark/value.go index 3ceacc60..db9ba113 100644 --- a/starlark/value.go +++ b/starlark/value.go @@ -1134,6 +1134,34 @@ func (x *Set) CompareSameType(op syntax.Token, y_ Value, depth int) (bool, error case syntax.NEQ: ok, err := setsEqual(x, y, depth) return !ok, err + case syntax.GE: // superset + if x.Len() < y.Len() { + return false, nil + } + iter := y.Iterate() + defer iter.Done() + return x.IsSuperset(iter) + case syntax.LE: // subset + if x.Len() > y.Len() { + return false, nil + } + iter := y.Iterate() + defer iter.Done() + return x.IsSubset(iter) + case syntax.GT: // proper superset + if x.Len() <= y.Len() { + return false, nil + } + iter := y.Iterate() + defer iter.Done() + return x.IsSuperset(iter) + case syntax.LT: // proper subset + if x.Len() >= y.Len() { + return false, nil + } + iter := y.Iterate() + defer iter.Done() + return x.IsSubset(iter) default: return false, fmt.Errorf("%s %s %s not implemented", x.Type(), op, y.Type()) } @@ -1151,11 +1179,28 @@ func setsEqual(x, y *Set, depth int) (bool, error) { return true, nil } -func (s *Set) Union(iter Iterator) (Value, error) { +func setFromIterator(iter Iterator) (*Set, error) { + var x Value + set := new(Set) + for iter.Next(&x) { + err := set.Insert(x) + if err != nil { + return set, err + } + } + return set, nil +} + +func (s *Set) clone() *Set { set := new(Set) for e := s.ht.head; e != nil; e = e.next { set.Insert(e.key) // can't fail } + return set +} + +func (s *Set) Union(iter Iterator) (Value, error) { + set := s.clone() var x Value for iter.Next(&x) { if err := set.Insert(x); err != nil { @@ -1165,6 +1210,74 @@ func (s *Set) Union(iter Iterator) (Value, error) { return set, nil } +func (s *Set) Difference(other Iterator) (Value, error) { + diff := s.clone() + var x Value + for other.Next(&x) { + if _, err := diff.Delete(x); err != nil { + return nil, err + } + } + return diff, nil +} + +func (s *Set) IsSuperset(other Iterator) (bool, error) { + var x Value + for other.Next(&x) { + found, err := s.Has(x) + if err != nil { + return false, err + } + if !found { + return false, nil + } + } + return true, nil +} + +func (s *Set) IsSubset(other Iterator) (bool, error) { + otherset, err := setFromIterator(other) + if err != nil { + return false, err + } + iter := s.Iterate() + defer iter.Done() + return otherset.IsSuperset(iter) +} + +func (s *Set) Intersection(other Iterator) (Value, error) { + intersect := new(Set) + var x Value + for other.Next(&x) { + found, err := s.Has(x) + if err != nil { + return nil, err + } + if found { + err = intersect.Insert(x) + if err != nil { + return nil, err + } + } + } + return intersect, nil +} + +func (s *Set) SymmetricDifference(other Iterator) (Value, error) { + diff := s.clone() + var x Value + for other.Next(&x) { + found, err := diff.Delete(x) + if err != nil { + return nil, err + } + if !found { + diff.Insert(x) + } + } + return diff, nil +} + // toString returns the string form of value v. // It may be more efficient than v.String() for larger values. func toString(v Value) string {