Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glorious SQL #54

Merged
merged 13 commits into from
Nov 24, 2023
Prev Previous commit
Next Next commit
fixes
subroseio committed Nov 23, 2023
commit 488705fb10e0e4b8ea657934e7520eb23db52100
179 changes: 76 additions & 103 deletions vault/sql.go
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"

"database/sql"
@@ -50,9 +49,9 @@ func (st *SqlStore) CreateSchemas() error {
tables := map[string]string{
"principals": "CREATE TABLE IF NOT EXISTS principals (username TEXT PRIMARY KEY, password TEXT, description TEXT)",
"policies": "CREATE TABLE IF NOT EXISTS policies (id TEXT, effect TEXT, actions TEXT[], resources TEXT[])",
"principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT)",
"tokens": "CREATE TABLE IF NOT EXISTS tokens (id TEXT, value TEXT)",
"collection_metadata": "CREATE TABLE IF NOT EXISTS collection_metadata (name TEXT, field_schema JSON)",
"principal_policies": "CREATE TABLE IF NOT EXISTS principal_policies (username TEXT, policy_id TEXT)",
}

for _, query := range tables {
@@ -65,68 +64,49 @@ func (st *SqlStore) CreateSchemas() error {
return nil
}

func (st SqlStore) createCollectionTable(ctx context.Context, c Collection) error {
// Define a dynamic struct based on the Fields of the collection
var dynamicStructFields []reflect.StructField

// Add an ID field to the struct
idField := reflect.StructField{
Name: "ID",
Type: reflect.TypeOf(""),
Tag: reflect.StructTag(`db:"id"`),
func (st *SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) {
tx, err := st.db.BeginTxx(ctx, nil)
if err != nil {
return "", err
}
dynamicStructFields = append(dynamicStructFields, idField)

for fieldName := range c.Fields {
exportedFieldName := strings.Title(fieldName)
structField := reflect.StructField{
Name: exportedFieldName,
Type: reflect.TypeOf(""), // Assuming all fields are strings for simplicity
Tag: reflect.StructTag(fmt.Sprintf(`db:"%s"`, fieldName)),
defer func() {
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err)
}
}
dynamicStructFields = append(dynamicStructFields, structField)
}

// dynamicStruct := reflect.StructOf(dynamicStructFields)
// dynamicStructPtr := reflect.New(dynamicStruct).Interface() // Create a pointer to a new instance of the dynamic struct

tableName := "collection_" + c.Name // Create a unique table name
}()

// Create the table using SQLX's MustExec with a pointer to the dynamic struct
// st.db.MustExecContext(ctx, "CREATE TABLE IF NOT EXISTS "+tableName+" (?)", dynamicStructPtr)
// Instead of using the dynamic struct directly, we will generate the SQL query manually
var queryBuilder strings.Builder
queryBuilder.WriteString("CREATE TABLE IF NOT EXISTS " + tableName + " (id TEXT")
for fieldName := range c.Fields {
queryBuilder.WriteString(", " + fieldName + " TEXT")
}
queryBuilder.WriteString(")")
st.db.MustExecContext(ctx, queryBuilder.String())

return nil
}

func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) {
// Convert the Fields map to JSON for storing in the collection_metadata table
fieldSchema, err := json.Marshal(c.Fields)
if err != nil {
return "", err
}

// Create a new CollectionMetadata instance
collectionMetadata := CollectionMetadata{
Name: c.Name,
FieldSchema: fieldSchema,
}

// Save the collection metadata
_, err = st.db.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", &collectionMetadata)
_, err = tx.NamedExecContext(ctx, "INSERT INTO collection_metadata (name, field_schema) VALUES (:name, :field_schema)", &collectionMetadata)
if err != nil {
return "", err
}

// Dynamically create a table for the collection
if err := st.createCollectionTable(ctx, c); err != nil {
tableName := "collection_" + c.Name
var queryBuilder strings.Builder
queryBuilder.WriteString("CREATE TABLE IF NOT EXISTS " + tableName + " (id TEXT PRIMARY KEY")
for fieldName := range c.Fields {
queryBuilder.WriteString(", " + fieldName + " TEXT")
}
queryBuilder.WriteString(")")
_, err = tx.ExecContext(ctx, queryBuilder.String())
if err != nil {
return "", err
}

err = tx.Commit()
if err != nil {
return "", err
}

@@ -164,39 +144,42 @@ func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) {
}

func (st SqlStore) DeleteCollection(ctx context.Context, name string) error {
// Start a transaction
tx, err := st.db.BeginTxx(ctx, nil)
if err != nil {
return err
}

// Delete the collection metadata
defer func() {
if err != nil {
rbErr := tx.Rollback()
if rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err)
}
}
}()

_, err = tx.ExecContext(ctx, "DELETE FROM collection_metadata WHERE name = $1", name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
tx.Rollback()
return &NotFoundError{"collection", name}
}
tx.Rollback()
return err
}

// Delete the dynamic table
tableName := "collection_" + name
_, err = tx.ExecContext(ctx, "DROP TABLE IF EXISTS "+tableName)
if err != nil {
tx.Rollback()
return err
}

// Commit the transaction
err = tx.Commit()
if err != nil {
return err
}

return nil
}

func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, records []Record) ([]string, error) {
var collectionMetadata CollectionMetadata
err := st.db.GetContext(ctx, &collectionMetadata, "SELECT * FROM collection_metadata WHERE name = $1", collectionName)
@@ -213,7 +196,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec
return nil, err
}

// Create slice for field names and initialize placeholders
fieldNames := make([]string, 0, len(fields))
placeholders := make([]string, 0, len(fields))
idx := 2 // Start from 2 because $1 is reserved for recordId
@@ -224,15 +206,13 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec
idx++
}

// Prepare SQL statement
query := fmt.Sprintf("INSERT INTO collection_%s (id, %s) VALUES ($1, %s)", collectionName, strings.Join(fieldNames, ", "), strings.Join(placeholders, ", "))
stmt, err := st.db.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()

// Processing records
recordIds := make([]string, len(records))
for i, record := range records {
// Validate record fields
@@ -243,7 +223,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec
recordId := GenerateId()
recordIds[i] = recordId

// Prepare values for insertion
values := make([]interface{}, len(fields)+1)
values[0] = recordId
for j, fieldName := range fieldNames {
@@ -254,7 +233,6 @@ func (st SqlStore) CreateRecords(ctx context.Context, collectionName string, rec
}
}

// Execute the prepared statement
_, err = stmt.ExecContext(ctx, values...)
if err != nil {
return nil, err
@@ -279,21 +257,18 @@ func (st SqlStore) GetRecords(ctx context.Context, collectionName string, record
return nil, err
}

// Prepare the query
query, args, err := sqlx.In("SELECT * FROM collection_"+collectionName+" WHERE id IN (?)", recordIDs)
if err != nil {
return nil, err
}
query = st.db.Rebind(query)

// Execute the query
rows, err := st.db.QueryxContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()

// Process the results
records := make(map[string]*Record)
for rows.Next() {
recordMap := make(map[string]interface{})
@@ -412,9 +387,18 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err
if err != nil {
return err
}

defer func() {
if err != nil {
rbErr := tx.Rollback()
if rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err)
}
}
}()

_, err = tx.NamedExecContext(ctx, "INSERT INTO principals (username, password, description) VALUES (:username, :password, :description)", &principal)
if err != nil {
tx.Rollback()
if pqErr, ok := err.(*pq.Error); ok {
if pqErr.Code == "23505" {
return &ConflictError{principal.Username}
@@ -426,10 +410,10 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal Principal) err
for _, policyId := range principal.Policies {
_, err = tx.ExecContext(ctx, "INSERT INTO principal_policies (username, policy_id) VALUES ($1, $2)", principal.Username, policyId)
if err != nil {
tx.Rollback()
return err
}
}

err = tx.Commit()
if err != nil {
return err
@@ -445,21 +429,28 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, username string) error {
return err
}

// First, delete associations in the many-to-many join table
defer func() {
if p := recover(); p != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after panic: %v", rbErr, p)
}
} else if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err)
}
}
}()

_, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE username = $1", username)
if err != nil {
tx.Rollback()
return err
}

// Now, delete the principal itself
_, err = tx.ExecContext(ctx, "DELETE FROM principals WHERE username = $1", username)
if err != nil {
tx.Rollback()
return err
}

// Commit the transaction
err = tx.Commit()
if err != nil {
return err
@@ -543,22 +534,26 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli
}

func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) {

// Start a transaction
tx, err := st.db.BeginTxx(ctx, nil)
if err != nil {
return "", err
}

defer func() {
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err)
}
}
}()

query := "INSERT INTO policies (id, effect, actions, resources) VALUES (:id, :effect, :actions, :resources)"
actions := make(pq.StringArray, len(p.Actions))
for i, action := range p.Actions {
actions[i] = string(action)
}
resources := make(pq.StringArray, len(p.Resources))
for i, resource := range p.Resources {
resources[i] = resource
}
copy(resources, p.Resources)
query, args, err := sqlx.Named(query, map[string]interface{}{
"id": p.PolicyId,
"effect": string(p.Effect),
@@ -567,20 +562,17 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) {
})

if err != nil {
tx.Rollback()
return "", err
}
query = tx.Rebind(query)
_, err = tx.ExecContext(ctx, query, args...)
if err != nil {
tx.Rollback()
if errors.Is(err, sql.ErrNoRows) {
return "", &ConflictError{p.PolicyId}
}
return "", err
}

// Commit the transaction
err = tx.Commit()
if err != nil {
return "", err
@@ -590,35 +582,35 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p Policy) (string, error) {
}

func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error {
// Start a transaction
tx, err := st.db.BeginTxx(ctx, nil)
if err != nil {
return err
}
// Delete the policy itself
defer func() {
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("rollback failed: %v, after error: %v", rbErr, err)
}
}
}()

result, err := tx.ExecContext(ctx, "DELETE FROM policies WHERE id = $1", policyID)
if err != nil {
tx.Rollback()
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
tx.Rollback()
return err
}
if rowsAffected == 0 {
tx.Rollback()
return &NotFoundError{"policy", policyID}
}

// Directly delete associations in the many-to-many join table
_, err = tx.ExecContext(ctx, "DELETE FROM principal_policies WHERE policy_id = $1", policyID)
if err != nil {
tx.Rollback()
return err
}

// Commit the transaction
err = tx.Commit()
if err != nil {
return err
@@ -656,28 +648,9 @@ func (st SqlStore) Flush(ctx context.Context) error {
return err
}
}
// Recreate schemas
err = st.CreateSchemas()
if err != nil {
return err
}
return nil
}

// func (st SqlStore) CreateCollection(ctx context.Context, c Collection) (string, error) {
// b, err := json.Marshal(c)
// if err != nil {
// return "", err
// }

// gormCol := DbCollection{Name: c.Name, Collection: datatypes.JSON(b)}
// result := st.db.Create(&gormCol)
// if result.Error != nil {
// switch result.Error {
// case gorm.ErrDuplicatedKey:
// return "", &ConflictError{c.Name}
// }
// }

// return c.Name, nil
// }