diff --git a/proto/utils.go b/internal/util/fieldmap.go similarity index 76% rename from proto/utils.go rename to internal/util/fieldmap.go index 8153c50b..1cc80b2d 100644 --- a/proto/utils.go +++ b/internal/util/fieldmap.go @@ -1,4 +1,20 @@ -package proto +/** + * Copyright 2024 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util import ( "unsafe" @@ -23,7 +39,7 @@ type FieldNameMap struct { } // Set sets the field descriptor for the given key -func (ft *FieldNameMap) Set(key string, field *FieldDescriptor) (exist bool) { +func (ft *FieldNameMap) Set(key string, field unsafe.Pointer) (exist bool) { if len(key) > ft.maxKeyLength { ft.maxKeyLength = len(key) } @@ -39,32 +55,37 @@ func (ft *FieldNameMap) Set(key string, field *FieldDescriptor) (exist bool) { } // Get gets the field descriptor for the given key -func (ft FieldNameMap) Get(k string) *FieldDescriptor { +func (ft FieldNameMap) Get(k string) unsafe.Pointer { if ft.trie != nil { - return (*FieldDescriptor)(ft.trie.Get(k)) + return (unsafe.Pointer)(ft.trie.Get(k)) } else if ft.hash != nil { - return (*FieldDescriptor)(ft.hash.Get(k)) + return (unsafe.Pointer)(ft.hash.Get(k)) } return nil } // All returns all field descriptors -func (ft FieldNameMap) All() []*FieldDescriptor { - return *(*[]*FieldDescriptor)(unsafe.Pointer(&ft.all)) +func (ft FieldNameMap) All() []caching.Pair { + return ft.all } // Size returns the size of the map func (ft FieldNameMap) Size() int { if ft.hash != nil { return ft.hash.Size() - } else { + } else if ft.trie != nil { return ft.trie.Size() } + return 0 } // Build builds the map. // It will try to build a trie tree if the dispersion of keys is higher enough (min). func (ft *FieldNameMap) Build() { + if len(ft.all) == 0 { + return + } + var empty unsafe.Pointer // statistics the distrubution for each position: @@ -146,23 +167,23 @@ func (ft *FieldNameMap) Build() { } // FieldIDMap is a map from field id to field descriptor -type FieldNumberMap struct { - m []*FieldDescriptor - all []*FieldDescriptor +type FieldIDMap struct { + m []unsafe.Pointer + all []unsafe.Pointer } // All returns all field descriptors -func (fd FieldNumberMap) All() (ret []*FieldDescriptor) { +func (fd FieldIDMap) All() (ret []unsafe.Pointer) { return fd.all } // Size returns the size of the map -func (fd FieldNumberMap) Size() int { +func (fd FieldIDMap) Size() int { return len(fd.m) } // Get gets the field descriptor for the given id -func (fd FieldNumberMap) Get(id FieldNumber) *FieldDescriptor { +func (fd FieldIDMap) Get(id int32) unsafe.Pointer { if int(id) >= len(fd.m) { return nil } @@ -170,10 +191,10 @@ func (fd FieldNumberMap) Get(id FieldNumber) *FieldDescriptor { } // Set sets the field descriptor for the given id -func (fd *FieldNumberMap) Set(id FieldNumber, f *FieldDescriptor) { +func (fd *FieldIDMap) Set(id int32, f unsafe.Pointer) { if int(id) >= len(fd.m) { len := int(id) + 1 - tmp := make([]*FieldDescriptor, len) + tmp := make([]unsafe.Pointer, len) copy(tmp, fd.m) fd.m = tmp } @@ -189,4 +210,4 @@ func (fd *FieldNumberMap) Set(id FieldNumber, f *FieldDescriptor) { } } fd.m[id] = f -} \ No newline at end of file +} diff --git a/internal/util/fieldmap_test.go b/internal/util/fieldmap_test.go new file mode 100644 index 00000000..aafb7dd0 --- /dev/null +++ b/internal/util/fieldmap_test.go @@ -0,0 +1,35 @@ +/** + * Copyright 2024 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import "testing" + +func TestEmptyFieldMap(t *testing.T) { + // empty test + ids := FieldIDMap{} + if ids.Get(1) != nil { + t.Fatalf("expect nil") + } + names := FieldNameMap{} + if names.Get("a") != nil { + t.Fatalf("expect nil") + } + names.Build() + if names.Get("a") != nil { + t.Fatalf("expect nil") + } +} diff --git a/proto/descriptor.go b/proto/descriptor.go index d40589e4..5e3e0448 100644 --- a/proto/descriptor.go +++ b/proto/descriptor.go @@ -1,5 +1,7 @@ package proto +import "github.com/cloudwego/dynamicgo/internal/util" + type TypeDescriptor struct { baseId FieldNumber // for LIST/MAP to write field tag by baseId typ Type @@ -113,8 +115,8 @@ func (f *FieldDescriptor) IsList() bool { type MessageDescriptor struct { baseId FieldNumber name string - ids FieldNumberMap - names FieldNameMap // store name and jsonName for FieldDescriptor + ids util.FieldIDMap + names util.FieldNameMap // store name and jsonName for FieldDescriptor } func (m *MessageDescriptor) Name() string { @@ -122,15 +124,15 @@ func (m *MessageDescriptor) Name() string { } func (m *MessageDescriptor) ByJSONName(name string) *FieldDescriptor { - return m.names.Get(name) + return (*FieldDescriptor)(m.names.Get(name)) } func (m *MessageDescriptor) ByName(name string) *FieldDescriptor { - return m.names.Get(name) + return (*FieldDescriptor)(m.names.Get(name)) } func (m *MessageDescriptor) ByNumber(id FieldNumber) *FieldDescriptor { - return m.ids.Get(id) + return (*FieldDescriptor)(m.ids.Get(int32(id))) } func (m *MessageDescriptor) FieldsCount() int { diff --git a/proto/idl.go b/proto/idl.go index 4b543c4c..da0353d8 100644 --- a/proto/idl.go +++ b/proto/idl.go @@ -4,7 +4,9 @@ import ( "context" "errors" "math" + "unsafe" + "github.com/cloudwego/dynamicgo/internal/util" "github.com/cloudwego/dynamicgo/meta" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/desc/protoparse" @@ -171,8 +173,8 @@ func parseMessage(ctx context.Context, msgDesc *desc.MessageDescriptor, cache co fields := msgDesc.GetFields() md := &MessageDescriptor{ baseId: FieldNumber(math.MaxInt32), - ids: FieldNumberMap{}, - names: FieldNameMap{}, + ids: util.FieldIDMap{}, + names: util.FieldNameMap{}, } ty = &TypeDescriptor{ @@ -249,9 +251,9 @@ func parseMessage(ctx context.Context, msgDesc *desc.MessageDescriptor, cache co // add fieldDescriptor to MessageDescriptor // md.ids[FieldNumber(id)] = fieldDesc - md.ids.Set(FieldNumber(id), fieldDesc) - md.names.Set(name, fieldDesc) - md.names.Set(jsonName, fieldDesc) + md.ids.Set(int32(id), unsafe.Pointer(fieldDesc)) + md.names.Set(name, unsafe.Pointer(fieldDesc)) + md.names.Set(jsonName, unsafe.Pointer(fieldDesc)) } md.names.Build() diff --git a/thrift/descriptor.go b/thrift/descriptor.go index 27ec4872..11cd4554 100644 --- a/thrift/descriptor.go +++ b/thrift/descriptor.go @@ -18,8 +18,10 @@ package thrift import ( "fmt" + "unsafe" "github.com/cloudwego/dynamicgo/http" + "github.com/cloudwego/dynamicgo/internal/util" "github.com/cloudwego/thriftgo/parser" ) @@ -166,8 +168,8 @@ func (d TypeDescriptor) Struct() *StructDescriptor { type StructDescriptor struct { baseID FieldID name string - ids FieldIDMap - names FieldNameMap + ids util.FieldIDMap + names util.FieldNameMap requires RequiresBitmap hmFields []*FieldDescriptor annotations []parser.Annotation @@ -212,12 +214,13 @@ func (s StructDescriptor) Name() string { // Len returns the number of fields in the struct func (s StructDescriptor) Len() int { - return len(s.ids.all) + return len(s.ids.All()) } // Fields returns all fields in the struct func (s StructDescriptor) Fields() []*FieldDescriptor { - return s.ids.All() + ret := s.ids.All() + return *(*[]*FieldDescriptor)(unsafe.Pointer(&ret)) } // Fields returns requireness bitmap in the struct. @@ -232,7 +235,7 @@ func (s StructDescriptor) Annotations() []parser.Annotation { // FieldById finds the field by field id func (s StructDescriptor) FieldById(id FieldID) *FieldDescriptor { - return s.ids.Get(id) + return (*FieldDescriptor)(s.ids.Get(int32(id))) } // FieldByName finds the field by key @@ -240,7 +243,7 @@ func (s StructDescriptor) FieldById(id FieldID) *FieldDescriptor { // NOTICE: Options.MapFieldWay can influence the behavior of this method. // ep: if Options.MapFieldWay is MapFieldWayName, then field names should be used as key. func (s StructDescriptor) FieldByKey(k string) (field *FieldDescriptor) { - return s.names.Get(k) + return (*FieldDescriptor)(s.names.Get(k)) } // FieldID is used to identify a field in a struct diff --git a/thrift/idl.go b/thrift/idl.go index b6e5160a..4183ca80 100644 --- a/thrift/idl.go +++ b/thrift/idl.go @@ -25,6 +25,7 @@ import ( "path/filepath" "strings" "time" + "unsafe" "github.com/cloudwego/dynamicgo/http" "github.com/cloudwego/dynamicgo/internal/json" @@ -76,7 +77,7 @@ type Options struct { // 2. it is on the top layer of the root struct of one function. EnableThriftBase bool - // PutNameSpaceToAnnotation indicates to extract the name-space of one type + // PutNameSpaceToAnnotation indicates to extract the name-space of one type // and put it on the type's annotation. The annotion format is: // - Key: "thrift.name_space" (== NameSpaceAnnotationKey) // - Values: pairs of Language and Name. for example: @@ -382,8 +383,8 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, typ: STRUCT, struc: &StructDescriptor{ baseID: FieldID(math.MaxUint16), - ids: FieldIDMap{}, - names: FieldNameMap{}, + ids: util.FieldIDMap{}, + names: util.FieldNameMap{}, requires: make(RequiresBitmap, 1), }, } @@ -393,7 +394,7 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, return err } if reqType.Type() == STRUCT { - for _, f := range reqType.Struct().names.all { + for _, f := range reqType.Struct().names.All() { x := (*FieldDescriptor)(f.Val) if x.isRequestBase { hasRequestBase = true @@ -406,8 +407,8 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, id: FieldID(reqAst.ID), typ: reqType, } - req.Struct().ids.Set(FieldID(reqAst.ID), reqField) - req.Struct().names.Set(reqAst.Name, reqField) + req.Struct().ids.Set(int32(reqAst.ID), unsafe.Pointer(reqField)) + req.Struct().names.Set(reqAst.Name, unsafe.Pointer(reqField)) req.Struct().names.Build() } @@ -418,8 +419,8 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, typ: STRUCT, struc: &StructDescriptor{ baseID: FieldID(math.MaxUint16), - ids: FieldIDMap{}, - names: FieldNameMap{}, + ids: util.FieldIDMap{}, + names: util.FieldNameMap{}, requires: make(RequiresBitmap, 1), }, } @@ -430,9 +431,9 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, respField := &FieldDescriptor{ typ: respType, } - resp.Struct().ids.Set(0, respField) + resp.Struct().ids.Set(0, unsafe.Pointer(respField)) // response has no name or id - resp.Struct().names.Set("", respField) + resp.Struct().names.Set("", unsafe.Pointer(respField)) // parse exceptions if len(fn.Throws) > 0 { @@ -449,8 +450,8 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, // isException: true, typ: exceptionType, } - resp.Struct().ids.Set(FieldID(exp.ID), exceptionField) - resp.Struct().names.Set(exp.Name, exceptionField) + resp.Struct().ids.Set(int32(exp.ID), unsafe.Pointer(exceptionField)) + resp.Struct().names.Set(exp.Name, unsafe.Pointer(exceptionField)) } resp.Struct().names.Build() } @@ -581,8 +582,8 @@ func parseType(ctx context.Context, t *parser.Type, tree *parser.Thrift, cache c struc: &StructDescriptor{ baseID: FieldID(math.MaxUint16), name: typeName, - ids: FieldIDMap{}, - names: FieldNameMap{}, + ids: util.FieldIDMap{}, + names: util.FieldNameMap{}, requires: make(RequiresBitmap, len(st.Fields)), annotations: oannos, }, @@ -677,18 +678,19 @@ func parseType(ctx context.Context, t *parser.Type, tree *parser.Thrift, cache c dv, _ := makeDefaultValue(_f.typ, field.Default, tree) _f.defaultValue = dv } + fp := unsafe.Pointer(_f) // set field id - ty.Struct().ids.Set(FieldID(field.ID), _f) + ty.Struct().ids.Set(int32(field.ID), fp) // set field requireness convertRequireness(field.Requiredness, ty.struc, _f, fopts) // set field key if fopts.MapFieldWay == meta.MapFieldUseAlias { - ty.Struct().names.Set(_f.alias, _f) + ty.Struct().names.Set(_f.alias, fp) } else if fopts.MapFieldWay == meta.MapFieldUseFieldName { - ty.Struct().names.Set(_f.name, _f) + ty.Struct().names.Set(_f.name, fp) } else { - ty.Struct().names.Set(_f.alias, _f) - ty.Struct().names.Set(_f.name, _f) + ty.Struct().names.Set(_f.alias, fp) + ty.Struct().names.Set(_f.name, fp) } } diff --git a/thrift/utils.go b/thrift/utils.go index d43c81c2..711dcdde 100644 --- a/thrift/utils.go +++ b/thrift/utils.go @@ -22,202 +22,17 @@ import ( "sync" "unsafe" - "github.com/cloudwego/dynamicgo/internal/caching" "github.com/cloudwego/dynamicgo/internal/rt" "github.com/cloudwego/dynamicgo/meta" ) -const ( - defaultMaxBucketSize float64 = 10 - defaultMapSize int = 4 - defaultHashMapLoadFactor int = 4 - defaultMaxFieldID = 256 - defaultMaxNestedDepth = 1024 -) - -// FieldNameMap is a map for field name and field descriptor -type FieldNameMap struct { - maxKeyLength int - all []caching.Pair - trie *caching.TrieTree - hash *caching.HashMap -} - -// Set sets the field descriptor for the given key -func (ft *FieldNameMap) Set(key string, field *FieldDescriptor) (exist bool) { - if len(key) > ft.maxKeyLength { - ft.maxKeyLength = len(key) - } - for i, v := range ft.all { - if v.Key == key { - exist = true - ft.all[i].Val = unsafe.Pointer(field) - return - } - } - ft.all = append(ft.all, caching.Pair{Val: unsafe.Pointer(field), Key: key}) - return -} - -// Get gets the field descriptor for the given key -func (ft FieldNameMap) Get(k string) *FieldDescriptor { - if ft.trie != nil { - return (*FieldDescriptor)(ft.trie.Get(k)) - } else if ft.hash != nil { - return (*FieldDescriptor)(ft.hash.Get(k)) - } - return nil -} - -// All returns all field descriptors -func (ft FieldNameMap) All() []*FieldDescriptor { - return *(*[]*FieldDescriptor)(unsafe.Pointer(&ft.all)) -} - -// Size returns the size of the map -func (ft FieldNameMap) Size() int { - if ft.hash != nil { - return ft.hash.Size() - } else { - return ft.trie.Size() - } -} - -// Build builds the map. -// It will try to build a trie tree if the dispersion of keys is higher enough (min). -func (ft *FieldNameMap) Build() { - var empty unsafe.Pointer - - // statistics the distrubution for each position: - // - primary slice store the position as its index - // - secondary map used to merge values with same char at the same position - var positionDispersion = make([]map[byte][]int, ft.maxKeyLength) - - for i, v := range ft.all { - for j := ft.maxKeyLength - 1; j >= 0; j-- { - if v.Key == "" { - // empty key, especially store - empty = v.Val - } - // get the char at the position, defualt (position beyonds key range) is ASCII 0 - var c = byte(0) - if j < len(v.Key) { - c = v.Key[j] - } - - if positionDispersion[j] == nil { - positionDispersion[j] = make(map[byte][]int, 16) - } - // recoder the index i of the value with same char c at the same position j - positionDispersion[j][c] = append(positionDispersion[j][c], i) - } - } - - // calculate the best position which has the highest dispersion - var idealPos = -1 - var min = defaultMaxBucketSize - var count = len(ft.all) - - for i := ft.maxKeyLength - 1; i >= 0; i-- { - cd := positionDispersion[i] - l := len(cd) - // calculate the dispersion (average bucket size) - f := float64(count) / float64(l) - if f < min { - min = f - idealPos = i - } - // 1 means all the value store in different bucket, no need to continue calulating - if min == 1 { - break - } - } - - if idealPos != -1 { - // find the best position, build a trie tree - ft.hash = nil - ft.trie = &caching.TrieTree{} - // NOTICE: we only use a two-layer tree here, for better performance - ft.trie.Positions = append(ft.trie.Positions, idealPos) - // set all key-values to the trie tree - for _, v := range ft.all { - ft.trie.Set(v.Key, v.Val) - } - if empty != nil { - ft.trie.Empty = empty - } - - } else { - // no ideal position, build a hash map - ft.trie = nil - ft.hash = caching.NewHashMap(len(ft.all), defaultHashMapLoadFactor) - // set all key-values to the trie tree - for _, v := range ft.all { - // caching.HashMap does not support duplicate key, so must check if the key exists before set - // WARN: if the key exists, the value WON'T be replaced - o := ft.hash.Get(v.Key) - if o == nil { - ft.hash.Set(v.Key, v.Val) - } - } - if empty != nil { - ft.hash.Set("", empty) - } - } -} - -// FieldIDMap is a map from field id to field descriptor -type FieldIDMap struct { - m []*FieldDescriptor - all []*FieldDescriptor -} - -// All returns all field descriptors -func (fd FieldIDMap) All() (ret []*FieldDescriptor) { - return fd.all -} - -// Size returns the size of the map -func (fd FieldIDMap) Size() int { - return len(fd.m) -} - -// Get gets the field descriptor for the given id -func (fd FieldIDMap) Get(id FieldID) *FieldDescriptor { - if int(id) >= len(fd.m) { - return nil - } - return fd.m[id] -} - -// Set sets the field descriptor for the given id -func (fd *FieldIDMap) Set(id FieldID, f *FieldDescriptor) { - if int(id) >= len(fd.m) { - len := int(id) + 1 - tmp := make([]*FieldDescriptor, len) - copy(tmp, fd.m) - fd.m = tmp - } - o := (fd.m)[id] - if o == nil { - fd.all = append(fd.all, f) - } else { - for i, v := range fd.all { - if v == o { - fd.all[i] = f - break - } - } - } - fd.m[id] = f -} - // RequiresBitmap is a bitmap to mark fields type RequiresBitmap []uint64 const ( - int64BitSize = 64 - int64ByteSize = 8 + defaultMaxFieldID = 256 + int64BitSize = 64 + int64ByteSize = 8 ) var bitmapPool = sync.Pool{ @@ -287,10 +102,11 @@ func FreeRequiresBitmap(b *RequiresBitmap) { bitmapPool.Put(b) } -//go:nocheckptr // CheckRequires scan every bit of the bitmap. When a bit is marked, it will: // - if the corresponding field is required-requireness, it reports error // - if the corresponding is not required-requireness but writeDefault is true, it will call handler to handle this field +// +//go:nocheckptr func (b RequiresBitmap) CheckRequires(desc *StructDescriptor, writeDefault bool, handler func(field *FieldDescriptor) error) error { // handle bitmap first n := len(b) @@ -324,11 +140,12 @@ func (b RequiresBitmap) CheckRequires(desc *StructDescriptor, writeDefault bool, return nil } -//go:nocheckptr // CheckRequires scan every bit of the bitmap. When a bit is marked, it will: // - if the corresponding field is required-requireness and writeRquired is true, it will call handler to handle this field, otherwise report error // - if the corresponding is default-requireness and writeDefault is true, it will call handler to handle this field // - if the corresponding is optional-requireness and writeOptional is true, it will call handler to handle this field +// +//go:nocheckptr func (b RequiresBitmap) HandleRequires(desc *StructDescriptor, writeRquired bool, writeDefault bool, writeOptional bool, handler func(field *FieldDescriptor) error) error { // handle bitmap first n := len(b)