Skip to content

Commit

Permalink
Merge pull request #3 from ruslanSorokin/main
Browse files Browse the repository at this point in the history
feat: add `GetAs` function to return typed repository
  • Loading branch information
sesaquecruz authored Jan 8, 2024
2 parents 4833b42 + bcd6261 commit 63212ac
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 66 deletions.
17 changes: 17 additions & 0 deletions uow/uow.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
var (
ErrRepositoryNotRegistered = errors.New("repository not registered")
ErrRepositoryAlreadyRegistered = errors.New("repository already registered")
ErrInvalidRepositoryType = errors.New("invalid repository type")
)

type RepositoryName string
Expand Down Expand Up @@ -47,6 +48,22 @@ func NewTransaction(tx *sql.Tx, repositories map[RepositoryName]RepositoryFactor
}
}

// Return repository of type T if any found.
// In case of type cast error returns ErrInvalidRepositoryType.
func GetAs[T any](t TX, name RepositoryName) (T, error) {
repository, err := t.Get(name)
var res T
if err != nil {
return res, err
}
res, ok := repository.(T)
if !ok {
return res, ErrInvalidRepositoryType
}

return res, nil
}

// Given a repository name returns a repository. Return an error if the repository does not exist.
func (t *Transaction) Get(name RepositoryName) (Repository, error) {
if repository, ok := t.repositories[name]; ok {
Expand Down
122 changes: 56 additions & 66 deletions uow/uow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,45 @@ func Test_Transaction_Get(t *testing.T) {
assert.Same(t, tx, orderRepository.(*OrderRepository).tx)
}

func Test_GetAs(t *testing.T) {
tx := &sql.Tx{}
repositories := make(map[RepositoryName]RepositoryFactory)

transaction := NewTransaction(tx, repositories)

_, err := transaction.Get("ProductRepository")
assert.ErrorIs(t, ErrRepositoryNotRegistered, err)

_, err = transaction.Get("OrderRepository")
assert.ErrorIs(t, ErrRepositoryNotRegistered, err)

transaction.repositories["ProductRepository"] = func(tx *sql.Tx) Repository {
return NewProductRepository(tx)
}

transaction.repositories["OrderRepository"] = func(tx *sql.Tx) Repository {
return NewOrderRepository(tx)
}

productRepository, err := GetAs[*ProductRepository](transaction, "ProductRepository")
assert.Nil(t, err)
assert.IsType(t, &ProductRepository{}, productRepository)
assert.Same(t, tx, productRepository.tx)

_, err = GetAs[ProductRepository](transaction, "ProductRepository")
assert.ErrorIs(t, err, ErrInvalidRepositoryType,
"trying to cast pointer object to value object")

orderRepository, err := GetAs[*OrderRepository](transaction, "OrderRepository")
assert.Nil(t, err)
assert.IsType(t, &OrderRepository{}, orderRepository)
assert.Same(t, tx, orderRepository.tx)

_, err = GetAs[OrderRepository](transaction, "OrderRepository")
assert.ErrorIs(t, err, ErrInvalidRepositoryType,
"trying to cast pointer object to value object")
}

func Test_UnitOfWork_NewUnitOfWork(t *testing.T) {
db := &sql.DB{}

Expand Down Expand Up @@ -288,16 +327,16 @@ func Test_UnitOfWork_Do_WhenTransactionSucceeds(t *testing.T) {

_, err = db.Exec(`
CREATE TABLE products (
id VARCHAR(36) PRIMARY KEY,
id VARCHAR(36) PRIMARY KEY,
amount INT(32) UNSIGNED NOT NULL
);
`)
require.Nil(t, err)

_, err = db.Exec(`
CREATE TABLE orders (
id VARCHAR(36) PRIMARY KEY,
product_id VARCHAR(36) NOT NULL,
id VARCHAR(36) PRIMARY KEY,
product_id VARCHAR(36) NOT NULL,
amount INT(32) UNSIGNED NOT NULL,
FOREIGN KEY (product_id) REFERENCES products(id)
);
Expand All @@ -321,16 +360,12 @@ func Test_UnitOfWork_Do_WhenTransactionSucceeds(t *testing.T) {

err = uow.Do(ctx, func(ctx context.Context, tx TX) error {
// Get repository
repository, err := tx.Get("ProductRepository")

productRepository, err := GetAs[*ProductRepository](tx, "ProductRepository")
if err != nil {
return err
}

productRepository, ok := repository.(*ProductRepository)
if !ok {
return errors.New("invalid type")
}

// Save product
err = productRepository.Save(ctx, product)
return err
Expand All @@ -342,26 +377,16 @@ func Test_UnitOfWork_Do_WhenTransactionSucceeds(t *testing.T) {

err = uow.Do(ctx, func(ctx context.Context, tx TX) error {
// Get repositories
repository, err := tx.Get("ProductRepository")
productRepository, err := GetAs[*ProductRepository](tx, "ProductRepository")
if err != nil {
return err
}

productRepository, ok := repository.(*ProductRepository)
if !ok {
return errors.New("invalid type")
}

repository, err = tx.Get("OrderRepository")
orderRepository, err := GetAs[*OrderRepository](tx, "OrderRepository")
if err != nil {
return err
}

orderRepository, ok := repository.(*OrderRepository)
if !ok {
return errors.New("invalid type")
}

// Get itens
productSaved, err := productRepository.Get(ctx, order.productId)
if err != nil {
Expand Down Expand Up @@ -389,26 +414,16 @@ func Test_UnitOfWork_Do_WhenTransactionSucceeds(t *testing.T) {
// Verify amounts
err = uow.Do(ctx, func(ctx context.Context, tx TX) error {
// Get repositories
repository, err := tx.Get("ProductRepository")
productRepository, err := GetAs[*ProductRepository](tx, "ProductRepository")
if err != nil {
return err
}

productRepository, ok := repository.(*ProductRepository)
if !ok {
return errors.New("invalid type")
}

repository, err = tx.Get("OrderRepository")
orderRepository, err := GetAs[*OrderRepository](tx, "OrderRepository")
if err != nil {
return err
}

orderRepository, ok := repository.(*OrderRepository)
if !ok {
return errors.New("invalid type")
}

// Get itens
productSaved, err := productRepository.Get(ctx, product.id)
if err != nil {
Expand Down Expand Up @@ -448,16 +463,16 @@ func Test_UnitOfWork_Do_WhenTransactionFails(t *testing.T) {

_, err = db.Exec(`
CREATE TABLE products (
id VARCHAR(36) PRIMARY KEY,
id VARCHAR(36) PRIMARY KEY,
amount INT(32) UNSIGNED NOT NULL
);
`)
require.Nil(t, err)

_, err = db.Exec(`
CREATE TABLE orders (
id VARCHAR(36) PRIMARY KEY,
product_id VARCHAR(36) NOT NULL,
id VARCHAR(36) PRIMARY KEY,
product_id VARCHAR(36) NOT NULL,
amount INT(32) UNSIGNED NOT NULL,
FOREIGN KEY (product_id) REFERENCES products(id)
);
Expand All @@ -481,16 +496,11 @@ func Test_UnitOfWork_Do_WhenTransactionFails(t *testing.T) {

err = uow.Do(ctx, func(ctx context.Context, tx TX) error {
// Get repository
repository, err := tx.Get("ProductRepository")
productRepository, err := GetAs[*ProductRepository](tx, "ProductRepository")
if err != nil {
return err
}

productRepository, ok := repository.(*ProductRepository)
if !ok {
return errors.New("invalid type")
}

// Save product
err = productRepository.Save(ctx, product)
return err
Expand All @@ -502,26 +512,16 @@ func Test_UnitOfWork_Do_WhenTransactionFails(t *testing.T) {

err = uow.Do(ctx, func(ctx context.Context, tx TX) error {
// Get repositories
repository, err := tx.Get("ProductRepository")
productRepository, err := GetAs[*ProductRepository](tx, "ProductRepository")
if err != nil {
return err
}

productRepository, ok := repository.(*ProductRepository)
if !ok {
return errors.New("invalid type")
}

repository, err = tx.Get("OrderRepository")
orderRepository, err := GetAs[*OrderRepository](tx, "OrderRepository")
if err != nil {
return err
}

orderRepository, ok := repository.(*OrderRepository)
if !ok {
return errors.New("invalid type")
}

// Get itens
productSaved, err := productRepository.Get(ctx, order.productId)
if err != nil {
Expand Down Expand Up @@ -552,26 +552,16 @@ func Test_UnitOfWork_Do_WhenTransactionFails(t *testing.T) {
// Verify amounts
err = uow.Do(ctx, func(ctx context.Context, tx TX) error {
// Get Repositories
repository, err := tx.Get("ProductRepository")
productRepository, err := GetAs[*ProductRepository](tx, "ProductRepository")
if err != nil {
return err
}

productRepository, ok := repository.(*ProductRepository)
if !ok {
return errors.New("invalid type")
}

repository, err = tx.Get("OrderRepository")
orderRepository, err := GetAs[*OrderRepository](tx, "OrderRepository")
if err != nil {
return err
}

orderRepository, ok := repository.(*OrderRepository)
if !ok {
return errors.New("invalid type")
}

// Get itens
productSaved, err := productRepository.Get(ctx, product.id)
if err != nil {
Expand Down

0 comments on commit 63212ac

Please sign in to comment.