From 241e8fc27fee8886594068b033374772423a94e7 Mon Sep 17 00:00:00 2001 From: Samudra-G Date: Thu, 28 Aug 2025 17:38:02 +0000 Subject: [PATCH] introduced rbac for auth routes --- api/account_test.go | 18 +++++++++------- api/middleware_test.go | 13 +++++++----- api/token.go | 1 + api/transfer_test.go | 21 ++++++++++--------- api/user.go | 2 ++ .../000005_add_role_to_users.down.sql | 1 + db/migration/000005_add_role_to_users.up.sql | 1 + db/sqlc/models.go | 1 + db/sqlc/user.sql.go | 9 +++++--- doc/db.dbml | 1 + doc/schema.sql | 3 ++- gapi/authorization.go | 15 ++++++++++++- gapi/main_test.go | 4 ++-- gapi/rpc_create_user_test.go | 1 + gapi/rpc_login_user.go | 2 ++ gapi/rpc_update_user.go | 4 ++-- gapi/rpc_update_user_test.go | 8 +++---- token/jwt_maker.go | 4 ++-- token/jwt_maker_test.go | 8 ++++--- token/maker.go | 2 +- token/paseto_maker.go | 4 ++-- token/paseto_maker_test.go | 6 ++++-- token/payload.go | 4 +++- util/role.go | 6 ++++++ 24 files changed, 92 insertions(+), 47 deletions(-) create mode 100644 db/migration/000005_add_role_to_users.down.sql create mode 100644 db/migration/000005_add_role_to_users.up.sql create mode 100644 util/role.go diff --git a/api/account_test.go b/api/account_test.go index 0ed266c..4b69f9d 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -23,6 +23,7 @@ import ( func TestGetAccountAPI(t *testing.T){ user, _ := randomUser(t) account := randomAccount(user.Username) + role := util.DepositorRole testCases := []struct{ name string @@ -35,7 +36,7 @@ func TestGetAccountAPI(t *testing.T){ name: "OK", accountID: account.ID, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -52,7 +53,7 @@ func TestGetAccountAPI(t *testing.T){ name: "UnauthorizedUser", accountID: account.ID, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -82,7 +83,7 @@ func TestGetAccountAPI(t *testing.T){ name: "NotFound", accountID: account.ID, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -98,7 +99,7 @@ func TestGetAccountAPI(t *testing.T){ name: "InternalError", accountID: account.ID, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -114,7 +115,7 @@ func TestGetAccountAPI(t *testing.T){ name: "InvalidID", accountID: 0, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -154,6 +155,7 @@ func TestGetAccountAPI(t *testing.T){ func TestCreateAccountAPI(t *testing.T) { user, _ := randomUser(t) account := randomAccount(user.Username) + role := util.DepositorRole testCases := []struct { name string @@ -168,7 +170,7 @@ func TestCreateAccountAPI(t *testing.T) { "currency": account.Currency, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { arg := db.CreateAccountParams{ @@ -209,7 +211,7 @@ func TestCreateAccountAPI(t *testing.T) { "currency": account.Currency, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -227,7 +229,7 @@ func TestCreateAccountAPI(t *testing.T) { "currency": "invalid", }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). diff --git a/api/middleware_test.go b/api/middleware_test.go index 63d3383..be2ab71 100644 --- a/api/middleware_test.go +++ b/api/middleware_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Samudra-G/simplebank/token" + "github.com/Samudra-G/simplebank/util" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -18,9 +19,10 @@ func addAuthorization( tokenMaker token.Maker, authorizationType string, username string, + role string, duration time.Duration, ) { - token, payload, err := tokenMaker.CreateToken(username, duration) + token, payload, err := tokenMaker.CreateToken(username, role, duration) require.NoError(t, err) require.NotEmpty(t, payload) @@ -28,6 +30,7 @@ func addAuthorization( request.Header.Set(authorizationHeaderKey, authorizationHeader) } func TestAuthMiddleware(t *testing.T) { + role := util.DepositorRole testCases := []struct{ name string setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) @@ -36,7 +39,7 @@ func TestAuthMiddleware(t *testing.T) { { name: "OK", setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", role, time.Minute) }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){ require.Equal(t, http.StatusOK, recorder.Code) @@ -53,7 +56,7 @@ func TestAuthMiddleware(t *testing.T) { { name: "UnsupportedAuthorization", setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, "unsupported", "user", time.Minute) + addAuthorization(t, request, tokenMaker, "unsupported", "user", role, time.Minute) }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusUnauthorized, recorder.Code) @@ -62,7 +65,7 @@ func TestAuthMiddleware(t *testing.T) { { name: "InvalidAuthorizationFormat", setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, "", "user", time.Minute) + addAuthorization(t, request, tokenMaker, "", "user", role, time.Minute) }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusUnauthorized, recorder.Code) @@ -71,7 +74,7 @@ func TestAuthMiddleware(t *testing.T) { { name: "ExpiredToken", setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", -time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", role, -time.Minute) }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusUnauthorized, recorder.Code) diff --git a/api/token.go b/api/token.go index 716182c..3fda8ad 100644 --- a/api/token.go +++ b/api/token.go @@ -66,6 +66,7 @@ func (server *Server) renewAccessToken(ctx *gin.Context) { accessToken, accessPayload, err := server.tokenMaker.CreateToken( refreshPayload.Username, + refreshPayload.Role, server.config.AccessTokenDuration, ) if err != nil { diff --git a/api/transfer_test.go b/api/transfer_test.go index e9c4029..08b15f8 100644 --- a/api/transfer_test.go +++ b/api/transfer_test.go @@ -20,6 +20,7 @@ import ( func TestTransferAPI(t *testing.T) { amount := int64(10) + role := util.DepositorRole user1, _ := randomUser(t) user2, _ := randomUser(t) @@ -49,7 +50,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil) @@ -75,7 +76,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user2.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user2.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil) @@ -113,7 +114,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows) @@ -133,7 +134,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil) @@ -153,7 +154,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user3.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user3.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account3.ID)).Times(1).Return(account3, nil) @@ -173,7 +174,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil) @@ -193,7 +194,7 @@ func TestTransferAPI(t *testing.T) { "currency": "XYZ", }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0) @@ -212,7 +213,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0) @@ -231,7 +232,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(1).Return(db.Account{}, sql.ErrConnDone) @@ -250,7 +251,7 @@ func TestTransferAPI(t *testing.T) { "currency": util.USD, }, setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, role, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil) diff --git a/api/user.go b/api/user.go index e1ccd6e..5f4ca06 100644 --- a/api/user.go +++ b/api/user.go @@ -106,6 +106,7 @@ func (server *Server) loginUser(ctx *gin.Context) { accessToken, accessPayload, err := server.tokenMaker.CreateToken( user.Username, + user.Role, server.config.AccessTokenDuration, ) if err != nil { @@ -115,6 +116,7 @@ func (server *Server) loginUser(ctx *gin.Context) { refreshToken, refreshPayload, err := server.tokenMaker.CreateToken( user.Username, + user.Role, server.config.RefreshTokenDuration, ) if err != nil { diff --git a/db/migration/000005_add_role_to_users.down.sql b/db/migration/000005_add_role_to_users.down.sql new file mode 100644 index 0000000..20f88f8 --- /dev/null +++ b/db/migration/000005_add_role_to_users.down.sql @@ -0,0 +1 @@ +ALTER TABLE "users" DROP COLUMN "role"; \ No newline at end of file diff --git a/db/migration/000005_add_role_to_users.up.sql b/db/migration/000005_add_role_to_users.up.sql new file mode 100644 index 0000000..d93b626 --- /dev/null +++ b/db/migration/000005_add_role_to_users.up.sql @@ -0,0 +1 @@ +ALTER TABLE "users" ADD COLUMN "role" varchar NOT NULL DEFAULT 'depositor'; \ No newline at end of file diff --git a/db/sqlc/models.go b/db/sqlc/models.go index e9e2f7c..0c838f2 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -54,6 +54,7 @@ type User struct { PasswordChangedAt time.Time `json:"password_changed_at"` CreatedAt time.Time `json:"created_at"` IsEmailVerified bool `json:"is_email_verified"` + Role string `json:"role"` } type VerifyEmail struct { diff --git a/db/sqlc/user.sql.go b/db/sqlc/user.sql.go index 4180a9c..43c1951 100644 --- a/db/sqlc/user.sql.go +++ b/db/sqlc/user.sql.go @@ -20,7 +20,7 @@ INSERT INTO users ( ) VALUES ( $1, $2, $3, $4 ) -RETURNING username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified +RETURNING username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified, role ` type CreateUserParams struct { @@ -46,12 +46,13 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.PasswordChangedAt, &i.CreatedAt, &i.IsEmailVerified, + &i.Role, ) return i, err } const getUser = `-- name: GetUser :one -SELECT username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified FROM users +SELECT username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified, role FROM users WHERE username = $1 LIMIT 1 ` @@ -66,6 +67,7 @@ func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { &i.PasswordChangedAt, &i.CreatedAt, &i.IsEmailVerified, + &i.Role, ) return i, err } @@ -80,7 +82,7 @@ SET is_email_verified = COALESCE($5, is_email_verified) WHERE username = $6 -RETURNING username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified +RETURNING username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified, role ` type UpdateUserParams struct { @@ -110,6 +112,7 @@ func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, e &i.PasswordChangedAt, &i.CreatedAt, &i.IsEmailVerified, + &i.Role, ) return i, err } diff --git a/doc/db.dbml b/doc/db.dbml index 10603f4..c7fa377 100644 --- a/doc/db.dbml +++ b/doc/db.dbml @@ -7,6 +7,7 @@ Project simple_bank { Table users as U { username varchar [pk] + role varchar [not null, default: 'depositor'] hashed_password varchar [not null] full_name varchar [not null] email varchar [unique, not null] diff --git a/doc/schema.sql b/doc/schema.sql index dfdad74..fdbcc47 100644 --- a/doc/schema.sql +++ b/doc/schema.sql @@ -1,9 +1,10 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-07T17:11:40.328Z +-- Generated at: 2025-08-28T16:39:27.644Z CREATE TABLE "users" ( "username" varchar PRIMARY KEY, + "role" varchar NOT NULL DEFAULT 'depositor', "hashed_password" varchar NOT NULL, "full_name" varchar NOT NULL, "email" varchar UNIQUE NOT NULL, diff --git a/gapi/authorization.go b/gapi/authorization.go index d82444b..f31bc79 100644 --- a/gapi/authorization.go +++ b/gapi/authorization.go @@ -14,7 +14,7 @@ const ( authorizationBearer = "bearer" ) -func (server *Server) authorizeUser(ctx context.Context) (*token.Payload, error) { +func (server *Server) authorizeUser(ctx context.Context, accessibleRoles []string) (*token.Payload, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, fmt.Errorf("missing metadata") @@ -42,5 +42,18 @@ func (server *Server) authorizeUser(ctx context.Context) (*token.Payload, error) return nil, fmt.Errorf("invalid access token: %s", err) } + if !hasPermission(payload.Role, accessibleRoles) { + return nil, fmt.Errorf("permission denied") + } + return payload, nil +} + +func hasPermission(userRole string, accessibleRoles []string) bool { + for _, role := range accessibleRoles { + if userRole == role { + return true + } + } + return false } \ No newline at end of file diff --git a/gapi/main_test.go b/gapi/main_test.go index 0a8cb2b..7ec9233 100644 --- a/gapi/main_test.go +++ b/gapi/main_test.go @@ -26,8 +26,8 @@ func newTestServer(t *testing.T, store db.Store, taskDistributor worker.TaskDist return server } -func newContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, duration time.Duration) context.Context{ - accessToken, _, err := tokenMaker.CreateToken(username, duration) +func newContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, role string, duration time.Duration) context.Context{ + accessToken, _, err := tokenMaker.CreateToken(username, role, duration) require.NoError(t, err) bearerToken := fmt.Sprintf("%s %s", authorizationBearer, accessToken) md := metadata.MD{ diff --git a/gapi/rpc_create_user_test.go b/gapi/rpc_create_user_test.go index b07be71..8595813 100644 --- a/gapi/rpc_create_user_test.go +++ b/gapi/rpc_create_user_test.go @@ -70,6 +70,7 @@ func randomUser(t *testing.T) (db.User, string) { user := db.User{ Username: util.RandomString(6), FullName: util.RandomOwner(), + Role: util.DepositorRole, Email: util.RandomEmail(), HashedPassword: hashedPassword, PasswordChangedAt: time.Now(), diff --git a/gapi/rpc_login_user.go b/gapi/rpc_login_user.go index d189d5a..5f8e465 100644 --- a/gapi/rpc_login_user.go +++ b/gapi/rpc_login_user.go @@ -35,6 +35,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) ( accessToken, accessPayload, err := server.tokenMaker.CreateToken( user.Username, + user.Role, server.config.AccessTokenDuration, ) if err != nil { @@ -43,6 +44,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) ( refreshToken, refreshPayload, err := server.tokenMaker.CreateToken( user.Username, + user.Role, server.config.RefreshTokenDuration, ) if err != nil { diff --git a/gapi/rpc_update_user.go b/gapi/rpc_update_user.go index 210c375..2eb0bb1 100644 --- a/gapi/rpc_update_user.go +++ b/gapi/rpc_update_user.go @@ -16,7 +16,7 @@ import ( ) func (server *Server)UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) (*pb.UpdateUserResponse, error) { - authPayload, err := server.authorizeUser(ctx) + authPayload, err := server.authorizeUser(ctx, []string{util.BankerRole, util.DepositorRole}) if err != nil { return nil, unautheticatedError(err) } @@ -26,7 +26,7 @@ func (server *Server)UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) return nil, invalidArgumentError(violations) } - if authPayload.Username != req.Username { + if authPayload.Role != util.BankerRole && authPayload.Username != req.Username { return nil, status.Errorf(codes.PermissionDenied, "cannot update other user's info") } diff --git a/gapi/rpc_update_user_test.go b/gapi/rpc_update_user_test.go index bbec706..43a2ac5 100644 --- a/gapi/rpc_update_user_test.go +++ b/gapi/rpc_update_user_test.go @@ -46,7 +46,7 @@ func TestUpdateUserAPI(t *testing.T) { Return(db.User{}, sql.ErrNoRows) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context{ - return newContextWithBearerToken(t, tokenMaker, user.Username, time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) @@ -90,7 +90,7 @@ func TestUpdateUserAPI(t *testing.T) { Return(updatedUser, nil) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context{ - return newContextWithBearerToken(t, tokenMaker, user.Username, time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.NoError(t, err) @@ -114,7 +114,7 @@ func TestUpdateUserAPI(t *testing.T) { Times(0) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, user.Username, -time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, -time.Minute) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) @@ -158,7 +158,7 @@ func TestUpdateUserAPI(t *testing.T) { Times(0) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, user.Username, time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) diff --git a/token/jwt_maker.go b/token/jwt_maker.go index 5a43ac1..e14f172 100644 --- a/token/jwt_maker.go +++ b/token/jwt_maker.go @@ -23,8 +23,8 @@ func NewJWTMaker(secretKey string) (Maker, error) { } //CreateToken creates a new token for a specific username and duration -func (maker *JWTMaker) CreateToken(username string, duration time.Duration) (string, *Payload, error) { - payload, err := NewPayload(username, duration) +func (maker *JWTMaker) CreateToken(username string, role string, duration time.Duration) (string, *Payload, error) { + payload, err := NewPayload(username, role, duration) if err != nil { return "", payload, err } diff --git a/token/jwt_maker_test.go b/token/jwt_maker_test.go index 2f84128..281d4d4 100644 --- a/token/jwt_maker_test.go +++ b/token/jwt_maker_test.go @@ -14,12 +14,13 @@ func TestJWTMaker(t *testing.T) { require.NoError(t, err) username := util.RandomOwner() + role := util.DepositorRole duration := time.Minute issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, payload, err := maker.CreateToken(username, duration) + token, payload, err := maker.CreateToken(username, role, duration) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) @@ -30,6 +31,7 @@ func TestJWTMaker(t *testing.T) { require.NotZero(t, payload.ID) require.Equal(t, username, payload.Username) + require.Equal(t, role, payload.Role) require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) } @@ -38,7 +40,7 @@ func TestExpiredJWTToken(t *testing.T) { maker, err := NewJWTMaker(util.RandomString(32)) require.NoError(t, err) - token, payload, err := maker.CreateToken(util.RandomOwner(), -time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) @@ -50,7 +52,7 @@ func TestExpiredJWTToken(t *testing.T) { } func TestInvalidJWTTokenAlgNone(t *testing.T) { - payload, err := NewPayload(util.RandomOwner(), time.Minute) + payload, err := NewPayload(util.RandomOwner(), util.DepositorRole, time.Minute) require.NoError(t, err) jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload) diff --git a/token/maker.go b/token/maker.go index 0bedb7f..45d270b 100644 --- a/token/maker.go +++ b/token/maker.go @@ -5,7 +5,7 @@ import "time" //Maker is an interface for managing tokens type Maker interface { //CreateToken creates a new token for a specific username and duration - CreateToken(username string, duration time.Duration) (string, *Payload, error) + CreateToken(username string, role string, duration time.Duration) (string, *Payload, error) //VerifyToken verifies if a token is valid or not VerifyToken(token string) (*Payload, error) diff --git a/token/paseto_maker.go b/token/paseto_maker.go index 41ddd09..3d852ac 100644 --- a/token/paseto_maker.go +++ b/token/paseto_maker.go @@ -28,8 +28,8 @@ func NewPasetoMaker (symmetricKey string) (Maker, error) { } //CreateToken creates a new token for a specific username and duration -func(maker *PasetoMaker) CreateToken(username string, duration time.Duration) (string, *Payload, error) { - payload, err := NewPayload(username, duration) +func(maker *PasetoMaker) CreateToken(username string, role string, duration time.Duration) (string, *Payload, error) { + payload, err := NewPayload(username, role, duration) if err != nil { return "", payload, err } diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go index 256e28c..af684a6 100644 --- a/token/paseto_maker_test.go +++ b/token/paseto_maker_test.go @@ -13,12 +13,13 @@ func TestPasetoMaker(t *testing.T) { require.NoError(t, err) username := util.RandomOwner() + role := util.DepositorRole duration := time.Minute issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, payload, err := maker.CreateToken(username, duration) + token, payload, err := maker.CreateToken(username, role, duration) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) @@ -29,6 +30,7 @@ func TestPasetoMaker(t *testing.T) { require.NotZero(t, payload.ID) require.Equal(t, username, payload.Username) + require.Equal(t, role, payload.Role) require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) } @@ -37,7 +39,7 @@ func TestExpiredPasetoToken(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(32)) require.NoError(t, err) - token, payload, err := maker.CreateToken(util.RandomOwner(), -time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) diff --git a/token/payload.go b/token/payload.go index a97c5be..c2f77d0 100644 --- a/token/payload.go +++ b/token/payload.go @@ -16,12 +16,13 @@ var( type Payload struct { ID uuid.UUID `json:"id"` Username string `json:"username"` + Role string `json:"role"` IssuedAt time.Time `json:"issued_at"` ExpiredAt time.Time `json:"expired_at"` } //NewPayload creates a new token payload with a specific username and duration -func NewPayload(username string, duration time.Duration) (*Payload, error) { +func NewPayload(username string, role string, duration time.Duration) (*Payload, error) { tokenID, err := uuid.NewRandom() if err != nil { return nil, err @@ -30,6 +31,7 @@ func NewPayload(username string, duration time.Duration) (*Payload, error) { payload := &Payload{ ID: tokenID, Username: username, + Role: role, IssuedAt: time.Now(), ExpiredAt: time.Now().Add(duration), } diff --git a/util/role.go b/util/role.go new file mode 100644 index 0000000..d8ee7d7 --- /dev/null +++ b/util/role.go @@ -0,0 +1,6 @@ +package util + +const ( + DepositorRole = "depositor" + BankerRole = "banker" +) \ No newline at end of file