Skip to content

Commit

Permalink
refactor: logic
Browse files Browse the repository at this point in the history
  • Loading branch information
siyul-park committed Nov 29, 2023
1 parent 8de7019 commit 76fd012
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 149 deletions.
19 changes: 19 additions & 0 deletions pkg/primitive/iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package primitive

import "reflect"

func getCommonType(values []any) reflect.Type {
if len(values) == 0 {
return typeAny
}

commonType := reflect.TypeOf(values[0])
for _, value := range values {
typ := reflect.TypeOf(value)
if typ != commonType {
return typeAny
}
}

return commonType
}
162 changes: 76 additions & 86 deletions pkg/primitive/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import (
)

type (
// Map is a representation of a map.
// Map represents a map structure.
Map struct {
value *immutable.SortedMap[Value, Value]
}

// mapTag represents the tag for map fields.
mapTag struct {
alias string
ignore bool
Expand All @@ -31,103 +32,104 @@ const (
tagMap = "map"
)

var _ Value = (*Map)(nil)
var _ immutable.Comparer[Value] = (*comparer)(nil)
var (
_ Value = (*Map)(nil)
_ immutable.Comparer[Value] = (*comparer)(nil)
)

// NewMap returns a new Map.
// NewMap creates a new Map with key-value pairs.
func NewMap(pairs ...Value) *Map {
b := immutable.NewSortedMapBuilder[Value, Value](&comparer{})
builder := immutable.NewSortedMapBuilder[Value, Value](&comparer{})
for i := 0; i < len(pairs)/2; i++ {
k := pairs[i*2]
v := pairs[i*2+1]

b.Set(k, v)
k, v := pairs[i*2], pairs[i*2+1]
builder.Set(k, v)
}
return &Map{value: b.Map()}
return &Map{value: builder.Map()}
}

func (o *Map) Get(key Value) (Value, bool) {
return o.value.Get(key)
// Get retrieves the value for a given key.
func (m *Map) Get(key Value) (Value, bool) {
return m.value.Get(key)
}

func (o *Map) GetOr(key, value Value) Value {
if v, ok := o.Get(key); ok {
// GetOr returns the value for a given key or a default value if the key is not found.
func (m *Map) GetOr(key, value Value) Value {
if v, ok := m.Get(key); ok {
return v
}
return value
}

func (o *Map) Set(key, value Value) *Map {
return &Map{value: o.value.Set(key, value)}
// Set adds or updates a key-value pair in the map.
func (m *Map) Set(key, value Value) *Map {
return &Map{value: m.value.Set(key, value)}
}

func (o *Map) Delete(key Value) *Map {
return &Map{value: o.value.Delete(key)}
// Delete removes a key and its corresponding value from the map.
func (m *Map) Delete(key Value) *Map {
return &Map{value: m.value.Delete(key)}
}

func (o *Map) Keys() []Value {
// Keys returns all keys in the map.
func (m *Map) Keys() []Value {
var keys []Value
itr := m.value.Iterator()

itr := o.value.Iterator()
for !itr.Done() {
k, _, _ := itr.Next()
keys = append(keys, k)
}
return keys
}

func (o *Map) Len() int {
return o.value.Len()
// Len returns the number of key-value pairs in the map.
func (m *Map) Len() int {
return m.value.Len()
}

// Map returns a raw representation.
func (o *Map) Map() map[any]any {
m := make(map[any]any, o.value.Len())
// Map converts the Map to a raw Go map.
func (m *Map) Map() map[any]any {
result := make(map[any]any, m.value.Len())
itr := m.value.Iterator()

itr := o.value.Iterator()
for !itr.Done() {
k, v, _ := itr.Next()

// FIXME: check interface is can't be map key.
if k != nil {
if v != nil {
m[k.Interface()] = v.Interface()
} else {
m[k.Interface()] = nil
}
result[k.Interface()] = v.Interface()
}
}

return m
return result
}

func (o *Map) Kind() Kind {
// Kind returns the kind of the Map.
func (m *Map) Kind() Kind {
return KindMap
}

func (o *Map) Compare(v Value) int {
if r, ok := v.(*Map); !ok {
if o.Kind() > v.Kind() {
return 1
} else {
// Compare compares two maps.
func (m *Map) Compare(v Value) int {
if r, ok := v.(*Map); ok {
keys1, keys2 := m.Keys(), r.Keys()

// Compare lengths
if len(keys1) < len(keys2) {
return -1
} else if len(keys1) > len(keys2) {
return 1
}
} else {
keys1 := o.Keys()
keys2 := r.Keys()

// Compare individual keys and values
for i, k1 := range keys1 {
if len(keys2) == i {
return 1
}

k2 := keys2[i]
if diff := Compare(k1, k2); diff != 0 {
return diff
}

v1, ok1 := o.Get(k1)
v2, ok2 := o.Get(k2)
v1, ok1 := m.Get(k1)
v2, ok2 := r.Get(k2)

if diff := Compare(NewBool(ok1), NewBool(ok2)); diff != 0 {
return diff
}
Expand All @@ -136,51 +138,34 @@ func (o *Map) Compare(v Value) int {
}
}

if len(keys2) > len(keys1) {
return -1
}
return 0
}

// If the types are different, compare based on type kind.
if m.Kind() > v.Kind() {
return 1
}
return -1
}

func (o *Map) Interface() any {
// Interface converts the Map to an interface{}.
func (m *Map) Interface() any {
var keys []any
var values []any

itr := o.value.Iterator()
itr := m.value.Iterator()

for !itr.Done() {
k, v, _ := itr.Next()

// FIXME: check interface is can't be map key.
if k != nil {
keys = append(keys, k.Interface())
if v != nil {
values = append(values, v.Interface())
} else {
values = append(values, nil)
}
values = append(values, v.Interface())
}
}

keyType := typeAny
valueType := typeAny

for i, key := range keys {
typ := reflect.TypeOf(key)
if i == 0 {
keyType = typ
} else if keyType != typ {
keyType = typeAny
}
}
for i, value := range values {
typ := reflect.TypeOf(value)
if i == 0 {
valueType = typ
} else if valueType != typ {
valueType = typeAny
}
}
keyType := getCommonType(keys)
valueType := getCommonType(values)

t := reflect.MakeMapWithSize(reflect.MapOf(keyType, valueType), len(keys))
for i, key := range keys {
Expand All @@ -190,11 +175,12 @@ func (o *Map) Interface() any {
return t.Interface()
}

// comparer.Compare compares two Values.
func (*comparer) Compare(a Value, b Value) int {
return Compare(a, b)
}

// NewMapEncoder is encode map or struct to Map.
// NewMapEncoder encodes a map or struct to a Map.
func NewMapEncoder(encoder encoding.Encoder[any, Value]) encoding.Encoder[any, Value] {
return encoding.EncoderFunc[any, Value](func(source any) (Value, error) {
if s := reflect.ValueOf(source); s.Kind() == reflect.Map {
Expand Down Expand Up @@ -251,12 +237,13 @@ func NewMapEncoder(encoder encoding.Encoder[any, Value]) encoding.Encoder[any, V
})
}

// NewMapDecoder is decode Map to map or struct.
// NewMapDecoder decodes a Map to a map or struct.
func NewMapDecoder(decoder encoding.Decoder[Value, any]) encoding.Decoder[Value, any] {
return encoding.DecoderFunc[Value, any](func(source Value, target any) error {
if s, ok := source.(*Map); ok {
if t := reflect.ValueOf(target); t.Kind() == reflect.Pointer {
if t.Elem().Kind() == reflect.Map {
switch t.Elem().Kind() {
case reflect.Map:
if t.Elem().IsNil() {
t.Elem().Set(reflect.MakeMapWithSize(t.Type().Elem(), s.Len()))
}
Expand All @@ -279,7 +266,7 @@ func NewMapDecoder(decoder encoding.Decoder[Value, any]) encoding.Decoder[Value,
t.Elem().SetMapIndex(k.Elem(), v.Elem())
}
return nil
} else if t.Elem().Kind() == reflect.Struct {
case reflect.Struct:
for i := 0; i < t.Elem().NumField(); i++ {
field := t.Elem().Type().Field(i)
if !field.IsExported() {
Expand Down Expand Up @@ -310,11 +297,14 @@ func NewMapDecoder(decoder encoding.Decoder[Value, any]) encoding.Decoder[Value,
return errors.WithMessage(err, fmt.Sprintf("value(%v) corresponding to the key(%v) cannot be decoded", value.Interface(), field.Name))
}
}
return nil
} else if t.Elem().Type() == typeAny {
t.Elem().Set(reflect.ValueOf(s.Interface()))
return nil
default:
if t.Type() == typeAny {
t.Elem().Set(reflect.ValueOf(s.Interface()))
} else {
return errors.WithStack(encoding.ErrUnsupportedValue)
}
}
return nil
}
}
return errors.WithStack(encoding.ErrUnsupportedValue)
Expand Down
Loading

0 comments on commit 76fd012

Please sign in to comment.