Skip to content

Commit

Permalink
feat: refactors in preparation for ogent server (#2668)
Browse files Browse the repository at this point in the history
  • Loading branch information
hspitzley-czi authored Nov 13, 2023
1 parent 013c622 commit 67578af
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 88 deletions.
2 changes: 1 addition & 1 deletion api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func exec(ctx context.Context) error {
logrus.Info("Sentry disabled for environment: ", cfg.Api.DeploymentStage)
}

return api.MakeApp(ctx, cfg).Listen()
return api.MakeFiberApp(ctx, cfg).Listen()
}

// @title Happy API
Expand Down
30 changes: 12 additions & 18 deletions api/pkg/api/app.go → api/pkg/api/app_fiber.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func MakeAPIApplication(cfg *setup.Configuration) *APIApplication {
}
}

func MakeApp(ctx context.Context, cfg *setup.Configuration) *APIApplication {
func MakeFiberApp(ctx context.Context, cfg *setup.Configuration) *APIApplication {
db := store.MakeDB(cfg.Database)
return MakeAppWithDB(ctx, cfg, db)
}
Expand Down Expand Up @@ -67,18 +67,8 @@ func MakeAppWithDB(ctx context.Context, cfg *setup.Configuration, db *store.DB)

v1 := apiApp.FiberApp.Group("/v1")
if *cfg.Auth.Enable {
verifiers := []request.OIDCVerifier{
request.MakeGithubVerifier("chanzuckerberg"),
}
for _, provider := range cfg.Auth.Providers {
verifier, err := request.MakeOIDCProvider(ctx, provider.IssuerURL, provider.ClientID, request.DefaultClaimsVerifier)
if err != nil {
logrus.Fatalf("failed to create OIDC verifier with error: %s", err.Error())
}
verifiers = append(verifiers, verifier)
}

v1.Use(request.MakeAuth(request.MakeMultiOIDCVerifier(verifiers...)))
verifier := request.MakeVerifierFromConfig(ctx, cfg)
v1.Use(request.MakeFiberAuthMiddleware(verifier))
}

v1.Use(fibersentry.New(fibersentry.Config{
Expand All @@ -87,11 +77,15 @@ func MakeAppWithDB(ctx context.Context, cfg *setup.Configuration, db *store.DB)
}))
v1.Use(func(c *fiber.Ctx) error {
user := sentry.User{}
if email := c.Locals(request.OIDCClaimsEmail{}); email != nil {
user.Email = email.(string)
}
if actor := c.Locals(request.OIDCClaimsGHActor{}); actor != nil {
user.Username = actor.(string)
oidcValues := c.Locals(request.OIDCAuthKey{})
if oidcValues != nil {
oidcValues := oidcValues.(*request.OIDCAuthValues)
if len(oidcValues.Email) > 0 {
user.Email = oidcValues.Email
}
if len(oidcValues.Actor) > 0 {
user.Username = oidcValues.Actor
}
}
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetUser(user)
Expand Down
2 changes: 1 addition & 1 deletion api/pkg/api/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestVersionCheckFail(t *testing.T) {
{
// unrestricted client without a version
userAgent: "happy-cli",
errorMessage: "expected version so be specified for happy-cli in the User-Agent header (format: happy-cli/<version>)",
errorMessage: "expected version to be specified for happy-cli in the User-Agent header (format: happy-cli/<version>)",
},
}

Expand Down
19 changes: 4 additions & 15 deletions api/pkg/api/config_group.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package api

import (
"regexp"
"strings"

"github.com/chanzuckerberg/happy/api/pkg/cmd"
"github.com/chanzuckerberg/happy/api/pkg/request"
"github.com/chanzuckerberg/happy/api/pkg/response"
Expand Down Expand Up @@ -103,7 +100,7 @@ func (c *ConfigHandler) configDumpHandler(ctx *fiber.Ctx) error {
// @Router /v1/config/copy [POST]
func (c *ConfigHandler) configCopyHandler(ctx *fiber.Ctx) error {
payload := getPayload[model.CopyAppConfigPayload](ctx)
payload.Key = standardizeKey(payload.Key)
payload.Key = request.StandardizeKey(payload.Key)

record, err := c.config.CopyAppConfig(&payload)
if err != nil {
Expand Down Expand Up @@ -182,7 +179,7 @@ func (c *ConfigHandler) getConfigsHandler(ctx *fiber.Ctx) error {
// @Router /v1/configs/ [POST]
func (c *ConfigHandler) postConfigsHandler(ctx *fiber.Ctx) error {
payload := getPayload[model.AppConfigPayload](ctx)
payload.Key = standardizeKey(payload.Key)
payload.Key = request.StandardizeKey(payload.Key)
record, err := c.config.SetConfigValue(&payload)
if err != nil {
return response.ServerErrorResponse(ctx, err.Error())
Expand All @@ -202,7 +199,7 @@ func (c *ConfigHandler) postConfigsHandler(ctx *fiber.Ctx) error {
func (c *ConfigHandler) getConfigByKeyHandler(ctx *fiber.Ctx) error {
payload := model.AppConfigLookupPayload{
AppMetadata: getPayload[model.AppMetadata](ctx),
ConfigKey: model.ConfigKey{Key: standardizeKey(ctx.Params("key"))},
ConfigKey: model.ConfigKey{Key: request.StandardizeKey(ctx.Params("key"))},
}
record, err := c.config.GetResolvedAppConfig(&payload)
if err != nil {
Expand All @@ -229,7 +226,7 @@ func (c *ConfigHandler) getConfigByKeyHandler(ctx *fiber.Ctx) error {
func (c *ConfigHandler) deleteConfigByKeyHandler(ctx *fiber.Ctx) error {
payload := model.AppConfigLookupPayload{
AppMetadata: getPayload[model.AppMetadata](ctx),
ConfigKey: model.ConfigKey{Key: standardizeKey(ctx.Params("key"))},
ConfigKey: model.ConfigKey{Key: request.StandardizeKey(ctx.Params("key"))},
}
record, err := c.config.DeleteAppConfig(&payload)
if err != nil {
Expand All @@ -252,11 +249,3 @@ func wrapAppConfigsWithCount(records []*model.AppConfig) model.WrappedAppConfigs
Count: len(records),
}
}

func standardizeKey(key string) string {
key = strings.ToUpper(key)

// replace all non-alphanumeric characters with _
regex := regexp.MustCompile("[^A-Z0-9]")
return regex.ReplaceAllString(key, "_")
}
42 changes: 10 additions & 32 deletions api/pkg/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func MakeConfig(db *store.DB) Config {
}

func MakeAppConfigFromEnt(in *ent.AppConfig) *model.AppConfig {
if in == nil {
return nil
}
deletedAt := gorm.DeletedAt{
Valid: false,
}
Expand Down Expand Up @@ -122,28 +125,14 @@ func (c *dbConfig) GetAppConfigsForEnv(payload *model.AppMetadata) ([]*model.Res

// Returns resolved stack-level configs for the given app, env, and stack (with overrides applied)
func (c *dbConfig) GetAppConfigsForStack(payload *model.AppMetadata) ([]*model.ResolvedAppConfig, error) {
// get all appconfigs for the app/env and order by key, then by stack DESC. Take the first item for each key
db := c.DB.GetDB()
records, err := appEnvScopedQuery(db.AppConfig, payload).
Where(appconfig.StackIn(payload.Stack, "")).
Order(ent.Asc(appconfig.FieldKey), ent.Desc(appconfig.FieldStack)).
All(context.Background())
records, err := c.DB.ListAppConfigsForStack(context.Background(), payload.AppName, payload.Environment, payload.Stack)
if err != nil {
return nil, err
}

// we'll get at most 2 config records for each key (one for env and one for stack), so we'll use a map to dedupe
// and select the stack record if it exists (since we order by stack DESC) and the env record otherwise
resolvedMap := map[string]*ent.AppConfig{}
for _, record := range records {
if _, ok := resolvedMap[record.Key]; !ok {
resolvedMap[record.Key] = record
}
}

results := []*model.ResolvedAppConfig{}
for _, record := range resolvedMap {
results = append(results, &model.ResolvedAppConfig{AppConfig: *MakeAppConfigFromEnt(record), Source: record.Source.String()})
results := make([]*model.ResolvedAppConfig, len(records))
for idx, record := range records {
results[idx] = &model.ResolvedAppConfig{AppConfig: *MakeAppConfigFromEnt(record), Source: record.Source.String()}
}

return results, nil
Expand All @@ -167,24 +156,13 @@ func rollback(tx *ent.Tx, err error) error {
}

func (c *dbConfig) GetResolvedAppConfig(payload *model.AppConfigLookupPayload) (*model.ResolvedAppConfig, error) {
db := c.DB.GetDB()
records, err := appEnvScopedQuery(db.AppConfig, &payload.AppMetadata).
Where(
appconfig.Key(payload.Key),
appconfig.StackIn(payload.Stack, ""),
).
Order(ent.Desc(appconfig.FieldStack)).
All(context.Background())
record, err := c.DB.ReadAppConfig(context.Background(), payload.AppName, payload.Environment, payload.Stack, payload.Key)
if err != nil {
return nil, errors.Wrap(err, "[GetResolvedAppConfig] unable to query app configs")
return nil, err
}

if len(records) == 0 {
if record == nil {
return nil, nil
}

// at most 2 records are defined and since we order by stack DESC, the first record is the stack-specific one if it exists
record := records[0]
return &model.ResolvedAppConfig{AppConfig: *MakeAppConfigFromEnt(record), Source: record.Source.String()}, nil
}

Expand Down
56 changes: 39 additions & 17 deletions api/pkg/request/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/chanzuckerberg/happy/api/pkg/response"
"github.com/chanzuckerberg/happy/api/pkg/setup"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gofiber/fiber/v2"
"github.com/hashicorp/go-multierror"
Expand Down Expand Up @@ -167,49 +168,70 @@ func MakeOIDCProvider(ctx context.Context, issuerURL, clientID string, claimsVer
}, nil
}

type OIDCSubjectKey struct{}
type OIDCClaimsGHActor struct{}
type OIDCClaimsEmail struct{}
type OIDCAuthKey struct{}

func validateAuthHeader(c *fiber.Ctx, authHeader string, verifier OIDCVerifier) error {
type OIDCAuthValues struct {
Subject string
Email string
Actor string
}

func ValidateAuthHeader(ctx context.Context, authHeader string, verifier OIDCVerifier) (*OIDCAuthValues, error) {
rawIDToken := stripBearerPrefixFromTokenString(authHeader)
token, err := verifier.Verify(c.Context(), rawIDToken)
token, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return errors.Wrap(err, "unable to verify ID token")
return nil, errors.Wrap(err, "unable to verify ID token")
}

var claims struct {
Email string `json:"email"`
GithubActor string `json:"actor"`
Email string `json:"email"`
Actor string `json:"actor"`
}
err = token.Claims(&claims)
if err != nil {
return err
return nil, err
}
if claims.Email == "" && claims.GithubActor == "" {
if claims.Email == "" && claims.Actor == "" {
// TODO: can't throw an error here because it breaks TFE runs, log the issue for now
// return errors.New("ID token didn't have email or actor claims")
logrus.Warn("ID token didn't have email or actor claims")
}

c.Locals(OIDCSubjectKey{}, token.Subject)
c.Locals(OIDCClaimsGHActor{}, claims.GithubActor)
c.Locals(OIDCClaimsEmail{}, claims.Email)
return &OIDCAuthValues{
Subject: token.Subject,
Email: claims.Email,
Actor: claims.Actor,
}, nil
}

return nil
func MakeVerifierFromConfig(ctx context.Context, cfg *setup.Configuration) OIDCVerifier {
verifiers := []OIDCVerifier{
MakeGithubVerifier("chanzuckerberg"),
}
for _, provider := range cfg.Auth.Providers {
verifier, err := MakeOIDCProvider(ctx, provider.IssuerURL, provider.ClientID, DefaultClaimsVerifier)
if err != nil {
logrus.Fatalf("failed to create OIDC verifier with error: %s", err.Error())
}
verifiers = append(verifiers, verifier)
}

return MakeMultiOIDCVerifier(verifiers...)
}

func MakeAuth(verifier OIDCVerifier) fiber.Handler {
func MakeFiberAuthMiddleware(verifier OIDCVerifier) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.GetReqHeaders()["Authorization"]
authHeader := c.GetReqHeaders()[fiber.HeaderAuthorization]
if len(authHeader) <= 0 {
return response.AuthErrorResponse(c, "missing auth header")
}

err := validateAuthHeader(c, authHeader, verifier)
oidcValues, err := ValidateAuthHeader(c.Context(), authHeader, verifier)
if err != nil {
return response.AuthErrorResponse(c, err.Error())
}

c.Locals(OIDCAuthKey{}, oidcValues)
return c.Next()
}
}
4 changes: 2 additions & 2 deletions api/pkg/request/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestValidateAuthHeaderNoErrors(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
err := validateAuthHeader(ctx, tc.authHeader, tc.verifier)
_, err := ValidateAuthHeader(ctx.Context(), tc.authHeader, tc.verifier)
r.NoError(err)
})
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestValidateAuthHeaderErrors(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
err := validateAuthHeader(ctx, tc.authHeader, tc.verifier)
_, err := ValidateAuthHeader(ctx.Context(), tc.authHeader, tc.verifier)
r.Error(err)
})
}
Expand Down
14 changes: 14 additions & 0 deletions api/pkg/request/format.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package request

import (
"regexp"
"strings"
)

func StandardizeKey(key string) string {
key = strings.ToUpper(key)

// replace all non-alphanumeric characters with _
regex := regexp.MustCompile("[^A-Z0-9]")
return regex.ReplaceAllString(key, "_")
}
4 changes: 2 additions & 2 deletions api/pkg/request/version_check_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var (

func init() {
MinimumVersions = map[string]string{
"happy-cli": "0.53.6",
"happy-cli": "0.90.0",
"happy-provider": "0.52.0",
}

Expand Down Expand Up @@ -50,7 +50,7 @@ func validateUserAgentVersion(userAgent string) error {
}

if len(clientVersionParts) < 2 {
return errors.Errorf("expected version so be specified for %s in the User-Agent header (format: %s/<version>)", client, client)
return errors.Errorf("expected version to be specified for %s in the User-Agent header (format: %s/<version>)", client, client)
}

versionStr := clientVersionParts[1]
Expand Down
Loading

0 comments on commit 67578af

Please sign in to comment.