diff --git a/drainer/config.go b/drainer/config.go index c51eb9ca0..8664c5827 100644 --- a/drainer/config.go +++ b/drainer/config.go @@ -78,6 +78,11 @@ type SyncerConfig struct { EnableDispatch bool `toml:"enable-dispatch" json:"enable-dispatch"` SafeMode bool `toml:"safe-mode" json:"safe-mode"` EnableCausality bool `toml:"enable-detect" json:"enable-detect"` + PluginPath string `toml:"plugin-path" json:"plugin-path"` + PluginNames []string `toml:"plugin-names" json:"plugin-names"` + SupportPlugin bool `toml:"support-plugin" json:"support-plugin"` + MarkDBName string `toml:"mark-db-name" json:"mark-db-name"` + MarkTableName string `toml:"mark-table-name" json:"mark-table-name"` } // RelayConfig is the Relay log's configuration. @@ -158,6 +163,12 @@ func NewConfig() *Config { fs.IntVar(&cfg.SyncedCheckTime, "synced-check-time", defaultSyncedCheckTime, "if we can't detect new binlog after many minute, we think the all binlog is all synced") fs.StringVar(new(string), "log-rotate", "", "DEPRECATED") + fs.StringVar(&cfg.SyncerCfg.PluginPath, "plugin-path", "", "The path of the plugins") + fs.Var(newSliceNames([]string{}, &cfg.SyncerCfg.PluginNames), "plugin-names", "The names of the plugins") + fs.BoolVar(&cfg.SyncerCfg.SupportPlugin, "support-plugin", false, "Whether plugin is supported,default: false") + fs.StringVar(&cfg.SyncerCfg.MarkDBName, "mark-db-name", "rel", "mark database's name") + fs.StringVar(&cfg.SyncerCfg.MarkTableName, "mark-table-name", "_drainer_repl_mark", "mark table's name") + return cfg } diff --git a/drainer/config_test.go b/drainer/config_test.go index 93a4dd378..445615c4c 100644 --- a/drainer/config_test.go +++ b/drainer/config_test.go @@ -293,3 +293,75 @@ func (t *testKafkaSuite) TestConfigDestDBTypeKafka(c *C) { c.Assert(cfg.SyncerCfg.To.KafkaVersion, Equals, defaultKafkaVersion) c.Assert(cfg.SyncerCfg.To.KafkaMaxMessages, Equals, 1024) } + +func (t *testKafkaSuite) TestConfigPlugin(c *C) { + args := []string{} + + cfg := NewConfig() + err := cfg.Parse(args) + c.Assert(err, IsNil) + + c.Assert(len(cfg.SyncerCfg.PluginPath), Equals, 0) + c.Assert(len(cfg.SyncerCfg.PluginNames), Equals, 0) + c.Assert(cfg.SyncerCfg.SupportPlugin, Equals, false) + c.Assert(cfg.SyncerCfg.MarkDBName, Equals, "rel") + c.Assert(cfg.SyncerCfg.MarkTableName, Equals, "_drainer_repl_mark") + + args = []string{ + "-plugin-path", "/tmp/drainer/plugin", + "-plugin-names", "demo1", + "-support-plugin", + "-mark-db-name", "db1", + "-mark-table-name", "tb1", + } + + cfg = NewConfig() + err = cfg.Parse(args) + c.Assert(err, IsNil) + + c.Assert(cfg.SyncerCfg.PluginPath, Equals, "/tmp/drainer/plugin") + c.Assert(len(cfg.SyncerCfg.PluginNames), Equals, 1) + c.Assert(cfg.SyncerCfg.PluginNames[0], Equals, "demo1") + c.Assert(cfg.SyncerCfg.SupportPlugin, Equals, true) + c.Assert(cfg.SyncerCfg.MarkDBName, Equals, "db1") + c.Assert(cfg.SyncerCfg.MarkTableName, Equals, "tb1") + + args = []string{ + "-plugin-names", "demo1,demo2", + "-mark-db-name", "", + "-mark-table-name", "", + } + + cfg = NewConfig() + err = cfg.Parse(args) + c.Assert(err, IsNil) + + c.Assert(len(cfg.SyncerCfg.PluginNames), Equals, 2) + c.Assert(cfg.SyncerCfg.PluginNames[0], Equals, "demo1") + c.Assert(cfg.SyncerCfg.PluginNames[1], Equals, "demo2") + c.Assert(cfg.SyncerCfg.MarkDBName, Equals, "") + c.Assert(cfg.SyncerCfg.MarkTableName, Equals, "") + + args = []string{ + "-plugin-names", "", + } + + cfg = NewConfig() + err = cfg.Parse(args) + c.Assert(err, IsNil) + + c.Assert(len(cfg.SyncerCfg.PluginNames), Equals, 1) + c.Assert(cfg.SyncerCfg.PluginNames[0], Equals, "") + + args = []string{ + "-plugin-names", ",", + } + + cfg = NewConfig() + err = cfg.Parse(args) + c.Assert(err, IsNil) + + c.Assert(len(cfg.SyncerCfg.PluginNames), Equals, 2) + c.Assert(cfg.SyncerCfg.PluginNames[0], Equals, "") + c.Assert(cfg.SyncerCfg.PluginNames[1], Equals, "") +} diff --git a/drainer/isyncer.go b/drainer/isyncer.go new file mode 100644 index 000000000..9e22fbb9f --- /dev/null +++ b/drainer/isyncer.go @@ -0,0 +1,11 @@ +package drainer + +import ( + "github.com/pingcap/tidb-binlog/drainer/loopbacksync" + "github.com/pingcap/tidb-binlog/pkg/loader" +) + +// SyncerFilter is the interface that for syncer-plugin +type SyncerFilter interface { + FilterTxn(txn *loader.Txn, info *loopbacksync.LoopBackSync) (bool, error) +} diff --git a/drainer/loopbacksync/loopbacksync.go b/drainer/loopbacksync/loopbacksync.go index c30afcd55..1a59f071e 100644 --- a/drainer/loopbacksync/loopbacksync.go +++ b/drainer/loopbacksync/loopbacksync.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/pingcap/tidb-binlog/pkg/plugin" "go.uber.org/zap" ) @@ -36,37 +37,58 @@ const ( ChannelInfo = "channel_info" ) -// CreateMarkTableDDL is the DDL to create the mark table. -var CreateMarkTableDDL string = fmt.Sprintf("CREATE TABLE If Not Exists %s (%s bigint not null,%s bigint not null DEFAULT 0, %s bigint DEFAULT 0, %s varchar(64) ,PRIMARY KEY (%s,%s));", MarkTableName, ID, ChannelID, Val, ChannelInfo, ID, ChannelID) - -// CreateMarkDBDDL is DDL to create the database of mark table. -var CreateMarkDBDDL = "create database IF NOT EXISTS retl;" - //LoopBackSync loopback sync info type LoopBackSync struct { ChannelID int64 LoopbackControl bool + MarkDBName string + MarkTableName string SyncDDL bool + Index int64 + PluginPath string + PluginNames []string + Hooks []*plugin.EventHooks + SupportPlugin bool + RecordID int } //NewLoopBackSyncInfo return LoopBackSyncInfo objec -func NewLoopBackSyncInfo(ChannelID int64, LoopbackControl, SyncDDL bool) *LoopBackSync { +func NewLoopBackSyncInfo(ChannelID int64, LoopbackControl, SyncDDL bool, path string, names []string, SupportPlug bool, mdbname, mtablename string) *LoopBackSync { l := &LoopBackSync{ ChannelID: ChannelID, LoopbackControl: LoopbackControl, SyncDDL: SyncDDL, + Index: 0, + PluginPath: path, + PluginNames: names, + SupportPlugin: SupportPlug, + MarkDBName: strings.TrimSpace(mdbname), + MarkTableName: strings.TrimSpace(mtablename), + } + if l.SupportPlugin { + l.Hooks = make([]*plugin.EventHooks, 4) + l.Hooks[plugin.SyncerFilter] = &plugin.EventHooks{} + + l.Hooks[plugin.ExecutorExtend] = &plugin.EventHooks{} + l.Hooks[plugin.LoaderInit] = &plugin.EventHooks{} + l.Hooks[plugin.LoaderDestroy] = &plugin.EventHooks{} + } return l } // CreateMarkTable create the db and table if need. -func CreateMarkTable(db *sql.DB) error { - _, err := db.Exec(CreateMarkDBDDL) +func CreateMarkTable(db *sql.DB, mdbname, mtablename string) error { + var err error + var createMarkDBDDL = fmt.Sprintf("create database IF NOT EXISTS %s;", mdbname) + _, err = db.Exec(createMarkDBDDL) if err != nil { return errors.Annotate(err, "failed to create mark db") } - _, err = db.Exec(CreateMarkTableDDL) + // CreateMarkTableDDL is the DDL to create the mark table. + var createMarkTableDDL string = fmt.Sprintf("CREATE TABLE If Not Exists %s.%s (%s bigint not null,%s bigint not null DEFAULT 0, %s bigint DEFAULT 0, %s varchar(64) ,PRIMARY KEY (%s,%s));", mdbname, mtablename, ID, ChannelID, Val, ChannelInfo, ID, ChannelID) + _, err = db.Exec(createMarkTableDDL) if err != nil { return errors.Annotate(err, "failed to create mark table") } diff --git a/drainer/loopbacksync/loopbacksync_test.go b/drainer/loopbacksync/loopbacksync_test.go index cdb250602..ca3fb7145 100644 --- a/drainer/loopbacksync/loopbacksync_test.go +++ b/drainer/loopbacksync/loopbacksync_test.go @@ -15,6 +15,7 @@ package loopbacksync import ( "database/sql/driver" + "fmt" "regexp" "testing" @@ -32,12 +33,18 @@ func (s *loopbackSuite) TestNewLoopBackSyncInfo(c *check.C) { var ChannelID int64 = 1 var LoopbackControl = true var SyncDDL = false - l := NewLoopBackSyncInfo(ChannelID, LoopbackControl, SyncDDL) + + l := NewLoopBackSyncInfo(ChannelID, LoopbackControl, SyncDDL, "", nil, false, "rel", "_drainer_repl_mark") c.Assert(l, check.DeepEquals, &LoopBackSync{ ChannelID: ChannelID, LoopbackControl: LoopbackControl, SyncDDL: SyncDDL, + PluginPath: "", + PluginNames: nil, + SupportPlugin: false, + MarkDBName: "rel", + MarkTableName: "_drainer_repl_mark", }) } @@ -45,12 +52,15 @@ func (s *loopbackSuite) TestCreateMarkTable(c *check.C) { db, mk, err := sqlmock.New() c.Assert(err, check.IsNil) + CreateMarkDBDDL := "create database IF NOT EXISTS rel;" + CreateMarkTableDDL := fmt.Sprintf("CREATE TABLE If Not Exists %s.%s (%s bigint not null,%s bigint not null DEFAULT 0, %s bigint DEFAULT 0, %s varchar(64) ,PRIMARY KEY (%s,%s));", "rel", "_drainer_repl_mark", ID, ChannelID, Val, ChannelInfo, ID, ChannelID) + mk.ExpectExec(regexp.QuoteMeta(CreateMarkDBDDL)). WillReturnResult(sqlmock.NewResult(0, 0)) mk.ExpectExec(regexp.QuoteMeta(CreateMarkTableDDL)). WillReturnResult(sqlmock.NewResult(0, 0)) - err = CreateMarkTable(db) + err = CreateMarkTable(db, "rel", "_drainer_repl_mark") c.Assert(err, check.IsNil) err = mk.ExpectationsWereMet() diff --git a/drainer/syncer.go b/drainer/syncer.go index fe9104189..8e9badbe8 100644 --- a/drainer/syncer.go +++ b/drainer/syncer.go @@ -19,12 +19,12 @@ import ( "sync/atomic" "time" - "github.com/pingcap/tidb-binlog/drainer/loopbacksync" - "github.com/pingcap/tidb-binlog/pkg/loader" - "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb-binlog/drainer/loopbacksync" + "github.com/pingcap/tidb-binlog/pkg/loader" + "github.com/pingcap/tidb-binlog/pkg/plugin" "go.uber.org/zap" "github.com/pingcap/tidb-binlog/drainer/checkpoint" @@ -77,8 +77,35 @@ func NewSyncer(cp checkpoint.CheckPoint, cfg *SyncerConfig, jobs []*model.Job) ( ignoreDBs = strings.Split(cfg.IgnoreSchemas, ",") } syncer.filter = filter.NewFilter(ignoreDBs, cfg.IgnoreTables, cfg.DoDBs, cfg.DoTables) - syncer.loopbackSync = loopbacksync.NewLoopBackSyncInfo(cfg.ChannelID, cfg.LoopbackControl, cfg.SyncDDL) + syncer.loopbackSync = loopbacksync.NewLoopBackSyncInfo(cfg.ChannelID, cfg.LoopbackControl, cfg.SyncDDL, cfg.PluginPath, + cfg.PluginNames, cfg.SupportPlugin, cfg.MarkDBName, cfg.MarkTableName) + if syncer.loopbackSync.SupportPlugin { + log.Info("Begin to Load syncer-plugins.") + for _, name := range syncer.loopbackSync.PluginNames { + n := strings.TrimSpace(name) + sym, err := plugin.LoadPlugin(syncer.loopbackSync.PluginPath, n) + if err != nil { + log.Error("Load plugin failed.", zap.String("plugin name", n), + zap.String("error", err.Error())) + continue + } + newPlugin, ok := sym.(func() interface{}) + if !ok { + log.Error("The correct new-function is not provided.", zap.String("plugin name", n), zap.String("type", "syncer plugin")) + continue + } + plg := newPlugin() + _, ok = plg.(SyncerFilter) + if !ok { + log.Info("SyncerFilter interface is not implemented.", zap.String("plugin name", n)) + } else { + plugin.RegisterPlugin(syncer.loopbackSync.Hooks[plugin.SyncerFilter], + n, plg) + log.Info("Load plugin success.", zap.String("plugin name", n), zap.String("interface", "SyncerFilter")) + } + } + } var err error // create schema syncer.schema, err = NewSchema(jobs, false) @@ -357,8 +384,31 @@ ForLoop: err = errors.Annotate(err, "handlePreviousDDLJobIfNeed failed") break ForLoop } + var isFilterTransaction = false var err1 error + + if s.loopbackSync.SupportPlugin { + hook := s.loopbackSync.Hooks[plugin.SyncerFilter] + var txn *loader.Txn + txn, err1 = translator.TiBinlogToTxn(s.schema, "", "", binlog, preWrite, false) + hook.Range(func(k, val interface{}) bool { + c, ok := val.(SyncerFilter) + if !ok { + return true + } + isFilterTransaction, err1 = c.FilterTxn(txn, s.loopbackSync) + if isFilterTransaction || err1 != nil { + return false + } + return true + }) + if err1 != nil { + log.Warn("FilterTxn return error", zap.String("error", err1.Error())) + break ForLoop + } + } + if s.loopbackSync != nil && s.loopbackSync.LoopbackControl { isFilterTransaction, err1 = loopBackStatus(binlog, preWrite, s.schema, s.loopbackSync) if err1 != nil { @@ -414,6 +464,36 @@ ForLoop: break ForLoop } + if s.loopbackSync.SupportPlugin { + var isFilterTransaction = false + var err1 error + txn := new(loader.Txn) + txn.DDL = &loader.DDL{ + Database: schema, + Table: table, + SQL: string(binlog.GetDdlQuery()), + } + hook := s.loopbackSync.Hooks[plugin.SyncerFilter] + hook.Range(func(k, val interface{}) bool { + c, ok := val.(SyncerFilter) + if !ok { + return true + } + isFilterTransaction, err1 = c.FilterTxn(txn, s.loopbackSync) + if isFilterTransaction || err1 != nil { + return false + } + return true + }) + if err1 != nil { + log.Warn("FilterTxn return error", zap.String("error", err1.Error())) + break ForLoop + } + if isFilterTransaction { + continue + } + } + if s.filter.SkipSchemaAndTable(schema, table) { log.Info("skip ddl by filter", zap.String("schema", schema), zap.String("table", table), zap.String("sql", sql), zap.Int64("commit ts", commitTS)) diff --git a/drainer/util.go b/drainer/util.go index 9543cdf3c..872cd82d5 100644 --- a/drainer/util.go +++ b/drainer/util.go @@ -20,6 +20,7 @@ import ( "os" "path" "sort" + "strings" "sync" "github.com/Shopify/sarama" @@ -191,3 +192,19 @@ func genDrainerID(listenAddr string) (string, error) { return fmt.Sprintf("%s:%s", hostname, port), nil } + +type sliceNames []string + +func newSliceNames(vals []string, p *[]string) *sliceNames { + *p = vals + return (*sliceNames)(p) +} + +func (s *sliceNames) Set(val string) error { + *s = sliceNames(strings.Split(val, ",")) + return nil +} + +func (s *sliceNames) Get() interface{} { return []string(*s) } + +func (s *sliceNames) String() string { return strings.Join([]string(*s), ",") } diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go index 159ec7179..92485bba4 100644 --- a/pkg/loader/executor.go +++ b/pkg/loader/executor.go @@ -22,6 +22,7 @@ import ( "time" "github.com/pingcap/tidb-binlog/drainer/loopbacksync" + "github.com/pingcap/tidb-binlog/pkg/plugin" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/errors" @@ -36,7 +37,6 @@ import ( var ( defaultBatchSize = 128 defaultWorkerCount = 16 - index int64 ) type executor struct { @@ -54,7 +54,6 @@ func newExecutor(db *gosql.DB) *executor { batchSize: defaultBatchSize, workerCount: defaultWorkerCount, } - return exe } @@ -88,14 +87,14 @@ func (e *executor) execTableBatchRetry(ctx context.Context, dmls []*DML, retryNu return errors.Trace(err) } -// a wrap of *sql.Tx with metrics -type tx struct { +// Tx is a wrap of *sql.Tx with metrics +type Tx struct { *gosql.Tx queryHistogramVec *prometheus.HistogramVec } // wrap of sql.Tx.Exec() -func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { +func (tx *Tx) exec(query string, args ...interface{}) (gosql.Result, error) { start := time.Now() res, err := tx.Tx.Exec(query, args...) if tx.queryHistogramVec != nil { @@ -105,7 +104,7 @@ func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { return res, err } -func (tx *tx) autoRollbackExec(query string, args ...interface{}) (res gosql.Result, err error) { +func (tx *Tx) autoRollbackExec(query string, args ...interface{}) (res gosql.Result, err error) { res, err = tx.exec(query, args...) if err != nil { log.Error("Exec fail, will rollback", zap.String("query", query), zap.Reflect("args", args), zap.Error(err)) @@ -118,7 +117,7 @@ func (tx *tx) autoRollbackExec(query string, args ...interface{}) (res gosql.Res } // wrap of sql.Tx.Commit() -func (tx *tx) commit() error { +func (tx *Tx) commit() error { start := time.Now() err := tx.Tx.Commit() if tx.queryHistogramVec != nil { @@ -128,18 +127,14 @@ func (tx *tx) commit() error { return errors.Trace(err) } -func (e *executor) addIndex() int64 { - return atomic.AddInt64(&index, 1) % ((int64)(e.workerCount)) -} - // return a wrap of sql.Tx -func (e *executor) begin() (*tx, error) { +func (e *executor) begin() (*Tx, error) { sqlTx, err := e.db.Begin() if err != nil { return nil, errors.Trace(err) } - var tx = &tx{ + var tx = &Tx{ Tx: sqlTx, queryHistogramVec: e.queryHistogramVec, } @@ -147,7 +142,7 @@ func (e *executor) begin() (*tx, error) { if e.info != nil && e.info.LoopbackControl { start := time.Now() - err = loopbacksync.UpdateMark(tx.Tx, e.addIndex(), e.info.ChannelID) + err = loopbacksync.UpdateMark(tx.Tx, atomic.AddInt64(&e.info.Index, 1)%((int64)(e.workerCount)), e.info.ChannelID) if err != nil { rerr := tx.Rollback() if rerr != nil { @@ -164,24 +159,48 @@ func (e *executor) begin() (*tx, error) { return tx, nil } +// return a wrap of sql.Tx +func (e *executor) externPoint(t *Tx, dmls []*DML) (*Tx, []*DML) { + hook := e.info.Hooks[plugin.ExecutorExtend] + hook.Range(func(k, val interface{}) bool { + c, ok := val.(ExecutorExtend) + if !ok { + //ignore type incorrect error + return true + } + t, dmls = c.ExtendTxn(t, dmls, e.info) + return dmls != nil + }) + return t, dmls +} + func (e *executor) bulkDelete(deletes []*DML) error { if len(deletes) == 0 { return nil } var sqls strings.Builder - argss := make([]interface{}, 0, len(deletes)) + tx, err := e.begin() + if err != nil { + return errors.Trace(err) + } + + if e.info.SupportPlugin { + tx, deletes = e.externPoint(tx, deletes) + if len(deletes) == 0 { + return nil + } + } + + argss := make([]interface{}, 0, len(deletes)) for _, dml := range deletes { sql, args := dml.sql() sqls.WriteString(sql) sqls.WriteByte(';') argss = append(argss, args...) } - tx, err := e.begin() - if err != nil { - return errors.Trace(err) - } + sql := sqls.String() _, err = tx.autoRollbackExec(sql, argss...) if err != nil { @@ -197,6 +216,18 @@ func (e *executor) bulkReplace(inserts []*DML) error { return nil } + tx, err := e.begin() + if err != nil { + return errors.Trace(err) + } + + if e.info.SupportPlugin { + tx, inserts = e.externPoint(tx, inserts) + if len(inserts) == 0 { + return nil + } + } + info := inserts[0].info var builder strings.Builder @@ -219,10 +250,7 @@ func (e *executor) bulkReplace(inserts []*DML) error { args = append(args, v) } } - tx, err := e.begin() - if err != nil { - return errors.Trace(err) - } + _, err = tx.autoRollbackExec(builder.String(), args...) if err != nil { return errors.Trace(err) @@ -351,6 +379,13 @@ func (e *executor) singleExec(dmls []*DML, safeMode bool) error { return errors.Trace(err) } + if e.info.SupportPlugin { + tx, dmls = e.externPoint(tx, dmls) + if len(dmls) == 0 { + return nil + } + } + for _, dml := range dmls { if safeMode && dml.Tp == UpdateDMLType { sql, args := dml.deleteSQL() diff --git a/pkg/loader/iloader.go b/pkg/loader/iloader.go new file mode 100644 index 000000000..514c1cb73 --- /dev/null +++ b/pkg/loader/iloader.go @@ -0,0 +1,21 @@ +package loader + +import ( + gosql "database/sql" + "github.com/pingcap/tidb-binlog/drainer/loopbacksync" +) + +// ExecutorExtend is the interface for loader plugin +type ExecutorExtend interface { + ExtendTxn(tx *Tx, dmls []*DML, info *loopbacksync.LoopBackSync) (*Tx, []*DML) +} + +// Init is the interface for loader plugin +type Init interface { + LoaderInit(db *gosql.DB, info *loopbacksync.LoopBackSync) error +} + +// Destroy is the interface that for loader-plugin +type Destroy interface { + LoaderDestroy(db *gosql.DB, info *loopbacksync.LoopBackSync) error +} diff --git a/pkg/loader/load.go b/pkg/loader/load.go index d9b0f1cae..66c39983c 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -17,14 +17,15 @@ import ( "context" gosql "database/sql" "fmt" + "strings" "sync" "sync/atomic" "time" - "github.com/pingcap/tidb-binlog/drainer/loopbacksync" - "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/pingcap/tidb-binlog/drainer/loopbacksync" + "github.com/pingcap/tidb-binlog/pkg/plugin" "github.com/pingcap/tidb-binlog/pkg/util" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" @@ -203,9 +204,57 @@ func NewLoader(db *gosql.DB, opt ...Option) (Loader, error) { cancel: cancel, } + s.loopBackSyncInfo.RecordID = s.workerCount + db.SetMaxOpenConns(opts.workerCount) db.SetMaxIdleConns(opts.workerCount) + if s.loopBackSyncInfo.SupportPlugin { + log.Info("Begin to Load loader-plugins.") + for _, name := range s.loopBackSyncInfo.PluginNames { + n := strings.TrimSpace(name) + sym, err := plugin.LoadPlugin(s.loopBackSyncInfo.PluginPath, n) + if err != nil { + log.Error("Load plugin failed.", zap.String("plugin name", n), + zap.String("error", err.Error())) + continue + } + newPlugin, ok := sym.(func() interface{}) + if !ok { + log.Error("The correct new-function is not provided.", zap.String("plugin name", n), zap.String("type", "loader plugin")) + continue + } + + plg := newPlugin() + _, ok = plg.(ExecutorExtend) + if !ok { + log.Info("ExecutorExtend interface is not implemented.", zap.String("plugin name", n)) + } else { + plugin.RegisterPlugin(s.loopBackSyncInfo.Hooks[plugin.ExecutorExtend], + n, plg) + log.Info("Load plugin success.", zap.String("plugin name", n), zap.String("interface", "ExecutorExtend")) + } + + _, ok = plg.(Init) + if !ok { + log.Info("LoaderInit interface is not implemented.", zap.String("plugin name", n)) + } else { + plugin.RegisterPlugin(s.loopBackSyncInfo.Hooks[plugin.LoaderInit], + n, plg) + log.Info("Load plugin success.", zap.String("plugin name", n), zap.String("interface", "LoaderInit")) + } + + _, ok = plg.(Destroy) + if !ok { + log.Info("LoaderDestroy interface is not implemented.", zap.String("plugin name", n)) + } else { + plugin.RegisterPlugin(s.loopBackSyncInfo.Hooks[plugin.LoaderDestroy], + n, plg) + log.Info("Load plugin success.", zap.String("plugin name", n), zap.String("interface", "LoaderDestroy")) + } + } + } + return s, nil } @@ -494,7 +543,7 @@ func (s *loaderImpl) execDMLs(dmls []*DML) error { } func (s *loaderImpl) initMarkTable() error { - if err := loopbacksync.CreateMarkTable(s.db); err != nil { + if err := loopbacksync.CreateMarkTable(s.db, s.loopBackSyncInfo.MarkDBName, s.loopBackSyncInfo.MarkTableName); err != nil { return errors.Trace(err) } return loopbacksync.InitMarkTableData(s.db, s.workerCount, s.loopBackSyncInfo.ChannelID) @@ -507,6 +556,40 @@ func (s *loaderImpl) Run() error { close(s.successTxn) }() + defer func() { + if s.loopBackSyncInfo.SupportPlugin { + var err error + hook := s.loopBackSyncInfo.Hooks[plugin.LoaderDestroy] + hook.Range(func(k, val interface{}) bool { + c, ok := val.(Destroy) + if !ok { + return true + } + err = c.LoaderDestroy(s.db, s.loopBackSyncInfo) + return err == nil + }) + if err != nil { + log.Error(errors.Trace(err).Error()) + } + } + }() + + var err error + if s.loopBackSyncInfo.SupportPlugin { + hook := s.loopBackSyncInfo.Hooks[plugin.LoaderInit] + hook.Range(func(k, val interface{}) bool { + c, ok := val.(Init) + if !ok { + return true + } + err = c.LoaderInit(s.db, s.loopBackSyncInfo) + return err == nil + }) + if err != nil { + return errors.Trace(err) + } + } + if s.loopBackSyncInfo != nil && s.loopBackSyncInfo.LoopbackControl { if err := s.initMarkTable(); err != nil { return errors.Trace(err) @@ -526,6 +609,7 @@ func (s *loaderImpl) Run() error { input := txnManager.run() for { + select { case txn, ok := <-input: if !ok { @@ -535,7 +619,6 @@ func (s *loaderImpl) Run() error { } return nil } - s.metricsInputTxn(txn) txnManager.pop(txn) if err := batch.put(txn); err != nil { diff --git a/pkg/loader/util.go b/pkg/loader/util.go index 20ef3dd0f..26f085dbf 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -113,6 +113,9 @@ func CreateDB(user string, password string, host string, port int, tls *tls.Conf } func quoteSchema(schema string, table string) string { + if len(schema) == 0 { + return fmt.Sprintf("`%s`", escapeName(table)) + } return fmt.Sprintf("`%s`.`%s`", escapeName(schema), escapeName(table)) } diff --git a/pkg/plugin/plugindemo/Makefile b/pkg/plugin/plugindemo/Makefile new file mode 100644 index 000000000..cc0838a46 --- /dev/null +++ b/pkg/plugin/plugindemo/Makefile @@ -0,0 +1,3 @@ +plugin: + gofmt -w demo.go + go build -o demo.so -buildmode=plugin demo.go diff --git a/pkg/plugin/plugindemo/demo.go b/pkg/plugin/plugindemo/demo.go new file mode 100644 index 000000000..856d92c4b --- /dev/null +++ b/pkg/plugin/plugindemo/demo.go @@ -0,0 +1,32 @@ +package main + +import ( + "github.com/pingcap/log" + "github.com/pingcap/tidb-binlog/drainer/loopbacksync" + "github.com/pingcap/tidb-binlog/pkg/loader" +) + +//PluginDemo is a demo struct +type PluginDemo struct{} + +//ExtendTxn is one of the Hook +func (pd PluginDemo) ExtendTxn(tx *loader.Tx, dmls []*loader.DML, info *loopbacksync.LoopBackSync) (*loader.Tx, []*loader.DML) { + //do sth + log.Info("i am ExtendTxn") + return nil, nil +} + +//FilterTxn is one of the Hook +func (pd PluginDemo) FilterTxn(txn *loader.Txn, info *loopbacksync.LoopBackSync) (bool, error) { + //do sth + log.Info("i am FilterTxn") + return true, nil +} + +//NewPlugin is the Factory function of plugin +func NewPlugin() interface{} { + return PluginDemo{} +} + +var _ PluginDemo +var _ = NewPlugin() diff --git a/pkg/plugin/plugins.go b/pkg/plugin/plugins.go new file mode 100644 index 000000000..5f0a66596 --- /dev/null +++ b/pkg/plugin/plugins.go @@ -0,0 +1,69 @@ +package plugin + +import ( + "fmt" + "plugin" + "sync" +) + +//Kind is the plugin's type thant we supported currently +type Kind uint8 + +const ( + //SyncerFilter is one kind of Plugin for syncer + SyncerFilter Kind = iota + //ExecutorExtend is one kind of Plugin for loader + ExecutorExtend + //LoaderInit is one kind of Plugin for loader + LoaderInit + //LoaderDestroy is one kind of Plugin for loader + LoaderDestroy + //FactorFunc is the factory of all plugins + FactorFunc = "NewPlugin" +) + +//EventHooks is a map of hook name to hook +type EventHooks struct { + sync.Map +} + +func (ehs *EventHooks) setPlugin(name string, plg interface{}) *EventHooks { + if len(name) == 0 || ehs == nil { + return ehs + } + ehs.Store(name, plg) + return ehs +} + +//GetAllPluginsName is get all names of plugin +func (ehs *EventHooks) GetAllPluginsName() []string { + if ehs == nil { + return nil + } + var ns []string = make([]string, 0) + ehs.Range(func(k, v interface{}) bool { + name, ok := k.(string) + if !ok { + return true + } + ns = append(ns, name) + return true + }) + return ns +} + +//LoadPlugin can load plugin by plugin's name +func LoadPlugin(path, name string) (plugin.Symbol, error) { + fp := path + "/" + name + p, err := plugin.Open(fp) + if err != nil { + return nil, fmt.Errorf("Open %s failed. err: %s", fp, err.Error()) + } + + return p.Lookup(FactorFunc) +} + +//RegisterPlugin register plugin to EventHooks +func RegisterPlugin(ehs *EventHooks, name string, plg interface{}) { + ehs.setPlugin(name, plg) +} diff --git a/pkg/plugin/plugins_test.go b/pkg/plugin/plugins_test.go new file mode 100644 index 000000000..89bc3c9c2 --- /dev/null +++ b/pkg/plugin/plugins_test.go @@ -0,0 +1,81 @@ +package plugin + +import ( + "fmt" + "testing" + + "github.com/pingcap/check" +) + +func Test(t *testing.T) { check.TestingT(t) } + +type PluginSuite struct { +} + +var _ = check.Suite(&PluginSuite{}) + +func (ps *PluginSuite) SetUpTest(c *check.C) { +} + +func (ps *PluginSuite) TearDownTest(c *check.C) { +} + +type ITest1 interface { + Do() int +} +type STest1 struct { + a int +} + +func (s STest1) Do() int { + return s.a +} + +func (ps *PluginSuite) TestRegisterPlugin(c *check.C) { + hook := &EventHooks{} + s1 := STest1{32} + + RegisterPlugin(hook, "test1", s1) + p := hook.GetAllPluginsName() + c.Assert(len(p), check.Equals, 1) + + RegisterPlugin(hook, "test1", s1) + p = hook.GetAllPluginsName() + c.Assert(len(p), check.Equals, 1) + + s2 := STest1{64} + RegisterPlugin(hook, "test2", s2) + p = hook.GetAllPluginsName() + c.Assert(len(p), check.Equals, 2) + c.Assert(p[0], check.Equals, "test1") + c.Assert(p[1], check.Equals, "test2") +} + +func (ps *PluginSuite) TestTraversePlugin(c *check.C) { + hook := &EventHooks{} + + s1 := STest1{32} + RegisterPlugin(hook, "test1", s1) + + s2 := STest1{64} + RegisterPlugin(hook, "test2", s2) + + s3 := STest1{128} + RegisterPlugin(hook, "test3", s3) + + p := hook.GetAllPluginsName() + c.Assert(len(p), check.Equals, 3) + + ret := 0 + hook.Range(func(k, val interface{}) bool { + c, ok := val.(ITest1) + if !ok { + //ignore type incorrect error + fmt.Printf("ok : %v\n", ok) + return true + } + ret += c.Do() + return true + }) + c.Assert(ret, check.Equals, 32+64+128) +}