diff --git a/lib/proto/proto.go b/lib/proto/proto.go index 36f00233..b5b7bac8 100644 --- a/lib/proto/proto.go +++ b/lib/proto/proto.go @@ -9,21 +9,22 @@ // // This package defines several types of Starlark value: // -// Message -- a protocol message -// RepeatedField -- a repeated field of a message, like a list +// Message -- a protocol message +// RepeatedField -- a repeated field of a message, like a list +// MapField -- a map field of a message, like a dict // -// FileDescriptor -- information about a .proto file -// FieldDescriptor -- information about a message field (or extension field) -// MessageDescriptor -- information about the type of a message -// EnumDescriptor -- information about an enumerated type -// EnumValueDescriptor -- a value of an enumerated type +// FileDescriptor -- information about a .proto file +// FieldDescriptor -- information about a message field (or extension field) +// MessageDescriptor -- information about the type of a message +// EnumDescriptor -- information about an enumerated type +// EnumValueDescriptor -- a value of an enumerated type // // A Message value is a wrapper around a protocol message instance. // Starlark programs may access and update Messages using dot notation: // -// x = msg.field -// msg.field = x + 1 -// msg.field += 1 +// x = msg.field +// msg.field = x + 1 +// msg.field += 1 // // Assignments to message fields perform dynamic checks on the type and // range of the value to ensure that the message is at all times valid. @@ -35,31 +36,39 @@ // performs a dynamic check to ensure that the RepeatedField holds // only elements of the correct type. // -// type(msg.uint32s) # "proto.repeated" -// msg.uint32s[0] = 1 -// msg.uint32s[0] = -1 # error: invalid uint32: -1 +// type(msg.uint32s) # "proto.repeated" +// msg.uint32s[0] = 1 +// msg.uint32s[0] = -1 # error: invalid uint32: -1 // // Any iterable may be assigned to a repeated field of a message. If // the iterable is itself a value of type RepeatedField, the message // field holds a reference to it. // -// msg2.uint32s = msg.uint32s # both messages share one RepeatedField -// msg.uint32s[0] = 123 -// print(msg2.uint32s[0]) # "123" +// msg2.uint32s = msg.uint32s # both messages share one RepeatedField +// msg.uint32s[0] = 123 +// print(msg2.uint32s[0]) # "123" // // The RepeatedFields' element types must match. // It is not enough for the values to be merely valid: // -// msg.uint32s = [1, 2, 3] # makes a copy -// msg.uint64s = msg.uint32s # error: repeated field has wrong type -// msg.uint64s = list(msg.uint32s) # ok; makes a copy +// msg.uint32s = [1, 2, 3] # makes a copy +// msg.uint64s = msg.uint32s # error: repeated field has wrong type +// msg.uint64s = list(msg.uint32s) # ok; makes a copy // // For all other iterables, a new RepeatedField is constructed from the // elements of the iterable. // -// msg.uints32s = [1, 2, 3] -// print(type(msg.uints32s)) # "proto.repeated" +// msg.uints32s = [1, 2, 3] +// print(type(msg.uints32s)) # "proto.repeated" // +// The value of a map field of a message is represented by the dict-like data +// type, MapField. Its items can be set and access in the usual ways. As with +// assignments to message fields, and assignment to a MapField performs a +// dynamic check to ensure the key and value are of the correct type. +// +// msg.string_map = {"a": "A", "b", "B"} +// msg.string_map["c"] = "C" +// print(type(msg.string_map)) # "proto.map" // // To construct a Message from encoded binary or text data, call // Unmarshal or UnmarshalText. These two functions are exposed to @@ -72,16 +81,12 @@ // are frozen. // // TODO(adonovan): document descriptors, enums, message instantiation. -// -// See proto_test.go for an example of how to use the 'proto' -// module in an application that embeds Starlark. -// package proto // TODO(adonovan): Go and Starlark API improvements: // - Make Message and RepeatedField comparable. // (NOTE: proto.Equal works only with generated message types.) -// - Support maps, oneof, any. But not messageset if we can avoid it. +// - Support oneof, any. But not messageset if we can avoid it. // - Support "well-known types". // - Defend against cycles in object graph. // - Test missing required fields in marshalling. @@ -111,8 +116,8 @@ import ( // for a Starlark thread to use this package. // // For example: -// SetPool(thread, protoregistry.GlobalFiles) // +// SetPool(thread, protoregistry.GlobalFiles) func SetPool(thread *starlark.Thread, pool DescriptorPool) { thread.SetLocal(contextKey, pool) } @@ -305,10 +310,9 @@ func getFieldStarlark(thread *starlark.Thread, fn *starlark.Builtin, args starla // When a message descriptor is called, it returns a new instance of the // protocol message it describes. // -// Message(msg) -- return a shallow copy of an existing message -// Message(k=v, ...) -- return a new message with the specified fields -// Message(dict(...)) -- return a new message with the specified fields -// +// Message(msg) -- return a shallow copy of an existing message +// Message(k=v, ...) -- return a new message with the specified fields +// Message(dict(...)) -- return a new message with the specified fields func (d MessageDescriptor) CallInternal(thread *starlark.Thread, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { dest := &Message{ msg: newMessage(d.Desc), @@ -597,6 +601,15 @@ func toStarlark(typ protoreflect.FieldDescriptor, x protoreflect.Value, frozen * frozen: frozen, } } + + if mp, ok := x.Interface().(protoreflect.Map); ok { + return &MapField{ + typ: typ, + mp: mp, + frozen: frozen, + } + } + return toStarlark1(typ, x, frozen) } @@ -654,7 +667,7 @@ func toStarlark1(typ protoreflect.FieldDescriptor, x protoreflect.Value, frozen // or RepeatedField wrapper values derived from it. type Message struct { msg protoreflect.Message // any concrete type is allowed - frozen *bool // shared by a group of related Message/RepeatedField wrappers + frozen *bool // shared by a group of related Message/RepeatedField/MapField wrappers } // Message returns the wrapped message. @@ -791,6 +804,11 @@ func defaultValue(fdesc protoreflect.FieldDescriptor) starlark.Value { return &RepeatedField{typ: fdesc, list: emptyList{}, frozen: &frozen} } + // The default value of a map field is an empty map. + if fdesc.IsMap() { + return &MapField{typ: fdesc, mp: emptyMap{}, frozen: &frozen} + } + // The zero value for a message type is an empty instance of that message. if desc := fdesc.Message(); desc != nil { return &Message{msg: newMessage(desc), frozen: &frozen} @@ -806,6 +824,17 @@ type emptyList struct{ protoreflect.List } func (emptyList) Len() int { return 0 } +type emptyMap struct{ protoreflect.Map } + +func (emptyMap) Len() int { return 0 } + +func (emptyMap) Get(_ protoreflect.MapKey) protoreflect.Value { + // An invalid value, signalling the supplied key does not exist. + return protoreflect.Value{} +} + +func (emptyMap) Range(_ func(protoreflect.MapKey, protoreflect.Value) bool) {} + // newMessage returns a new empty instance of the message type described by desc. func newMessage(desc protoreflect.MessageDescriptor) protoreflect.Message { // If desc refers to a built-in message, @@ -898,18 +927,48 @@ type RepeatedField struct { itercount int } +var _ starlark.Iterable = (*RepeatedField)(nil) var _ starlark.HasSetIndex = (*RepeatedField)(nil) +var _ starlark.HasAttrs = (*RepeatedField)(nil) + +func (rf *RepeatedField) AttrNames() []string { + return []string{"append"} +} + +func (rf *RepeatedField) Attr(name string) (starlark.Value, error) { + if name != "append" { + return nil, nil + } + return starlark.NewBuiltin("append", repeatedFieldAppend).BindReceiver(rf), nil +} + +func repeatedFieldAppend(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var object starlark.Value + if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &object); err != nil { + return nil, err + } + rf := b.Receiver().(*RepeatedField) + + if err := rf.checkMutable("append to"); err != nil { + return nil, fmt.Errorf("%s: %v", b.Name(), err) + } + + po, err := toProto(rf.typ, object) + if err != nil { + return nil, fmt.Errorf("appending to repeated field: %v", err) + } + rf.list.Append(po) + + return starlark.None, nil +} func (rf *RepeatedField) Type() string { return fmt.Sprintf("proto.repeated<%s>", typeString(rf.typ)) } func (rf *RepeatedField) SetIndex(i int, v starlark.Value) error { - if *rf.frozen { - return fmt.Errorf("cannot insert value in frozen repeated field") - } - if rf.itercount > 0 { - return fmt.Errorf("cannot insert value in repeated field with active iterators") + if err := rf.checkMutable("insert value in"); err != nil { + return err } x, err := toProto(rf.typ, v) if err != nil { @@ -922,6 +981,18 @@ func (rf *RepeatedField) SetIndex(i int, v starlark.Value) error { return nil } +// checkMutable reports an error if the repeated field should not be mutated. +// verb+" repeated field" should describe the operation. +func (rf *RepeatedField) checkMutable(verb string) error { + if *rf.frozen { + return fmt.Errorf("cannot %s frozen repeated field", verb) + } + if rf.itercount > 0 { + return fmt.Errorf("cannot %s repeated field during iteration", verb) + } + return nil +} + func (rf *RepeatedField) Freeze() { *rf.frozen = true } func (rf *RepeatedField) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable: %s", rf.Type()) } func (rf *RepeatedField) Index(i int) starlark.Value { @@ -969,6 +1040,145 @@ func (it *repeatedFieldIterator) Done() { } } +type MapField struct { + typ protoreflect.FieldDescriptor + + mp protoreflect.Map + frozen *bool + itercount int +} + +var _ starlark.HasSetKey = (*MapField)(nil) +var _ starlark.IterableMapping = (*MapField)(nil) + +func (mf *MapField) Type() string { + return fmt.Sprintf("proto.map<%s, %s>", typeString(mf.typ.MapKey()), typeString(mf.typ.MapValue())) +} + +func (mf *MapField) SetKey(k, v starlark.Value) error { + if err := mf.checkMutable("set key in"); err != nil { + return err + } + + kx, err := toProto(mf.typ.MapKey(), k) + if err != nil { + return fmt.Errorf("setting key of map field: %v", err) + } + vx, err := toProto(mf.typ.MapValue(), v) + if err != nil { + return fmt.Errorf("setting value of map field: %v", err) + } + + mf.mp.Set(kx.MapKey(), vx) + return nil +} + +// checkMutable reports an error if the map field should not be mutated. +// verb+" map field" should describe the operation. +func (mf *MapField) checkMutable(verb string) error { + if *mf.frozen { + return fmt.Errorf("cannot %s frozen map field", verb) + } + if mf.itercount > 0 { + return fmt.Errorf("cannot %s map field during iteration", verb) + } + return nil +} + +func (mf *MapField) Get(k starlark.Value) (starlark.Value, bool, error) { + pk, err := toProto(mf.typ.MapKey(), k) + if err != nil { + return nil, false, fmt.Errorf("getting value of map field: %v", err) + } + + v := mf.mp.Get(pk.MapKey()) + if !v.IsValid() { + return nil, false, nil + } + + return toStarlark1(mf.typ.MapValue(), v, mf.frozen), true, nil +} + +func (mf *MapField) Freeze() { *mf.frozen = true } +func (mf *MapField) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable: %s", mf.Type()) } + +func (mf *MapField) Iterate() starlark.Iterator { + if !*mf.frozen { + mf.itercount++ + } + + // TODO(negz): Should we store only keys in the iterator? It doesn't + // consume the values. + return &mapFieldIterator{mf, mf.Items(), 0} +} + +func (mf *MapField) Items() []starlark.Tuple { + out := make([]starlark.Tuple, 0, mf.mp.Len()) + + mf.mp.Range(func(mk protoreflect.MapKey, v protoreflect.Value) bool { + out = append(out, starlark.Tuple{ + toStarlark1(mf.typ.MapKey(), mk.Value(), mf.frozen), + toStarlark1(mf.typ.MapValue(), v, mf.frozen), + }) + return true // Keep iterating. + }) + + // A map key can be any scalar protobuf type except floats and bytes. + // In practice in starlark they should be Int, String, or Bool and thus + // either TotallyOrdered or Comparable. None of these return errors when + // compared to the same type. + sort.Slice(out, func(i, j int) bool { + less, _ := starlark.Compare(syntax.LT, out[i][0], out[j][0]) + return less + }) + + return out +} + +func (mf *MapField) Len() int { return mf.mp.Len() } + +func (mf *MapField) String() string { + // We want to use {k1: v1, k2: v2} notation, like a dict. + buf := new(bytes.Buffer) + buf.WriteByte('{') + + for i, kv := range mf.Items() { + if i > 0 { + buf.WriteString(", ") + } + buf.WriteString(kv[0].String()) + buf.WriteString(": ") + buf.WriteString(kv[1].String()) + } + + buf.WriteByte('}') + return buf.String() +} + +func (rf *MapField) Truth() starlark.Bool { return rf.mp.Len() > 0 } + +type mapFieldIterator struct { + mf *MapField + + items []starlark.Tuple + i int +} + +func (it *mapFieldIterator) Next(p *starlark.Value) bool { + if it.i < len(it.items) { + *p = it.items[it.i][0] // We're only iterating over keys. + it.i++ + return true + } + return false +} + +func (it *mapFieldIterator) Done() { + if !*it.mf.frozen { + it.mf.itercount-- + } +} + func writeString(buf *bytes.Buffer, fdesc protoreflect.FieldDescriptor, v protoreflect.Value) { // TODO(adonovan): opt: don't materialize the Starlark value. // TODO(adonovan): skip message type when printing submessages? {...}? @@ -1219,11 +1429,10 @@ func enumValueOf(enum protoreflect.EnumDescriptor, x starlark.Value) (protorefle // // An EnumValueDescriptor has the following fields: // -// index -- int, index of this value within the enum sequence -// name -- string, name of this enum value -// number -- int, numeric value of this enum value -// type -- EnumDescriptor, the enum type to which this value belongs -// +// index -- int, index of this value within the enum sequence +// name -- string, name of this enum value +// number -- int, numeric value of this enum value +// type -- EnumDescriptor, the enum type to which this value belongs type EnumValueDescriptor struct { Desc protoreflect.EnumValueDescriptor } diff --git a/starlark/eval_test.go b/starlark/eval_test.go index 66786711..581f6bd3 100644 --- a/starlark/eval_test.go +++ b/starlark/eval_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math" + "os" "os/exec" "path/filepath" "reflect" @@ -19,15 +20,16 @@ import ( "go.starlark.net/internal/chunkedfile" "go.starlark.net/lib/json" starlarkmath "go.starlark.net/lib/math" - "go.starlark.net/lib/proto" + starlarkproto "go.starlark.net/lib/proto" "go.starlark.net/lib/time" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" "go.starlark.net/starlarktest" "go.starlark.net/syntax" - "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" - _ "google.golang.org/protobuf/types/descriptorpb" // example descriptor needed for lib/proto tests + "google.golang.org/protobuf/types/descriptorpb" ) // A test may enable non-standard options by containing (e.g.) "option:recursion". @@ -115,7 +117,23 @@ func TestExecFile(t *testing.T) { testdata := starlarktest.DataFile("starlark", ".") thread := &starlark.Thread{Load: load} starlarktest.SetReporter(thread, t) - proto.SetPool(thread, protoregistry.GlobalFiles) + + // This proto is used for the proto.star tests. It's generated by running: + // protoc --descriptor_set_out=test.fds test.proto + data, err := os.ReadFile("testdata/proto/test.fds") + if err != nil { + t.Fatal(err) + } + fds := &descriptorpb.FileDescriptorSet{} + if err := proto.Unmarshal(data, fds); err != nil { + t.Fatal(err) + } + pool, err := protodesc.NewFiles(fds) + if err != nil { + t.Fatal(err) + } + starlarkproto.SetPool(thread, pool) + for _, file := range []string{ "testdata/assign.star", "testdata/bool.star", @@ -207,7 +225,7 @@ func load(thread *starlark.Thread, module string) (starlark.StringDict, error) { return starlark.StringDict{"math": starlarkmath.Module}, nil } if module == "proto.star" { - return starlark.StringDict{"proto": proto.Module}, nil + return starlark.StringDict{"proto": starlarkproto.Module}, nil } // TODO(adonovan): test load() using this execution path. diff --git a/starlark/testdata/proto.star b/starlark/testdata/proto.star index f1400820..b008d4c7 100644 --- a/starlark/testdata/proto.star +++ b/starlark/testdata/proto.star @@ -3,12 +3,27 @@ load("assert.star", "assert") load("proto.star", "proto") -schema = proto.file("google/protobuf/descriptor.proto") +schema = proto.file("test.proto") -m = schema.FileDescriptorProto(name = "somename.proto", dependency = ["a", "b", "c"]) +m = schema.Test( + string_field="I'm a string!", + int32_field=42, + repeated_field=["a", "b", "c"], + map_field={"a": "A", "b": "B"} +) assert.eq(type(m), "proto.Message") -assert.eq(m.name, "somename.proto") -assert.eq(list(m.dependency), ["a", "b", "c"]) -m.dependency = ["d", "e"] -assert.eq(list(m.dependency), ["d", "e"]) +assert.eq(type(m.repeated_field), "proto.repeated") +assert.eq(type(m.map_field), "proto.map") +assert.eq(m.string_field, "I'm a string!") +assert.eq(m.int32_field, 42) + +assert.eq(list(m.repeated_field), ["a", "b", "c"]) +m.repeated_field = ["d", "e"] +assert.eq(list(m.repeated_field), ["d", "e"]) +m.repeated_field.append("f") +assert.eq(list(m.repeated_field), ["d", "e", "f"]) + +assert.eq(dict(m.map_field), {"a": "A", "b": "B"}) +m.map_field["c"] = "C" +assert.eq(dict(m.map_field), {"a": "A", "b": "B", "c": "C"}) diff --git a/starlark/testdata/proto/test.fds b/starlark/testdata/proto/test.fds new file mode 100644 index 00000000..2d94d85a --- /dev/null +++ b/starlark/testdata/proto/test.fds @@ -0,0 +1,13 @@ + +ª + +test.protogo.starlark.net.testdata"ù +Test! + string_field ( R stringField + int32_field (R +int32Field% +repeated_field ( R repeatedFieldI + map_field ( 2,.go.starlark.net.testdata.Test.MapFieldEntryRmapField; + MapFieldEntry +key ( Rkey +value ( Rvalue:8bproto3 \ No newline at end of file diff --git a/starlark/testdata/proto/test.proto b/starlark/testdata/proto/test.proto new file mode 100644 index 00000000..52f49024 --- /dev/null +++ b/starlark/testdata/proto/test.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package go.starlark.net.testdata; + +message Test { + string string_field = 1; + int32 int32_field = 2; + repeated string repeated_field = 3; + map map_field = 4; +} +