diff --git a/temporal/internal/driver/local/keyvalue.go b/temporal/internal/driver/local/keyvalue.go index 7782653..ca3ff0f 100644 --- a/temporal/internal/driver/local/keyvalue.go +++ b/temporal/internal/driver/local/keyvalue.go @@ -10,43 +10,36 @@ import ( "github.com/TykTechnologies/storage/temporal/temperr" ) -func (api *API) Get(ctx context.Context, key string) (value string, err error) { +func (api *API) Get(ctx context.Context, key string) (string, error) { if key == "" { return "", temperr.KeyEmpty } o, err := api.Store.Get(key) if err != nil { - return "", temperr.KeyNotFound + return "", err } - if o == nil { + if o == nil || o.IsExpired() || o.Deleted { return "", temperr.KeyNotFound } - if o.IsExpired() { - return "", temperr.KeyNotFound - } + return api.convertToString(o.Value) +} - if o.Deleted { - return "", temperr.KeyNotFound +func (api *API) convertToString(value interface{}) (string, error) { + switch v := value.(type) { + case string: + return v, nil + case int: + return strconv.Itoa(v), nil + case int64: + return strconv.FormatInt(v, 10), nil + case int32: + return strconv.FormatInt(int64(v), 10), nil + default: + return "", temperr.KeyMisstype } - - v, ok := o.Value.(string) - if !ok { - switch o.Value.(type) { - case int: - v = strconv.Itoa(o.Value.(int)) - case int64: - v = strconv.FormatInt(o.Value.(int64), 10) - case int32: - v = strconv.FormatInt(int64(o.Value.(int32)), 10) - default: - return "", temperr.KeyMisstype - } - } - - return v, nil } func (api *API) Set(ctx context.Context, key, value string, ttl time.Duration) error { @@ -56,16 +49,14 @@ func (api *API) Set(ctx context.Context, key, value string, ttl time.Duration) e o := &Object{ Value: value, + NoExp: ttl <= 0, } - o.NoExp = true - if ttl > 0 { + if !o.NoExp { o.SetExpire(ttl) - o.NoExp = false } - err := api.Store.Set(key, o) - if err != nil { + if err := api.Store.Set(key, o); err != nil { return err } @@ -82,19 +73,11 @@ func (api *API) SetIfNotExist(ctx context.Context, key, value string, expiration return false, err } - if o != nil { - if !o.Deleted && !o.IsExpired() { - return false, nil - } - } - - err = api.Set(ctx, key, value, expiration) - if err != nil { - return false, err + if o != nil && !o.Deleted && !o.IsExpired() { + return false, nil } - err = api.addToKeyIndex(key) - if err != nil { + if err := api.Set(ctx, key, value, expiration); err != nil { return false, err } @@ -112,14 +95,13 @@ func (api *API) Delete(ctx context.Context, key string) error { } if o == nil { - return nil + return nil // Key doesn't exist, no need to delete } o.Deleted = true o.DeletedAt = time.Now() - err = api.Store.Set(key, o) - if err != nil { + if err := api.Store.Set(key, o); err != nil { return err } @@ -140,169 +122,106 @@ func (api *API) Increment(ctx context.Context, key string) (int64, error) { } o, err := api.Store.Get(key) - if err != nil { - // create the object - o = NewCounter(1) + if err != nil || o == nil || o.Deleted || o.IsExpired() { + return api.createNewCounter(key) + } - api.Store.Set(key, o) - api.addToKeyIndex(key) - return 1, nil + value, err := api.getCounterValue(o) + if err != nil { + return 0, err } - if o == nil { - o = NewCounter(1) - - api.Store.Set(key, o) - api.addToKeyIndex(key) - return 1, nil - } - - if o.Deleted || o.IsExpired() { - o = NewCounter(0) - api.addToKeyIndex(key) - } - - var v int64 = -1 - if o.Type != TypeCounter { - switch o.Value.(type) { - case int: - fmt.Println("int") - o.Type = TypeCounter - v = int64(o.Value.(int)) - case int64: - fmt.Println("int64") - o.Type = TypeCounter - v = o.Value.(int64) - case int32: - fmt.Println("int32") - o.Type = TypeCounter - v = int64(o.Value.(int32)) - case string: - // try to convert - conv, err := strconv.Atoi(o.Value.(string)) - if err != nil { - return 0, temperr.KeyMisstype - } - o.Value = int64(conv) - v = int64(conv) - o.Type = TypeCounter - default: - return 0, temperr.KeyMisstype - } - } else { - var ok bool - v, ok = o.Value.(int64) - if !ok { - return 0, temperr.KeyMisstype - } + newValue := value + 1 + o.Value = newValue + o.Type = TypeCounter + + if err := api.Store.Set(key, o); err != nil { + return 0, err } - o.Value = v + 1 - err = api.Store.Set(key, o) - if err != nil { + return newValue, nil +} + +func (api *API) createNewCounter(key string) (int64, error) { + o := NewCounter(1) + if err := api.Store.Set(key, o); err != nil { + return 0, err + } + if err := api.addToKeyIndex(key); err != nil { return 0, err } + return 1, nil +} - return o.Value.(int64), nil +func (api *API) getCounterValue(o *Object) (int64, error) { + switch v := o.Value.(type) { + case int: + return int64(v), nil + case int64: + return v, nil + case int32: + return int64(v), nil + case string: + i, err := strconv.Atoi(v) + if err != nil { + return 0, temperr.KeyMisstype + } + return int64(i), err + default: + return 0, temperr.KeyMisstype + } } -func (api *API) Decrement(ctx context.Context, key string) (newValue int64, err error) { +func (api *API) Decrement(ctx context.Context, key string) (int64, error) { if key == "" { return 0, temperr.KeyEmpty } o, err := api.Store.Get(key) - if err != nil { - // create the object - o = &Object{ - Value: int64(-1), - Type: TypeCounter, - NoExp: true, - } - - api.Store.Set(key, o) - api.addToKeyIndex(key) - return -1, nil + if err != nil || o == nil || o.Deleted || o.IsExpired() { + return api.createNewCounterWithValue(key, -1) } - if o == nil { - o = &Object{ - Value: int64(-1), - Type: TypeCounter, - NoExp: true, - } - - api.Store.Set(key, o) - api.addToKeyIndex(key) - return -1, nil + value, err := api.getCounterValue(o) + if err != nil { + return 0, err } - if o.Deleted || o.IsExpired() { - o = &Object{ - Value: int64(0), - Type: TypeCounter, - NoExp: true, - } - api.addToKeyIndex(key) - } - - var v int64 - if o.Type != TypeCounter { - switch o.Value.(type) { - case int: - o.Type = TypeCounter - v = int64(o.Value.(int)) - case int64: - o.Type = TypeCounter - v = o.Value.(int64) - case int32: - o.Type = TypeCounter - v = int64(o.Value.(int32)) - case string: - // try to convert - conv, err := strconv.Atoi(o.Value.(string)) - if err != nil { - return 0, temperr.KeyMisstype - } - o.Value = int64(conv) - v = int64(conv) - o.Type = TypeCounter - default: - return 0, temperr.KeyMisstype - } - } else { - var ok bool - v, ok = o.Value.(int64) - if !ok { - return 0, temperr.KeyMisstype - } + newValue := value - 1 + o.Value = newValue + o.Type = TypeCounter + + if err := api.Store.Set(key, o); err != nil { + return 0, err } - o.Value = v - 1 + return newValue, nil +} - err = api.Store.Set(key, o) - if err != nil { +func (api *API) createNewCounterWithValue(key string, value int64) (int64, error) { + o := NewCounter(value) + if err := api.Store.Set(key, o); err != nil { return 0, err } - - return o.Value.(int64), nil + if err := api.addToKeyIndex(key); err != nil { + return 0, err + } + return value, nil } -func (api *API) Exists(ctx context.Context, key string) (exists bool, err error) { +func (api *API) Exists(ctx context.Context, key string) (bool, error) { if key == "" { return false, temperr.KeyEmpty } - o, err := api.Get(ctx, key) - if err != nil { - return false, nil + _, err := api.Get(ctx, key) + if err == nil { + return true, nil } - - if o == "" { + if err == temperr.KeyNotFound { return false, nil } - - return true, nil + return false, err } func (api *API) Expire(ctx context.Context, key string, ttl time.Duration) error { @@ -310,36 +229,35 @@ func (api *API) Expire(ctx context.Context, key string, ttl time.Duration) error return temperr.KeyEmpty } + // non-existing keys for these functions should return nil, not errors o, err := api.Store.Get(key) if err != nil { - return err + return nil } - if o == nil { return nil } if ttl <= 0 { o.NoExp = true - return nil + } else { + o.SetExpire(ttl) + o.NoExp = false } - o.SetExpire(ttl) - o.NoExp = false - return api.Store.Set(key, o) } -func (api *API) TTL(ctx context.Context, key string) (ttl int64, err error) { +func (api *API) TTL(ctx context.Context, key string) (int64, error) { if key == "" { return -2, temperr.KeyEmpty } o, err := api.Store.Get(key) if err != nil { - return -2, err + // bizarre, but should return nil + return -2, nil } - if o == nil { return -2, nil } @@ -348,63 +266,56 @@ func (api *API) TTL(ctx context.Context, key string) (ttl int64, err error) { return -1, nil } - return int64(time.Until(o.Exp).Round(time.Second).Seconds()), nil + ttl := time.Until(o.Exp).Round(time.Second).Seconds() + return int64(ttl), nil } -func (api *API) DeleteKeys(ctx context.Context, keys []string) (numberOfDeletedKeys int64, err error) { +func (api *API) DeleteKeys(ctx context.Context, keys []string) (int64, error) { if len(keys) == 0 { return 0, temperr.KeyEmpty } - var k int64 = 0 + + var deleted int64 for _, key := range keys { - e, _ := api.Exists(ctx, key) - if e { - err = api.Delete(ctx, key) - if err != nil { - return k, err + exists, err := api.Exists(ctx, key) + if err != nil { + return deleted, err + } + if exists { + if err := api.Delete(ctx, key); err != nil { + return deleted, err } - k++ + deleted++ } - } - return k, nil + return deleted, nil } func (api *API) DeleteScanMatch(ctx context.Context, pattern string) (int64, error) { - err := api.Connector.Ping(ctx) - if err != nil { + if err := api.Connector.Ping(ctx); err != nil { return 0, err } keys, err := api.Keys(ctx, pattern) - var k int64 = 0 if err != nil { - return k, err + return 0, err } - for _, key := range keys { - err := api.Delete(ctx, key) - if err != nil { - return k, err - } - k++ + // need to return nil for this function + c, err := api.DeleteKeys(ctx, keys) + if err != nil { + return 0, nil } - return k, nil + return c, nil } func (api *API) Keys(ctx context.Context, pattern string) ([]string, error) { - err := api.Connector.Ping(ctx) - if err != nil { + if err := api.Connector.Ping(ctx); err != nil { return nil, err } - // filter is a prefix, e.g. rumbaba:keys:* - // strip the * - // Strip the trailing "*" from the pattern - pattern = strings.TrimSuffix(pattern, "*") - // Get the key index keyIndexObj, err := api.Store.Get(keyIndexKey) if err != nil { return nil, err @@ -412,75 +323,84 @@ func (api *API) Keys(ctx context.Context, pattern string) ([]string, error) { if keyIndexObj == nil { return nil, nil } - keyIndex := keyIndexObj.Value.(map[string]interface{}) - // Get the deleted key index + keyIndex, ok := keyIndexObj.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid key index format") + } + deletedKeyIndexObj, err := api.Store.Get(deletedKeyIndexKey) if err != nil { return nil, err } + deletedKeys, ok := deletedKeyIndexObj.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid deleted key index format") + } - deletedKeys := deletedKeyIndexObj.Value.(map[string]interface{}) + pattern = strings.TrimSuffix(pattern, "*") + var matchedKeys []string - var retKeys []string for key := range keyIndex { - // Check if the key matches the pattern and is not deleted - if pattern == "" || strings.HasPrefix(key, pattern) { - _, f := deletedKeys[key] - if !f { - retKeys = append(retKeys, key) - } + if !api.isKeyDeleted(key, deletedKeys) && strings.HasPrefix(key, pattern) { + matchedKeys = append(matchedKeys, key) } } - return retKeys, nil + return matchedKeys, nil } -func (api *API) GetMulti(ctx context.Context, keys []string) (values []interface{}, err error) { - var objects []interface{} - for _, key := range keys { - o, _ := api.Get(ctx, key) - - if o == "" { - objects = append(objects, nil) - continue - } +func (api *API) isKeyDeleted(key string, deletedKeys map[string]interface{}) bool { + _, deleted := deletedKeys[key] + return deleted +} - objects = append(objects, o) +func (api *API) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + var values []interface{} + for _, key := range keys { + value, err := api.Get(ctx, key) + if err == temperr.KeyNotFound { + values = append(values, nil) + } else if err != nil { + return nil, err + } else { + values = append(values, value) + } } - return objects, nil + return values, nil } -func (api *API) GetKeysAndValuesWithFilter(ctx context.Context, pattern string) (keysAndValues map[string]interface{}, err error) { +func (api *API) GetKeysAndValuesWithFilter(ctx context.Context, pattern string) (map[string]interface{}, error) { keys, err := api.Keys(ctx, pattern) if err != nil { return nil, err } - kv := make(map[string]interface{}) + keysAndValues := make(map[string]interface{}) for _, key := range keys { - o, err := api.Get(ctx, key) - if err != nil { - continue + value, err := api.Get(ctx, key) + if err == nil { + keysAndValues[key] = value + } else if err != temperr.KeyNotFound { + return nil, err } - - kv[key] = o } - return kv, nil + return keysAndValues, nil } func (api *API) GetKeysWithOpts(ctx context.Context, searchStr string, cursors map[string]uint64, count int64) (keys []string, updatedCursor map[string]uint64, continueScan bool, err error) { - err = api.Connector.Ping(ctx) - if err != nil { + if err := api.Connector.Ping(ctx); err != nil { return nil, nil, false, err } - // no op + // TODO: Implement the actual functionality based on your requirements + // This function is currently a no-op and needs to be implemented + return nil, nil, true, nil }