From 76fd012c6ab65450c327de81a19f0ff476371888 Mon Sep 17 00:00:00 2001 From: siyul-park Date: Wed, 29 Nov 2023 02:47:24 -0500 Subject: [PATCH] refactor: logic --- pkg/primitive/iterator.go | 19 +++++ pkg/primitive/map.go | 162 ++++++++++++++++++-------------------- pkg/primitive/slice.go | 109 ++++++++++++------------- pkg/primitive/string.go | 22 +++--- 4 files changed, 163 insertions(+), 149 deletions(-) create mode 100644 pkg/primitive/iterator.go diff --git a/pkg/primitive/iterator.go b/pkg/primitive/iterator.go new file mode 100644 index 00000000..b6d4b440 --- /dev/null +++ b/pkg/primitive/iterator.go @@ -0,0 +1,19 @@ +package primitive + +import "reflect" + +func getCommonType(values []any) reflect.Type { + if len(values) == 0 { + return typeAny + } + + commonType := reflect.TypeOf(values[0]) + for _, value := range values { + typ := reflect.TypeOf(value) + if typ != commonType { + return typeAny + } + } + + return commonType +} diff --git a/pkg/primitive/map.go b/pkg/primitive/map.go index 8c9c3c0c..66af91bb 100644 --- a/pkg/primitive/map.go +++ b/pkg/primitive/map.go @@ -12,11 +12,12 @@ import ( ) type ( - // Map is a representation of a map. + // Map represents a map structure. Map struct { value *immutable.SortedMap[Value, Value] } + // mapTag represents the tag for map fields. mapTag struct { alias string ignore bool @@ -31,44 +32,49 @@ const ( tagMap = "map" ) -var _ Value = (*Map)(nil) -var _ immutable.Comparer[Value] = (*comparer)(nil) +var ( + _ Value = (*Map)(nil) + _ immutable.Comparer[Value] = (*comparer)(nil) +) -// NewMap returns a new Map. +// NewMap creates a new Map with key-value pairs. func NewMap(pairs ...Value) *Map { - b := immutable.NewSortedMapBuilder[Value, Value](&comparer{}) + builder := immutable.NewSortedMapBuilder[Value, Value](&comparer{}) for i := 0; i < len(pairs)/2; i++ { - k := pairs[i*2] - v := pairs[i*2+1] - - b.Set(k, v) + k, v := pairs[i*2], pairs[i*2+1] + builder.Set(k, v) } - return &Map{value: b.Map()} + return &Map{value: builder.Map()} } -func (o *Map) Get(key Value) (Value, bool) { - return o.value.Get(key) +// Get retrieves the value for a given key. +func (m *Map) Get(key Value) (Value, bool) { + return m.value.Get(key) } -func (o *Map) GetOr(key, value Value) Value { - if v, ok := o.Get(key); ok { +// GetOr returns the value for a given key or a default value if the key is not found. +func (m *Map) GetOr(key, value Value) Value { + if v, ok := m.Get(key); ok { return v } return value } -func (o *Map) Set(key, value Value) *Map { - return &Map{value: o.value.Set(key, value)} +// Set adds or updates a key-value pair in the map. +func (m *Map) Set(key, value Value) *Map { + return &Map{value: m.value.Set(key, value)} } -func (o *Map) Delete(key Value) *Map { - return &Map{value: o.value.Delete(key)} +// Delete removes a key and its corresponding value from the map. +func (m *Map) Delete(key Value) *Map { + return &Map{value: m.value.Delete(key)} } -func (o *Map) Keys() []Value { +// Keys returns all keys in the map. +func (m *Map) Keys() []Value { var keys []Value + itr := m.value.Iterator() - itr := o.value.Iterator() for !itr.Done() { k, _, _ := itr.Next() keys = append(keys, k) @@ -76,58 +82,54 @@ func (o *Map) Keys() []Value { return keys } -func (o *Map) Len() int { - return o.value.Len() +// Len returns the number of key-value pairs in the map. +func (m *Map) Len() int { + return m.value.Len() } -// Map returns a raw representation. -func (o *Map) Map() map[any]any { - m := make(map[any]any, o.value.Len()) +// Map converts the Map to a raw Go map. +func (m *Map) Map() map[any]any { + result := make(map[any]any, m.value.Len()) + itr := m.value.Iterator() - itr := o.value.Iterator() for !itr.Done() { k, v, _ := itr.Next() - // FIXME: check interface is can't be map key. if k != nil { - if v != nil { - m[k.Interface()] = v.Interface() - } else { - m[k.Interface()] = nil - } + result[k.Interface()] = v.Interface() } } - return m + return result } -func (o *Map) Kind() Kind { +// Kind returns the kind of the Map. +func (m *Map) Kind() Kind { return KindMap } -func (o *Map) Compare(v Value) int { - if r, ok := v.(*Map); !ok { - if o.Kind() > v.Kind() { - return 1 - } else { +// Compare compares two maps. +func (m *Map) Compare(v Value) int { + if r, ok := v.(*Map); ok { + keys1, keys2 := m.Keys(), r.Keys() + + // Compare lengths + if len(keys1) < len(keys2) { return -1 + } else if len(keys1) > len(keys2) { + return 1 } - } else { - keys1 := o.Keys() - keys2 := r.Keys() + // Compare individual keys and values for i, k1 := range keys1 { - if len(keys2) == i { - return 1 - } - k2 := keys2[i] if diff := Compare(k1, k2); diff != 0 { return diff } - v1, ok1 := o.Get(k1) - v2, ok2 := o.Get(k2) + v1, ok1 := m.Get(k1) + v2, ok2 := r.Get(k2) + if diff := Compare(NewBool(ok1), NewBool(ok2)); diff != 0 { return diff } @@ -136,51 +138,34 @@ func (o *Map) Compare(v Value) int { } } - if len(keys2) > len(keys1) { - return -1 - } return 0 } + + // If the types are different, compare based on type kind. + if m.Kind() > v.Kind() { + return 1 + } + return -1 } -func (o *Map) Interface() any { +// Interface converts the Map to an interface{}. +func (m *Map) Interface() any { var keys []any var values []any - itr := o.value.Iterator() + itr := m.value.Iterator() + for !itr.Done() { k, v, _ := itr.Next() - // FIXME: check interface is can't be map key. if k != nil { keys = append(keys, k.Interface()) - if v != nil { - values = append(values, v.Interface()) - } else { - values = append(values, nil) - } + values = append(values, v.Interface()) } } - keyType := typeAny - valueType := typeAny - - for i, key := range keys { - typ := reflect.TypeOf(key) - if i == 0 { - keyType = typ - } else if keyType != typ { - keyType = typeAny - } - } - for i, value := range values { - typ := reflect.TypeOf(value) - if i == 0 { - valueType = typ - } else if valueType != typ { - valueType = typeAny - } - } + keyType := getCommonType(keys) + valueType := getCommonType(values) t := reflect.MakeMapWithSize(reflect.MapOf(keyType, valueType), len(keys)) for i, key := range keys { @@ -190,11 +175,12 @@ func (o *Map) Interface() any { return t.Interface() } +// comparer.Compare compares two Values. func (*comparer) Compare(a Value, b Value) int { return Compare(a, b) } -// NewMapEncoder is encode map or struct to Map. +// NewMapEncoder encodes a map or struct to a Map. func NewMapEncoder(encoder encoding.Encoder[any, Value]) encoding.Encoder[any, Value] { return encoding.EncoderFunc[any, Value](func(source any) (Value, error) { if s := reflect.ValueOf(source); s.Kind() == reflect.Map { @@ -251,12 +237,13 @@ func NewMapEncoder(encoder encoding.Encoder[any, Value]) encoding.Encoder[any, V }) } -// NewMapDecoder is decode Map to map or struct. +// NewMapDecoder decodes a Map to a map or struct. func NewMapDecoder(decoder encoding.Decoder[Value, any]) encoding.Decoder[Value, any] { return encoding.DecoderFunc[Value, any](func(source Value, target any) error { if s, ok := source.(*Map); ok { if t := reflect.ValueOf(target); t.Kind() == reflect.Pointer { - if t.Elem().Kind() == reflect.Map { + switch t.Elem().Kind() { + case reflect.Map: if t.Elem().IsNil() { t.Elem().Set(reflect.MakeMapWithSize(t.Type().Elem(), s.Len())) } @@ -279,7 +266,7 @@ func NewMapDecoder(decoder encoding.Decoder[Value, any]) encoding.Decoder[Value, t.Elem().SetMapIndex(k.Elem(), v.Elem()) } return nil - } else if t.Elem().Kind() == reflect.Struct { + case reflect.Struct: for i := 0; i < t.Elem().NumField(); i++ { field := t.Elem().Type().Field(i) if !field.IsExported() { @@ -310,11 +297,14 @@ func NewMapDecoder(decoder encoding.Decoder[Value, any]) encoding.Decoder[Value, return errors.WithMessage(err, fmt.Sprintf("value(%v) corresponding to the key(%v) cannot be decoded", value.Interface(), field.Name)) } } - return nil - } else if t.Elem().Type() == typeAny { - t.Elem().Set(reflect.ValueOf(s.Interface())) - return nil + default: + if t.Type() == typeAny { + t.Elem().Set(reflect.ValueOf(s.Interface())) + } else { + return errors.WithStack(encoding.ErrUnsupportedValue) + } } + return nil } } return errors.WithStack(encoding.ErrUnsupportedValue) diff --git a/pkg/primitive/slice.go b/pkg/primitive/slice.go index 406b9447..66e44a76 100644 --- a/pkg/primitive/slice.go +++ b/pkg/primitive/slice.go @@ -20,94 +20,99 @@ var _ Value = (*Slice)(nil) // NewSlice returns a new Slice. func NewSlice(values ...Value) *Slice { - b := immutable.NewListBuilder[Value]() + builder := immutable.NewListBuilder[Value]() for _, v := range values { - b.Append(v) + builder.Append(v) } - return &Slice{value: b.List()} + return &Slice{value: builder.List()} } -func (o *Slice) Prepend(value Value) *Slice { - return &Slice{value: o.value.Prepend(value)} +func (s *Slice) Prepend(value Value) *Slice { + return &Slice{value: s.value.Prepend(value)} } -func (o *Slice) Append(value Value) *Slice { - return &Slice{value: o.value.Append(value)} +func (s *Slice) Append(value Value) *Slice { + return &Slice{value: s.value.Append(value)} } -func (o *Slice) Sub(start, end int) *Slice { - return &Slice{value: o.value.Slice(start, end)} +func (s *Slice) Sub(start, end int) *Slice { + return &Slice{value: s.value.Slice(start, end)} } -func (o *Slice) Get(index int) Value { - if index >= o.value.Len() { +func (s *Slice) Get(index int) Value { + if index >= s.value.Len() { return nil } - return o.value.Get(index) + return s.value.Get(index) } -func (o *Slice) Set(index int, value Value) *Slice { - if index < 0 && index >= o.value.Len() { - return o +func (s *Slice) Set(index int, value Value) *Slice { + if index < 0 || index >= s.value.Len() { + return s } - return &Slice{value: o.value.Set(index, value)} + return &Slice{value: s.value.Set(index, value)} } -func (o *Slice) Len() int { - return o.value.Len() +func (s *Slice) Len() int { + return s.value.Len() } // Slice returns a raw representation. -func (o *Slice) Slice() []any { - // TODO: support more type defined slice. - s := make([]any, o.value.Len()) +func (s *Slice) Slice() []any { + rawSlice := make([]any, s.value.Len()) - itr := o.value.Iterator() - for !itr.Done() { - i, v := itr.Next() + itr := s.value.Iterator() + for i := 0; !itr.Done(); i++ { + _, v := itr.Next() if v != nil { - s[i] = v.Interface() + rawSlice[i] = v.Interface() } } - return s + return rawSlice } -func (o *Slice) Kind() Kind { +func (s *Slice) Kind() Kind { return KindSlice } -func (o *Slice) Compare(v Value) int { - if r, ok := v.(*Slice); !ok { - if o.Kind() > v.Kind() { - return 1 - } else { - return -1 +func (s *Slice) Compare(v Value) int { + if r, ok := v.(*Slice); ok { + minLen := s.Len() + if minLen > r.Len() { + minLen = r.Len() } - } else { - for i := 0; i < o.Len(); i++ { - if r.Len() == i { - return 1 - } - if diff := Compare(o.Get(i), r.Get(i)); diff != 0 { + for i := 0; i < minLen; i++ { + if diff := Compare(s.Get(i), r.Get(i)); diff != 0 { return diff } } - if o.Len() > r.Len() { + if s.Len() < r.Len() { return -1 + } else if s.Len() > r.Len() { + return 1 } + return 0 } + + // If the types are different, compare based on type kind. + if s.Kind() > v.Kind() { + return 1 + } + return -1 } -func (o *Slice) Interface() any { +func (s *Slice) Interface() any { var values []any - itr := o.value.Iterator() - for !itr.Done() { + + itr := s.value.Iterator() + for i := 0; !itr.Done(); i++ { _, v := itr.Next() + if v != nil { values = append(values, v.Interface()) } else { @@ -115,22 +120,18 @@ func (o *Slice) Interface() any { } } - valueType := typeAny + // Check if all elements have the same type. + elementType := getCommonType(values) + // Create a slice of the common type. + sliceValue := reflect.MakeSlice(reflect.SliceOf(elementType), s.value.Len(), s.value.Len()) for i, value := range values { - typ := reflect.TypeOf(value) - if i == 0 { - valueType = typ - } else if valueType != typ { - valueType = typeAny + if value != nil { + sliceValue.Index(i).Set(reflect.ValueOf(value)) } } - t := reflect.MakeSlice(reflect.SliceOf(valueType), o.value.Len(), o.value.Len()) - for i, value := range values { - t.Index(i).Set(reflect.ValueOf(value)) - } - return t.Interface() + return sliceValue.Interface() } // NewSliceEncoder is encode slice or array to Slice. diff --git a/pkg/primitive/string.go b/pkg/primitive/string.go index 51f52c4c..9f4cc187 100644 --- a/pkg/primitive/string.go +++ b/pkg/primitive/string.go @@ -8,10 +8,8 @@ import ( encoding2 "github.com/siyul-park/uniflow/pkg/encoding" ) -type ( - // String is a representation of a string. - String string -) +// String is a representation of a string. +type String string var _ Value = (String)("") @@ -20,26 +18,31 @@ func NewString(value string) String { return String(value) } +// Len returns the length of the string. func (o String) Len() int { return len([]rune(o)) } +// Get returns the rune at the specified index in the string. func (o String) Get(index int) rune { - if index >= len([]rune(o)) { + runes := []rune(o) + if index >= len(runes) { return rune(0) } - return []rune(o)[index] + return runes[index] } -// String returns a raw representation. +// String returns the raw string representation. func (o String) String() string { return string(o) } +// Kind returns the kind of the value. func (o String) Kind() Kind { return KindString } +// Compare compares two String values. func (o String) Compare(v Value) int { if r, ok := v.(String); !ok { if o.Kind() > v.Kind() { @@ -52,11 +55,12 @@ func (o String) Compare(v Value) int { } } +// Interface converts String to its underlying string. func (o String) Interface() any { return string(o) } -// NewStringEncoder is encode string to String. +// NewStringEncoder encodes a string to a String. func NewStringEncoder() encoding2.Encoder[any, Value] { return encoding2.EncoderFunc[any, Value](func(source any) (Value, error) { if s, ok := source.(encoding.TextMarshaler); ok { @@ -72,7 +76,7 @@ func NewStringEncoder() encoding2.Encoder[any, Value] { }) } -// NewStringDecoder is decode String to string. +// NewStringDecoder decodes a String to a string. func NewStringDecoder() encoding2.Decoder[Value, any] { return encoding2.DecoderFunc[Value, any](func(source Value, target any) error { if s, ok := source.(String); ok {