Skip to content

Commit

Permalink
PTEUDO-1691: remove database name from dbclient (#345)
Browse files Browse the repository at this point in the history
* PTEUDO-1691: remove database name from dbclient

* PTEUDO-1691: checking whether the db exists or not

* PTEUDO-1691: removed old function

* PTEUDO-1691: renamed function name to a go standard

* PTEUDO-1691: renamed receivers to pc

* PTEUDO-1691: rollback, renamed receivers to pc

* PTEUDO-1691: added admin db connection

* PTEUDO-1691: fix client test

* fix bugs in mutating webhooks

---------

Co-authored-by: Drew Wells <dwells@infoblox.com>
  • Loading branch information
leandrorichardtoledo and drewwells authored Oct 26, 2024
1 parent 25d037a commit 61da8a4
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 37 deletions.
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func main() {
flag.StringVar(&metricsDepYamlPath, "metrics-dep-yaml", "/config/postgres-exporter/deployment.yaml", "path to the metrics deployment yaml")
flag.StringVar(&metricsConfigYamlPath, "metrics-config-yaml", "/config/postgres-exporter/config.yaml", "path to the metrics config yaml")
flag.BoolVar(&enableDBProxyWebhook, "enable-db-proxy", false, "Enable DB Proxy webhook. See docs for usage: https://infobloxopen.github.io/db-controller/#quick-start")
flag.BoolVar(&enableDBProxyWebhook, "enable-deprecation-conversion-webhook", false, "Enable conversion of deprecated pods using dbproxy and/or dsnexec annotations")
flag.BoolVar(&enableDeprecatedConversionWebhook, "enable-deprecation-conversion-webhook", false, "Enable conversion of deprecated pods using dbproxy and/or dsnexec annotations")

opts := zap.Options{
Development: true,
Expand Down
2 changes: 0 additions & 2 deletions helm/db-controller/minikube.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ controllerConfig:
zapLogger:
develMode: true
level: debug
dbproxy:
enabled: true

# Block XRDs and XRs in manual deploys
xrd:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ metadata:
annotations:
cert-manager.io/inject-ca-from: {{ .Release.Namespace }}/{{ include "db-controller.fullname" . }}-webhook
webhooks:
{{- if .Values.deprecationConversionWebhook.enabled }}
- clientConfig:
service:
name: {{ include "db-controller.fullname" . }}
Expand Down Expand Up @@ -34,6 +35,7 @@ webhooks:
scope: "Namespaced"
sideEffects: None
timeoutSeconds: 10
{{- end }}
- clientConfig:
service:
name: {{ include "db-controller.fullname" . }}
Expand Down
90 changes: 60 additions & 30 deletions pkg/dbclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dbclient
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"strings"
Expand All @@ -27,17 +26,18 @@ var defaultExtensions = []string{"citext", "uuid-ossp",
"plpgsql", "hll"}

var specialExtensionsMap = map[string]func(*client, string, string) error{
"pg_partman": (*client).pg_partman,
"pg_cron": (*client).pg_cron,
"pg_partman": (*client).CreatePgPartmanExtension,
"pg_cron": (*client).CreatePgCronExtension,
}

type client struct {
// cloud is `aws` or `gcp`
cloud string
dbType string
dbURL string
DB *sql.DB
log logr.Logger
cloud string
dbType string
dbURL string
DB *sql.DB
adminDB *sql.DB
log logr.Logger
}

func (p *client) GetDB() *sql.DB {
Expand Down Expand Up @@ -78,7 +78,7 @@ func New(cfg Config) (*client, error) {
return newPostgresClient(context.TODO(), cfg)
}

// creates postgres client
// newPostgresClient creates a new client for postgres.
func newPostgresClient(ctx context.Context, cfg Config) (*client, error) {

if cfg.Cloud == "" {
Expand All @@ -98,12 +98,32 @@ func newPostgresClient(ctx context.Context, cfg Config) (*client, error) {
if err != nil {
return nil, err
}

// create a new connection to the database to run admin commands.

dsnURL, err := url.Parse(authedDSN)
if err != nil {
return nil, fmt.Errorf("could not parse DSN: %w", err)
}
dsnURL.Path = "/postgres"

adminDB, err := sql.Open(PostgresType, dsnURL.String())
if err != nil {
return nil, err
}

if err := adminDB.PingContext(ctx); err != nil {
adminDB.Close()
return nil, fmt.Errorf("could not connect to admin database: %w", err)
}

return &client{
cloud: cfg.Cloud,
dbType: PostgresType,
DB: db,
log: cfg.Log,
dbURL: cfg.DSN,
cloud: cfg.Cloud,
dbType: PostgresType,
DB: db,
adminDB: adminDB,
log: cfg.Log,
dbURL: cfg.DSN,
}, nil
}

Expand All @@ -126,40 +146,50 @@ func (pc *client) SanitizeDSN() string {
return u.Redacted()
}

// CreateDataBase implements typo func name incase anybody is using it
func (pc *client) CreateDataBase(name string) (bool, error) {
pc.log.Error(errors.New("CreateDataBase called, use CreateDatabase"), "old_interface_called")
return pc.CreateDatabase(name)
}

// unit test override
var getDefaulExtensions = func() []string {
return defaultExtensions
}

// CreateDatabase creates a database if it does not exist.
func (pc *client) CreateDatabase(dbName string) (bool, error) {
var exists bool
db := pc.DB
log := pc.log.WithValues("database", dbName, "dsn", pc.SanitizeDSN())

log.Info("pinging database", "ping", db.Ping())

err := db.QueryRow(`SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = $1)`, dbName).Scan(&exists)
query := `SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = $1)`
err := pc.adminDB.QueryRow(query, dbName).Scan(&exists)
if err != nil {
// TODO: use error codes provided by the pq driver.
if strings.Contains(err.Error(), "does not exist") {
return pc.createDatabase(dbName, log)
}

log.Error(err, "could not query for database name", "query", fmt.Sprintf(`SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = '%s')`, dbName))
metrics.DBProvisioningErrors.WithLabelValues("read error")
metrics.DBProvisioningErrors.WithLabelValues("read error").Inc()
return false, err
}

if exists {
return false, nil
}

// create the database
if _, err := db.Exec(fmt.Sprintf("create database %s", pq.QuoteIdentifier(dbName))); err != nil {
return pc.createDatabase(dbName, log)
}

func (pc *client) createDatabase(dbName string, log logr.Logger) (bool, error) {
_, err := pc.adminDB.Exec(fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
log.Error(err, "could not create database")
metrics.DBProvisioningErrors.WithLabelValues("create error")
metrics.DBProvisioningErrors.WithLabelValues("create error").Inc()
return false, err
}

if err := pc.DB.Ping(); err != nil {
log.Error(err, "could not connect to database")
metrics.DBProvisioningErrors.WithLabelValues("ping error").Inc()
return false, err
}

metrics.DBCreated.Inc()
return true, nil
}
Expand Down Expand Up @@ -218,7 +248,7 @@ func (pc *client) CreateSpecialExtensions(dbName string, role string) error {
return nil
}

func (pc *client) pg_cron(dbName string, role string) error {
func (pc *client) CreatePgCronExtension(dbName string, role string) error {
// create extension pg_cron and grant usage to public
db, err := pc.getDB(dbName)
if err != nil {
Expand All @@ -243,7 +273,7 @@ func (pc *client) pg_cron(dbName string, role string) error {
return nil
}

func (pc *client) pg_partman(dbName string, role string) error {
func (pc *client) CreatePgPartmanExtension(dbName string, role string) error {

db, err := pc.getDB(dbName)
if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions pkg/dbclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ func TestPostgresClientOperations(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pc := &client{
dbType: "postgres",
dbURL: dsn,
DB: db,
log: NewTestLogger(t),
dbType: "postgres",
dbURL: dsn,
DB: db,
adminDB: db,
log: NewTestLogger(t),
}

got, err := pc.CreateDatabase(tt.args.dbName)
Expand Down

0 comments on commit 61da8a4

Please sign in to comment.