diff --git a/base.go b/base.go index 2bfe6f8..fd811c9 100644 --- a/base.go +++ b/base.go @@ -300,6 +300,15 @@ func (d *base) CreateTableSql(model *Model, ifNotExists bool) string { a = append(a, ", ") } } + if len(model.ForeignKeys) > 0 { + a = append(a, ", ") + } + for i, fk := range model.ForeignKeys { + a = append(a, d.Dialect.ForeignKey(fk)) + if i < len(model.ForeignKeys)-1 { + a = append(a, ", ") + } + } a = append(a, " )") return strings.Join(a, "") } @@ -434,3 +443,30 @@ func (d *base) KeywordPrimaryKey() string { func (d *base) KeywordAutoIncrement() string { return "AUTOINCREMENT" } + +func (d *base) ForeignKey(fk *ForeignKey) string { + return fmt.Sprintf( + "FOREIGN KEY (%v) REFERENCES %v(%v) ON UPDATE %v ON DELETE %v", + d.Dialect.Quote(fk.Column), + d.Dialect.Quote(fk.ReferenceTable), + d.Dialect.Quote(fk.ReferenceColumn), + d.Dialect.ReferentialAction(fk.OnUpdate), + d.Dialect.ReferentialAction(fk.OnDelete), + ) + +} + +func (d *base) ReferentialAction(ra ReferentialAction) string { + switch ra { + case Cascade: + return "CASCADE" + case Restrict: + return "RESTRICT" + case NoAction: + return "NO ACTION" + case SetNull: + return "SET NULL" + } + + return "NO ACTION" +} diff --git a/dialect.go b/dialect.go index 13edf59..b966ed7 100644 --- a/dialect.go +++ b/dialect.go @@ -130,4 +130,10 @@ type Dialect interface { // KeywordAutoIncrement returns the dialect specific keyword for 'AUTO_INCREMENT'. KeywordAutoIncrement() string + + // ForeignKey returns the dialect spefific foreign key constraint + ForeignKey(fk *ForeignKey) string + + // ReferentialAction returns the dialect spefific foreign key referntial action + ReferentialAction(ra ReferentialAction) string } diff --git a/hood.go b/hood.go index 149436d..8a83404 100644 --- a/hood.go +++ b/hood.go @@ -55,6 +55,19 @@ type ( // Indexes represents an array of indexes. Indexes []*Index + // ForeignKey represents a foreign key + ForeignKey struct { + Name string + Column string + ReferenceTable string + ReferenceColumn string + OnUpdate ReferentialAction + OnDelete ReferentialAction + } + + // ForeignKeys represents an array of foreign keys + ForeignKeys []*ForeignKey + // Created denotes a timestamp field that is automatically set on insert. Created struct { time.Time @@ -67,10 +80,11 @@ type ( // Model represents a parsed schema interface{}. Model struct { - Pk *ModelField - Table string - Fields []*ModelField - Indexes Indexes + Pk *ModelField + Table string + Fields []*ModelField + Indexes Indexes + ForeignKeys ForeignKeys } // ModelField represents a schema field of a parsed model. @@ -100,6 +114,11 @@ type ( Indexes(indexes *Indexes) } + // ForeignKeyed defines the foreign keys for a table. + ForeignKeyed interface { + ForeignKeys(foreignKeys *ForeignKeys) + } + // TODO: implement aggregate function types // // // Avg denotes an average aggregate function argument @@ -150,6 +169,15 @@ const ( type Join int +const ( + Cascade = ReferentialAction(iota) + Restrict + NoAction + SetNull +) + +type ReferentialAction int + // Add adds an index func (ix *Indexes) Add(name string, columns ...string) { *ix = append(*ix, &Index{Name: name, Columns: columns, Unique: false}) @@ -160,6 +188,11 @@ func (ix *Indexes) AddUnique(name string, columns ...string) { *ix = append(*ix, &Index{Name: name, Columns: columns, Unique: true}) } +// Add adds a foreign key +func (fk *ForeignKeys) Add(name string, column string, referenceTable string, referenceColumn string, onUpdate ReferentialAction, onDelete ReferentialAction) { + *fk = append(*fk, &ForeignKey{Name: name, Column: column, ReferenceTable: referenceTable, ReferenceColumn: referenceColumn, OnUpdate: onUpdate, OnDelete: onDelete}) +} + // Quote quotes the path using the given dialects Quote method func (p Path) Quote(d Dialect) string { sep := "." @@ -358,6 +391,18 @@ func (index *Index) GoDeclaration() string { ) } +func (foreignKey *ForeignKey) GoDeclaration() string { + return fmt.Sprintf( + "foreignKeys.Add(\"%s\", \"%s\", \"%s\", \"%s\", %s, %s)", + foreignKey.Name, + foreignKey.Column, + foreignKey.ReferenceTable, + foreignKey.ReferenceColumn, + foreignKey.OnUpdate, + foreignKey.OnDelete, + ) +} + func (model *Model) Validate() error { for _, field := range model.Fields { err := field.Validate() @@ -384,6 +429,15 @@ func (model *Model) GoDeclaration() string { } a = append(a, "}") } + if len(model.ForeignKeys) > 0 { + a = append(a, + fmt.Sprintf("\nfunc (table *%s) ForeignKeys(foreignKeys *hood.ForeignKeys) {", tableName), + ) + for _, fk := range model.ForeignKeys { + a = append(a, "\t"+fk.GoDeclaration()) + } + a = append(a, "}") + } return strings.Join(a, "\n") } @@ -1313,6 +1367,12 @@ func addIndexes(m *Model, f interface{}) { } } +func addForeignKeys(m *Model, f interface{}) { + if t, ok := f.(ForeignKeyed); ok { + t.ForeignKeys(&m.ForeignKeys) + } +} + func interfaceToModel(f interface{}) (*Model, error) { v := reflect.Indirect(reflect.ValueOf(f)) if v.Kind() != reflect.Struct { @@ -1327,6 +1387,7 @@ func interfaceToModel(f interface{}) (*Model, error) { } addFields(m, t, v) addIndexes(m, f) + addForeignKeys(m, f) return m, nil }