diff --git a/router/middleware.go b/router/middleware.go index 5ac6fef..21f5e27 100644 --- a/router/middleware.go +++ b/router/middleware.go @@ -1,8 +1,12 @@ package router import ( + "bytes" + "encoding/json" "errors" + "io/ioutil" "net/http" + "strconv" "github.com/labstack/echo" @@ -36,17 +40,39 @@ func MiddlewareAdmin(next echo.HandlerFunc) echo.HandlerFunc { } } -// MiddlewareItemSocial ItemがPersonalItemでない場合はAdmin以外を弾くmiddleware -func MiddlewareItemSocial(getItem func(c echo.Context) model.Item) func(next echo.HandlerFunc) echo.HandlerFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - item := getItem(c) - user := c.Get("user").(model.User) - if item.Type != model.PersonalItem && !user.Admin { - return c.NoContent(http.StatusForbidden) - } +// MiddlewareBodyItemSocial リクエストボディから取得したItemがPersonalItemでない場合はAdmin以外を弾くmiddleware +func MiddlewareBodyItemSocial(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return next(c) + } + c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(body)) + item := model.Item{} + if err = json.Unmarshal(body, &item); err != nil { + return next(c) + } + user := c.Get("user").(model.User) + if item.Type != model.PersonalItem && !user.Admin { + return c.NoContent(http.StatusForbidden) + } + return next(c) + } +} + +// MiddlewareParamItemSocial パラメータから取得したItemがPersonalItemでない場合はAdmin以外を弾くmiddleware +func MiddlewareParamItemSocial(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + itemID, err := strconv.Atoi(c.Param("id")) + if err != nil { return next(c) } + item, _ := model.GetItemByID(uint(itemID)) + user := c.Get("user").(model.User) + if item.Type != model.PersonalItem && !user.Admin { + return c.NoContent(http.StatusForbidden) + } + return next(c) } } diff --git a/router/router.go b/router/router.go index 071c689..c6cb375 100644 --- a/router/router.go +++ b/router/router.go @@ -1,14 +1,9 @@ package router import ( - "bytes" - "encoding/json" - "io/ioutil" "net/http" - "strconv" "github.com/labstack/echo/middleware" - "github.com/traPtitech/booQ/model" "github.com/labstack/echo" ) @@ -31,37 +26,12 @@ func SetupRouting(e *echo.Echo, client *UserProvider) { apiItems := api.Group("/items") { apiItems.GET("", GetItems) - apiItems.POST("", PostItems, MiddlewareItemSocial(func(c echo.Context) model.Item { - item := model.Item{} - body, err := ioutil.ReadAll(c.Request().Body) - if err != nil { - return model.Item{} - } - c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(body)) - if err = json.Unmarshal(body, &item); err != nil { - return model.Item{} - } - return item - })) + apiItems.POST("", PostItems, MiddlewareBodyItemSocial) apiItems.GET("/:id", GetItem) apiItems.PUT("/:id", PutItem) apiItems.DELETE("/:id", DeleteItem, MiddlewareAdmin) - apiItems.POST("/:id/owners", PostOwners, MiddlewareItemSocial(func(c echo.Context) model.Item { - itemID, err := strconv.Atoi(c.Param("id")) - if err != nil { - return model.Item{} - } - item, _ := model.GetItemByID(uint(itemID)) - return item - })) - apiItems.PUT("/:id/owners", PutOwners, MiddlewareItemSocial(func(c echo.Context) model.Item { - itemID, err := strconv.Atoi(c.Param("id")) - if err != nil { - return model.Item{} - } - item, _ := model.GetItemByID(uint(itemID)) - return item - })) + apiItems.POST("/:id/owners", PostOwners, MiddlewareParamItemSocial) + apiItems.PUT("/:id/owners", PutOwners, MiddlewareParamItemSocial) apiItems.POST("/:id/logs", PostLogs) apiItems.POST("/:id/comments", PostComments) apiItems.POST("/:id/likes", PostLikes)