diff --git a/auth/auth.go b/auth/auth.go index 7a255327..a4a7a0cf 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -172,6 +172,7 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) { return h, err } +// NoAuth returns a handler that does not perform any authentication. func NoAuth() (HandlerFunc, error) { return func(w http.ResponseWriter, r *http.Request) (context.Context, error) { return r.Context(), nil @@ -232,6 +233,7 @@ func NewAuth(ac Auth, log *zap.Logger, opt Options, hFn ...HandlerFunc) ( }, nil } +// SimpleHandler is a simple auth handler that sets the user ID, provider and role func SimpleHandler(ac Auth) (HandlerFunc, error) { return func(_ http.ResponseWriter, r *http.Request) (context.Context, error) { c := r.Context() @@ -257,6 +259,7 @@ func SimpleHandler(ac Auth) (HandlerFunc, error) { var Err401 = errors.New("401 unauthorized") +// HeaderHandler is a middleware that checks for a header value func HeaderHandler(ac Auth) (HandlerFunc, error) { hdr := ac.Header @@ -287,14 +290,17 @@ func HeaderHandler(ac Auth) (HandlerFunc, error) { }, nil } +// IsAuth returns true if the context contains a user ID func IsAuth(c context.Context) bool { return c != nil && c.Value(core.UserIDKey) != nil } +// UserID returns the user ID from the context func UserID(c context.Context) interface{} { return c.Value(core.UserIDKey) } +// UserIDInt returns the user ID from the context as an int func UserIDInt(c context.Context) int { v, ok := UserID(c).(string) if !ok { diff --git a/auth/internal/rails/auth.go b/auth/internal/rails/auth.go index cc0dc325..a84206f5 100644 --- a/auth/internal/rails/auth.go +++ b/auth/internal/rails/auth.go @@ -30,6 +30,7 @@ type Auth struct { AuthSalt string } +// NewAuth creates a new Auth instance func NewAuth(version, secret string) (*Auth, error) { ra := &Auth{ Secret: secret, @@ -60,6 +61,7 @@ func NewAuth(version, secret string) (*Auth, error) { return ra, nil } +// ParseCookie parses the rails cookie and returns the user ID func (ra Auth) ParseCookie(cookie string) (userID string, err error) { var dcookie []byte @@ -87,6 +89,7 @@ func (ra Auth) ParseCookie(cookie string) (userID string, err error) { return } +// ParseCookie parses the rails cookie and returns the user ID func ParseCookie(cookie string) (string, error) { if cookie[0] != '{' { return getUserId4([]byte(cookie)) @@ -95,6 +98,7 @@ func ParseCookie(cookie string) (string, error) { return getUserId([]byte(cookie)) } +// getUserId extracts the user ID from the session data func getUserId(data []byte) (userID string, err error) { var sessionData map[string]interface{} @@ -135,10 +139,11 @@ func getUserId(data []byte) (userID string, err error) { return } +// getUserId4 extracts the user ID from the session data func getUserId4(data []byte) (userID string, err error) { sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap() if err != nil { - return + return "", err } wardenData, ok := sessionData["warden.user.user.key"] diff --git a/auth/internal/rails/cookie.go b/auth/internal/rails/cookie.go index 72c225cb..8bc7fc8f 100644 --- a/auth/internal/rails/cookie.go +++ b/auth/internal/rails/cookie.go @@ -12,6 +12,7 @@ import ( "golang.org/x/crypto/pbkdf2" ) +// parseCookie decrypts and parses a Rails session cookie func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) { return session.DecryptSignedCookie( cookie, @@ -22,6 +23,7 @@ func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) { // {"session_id":"a71d6ffcd4ed5572ea2097f569eb95ef","warden.user.user.key":[[2],"$2a$11$q9Br7m4wJxQvF11hAHvTZO"],"_csrf_token":"HsYgrD2YBaWAabOYceN0hluNRnGuz49XiplmMPt43aY="} +// parseCookie52 decrypts and parses a Rails 5.2+ session cookie func parseCookie52(cookie, secretKeyBase, authSalt string) ([]byte, error) { ecookie, err := url.QueryUnescape(cookie) if err != nil { diff --git a/auth/jwt.go b/auth/jwt.go index 63e9c9a4..c3743103 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -14,6 +14,8 @@ const ( authHeader = "Authorization" ) +// JwtHandler is a middleware that checks for a JWT token in the cookie or the +// authorization header. If the token is found, it is validated and the claims func JwtHandler(ac Auth) (HandlerFunc, error) { jwtProvider, err := provider.NewProvider(ac.JWT) if err != nil { diff --git a/auth/provider/auth0.go b/auth/provider/auth0.go index 1a3e07b8..9125a53d 100644 --- a/auth/provider/auth0.go +++ b/auth/provider/auth0.go @@ -15,6 +15,7 @@ type Auth0Provider struct { issuer string } +// NewAuth0Provider creates a new Auth0 JWT provider func NewAuth0Provider(config JWTConfig) (*Auth0Provider, error) { key, err := getKey(config) if err != nil { @@ -27,12 +28,14 @@ func NewAuth0Provider(config JWTConfig) (*Auth0Provider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *Auth0Provider) KeyFunc() jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { return p.key, nil } } +// VerifyAudience checks if the audience claim is valid func (p *Auth0Provider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -40,6 +43,7 @@ func (p *Auth0Provider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer checks if the issuer claim is valid func (p *Auth0Provider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -47,6 +51,7 @@ func (p *Auth0Provider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *Auth0Provider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") diff --git a/auth/provider/firebase.go b/auth/provider/firebase.go index fceab5a4..c646e830 100644 --- a/auth/provider/firebase.go +++ b/auth/provider/firebase.go @@ -35,6 +35,7 @@ type FirebaseProvider struct { issuer string } +// NewFirebaseProvider creates a new Firebase JWT provider func NewFirebaseProvider(config JWTConfig) (*FirebaseProvider, error) { issuer := config.Issuer if issuer == "" { @@ -46,10 +47,12 @@ func NewFirebaseProvider(config JWTConfig) (*FirebaseProvider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *FirebaseProvider) KeyFunc() jwt.Keyfunc { return firebaseKeyFunction } +// VerifyAudience checks if the audience claim is valid func (p *FirebaseProvider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -57,6 +60,7 @@ func (p *FirebaseProvider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer checks if the issuer claim is valid func (p *FirebaseProvider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -64,6 +68,7 @@ func (p *FirebaseProvider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *FirebaseProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") @@ -85,6 +90,7 @@ func (e *firebaseKeyError) Error() string { return e.Message + " " + e.Err.Error() } +// firebaseKeyFunction returns the public key used to verify the JWT token func firebaseKeyFunction(token *jwt.Token) (interface{}, error) { kid, ok := token.Header["kid"] diff --git a/auth/provider/generic.go b/auth/provider/generic.go index 6d9f0306..0a868154 100644 --- a/auth/provider/generic.go +++ b/auth/provider/generic.go @@ -14,6 +14,7 @@ type GenericProvider struct { issuer string } +// NewGenericProvider creates a new generic JWT provider func NewGenericProvider(config JWTConfig) (*GenericProvider, error) { key, err := getKey(config) if err != nil { @@ -26,12 +27,14 @@ func NewGenericProvider(config JWTConfig) (*GenericProvider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *GenericProvider) KeyFunc() jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { return p.key, nil } } +// VerifyAudience verifies the audience claim of the JWT token func (p *GenericProvider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -39,6 +42,7 @@ func (p *GenericProvider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer verifies the issuer claim of the JWT token func (p *GenericProvider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -46,6 +50,7 @@ func (p *GenericProvider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *GenericProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") diff --git a/auth/provider/jwks.go b/auth/provider/jwks.go index 9a12a77a..25f3fc55 100644 --- a/auth/provider/jwks.go +++ b/auth/provider/jwks.go @@ -19,6 +19,7 @@ type keychainCache struct { semaphore int32 } +// newKeychainCache creates a new KeychainCache func newKeychainCache(jwksURL string, refreshInterval, minRefreshInterval int) *keychainCache { ar := jwk.NewAutoRefresh(context.Background()) if refreshInterval > 0 { @@ -34,6 +35,7 @@ func newKeychainCache(jwksURL string, refreshInterval, minRefreshInterval int) * } } +// getKey returns the key from the cache func (k *keychainCache) getKey(kid string) (interface{}, error) { set, err := k.keyCache.Fetch(context.TODO(), k.jwksURL) if err != nil { @@ -89,6 +91,7 @@ type JWKSProvider struct { cache *keychainCache } +// NewJWKSProvider creates a new JWKSProvider func NewJWKSProvider(config JWTConfig) (*JWKSProvider, error) { if config.JWKSURL == "" { return nil, errors.New("undefined JWKSURL") @@ -100,6 +103,7 @@ func NewJWKSProvider(config JWTConfig) (*JWKSProvider, error) { }, nil } +// KeyFunc returns a function that returns the key used to verify the JWT token func (p *JWKSProvider) KeyFunc() jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { if token == nil { @@ -123,6 +127,7 @@ func (p *JWKSProvider) KeyFunc() jwt.Keyfunc { } } +// VerifyAudience checks if the audience claim is valid func (p *JWKSProvider) VerifyAudience(claims jwt.MapClaims) bool { if claims == nil { return false @@ -130,6 +135,7 @@ func (p *JWKSProvider) VerifyAudience(claims jwt.MapClaims) bool { return claims.VerifyAudience(p.aud, p.aud != "") } +// VerifyIssuer checks if the issuer claim is valid func (p *JWKSProvider) VerifyIssuer(claims jwt.MapClaims) bool { if claims == nil { return false @@ -137,6 +143,7 @@ func (p *JWKSProvider) VerifyIssuer(claims jwt.MapClaims) bool { return claims.VerifyIssuer(p.issuer, p.issuer != "") } +// SetContextValues sets the user ID and provider in the context func (p *JWKSProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) { if claims == nil { return ctx, errors.New("undefined claims") diff --git a/auth/provider/provider.go b/auth/provider/provider.go index d322b9ee..efafa2f7 100644 --- a/auth/provider/provider.go +++ b/auth/provider/provider.go @@ -51,6 +51,7 @@ type JWTProvider interface { SetContextValues(context.Context, jwt.MapClaims) (context.Context, error) } +// NewProvider creates a new JWT provider based on the config values func NewProvider(config JWTConfig) (JWTProvider, error) { switch config.Provider { case "auth0": @@ -64,20 +65,21 @@ func NewProvider(config JWTConfig) (JWTProvider, error) { } } +// getKey returns the key used to verify the JWT token func getKey(config JWTConfig) (interface{}, error) { var key interface{} var err error switch { case config.PubKey != "": - pk := []byte(config.PubKey) + pubKey := []byte(config.PubKey) switch config.PubKeyType { case "ecdsa": - key, err = jwt.ParseECPublicKeyFromPEM(pk) + key, err = jwt.ParseECPublicKeyFromPEM(pubKey) case "rsa": - key, err = jwt.ParseRSAPublicKeyFromPEM(pk) + key, err = jwt.ParseRSAPublicKeyFromPEM(pubKey) default: - key, err = jwt.ParseECPublicKeyFromPEM(pk) + key, err = jwt.ParseECPublicKeyFromPEM(pubKey) } if err != nil { return nil, err diff --git a/auth/rails.go b/auth/rails.go index 6f4611a4..0eddd4bd 100644 --- a/auth/rails.go +++ b/auth/rails.go @@ -15,6 +15,7 @@ import ( "github.com/gomodule/redigo/redis" ) +// RailsHandler returns a handler that authenticates using a Rails session cookie func RailsHandler(ac Auth) (HandlerFunc, error) { ru := ac.Rails.URL @@ -29,6 +30,7 @@ func RailsHandler(ac Auth) (HandlerFunc, error) { return RailsCookieHandler(ac) } +// RailsRedisHandler returns a handler that authenticates using a Rails session cookie func RailsRedisHandler(ac Auth) (HandlerFunc, error) { cookie := ac.Cookie @@ -95,6 +97,7 @@ func RailsRedisHandler(ac Auth) (HandlerFunc, error) { }, nil } +// RailsMemcacheHandler returns a handler that authenticates using a Rails session cookie func RailsMemcacheHandler(ac Auth) (HandlerFunc, error) { cookie := ac.Cookie @@ -138,6 +141,7 @@ func RailsMemcacheHandler(ac Auth) (HandlerFunc, error) { }, nil } +// RailsCookieHandler returns a handler that authenticates using a Rails session cookie func RailsCookieHandler(ac Auth) (HandlerFunc, error) { cookie := ac.Cookie if len(cookie) == 0 { @@ -168,6 +172,7 @@ func RailsCookieHandler(ac Auth) (HandlerFunc, error) { }, nil } +// railsAuth returns a new rails auth instance func railsAuth(ac Auth) (*rails.Auth, error) { secret := ac.Rails.SecretKeyBase if len(secret) == 0 { diff --git a/cmd/cmd.go b/cmd/cmd.go index 96ce0944..cb96f887 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -29,6 +29,7 @@ var ( cpath string ) +// Cmd is the entry point for the CLI func Cmd() { log = newLogger(false).Sugar() @@ -64,14 +65,12 @@ func Cmd() { } } +// setup is a helper function to read the config file func setup(cpath string) { if conf != nil { return } - setupAgain(cpath) -} -func setupAgain(cpath string) { cp, err := filepath.Abs(cpath) if err != nil { log.Fatal(err) @@ -83,6 +82,7 @@ func setupAgain(cpath string) { } } +// initDB is a helper function to initialize the database connection func initDB(openDB bool) { var err error @@ -97,6 +97,7 @@ func initDB(openDB bool) { dbOpened = openDB } +// newLogger creates a new logger func newLogger(json bool) *zap.Logger { econf := zapcore.EncoderConfig{ MessageKey: "msg", diff --git a/cmd/cmd_admin.go b/cmd/cmd_admin.go index c601d196..c0646630 100644 --- a/cmd/cmd_admin.go +++ b/cmd/cmd_admin.go @@ -16,6 +16,7 @@ var ( secret string ) +// deployCmd deploys a new config or rolls back the active config func deployCmd() *cobra.Command { c := &cobra.Command{ Use: "deploy", @@ -36,6 +37,7 @@ func deployCmd() *cobra.Command { return c } +// initCmd initializes the admin database func initCmd() *cobra.Command { c := &cobra.Command{ Use: "init", @@ -45,6 +47,7 @@ func initCmd() *cobra.Command { return c } +// cmdInit initializes the admin database func cmdInit(cmd *cobra.Command, args []string) { setup(cpath) initDB(true) @@ -56,6 +59,7 @@ func cmdInit(cmd *cobra.Command, args []string) { log.Infof("init successful: %s", name) } +// cmdDeploy deploys a new config func cmdDeploy(cmd *cobra.Command, args []string) { if host == "" { log.Fatalf("--host is a required argument") @@ -79,6 +83,7 @@ func cmdDeploy(cmd *cobra.Command, args []string) { } } +// cmdRollback rolls back the active config func cmdRollback(cmd *cobra.Command, args []string) { if host == "" { log.Fatalf("--host is a required argument") diff --git a/cmd/cmd_db.go b/cmd/cmd_db.go index f1d5e896..cba08c40 100644 --- a/cmd/cmd_db.go +++ b/cmd/cmd_db.go @@ -4,6 +4,7 @@ import ( "github.com/spf13/cobra" ) +// dbCmd creates the db command func dbCmd() *cobra.Command { c := &cobra.Command{ Use: "db", @@ -30,6 +31,7 @@ func dbCmd() *cobra.Command { return c } +// cmdDBSeed seeds the database func cmdDBSetup(cmd *cobra.Command, args []string) { setup(cpath) diff --git a/cmd/cmd_migrate.go b/cmd/cmd_migrate.go index fa211142..4cc5e18c 100644 --- a/cmd/cmd_migrate.go +++ b/cmd/cmd_migrate.go @@ -14,6 +14,7 @@ import ( "golang.org/x/text/language" ) +// This is the cobra CLI command for the migrate subcommand func migrateCmd() *cobra.Command { c := &cobra.Command{ Use: "migrate", @@ -69,6 +70,7 @@ var newMigrationText = `-- Write your migrate up statements here -- Then delete the separator line above. ` +// cmdDBMigrate is the main function for the migrate subcommand func cmdDBMigrate(cmd *cobra.Command, args []string) { doneSomething := false @@ -93,7 +95,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { m.Data = getMigrationVars(conf) - err = m.LoadMigrations(conf.RelPath(conf.MigrationsPath)) + err = m.LoadMigrations(conf.AbsolutePath(conf.MigrationsPath)) if err != nil { log.Fatalf("Failed to load migrations: %s", err) } @@ -197,6 +199,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { } } +// cmdMigrateStatus is the function for the migrate status subcommand func cmdMigrateStatus(cmd *cobra.Command, args []string) { setup(cpath) initDB(true) @@ -212,7 +215,7 @@ func cmdMigrateStatus(cmd *cobra.Command, args []string) { m.Data = getMigrationVars(conf) - err = m.LoadMigrations(conf.RelPath(conf.MigrationsPath)) + err = m.LoadMigrations(conf.AbsolutePath(conf.MigrationsPath)) if err != nil { log.Fatalf("Failed to load migrations: %s", err) } @@ -238,6 +241,7 @@ func cmdMigrateStatus(cmd *cobra.Command, args []string) { status, mver, len(m.Migrations), conf.DB.Host, conf.DB.DBName) } +// cmdMigrateNew is the function for the migrate new subcommand func cmdMigrateNew(cmd *cobra.Command, args []string) { if len(args) != 1 { cmd.Help() //nolint:errcheck @@ -248,7 +252,7 @@ func cmdMigrateNew(cmd *cobra.Command, args []string) { initDB(false) name := args[0] - migrationsPath := conf.RelPath(conf.MigrationsPath) + migrationsPath := conf.AbsolutePath(conf.MigrationsPath) m, err := migrate.FindMigrations(migrationsPath) if err != nil { @@ -306,6 +310,7 @@ func ExtractErrorLine(source string, position int) (ErrorLineExtract, error) { return ele, nil } +// getMigrationVars returns the variables to be used in the migration templates func getMigrationVars(c *serv.Config) map[string]interface{} { en := cases.Title(language.English) diff --git a/cmd/cmd_new.go b/cmd/cmd_new.go index b7d52a41..12df832a 100644 --- a/cmd/cmd_new.go +++ b/cmd/cmd_new.go @@ -17,6 +17,7 @@ import ( var dbURL string +// This is the cobra CLI command for the new subcommand func newCmd() *cobra.Command { c := &cobra.Command{ Use: "new ", @@ -29,6 +30,7 @@ func newCmd() *cobra.Command { return c } +// cmdNew is the handler for the new subcommand func cmdNew(cmd *cobra.Command, args []string) { if len(args) != 1 { cmd.Help() //nolint:errcheck diff --git a/cmd/cmd_secrets.go b/cmd/cmd_secrets.go index da78b3b7..4116f7fd 100644 --- a/cmd/cmd_secrets.go +++ b/cmd/cmd_secrets.go @@ -82,7 +82,7 @@ func cmdSecrets() *cobra.Command { } else { setup(cpath) if conf.SecretsFile != "" { - fileName, err = filepath.Abs(conf.RelPath(conf.SecretsFile)) + fileName, err = filepath.Abs(conf.AbsolutePath(conf.SecretsFile)) } } diff --git a/cmd/cmd_seed.go b/cmd/cmd_seed.go index 2d427fb5..2c099326 100644 --- a/cmd/cmd_seed.go +++ b/cmd/cmd_seed.go @@ -24,6 +24,7 @@ import ( "github.com/spf13/cobra" ) +// cmdSeed is the cobra CLI for the seed subcommand func cmdDBSeed(cmd *cobra.Command, args []string) { setup(cpath) initDB(true) @@ -49,6 +50,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { log.Infof("Seed script completed") } +// compileAndRunJS compiles and runs the seed script func compileAndRunJS(seed string, db *sql.DB) error { b, err := os.ReadFile(seed) if err != nil { @@ -168,7 +170,7 @@ func compileAndRunJS(seed string, db *sql.DB) error { return err } -// func runFunc(call goja.FunctionCall) { +// graphQLFunc is a helper function to run a GraphQL query func graphQLFunc(gj *core.GraphJin, query string, data interface{}, opt map[string]string) map[string]interface{} { ct := context.Background() @@ -214,6 +216,7 @@ type csvSource struct { i int } +// NewCSVSource creates a new CSV source func NewCSVSource(filename string, sep rune) (*csvSource, error) { f, err := os.Open(filename) if err != nil { @@ -272,6 +275,7 @@ func (c *csvSource) Values() ([]interface{}, error) { return vals, nil } +// isDigit checks if a string is a digit func isDigit(v string) bool { for i := range v { if v[i] < '0' || v[i] > '9' { @@ -285,6 +289,7 @@ func (c *csvSource) Err() error { return nil } +// importCSV imports a CSV file into a table func importCSV(table, filename string, sep string, db *sql.DB) int64 { log.Infof("Seeding table: %s, From file: %s", table, filename) @@ -353,6 +358,7 @@ func logFunc(args ...interface{}) { } } +// avatarURL returns a random avatar URL func avatarURL(size int) string { if size == 0 { size = 200 diff --git a/cmd/cmd_serv.go b/cmd/cmd_serv.go index fc2bd161..98f130cc 100644 --- a/cmd/cmd_serv.go +++ b/cmd/cmd_serv.go @@ -7,6 +7,7 @@ import ( var deployActive bool +// servCmd is the cobra CLI command for the serve subcommand func servCmd() *cobra.Command { c := &cobra.Command{ Use: "serve", @@ -18,6 +19,7 @@ func servCmd() *cobra.Command { return c } +// cmdServ is the handler for the serve subcommand func cmdServ(*cobra.Command, []string) { setup(cpath) diff --git a/cmd/cmd_version.go b/cmd/cmd_version.go index 245d6c4f..8738a051 100644 --- a/cmd/cmd_version.go +++ b/cmd/cmd_version.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" ) +// This is cobra CLI command for the version subcommand func versionCmd() *cobra.Command { c := &cobra.Command{ Use: "version", @@ -16,10 +17,12 @@ func versionCmd() *cobra.Command { return c } +// cmdVersion is the handler for the version subcommand func cmdVersion(cmd *cobra.Command, args []string) { fmt.Printf("%s\n", BuildDetails()) } +// BuildDetails returns the version information func BuildDetails() string { if version == "" { return ` diff --git a/conf/config.go b/conf/config.go index 3db9f5b0..dcd12654 100644 --- a/conf/config.go +++ b/conf/config.go @@ -13,6 +13,7 @@ type configInfo struct { Inherits string } +// NewConfig creates a new config object func NewConfig(configPath, configFile string) (c *core.Config, err error) { fs := core.NewOsFS(configPath) if c, err = NewConfigWithFS(fs, configFile); err != nil { @@ -21,6 +22,7 @@ func NewConfig(configPath, configFile string) (c *core.Config, err error) { return } +// NewConfigWithFS creates a new config object using the provided filesystem func NewConfigWithFS(fs core.FS, configFile string) (c *core.Config, err error) { c = &core.Config{FS: fs} var ci configInfo @@ -47,6 +49,7 @@ func NewConfigWithFS(fs core.FS, configFile string) (c *core.Config, err error) return } +// readConfig reads the config file and unmarshals it into the provided struct func readConfig(fs core.FS, configFile string, v interface{}) (err error) { format := filepath.Ext(configFile) diff --git a/core/api.go b/core/api.go index 5fd647ef..b615b028 100644 --- a/core/api.go +++ b/core/api.go @@ -46,36 +46,36 @@ const ( // GraphJin struct is an instance of the GraphJin engine it holds all the required information like // datase schemas, relationships, etc that the GraphQL to SQL compiler would need to do it's job. -type graphjin struct { - conf *Config - db *sql.DB - log *_log.Logger - fs FS - trace Tracer - dbtype string - dbinfo *sdata.DBInfo - schema *sdata.DBSchema - allowList *allow.List - encKey [32]byte - encKeySet bool - cache Cache - queries sync.Map - roles map[string]*Role - roleStmt string - roleStmtMD psql.Metadata - tmap map[string]qcode.TConfig - rtmap map[string]ResolverFn - rmap map[string]resItem - abacEnabled bool - qc *qcode.Compiler - pc *psql.Compiler - subs sync.Map - prod bool - prodSec bool - namespace string - pf []byte - opts []Option - done chan bool +type GraphjinEngine struct { + conf *Config + db *sql.DB + log *_log.Logger + fs FS + trace Tracer + dbtype string + dbinfo *sdata.DBInfo + schema *sdata.DBSchema + allowList *allow.List + encryptionKey [32]byte + encryptionKeySet bool + cache Cache + queries sync.Map + roles map[string]*Role + roleStatement string + roleStatementMetadata psql.Metadata + tmap map[string]qcode.TConfig + rtmap map[string]ResolverFn + rmap map[string]resItem + abacEnabled bool + qcodeCompiler *qcode.Compiler + psqlCompiler *psql.Compiler + subs sync.Map + prod bool + prodSec bool + namespace string + printFormat []byte + opts []Option + done chan bool } type GraphJin struct { @@ -83,7 +83,7 @@ type GraphJin struct { done chan bool } -type Option func(*graphjin) error +type Option func(*GraphjinEngine) error // NewGraphJin creates the GraphJin struct, this involves querying the database to learn its // schemas and relationships @@ -104,6 +104,7 @@ func NewGraphJin(conf *Config, db *sql.DB, options ...Option) (g *GraphJin, err return } +// NewGraphJinWithFS creates the GraphJin struct, this involves querying the database to learn its func NewGraphJinWithFS(conf *Config, db *sql.DB, fs FS, options ...Option) (g *GraphJin, err error) { g = &GraphJin{done: make(chan bool)} if err = g.newGraphJin(conf, db, nil, fs, options...); err != nil { @@ -116,6 +117,7 @@ func NewGraphJinWithFS(conf *Config, db *sql.DB, fs FS, options ...Option) (g *G return } +// newGraphJinWithDBInfo creates the GraphJin struct, this involves querying the database to learn its // it all starts here func (g *GraphJin) newGraphJin(conf *Config, db *sql.DB, @@ -129,18 +131,18 @@ func (g *GraphJin) newGraphJin(conf *Config, t := time.Now() - gj := &graphjin{ - conf: conf, - db: db, - dbinfo: dbinfo, - log: _log.New(os.Stdout, "", 0), - prod: conf.Production, - prodSec: conf.Production, - pf: []byte(fmt.Sprintf("gj/%x:", t.Unix())), - opts: options, - fs: fs, - trace: &tracer{}, - done: g.done, + gj := &GraphjinEngine{ + conf: conf, + db: db, + dbinfo: dbinfo, + log: _log.New(os.Stdout, "", 0), + prod: conf.Production, + prodSec: conf.Production, + printFormat: []byte(fmt.Sprintf("gj/%x:", t.Unix())), + opts: options, + fs: fs, + trace: &tracer{}, + done: g.done, } if gj.conf.DisableProdSecurity { @@ -193,8 +195,8 @@ func (g *GraphJin) newGraphJin(conf *Config, if conf.SecretKey != "" { sk := sha256.Sum256([]byte(conf.SecretKey)) - gj.encKey = sk - gj.encKeySet = true + gj.encryptionKey = sk + gj.encryptionKeySet = true } g.Store(gj) @@ -202,28 +204,31 @@ func (g *GraphJin) newGraphJin(conf *Config, } func OptionSetNamespace(namespace string) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { s.namespace = namespace return nil } } +// OptionSetFS sets the file system to be used by GraphJin func OptionSetFS(fs FS) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { s.fs = fs return nil } } +// OptionSetTrace sets the tracer to be used by GraphJin func OptionSetTrace(trace Tracer) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { s.trace = trace return nil } } +// OptionSetResolver sets the resolver function to be used by GraphJin func OptionSetResolver(name string, fn ResolverFn) Option { - return func(s *graphjin) error { + return func(s *GraphjinEngine) error { if s.rtmap == nil { s.rtmap = s.newRTMap() } @@ -242,8 +247,8 @@ type Error struct { // Result struct contains the output of the GraphQL function this includes resulting json from the // database query and any error information type Result struct { - ns string - op qcode.QType + namespace string + operation qcode.QType name string sql string role string @@ -256,8 +261,8 @@ type Result struct { // Extensions *extensions `json:"extensions,omitempty"` } -// ReqConfig is used to pass request specific config values to the GraphQL and Subscribe functions. Dynamic variables can be set here. -type ReqConfig struct { +// RequestConfig is used to pass request specific config values to the GraphQL and Subscribe functions. Dynamic variables can be set here. +type RequestConfig struct { ns *string // APQKey is set when using GraphJin with automatic persisted queries @@ -271,12 +276,12 @@ type ReqConfig struct { } // SetNamespace is used to set namespace requests within a single instance of GraphJin. For example queries with the same name -func (rc *ReqConfig) SetNamespace(ns string) { +func (rc *RequestConfig) SetNamespace(ns string) { rc.ns = &ns } // GetNamespace is used to get the namespace requests within a single instance of GraphJin -func (rc *ReqConfig) GetNamespace() (string, bool) { +func (rc *RequestConfig) GetNamespace() (string, bool) { if rc.ns != nil { return *rc.ns, true } @@ -294,9 +299,9 @@ func (rc *ReqConfig) GetNamespace() (string, bool) { func (g *GraphJin) GraphQL(c context.Context, query string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) c1, span := gj.spanStart(c, "GraphJin Query") defer span.End() @@ -347,7 +352,7 @@ func (g *GraphJin) GraphQL(c context.Context, // if not production then save to allow list if !gj.prod && r.name != "IntrospectionQuery" { - if err = gj.saveToAllowList(resp.qc, resp.res.ns); err != nil { + if err = gj.saveToAllowList(resp.qc, resp.res.namespace); err != nil { return } } @@ -360,10 +365,10 @@ func (g *GraphJin) GraphQLTx(c context.Context, tx *sql.Tx, query string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { if rc == nil { - rc = &ReqConfig{Tx: tx} + rc = &RequestConfig{Tx: tx} } else { rc.Tx = tx } @@ -375,9 +380,9 @@ func (g *GraphJin) GraphQLTx(c context.Context, func (g *GraphJin) GraphQLByName(c context.Context, name string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) c1, span := gj.spanStart(c, "GraphJin Query") defer span.End() @@ -401,78 +406,79 @@ func (g *GraphJin) GraphQLByNameTx(c context.Context, tx *sql.Tx, name string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (res *Result, err error) { if rc == nil { - rc = &ReqConfig{Tx: tx} + rc = &RequestConfig{Tx: tx} } else { rc.Tx = tx } return g.GraphQLByName(c, name, vars, rc) } -type graphqlReq struct { - ns string - op qcode.QType - name string - query []byte - vars json.RawMessage - aschema map[string]json.RawMessage - rc *ReqConfig +type GraphqlReq struct { + namespace string + operation qcode.QType + name string + query []byte + vars json.RawMessage + aschema map[string]json.RawMessage + requestconfig *RequestConfig } -type graphqlResp struct { +type GraphqlResponse struct { res Result qc *qcode.QCode } -func (gj *graphjin) newGraphqlReq(rc *ReqConfig, +// newGraphqlReq creates a new GraphQL request +func (gj *GraphjinEngine) newGraphqlReq(rc *RequestConfig, op string, name string, query []byte, vars json.RawMessage, -) (r graphqlReq) { - r = graphqlReq{ - op: qcode.GetQTypeByName(op), - name: name, - query: query, - vars: vars, +) (r GraphqlReq) { + r = GraphqlReq{ + operation: qcode.GetQTypeByName(op), + name: name, + query: query, + vars: vars, } if rc != nil { - r.rc = rc + r.requestconfig = rc } if rc != nil && rc.ns != nil { - r.ns = *rc.ns + r.namespace = *rc.ns } else { - r.ns = gj.namespace + r.namespace = gj.namespace } return } // Set is used to set the namespace, operation type, name and query for the GraphQL request -func (r *graphqlReq) Set(item allow.Item) { - r.ns = item.Namespace - r.op = qcode.GetQTypeByName(item.Operation) +func (r *GraphqlReq) Set(item allow.Item) { + r.namespace = item.Namespace + r.operation = qcode.GetQTypeByName(item.Operation) r.name = item.Name r.query = item.Query r.aschema = item.ActionJSON } // GraphQL function is our main function it takes a GraphQL query compiles it -func (gj *graphjin) queryWithResult(c context.Context, r graphqlReq) (res *Result, err error) { +func (gj *GraphjinEngine) queryWithResult(c context.Context, r GraphqlReq) (res *Result, err error) { resp, err := gj.query(c, r) return &resp.res, err } // GraphQL function is our main function it takes a GraphQL query compiles it -func (gj *graphjin) query(c context.Context, r graphqlReq) ( - resp graphqlResp, err error, +func (gj *GraphjinEngine) query(c context.Context, r GraphqlReq) ( + resp GraphqlResponse, err error, ) { resp.res = Result{ - ns: r.ns, - op: r.op, - name: r.name, + namespace: r.namespace, + operation: r.operation, + name: r.name, } if !gj.prodSec && r.name == "IntrospectionQuery" { @@ -480,12 +486,12 @@ func (gj *graphjin) query(c context.Context, r graphqlReq) ( return } - if r.op == qcode.QTSubscription { + if r.operation == qcode.QTSubscription { err = errors.New("use 'core.Subscribe' for subscriptions") return } - if r.op == qcode.QTMutation && gj.schema.DBType() == "mysql" { + if r.operation == qcode.QTMutation && gj.schema.DBType() == "mysql" { err = errors.New("mysql: mutations not supported") return } @@ -519,15 +525,16 @@ func (g *GraphJin) Reload() error { return g.reload(nil) } +// reload redoes database discover and reinitializes GraphJin. func (g *GraphJin) reload(di *sdata.DBInfo) (err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) err = g.newGraphJin(gj.conf, gj.db, di, gj.fs, gj.opts...) return } // IsProd return true for production mode or false for development mode func (g *GraphJin) IsProd() bool { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) return gj.prod } @@ -546,6 +553,7 @@ func Operation(query string) (h Header, err error) { return } +// getFS returns the file system to be used by GraphJin func getFS(conf *Config) (fs FS, err error) { if v, ok := conf.FS.(FS); ok { fs = v @@ -561,6 +569,7 @@ func getFS(conf *Config) (fs FS, err error) { return } +// newError creates a new error list func newError(err error) (errList []Error) { errList = []Error{{Message: err.Error()}} return diff --git a/core/args.go b/core/args.go index ae27b60d..755fa7d2 100644 --- a/core/args.go +++ b/core/args.go @@ -18,10 +18,10 @@ type args struct { cindx int // index of cursor arg } -func (gj *graphjin) argList(c context.Context, +func (gj *GraphjinEngine) argList(c context.Context, md psql.Metadata, fields map[string]json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, buildJSON bool, ) (ar args, err error) { ar = args{cindx: -1} diff --git a/core/cache.go b/core/cache.go index a605d13b..52a03d64 100644 --- a/core/cache.go +++ b/core/cache.go @@ -8,11 +8,13 @@ type Cache struct { cache *lru.TwoQueueCache } -func (gj *graphjin) initCache() (err error) { +// initCache initializes the cache +func (gj *GraphjinEngine) initCache() (err error) { gj.cache.cache, err = lru.New2Q(500) return } +// Get returns the value from the cache func (c Cache) Get(key string) (val []byte, fromCache bool) { if v, ok := c.cache.Get(key); ok { val = v.([]byte) @@ -21,6 +23,7 @@ func (c Cache) Get(key string) (val []byte, fromCache bool) { return } +// Set sets the value in the cache func (c Cache) Set(key string, val []byte) { c.cache.Add(key, val) } diff --git a/core/config.go b/core/config.go index e84f0011..c13e56ed 100644 --- a/core/config.go +++ b/core/config.go @@ -267,7 +267,7 @@ type ResolverReq struct { ID string Sel *qcode.Select Log *log.Logger - *ReqConfig + *RequestConfig } // AddRoleTable function is a helper function to make it easy to add per-table diff --git a/core/core.go b/core/core.go index 8febb823..b00ef3f4 100644 --- a/core/core.go +++ b/core/core.go @@ -56,7 +56,7 @@ const ( // Duration time.Duration `json:"duration"` // } -func (gj *graphjin) getIntroResult() (data json.RawMessage, err error) { +func (gj *GraphjinEngine) getIntroResult() (data json.RawMessage, err error) { var ok bool if data, ok = gj.cache.Get("_intro"); ok { return @@ -69,7 +69,7 @@ func (gj *graphjin) getIntroResult() (data json.RawMessage, err error) { } // Initializes the database discovery process on graphjin -func (gj *graphjin) initDiscover() (err error) { +func (gj *GraphjinEngine) initDiscover() (err error) { switch gj.conf.DBType { case "": gj.dbtype = "postgres" @@ -86,7 +86,7 @@ func (gj *graphjin) initDiscover() (err error) { } // Private method that does the actual database discovery for initDiscover -func (gj *graphjin) _initDiscover() (err error) { +func (gj *GraphjinEngine) _initDiscover() (err error) { if gj.prod && gj.conf.EnableSchema { b, err := gj.fs.Get("db.graphql") if err != nil { @@ -132,14 +132,14 @@ func (gj *graphjin) _initDiscover() (err error) { } // Initializes the database schema on graphjin -func (gj *graphjin) initSchema() error { +func (gj *GraphjinEngine) initSchema() error { if err := gj._initSchema(); err != nil { return fmt.Errorf("%s: %w", gj.dbtype, err) } return nil } -func (gj *graphjin) _initSchema() (err error) { +func (gj *GraphjinEngine) _initSchema() (err error) { if len(gj.dbinfo.Tables) == 0 { return fmt.Errorf("no tables found in database") } @@ -178,7 +178,7 @@ func (gj *graphjin) _initSchema() (err error) { return } -func (gj *graphjin) initIntro() (err error) { +func (gj *GraphjinEngine) initIntro() (err error) { if !gj.prod && gj.conf.EnableIntrospection { var introJSON json.RawMessage introJSON, err = gj.getIntroResult() @@ -194,7 +194,7 @@ func (gj *graphjin) initIntro() (err error) { } // Initializes the qcode compilers -func (gj *graphjin) initCompilers() (err error) { +func (gj *GraphjinEngine) initCompilers() (err error) { qcc := qcode.Config{ TConfig: gj.tmap, DefaultBlock: gj.conf.DefaultBlock, @@ -206,29 +206,29 @@ func (gj *graphjin) initCompilers() (err error) { Validators: valid.Validators, } - gj.qc, err = qcode.NewCompiler(gj.schema, qcc) + gj.qcodeCompiler, err = qcode.NewCompiler(gj.schema, qcc) if err != nil { return } - if err = addRoles(gj.conf, gj.qc); err != nil { + if err = addRoles(gj.conf, gj.qcodeCompiler); err != nil { return } - gj.pc = psql.NewCompiler(psql.Config{ + gj.psqlCompiler = psql.NewCompiler(psql.Config{ Vars: gj.conf.Vars, DBType: gj.schema.DBType(), DBVersion: gj.schema.DBVersion(), - SecPrefix: gj.pf, + SecPrefix: gj.printFormat, EnableCamelcase: gj.conf.EnableCamelcase, }) return } -func (gj *graphjin) executeRoleQuery(c context.Context, +func (gj *GraphjinEngine) executeRoleQuery(c context.Context, conn *sql.Conn, vmap map[string]json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (role string, err error) { if c.Value(UserIDKey) == nil { role = "anon" @@ -236,7 +236,7 @@ func (gj *graphjin) executeRoleQuery(c context.Context, } ar, err := gj.argList(c, - gj.roleStmtMD, + gj.roleStatementMetadata, vmap, rc, false) @@ -266,9 +266,9 @@ func (gj *graphjin) executeRoleQuery(c context.Context, err = retryOperation(c1, func() error { var row *sql.Row if rc != nil && rc.Tx != nil { - row = rc.Tx.QueryRowContext(c1, gj.roleStmt, ar.values...) + row = rc.Tx.QueryRowContext(c1, gj.roleStatement, ar.values...) } else { - row = conn.QueryRowContext(c1, gj.roleStmt, ar.values...) + row = conn.QueryRowContext(c1, gj.roleStatement, ar.values...) } return row.Scan(&role) }) @@ -283,7 +283,7 @@ func (gj *graphjin) executeRoleQuery(c context.Context, // Returns the operation type for the query result func (r *Result) Operation() OpType { - switch r.op { + switch r.operation { case qcode.QTQuery: return OpQuery @@ -297,12 +297,12 @@ func (r *Result) Operation() OpType { // Returns the namespace for the query result func (r *Result) Namespace() string { - return r.ns + return r.namespace } // Returns the operation name for the query result func (r *Result) OperationName() string { - return r.op.String() + return r.operation.String() } // Returns the query name for the query result @@ -368,6 +368,7 @@ func (r *Result) CacheControl() string { // append(c.res.Extensions.Tracing.Execution.Resolvers, tr) // } +// debugLogStmt logs the query statement for debugging func (s *gstate) debugLogStmt() { st := s.cs.st @@ -386,7 +387,7 @@ func (s *gstate) debugLogStmt() { } // Saved the query qcode to the allow list -func (gj *graphjin) saveToAllowList(qc *qcode.QCode, ns string) (err error) { +func (gj *GraphjinEngine) saveToAllowList(qc *qcode.QCode, ns string) (err error) { if gj.conf.DisableAllowList { return nil } @@ -416,7 +417,7 @@ func (gj *graphjin) saveToAllowList(qc *qcode.QCode, ns string) (err error) { } // Starts tracing with the given name -func (gj *graphjin) spanStart(c context.Context, name string) (context.Context, Spaner) { +func (gj *GraphjinEngine) spanStart(c context.Context, name string) (context.Context, Spaner) { return gj.trace.Start(c, name) } diff --git a/core/crypt.go b/core/crypt.go index bb8c9342..50277977 100644 --- a/core/crypt.go +++ b/core/crypt.go @@ -7,6 +7,11 @@ import ( "encoding/base64" ) +// encryptValues encrypts the values in the data using the given key +// data: the data to encrypt +// encPrefix: the prefix to search for the values to encrypt +// decPrefix: the prefix to replace the values with +// nonce: the nonce to use for encryption func encryptValues( data, encPrefix, decPrefix, nonce []byte, key [32]byte) ([]byte, error) { @@ -78,6 +83,10 @@ func encryptValues( return b.Bytes(), nil } +// decryptValues decrypts the values in the data using the given key +// data: the data to decrypt +// prefix: the prefix to search for the values to decrypt +// key: the key to use for decryption func decryptValues(data, prefix []byte, key [32]byte) ([]byte, error) { var s, e int if e = bytes.Index(data[s:], prefix); e == -1 { @@ -151,6 +160,7 @@ func decryptValues(data, prefix []byte, key [32]byte) ([]byte, error) { return b.Bytes(), nil } +// firstCursorValue returns the first cursor value in the data func firstCursorValue(data []byte, prefix []byte) []byte { var buf [100]byte pf := append(buf[:0], prefix...) diff --git a/core/gstate.go b/core/gstate.go index bb5ed17f..8d305bfd 100644 --- a/core/gstate.go +++ b/core/gstate.go @@ -16,8 +16,8 @@ import ( ) type gstate struct { - gj *graphjin - r graphqlReq + gj *GraphjinEngine + r GraphqlReq cs *cstate vmap map[string]json.RawMessage data []byte @@ -40,7 +40,7 @@ type stmt struct { sql string } -func newGState(c context.Context, gj *graphjin, r graphqlReq) (s gstate, err error) { +func newGState(c context.Context, gj *GraphjinEngine, r GraphqlReq) (s gstate, err error) { s.gj = gj s.r = r @@ -58,7 +58,7 @@ func newGState(c context.Context, gj *graphjin, r graphqlReq) (s gstate, err err // convert variable json to a go map also decrypted encrypted values if len(r.vars) != 0 { var vars json.RawMessage - vars, err = decryptValues(r.vars, decPrefix, s.gj.encKey) + vars, err = decryptValues(r.vars, decPrefix, s.gj.encryptionKey) if err != nil { return } @@ -113,16 +113,16 @@ func (s *gstate) compileQueryForRole() (err error) { vars = s.vmap } - if st.qc, err = s.gj.qc.Compile( + if st.qc, err = s.gj.qcodeCompiler.Compile( s.r.query, vars, s.role, - s.r.ns); err != nil { + s.r.namespace); err != nil { return } var w bytes.Buffer - if st.md, err = s.gj.pc.Compile(&w, st.qc); err != nil { + if st.md, err = s.gj.psqlCompiler.Compile(&w, st.qc); err != nil { return } @@ -254,7 +254,7 @@ func (s *gstate) execute(c context.Context, conn *sql.Conn) (err error) { if span.IsRecording() { span.SetAttributesString( - StringAttr{"query.namespace", s.r.ns}, + StringAttr{"query.namespace", s.r.namespace}, StringAttr{"query.operation", cs.st.qc.Type.String()}, StringAttr{"query.name", cs.st.qc.Name}, StringAttr{"query.role", cs.st.role}) @@ -270,25 +270,25 @@ func (s *gstate) execute(c context.Context, conn *sql.Conn) (err error) { s.dhash = sha256.Sum256(s.data) s.data, err = encryptValues(s.data, - s.gj.pf, decPrefix, s.dhash[:], s.gj.encKey) + s.gj.printFormat, decPrefix, s.dhash[:], s.gj.encryptionKey) return } func (s *gstate) executeRoleQuery(c context.Context, conn *sql.Conn) (err error) { - s.role, err = s.gj.executeRoleQuery(c, conn, s.vmap, s.r.rc) + s.role, err = s.gj.executeRoleQuery(c, conn, s.vmap, s.r.requestconfig) return } func (s *gstate) argList(c context.Context) (args args, err error) { - args, err = s.gj.argList(c, s.cs.st.md, s.vmap, s.r.rc, false) + args, err = s.gj.argList(c, s.cs.st.md, s.vmap, s.r.requestconfig, false) return } func (s *gstate) argListForSub(c context.Context, vmap map[string]json.RawMessage, ) (args args, err error) { - args, err = s.gj.argList(c, s.cs.st.md, vmap, s.r.rc, true) + args, err = s.gj.argList(c, s.cs.st.md, vmap, s.r.requestconfig, true) return } @@ -354,13 +354,13 @@ func (s *gstate) qcode() (qc *qcode.QCode) { } func (s *gstate) tx() (tx *sql.Tx) { - if s.r.rc != nil { - tx = s.r.rc.Tx + if s.r.requestconfig != nil { + tx = s.r.requestconfig.Tx } return } func (s *gstate) key() (key string) { - key = s.r.ns + s.r.name + s.role + key = s.r.namespace + s.r.name + s.role return } diff --git a/core/init.go b/core/init.go index aa20d408..9fcbb4a2 100644 --- a/core/init.go +++ b/core/init.go @@ -11,17 +11,17 @@ import ( ) // Initializes the graphjin instance with the config -func (gj *graphjin) initConfig() error { +func (gj *GraphjinEngine) initConfig() error { c := gj.conf - tm := make(map[string]struct{}) + tableMap := make(map[string]struct{}) - for _, t := range c.Tables { - k := t.Schema + t.Name - if _, ok := tm[k]; ok { - return fmt.Errorf("duplicate table found: %s", t.Name) + for _, table := range c.Tables { + k := table.Schema + table.Name + if _, ok := tableMap[k]; ok { + return fmt.Errorf("duplicate table found: %s", table.Name) } - tm[k] = struct{}{} + tableMap[k] = struct{}{} } for k, v := range c.Vars { @@ -84,7 +84,8 @@ func (gj *graphjin) initConfig() error { return nil } -func (gj *graphjin) addTableInfo(t Table) error { +// addTableInfo adds table info to the compiler +func (gj *GraphjinEngine) addTableInfo(t Table) error { obm := map[string][][2]string{} for k, ob := range t.OrderBy { @@ -103,6 +104,7 @@ func (gj *graphjin) addTableInfo(t Table) error { return nil } +// getDBTableAliases returns a map of table aliases func getDBTableAliases(c *Config) map[string][]string { m := make(map[string][]string, len(c.Tables)) @@ -116,7 +118,8 @@ func getDBTableAliases(c *Config) map[string][]string { return m } -func addTables(conf *Config, di *sdata.DBInfo) error { +// addTables adds tables to the database info +func addTables(conf *Config, dbInfo *sdata.DBInfo) error { var err error for _, t := range conf.Tables { @@ -126,13 +129,13 @@ func addTables(conf *Config, di *sdata.DBInfo) error { } switch t.Type { case "json", "jsonb": - err = addJsonTable(conf, di, t) + err = addJsonTable(conf, dbInfo, t) case "polymorphic": - err = addVirtualTable(conf, di, t) + err = addVirtualTable(conf, dbInfo, t) default: - err = updateTable(conf, di, t) + err = updateTable(conf, dbInfo, t) } if err != nil { @@ -143,14 +146,15 @@ func addTables(conf *Config, di *sdata.DBInfo) error { return nil } -func updateTable(conf *Config, di *sdata.DBInfo, t Table) error { - t1, err := di.GetTable(t.Schema, t.Name) +// updateTable updates the table info in the database info +func updateTable(conf *Config, dbInfo *sdata.DBInfo, table Table) error { + t1, err := dbInfo.GetTable(table.Schema, table.Name) if err != nil { return fmt.Errorf("table: %w", err) } - for _, c := range t.Columns { - c1, err := di.GetColumn(t.Schema, t.Name, c.Name) + for _, c := range table.Columns { + c1, err := dbInfo.GetColumn(table.Schema, table.Name, c.Name) if err != nil { return err } @@ -168,18 +172,19 @@ func updateTable(conf *Config, di *sdata.DBInfo, t Table) error { return nil } -func addJsonTable(conf *Config, di *sdata.DBInfo, t Table) error { +// addJsonTable adds a json table to the database info +func addJsonTable(conf *Config, dbInfo *sdata.DBInfo, table Table) error { // This is for jsonb column that want to be a table. - if t.Table == "" { - return fmt.Errorf("json table: set the 'table' for column '%s'", t.Name) + if table.Table == "" { + return fmt.Errorf("json table: set the 'table' for column '%s'", table.Name) } - bc, err := di.GetColumn(t.Schema, t.Table, t.Name) + bc, err := dbInfo.GetColumn(table.Schema, table.Table, table.Name) if err != nil { return fmt.Errorf("json table: %w", err) } - bt, err := di.GetTable(bc.Schema, bc.Table) + bt, err := dbInfo.GetTable(bc.Schema, bc.Table) if err != nil { return fmt.Errorf("json table: %w", err) } @@ -187,23 +192,23 @@ func addJsonTable(conf *Config, di *sdata.DBInfo, t Table) error { if bc.Type != "json" && bc.Type != "jsonb" { return fmt.Errorf( "json table: column '%s' in table '%s' is of type '%s'. Only JSON or JSONB is valid", - t.Name, t.Table, bc.Type) + table.Name, table.Table, bc.Type) } - columns := make([]sdata.DBColumn, 0, len(t.Columns)) + columns := make([]sdata.DBColumn, 0, len(table.Columns)) - for i := range t.Columns { - c := t.Columns[i] + for i := range table.Columns { + c := table.Columns[i] columns = append(columns, sdata.DBColumn{ ID: -1, Schema: bc.Schema, - Table: t.Name, + Table: table.Name, Name: c.Name, Type: c.Type, }) if c.Type == "" { return fmt.Errorf("json table: type parameter missing for column: %s.%s'", - t.Name, c.Name) + table.Name, c.Name) } } @@ -216,14 +221,15 @@ func addJsonTable(conf *Config, di *sdata.DBInfo, t Table) error { Type: bc.Type, } - nt := sdata.NewDBTable(bc.Schema, t.Name, bc.Type, columns) + nt := sdata.NewDBTable(bc.Schema, table.Name, bc.Type, columns) nt.PrimaryCol = col1 nt.SecondaryCol = bt.PrimaryCol - di.AddTable(nt) + dbInfo.AddTable(nt) return nil } +// addVirtualTable adds a virtual table to the database info func addVirtualTable(conf *Config, di *sdata.DBInfo, t Table) error { if len(t.Columns) == 0 { return fmt.Errorf("polymorphic table: no id column specified") @@ -249,6 +255,7 @@ func addVirtualTable(conf *Config, di *sdata.DBInfo, t Table) error { return nil } +// addForeignKeys adds foreign keys to the database info func addForeignKeys(conf *Config, di *sdata.DBInfo) error { for _, t := range conf.Tables { if t.Type == "polymorphic" { @@ -266,6 +273,7 @@ func addForeignKeys(conf *Config, di *sdata.DBInfo) error { return nil } +// addForeignKey adds a foreign key to the database info func addForeignKey(conf *Config, di *sdata.DBInfo, c Column, t Table) error { c1, err := di.GetColumn(t.Schema, t.Name, c.Name) if err != nil { @@ -310,6 +318,7 @@ func addForeignKey(conf *Config, di *sdata.DBInfo, c Column, t Table) error { return nil } +// addRoles adds roles to the compiler func addRoles(c *Config, qc *qcode.Compiler) error { for _, r := range c.Roles { for _, t := range r.Tables { @@ -322,6 +331,7 @@ func addRoles(c *Config, qc *qcode.Compiler) error { return nil } +// addRole adds a role to the compiler func addRole(qc *qcode.Compiler, r Role, t RoleTable, defaultBlock bool) error { ro := false // read-only @@ -392,10 +402,12 @@ func addRole(qc *qcode.Compiler, r Role, t RoleTable, defaultBlock bool) error { }) } +// GetTable returns a table from the role func (r *Role) GetTable(schema, name string) *RoleTable { return r.tm[name] } +// getFK returns the foreign key for the column func (c *Column) getFK(defaultSchema string) ([3]string, bool) { var ret [3]string var ok bool @@ -412,10 +424,12 @@ func (c *Column) getFK(defaultSchema string) ([3]string, bool) { return ret, ok } +// sanitize trims the value func sanitize(value string) string { return strings.TrimSpace(value) } +// isASCII checks if the string is ASCII func isASCII(s string) (int, bool) { for i := 0; i < len(s); i++ { if s[i] > unicode.MaxASCII { @@ -425,7 +439,8 @@ func isASCII(s string) (int, bool) { return -1, true } -func (gj *graphjin) initAllowList() (err error) { +// initAllowList initializes the allow list +func (gj *GraphjinEngine) initAllowList() (err error) { gj.allowList, err = allow.New( gj.log, gj.fs, diff --git a/core/internal/allow/allow.go b/core/internal/allow/allow.go index 3df9fb0b..b069bef9 100644 --- a/core/internal/allow/allow.go +++ b/core/internal/allow/allow.go @@ -22,7 +22,7 @@ type FS interface { var ErrUnknownGraphQLQuery = errors.New("unknown graphql query") const ( - queryPath = "/queries" + QUERY_PATH = "/queries" ) type Item struct { @@ -45,6 +45,7 @@ type List struct { fs FS } +// New creates a new allow list func New(log *_log.Logger, fs FS, readOnly bool) (al *List, err error) { if fs == nil { return nil, fmt.Errorf("no filesystem defined for the allow list") @@ -77,6 +78,7 @@ func New(log *_log.Logger, fs FS, readOnly bool) (al *List, err error) { return al, err } +// Set adds a new query to the allow list func (al *List) Set(item Item) error { if al.saveChan == nil { return errors.New("allow list is read-only") @@ -90,6 +92,7 @@ func (al *List) Set(item Item) error { return nil } +// GetByName returns a query by name func (al *List) GetByName(name string, useCache bool) (item Item, err error) { if useCache { if v, ok := al.cache.Get(name); ok { @@ -98,26 +101,27 @@ func (al *List) GetByName(name string, useCache bool) (item Item, err error) { } } - fp := filepath.Join(queryPath, name) + fp := filepath.Join(QUERY_PATH, name) var ok bool if ok, err = al.fs.Exists((fp + ".gql")); err != nil { return } else if ok { - item, err = al.get(queryPath, name, ".gql", useCache) + item, err = al.get(QUERY_PATH, name, ".gql", useCache) return } if ok, err = al.fs.Exists((fp + ".graphql")); err != nil { return } else if ok { - item, err = al.get(queryPath, name, ".gql", useCache) + item, err = al.get(QUERY_PATH, name, ".gql", useCache) } else { err = ErrUnknownGraphQLQuery } return } +// get returns a query by name func (al *List) get(queryPath, name, ext string, useCache bool) (item Item, err error) { queryNS, queryName := splitName(name) @@ -161,6 +165,7 @@ func (al *List) get(queryPath, name, ext string, useCache bool) (item Item, err return } +// save saves a query to the allow list func (al *List) save(item Item) (err error) { item.Name = strings.TrimSpace(item.Name) if item.Name == "" { @@ -170,6 +175,7 @@ func (al *List) save(item Item) (err error) { return al.saveItem(item) } +// saveItem saves a query to the allow list func (al *List) saveItem(item Item) (err error) { var queryFile string if item.Namespace != "" { @@ -196,7 +202,7 @@ func (al *List) saveItem(item Item) (err error) { fmap[fragFile] = struct{}{} } - ff := filepath.Join(queryPath, "fragments", (fragFile + ".gql")) + ff := filepath.Join(QUERY_PATH, "fragments", (fragFile + ".gql")) err = al.fs.Put(ff, []byte(f.Value)) if err != nil { return @@ -207,7 +213,7 @@ func (al *List) saveItem(item Item) (err error) { } buf.Write(bytes.TrimSpace(item.Query)) - qf := filepath.Join(queryPath, (queryFile + ".gql")) + qf := filepath.Join(QUERY_PATH, (queryFile + ".gql")) err = al.fs.Put(qf, bytes.TrimSpace(buf.Bytes())) if err != nil { return @@ -215,7 +221,7 @@ func (al *List) saveItem(item Item) (err error) { if len(item.ActionJSON) != 0 { var vars []byte - jf := filepath.Join(queryPath, (queryFile + ".json")) + jf := filepath.Join(QUERY_PATH, (queryFile + ".json")) vars, err = json.MarshalIndent(item.ActionJSON, "", " ") if err != nil { return @@ -225,12 +231,13 @@ func (al *List) saveItem(item Item) (err error) { return } -func splitName(v string) (string, string) { - i := strings.LastIndex(v, ".") +// splitName splits a name into namespace and name +func splitName(name string) (string, string) { + i := strings.LastIndex(name, ".") if i == -1 { - return "", v - } else if i < len(v)-1 { - return v[:i], v[(i + 1):] + return "", name + } else if i < len(name)-1 { + return name[:i], name[(i + 1):] } return "", "" } diff --git a/core/internal/allow/gql.go b/core/internal/allow/gql.go index 5ccf52b5..a18301f9 100644 --- a/core/internal/allow/gql.go +++ b/core/internal/allow/gql.go @@ -10,6 +10,7 @@ import ( var incRe = regexp.MustCompile(`(?m)#import \"(.+)\"`) +// readGQL reads a graphql file and resolves all imports func readGQL(fs FS, fname string) (gql []byte, err error) { var b bytes.Buffer @@ -28,6 +29,7 @@ func readGQL(fs FS, fname string) (gql []byte, err error) { return } +// parseGQL parses a graphql file and resolves all imports func parseGQL(fs FS, fname string, r io.Writer) (err error) { b, err := fs.Get(fname) if err != nil { diff --git a/core/internal/assert/assert.go b/core/internal/assert/assert.go index 73f889fc..f9377577 100644 --- a/core/internal/assert/assert.go +++ b/core/internal/assert/assert.go @@ -5,12 +5,14 @@ import ( "testing" ) +// Equals compares two values func Equals(t *testing.T, exp, got interface{}) { if !reflect.DeepEqual(exp, got) { t.Errorf("expected %v, got %v", exp, got) } } +// Empty checks if a slice is empty func Empty(t *testing.T, got interface{}) { val := reflect.ValueOf(got) if val.Kind() != reflect.Slice { @@ -24,12 +26,14 @@ func Empty(t *testing.T, got interface{}) { } } +// NoError checks if an error is nil func NoError(t *testing.T, err error) { if err != nil { t.Errorf("no errror expected, got %s", err.Error()) } } +// NoErrorFatal checks if an error is nil and fails the test func NoErrorFatal(t *testing.T, err error) { if err != nil { t.Fatalf("no errror expected, got %s", err.Error()) diff --git a/core/internal/graph/lex.go b/core/internal/graph/lex.go index bdb481cd..8257c6fc 100644 --- a/core/internal/graph/lex.go +++ b/core/internal/graph/lex.go @@ -79,8 +79,8 @@ var punctuators = map[rune]MType{ const eof = -1 -// stateFn represents the state of the scanner as a function that returns the next state. -type stateFn func(*lexer) stateFn +// StateFn represents the state of the scanner as a function that returns the next state. +type StateFn func(*lexer) StateFn // lexer holds the state of the scanner. type lexer struct { @@ -96,6 +96,7 @@ type lexer struct { var zeroLex = lexer{} +// Reset resets the lexer to scan a new input string. func (l *lexer) Reset() { *l = zeroLex } @@ -133,6 +134,7 @@ func (l *lexer) backup() { } } +// current returns the current bytes of the input. func (l *lexer) current() []byte { return l.input[l.start:l.pos] } @@ -151,6 +153,7 @@ func (l *lexer) emit(t MType) { l.start = l.pos } +// emitL passes an item back to the client and lowercases the value. func (l *lexer) emitL(t MType) { lowercase(l.current()) l.emit(t) @@ -199,7 +202,7 @@ func (l *lexer) acceptRun(valid []byte) { // errorf returns an error token and terminates the scan by passing // back a nil pointer that will be the next state, terminating l.nextItem. -func (l *lexer) errorf(format string, args ...interface{}) stateFn { +func (l *lexer) errorf(format string, args ...interface{}) StateFn { l.err = fmt.Errorf(format, args...) l.items = append(l.items, item{itemError, l.start, l.input[l.start:l.pos], l.line}) return nil @@ -233,7 +236,7 @@ func (l *lexer) run() { } // lexInsideAction scans the elements inside action delimiters. -func lexRoot(l *lexer) stateFn { +func lexRoot(l *lexer) StateFn { r := l.next() switch { @@ -287,7 +290,7 @@ func lexRoot(l *lexer) stateFn { } // lexName scans a name. -func lexName(l *lexer) stateFn { +func lexName(l *lexer) StateFn { for { r := l.next() @@ -317,7 +320,7 @@ func lexName(l *lexer) stateFn { } // lexString scans a string. -func lexString(l *lexer) stateFn { +func lexString(l *lexer) StateFn { if sr, ok := l.accept([]byte(quotesToken)); ok { l.ignore() @@ -351,7 +354,7 @@ func lexString(l *lexer) stateFn { // lexNumber scans a number: decimal and float. This isn't a perfect number scanner // for instance it accepts "." and "0x0.2" and "089" - but when it's wrong the input // is invalid and the parser (via strconv) should notice. -func lexNumber(l *lexer) stateFn { +func lexNumber(l *lexer) StateFn { if !l.scanNumber() { return l.errorf("bad number syntax: %q", l.input[l.start:l.pos]) } @@ -359,6 +362,7 @@ func lexNumber(l *lexer) stateFn { return lexRoot } +// scanNumber scans a number: decimal and float. func (l *lexer) scanNumber() bool { // Optional leading sign. l.accept(signsToken) @@ -391,14 +395,17 @@ func isAlphaNumeric(r rune) bool { return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) } +// equals reports whether b is equal to val. func equals(b, val []byte) bool { return bytes.EqualFold(b, val) } +// contains reports whether b contains any of the chars. func contains(b []byte, chars string) bool { return bytes.ContainsAny(b, chars) } +// lowercase lowercases the bytes in b. func lowercase(b []byte) { for i := 0; i < len(b); i++ { if b[i] >= 'A' && b[i] <= 'Z' { @@ -407,6 +414,7 @@ func lowercase(b []byte) { } } +// String returns a string representation of the item. func (i item) String() string { var v string diff --git a/core/internal/graph/utils.go b/core/internal/graph/utils.go index 50470266..6b3f78cc 100644 --- a/core/internal/graph/utils.go +++ b/core/internal/graph/utils.go @@ -13,6 +13,7 @@ type FPInfo struct { Name string } +// FastParse parses the query and returns the operation type and name func FastParse(gql string) (h FPInfo, err error) { if gql == "" { return h, errors.New("query missing or empty") @@ -20,6 +21,7 @@ func FastParse(gql string) (h FPInfo, err error) { return fastParse(strings.NewReader(gql)) } +// FastParseBytes parses the query and returns the operation type and name func FastParseBytes(gql []byte) (h FPInfo, err error) { if len(gql) == 0 { return h, errors.New("query missing or empty") @@ -27,6 +29,7 @@ func FastParseBytes(gql []byte) (h FPInfo, err error) { return fastParse(bytes.NewReader(gql)) } +// fastParse parses the query and returns the operation type and name func fastParse(r io.Reader) (h FPInfo, err error) { var s scanner.Scanner s.Init(r) diff --git a/core/internal/sdata/dwg.go b/core/internal/sdata/dwg.go index 9aa208e8..feb20fb6 100644 --- a/core/internal/sdata/dwg.go +++ b/core/internal/sdata/dwg.go @@ -14,6 +14,7 @@ var ( ErrThoughNodeNotFound = errors.New("though node not found") ) +// TEdge represents a table edge for the graph type TEdge struct { From, To, Weight int32 @@ -24,32 +25,36 @@ type TEdge struct { name string } +// addNode adds a table node to the graph func (s *DBSchema) addNode(t DBTable) int32 { s.tables = append(s.tables, t) - n := s.rg.AddNode() + n := s.relationshipGraph.AddNode() s.tindex[(t.Schema + ":" + t.Name)] = nodeInfo{n} return n } +// addAliases adds table aliases to the graph func (s *DBSchema) addAliases(t DBTable, nodeID int32, aliases []string) { for _, al := range aliases { s.tindex[(t.Schema + ":" + al)] = nodeInfo{nodeID} - s.ai[al] = nodeInfo{nodeID} + s.tableAliasIndex[al] = nodeInfo{nodeID} } } +// GetAliases returns a map of table aliases func (s *DBSchema) GetAliases() map[string]DBTable { ts := make(map[string]DBTable) - for name, n := range s.ai { + for name, n := range s.tableAliasIndex { ts[name] = s.tables[int(n.nodeID)] } return ts } +// IsAlias checks if a table is an alias func (s *DBSchema) IsAlias(name string) bool { - _, ok := s.ai[name] + _, ok := s.tableAliasIndex[name] return ok } @@ -145,11 +150,11 @@ func (s *DBSchema) addToGraph( return err } - if err := s.rg.UpdateEdge(ln, rn, edgeID1, edgeID2); err != nil { + if err := s.relationshipGraph.UpdateEdge(ln, rn, edgeID1, edgeID2); err != nil { return err } - if err := s.rg.UpdateEdge(rn, ln, edgeID2, edgeID1); err != nil { + if err := s.relationshipGraph.UpdateEdge(rn, ln, edgeID2, edgeID1); err != nil { return err } @@ -166,10 +171,11 @@ func (s *DBSchema) addToGraph( return nil } +// addEdge creates a relationship between two tables func (s *DBSchema) addEdge(name string, edge TEdge, inSchema bool, ) (int32, error) { // add edge to graph - edgeID, err := s.rg.AddEdge(edge.From, edge.To, + edgeID, err := s.relationshipGraph.AddEdge(edge.From, edge.To, edge.Weight, edge.CName) if err != nil { return -1, err @@ -181,13 +187,14 @@ func (s *DBSchema) addEdge(name string, edge TEdge, inSchema bool, if inSchema { edge.name = name } - s.ae[edgeID] = edge + s.allEdges[edgeID] = edge return edgeID, nil } +// addEdgeInfo adds edge info to the index func (s *DBSchema) addEdgeInfo(k string, ei edgeInfo) { - if eiList, ok := s.ei[k]; ok { + if eiList, ok := s.edgesIndex[k]; ok { for i, v := range eiList { if v.nodeID != ei.nodeID { continue @@ -198,13 +205,14 @@ func (s *DBSchema) addEdgeInfo(k string, ei edgeInfo) { } } edgeIDs := append(v.edgeIDs, ei.edgeIDs[0]) - s.ei[k][i].edgeIDs = edgeIDs + s.edgesIndex[k][i].edgeIDs = edgeIDs return } } - s.ei[k] = append(s.ei[k], ei) + s.edgesIndex[k] = append(s.edgesIndex[k], ei) } +// Find returns a table by schema and name func (s *DBSchema) Find(schema, name string) (DBTable, error) { var t DBTable @@ -220,6 +228,7 @@ func (s *DBSchema) Find(schema, name string) (DBTable, error) { return s.tables[v.nodeID], nil } +// TPath represents a table path type TPath struct { Rel RelType LT DBTable @@ -228,13 +237,14 @@ type TPath struct { RC DBColumn } +// FindPath returns a path between two tables func (s *DBSchema) FindPath(from, to, through string) ([]TPath, error) { - fl, ok := s.ei[from] + fl, ok := s.edgesIndex[from] if !ok { return nil, ErrFromEdgeNotFound } - tl, ok := s.ei[to] + tl, ok := s.edgesIndex[to] if !ok { return nil, ErrToEdgeNotFound } @@ -250,7 +260,7 @@ func (s *DBSchema) FindPath(from, to, through string) ([]TPath, error) { path := []TPath{} for _, eid := range res.edges { - edge := s.ae[eid] + edge := s.allEdges[eid] path = append(path, TPath{ Rel: edge.Type, LT: edge.LT, @@ -265,11 +275,13 @@ func (s *DBSchema) FindPath(from, to, through string) ([]TPath, error) { return path, nil } +// graphResult represents a graph result type graphResult struct { from, to edgeInfo edges []int32 } +// between finds a path between two tables func (s *DBSchema) between(from, to []edgeInfo, through string) (res graphResult, err error) { // TODO: picking a path // 1. first look for a direct edge to other table @@ -288,13 +300,14 @@ func (s *DBSchema) between(from, to []edgeInfo, through string) (res graphResult return res, ErrPathNotFound } +// pickPath picks a path between two tables func (s *DBSchema) pickPath(from, to edgeInfo, through string) (res graphResult, err error) { res.from = from res.to = to fn := from.nodeID tn := to.nodeID - paths := s.rg.AllPaths(fn, tn) + paths := s.relationshipGraph.AllPaths(fn, tn) if through != "" { paths, err = s.pickThroughPath(paths, through) @@ -313,6 +326,7 @@ func (s *DBSchema) pickPath(from, to edgeInfo, through string) (res graphResult, return res, ErrPathNotFound } +// pickEdges picks edges between two tables func (s *DBSchema) pickEdges(path []int32, from, to edgeInfo) (edges []int32, allFound bool) { pathLen := len(path) peID := int32(-2) // must be -2 so does not match default -1 @@ -320,7 +334,7 @@ func (s *DBSchema) pickEdges(path []int32, from, to edgeInfo) (edges []int32, al for i := 1; i < pathLen; i++ { fn := path[i-1] tn := path[i] - lines := s.rg.GetEdges(fn, tn) + lines := s.relationshipGraph.GetEdges(fn, tn) // s.PrintLines(lines) @@ -354,6 +368,7 @@ func (s *DBSchema) pickEdges(path []int32, from, to edgeInfo) (edges []int32, al return } +// pickThroughPath picks a path through a node func (s *DBSchema) pickThroughPath(paths [][]int32, through string) ([][]int32, error) { var npaths [][]int32 @@ -376,6 +391,7 @@ func (s *DBSchema) pickThroughPath(paths [][]int32, through string) ([][]int32, return npaths, nil } +// pickLine picks a line between two tables func pickLine(lines []util.Edge, ei edgeInfo, peID int32) *util.Edge { for _, v := range lines { for _, eid := range ei.edgeIDs { @@ -387,6 +403,7 @@ func pickLine(lines []util.Edge, ei edgeInfo, peID int32) *util.Edge { return nil } +// PathToRel converts a table path to a relationship func PathToRel(p TPath) DBRel { return DBRel{ Type: p.Rel, @@ -395,6 +412,7 @@ func PathToRel(p TPath) DBRel { } } +// minWeightedLine returns the line with the minimum weight func minWeightedLine(lines []util.Edge, peID int32) *util.Edge { var min int32 = 100 var line *util.Edge @@ -413,9 +431,10 @@ func minWeightedLine(lines []util.Edge, peID int32) *util.Edge { return line } +// PrintLines prints the graph lines func (s *DBSchema) PrintLines(lines []util.Edge) { for _, v := range lines { - e := s.ae[v.ID] + e := s.allEdges[v.ID] f := s.tables[e.From] t := s.tables[e.To] @@ -425,6 +444,7 @@ func (s *DBSchema) PrintLines(lines []util.Edge) { fmt.Println("---") } +// PrintEdgeInfo prints edge info func (s *DBSchema) PrintEdgeInfo(e edgeInfo) { t := s.tables[e.nodeID] fmt.Printf("-- EdgeInfo %s %+v\n", t.Name, e.edgeIDs) @@ -434,6 +454,7 @@ func (s *DBSchema) PrintEdgeInfo(e edgeInfo) { // } } +// String returns a string representation of a table path func (tp *TPath) String() string { return fmt.Sprintf("(%s) %s ==> %s ==> (%s) %s", tp.LT.String(), tp.LC.String(), diff --git a/core/internal/sdata/schema.go b/core/internal/sdata/schema.go index 6e14bbf1..31c9fac2 100644 --- a/core/internal/sdata/schema.go +++ b/core/internal/sdata/schema.go @@ -19,18 +19,18 @@ type nodeInfo struct { } type DBSchema struct { - typ string // db type - ver int // db version - schema string // db schema - name string // db name - tables []DBTable // tables - vt map[string]VirtualTable // for polymorphic relationships - fm map[string]DBFunction // db functions - tindex map[string]nodeInfo // table index - ai map[string]nodeInfo // table alias index - ei map[string][]edgeInfo // edges index - ae map[int32]TEdge // all edges - rg *util.Graph // relationship graph + dbType string // db type + version int // db version + schema string // db schema + name string // db name + tables []DBTable // tables + virtualTables map[string]VirtualTable // for polymorphic relationships + dbFunctions map[string]DBFunction // db functions + tindex map[string]nodeInfo // table index + tableAliasIndex map[string]nodeInfo // table alias index + edgesIndex map[string][]edgeInfo // edges index + allEdges map[int32]TEdge // all edges + relationshipGraph *util.Graph // relationship graph } type RelType int @@ -46,39 +46,43 @@ const ( RelSkip ) +// DBRelLeft represents database information type DBRelLeft struct { Ti DBTable Col DBColumn } +// DBRelRight represents a database relationship type DBRelRight struct { VTable string Ti DBTable Col DBColumn } +// DBRel represents a database relationship type DBRel struct { Type RelType Left DBRelLeft Right DBRelRight } +// NewDBSchema creates a new database schema func NewDBSchema( info *DBInfo, aliases map[string][]string, ) (*DBSchema, error) { schema := &DBSchema{ - typ: info.Type, - ver: info.Version, - schema: info.Schema, - name: info.Name, - vt: make(map[string]VirtualTable), - fm: make(map[string]DBFunction), - tindex: make(map[string]nodeInfo), - ai: make(map[string]nodeInfo), - ei: make(map[string][]edgeInfo), - ae: make(map[int32]TEdge), - rg: util.NewGraph(), + dbType: info.Type, + version: info.Version, + schema: info.Schema, + name: info.Name, + virtualTables: make(map[string]VirtualTable), + dbFunctions: make(map[string]DBFunction), + tindex: make(map[string]nodeInfo), + tableAliasIndex: make(map[string]nodeInfo), + edgesIndex: make(map[string][]edgeInfo), + allEdges: make(map[int32]TEdge), + relationshipGraph: util.NewGraph(), } for _, t := range info.Tables { @@ -102,11 +106,11 @@ func NewDBSchema( // add aliases to edge index by duplicating for t, al := range aliases { for _, alias := range al { - if _, ok := schema.ei[alias]; ok { + if _, ok := schema.edgesIndex[alias]; ok { continue } - if e, ok := schema.ei[t]; ok { - schema.ei[alias] = e + if e, ok := schema.edgesIndex[t]; ok { + schema.edgesIndex[alias] = e } } } @@ -127,13 +131,14 @@ func NewDBSchema( // don't include functions that return records // as those are considered selector functions if f.Type != "record" { - schema.fm[f.Name] = info.Functions[k] + schema.dbFunctions[f.Name] = info.Functions[k] } } return schema, nil } +// addRels adds relationships to the schema func (s *DBSchema) addRels(t DBTable) error { var err error switch t.Type { @@ -152,6 +157,7 @@ func (s *DBSchema) addRels(t DBTable) error { return s.addColumnRels(t) } +// addJsonRel adds a json relationship to the schema func (s *DBSchema) addJsonRel(t DBTable) error { st, err := s.Find(t.SecondaryCol.Schema, t.SecondaryCol.Table) if err != nil { @@ -166,6 +172,7 @@ func (s *DBSchema) addJsonRel(t DBTable) error { return s.addToGraph(t, t.PrimaryCol, st, sc, RelEmbedded) } +// addPolymorphicRel adds a polymorphic relationship to the schema func (s *DBSchema) addPolymorphicRel(t DBTable) error { pt, err := s.Find(t.PrimaryCol.FKeySchema, t.PrimaryCol.FKeyTable) if err != nil { @@ -185,6 +192,7 @@ func (s *DBSchema) addPolymorphicRel(t DBTable) error { return s.addToGraph(t, t.PrimaryCol, pt, pc, RelPolymorphic) } +// addRemoteRel adds a remote relationship to the schema func (s *DBSchema) addRemoteRel(t DBTable) error { pt, err := s.Find(t.PrimaryCol.FKeySchema, t.PrimaryCol.FKeyTable) if err != nil { @@ -199,6 +207,7 @@ func (s *DBSchema) addRemoteRel(t DBTable) error { return s.addToGraph(t, t.PrimaryCol, pt, pc, RelRemote) } +// addColumnRels adds column relationships to the schema func (s *DBSchema) addColumnRels(t DBTable) error { var err error @@ -244,8 +253,9 @@ func (s *DBSchema) addColumnRels(t DBTable) error { return nil } +// addVirtual adds a virtual table to the schema func (s *DBSchema) addVirtual(vt VirtualTable) error { - s.vt[vt.Name] = vt + s.virtualTables[vt.Name] = vt for _, t := range s.tables { idCol, ok := t.getColumn(vt.IDColumn) @@ -298,22 +308,25 @@ func (s *DBSchema) addVirtual(vt VirtualTable) error { return nil } +// GetTables returns a table from the schema func (s *DBSchema) GetTables() []DBTable { return s.tables } +// RelNode represents a relationship node type RelNode struct { Name string Type RelType Table DBTable } +// GetFirstDegree returns the first degree relationships of a table func (s *DBSchema) GetFirstDegree(t DBTable) (items []RelNode, err error) { currNode, ok := s.tindex[(t.Schema + ":" + t.Name)] if !ok { return nil, fmt.Errorf("table not found: %s", t.String()) } - relatedNodes := s.rg.Connections(currNode.nodeID) + relatedNodes := s.relationshipGraph.Connections(currNode.nodeID) for _, id := range relatedNodes { v := s.getRelNodes(id, currNode.nodeID) items = append(items, v...) @@ -321,15 +334,16 @@ func (s *DBSchema) GetFirstDegree(t DBTable) (items []RelNode, err error) { return } +// GetSecondDegree returns the second degree relationships of a table func (s *DBSchema) GetSecondDegree(t DBTable) (items []RelNode, err error) { currNode, ok := s.tindex[(t.Schema + ":" + t.Name)] if !ok { return nil, fmt.Errorf("table not found: %s", t.String()) } - relatedNodes1 := s.rg.Connections(currNode.nodeID) + relatedNodes1 := s.relationshipGraph.Connections(currNode.nodeID) for _, id := range relatedNodes1 { - relatedNodes2 := s.rg.Connections(id) + relatedNodes2 := s.relationshipGraph.Connections(id) for _, id1 := range relatedNodes2 { v := s.getRelNodes(id1, id) items = append(items, v...) @@ -338,10 +352,11 @@ func (s *DBSchema) GetSecondDegree(t DBTable) (items []RelNode, err error) { return } +// getRelNodes returns the relationship nodes func (s *DBSchema) getRelNodes(fromID, toID int32) (items []RelNode) { - edges := s.rg.GetEdges(fromID, toID) + edges := s.relationshipGraph.GetEdges(fromID, toID) for _, e := range edges { - e1 := s.ae[e.ID] + e1 := s.allEdges[e.ID] if e1.name == "" { continue } @@ -351,6 +366,7 @@ func (s *DBSchema) getRelNodes(fromID, toID int32) (items []RelNode) { return } +// getColumn returns a column from a table func (ti *DBTable) getColumn(name string) (DBColumn, bool) { var c DBColumn if i, ok := ti.colMap[name]; ok { @@ -359,6 +375,7 @@ func (ti *DBTable) getColumn(name string) (DBColumn, bool) { return c, false } +// GetColumn returns a column from a table func (ti *DBTable) GetColumn(name string) (DBColumn, error) { c, ok := ti.getColumn(name) if ok { @@ -367,14 +384,17 @@ func (ti *DBTable) GetColumn(name string) (DBColumn, error) { return c, fmt.Errorf("column: '%s.%s' not found", ti.Name, name) } +// ColumnExists returns true if a column exists in a table func (ti *DBTable) ColumnExists(name string) (DBColumn, bool) { return ti.getColumn(name) } +// GetFunction returns a function from the schema func (s *DBSchema) GetFunctions() map[string]DBFunction { - return s.fm + return s.dbFunctions } +// GetRelName returns the relationship name func GetRelName(colName string) string { cn := colName @@ -397,18 +417,22 @@ func GetRelName(colName string) string { return cn } +// DBType returns the database type func (s *DBSchema) DBType() string { - return s.typ + return s.dbType } +// DBVersion returns the database version func (s *DBSchema) DBVersion() int { - return s.ver + return s.version } +// DBSchema returns the database schema func (s *DBSchema) DBSchema() string { return s.schema } +// DBName returns the database name func (s *DBSchema) DBName() string { return s.name } diff --git a/core/internal/sdata/strings.go b/core/internal/sdata/strings.go index 956ac112..5c45a848 100644 --- a/core/internal/sdata/strings.go +++ b/core/internal/sdata/strings.go @@ -5,10 +5,12 @@ import ( "strings" ) +// String returns a string representation of the DBTable func (ti *DBTable) String() string { return ti.Schema + "." + ti.Name } +// String returns a string representation of the DBColumn func (col DBColumn) String() string { var sb strings.Builder @@ -23,6 +25,7 @@ func (col DBColumn) String() string { return sb.String() } +// String returns a string representation of the DBFunction func (fn DBFunction) String() string { var sb strings.Builder @@ -50,6 +53,7 @@ func (fn DBFunction) String() string { return sb.String() } +// String returns a string representation of the DBRel func (re *DBRel) String() string { return fmt.Sprintf("'%s' --(%s)--> '%s'", re.Left.Col.String(), diff --git a/core/internal/sdata/tables.go b/core/internal/sdata/tables.go index 8be92cb9..cf397b9f 100644 --- a/core/internal/sdata/tables.go +++ b/core/internal/sdata/tables.go @@ -10,6 +10,7 @@ import ( "golang.org/x/sync/errgroup" ) +// DBInfo holds the database schema information type DBInfo struct { Type string Version int @@ -24,6 +25,7 @@ type DBInfo struct { hash int } +// DBTable holds the database table information type DBTable struct { Comment string Schema string @@ -38,6 +40,7 @@ type DBTable struct { colMap map[string]int } +// VirtualTable holds the virtual table information type VirtualTable struct { Name string IDColumn string @@ -45,6 +48,7 @@ type VirtualTable struct { FKeyColumn string } +// GetDBInfo returns the database schema information func GetDBInfo( db *sql.DB, dbType string, @@ -101,6 +105,7 @@ func GetDBInfo( return di, nil } +// NewDBInfo returns a new DBInfo object func NewDBInfo( dbType string, dbVersion int, @@ -176,6 +181,7 @@ func NewDBInfo( return di } +// NewDBTable returns a new DBTable object func NewDBTable(schema, name, _type string, cols []DBColumn) DBTable { ti := DBTable{ Schema: schema, @@ -202,6 +208,7 @@ func NewDBTable(schema, name, _type string, cols []DBColumn) DBTable { return ti } +// AddTable adds a table to the DBInfo object func (di *DBInfo) AddTable(t DBTable) { for i, c := range t.Columns { di.colMap[(c.Schema + ":" + c.Table + ":" + c.Name)] = i @@ -212,6 +219,7 @@ func (di *DBInfo) AddTable(t DBTable) { di.tableMap[(t.Schema + ":" + t.Name)] = i } +// GetTable returns a table from the DBInfo object func (di *DBInfo) GetColumn(schema, table, column string) (*DBColumn, error) { t, err := di.GetTable(schema, table) if err != nil { @@ -226,6 +234,7 @@ func (di *DBInfo) GetColumn(schema, table, column string) (*DBColumn, error) { return &t.Columns[cid], nil } +// GetTable returns a table from the DBInfo object func (di *DBInfo) GetTable(schema, table string) (*DBTable, error) { tid, ok := di.tableMap[(schema + ":" + table)] if !ok { @@ -235,6 +244,7 @@ func (di *DBInfo) GetTable(schema, table string) (*DBTable, error) { return &di.Tables[tid], nil } +// DBColumn returns the column as a string type DBColumn struct { Comment string ID int32 @@ -254,6 +264,7 @@ type DBColumn struct { Schema string } +// DiscoverColumns returns the columns of a table func DiscoverColumns(db *sql.DB, dbtype string, blockList []string) ([]DBColumn, error) { var sqlStmt string @@ -350,6 +361,7 @@ func DiscoverColumns(db *sql.DB, dbtype string, blockList []string) ([]DBColumn, return cols, nil } +// DBFunction holds the database function information type DBFunction struct { Comment string Schema string @@ -360,6 +372,7 @@ type DBFunction struct { Outputs []DBFuncParam } +// DBFuncParam holds the database function parameter information type DBFuncParam struct { ID int Name string @@ -367,6 +380,7 @@ type DBFuncParam struct { Array bool } +// DiscoverFunctions returns the functions of a database func DiscoverFunctions(db *sql.DB, dbtype string, blockList []string) ([]DBFunction, error) { var sqlStmt string @@ -423,6 +437,7 @@ func DiscoverFunctions(db *sql.DB, dbtype string, blockList []string) ([]DBFunct return funcs, nil } +// GetInput returns the input of a function func (fn *DBFunction) GetInput(name string) (ret DBFuncParam, err error) { for _, in := range fn.Inputs { if in.Name == name { @@ -432,10 +447,12 @@ func (fn *DBFunction) GetInput(name string) (ret DBFuncParam, err error) { return ret, fmt.Errorf("function input '%s' not found", name) } +// Hash returns the hash of the DBInfo object func (di *DBInfo) Hash() int { return di.hash } +// isInList checks if a value is in a list func isInList(val string, s []string) bool { for _, v := range s { regex := fmt.Sprintf("^%s$", v) diff --git a/core/internal/util/graph.go b/core/internal/util/graph.go index 67a9e7c7..b2a9a0c1 100644 --- a/core/internal/util/graph.go +++ b/core/internal/util/graph.go @@ -17,16 +17,19 @@ type Graph struct { graph [][]int32 } +// Create a new graph func NewGraph() *Graph { return &Graph{edges: make(map[[2]int32][]Edge)} } +// AddNode adds a new node to the graph func (g *Graph) AddNode() int32 { id := int32(len(g.graph)) g.graph = append(g.graph, []int32{}) return id } +// AddEdge adds a new edge to the graph func (g *Graph) AddEdge(from, to, weight int32, name string) (int32, error) { nl := int32(len(g.graph)) if from >= nl { @@ -55,6 +58,7 @@ func (g *Graph) AddEdge(from, to, weight int32, name string) (int32, error) { return id, nil } +// UpdateEdge updates the edge with the given ID func (g *Graph) UpdateEdge( from, to, edgeID, oppEdgeID int32, ) error { @@ -74,10 +78,12 @@ func (g *Graph) UpdateEdge( return fmt.Errorf("edge not found: %d", edgeID) } +// GetEdges returns all edges between the two nodes func (g *Graph) GetEdges(from, to int32) []Edge { return g.edges[[2]int32{from, to}] } +// AllPaths returns all paths between two nodes func (g *Graph) AllPaths(from, to int32) [][]int32 { var paths [][]int32 var limit int @@ -135,10 +141,12 @@ func (g *Graph) AllPaths(from, to int32) [][]int32 { return paths } +// Connections returns all connections for a given node func (g *Graph) Connections(n int32) []int32 { return g.graph[n] } +// equals checks if two slices are equal func equals(a, b []int32) bool { if len(a) != len(b) { return false diff --git a/core/internal/util/graph_test.go b/core/internal/util/graph_test.go index cbb8b417..caf1914d 100644 --- a/core/internal/util/graph_test.go +++ b/core/internal/util/graph_test.go @@ -53,7 +53,7 @@ func TestGraph1(t *testing.T) { }) edges := g.GetEdges(b, b) - assert.Equals(t, edges, []util.Edge{{13, 2, "test"}}) + assert.Equals(t, edges, []util.Edge{{ID: 13, OppID: 2, Weight: 0, Name: "test"}}) } /* diff --git a/core/intro.go b/core/intro.go index 713fb070..b409441c 100644 --- a/core/intro.go +++ b/core/intro.go @@ -198,7 +198,8 @@ type Introspection struct { result IntroResult } -func (gj *graphjin) introQuery() (result json.RawMessage, err error) { +// introQuery returns the introspection query result +func (gj *GraphjinEngine) introQuery() (result json.RawMessage, err error) { // Initialize the introscpection object in := Introspection{ @@ -216,6 +217,7 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { MutationType: &ShortFullType{Name: "Mutation"}, } + // Add the standard types // Add the standard types for _, v := range stdTypes { in.addType(v) @@ -228,6 +230,11 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { in.addExpTypes(v, "Int", newTypeRef("", "Int", nil)) in.addExpTypes(v, "Boolean", newTypeRef("", "Boolean", nil)) in.addExpTypes(v, "Float", newTypeRef("", "Float", nil)) + in.addExpTypes(v, "ID", newTypeRef("", "ID", nil)) + in.addExpTypes(v, "String", newTypeRef("", "String", nil)) + in.addExpTypes(v, "Int", newTypeRef("", "Int", nil)) + in.addExpTypes(v, "Boolean", newTypeRef("", "Boolean", nil)) + in.addExpTypes(v, "Float", newTypeRef("", "Float", nil)) // ListExpression Types v = append(expAll, expList...) @@ -235,14 +242,21 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { in.addExpTypes(v, "IntList", newTypeRef("", "Int", nil)) in.addExpTypes(v, "BooleanList", newTypeRef("", "Boolean", nil)) in.addExpTypes(v, "FloatList", newTypeRef("", "Float", nil)) + in.addExpTypes(v, "StringList", newTypeRef("", "String", nil)) + in.addExpTypes(v, "IntList", newTypeRef("", "Int", nil)) + in.addExpTypes(v, "BooleanList", newTypeRef("", "Boolean", nil)) + in.addExpTypes(v, "FloatList", newTypeRef("", "Float", nil)) v = append(expAll, expJSON...) in.addExpTypes(v, "JSON", newTypeRef("", "String", nil)) + in.addExpTypes(v, "JSON", newTypeRef("", "String", nil)) + // Add the roles // Add the roles in.addRolesEnumType(gj.roles) in.addTablesEnumType() + // Get all the alias and add to the schema // Get all the alias and add to the schema for alias, t := range in.schema.GetAliases() { if err = in.addTable(t, alias); err != nil { @@ -250,6 +264,7 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { } } + // Get all the tables and add to the schema // Get all the tables and add to the schema for _, t := range in.schema.GetTables() { if err = in.addTable(t, ""); err != nil { @@ -257,12 +272,14 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { } } + // Add the directives // Add the directives for _, dt := range dirTypes { in.addDirType(dt) } in.addDirValidateType() + // Add the types to the schema // Add the types to the schema for _, v := range in.types { in.result.Schema.Types = append(in.result.Schema.Types, v) @@ -272,6 +289,7 @@ func (gj *graphjin) introQuery() (result json.RawMessage, err error) { return } +// addTable adds a table to the introspection schema func (in *Introspection) addTable(table sdata.DBTable, alias string) (err error) { if table.Blocked || len(table.Columns) == 0 { return @@ -309,6 +327,7 @@ func (in *Introspection) addTable(table sdata.DBTable, alias string) (err error) return } +// addTypeTo adds a type to the introspection schema func (in *Introspection) addTypeTo(op string, ft FullType) { qt := in.types[op] qt.Fields = append(qt.Fields, FieldObject{ @@ -320,6 +339,7 @@ func (in *Introspection) addTypeTo(op string, ft FullType) { in.types[op] = qt } +// getName returns the name of the type func (in *Introspection) getName(name string) string { if in.camelCase { return util.ToCamel(name) @@ -328,6 +348,7 @@ func (in *Introspection) getName(name string) string { } } +// addExpTypes adds the expression types to the introspection schema func (in *Introspection) addExpTypes(exps []exp, name string, rt *TypeRef) { ft := FullType{ Kind: KIND_INPUT_OBJ, @@ -350,10 +371,12 @@ func (in *Introspection) addExpTypes(exps []exp, name string, rt *TypeRef) { in.addType(ft) } +// addTableType adds a table type to the introspection schema func (in *Introspection) addTableType(t sdata.DBTable, alias string) (ft FullType, err error) { return in.addTableTypeWithDepth(t, alias, 0) } +// addTableTypeWithDepth adds a table type with depth to the introspection schema func (in *Introspection) addTableTypeWithDepth( table sdata.DBTable, alias string, depth int, ) (ft FullType, err error) { @@ -463,6 +486,7 @@ func (in *Introspection) addTableTypeWithDepth( return } +// addColumnsEnumType adds an enum type for the columns of the table func (in *Introspection) addColumnsEnumType(t sdata.DBTable) (err error) { tableName := in.getName(t.Name) ft := FullType{ @@ -483,6 +507,7 @@ func (in *Introspection) addColumnsEnumType(t sdata.DBTable) (err error) { return } +// addTablesEnumType adds an enum type for the tables func (in *Introspection) addTablesEnumType() { ft := FullType{ Kind: KIND_ENUM, @@ -501,6 +526,7 @@ func (in *Introspection) addTablesEnumType() { in.addType(ft) } +// addRolesEnumType adds an enum type for the roles func (in *Introspection) addRolesEnumType(roles map[string]*Role) { ft := FullType{ Kind: KIND_ENUM, @@ -520,6 +546,7 @@ func (in *Introspection) addRolesEnumType(roles map[string]*Role) { in.addType(ft) } +// addOrderByType adds an order by type to the introspection schema func (in *Introspection) addOrderByType(t sdata.DBTable, ft *FullType) { ty := FullType{ Kind: KIND_INPUT_OBJ, @@ -539,6 +566,7 @@ func (in *Introspection) addOrderByType(t sdata.DBTable, ft *FullType) { ft.addArg("orderBy", newTypeRef("", (t.Name+SUFFIX_ORDER_BY), nil)) } +// addWhereType adds a where type to the introspection schema func (in *Introspection) addWhereType(table sdata.DBTable, ft *FullType) { tablename := (table.Name + SUFFIX_WHERE) ty := FullType{ @@ -664,6 +692,7 @@ func (in *Introspection) addInputType(table sdata.DBTable, ft FullType) (retFT F return } +// addTableArgsType adds the table arguments type to the introspection schema func (in *Introspection) addTableArgsType(table sdata.DBTable, ft *FullType) { if table.Type != "function" { return @@ -673,6 +702,7 @@ func (in *Introspection) addTableArgsType(table sdata.DBTable, ft *FullType) { ft.addArg("args", newTypeRef("", ty.Name, nil)) } +// addArgsType adds the arguments type to the introspection schema func (in *Introspection) addArgsType(table sdata.DBTable, fn sdata.DBFunction) (ft FullType) { ft = FullType{ Kind: "INPUT_OBJECT", @@ -705,6 +735,7 @@ func (in *Introspection) addArgsType(table sdata.DBTable, fn sdata.DBFunction) ( return } +// getColumnField returns the field object for the given column func (in *Introspection) getColumnField(column sdata.DBColumn) (field FieldObject, err error) { field.Args = []InputValue{} field.Name = in.getName(column.Name) @@ -736,6 +767,7 @@ func (in *Introspection) getColumnField(column sdata.DBColumn) (field FieldObjec return } +// getFunctionField returns the field object for the given function func (in *Introspection) getFunctionField(t sdata.DBTable, fn sdata.DBFunction) (f FieldObject) { f.Name = in.getName(fn.Name) f.Args = []InputValue{} @@ -761,6 +793,7 @@ func (in *Introspection) getFunctionField(t sdata.DBTable, fn sdata.DBFunction) return } +// getTableField returns the field object for the given table func (in *Introspection) getTableField(relNode sdata.RelNode) ( f FieldObject, skip bool, err error, ) { @@ -785,6 +818,7 @@ func (in *Introspection) getTableField(relNode sdata.RelNode) ( return } +// addDirType adds a directive type to the introspection schema func (in *Introspection) addDirType(dt dir) { d := DirectiveType{ Name: dt.name, @@ -805,6 +839,7 @@ func (in *Introspection) addDirType(dt dir) { in.result.Schema.Directives = append(in.result.Schema.Directives, d) } +// addDirValidateType adds a validate directive type to the introspection schema func (in *Introspection) addDirValidateType() { ft := FullType{ Kind: KIND_ENUM, @@ -848,6 +883,7 @@ func (in *Introspection) addDirValidateType() { in.result.Schema.Directives = append(in.result.Schema.Directives, d) } +// addArg adds an argument to the full type func (ft *FullType) addArg(name string, tr *TypeRef) { ft.InputFields = append(ft.InputFields, InputValue{ Name: name, @@ -855,6 +891,7 @@ func (ft *FullType) addArg(name string, tr *TypeRef) { }) } +// addOrReplaceArg adds or replaces an argument to the full type func (ft *FullType) addOrReplaceArg(name string, tr *TypeRef) { for i, a := range ft.InputFields { if a.Name == name { @@ -868,6 +905,7 @@ func (ft *FullType) addOrReplaceArg(name string, tr *TypeRef) { }) } +// addType adds a type to the introspection schema func (in *Introspection) addType(ft FullType) { in.types[ft.Name] = ft } diff --git a/core/osfs.go b/core/osfs.go index a8b576d6..9a8ed1c2 100644 --- a/core/osfs.go +++ b/core/osfs.go @@ -8,17 +8,20 @@ import ( ) type osFS struct { - bp string + basePath string } -func NewOsFS(basePath string) *osFS { return &osFS{bp: basePath} } +// NewOsFS creates a new OSFS instance +func NewOsFS(basePath string) *osFS { return &osFS{basePath: basePath} } +// Get returns the file content func (f *osFS) Get(path string) ([]byte, error) { - return os.ReadFile(filepath.Join(f.bp, path)) + return os.ReadFile(filepath.Join(f.basePath, path)) } +// Put writes the data to the file func (f *osFS) Put(path string, data []byte) (err error) { - path = filepath.Join(f.bp, path) + path = filepath.Join(f.basePath, path) dir := filepath.Dir(path) ok, err := f.exists(dir) @@ -32,12 +35,14 @@ func (f *osFS) Put(path string, data []byte) (err error) { return os.WriteFile(path, data, os.ModePerm) } +// Exists checks if the file exists func (f *osFS) Exists(path string) (ok bool, err error) { - path = filepath.Join(f.bp, path) + path = filepath.Join(f.basePath, path) ok, err = f.exists(path) return } +// Remove deletes the file func (f *osFS) exists(path string) (ok bool, err error) { if _, err = os.Stat(path); err == nil { ok = true diff --git a/core/remote_api.go b/core/remote_api.go index cee72c8c..73ef723f 100644 --- a/core/remote_api.go +++ b/core/remote_api.go @@ -11,7 +11,7 @@ import ( "github.com/dosco/graphjin/core/v3/internal/jsn" ) -// RemoteAPI struct defines a remote API endpoint +// remoteAPI struct defines a remote API endpoint type remoteAPI struct { httpClient *http.Client URL string @@ -26,6 +26,7 @@ type remoteHdrs struct { Value string } +// newRemoteAPI creates a new remote API endpoint func newRemoteAPI(v map[string]interface{}, httpClient *http.Client) (*remoteAPI, error) { ra := remoteAPI{ httpClient: httpClient, @@ -50,6 +51,7 @@ func newRemoteAPI(v map[string]interface{}, httpClient *http.Client) (*remoteAPI return &ra, nil } +// Resolve function resolves a remote API request func (r *remoteAPI) Resolve(c context.Context, rr ResolverReq) ([]byte, error) { uri := strings.ReplaceAll(r.URL, "$id", rr.ID) diff --git a/core/remote_join.go b/core/remote_join.go index d1abdb23..8013e94c 100644 --- a/core/remote_join.go +++ b/core/remote_join.go @@ -11,6 +11,7 @@ import ( "github.com/dosco/graphjin/core/v3/internal/qcode" ) +// execRemoteJoin fetches remote data for the marked insertion points func (s *gstate) execRemoteJoin(c context.Context) (err error) { // fetch the field name used within the db response json // that are used to mark insertion points and the mapping between @@ -41,6 +42,7 @@ func (s *gstate) execRemoteJoin(c context.Context) (err error) { return } +// resolveRemotes fetches remote data for the marked insertion points func (s *gstate) resolveRemotes( ctx context.Context, from []jsn.Field, @@ -85,7 +87,7 @@ func (s *gstate) resolveRemotes( ctx1, span := s.gj.spanStart(ctx, "Execute Remote Request") b, err := r.Fn.Resolve(ctx1, ResolverReq{ - ID: string(id), Sel: sel, Log: s.gj.log, ReqConfig: s.r.rc, + ID: string(id), Sel: sel, Log: s.gj.log, RequestConfig: s.r.requestconfig, }) if err != nil { cerr = fmt.Errorf("%s: %s", sel.Table, err) @@ -121,6 +123,7 @@ func (s *gstate) resolveRemotes( return to, cerr } +// parentFieldIds fetches the field name used within the db response json func (s *gstate) parentFieldIds() ([][]byte, map[string]*qcode.Select, error) { selects := s.cs.st.qc.Selects remotes := s.cs.st.qc.Remotes @@ -148,6 +151,7 @@ func (s *gstate) parentFieldIds() ([][]byte, map[string]*qcode.Select, error) { return fm, sm, nil } +// fieldsToList converts a list of qcode.Field to a list of strings func fieldsToList(fields []qcode.Field) []string { var f []string diff --git a/core/resolve.go b/core/resolve.go index 7ae844e7..678cc402 100644 --- a/core/resolve.go +++ b/core/resolve.go @@ -15,7 +15,8 @@ type resItem struct { Fn Resolver } -func (gj *graphjin) newRTMap() map[string]ResolverFn { +// newRTMap returns a map of resolver functions +func (gj *GraphjinEngine) newRTMap() map[string]ResolverFn { return map[string]ResolverFn{ "remote_api": func(v ResolverProps) (Resolver, error) { return newRemoteAPI(v, gj.trace.NewHTTPClient()) @@ -23,7 +24,8 @@ func (gj *graphjin) newRTMap() map[string]ResolverFn { } } -func (gj *graphjin) initResolvers() error { +// initResolvers initializes the resolvers +func (gj *GraphjinEngine) initResolvers() error { gj.rmap = make(map[string]resItem) if gj.rtmap == nil { @@ -42,7 +44,8 @@ func (gj *graphjin) initResolvers() error { return nil } -func (gj *graphjin) initRemote( +// initRemote initializes the remote resolver +func (gj *GraphjinEngine) initRemote( rc ResolverConfig, rtmap map[string]ResolverFn, ) error { // Defines the table column to be used as an id in the diff --git a/core/rolestmt.go b/core/rolestmt.go index 522d8b5e..61a59aa2 100644 --- a/core/rolestmt.go +++ b/core/rolestmt.go @@ -8,7 +8,7 @@ import ( ) // nolint:errcheck -func (gj *graphjin) prepareRoleStmt() error { +func (gj *GraphjinEngine) prepareRoleStmt() error { if !gj.abacEnabled { return nil } @@ -20,7 +20,7 @@ func (gj *graphjin) prepareRoleStmt() error { w := &bytes.Buffer{} io.WriteString(w, `SELECT (CASE WHEN EXISTS (`) - gj.pc.RenderVar(w, &gj.roleStmtMD, gj.conf.RolesQuery) + gj.psqlCompiler.RenderVar(w, &gj.roleStatementMetadata, gj.conf.RolesQuery) io.WriteString(w, `) THEN `) io.WriteString(w, `(SELECT (CASE`) @@ -36,7 +36,7 @@ func (gj *graphjin) prepareRoleStmt() error { } io.WriteString(w, ` ELSE 'user' END) FROM (`) - gj.pc.RenderVar(w, &gj.roleStmtMD, gj.conf.RolesQuery) + gj.psqlCompiler.RenderVar(w, &gj.roleStatementMetadata, gj.conf.RolesQuery) io.WriteString(w, `) AS _sg_auth_roles_query LIMIT 1) `) switch gj.dbtype { @@ -47,6 +47,6 @@ func (gj *graphjin) prepareRoleStmt() error { io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS _sg_auth_filler LIMIT 1; `) } - gj.roleStmt = w.String() + gj.roleStatement = w.String() return nil } diff --git a/core/schema.go b/core/schema.go index 938cd728..3ce5521d 100644 --- a/core/schema.go +++ b/core/schema.go @@ -82,6 +82,7 @@ type {{.Name}} {{end -}} ` +// writeSchema writes the schema to the given writer func writeSchema(s *sdata.DBInfo, out io.Writer) (err error) { fn := template.FuncMap{ "pascal": toPascalCase, @@ -104,6 +105,7 @@ func writeSchema(s *sdata.DBInfo, out io.Writer) (err error) { return } +// toPascalCase converts a string to pascal case func toPascalCase(text string) string { var sb strings.Builder for _, v := range strings.Fields(text) { @@ -115,6 +117,7 @@ func toPascalCase(text string) string { var dbTypeRe = regexp.MustCompile(`([a-zA-Z ]+)(\((.+)\))?`) +// parseDBType parses the db type string func parseDBType(name string) (res [2]string, err error) { v := dbTypeRe.FindStringSubmatch(name) if len(v) == 4 { diff --git a/core/subs.go b/core/subs.go index 6ee8520e..ed149a46 100644 --- a/core/subs.go +++ b/core/subs.go @@ -83,7 +83,7 @@ func (g *GraphJin) Subscribe( c context.Context, query string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (m *Member, err error) { // get the name, query vars h, err := graph.FastParse(query) @@ -91,7 +91,7 @@ func (g *GraphJin) Subscribe( return } - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) // create the request object r := gj.newGraphqlReq(rc, "subscription", h.Name, nil, vars) @@ -119,9 +119,9 @@ func (g *GraphJin) SubscribeByName( c context.Context, name string, vars json.RawMessage, - rc *ReqConfig, + rc *RequestConfig, ) (m *Member, err error) { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) item, err := gj.allowList.GetByName(name, gj.prod) if err != nil { @@ -134,15 +134,16 @@ func (g *GraphJin) SubscribeByName( return } -func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( +// subscribe function is called on the graphjin struct to subscribe to a query. +func (gj *GraphjinEngine) subscribe(c context.Context, r GraphqlReq) ( m *Member, err error, ) { - if r.op != qcode.QTSubscription { + if r.operation != qcode.QTSubscription { return nil, errors.New("subscription: not a subscription query") } // transactions not supported with subscriptions - if r.rc != nil && r.rc.Tx != nil { + if r.requestconfig != nil && r.requestconfig.Tx != nil { return nil, errors.New("subscription: database transactions not supported") } @@ -189,7 +190,7 @@ func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( } m = &Member{ - ns: r.ns, + ns: r.namespace, id: atomic.AddUint64(&sub.idgen, 1), Result: make(chan *Result, 10), sub: sub, @@ -206,13 +207,14 @@ func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( return } -func (gj *graphjin) initSub(c context.Context, sub *sub) (err error) { +// initSub function is called on the graphjin struct to initialize a subscription. +func (gj *GraphjinEngine) initSub(c context.Context, sub *sub) (err error) { if err = sub.s.compile(); err != nil { return } if !gj.prod { - err = gj.saveToAllowList(sub.s.cs.st.qc, sub.s.r.ns) + err = gj.saveToAllowList(sub.s.cs.st.qc, sub.s.r.namespace) if err != nil { return } @@ -226,7 +228,8 @@ func (gj *graphjin) initSub(c context.Context, sub *sub) (err error) { return } -func (gj *graphjin) subController(sub *sub) { +// subController function is called on the graphjin struct to control the subscription. +func (gj *GraphjinEngine) subController(sub *sub) { // remove subscription if controller exists defer gj.subs.Delete(sub.k) @@ -264,6 +267,7 @@ func (gj *graphjin) subController(sub *sub) { } } +// addMember function is called on the sub struct to add a member. func (s *sub) addMember(m *Member) error { mi := minfo{cindx: m.cindx} if mi.cindx != -1 { @@ -294,6 +298,7 @@ func (s *sub) addMember(m *Member) error { return nil } +// deleteMember function is called on the sub struct to delete a member. func (s *sub) deleteMember(m *Member) { i, ok := s.findByID(m.id) if !ok { @@ -313,6 +318,7 @@ func (s *sub) deleteMember(m *Member) { s.ids = s.ids[:len(s.ids)-1] } +// updateMember function is called on the sub struct to update a member. func (s *sub) updateMember(msg mmsg) error { i, ok := s.findByID(msg.id) if !ok { @@ -338,7 +344,8 @@ func (s *sub) updateMember(msg mmsg) error { return nil } -func (s *sub) fanOutJobs(gj *graphjin) { +// fanOutJobs function is called on the sub struct to fan out jobs. +func (s *sub) fanOutJobs(gj *GraphjinEngine) { switch { case len(s.ids) == 0: return @@ -355,7 +362,8 @@ func (s *sub) fanOutJobs(gj *graphjin) { } } -func (gj *graphjin) subCheckUpdates(sub *sub, mv mval, start int) { +// subCheckUpdates function is called on the graphjin struct to check updates. +func (gj *GraphjinEngine) subCheckUpdates(sub *sub, mv mval, start int) { // Do not use the `mval` embedded inside sub since // its not thread safe use the copy `mv mval`. @@ -433,7 +441,8 @@ func (gj *graphjin) subCheckUpdates(sub *sub, mv mval, start int) { } } -func (gj *graphjin) subFirstQuery(sub *sub, m *Member) (mmsg, error) { +// subFirstQuery function is called on the graphjin struct to get the first query. +func (gj *GraphjinEngine) subFirstQuery(sub *sub, m *Member) (mmsg, error) { c := context.Background() // when params are not available we use a more optimized @@ -473,7 +482,8 @@ func (gj *graphjin) subFirstQuery(sub *sub, m *Member) (mmsg, error) { return mm, err } -func (gj *graphjin) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) { +// subNotifyMember function is called on the graphjin struct to notify a member. +func (gj *GraphjinEngine) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) { _, err := gj.subNotifyMemberEx(s, mv.mi[j].dh, mv.mi[j].cindx, @@ -484,7 +494,8 @@ func (gj *graphjin) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) } } -func (gj *graphjin) subNotifyMemberEx(sub *sub, +// subNotifyMemberEx function is called on the graphjin struct to notify a member. +func (gj *GraphjinEngine) subNotifyMemberEx(sub *sub, dh [32]byte, cindx int, id uint64, rc chan *Result, js json.RawMessage, update bool, ) (mmsg, error) { mm := mmsg{id: id} @@ -496,15 +507,15 @@ func (gj *graphjin) subNotifyMemberEx(sub *sub, nonce := mm.dh - if cv := firstCursorValue(js, gj.pf); len(cv) != 0 { + if cv := firstCursorValue(js, gj.printFormat); len(cv) != 0 { mm.cursor = string(cv) } ejs, err := encryptValues(js, - gj.pf, + gj.printFormat, decPrefix, nonce[:], - gj.encKey) + gj.encryptionKey) if err != nil { return mm, fmt.Errorf(errSubs, "cursor", err) } @@ -520,11 +531,11 @@ func (gj *graphjin) subNotifyMemberEx(sub *sub, } res := &Result{ - op: qcode.QTQuery, - name: sub.s.r.name, - sql: sub.s.cs.st.sql, - role: sub.s.cs.st.role, - Data: ejs, + operation: qcode.QTQuery, + name: sub.s.r.name, + sql: sub.s.cs.st.sql, + role: sub.s.cs.st.role, + Data: ejs, } // if parameters exists then each response is unique @@ -538,6 +549,7 @@ func (gj *graphjin) subNotifyMemberEx(sub *sub, return mm, nil } +// renderSubWrap function is called on the graphjin struct to render a sub wrap. func renderSubWrap(st stmt, ct string) string { var w strings.Builder @@ -577,6 +589,7 @@ func renderSubWrap(st stmt, ct string) string { return w.String() } +// renderJSONArray function is called on the graphjin struct to render a json array. func renderJSONArray(v []json.RawMessage) json.RawMessage { w := bytes.Buffer{} w.WriteRune('[') @@ -590,6 +603,7 @@ func renderJSONArray(v []json.RawMessage) json.RawMessage { return json.RawMessage(w.Bytes()) } +// findByID function is called on the sub struct to find a member by id. func (s *sub) findByID(id uint64) (int, bool) { for i := range s.ids { if s.ids[i] == id { @@ -599,6 +613,7 @@ func (s *sub) findByID(id uint64) (int, bool) { return 0, false } +// Unsubscribe function is called on the member struct to unsubscribe. func (m *Member) Unsubscribe() { if m != nil && !m.done { m.sub.del <- m @@ -606,10 +621,12 @@ func (m *Member) Unsubscribe() { } } +// ID function is called on the member struct to get the id. func (m *Member) ID() uint64 { return m.id } +// String function is called on the member struct to get the string. func (m *Member) String() string { return strconv.Itoa(int(m.id)) } diff --git a/core/trace.go b/core/trace.go index af947ae0..c182beb7 100644 --- a/core/trace.go +++ b/core/trace.go @@ -21,17 +21,21 @@ type tracer struct{} type span struct{} +// Start starts a new trace span func (t *tracer) Start(c context.Context, name string) (context.Context, Spaner) { return c, &span{} } +// NewHTTPClient creates a new HTTP client func (t *tracer) NewHTTPClient() *http.Client { return &http.Client{} } +// End ends the span func (s *span) End() { } +// Error logs an error func (s *span) Error(err error) { } @@ -40,9 +44,11 @@ type StringAttr struct { Value string } +// IsRecording returns true if the span is recording func (s *span) IsRecording() bool { return false } +// SetAttributesString sets the attributes func (s *span) SetAttributesString(attrs ...StringAttr) { } diff --git a/core/watcher.go b/core/watcher.go index 3f25351b..f1bb1478 100644 --- a/core/watcher.go +++ b/core/watcher.go @@ -6,8 +6,9 @@ import ( "github.com/dosco/graphjin/core/v3/internal/sdata" ) +// initDBWatcher initializes the database schema watcher func (g *GraphJin) initDBWatcher() error { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) // no schema polling in production if gj.prod { @@ -30,12 +31,13 @@ func (g *GraphJin) initDBWatcher() error { return nil } +// startDBWatcher starts the database schema watcher func (g *GraphJin) startDBWatcher(ps time.Duration) { ticker := time.NewTicker(ps) defer ticker.Stop() for range ticker.C { - gj := g.Load().(*graphjin) + gj := g.Load().(*GraphjinEngine) latestDi, err := sdata.GetDBInfo( gj.db, diff --git a/go.work.sum b/go.work.sum index 3c2e02de..c887ea90 100644 --- a/go.work.sum +++ b/go.work.sum @@ -676,6 +676,7 @@ github.com/fatih/color v1.14.1/go.mod h1:2oHN61fhTpgcxD3TSWCgKDiH1+x4OiDVVGH8Wlg github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-asn1-ber/asn1-ber v1.5.5 h1:MNHlNMBDgEKD4TcKr36vQN68BA00aDfjIt3/bD50WnA= github.com/go-asn1-ber/asn1-ber v1.5.5/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= diff --git a/serv/admin.go b/serv/admin.go index 89847e26..bc41437a 100644 --- a/serv/admin.go +++ b/serv/admin.go @@ -12,10 +12,11 @@ import ( "time" ) -func adminDeployHandler(s1 *Service) http.Handler { +// adminDeployHandler handles the admin deploy endpoint +func adminDeployHandler(s1 *HttpService) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { var req DeployReq - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if !s.isAdminSecret(r) { authFail(w) @@ -57,9 +58,10 @@ func adminDeployHandler(s1 *Service) http.Handler { return http.HandlerFunc(h) } -func adminRollbackHandler(s1 *Service) http.Handler { +// adminRollbackHandler handles the admin rollback endpoint +func adminRollbackHandler(s1 *HttpService) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if !s.isAdminSecret(r) { authFail(w) @@ -85,7 +87,8 @@ func adminRollbackHandler(s1 *Service) http.Handler { return http.HandlerFunc(h) } -func (s *service) isAdminSecret(r *http.Request) bool { +// adminConfigHandler handles the checking of the admin secret endpoint +func (s *graphjinService) isAdminSecret(r *http.Request) bool { atomic.AddInt32(&s.adminCount, 1) defer atomic.StoreInt32(&s.adminCount, 0) @@ -105,14 +108,17 @@ func (s *service) isAdminSecret(r *http.Request) bool { return (err == nil) && bytes.Equal(v1, s.asec[:]) } +// badReq sends a bad request response func badReq(w http.ResponseWriter, msg string) { http.Error(w, msg, http.StatusBadRequest) } +// intErr sends an internal server error response func intErr(w http.ResponseWriter, msg string) { http.Error(w, msg, http.StatusInternalServerError) } +// authFail sends an unauthorized response func authFail(w http.ResponseWriter) { http.Error(w, "auth failed", http.StatusUnauthorized) } diff --git a/serv/afero.go b/serv/afero.go index 10276062..428683c7 100644 --- a/serv/afero.go +++ b/serv/afero.go @@ -11,14 +11,17 @@ type aferoFS struct { fs afero.Fs } +// newAferoFS creates a new aferoFS instance func newAferoFS(fs afero.Fs, basePath string) *aferoFS { return &aferoFS{fs: afero.NewBasePathFs(fs, basePath)} } +// Get reads a file from the file system func (f *aferoFS) Get(path string) ([]byte, error) { return afero.ReadFile(f.fs, path) } +// Put writes a file to the file system func (f *aferoFS) Put(path string, data []byte) (err error) { dir := filepath.Dir(path) ok, err := f.Exists(dir) @@ -32,6 +35,7 @@ func (f *aferoFS) Put(path string, data []byte) (err error) { return afero.WriteFile(f.fs, path, data, os.ModePerm) } +// Exists checks if a file exists in the file system func (f *aferoFS) Exists(path string) (exists bool, err error) { return afero.Exists(f.fs, path) } diff --git a/serv/api.go b/serv/api.go index 184ac7b0..075e98cd 100644 --- a/serv/api.go +++ b/serv/api.go @@ -56,7 +56,7 @@ import ( "go.uber.org/zap/zapcore" ) -type Service struct { +type HttpService struct { atomic.Value opt []Option cpath string @@ -71,7 +71,7 @@ const ( type HookFn func(*core.Result) -type service struct { +type graphjinService struct { log *zap.SugaredLogger // logger zlog *zap.Logger // faster logger logLevel int // log level @@ -92,10 +92,10 @@ type service struct { tracer trace.Tracer } -type Option func(*service) error +type Option func(*graphjinService) error // NewGraphJinService a new service -func NewGraphJinService(conf *Config, options ...Option) (*Service, error) { +func NewGraphJinService(conf *Config, options ...Option) (*HttpService, error) { if conf.dirty { return nil, errors.New("do not re-use config object") } @@ -105,7 +105,7 @@ func NewGraphJinService(conf *Config, options ...Option) (*Service, error) { return nil, err } - s1 := &Service{opt: options, cpath: conf.Serv.ConfigPath} + s1 := &HttpService{opt: options, cpath: conf.Serv.ConfigPath} s1.Store(s) if s.conf.WatchAndReload { @@ -121,7 +121,7 @@ func NewGraphJinService(conf *Config, options ...Option) (*Service, error) { // OptionSetDB sets a new db client func OptionSetDB(db *sql.DB) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.db = db return nil } @@ -129,7 +129,7 @@ func OptionSetDB(db *sql.DB) Option { // OptionSetHookFunc sets a function to be called on every request func OptionSetHookFunc(fn HookFn) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.hook = fn return nil } @@ -137,7 +137,7 @@ func OptionSetHookFunc(fn HookFn) Option { // OptionSetNamespace sets service namespace func OptionSetNamespace(namespace string) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.namespace = &namespace return nil } @@ -145,7 +145,7 @@ func OptionSetNamespace(namespace string) Option { // OptionSetFS sets service filesystem func OptionSetFS(fs core.FS) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.fs = fs return nil } @@ -153,7 +153,7 @@ func OptionSetFS(fs core.FS) Option { // OptionSetZapLogger sets service structured logger func OptionSetZapLogger(zlog *zap.Logger) Option { - return func(s *service) error { + return func(s *graphjinService) error { s.zlog = zlog s.log = zlog.Sugar() return nil @@ -162,13 +162,14 @@ func OptionSetZapLogger(zlog *zap.Logger) Option { // OptionDeployActive caused the active config to be deployed on func OptionDeployActive() Option { - return func(s *service) error { + return func(s *graphjinService) error { s.deployActive = true return nil } } -func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*service, error) { +// newGraphJinService creates a new service +func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*graphjinService, error) { var err error if conf == nil { conf = &Config{Core: Core{Debug: true}} @@ -178,7 +179,7 @@ func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*service, prod := conf.Serv.Production conf.Core.Production = prod - s := &service{ + s := &graphjinService{ conf: conf, zlog: zlog, log: zlog.Sugar(), @@ -224,7 +225,8 @@ func newGraphJinService(conf *Config, db *sql.DB, options ...Option) (*service, return s, nil } -func (s *service) normalStart() error { +// normalStart starts the service in normal mode +func (s *graphjinService) normalStart() error { opts := []core.Option{ core.OptionSetFS(s.fs), core.OptionSetTrace(otelPlugin.NewTracerFrom(s.tracer)), @@ -238,7 +240,8 @@ func (s *service) normalStart() error { return err } -func (s *service) hotStart() error { +// hotStart starts the service in hot-deploy mode +func (s *graphjinService) hotStart() error { ab, err := fetchActiveBundle(s.db) if err != nil { if strings.Contains(err.Error(), "_graphjin.") { @@ -251,7 +254,7 @@ func (s *service) hotStart() error { return s.normalStart() } - cf := s.conf.vi.ConfigFileUsed() + cf := s.conf.viper.ConfigFileUsed() cf = filepath.Base(strings.TrimSuffix(cf, filepath.Ext(cf))) cf = filepath.Join("/", cf) @@ -283,9 +286,9 @@ func (s *service) hotStart() error { } // Deploy a new configuration -func (s *Service) Deploy(conf *Config, options ...Option) error { +func (s *HttpService) Deploy(conf *Config, options ...Option) error { var err error - os := s.Load().(*service) + os := s.Load().(*graphjinService) if conf == nil { return nil @@ -304,27 +307,28 @@ func (s *Service) Deploy(conf *Config, options ...Option) error { } // Start the service listening on the configured port -func (s *Service) Start() error { +func (s *HttpService) Start() error { startHTTP(s) return nil } // Attach route to the internal http service -func (s *Service) Attach(mux Mux) error { +func (s *HttpService) Attach(mux Mux) error { return s.attach(mux, nil) } // AttachWithNS a namespaced route to the internal http service -func (s *Service) AttachWithNS(mux Mux, namespace string) error { +func (s *HttpService) AttachWithNS(mux Mux, namespace string) error { return s.attach(mux, &namespace) } -func (s *Service) attach(mux Mux, ns *string) error { +// attach attaches the service to the router +func (s *HttpService) attach(mux Mux, ns *string) error { if _, err := routesHandler(s, mux, ns); err != nil { return err } - s1 := s.Load().(*service) + s1 := s.Load().(*graphjinService) ver := version dep := s1.conf.name @@ -356,26 +360,26 @@ func (s *Service) attach(mux Mux, ns *string) error { } // GraphQLis the http handler the GraphQL endpoint -func (s *Service) GraphQL(ah auth.HandlerFunc) http.Handler { +func (s *HttpService) GraphQL(ah auth.HandlerFunc) http.Handler { return s.apiHandler(nil, ah, false) } // GraphQLWithNS is the http handler the namespaced GraphQL endpoint -func (s *Service) GraphQLWithNS(ah auth.HandlerFunc, ns string) http.Handler { +func (s *HttpService) GraphQLWithNS(ah auth.HandlerFunc, ns string) http.Handler { return s.apiHandler(&ns, ah, false) } // REST is the http handler the REST endpoint -func (s *Service) REST(ah auth.HandlerFunc) http.Handler { +func (s *HttpService) REST(ah auth.HandlerFunc) http.Handler { return s.apiHandler(nil, ah, true) } // RESTWithNS is the http handler the namespaced REST endpoint -func (s *Service) RESTWithNS(ah auth.HandlerFunc, ns string) http.Handler { +func (s *HttpService) RESTWithNS(ah auth.HandlerFunc, ns string) http.Handler { return s.apiHandler(&ns, ah, true) } -func (s *Service) apiHandler(ns *string, ah auth.HandlerFunc, rest bool) http.Handler { +func (s *HttpService) apiHandler(ns *string, ah auth.HandlerFunc, rest bool) http.Handler { var h http.Handler if rest { h = s.apiV1Rest(ns, ah) @@ -386,32 +390,34 @@ func (s *Service) apiHandler(ns *string, ah auth.HandlerFunc, rest bool) http.Ha } // WebUI is the http handler the web ui endpoint -func (s *Service) WebUI(routePrefix, gqlEndpoint string) http.Handler { +func (s *HttpService) WebUI(routePrefix, gqlEndpoint string) http.Handler { return webuiHandler(routePrefix, gqlEndpoint) } // GetGraphJin fetching internal GraphJin core -func (s *Service) GetGraphJin() *core.GraphJin { - s1 := s.Load().(*service) +func (s *HttpService) GetGraphJin() *core.GraphJin { + s1 := s.Load().(*graphjinService) return s1.gj } // GetDB fetching internal db client -func (s *Service) GetDB() *sql.DB { - s1 := s.Load().(*service) +func (s *HttpService) GetDB() *sql.DB { + s1 := s.Load().(*graphjinService) return s1.db } // Reload re-runs database discover and reinitializes service. -func (s *Service) Reload() error { - s1 := s.Load().(*service) +func (s *HttpService) Reload() error { + s1 := s.Load().(*graphjinService) return s1.gj.Reload() } -func (s *service) spanStart(c context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { +// spanStart starts the tracer +func (s *graphjinService) spanStart(c context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { return s.tracer.Start(c, name, opts...) } +// spanError records an error in the span func spanError(span trace.Span, err error) { if span.IsRecording() { span.RecordError(err) diff --git a/serv/client.go b/serv/client.go index 2b8a0422..84ed0fd4 100644 --- a/serv/client.go +++ b/serv/client.go @@ -39,6 +39,7 @@ type Resp struct { Msg string } +// NewAdminClient creates a new admin client func NewAdminClient(host string, secret string) *Client { c := resty.New(). SetBaseURL(host). @@ -67,6 +68,7 @@ func NewAdminClient(host string, secret string) *Client { return &Client{c} } +// Deploy deploys the configuration to the server func (c *Client) Deploy(name, confPath string) (*Resp, error) { errMsg := "deploy failed: %w" @@ -85,6 +87,7 @@ func (c *Client) Deploy(name, confPath string) (*Resp, error) { return &Resp{Msg: string(res.Body())}, nil } +// Rollback rolls back the last deployment func (c *Client) Rollback() (*Resp, error) { errMsg := "rollback failed: %w" @@ -97,6 +100,7 @@ func (c *Client) Rollback() (*Resp, error) { return &Resp{Msg: string(res.Body())}, nil } +// buildBundle creates a zip archive of the configuration directory func buildBundle(confPath string) (string, error) { buf := bytes.Buffer{} z := zip.NewWriter(&buf) diff --git a/serv/config.go b/serv/config.go index 7ac30e19..eac4d580 100644 --- a/serv/config.go +++ b/serv/config.go @@ -37,7 +37,7 @@ type Config struct { hash string name string dirty bool - vi *viper.Viper + viper *viper.Viper } // Configuration for admin service @@ -239,23 +239,24 @@ func ReadInConfig(configFile string) (*Config, error) { // ReadInConfigFS is the same as ReadInConfig but it also takes a filesytem as an argument func ReadInConfigFS(configFile string, fs afero.Fs) (*Config, error) { - c, err := readInConfig(configFile, fs) + config, err := readInConfig(configFile, fs) if err != nil { return nil, err } - c1, err := setupSecrets(c, fs) + secrets, err := setupSecrets(config, fs) if err != nil { - return nil, fmt.Errorf("%w: %s", err, c.SecretsFile) + return nil, fmt.Errorf("%w: %s", err, config.SecretsFile) } - return c1, err + return secrets, err } +// setupSecrets function reads in the secrets file and merges the secrets into the config func setupSecrets(conf *Config, fs afero.Fs) (*Config, error) { if conf.SecretsFile == "" { return conf, nil } - secFile, err := filepath.Abs(conf.RelPath(conf.SecretsFile)) + secFile, err := filepath.Abs(conf.AbsolutePath(conf.SecretsFile)) if err != nil { return nil, err } @@ -267,15 +268,15 @@ func setupSecrets(conf *Config, fs afero.Fs) (*Config, error) { return nil, err } - for k, v := range newConf.secrets { - util.SetKeyValue(conf.vi, k, v) + for secretKey, secretValue := range newConf.secrets { + util.SetKeyValue(conf.viper, secretKey, secretValue) } if len(newConf.secrets) == 0 { return conf, nil } - if err := conf.vi.Unmarshal(&newConf); err != nil { + if err := conf.viper.Unmarshal(&newConf); err != nil { return nil, fmt.Errorf("failed to decode config, %v", err) } @@ -289,35 +290,37 @@ func setupSecrets(conf *Config, fs afero.Fs) (*Config, error) { return &newConf, nil } +// readInConfig function reads in the config file for the environment specified in the GO_ENV func readInConfig(configFile string, fs afero.Fs) (*Config, error) { cp := filepath.Dir(configFile) - vi := newViper(cp, filepath.Base(configFile)) + viper := newViper(cp, filepath.Base(configFile)) if fs != nil { - vi.SetFs(fs) + viper.SetFs(fs) } - if err := vi.ReadInConfig(); err != nil { + + if err := viper.ReadInConfig(); err != nil { return nil, err } - if pcf := vi.GetString("inherits"); pcf != "" { - cf := vi.ConfigFileUsed() - vi = newViper(cp, pcf) + if pcf := viper.GetString("inherits"); pcf != "" { + cf := viper.ConfigFileUsed() + viper = newViper(cp, pcf) if fs != nil { - vi.SetFs(fs) + viper.SetFs(fs) } - if err := vi.ReadInConfig(); err != nil { + if err := viper.ReadInConfig(); err != nil { return nil, err } - if v := vi.GetString("inherits"); v != "" { - return nil, fmt.Errorf("inherited config '%s' cannot itself inherit '%s'", pcf, v) + if value := viper.GetString("inherits"); value != "" { + return nil, fmt.Errorf("inherited config '%s' cannot itself inherit '%s'", pcf, value) } - vi.SetConfigFile(cf) + viper.SetConfigFile(cf) - if err := vi.MergeInConfig(); err != nil { + if err := viper.MergeInConfig(); err != nil { return nil, err } } @@ -325,20 +328,21 @@ func readInConfig(configFile string, fs afero.Fs) (*Config, error) { for _, e := range os.Environ() { if strings.HasPrefix(e, "GJ_") || strings.HasPrefix(e, "SJ_") { kv := strings.SplitN(e, "=", 2) - util.SetKeyValue(vi, kv[0], kv[1]) + util.SetKeyValue(viper, kv[0], kv[1]) } } - c := &Config{vi: vi} - c.Serv.ConfigPath = cp + config := &Config{viper: viper} + config.Serv.ConfigPath = cp - if err := vi.Unmarshal(c); err != nil { + if err := viper.Unmarshal(&config); err != nil { return nil, fmt.Errorf("failed to decode config, %v", err) } - return c, nil + return config, nil } +// NewConfig function creates a new GraphJin configuration from the provided config string func NewConfig(config, format string) (*Config, error) { if format == "" { format = "yaml" @@ -355,22 +359,23 @@ func NewConfig(config, format string) (*Config, error) { } } - vi := newViperWithDefaults() - vi.SetConfigType(format) + viper := newViperWithDefaults() + viper.SetConfigType(format) - if err := vi.ReadConfig(strings.NewReader(config)); err != nil { + if err := viper.ReadConfig(strings.NewReader(config)); err != nil { return nil, err } - c := &Config{vi: vi} + c := &Config{viper: viper} - if err := vi.Unmarshal(c); err != nil { + if err := viper.Unmarshal(&c); err != nil { return nil, fmt.Errorf("failed to decode config, %v", err) } return c, nil } +// newViperWithDefaults returns a new viper instance with the default settings func newViperWithDefaults() *viper.Viper { vi := viper.New() @@ -406,6 +411,7 @@ func newViperWithDefaults() *viper.Viper { return vi } +// newViper returns a new viper instance with the default settings func newViper(configPath, configFile string) *viper.Viper { vi := newViperWithDefaults() vi.SetConfigName(strings.TrimSuffix(configFile, filepath.Ext(configFile))) @@ -419,45 +425,54 @@ func newViper(configPath, configFile string) *viper.Viper { return vi } -func (c *Config) GetSecret(k string) (string, bool) { - v, ok := c.secrets[k] - return v, ok +// GetSecret returns the value of the secret key +// if it exists +func (c *Config) GetSecret(key string) (string, bool) { + value, ok := c.secrets[key] + return value, ok } -func (c *Config) GetSecretOrEnv(k string) string { - if v, ok := c.GetSecret(k); ok { - return v +// GetSecretOrEnv returns the value of the secret key if +// it exists or the value of the environment variable +func (c *Config) GetSecretOrEnv(key string) string { + if value, ok := c.GetSecret(key); ok { + return value } - return os.Getenv(k) + return os.Getenv(key) } // func (c *Config) telemetryEnabled() bool { // return c.Telemetry.Debug || c.Telemetry.Metrics.Exporter != "" || c.Telemetry.Tracing.Exporter != "" // } -func (c *Config) RelPath(p string) string { +// AbsolutePath returns the absolute path of the file +func (c *Config) AbsolutePath(p string) string { if filepath.IsAbs(p) { return p } return filepath.Join(c.Serv.ConfigPath, p) } +// SetHash sets the hash value of the configuration func (c *Config) SetHash(hash string) { c.hash = hash } +// SetName sets the name of the configuration func (c *Config) SetName(name string) { c.name = name } +// rateLimiterEnable returns true if the rate limiter is enabled func (c *Config) rateLimiterEnable() bool { return c.RateLimiter.Rate > 0 && c.RateLimiter.Bucket > 0 } +// GetConfigName returns the name of the configuration func GetConfigName() string { - ge := strings.TrimSpace(strings.ToLower(os.Getenv("GO_ENV"))) + goEnv := strings.TrimSpace(strings.ToLower(os.Getenv("GO_ENV"))) - switch ge { + switch goEnv { case "production", "prod": return "prod" @@ -471,6 +486,6 @@ func GetConfigName() string { return "dev" default: - return ge + return goEnv } } diff --git a/serv/db.go b/serv/db.go index f0e122c3..3de28687 100644 --- a/serv/db.go +++ b/serv/db.go @@ -34,10 +34,12 @@ type dbConf struct { connString string } +// Config holds the configuration for the service func NewDB(conf *Config, openDB bool, log *zap.SugaredLogger, fs core.FS) (*sql.DB, error) { return newDB(conf, openDB, false, log, fs) } +// newDB initializes the database func newDB( conf *Config, openDB, useTelemetry bool, @@ -97,43 +99,44 @@ func newDB( } } +// initPostgres initializes the postgres database func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, error) { - c := conf - config, _ := pgx.ParseConfig(c.DB.ConnString) - if c.DB.Host != "" { - config.Host = c.DB.Host + confCopy := conf + config, _ := pgx.ParseConfig(confCopy.DB.ConnString) + if confCopy.DB.Host != "" { + config.Host = confCopy.DB.Host } - if c.DB.Port != 0 { - config.Port = c.DB.Port + if confCopy.DB.Port != 0 { + config.Port = confCopy.DB.Port } - if c.DB.User != "" { - config.User = c.DB.User + if confCopy.DB.User != "" { + config.User = confCopy.DB.User } - if c.DB.Password != "" { - config.Password = c.DB.Password + if confCopy.DB.Password != "" { + config.Password = confCopy.DB.Password } if config.RuntimeParams == nil { config.RuntimeParams = map[string]string{} } - if c.DB.Schema != "" { - config.RuntimeParams["search_path"] = c.DB.Schema + if confCopy.DB.Schema != "" { + config.RuntimeParams["search_path"] = confCopy.DB.Schema } - if c.AppName != "" { - config.RuntimeParams["application_name"] = c.AppName + if confCopy.AppName != "" { + config.RuntimeParams["application_name"] = confCopy.AppName } - // if openDB { - config.Database = c.DB.DBName - // } + if openDB { + config.Database = confCopy.DB.DBName + } - if c.DB.EnableTLS { - if len(c.DB.ServerName) == 0 { + if confCopy.DB.EnableTLS { + if len(confCopy.DB.ServerName) == 0 { return nil, errors.New("tls: server_name is required") } - if len(c.DB.ServerCert) == 0 { + if len(confCopy.DB.ServerCert) == 0 { return nil, errors.New("tls: server_cert is required") } @@ -141,10 +144,10 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, var pem []byte var err error - if strings.Contains(c.DB.ServerCert, pemSig) { - pem = []byte(strings.ReplaceAll(c.DB.ServerCert, `\n`, "\n")) + if strings.Contains(confCopy.DB.ServerCert, pemSig) { + pem = []byte(strings.ReplaceAll(confCopy.DB.ServerCert, `\n`, "\n")) } else { - pem, err = fs.Get(c.DB.ServerCert) + pem, err = fs.Get(confCopy.DB.ServerCert) } if err != nil { @@ -158,24 +161,24 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, config.TLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: rootCertPool, - ServerName: c.DB.ServerName, + ServerName: confCopy.DB.ServerName, } - if len(c.DB.ClientCert) > 0 { - if len(c.DB.ClientKey) == 0 { + if len(confCopy.DB.ClientCert) > 0 { + if len(confCopy.DB.ClientKey) == 0 { return nil, errors.New("tls: client_key is required") } clientCert := make([]tls.Certificate, 0, 1) var certs tls.Certificate - if strings.Contains(c.DB.ClientCert, pemSig) { + if strings.Contains(confCopy.DB.ClientCert, pemSig) { certs, err = tls.X509KeyPair( - []byte(strings.ReplaceAll(c.DB.ClientCert, `\n`, "\n")), - []byte(strings.ReplaceAll(c.DB.ClientKey, `\n`, "\n")), + []byte(strings.ReplaceAll(confCopy.DB.ClientCert, `\n`, "\n")), + []byte(strings.ReplaceAll(confCopy.DB.ClientKey, `\n`, "\n")), ) } else { - certs, err = loadX509KeyPair(fs, c.DB.ClientCert, c.DB.ClientKey) + certs, err = loadX509KeyPair(fs, confCopy.DB.ClientCert, confCopy.DB.ClientKey) } if err != nil { @@ -190,6 +193,7 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, return &dbConf{"pgx", stdlib.RegisterConnConfig(config)}, nil } +// initMysql initializes the mysql database func initMysql(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, error) { var connString string c := conf @@ -207,6 +211,7 @@ func initMysql(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, er return &dbConf{"mysql", connString}, nil } +// loadX509KeyPair loads a X509 key pair from a file system func loadX509KeyPair(fs core.FS, certFile, keyFile string) ( cert tls.Certificate, err error, ) { diff --git a/serv/deploy.go b/serv/deploy.go index c340a94a..ae7a26b5 100644 --- a/serv/deploy.go +++ b/serv/deploy.go @@ -23,7 +23,8 @@ type depResp struct { name, pname string } -func (s *service) saveConfig(c context.Context, name, bundle string) (*depResp, error) { +// saveConfig saves the config to the database +func (s *graphjinService) saveConfig(c context.Context, name, bundle string) (*depResp, error) { var dres depResp zip, err := base64.StdEncoding.DecodeString(bundle) @@ -139,7 +140,8 @@ func (s *service) saveConfig(c context.Context, name, bundle string) (*depResp, return &dres, nil } -func (s *service) rollbackConfig(c context.Context) (*depResp, error) { +// rollbackConfig rolls back the config to the previous one +func (s *graphjinService) rollbackConfig(c context.Context) (*depResp, error) { var dres depResp opt := &sql.TxOptions{Isolation: sql.LevelSerializable} @@ -216,6 +218,7 @@ type adminParams struct { params map[string]string } +// getAdminParams fetches the admin params from the database func getAdminParams(tx *sql.Tx) (adminParams, error) { var ap adminParams @@ -259,14 +262,15 @@ func getAdminParams(tx *sql.Tx) (adminParams, error) { return ap, nil } -func startHotDeployWatcher(s1 *Service) error { +// startHotDeployWatcher starts the hot deploy watcher +func startHotDeployWatcher(s1 *HttpService) error { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for range ticker.C { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) - cf := s.conf.vi.ConfigFileUsed() + cf := s.conf.viper.ConfigFileUsed() cf = filepath.Join("/", filepath.Base(strings.TrimSuffix(cf, filepath.Ext(cf)))) var id int @@ -322,6 +326,7 @@ type activeBundle struct { name, hash, bundle string } +// fetchActiveBundle fetches the active bundle from the database func fetchActiveBundle(db *sql.DB) (*activeBundle, error) { var ab activeBundle @@ -346,7 +351,8 @@ func fetchActiveBundle(db *sql.DB) (*activeBundle, error) { return &ab, nil } -func deployBundle(s1 *Service, name, hash, confFile, bundle string) error { +// deployBundle deploys the bundle to the server +func deployBundle(s1 *HttpService, name, hash, confFile, bundle string) error { bfs, err := bundle2Fs(name, hash, confFile, bundle) if err != nil { return err @@ -360,6 +366,7 @@ type bundleFs struct { fs afero.Fs } +// bundle2Fs converts the bundle to a filesystem func bundle2Fs(name, hash, confFile, bundle string) (bundleFs, error) { var bfs bundleFs diff --git a/serv/filewatch.go b/serv/filewatch.go index 6546351e..29e0f36e 100644 --- a/serv/filewatch.go +++ b/serv/filewatch.go @@ -13,7 +13,8 @@ import ( "github.com/pkg/errors" ) -func startConfigWatcher(s1 *Service) error { +// startConfigWatcher watches for changes in the config file +func startConfigWatcher(s1 *HttpService) error { var watcher *fsnotify.Watcher var err error @@ -59,7 +60,7 @@ func startConfigWatcher(s1 *Service) error { } for { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) select { case err := <-watcher.Errors: @@ -88,7 +89,7 @@ func startConfigWatcher(s1 *Service) error { } // Check if new config is valid - cf := s.conf.RelPath(GetConfigName()) + cf := s.conf.AbsolutePath(GetConfigName()) conf, err := readInConfig(cf, nil) if err != nil { s.log.Error(err) diff --git a/serv/health.go b/serv/health.go index ca7b967b..837df682 100644 --- a/serv/health.go +++ b/serv/health.go @@ -10,9 +10,10 @@ import ( var healthyResponse = []byte("All's Well") -func healthCheckHandler(s1 *Service) http.Handler { +// healthCheckHandler returns a handler that checks the health of the service +func healthCheckHandler(s1 *HttpService) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) c, cancel := context.WithTimeout(r.Context(), s.conf.DB.PingTimeout) defer cancel() diff --git a/serv/http.go b/serv/http.go index 2cbab5ee..474becc8 100644 --- a/serv/http.go +++ b/serv/http.go @@ -55,9 +55,10 @@ type errorResp struct { Errors []string `json:"errors"` } -func apiV1Handler(s1 *Service, ns *string, h http.Handler, ah auth.HandlerFunc) http.Handler { +// apiV1Handler is the main handler for all API requests +func apiV1Handler(s1 *HttpService, ns *string, h http.Handler, ah auth.HandlerFunc) http.Handler { var zlog *zap.Logger - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if s.conf.Core.Debug { zlog = s.zlog @@ -107,14 +108,15 @@ func apiV1Handler(s1 *Service, ns *string, h http.Handler, ah auth.HandlerFunc) return h } -func (s1 *Service) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { +// apiV1GraphQLHandler handles the GraphQL API requests +func (s1 *HttpService) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { dtrace := otel.GetTextMapPropagator() h := func(w http.ResponseWriter, r *http.Request) { var err error start := time.Now() - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) w.Header().Set("Content-Type", "application/json") @@ -155,7 +157,7 @@ func (s1 *Service) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { return } - var rc core.ReqConfig + var rc core.RequestConfig if req.apqEnabled() { rc.APQKey = (req.OpName + req.Ext.Persisted.Sha256Hash) @@ -205,7 +207,8 @@ func (s1 *Service) apiV1GraphQL(ns *string, ah auth.HandlerFunc) http.Handler { return http.HandlerFunc(h) } -func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { +// apiV1Rest returns a handler that handles the REST API requests +func (s1 *HttpService) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { rLen := len(routeREST) dtrace := otel.GetTextMapPropagator() @@ -213,7 +216,7 @@ func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { var err error start := time.Now() - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) w.Header().Set("Content-Type", "application/json") @@ -255,7 +258,7 @@ func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { return } - var rc core.ReqConfig + var rc core.RequestConfig if rc.Vars == nil && len(s.conf.Core.HeaderVars) != 0 { rc.Vars = s.setHeaderVars(r) @@ -288,11 +291,12 @@ func (s1 *Service) apiV1Rest(ns *string, ah auth.HandlerFunc) http.Handler { return http.HandlerFunc(h) } -func (s *service) responseHandler(ct context.Context, +// responseHandler handles the response from the GraphQL API +func (s *graphjinService) responseHandler(ct context.Context, w http.ResponseWriter, r *http.Request, start time.Time, - rc core.ReqConfig, + rc core.RequestConfig, res *core.Result, err error, ) { @@ -330,7 +334,8 @@ func (s *service) responseHandler(ct context.Context, } } -func (s *service) reqLog(res *core.Result, rc core.ReqConfig, resTimeMs int64, err error) { +// reqLog logs the request details +func (s *graphjinService) reqLog(res *core.Result, rc core.RequestConfig, resTimeMs int64, err error) { var fields []zapcore.Field var sql string @@ -373,7 +378,8 @@ func (s *service) reqLog(res *core.Result, rc core.ReqConfig, resTimeMs int64, e } } -func (s *service) setHeaderVars(r *http.Request) map[string]interface{} { +// setHeaderVars sets the header variables +func (s *graphjinService) setHeaderVars(r *http.Request) map[string]interface{} { vars := make(map[string]interface{}) for k, v := range s.conf.Core.HeaderVars { vars[k] = func() string { @@ -386,11 +392,12 @@ func (s *service) setHeaderVars(r *http.Request) map[string]interface{} { return vars } +// apqEnabled checks if the APQ is enabled func (r gqlReq) apqEnabled() bool { return r.Ext.Persisted.Sha256Hash != "" } -// nolint:errcheck +// renderErr renders the error response func renderErr(w http.ResponseWriter, err error) { if err == errUnauthorized { w.WriteHeader(http.StatusUnauthorized) @@ -402,6 +409,7 @@ func renderErr(w http.ResponseWriter, err error) { } } +// parseBody parses the request body func parseBody(r *http.Request) ([]byte, error) { b, err := io.ReadAll(io.LimitReader(r.Body, maxReadBytes)) if err != nil { @@ -411,6 +419,7 @@ func parseBody(r *http.Request) ([]byte, error) { return b, nil } +// newDTrace creates a new DTrace func newDTrace(dtrace propagation.TextMapPropagator, r *http.Request) (context.Context, []trace.SpanStartOption) { ctx := dtrace.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) diff --git a/serv/init.go b/serv/init.go index 361e56de..2de22bc9 100644 --- a/serv/init.go +++ b/serv/init.go @@ -10,7 +10,8 @@ import ( "github.com/dosco/graphjin/core/v3" ) -func initLogLevel(s *service) { +// initLogLevel initializes the log level +func initLogLevel(s *graphjinService) { switch s.conf.LogLevel { case "debug": s.logLevel = logLevelDebug @@ -25,7 +26,8 @@ func initLogLevel(s *service) { } } -func validateConf(s *service) { +// validateConf validates the configuration +func validateConf(s *graphjinService) { var anonFound bool for _, r := range s.conf.Core.Roles { @@ -40,7 +42,8 @@ func validateConf(s *service) { } } -func (s *service) initFS() error { +// initFS initializes the file system +func (s *graphjinService) initFS() error { basePath, err := s.basePath() if err != nil { return err @@ -53,7 +56,8 @@ func (s *service) initFS() error { return nil } -func (s *service) initConfig() error { +// initConfig initializes the configuration +func (s *graphjinService) initConfig() error { c := s.conf c.dirty = true @@ -96,7 +100,8 @@ func (s *service) initConfig() error { return nil } -func (s *service) initDB() error { +// initDB initializes the database +func (s *graphjinService) initDB() error { var err error if s.db != nil { @@ -110,7 +115,8 @@ func (s *service) initDB() error { return nil } -func (s *service) basePath() (string, error) { +// basePath returns the base path +func (s *graphjinService) basePath() (string, error) { if s.conf.Serv.ConfigPath == "" { if cp, err := os.Getwd(); err == nil { return filepath.Join(cp, "config"), nil diff --git a/serv/internal/secrets/decrypt.go b/serv/internal/secrets/decrypt.go index f2b9923e..6dcf3579 100644 --- a/serv/internal/secrets/decrypt.go +++ b/serv/internal/secrets/decrypt.go @@ -22,6 +22,7 @@ type decryptOpts struct { KeyServices []keyservice.KeyServiceClient } +// decrypt decrypts the file at the given path using options passed. func decrypt(opts decryptOpts, fs afero.Fs) (decryptedFile []byte, err error) { tree, err := LoadEncryptedFileWithBugFixes(common.GenericDecryptOpts{ Cipher: opts.Cipher, @@ -84,6 +85,7 @@ func extract(tree *sops.Tree, path []interface{}, outputStore sops.Store) ([]byt return bytes, nil } +// LoadEncryptedFileWithBugFixes loads an encrypted file from the given path and applies bug fixes. func LoadEncryptedFileWithBugFixes( opts common.GenericDecryptOpts, fs afero.Fs) (*sops.Tree, error) { @@ -112,6 +114,7 @@ func LoadEncryptedFileWithBugFixes( return tree, nil } +// LoadEncryptedFile loads an encrypted file from the given path. func LoadEncryptedFile( loader sops.EncryptedFileLoader, inputPath string, diff --git a/serv/internal/secrets/edit.go b/serv/internal/secrets/edit.go index 7c03b5d2..c67726cd 100644 --- a/serv/internal/secrets/edit.go +++ b/serv/internal/secrets/edit.go @@ -56,6 +56,7 @@ GJ_ADMIN_SECRET_KEY: hotdeploy_admin_secret_key GJ_SECRET_KEY: graphjin_generic_secret_key GJ_AUTH_JWT_SECRET: jwt_auth_secret_key` +// editExample edits the example file func editExample(opts editExampleOpts) ([]byte, error) { branches, err := opts.InputStore.LoadPlainFile([]byte(fileBytes)) if err != nil { @@ -87,6 +88,7 @@ func editExample(opts editExampleOpts) ([]byte, error) { return editTree(opts.editOpts, &tree, dataKey) } +// edit edits the file at the given path using options passed. func edit(opts editOpts) ([]byte, error) { // Load the file tree, err := common.LoadEncryptedFileWithBugFixes(common.GenericDecryptOpts{ @@ -113,6 +115,7 @@ func edit(opts editOpts) ([]byte, error) { return editTree(opts, tree, dataKey) } +// editTree edits the tree using the options passed. func editTree(opts editOpts, tree *sops.Tree, dataKey []byte) ([]byte, error) { // Create temporary file for editing tmpdir, err := os.MkdirTemp("", "") @@ -180,6 +183,7 @@ func editTree(opts editOpts, tree *sops.Tree, dataKey []byte) ([]byte, error) { return encryptedFile, nil } +// runEditorUntilOk runs the editor until the file is saved and the hash is different func runEditorUntilOk(opts runEditorUntilOkOpts) error { for { err := runEditor(opts.TmpFile.Name()) @@ -240,6 +244,7 @@ func runEditorUntilOk(opts runEditorUntilOkOpts) error { return nil } +// hashFile returns the MD5 hash of the file at the given path func hashFile(filePath string) ([]byte, error) { var result []byte file, err := os.Open(filePath) @@ -254,6 +259,7 @@ func hashFile(filePath string) ([]byte, error) { return hash.Sum(result), nil } +// runEditor runs the editor func runEditor(path string) error { editor := os.Getenv("EDITOR") var cmd *exec.Cmd @@ -279,6 +285,7 @@ func runEditor(path string) error { return cmd.Run() } +// lookupAnyEditor looks up the first available editor func lookupAnyEditor(editorNames ...string) (editorPath string, err error) { for _, editorName := range editorNames { editorPath, err = exec.LookPath(editorName) diff --git a/serv/internal/secrets/init.go b/serv/internal/secrets/init.go index 9113af11..6efdfd9a 100644 --- a/serv/internal/secrets/init.go +++ b/serv/internal/secrets/init.go @@ -11,11 +11,8 @@ import ( "go.mozilla.org/sops/v3/stores/dotenv" ) +// Init reads the secrets from the given file and returns them as a map func Init(filename string, fs afero.Fs) (map[string]string, error) { - return initSecrets(filename, fs) -} - -func initSecrets(filename string, fs afero.Fs) (map[string]string, error) { var err error inputStore := common.DefaultStoreForPath(filename) diff --git a/serv/internal/secrets/rotate.go b/serv/internal/secrets/rotate.go index 846cfa3a..1a3a0512 100644 --- a/serv/internal/secrets/rotate.go +++ b/serv/internal/secrets/rotate.go @@ -24,6 +24,7 @@ type rotateOpts struct { KeyServices []keyservice.KeyServiceClient } +// rotate rotates the keys in the file at the given path using options passed. func rotate(opts rotateOpts) ([]byte, error) { tree, err := common.LoadEncryptedFileWithBugFixes(common.GenericDecryptOpts{ Cipher: opts.Cipher, diff --git a/serv/internal/secrets/run.go b/serv/internal/secrets/run.go index d2e6a420..8c8c24e0 100644 --- a/serv/internal/secrets/run.go +++ b/serv/internal/secrets/run.go @@ -21,6 +21,7 @@ type SecretArgs struct { KMS, KMSC, AWS, GCP, Azure, PGP string //nolint:golint,unused } +// SecretsCmd is the entry point for the secrets command func SecretsCmd(cmdName, fileName string, sa SecretArgs, args []string, log *zap.SugaredLogger) error { var err error @@ -148,6 +149,7 @@ func SecretsCmd(cmdName, fileName string, sa SecretArgs, args []string, log *zap return nil } +// keyGroups returns a slice of key groups based on the secret arguments func keyGroups(sa SecretArgs, file string) ([]sops.KeyGroup, error) { var kmsKeys []keys.MasterKey var pgpKeys []keys.MasterKey diff --git a/serv/internal/util/log.go b/serv/internal/util/log.go index 7c0c8774..262ea160 100644 --- a/serv/internal/util/log.go +++ b/serv/internal/util/log.go @@ -7,6 +7,8 @@ import ( "go.uber.org/zap/zapcore" ) +// NewLogger creates a new zap logger instance +// json - if true logs are in json format func NewLogger(json bool) *zap.Logger { econf := zapcore.EncoderConfig{ MessageKey: "msg", diff --git a/serv/internal/util/viper.go b/serv/internal/util/viper.go index 5621b87d..81c3ebc8 100644 --- a/serv/internal/util/viper.go +++ b/serv/internal/util/viper.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/viper" ) +// SetKeyValue sets the value of a key in the viper config func SetKeyValue(vi *viper.Viper, key string, value interface{}) bool { if strings.HasPrefix(key, "GJ_") || strings.HasPrefix(key, "SG_") { key = key[3:] diff --git a/serv/iplimiter.go b/serv/iplimiter.go index 7c648d8b..513a0466 100644 --- a/serv/iplimiter.go +++ b/serv/iplimiter.go @@ -14,10 +14,12 @@ import ( var ipCache cache.Cache +// init initializes the cache func init() { ipCache, _ = cache.NewCache(cache.MaxKeys(10), cache.TTL(time.Minute*5)) } +// getIPLimiter returns the rate limiter for the given IP func getIPLimiter(ip string, limit float64, bucket int) *rate.Limiter { v, exists := ipCache.Get(ip) if !exists { @@ -29,11 +31,12 @@ func getIPLimiter(ip string, limit float64, bucket int) *rate.Limiter { return v.(*rate.Limiter) } -func rateLimiter(s1 *Service, h http.Handler) http.Handler { +// rateLimiter is a middleware that limits the number of requests per IP +func rateLimiter(s1 *HttpService, h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { var iph, ip string var err error - s := s1.Load().(*service) + s := s1.Load().(*graphjinService) if s.conf.RateLimiter.IPHeader != "" { iph = r.Header.Get(s.conf.RateLimiter.IPHeader) diff --git a/serv/migrate.go b/serv/migrate.go index e3aa920d..601c53eb 100644 --- a/serv/migrate.go +++ b/serv/migrate.go @@ -33,6 +33,7 @@ CREATE TABLE _graphjin.configs ( CREATE INDEX config_active ON _graphjin.configs (active); ` +// InitAdmin creates the admin tables func InitAdmin(db *sql.DB, dbtype string) error { c := context.Background() @@ -52,6 +53,7 @@ func InitAdmin(db *sql.DB, dbtype string) error { return nil } +// idColSql returns the id column sql func idColSql(dbtype string) string { switch dbtype { case "mysql": diff --git a/serv/routes.go b/serv/routes.go index db7bd1ca..cebd8f57 100644 --- a/serv/routes.go +++ b/serv/routes.go @@ -17,8 +17,9 @@ type Mux interface { ServeHTTP(http.ResponseWriter, *http.Request) } -func routesHandler(s1 *Service, mux Mux, ns *string) (http.Handler, error) { - s := s1.Load().(*service) +// routesHandler is the main handler for all routes +func routesHandler(s1 *HttpService, mux Mux, ns *string) (http.Handler, error) { + s := s1.Load().(*graphjinService) // Healthcheck API mux.Handle(healthRoute, healthCheckHandler(s1)) diff --git a/serv/secrets.go b/serv/secrets.go index 28c379e4..a8c3aff7 100644 --- a/serv/secrets.go +++ b/serv/secrets.go @@ -6,15 +6,18 @@ import ( "go.uber.org/zap" ) +// SecretArgs holds the arguments for the secrets command type SecretArgs struct { KMS, KMSC, AWS, GCP, Azure, PGP string } +// SecretsCmd runs the secrets command func SecretsCmd(cmdName, fileName string, sa SecretArgs, args []string, log *zap.SugaredLogger) error { return secrets.SecretsCmd( cmdName, fileName, secrets.SecretArgs(sa), args, log) } +// InitSecrets initializes the secrets from the secrets file func initSecrets(secFile string, fs afero.Fs) (map[string]string, error) { return secrets.Init(secFile, fs) } diff --git a/serv/serv.go b/serv/serv.go index 9d090358..26870ae8 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -20,8 +20,9 @@ const ( defaultHP = "0.0.0.0:8080" ) -func initConfigWatcher(s1 *Service) { - s := s1.Load().(*service) +// Initialize the watcher for the graphjin config file +func initConfigWatcher(s1 *HttpService) { + s := s1.Load().(*graphjinService) if s.conf.Serv.Production { return } @@ -34,8 +35,9 @@ func initConfigWatcher(s1 *Service) { }() } -func initHotDeployWatcher(s1 *Service) { - s := s1.Load().(*service) +// Initialize the hot deploy watcher +func initHotDeployWatcher(s1 *HttpService) { + s := s1.Load().(*graphjinService) go func() { err := startHotDeployWatcher(s1) if err != nil { @@ -44,8 +46,9 @@ func initHotDeployWatcher(s1 *Service) { }() } -func startHTTP(s1 *Service) { - s := s1.Load().(*service) +// Start the HTTP server +func startHTTP(s1 *HttpService) { + s := s1.Load().(*graphjinService) r := chi.NewRouter() routes, err := routesHandler(s1, r, s.namespace) @@ -125,6 +128,7 @@ func startHTTP(s1 *Service) { <-idleConnsClosed } +// Set the server header func setServerHeader(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Server", serverName) diff --git a/serv/telemetry.go b/serv/telemetry.go index 4386e6a2..58aa5d58 100644 --- a/serv/telemetry.go +++ b/serv/telemetry.go @@ -10,9 +10,10 @@ import ( semconv "go.opentelemetry.io/otel/semconv/v1.24.0" ) +// InitTelemetry initializes the OpenTelemetry SDK with the given exporter and service name. func InitTelemetry( - c context.Context, - exp trace.SpanExporter, + context context.Context, + exporter trace.SpanExporter, serviceName, serviceInstanceID string, ) error { r1 := resource.NewWithAttributes( @@ -27,7 +28,7 @@ func InitTelemetry( } provider := trace.NewTracerProvider( - trace.WithBatcher(exp), + trace.WithBatcher(exporter), trace.WithResource(r2), trace.WithSampler(trace.AlwaysSample()), ) diff --git a/serv/webui.go b/serv/webui.go index 38b27055..a991b4b4 100644 --- a/serv/webui.go +++ b/serv/webui.go @@ -10,6 +10,7 @@ import ( //go:embed web/build var webBuild embed.FS +// webuiHandler serves the web UI func webuiHandler(routePrefix string, gqlEndpoint string) http.Handler { webRoot, _ := fs.Sub(webBuild, "web/build") fs := http.FileServer(http.FS(webRoot)) diff --git a/serv/ws.go b/serv/ws.go index 2a6b5b74..f99a075c 100644 --- a/serv/ws.go +++ b/serv/ws.go @@ -74,7 +74,8 @@ type wsState struct { done chan bool } -func (s *service) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) { +// apiV1Ws handles the websocket connection +func (s *graphjinService) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { renderErr(w, err) @@ -127,7 +128,8 @@ type authHeaders struct { UserID interface{} `json:"X-User-ID"` } -func (s *service) subSwitch(wc *wsConn, req wsReq) (err error) { +// subSwitch handles the websocket message types +func (s *graphjinService) subSwitch(wc *wsConn, req wsReq) (err error) { switch req.Type { case "connection_init": if err = setHeaders(req, wc.r); err != nil { @@ -198,7 +200,8 @@ func (s *service) subSwitch(wc *wsConn, req wsReq) (err error) { return } -func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) { +// waitForData waits for data from the subscription +func (s *graphjinService) waitForData(wc *wsConn, st *wsState, useNext bool) { var buf bytes.Buffer var ptype string @@ -247,6 +250,7 @@ func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) { } } +// setHeaders sets the headers from the payload func setHeaders(req wsReq, r *http.Request) (err error) { if len(req.Payload) == 0 { return @@ -266,6 +270,7 @@ func setHeaders(req wsReq, r *http.Request) (err error) { return } +// sendError sends an error message to the client func sendError(wc *wsConn, id string, cerr error) (err error) { m := wsRes{ID: id, Type: "error"} m.Payload.Errors = []core.Error{{Message: cerr.Error()}} diff --git a/tests/core_test.go b/tests/core_test.go index cd9b3e6d..b1c8f2a3 100644 --- a/tests/core_test.go +++ b/tests/core_test.go @@ -65,7 +65,7 @@ func TestAPQ(t *testing.T) { return } - _, err = gj.GraphQL(context.Background(), gql, nil, &core.ReqConfig{ + _, err = gj.GraphQL(context.Background(), gql, nil, &core.RequestConfig{ APQKey: "getProducts", }) if err != nil { @@ -73,7 +73,7 @@ func TestAPQ(t *testing.T) { return } - res, err := gj.GraphQL(context.Background(), "", nil, &core.ReqConfig{ + res, err := gj.GraphQL(context.Background(), "", nil, &core.RequestConfig{ APQKey: "getProducts", }) if err != nil { @@ -199,7 +199,7 @@ func TestAllowListWithNamespace(t *testing.T) { return } - var rc core.ReqConfig + var rc core.RequestConfig rc.SetNamespace("api") _, err = gj2.GraphQL(context.Background(), gql2, nil, &rc)