Skip to content

Commit

Permalink
feat: add enforcer
Browse files Browse the repository at this point in the history
  • Loading branch information
YenchangChan committed May 27, 2024
1 parent 55c1fb6 commit 0a92b89
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 14 deletions.
22 changes: 14 additions & 8 deletions cmd/password/password.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@ package main
import (
"crypto/md5"
"fmt"
"github.com/housepower/ckman/common"
"golang.org/x/term"
"os"
"path"
"syscall"
)

"github.com/housepower/ckman/common"
"golang.org/x/term"
)

func main() {
fmt.Printf("Initiating the setup of password for reserved user %s\n", common.DefaultUserName)
fmt.Println(`Password must be at least 8 characters long.
Password must contain at least three character categories among the following:
* Uppercase characters (A-Z)
* Lowercase characters (a-z)
* Digits (0-9)
* Special characters (~!@#$%^&*_-+=|\(){}[]:;"'<>,.?/)`)

fmt.Printf("\nEnter password for [%s]: ", common.DefaultUserName)
fmt.Printf("\nEnter username(ckman/guest):")
var username string
fmt.Scanf("%s", &username)
if !common.UsernameInvalid(username) {
fmt.Printf("invalid username, expect %s or %s\n", common.DefaultAdminName, common.DefaultGuestName)
return
}
fmt.Printf("\nEnter password for [%s]: ", username)
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
fmt.Printf("\nEnter password fail: %v\n", err)
Expand All @@ -33,7 +39,7 @@ Password must contain at least three character categories among the following:
return
}

fmt.Printf("\nReenter password for [%s]: ", common.DefaultUserName)
fmt.Printf("\nReenter password for [%s]: ", username)
dupPassword, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
fmt.Printf("\nReenter password fail: %v\n", err)
Expand All @@ -52,7 +58,7 @@ Password must contain at least three character categories among the following:
return
}

passwordFile := path.Join(common.GetWorkDirectory(), "conf/password")
passwordFile := path.Join(common.GetWorkDirectory(), path.Join("conf", common.PasswordFile[username]))
fileFd, err := os.OpenFile(passwordFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
fmt.Printf("\nOpen password file %s fail: %v\n", passwordFile, err)
Expand All @@ -65,5 +71,5 @@ Password must contain at least three character categories among the following:
return
}

fmt.Printf("\nSet password for [%s] success\n", common.DefaultUserName)
fmt.Printf("\nSet password for [%s] success\n", username)
}
12 changes: 11 additions & 1 deletion common/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,20 @@ import (
)

var (
DefaultUserName = "ckman"
DefaultAdminName = "ckman"
DefaultGuestName = "guest"
DefaultSigningKey = "change me"
)

var PasswordFile = map[string]string{
DefaultAdminName: "password",
DefaultGuestName: "guestpassword",
}

func UsernameInvalid(username string) bool {
return ArraySearch(username, []string{DefaultAdminName, DefaultGuestName})
}

type JWT struct {
SigningKey []byte
}
Expand Down
2 changes: 1 addition & 1 deletion common/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestJwt(t *testing.T) {
StandardClaims: jwt.StandardClaims{
IssuedAt: 1136160245,
},
Name: DefaultUserName,
Name: DefaultAdminName,
ClientIP: "172.16.144.1",
}
token, err := j.CreateToken(claims)
Expand Down
6 changes: 3 additions & 3 deletions controller/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ func (controller *UserController) Login(c *gin.Context) {
return
}

if req.Username != common.DefaultUserName {
if !common.UsernameInvalid(req.Username) {
controller.wrapfunc(c, model.E_USER_VERIFY_FAIL, nil)
return
}

passwordFile := path.Join(filepath.Dir(controller.config.ConfigFile), "password")
passwordFile := path.Join(filepath.Dir(controller.config.ConfigFile), common.PasswordFile[req.Username])
data, err := os.ReadFile(passwordFile)
if err != nil {
controller.wrapfunc(c, model.E_GET_USER_PASSWORD_FAIL, err)
Expand All @@ -70,7 +70,7 @@ func (controller *UserController) Login(c *gin.Context) {
IssuedAt: time.Now().Unix(),
// ExpiresAt: time.Now().Add(time.Second * time.Duration(d.config.Server.SessionTimeout)).Unix(),
},
Name: common.DefaultUserName,
Name: req.Username,
ClientIP: c.ClientIP(),
}
token, err := j.CreateToken(claims)
Expand Down
104 changes: 104 additions & 0 deletions server/enforce/enforce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package enforce

import (
"strings"
)

const (
ADMIN string = "ckman"
GUEST string = "guest"

POST string = "POST"
GET string = "GET"
PUT string = "PUT"
DELETE string = "DELETE"
)

type Policy struct {
User string
URL string
Method string
}

type Model struct {
Admin string
UrlPrefix []string
}

type Enforcer struct {
model Model
policies []Policy
}

var DefaultModel = Model{
Admin: ADMIN,
UrlPrefix: []string{
"/api/v1", "/api/v2",
},
}
var e *Enforcer

func init() {
e = &Enforcer{
model: DefaultModel,
policies: []Policy{
{GUEST, "/ck/cluster", GET},
{GUEST, "/ck/cluster/*", GET},
{GUEST, "/ck/table/*", GET},
{GUEST, "/ck/table/group_uniq_array/*", GET},
{GUEST, "/ck/query/*", GET},
{GUEST, "/ck/query_explain/*", GET},
{GUEST, "/ck/query_history/*", GET},
{GUEST, "/ck/table_lists/*", GET},
{GUEST, "/ck/table_schema/*", GET},
{GUEST, "/ck/get/*", GET},
{GUEST, "/ck/partition/*", GET},
{GUEST, "/ck/table_metric/*", GET},
{GUEST, "/ck/table_merges/*", GET},
{GUEST, "/ck/open_sessions/*", GET},
{GUEST, "/ck/slow_sessions/*", GET},
{GUEST, "/ck/ddl_queue/*", GET},
{GUEST, "/ck/node/log/*", POST},
{GUEST, "/ck/ping/*", POST},
{GUEST, "/ck/config/*", GET},
{GUEST, "/zk/status/*", GET},
{GUEST, "/zk/replicated_table/*", GET},
{GUEST, "/package", GET},
{GUEST, "/metric/query/*", GET},
{GUEST, "/metric/query_range/*", GET},
{GUEST, "/version", GET},
{GUEST, "/ui/schema", GET},
{GUEST, "/task/*", GET},
{GUEST, "/task/lists", GET},
{GUEST, "/task/running", GET},
},
}
}

func (e *Enforcer) Match(url1, url2 string) bool {
for _, prefix := range e.model.UrlPrefix {
if !strings.HasPrefix(url2, prefix) {
return false
}
url22 := strings.TrimPrefix(url2, prefix)
url11 := strings.TrimSuffix(url1, "*")
if strings.HasPrefix(url22, url11) {
return true
}
}

return false
}

func Enforce(username, url, method string) bool {
if username == e.model.Admin {
return true
}

for _, policy := range e.policies {
if policy.User == username && e.Match(policy.URL, url) && policy.Method == method {
return true
}
}
return false
}
19 changes: 18 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/housepower/ckman/model"
"github.com/housepower/ckman/repository"
"github.com/housepower/ckman/router"
"github.com/housepower/ckman/server/enforce"
"github.com/housepower/ckman/service/prometheus"
ginSwagger "github.com/swaggo/gin-swagger"
"github.com/swaggo/gin-swagger/swaggerFiles"
Expand Down Expand Up @@ -99,6 +100,7 @@ func (server *ApiServer) Start() error {
// add authenticate middleware for /api
groupApi.Use(ginJWTAuth())
groupApi.Use(ginRefreshTokenExpires())
groupApi.Use(ginEnforce())
groupApi.PUT("/logout", userController.Logout)
groupV1 := groupApi.Group("/v1")
router.InitRouterV1(groupV1, server.config, server.signal)
Expand Down Expand Up @@ -206,6 +208,8 @@ func ginJWTAuth() gin.HandlerFunc {
c.Abort()
return
}
//c.Set("username", userToken.UserId)
c.Set("username", common.DefaultAdminName)
return
}

Expand Down Expand Up @@ -257,8 +261,8 @@ func ginJWTAuth() gin.HandlerFunc {
if clientIp == claims.ClientIP {
c.Set("claims", claims)
c.Set("token", token)
c.Set("username", claims.Name)
return

}
}
if claims.ClientIP != c.ClientIP() {
Expand All @@ -270,6 +274,7 @@ func ginJWTAuth() gin.HandlerFunc {

c.Set("claims", claims)
c.Set("token", token)
c.Set("username", claims.Name)
}
}

Expand Down Expand Up @@ -313,3 +318,15 @@ func PromHttpSD(c *gin.Context) {
objs := prometheus.GetObjects(clusters)
c.JSON(http.StatusOK, objs[schema])
}

func ginEnforce() gin.HandlerFunc {
return func(c *gin.Context) {
username := c.GetString("username")
ok := enforce.Enforce(username, c.Request.URL.RequestURI(), c.Request.Method)
if !ok {
err := fmt.Errorf("permission denied: username [%s]", username)
router.WrapMsg(c, model.E_USER_VERIFY_FAIL, err)
c.Abort()
}
}
}

0 comments on commit 0a92b89

Please sign in to comment.