Skip to content

Commit c0b7c40

Browse files
authored
chore: move auth logic into middleware && naming the files better (#115)
* chore: move auth logic into middleware * refactor: move the project ownership into central place * refactor: give files better names
1 parent 7d27593 commit c0b7c40

File tree

7 files changed

+132
-90
lines changed

7 files changed

+132
-90
lines changed

shibuya/api/collection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func getCollection(collectionID string) (*model.Collection, error) {
2525
}
2626

2727
func (s *ShibuyaAPI) collectionConfigGetHandler(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
28-
collection, err := checkCollectionOwnership(req, params)
28+
collection, err := hasCollectionOwnership(req, params)
2929
if err != nil {
3030
s.handleErrors(w, err)
3131
return

shibuya/api/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,11 @@ func makeInternalServerError(message string) error {
3232
func makeInvalidResourceError(resource string) error {
3333
return fmt.Errorf("%winvalid %s", invalidRequestErr, resource)
3434
}
35+
36+
func makeProjectOwnershipError() error {
37+
return fmt.Errorf("%w%s", noPermissionErr, "You don't own the project")
38+
}
39+
40+
func makeCollectionOwnershipError() error {
41+
return fmt.Errorf("%w%s", noPermissionErr, "You don't own the collection")
42+
}

shibuya/api/main.go

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"net/http"
1010
"strconv"
11+
"strings"
1112
"time"
1213

1314
"github.com/julienschmidt/httprouter"
@@ -89,11 +90,7 @@ func (s *ShibuyaAPI) handleErrors(w http.ResponseWriter, err error) {
8990
}
9091

9192
func (s *ShibuyaAPI) projectsGetHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
92-
account := model.GetAccountBySession(r)
93-
if account == nil {
94-
s.makeFailMessage(w, "Need to login", http.StatusForbidden)
95-
return
96-
}
93+
account := r.Context().Value(accountKey).(*model.Account)
9794
qs := r.URL.Query()
9895
var includeCollections, includePlans bool
9996
var err error
@@ -145,11 +142,7 @@ func (s *ShibuyaAPI) projectUpdateHandler(w http.ResponseWriter, _ *http.Request
145142
}
146143

147144
func (s *ShibuyaAPI) projectCreateHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
148-
account := model.GetAccountBySession(r)
149-
if account == nil {
150-
s.handleErrors(w, makeLoginError())
151-
return
152-
}
145+
account := r.Context().Value(accountKey).(*model.Account)
153146
r.ParseForm()
154147
name := r.Form.Get("name")
155148
if name == "" {
@@ -191,18 +184,14 @@ func (s *ShibuyaAPI) projectCreateHandler(w http.ResponseWriter, r *http.Request
191184
}
192185

193186
func (s *ShibuyaAPI) projectDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
194-
account := model.GetAccountBySession(r)
195-
if account == nil {
196-
s.handleErrors(w, makeLoginError())
197-
return
198-
}
187+
account := r.Context().Value(accountKey).(*model.Account)
199188
project, err := getProject(params.ByName("project_id"))
200189
if err != nil {
201190
s.handleErrors(w, err)
202191
return
203192
}
204-
if _, ok := account.MLMap[project.Owner]; !ok {
205-
s.handleErrors(w, noPermissionErr)
193+
if r := hasProjectOwnership(project, account); !r {
194+
s.handleErrors(w, makeProjectOwnershipError())
206195
return
207196
}
208197
collectionIDs, err := project.GetCollections()
@@ -260,20 +249,16 @@ func (s *ShibuyaAPI) collectionAdminGetHandler(w http.ResponseWriter, r *http.Re
260249
}
261250

262251
func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
263-
account := model.GetAccountBySession(r)
264-
if account == nil {
265-
s.handleErrors(w, makeLoginError())
266-
return
267-
}
252+
account := r.Context().Value(accountKey).(*model.Account)
268253
r.ParseForm()
269254
projectID := r.Form.Get("project_id")
270255
project, err := getProject(projectID)
271256
if err != nil {
272257
s.handleErrors(w, err)
273258
return
274259
}
275-
if _, ok := account.MLMap[project.Owner]; !ok {
276-
s.handleErrors(w, makeNoPermissionErr("You don't own the project"))
260+
if r := hasProjectOwnership(project, account); !r {
261+
s.handleErrors(w, makeProjectOwnershipError())
277262
return
278263
}
279264
name := r.Form.Get("name")
@@ -294,11 +279,7 @@ func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _
294279
}
295280

296281
func (s *ShibuyaAPI) planDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
297-
account := model.GetAccountBySession(r)
298-
if account == nil {
299-
s.handleErrors(w, makeLoginError())
300-
return
301-
}
282+
account := r.Context().Value(accountKey).(*model.Account)
302283
plan, err := getPlan(params.ByName("plan_id"))
303284
if err != nil {
304285
s.handleErrors(w, err)
@@ -309,8 +290,8 @@ func (s *ShibuyaAPI) planDeleteHandler(w http.ResponseWriter, r *http.Request, p
309290
s.handleErrors(w, err)
310291
return
311292
}
312-
if _, ok := account.MLMap[project.Owner]; !ok {
313-
s.handleErrors(w, makeLoginError())
293+
if r := hasProjectOwnership(project, account); !r {
294+
s.handleErrors(w, makeProjectOwnershipError())
314295
return
315296
}
316297
using, err := plan.IsBeingUsed()
@@ -355,7 +336,7 @@ func (s *ShibuyaAPI) collectionFilesGetHandler(w http.ResponseWriter, _ *http.Re
355336
}
356337

357338
func (s *ShibuyaAPI) collectionFilesUploadHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
358-
collection, err := checkCollectionOwnership(r, params)
339+
collection, err := hasCollectionOwnership(r, params)
359340
if err != nil {
360341
s.handleErrors(w, err)
361342
return
@@ -375,7 +356,7 @@ func (s *ShibuyaAPI) collectionFilesUploadHandler(w http.ResponseWriter, r *http
375356
}
376357

377358
func (s *ShibuyaAPI) collectionFilesDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
378-
collection, err := checkCollectionOwnership(r, params)
359+
collection, err := hasCollectionOwnership(r, params)
379360
if err != nil {
380361
s.handleErrors(w, err)
381362
return
@@ -415,11 +396,7 @@ func (s *ShibuyaAPI) planFilesDeleteHandler(w http.ResponseWriter, r *http.Reque
415396
}
416397

417398
func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
418-
account := model.GetAccountBySession(r)
419-
if account == nil {
420-
s.handleErrors(w, makeLoginError())
421-
return
422-
}
399+
account := r.Context().Value(accountKey).(*model.Account)
423400
r.ParseForm()
424401
collectionName := r.Form.Get("name")
425402
if collectionName == "" {
@@ -432,8 +409,8 @@ func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Requ
432409
s.handleErrors(w, err)
433410
return
434411
}
435-
if _, ok := account.MLMap[project.Owner]; !ok {
436-
s.handleErrors(w, makeNoPermissionErr("You don't have the permission"))
412+
if r := hasProjectOwnership(project, account); !r {
413+
s.handleErrors(w, makeProjectOwnershipError())
437414
return
438415
}
439416
collectionID, err := model.CreateCollection(collectionName, project.ID)
@@ -450,7 +427,7 @@ func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Requ
450427
}
451428

452429
func (s *ShibuyaAPI) collectionDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
453-
collection, err := checkCollectionOwnership(r, params)
430+
collection, err := hasCollectionOwnership(r, params)
454431
if err != nil {
455432
s.handleErrors(w, err)
456433
return
@@ -480,7 +457,7 @@ func (s *ShibuyaAPI) collectionDeleteHandler(w http.ResponseWriter, r *http.Requ
480457
}
481458

482459
func (s *ShibuyaAPI) collectionGetHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
483-
collection, err := checkCollectionOwnership(r, params)
460+
collection, err := hasCollectionOwnership(r, params)
484461
if err != nil {
485462
s.handleErrors(w, err)
486463
return
@@ -519,7 +496,7 @@ func (s *ShibuyaAPI) collectionUpdateHandler(w http.ResponseWriter, _ *http.Requ
519496
}
520497

521498
func (s *ShibuyaAPI) collectionUploadHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
522-
collection, err := checkCollectionOwnership(r, params)
499+
collection, err := hasCollectionOwnership(r, params)
523500
if err != nil {
524501
s.handleErrors(w, err)
525502
return
@@ -613,7 +590,7 @@ func (s *ShibuyaAPI) collectionUploadHandler(w http.ResponseWriter, r *http.Requ
613590
}
614591

615592
func (s *ShibuyaAPI) collectionEnginesDetailHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
616-
collection, err := checkCollectionOwnership(r, params)
593+
collection, err := hasCollectionOwnership(r, params)
617594
if err != nil {
618595
s.handleErrors(w, err)
619596
return
@@ -627,7 +604,7 @@ func (s *ShibuyaAPI) collectionEnginesDetailHandler(w http.ResponseWriter, r *ht
627604
}
628605

629606
func (s *ShibuyaAPI) collectionDeploymentHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
630-
collection, err := checkCollectionOwnership(r, params)
607+
collection, err := hasCollectionOwnership(r, params)
631608
if err != nil {
632609
s.handleErrors(w, err)
633610
return
@@ -644,7 +621,7 @@ func (s *ShibuyaAPI) collectionDeploymentHandler(w http.ResponseWriter, r *http.
644621
}
645622

646623
func (s *ShibuyaAPI) collectionTriggerHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
647-
collection, err := checkCollectionOwnership(r, params)
624+
collection, err := hasCollectionOwnership(r, params)
648625
if err != nil {
649626
s.handleErrors(w, err)
650627
return
@@ -656,7 +633,7 @@ func (s *ShibuyaAPI) collectionTriggerHandler(w http.ResponseWriter, r *http.Req
656633
}
657634

658635
func (s *ShibuyaAPI) collectionTermHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
659-
collection, err := checkCollectionOwnership(r, params)
636+
collection, err := hasCollectionOwnership(r, params)
660637
if err != nil {
661638
s.handleErrors(w, err)
662639
return
@@ -668,7 +645,7 @@ func (s *ShibuyaAPI) collectionTermHandler(w http.ResponseWriter, r *http.Reques
668645
}
669646

670647
func (s *ShibuyaAPI) collectionStatusHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
671-
collection, err := checkCollectionOwnership(r, params)
648+
collection, err := hasCollectionOwnership(r, params)
672649
if err != nil {
673650
s.handleErrors(w, err)
674651
return
@@ -681,7 +658,7 @@ func (s *ShibuyaAPI) collectionStatusHandler(w http.ResponseWriter, r *http.Requ
681658
}
682659

683660
func (s *ShibuyaAPI) collectionPurgeHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
684-
collection, err := checkCollectionOwnership(r, params)
661+
collection, err := hasCollectionOwnership(r, params)
685662
if err != nil {
686663
s.handleErrors(w, err)
687664
return
@@ -714,7 +691,7 @@ func (s *ShibuyaAPI) planLogHandler(w http.ResponseWriter, r *http.Request, para
714691
}
715692

716693
func (s *ShibuyaAPI) streamCollectionMetrics(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
717-
collection, err := checkCollectionOwnership(r, params)
694+
collection, err := hasCollectionOwnership(r, params)
718695
if err != nil {
719696
s.handleErrors(w, err)
720697
return
@@ -789,7 +766,7 @@ type Route struct {
789766
type Routes []*Route
790767

791768
func (s *ShibuyaAPI) InitRoutes() Routes {
792-
return Routes{
769+
routes := Routes{
793770
&Route{"get_projects", "GET", "/api/projects", s.projectsGetHandler},
794771
&Route{"create_project", "POST", "/api/projects", s.projectCreateHandler},
795772
&Route{"delete_project", "DELETE", "/api/projects/:project_id", s.projectDeleteHandler},
@@ -833,4 +810,12 @@ func (s *ShibuyaAPI) InitRoutes() Routes {
833810

834811
&Route{"admin_collections", "GET", "/api/admin/collections", s.collectionAdminGetHandler},
835812
}
813+
for _, r := range routes {
814+
// TODO! We don't require auth for usage endpoint for now.
815+
if strings.Contains(r.Path, "usage") {
816+
continue
817+
}
818+
r.HandlerFunc = s.authRequired(r.HandlerFunc)
819+
}
820+
return routes
836821
}

shibuya/api/middlewares.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package api
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
8+
"github.com/julienschmidt/httprouter"
9+
"github.com/rakutentech/shibuya/shibuya/model"
10+
)
11+
12+
const (
13+
accountKey = "account"
14+
)
15+
16+
func authWithSession(r *http.Request) (*model.Account, error) {
17+
account := model.GetAccountBySession(r)
18+
if account == nil {
19+
return nil, makeLoginError()
20+
}
21+
return account, nil
22+
}
23+
24+
// TODO add JWT token auth in the future
25+
func authWithToken(_ *http.Request) (*model.Account, error) {
26+
return nil, errors.New("No token presented")
27+
}
28+
29+
func (s *ShibuyaAPI) authRequired(next httprouter.Handle) httprouter.Handle {
30+
return httprouter.Handle(func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
31+
var account *model.Account
32+
var err error
33+
account, err = authWithSession(r)
34+
if err != nil {
35+
s.handleErrors(w, err)
36+
return
37+
}
38+
next(w, r.WithContext(context.WithValue(r.Context(), accountKey, account)), params)
39+
})
40+
}

shibuya/api/networkutils.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package api
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
)
7+
8+
func retrieveClientIP(r *http.Request) string {
9+
t := r.Header.Get("x-forwarded-for")
10+
if t == "" {
11+
return r.RemoteAddr
12+
}
13+
return strings.Split(t, ",")[0]
14+
}

shibuya/api/ownership.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package api
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/julienschmidt/httprouter"
7+
"github.com/rakutentech/shibuya/shibuya/model"
8+
)
9+
10+
func hasProjectOwnership(project *model.Project, account *model.Account) bool {
11+
if _, ok := account.MLMap[project.Owner]; !ok {
12+
if !account.IsAdmin() {
13+
return false
14+
}
15+
}
16+
return true
17+
}
18+
19+
func hasCollectionOwnership(r *http.Request, params httprouter.Params) (*model.Collection, error) {
20+
collection, err := getCollection(params.ByName("collection_id"))
21+
if err != nil {
22+
return nil, err
23+
}
24+
account := r.Context().Value(accountKey).(*model.Account)
25+
project, err := model.GetProject(collection.ProjectID)
26+
if err != nil {
27+
return nil, err
28+
}
29+
if r := hasProjectOwnership(project, account); !r {
30+
return nil, makeCollectionOwnershipError()
31+
}
32+
return collection, nil
33+
}

0 commit comments

Comments
 (0)