Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions api/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
func TestGetAccountAPI(t *testing.T){
user, _ := randomUser(t)
account := randomAccount(user.Username)
role := util.DepositorRole

testCases := []struct{
name string
Expand All @@ -35,7 +36,7 @@ func TestGetAccountAPI(t *testing.T){
name: "OK",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand All @@ -52,7 +53,7 @@ func TestGetAccountAPI(t *testing.T){
name: "UnauthorizedUser",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand Down Expand Up @@ -82,7 +83,7 @@ func TestGetAccountAPI(t *testing.T){
name: "NotFound",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand All @@ -98,7 +99,7 @@ func TestGetAccountAPI(t *testing.T){
name: "InternalError",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand All @@ -114,7 +115,7 @@ func TestGetAccountAPI(t *testing.T){
name: "InvalidID",
accountID: 0,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand Down Expand Up @@ -154,6 +155,7 @@ func TestGetAccountAPI(t *testing.T){
func TestCreateAccountAPI(t *testing.T) {
user, _ := randomUser(t)
account := randomAccount(user.Username)
role := util.DepositorRole

testCases := []struct {
name string
Expand All @@ -168,7 +170,7 @@ func TestCreateAccountAPI(t *testing.T) {
"currency": account.Currency,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
arg := db.CreateAccountParams{
Expand Down Expand Up @@ -209,7 +211,7 @@ func TestCreateAccountAPI(t *testing.T) {
"currency": account.Currency,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand All @@ -227,7 +229,7 @@ func TestCreateAccountAPI(t *testing.T) {
"currency": "invalid",
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand Down
13 changes: 8 additions & 5 deletions api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/Samudra-G/simplebank/token"
"github.com/Samudra-G/simplebank/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
Expand All @@ -18,16 +19,18 @@ func addAuthorization(
tokenMaker token.Maker,
authorizationType string,
username string,
role string,
duration time.Duration,
) {
token, payload, err := tokenMaker.CreateToken(username, duration)
token, payload, err := tokenMaker.CreateToken(username, role, duration)
require.NoError(t, err)
require.NotEmpty(t, payload)

authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
request.Header.Set(authorizationHeaderKey, authorizationHeader)
}
func TestAuthMiddleware(t *testing.T) {
role := util.DepositorRole
testCases := []struct{
name string
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
Expand All @@ -36,7 +39,7 @@ func TestAuthMiddleware(t *testing.T) {
{
name: "OK",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", role, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){
require.Equal(t, http.StatusOK, recorder.Code)
Expand All @@ -53,7 +56,7 @@ func TestAuthMiddleware(t *testing.T) {
{
name: "UnsupportedAuthorization",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, "unsupported", "user", time.Minute)
addAuthorization(t, request, tokenMaker, "unsupported", "user", role, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
Expand All @@ -62,7 +65,7 @@ func TestAuthMiddleware(t *testing.T) {
{
name: "InvalidAuthorizationFormat",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, "", "user", time.Minute)
addAuthorization(t, request, tokenMaker, "", "user", role, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
Expand All @@ -71,7 +74,7 @@ func TestAuthMiddleware(t *testing.T) {
{
name: "ExpiredToken",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", -time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", role, -time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
Expand Down
1 change: 1 addition & 0 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func (server *Server) renewAccessToken(ctx *gin.Context) {

accessToken, accessPayload, err := server.tokenMaker.CreateToken(
refreshPayload.Username,
refreshPayload.Role,
server.config.AccessTokenDuration,
)
if err != nil {
Expand Down
21 changes: 11 additions & 10 deletions api/transfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

func TestTransferAPI(t *testing.T) {
amount := int64(10)
role := util.DepositorRole

user1, _ := randomUser(t)
user2, _ := randomUser(t)
Expand Down Expand Up @@ -49,7 +50,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
Expand All @@ -75,7 +76,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user2.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user2.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
Expand Down Expand Up @@ -113,7 +114,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
Expand All @@ -133,7 +134,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
Expand All @@ -153,7 +154,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user3.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user3.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account3.ID)).Times(1).Return(account3, nil)
Expand All @@ -173,7 +174,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
Expand All @@ -193,7 +194,7 @@ func TestTransferAPI(t *testing.T) {
"currency": "XYZ",
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
Expand All @@ -212,7 +213,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
Expand All @@ -231,7 +232,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(1).Return(db.Account{}, sql.ErrConnDone)
Expand All @@ -250,7 +251,7 @@ func TestTransferAPI(t *testing.T) {
"currency": util.USD,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
Expand Down
2 changes: 2 additions & 0 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (server *Server) loginUser(ctx *gin.Context) {

accessToken, accessPayload, err := server.tokenMaker.CreateToken(
user.Username,
user.Role,
server.config.AccessTokenDuration,
)
if err != nil {
Expand All @@ -115,6 +116,7 @@ func (server *Server) loginUser(ctx *gin.Context) {

refreshToken, refreshPayload, err := server.tokenMaker.CreateToken(
user.Username,
user.Role,
server.config.RefreshTokenDuration,
)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions db/migration/000005_add_role_to_users.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "users" DROP COLUMN "role";
1 change: 1 addition & 0 deletions db/migration/000005_add_role_to_users.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "users" ADD COLUMN "role" varchar NOT NULL DEFAULT 'depositor';
1 change: 1 addition & 0 deletions db/sqlc/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions db/sqlc/user.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions doc/db.dbml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Project simple_bank {

Table users as U {
username varchar [pk]
role varchar [not null, default: 'depositor']
hashed_password varchar [not null]
full_name varchar [not null]
email varchar [unique, not null]
Expand Down
3 changes: 2 additions & 1 deletion doc/schema.sql
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
-- SQL dump generated using DBML (dbml.dbdiagram.io)
-- Database: PostgreSQL
-- Generated at: 2025-08-07T17:11:40.328Z
-- Generated at: 2025-08-28T16:39:27.644Z

CREATE TABLE "users" (
"username" varchar PRIMARY KEY,
"role" varchar NOT NULL DEFAULT 'depositor',
"hashed_password" varchar NOT NULL,
"full_name" varchar NOT NULL,
"email" varchar UNIQUE NOT NULL,
Expand Down
15 changes: 14 additions & 1 deletion gapi/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const (
authorizationBearer = "bearer"
)

func (server *Server) authorizeUser(ctx context.Context) (*token.Payload, error) {
func (server *Server) authorizeUser(ctx context.Context, accessibleRoles []string) (*token.Payload, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, fmt.Errorf("missing metadata")
Expand Down Expand Up @@ -42,5 +42,18 @@ func (server *Server) authorizeUser(ctx context.Context) (*token.Payload, error)
return nil, fmt.Errorf("invalid access token: %s", err)
}

if !hasPermission(payload.Role, accessibleRoles) {
return nil, fmt.Errorf("permission denied")
}

return payload, nil
}

func hasPermission(userRole string, accessibleRoles []string) bool {
for _, role := range accessibleRoles {
if userRole == role {
return true
}
}
return false
}
4 changes: 2 additions & 2 deletions gapi/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func newTestServer(t *testing.T, store db.Store, taskDistributor worker.TaskDist
return server
}

func newContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, duration time.Duration) context.Context{
accessToken, _, err := tokenMaker.CreateToken(username, duration)
func newContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, role string, duration time.Duration) context.Context{
accessToken, _, err := tokenMaker.CreateToken(username, role, duration)
require.NoError(t, err)
bearerToken := fmt.Sprintf("%s %s", authorizationBearer, accessToken)
md := metadata.MD{
Expand Down
Loading
Loading