diff --git a/arrays/chunk.go b/arrays/chunk.go deleted file mode 100644 index 4cd9236..0000000 --- a/arrays/chunk.go +++ /dev/null @@ -1,19 +0,0 @@ -package arrays - -// Chunk will take a slice of any kind and a chunk size and return a slice of slices -func Chunk[T any](arr []T, chunkSize int) [][]T { - chunks := [][]T{} - if len(arr) == 0 { - return chunks - } - - for i := 0; i < len(arr); i += chunkSize { - end := i + chunkSize - if end > len(arr) { - end = len(arr) - } - chunks = append(chunks, arr[i:end]) - } - - return chunks -} diff --git a/auth/encrryption.go b/auth/encrryption.go new file mode 100644 index 0000000..3020e1b --- /dev/null +++ b/auth/encrryption.go @@ -0,0 +1,84 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + + "github.com/ochom/gutils/env" +) + +// Vault +type Vault struct { + key string +} + +// NewVault creates a new Vault with the given key +func NewVault(keys ...string) (*Vault, error) { + key := env.Get("VAULT_KEY") + + if len(keys) > 0 && keys[0] != "" { + key = keys[0] + } + + if len(key) != 32 { + return nil, fmt.Errorf("vault key must be 32 bytes long") + } + + return &Vault{key: key}, nil +} + +// Encrypt text using AES-GCM +func (v Vault) Encrypt(plaintext string) (string, error) { + block, err := aes.NewCipher([]byte(v.key)) + if err != nil { + return "", err + } + + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, aesGCM.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + ciphertext := aesGCM.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt text using AES-GCM +func (v Vault) Decrypt(encryptedText string) (string, error) { + data, err := base64.StdEncoding.DecodeString(encryptedText) + if err != nil { + return "", err + } + + block, err := aes.NewCipher([]byte(v.key)) + if err != nil { + return "", err + } + + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonceSize := aesGCM.NonceSize() + if len(data) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} diff --git a/auth/encrryption_test.go b/auth/encrryption_test.go new file mode 100644 index 0000000..cb158cf --- /dev/null +++ b/auth/encrryption_test.go @@ -0,0 +1,38 @@ +package auth_test + +import ( + "testing" + + "github.com/ochom/gutils/auth" +) + +func TestVault_Encrypt(t *testing.T) { + vault, err := auth.NewVault("12345678901234567890123456789012") + if err != nil { + t.Fatalf("failed to create vault: %v", err) + } + + plaintext := "Hello, World!" + encryptedText, err := vault.Encrypt(plaintext) + if err != nil { + t.Fatalf("encryption failed: %v", err) + } + + if encryptedText == plaintext { + t.Fatalf("encrypted text should not be the same as plaintext") + } + + vault2, err := auth.NewVault("12345678901234567890123456789012") + if err != nil { + t.Fatalf("failed to create vault: %v", err) + } + + decryptedText, err := vault2.Decrypt(encryptedText) + if err != nil { + t.Fatalf("decryption failed: %v", err) + } + + if decryptedText != plaintext { + t.Fatalf("decrypted text does not match original plaintext") + } +} diff --git a/env/env_test.go b/env/env_test.go index eb9164f..db4b159 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -8,7 +8,7 @@ import ( ) func TestGetEnv(t *testing.T) { - os.Setenv("HELLO", "world") + _ = os.Setenv("HELLO", "world") type args struct { key string defaultValue string @@ -45,8 +45,8 @@ func TestGetEnv(t *testing.T) { } func TestInt(t *testing.T) { - os.Setenv("HELLO", "45") - os.Setenv("HELLO2", "45s") + _ = os.Setenv("HELLO", "45") + _ = os.Setenv("HELLO2", "45s") type args struct { key string @@ -92,9 +92,9 @@ func TestInt(t *testing.T) { } func TestBool(t *testing.T) { - os.Setenv("HELLO", "true") - os.Setenv("HELLO2", "false") - os.Setenv("HELLO3", "test") + _ = os.Setenv("HELLO", "true") + _ = os.Setenv("HELLO2", "false") + _ = os.Setenv("HELLO3", "test") type args struct { key string @@ -148,8 +148,8 @@ func TestBool(t *testing.T) { } func TestFloat(t *testing.T) { - os.Setenv("HELLO", "45.5") - os.Setenv("HELLO2", "45.5s") + _ = os.Setenv("HELLO", "45.5") + _ = os.Setenv("HELLO2", "45.5s") type args struct { key string diff --git a/gttp/gofiber.go b/gttp/gofiber.go index 5b560d8..b129059 100644 --- a/gttp/gofiber.go +++ b/gttp/gofiber.go @@ -1,7 +1,6 @@ package gttp import ( - "context" "errors" "fmt" "strings" @@ -23,71 +22,60 @@ func (c *fiberClient) get(url string, headers M, timeouts ...time.Duration) (res } // sendRequest sends a request to the specified URL. -func (c *fiberClient) sendRequest(url, method string, headers M, body []byte, timeouts ...time.Duration) (resp *Response, err error) { +func (c *fiberClient) sendRequest(url, method string, headers M, body []byte, timeouts ...time.Duration) (*Response, error) { timeout := time.Hour if len(timeouts) > 0 { timeout = timeouts[0] } - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - result := make(chan *Response, 1) - go func() { - resp := c.makeRequest(url, method, headers, body) - result <- resp - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case r := <-result: - if len(r.Errors) == 0 { - return r, nil - } - - errStrings := []string{} - for _, err := range r.Errors { - errStrings = append(errStrings, err.Error()) - } + resp := c.makeRequest(url, method, headers, body, timeout) + if len(resp.Errors) == 0 { + return resp, nil + } - return r, errors.New(strings.Join(errStrings, ", ")) + errStrings := []string{} + for _, err := range resp.Errors { + errStrings = append(errStrings, err.Error()) } + + return resp, errors.New(strings.Join(errStrings, ", ")) } // makeRequest sends a request to the specified URL. -func (c *fiberClient) makeRequest(url, method string, headers M, body []byte) (resp *Response) { - client := fiber.AcquireClient() - var req *fiber.Agent - +func (c *fiberClient) makeRequest(url, method string, headers M, body []byte, timeout time.Duration) *Response { + var agent *fiber.Agent switch method { case "POST": - req = client.Post(url) + agent = fiber.Post(url) case "GET": - req = client.Get(url) + agent = fiber.Get(url) case "DELETE": - req = client.Delete(url) + agent = fiber.Delete(url) case "PUT": - req = client.Put(url) + agent = fiber.Put(url) case "PATCH": - req = client.Patch(url) + agent = fiber.Patch(url) default: err := fmt.Errorf("unknown method: %s", method) return NewResponse(500, []error{err}, nil) } // skip ssl verification - req.InsecureSkipVerify() + agent.InsecureSkipVerify() + agent.Timeout(timeout) + // add request headers for k, v := range headers { - req.Add(k, v) + agent.Add(k, v) } + // add request body if method == "POST" || method == "PUT" || method == "PATCH" { - req.Body(body) + agent.Body(body) } - code, content, errs := req.Bytes() + // make request + code, content, errs := agent.Bytes() if code == 0 { code = 500 } diff --git a/helpers/address.go b/helpers/address.go new file mode 100644 index 0000000..ea6b7bc --- /dev/null +++ b/helpers/address.go @@ -0,0 +1,20 @@ +package helpers + +import ( + "fmt" + "net" + "time" + + "github.com/ochom/gutils/logs" +) + +// GetAvailableAddress returns the next available address e.g :8080 +func GetAvailableAddress(port int) string { + _, err := net.DialTimeout("tcp", net.JoinHostPort("", fmt.Sprintf("%d", port)), time.Second) + if err == nil { + logs.Warn("[🥵] address :%d is not available trying another port...", port) + return GetAvailableAddress(port + 1) + } + + return fmt.Sprintf(":%d", port) +} diff --git a/helpers/address_test.go b/helpers/address_test.go new file mode 100644 index 0000000..74fe22f --- /dev/null +++ b/helpers/address_test.go @@ -0,0 +1,27 @@ +package helpers + +import "testing" + +func TestGetAvailableAddress(t *testing.T) { + type args struct { + port int + } + tests := []struct { + name string + args args + want string + }{ + { + name: "Test GetAvailableAddress", + args: args{port: 8080}, + want: ":8080", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetAvailableAddress(tt.args.port); got != tt.want { + t.Errorf("GetAvailableAddress() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/helpers/json.go b/helpers/json.go deleted file mode 100644 index 51c3eda..0000000 --- a/helpers/json.go +++ /dev/null @@ -1,29 +0,0 @@ -package helpers - -import ( - "encoding/json" - - "github.com/ochom/gutils/logs" -) - -// ToBytes converts provided interface to slice of bytes -func ToBytes[T any](payload T) []byte { - bytesPayload, err := json.Marshal(&payload) - if err != nil { - logs.Error("Failed to marshal JSON: %s", err.Error()) - return nil - } - - return bytesPayload -} - -// FromBytes converts slice of bytes to provided interface -func FromBytes[T any](payload []byte) T { - var data T - if err := json.Unmarshal(payload, &data); err != nil { - logs.Error("Failed to unmarshal JSON: %s", err.Error()) - return data - } - - return data -} diff --git a/helpers/mobiles.go b/helpers/mobiles.go index f658473..744cab5 100644 --- a/helpers/mobiles.go +++ b/helpers/mobiles.go @@ -3,40 +3,36 @@ package helpers import ( "crypto/sha256" "fmt" - "slices" "strings" + "unicode" "github.com/ochom/gutils/logs" ) // ParseMobile parses phone number to 254 format -func ParseMobile(mobile string) string { - // replace all non-digit characters - mobile = strings.Map(func(r rune) rune { - if slices.Contains([]rune("0123456789"), r) { - return r +func ParseMobile(mobile string) (string, bool) { + var digits []rune + for _, r := range mobile { + if unicode.IsDigit(r) { + digits = append(digits, r) } - - return -1 - }, mobile) - - // remove leading zeros - mobile = strings.TrimLeft(mobile, "0") - - // remove leading 254 - mobile = strings.TrimPrefix(mobile, "254") - - // check if remaining mobile is 9 digits - if len(mobile) != 9 { - return "" } + cleaned := string(digits) - // check if mobile starts with 7 or 1 - if mobile[0] != '7' && mobile[0] != '1' { - return "" + switch { + case strings.HasPrefix(cleaned, "254"): + if len(cleaned) == 12 && (cleaned[3] == '7' || cleaned[3] == '1') { + return cleaned, true + } + case strings.HasPrefix(cleaned, "07") || strings.HasPrefix(cleaned, "01"): + if len(cleaned) == 10 { + return "254" + cleaned[1:], true + } + case len(cleaned) == 9 && (cleaned[0] == '7' || cleaned[0] == '1'): + return "254" + cleaned, true } - return "254" + mobile + return "", false } // HashPhone hashes phone number to sha256 hash diff --git a/helpers/mobiles_test.go b/helpers/mobiles_test.go index 1e38bec..b1c3284 100644 --- a/helpers/mobiles_test.go +++ b/helpers/mobiles_test.go @@ -11,51 +11,61 @@ func TestParseMobile(t *testing.T) { mobile string } tests := []struct { - name string - args args - want string + name string + args args + wantPhone string + wantOk bool }{ { name: "test 1", args: args{ mobile: "712345678", }, - want: "254712345678", + wantPhone: "254712345678", + wantOk: true, }, { name: "test 2", args: args{ mobile: "0712345678", }, - want: "254712345678", + wantPhone: "254712345678", + wantOk: true, }, { name: "tst 3", args: args{ mobile: "+254712345678", }, - want: "254712345678", + wantPhone: "254712345678", + wantOk: true, }, { name: "test 4", args: args{ mobile: "2547123456", }, - want: "", + wantPhone: "", + wantOk: false, }, { name: "test 5", args: args{ mobile: "254212345678", }, - want: "", + wantPhone: "", + wantOk: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - phone := helpers.ParseMobile(tt.args.mobile) - if phone != "" && phone != tt.want { - t.Errorf("ParseMobile() phone = %v, want %v", phone, tt.want) + phone, ok := helpers.ParseMobile(tt.args.mobile) + if ok != tt.wantOk { + t.Errorf("ParseMobile() ok = %v, wantOk %v", ok, tt.wantOk) + } + + if phone != tt.wantPhone { + t.Errorf("ParseMobile() phone = %v, wantPhone %v", phone, tt.wantPhone) } }) } diff --git a/helpers/test-ports.go b/helpers/test-ports.go deleted file mode 100644 index 9b2c51c..0000000 --- a/helpers/test-ports.go +++ /dev/null @@ -1,20 +0,0 @@ -package helpers - -import ( - "fmt" - "net" - "time" - - "github.com/ochom/gutils/logs" -) - -// TestPorts is a list of ports that are used for testing -func GetPort(startingPort int) int { - _, err := net.DialTimeout("tcp", net.JoinHostPort("", fmt.Sprintf("%d", startingPort)), time.Second) - if err == nil { - logs.Warn("[🥵] port %d is not available trying another port: %d", startingPort, startingPort+1) - return GetPort(startingPort + 1) - } - - return startingPort -} diff --git a/jsonx/json.go b/jsonx/json.go new file mode 100644 index 0000000..09176a3 --- /dev/null +++ b/jsonx/json.go @@ -0,0 +1,44 @@ +package jsonx + +import ( + baseJSON "encoding/json" + + "github.com/ochom/gutils/logs" +) + +type byteData []byte + +// String converts byte data to string +func (b byteData) String() string { + if b == nil { + return "" + } + return string(b) +} + +// Bytes converts byte data to slice of bytes +func (b byteData) Bytes() []byte { + return b +} + +// Encode encodes the given payload into JSON format in byte slice +func Encode(payload any) byteData { + bytesPayload, err := baseJSON.Marshal(&payload) + if err != nil { + logs.Error("Failed to marshal JSON: %s", err.Error()) + return nil + } + + return bytesPayload +} + +// Decode decodes JSON payload into the specified type +func Decode[T any](payload []byte) T { + var data T + if err := baseJSON.Unmarshal(payload, &data); err != nil { + logs.Error("Failed to unmarshal JSON: %s", err.Error()) + return data + } + + return data +} diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 30906ae..b6bdc60 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -10,7 +10,9 @@ import ( type consumer struct { connectionName string url string + exchange string queue string + routingKey string // basic durable bool @@ -24,6 +26,16 @@ type consumer struct { noWait bool } +// SetExchangeName implements Consumer. +func (c *consumer) SetExchangeName(exchangeName string) { + c.exchange = exchangeName +} + +// SetRoutingKey implements Consumer. +func (c *consumer) SetRoutingKey(routingKey string) { + c.routingKey = routingKey +} + // SetConnectionName implements Consumer. func (c *consumer) SetConnectionName(connectionName string) { c.connectionName = connectionName @@ -108,6 +120,11 @@ func (c *consumer) Consume(workerFunc func([]byte)) error { return fmt.Errorf("queue Declare: %s", err.Error()) } + err = bindQueue(ch, c.exchange, q.Name, c.routingKey) + if err != nil { + return fmt.Errorf("queue Bind: %s", err.Error()) + } + deliveries, err := ch.Consume( q.Name, // queue c.tag, // consumerTag diff --git a/pubsub/publisher.go b/pubsub/publisher.go index 0b7e4df..a4bd3ce 100644 --- a/pubsub/publisher.go +++ b/pubsub/publisher.go @@ -22,19 +22,27 @@ func NewPublisher(rabbitURL, exchangeName, queueName string) Publisher { return &publisher{ url: rabbitURL, exchange: exchangeName, - queue: queueName, exchangeType: Direct, + queue: queueName, } } +// SetQueueName ... +func (p *publisher) SetQueueName(queueName string) { + p.queue = queueName +} + +// SetConnectionName ... func (p *publisher) SetConnectionName(connectionName string) { p.connectionName = connectionName } +// SetRoutingKey ... func (p *publisher) SetRoutingKey(routingKey string) { p.routingKey = routingKey } +// SetExchangeType ... func (p *publisher) SetExchangeType(exchangeType ExchangeType) { p.exchangeType = exchangeType } @@ -49,49 +57,9 @@ func (p *publisher) Publish(body []byte) error { return p.publish(body, 0) } -// initPubSub ... -func (p *publisher) initPubSub(ch *amqp.Channel) error { - err := ch.ExchangeDeclare( - p.exchange, // name - string(p.exchangeType), // type - true, // durable - false, // auto-deleted - false, // internal - false, // no-wait - amqp.Table{ - "x-delayed-type": "direct", - }, // arguments - ) - if err != nil { - return fmt.Errorf("exchange Declare: %s", err.Error()) - } - - // declare queue - q, err := ch.QueueDeclare( - p.queue, // name - true, // durable - false, // delete when unused - false, // exclusive - false, // no-wait - nil, // arguments - ) - if err != nil { - return fmt.Errorf("queue Declare: %s", err.Error()) - } - - // bind queue to exchange - err = ch.QueueBind( - q.Name, // queue name - p.routingKey, // routing key - p.exchange, // exchange - false, // no-wait - nil, - ) - if err != nil { - return fmt.Errorf("queue Bind: %s", err.Error()) - } - - return nil +// declare create exchange and queue +func (p *publisher) declare(ch *amqp.Channel) error { + return declare(ch, p.exchange, p.queue, p.routingKey, p.exchangeType) } // publish ... @@ -119,7 +87,7 @@ func (p *publisher) publish(body []byte, delay time.Duration) error { return fmt.Errorf("failed to set QoS: %s", err.Error()) } - if err := p.initPubSub(channel); err != nil { + if err := p.declare(channel); err != nil { return fmt.Errorf("failed to initialize a pubsub: %s", err.Error()) } diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 07ac686..4ee3875 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -1,6 +1,11 @@ package pubsub -import "time" +import ( + "fmt" + "time" + + "github.com/streadway/amqp" +) // ExchangeType ... type ExchangeType string @@ -16,6 +21,7 @@ var ( type Publisher interface { SetConnectionName(string) SetExchangeType(ExchangeType) + SetQueueName(string) SetRoutingKey(string) Publish([]byte) error PublishWithDelay([]byte, time.Duration) error @@ -23,6 +29,8 @@ type Publisher interface { type Consumer interface { SetConnectionName(string) + SetExchangeName(string) + SetRoutingKey(string) SetDurable(bool) SetDeleteWhenUnused(bool) SetTag(string) @@ -32,3 +40,62 @@ type Consumer interface { SetNoWait(bool) Consume(func([]byte)) error } + +// declare create exchange and queue +func declare(ch *amqp.Channel, exchange, queue, routingKey string, exchangeType ExchangeType) error { + err := ch.ExchangeDeclare( + exchange, // name + string(exchangeType), // type + true, // durable + false, // auto-deleted + false, // internal + false, // no-wait + amqp.Table{ + "x-delayed-type": "direct", + }, // arguments + ) + if err != nil { + return fmt.Errorf("exchange Declare: %s", err.Error()) + } + + // if queue is empty, then no need to declare queue + if queue == "" { + return nil + } + + // declare queue + q, err := ch.QueueDeclare( + queue, // name + true, // durable + false, // delete when unused + false, // exclusive + false, // no-wait + nil, // arguments + ) + if err != nil { + return fmt.Errorf("queue Declare: %s", err.Error()) + } + + // bind queue to exchange + return bindQueue(ch, exchange, q.Name, routingKey) +} + +// bind queue to exchange ... +func bindQueue(ch *amqp.Channel, exchange, queue, routingKey string) error { + if exchange == "" || queue == "" { + return nil + } + + err := ch.QueueBind( + queue, // queue name + routingKey, // routing key + exchange, // exchange + false, // no-wait + nil, + ) + if err != nil { + return fmt.Errorf("queue Bind: %s", err.Error()) + } + + return nil +} diff --git a/pubsub/streamx.go b/pubsub/streamx.go index f3264a8..45dde6b 100644 --- a/pubsub/streamx.go +++ b/pubsub/streamx.go @@ -5,47 +5,37 @@ import ( "github.com/ochom/gutils/env" "github.com/ochom/gutils/gttp" - "github.com/ochom/gutils/helpers" + "github.com/ochom/gutils/jsonx" "github.com/ochom/gutils/logs" ) -// StreamMessage ... -type StreamMessage struct { - InstanceID string `json:"instanceID"` - Channel string `json:"channel"` - ID string `json:"id"` - Event string `json:"event"` - Data any `json:"message"` -} - -type StreamX struct { - Url string - apiKey string -} - -var streamX *StreamX - -func init() { - InitStreamX(env.Get("STREAMX_API_KEY")) -} +var ( + url = env.Get("STREAMX_URL", "https://api.StreamSdk.co.ke") + apiKey = env.Get("STREAMX_API_KEY", "") + instanceID = env.Get("STREAMX_INSTANCE_ID", "default") +) -func InitStreamX(apiKey string) { - streamX = &StreamX{apiKey: apiKey} +type StreamSdk struct { + url string + apiKey string + instanceID string } -func (s *StreamX) publish(message *StreamMessage) { - if s == nil { - logs.Error("StreamX not initialized") - return - } - +// PublishStream publishes a message to the stream +func (s StreamSdk) PublishStream(channel string, event string, data any) { headers := map[string]string{ "Content-Type": "application/json", - "Authorization": streamX.apiKey, + "Authorization": s.apiKey, } - url := fmt.Sprintf("%s/publish", env.Get("STREAMX_URL", "https://api.streamx.co.ke")) - res, err := gttp.Post(url, headers, helpers.ToBytes(message)) + url := fmt.Sprintf("%s/publish", s.url) + res, err := gttp.Post(url, headers, jsonx.Encode(map[string]any{ + "instance_id": s.instanceID, + "channel": channel, + "event": event, + "data": data, + })) + if err != nil { logs.Error("Failed to publish message to stream: %v", err) return @@ -55,11 +45,28 @@ func (s *StreamX) publish(message *StreamMessage) { logs.Error("Failed to publish message to stream: %v", string(res.Body)) return } - - logs.Info("StreamMessage published to StreamX ==> msgID: %s", message.ID) } -// PublishStream publishes a message to the stream -func PublishStream(message *StreamMessage) { - go streamX.publish(message) +// NewStreamX create new instance of StreamSdk +// with optional parameters for instance ID, URL, and API key. +func NewStreamX(params ...string) (sdk *StreamSdk) { + sdk = &StreamSdk{ + url: url, + apiKey: apiKey, + instanceID: instanceID, + } + + if len(params) > 0 { + sdk.instanceID = params[0] + } + + if len(params) > 1 { + sdk.url = params[1] + } + + if len(params) > 2 { + sdk.apiKey = params[2] + } + + return sdk } diff --git a/sqlr/db.go b/sqlr/db.go index bfadf9c..b1b6546 100644 --- a/sqlr/db.go +++ b/sqlr/db.go @@ -110,7 +110,7 @@ func getGormConfig(config *Config) *gorm.Config { logger.Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: config.LogLevel, - IgnoreRecordNotFoundError: false, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, Colorful: true, }) diff --git a/sqlr/helpers.go b/sqlr/helpers.go index 7f63da2..45566a4 100644 --- a/sqlr/helpers.go +++ b/sqlr/helpers.go @@ -1,6 +1,8 @@ package sqlr import ( + "context" + "github.com/ochom/gutils/logs" "gorm.io/gorm" ) @@ -10,11 +12,21 @@ func Create[T any](data *T) error { return instance.gormDB.Create(data).Error } +// CreateWithCtx ... +func CreateWithCtx[T any](ctx context.Context, data *T) error { + return instance.gormDB.WithContext(ctx).Create(data).Error +} + // Update ... func Update[T any](data *T) error { return instance.gormDB.Save(data).Error } +// UpdateWithCtx ... +func UpdateWithCtx[T any](ctx context.Context, data *T) error { + return instance.gormDB.WithContext(ctx).Save(data).Error +} + // UpdateOne ... func UpdateOne[T any](scope func(db *gorm.DB) *gorm.DB, updates map[string]any) error { var model T @@ -26,6 +38,11 @@ func Delete[T any](scopes ...func(db *gorm.DB) *gorm.DB) error { return instance.gormDB.Scopes(scopes...).Delete(new(T)).Error } +// DeleteWithCtx ... +func DeleteWithCtx[T any](ctx context.Context, scopes ...func(db *gorm.DB) *gorm.DB) error { + return instance.gormDB.WithContext(ctx).Scopes(scopes...).Delete(new(T)).Error +} + // DeleteById ... func DeleteById[T any](id any, scopes ...func(db *gorm.DB) *gorm.DB) error { return Delete[T](append(scopes, func(db *gorm.DB) *gorm.DB { @@ -74,11 +91,60 @@ func FindWithLimit[T any](page, limit int, scopes ...func(db *gorm.DB) *gorm.DB) // Count ... func Count[T any](scopes ...func(db *gorm.DB) *gorm.DB) int { var count int64 - var model T - if err := instance.gormDB.Model(&model).Scopes(scopes...).Count(&count).Error; err != nil { + if err := instance.gormDB.Model(new(T)).Scopes(scopes...).Count(&count).Error; err != nil { logs.Info("Count: %s", err.Error()) return 0 } return int(count) } + +// Exists ... +func Exists[T any](scopes ...func(db *gorm.DB) *gorm.DB) bool { + query := instance.gormDB.Model(new(T)).Scopes(scopes...).Select("1").Limit(1) + var exists bool + if err := query.Scan(&exists).Error; err != nil { + logs.Info("Exists: %s", err.Error()) + return false + } + + return exists +} + +// Raw ... +func Raw(query string, values ...any) *gorm.DB { + return instance.gormDB.Raw(query, values...) +} + +// Exec ... +func Exec(query string, values ...any) error { + return instance.gormDB.Exec(query, values...).Error +} + +// Transact ... +func Transact(fn ...func(tx *gorm.DB) error) error { + err := instance.gormDB.Transaction(func(db *gorm.DB) error { + for _, f := range fn { + if err := f(db); err != nil { + return err + } + } + return nil + }) + + return err +} + +// TransactWithCtx ... +func TransactWithCtx(ctx context.Context, fn ...func(tx *gorm.DB) error) error { + err := instance.gormDB.WithContext(ctx).Transaction(func(db *gorm.DB) error { + for _, f := range fn { + if err := f(db); err != nil { + return err + } + } + return nil + }) + + return err +} diff --git a/sqlr/types.go b/sqlr/types.go index 86f0f5c..00f9077 100644 --- a/sqlr/types.go +++ b/sqlr/types.go @@ -8,13 +8,14 @@ import ( // Database configuration type Config struct { - Url string - LogLevel logger.LogLevel - MaxOpenConns int - MaxIdleConns int - MaxConnIdleTime time.Duration - MaxConnLifeTime time.Duration - SkipDefaultTransaction bool + Url string + LogLevel logger.LogLevel + IgnoreRecordNotFoundError bool + MaxOpenConns int + MaxIdleConns int + MaxConnIdleTime time.Duration + MaxConnLifeTime time.Duration + SkipDefaultTransaction bool } // defaultConfig ... diff --git a/uuid/short_id.go b/uuid/short_id.go new file mode 100644 index 0000000..2066cec --- /dev/null +++ b/uuid/short_id.go @@ -0,0 +1,38 @@ +package uuid + +import ( + "strconv" + "strings" +) + +// ShortID translates a number to a unique character string +func ShortID(input int) string { + modulus := 1 << 16 // Modulus for Feistel function + rounds := 4 // Number of Feistel rounds + keys := []int{12345, 67890, 54321, 98765} // Round keys + + output := feistelNetwork(input, rounds, modulus, keys) + + res := strings.ToUpper(strconv.FormatInt(int64(output), 36)) + return res +} + +// feistelNetwork for bijective mapping +func feistelNetwork(input, rounds, modulus int, keys []int) int { + left := input >> 16 // Higher 16 bits + right := input & 0xFFFF // Lower 16 bits + + for i := range rounds { + newRight := left ^ feistelFunction(right, keys[i], modulus) + left = right + right = newRight + } + + // Combine left and right parts + return (right << 16) | left +} + +// feistelFunction: a simple example with modular arithmetic +func feistelFunction(data, key, modulus int) int { + return (data*key ^ key) % modulus +}