Skip to content

Commit a77f892

Browse files
authored
feat: table prefix (#47)
Signed-off-by: abingcbc <abingcbc626@gmail.com>
1 parent 2ea99d3 commit a77f892

File tree

2 files changed

+65
-60
lines changed

2 files changed

+65
-60
lines changed

adapter.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type Adapter struct {
5454
dbSpecified bool
5555
isFiltered bool
5656
engine *xorm.Engine
57+
tablePrefix string
5758
tableName string
5859
}
5960

@@ -112,11 +113,12 @@ func NewAdapter(driverName string, dataSourceName string, dbSpecified ...bool) (
112113
}
113114

114115
// NewAdapterWithTableName .
115-
func NewAdapterWithTableName(driverName string, dataSourceName string, tableName string, dbSpecified ...bool) (*Adapter, error) {
116+
func NewAdapterWithTableName(driverName string, dataSourceName string, tableName string, tablePrefix string, dbSpecified ...bool) (*Adapter, error) {
116117
a := &Adapter{
117118
driverName: driverName,
118119
dataSourceName: dataSourceName,
119120
tableName: tableName,
121+
tablePrefix: tablePrefix,
120122
}
121123

122124
if len(dbSpecified) == 0 {
@@ -154,10 +156,11 @@ func NewAdapterByEngine(engine *xorm.Engine) (*Adapter, error) {
154156
}
155157

156158
// NewAdapterByEngineWithTableName .
157-
func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string) (*Adapter, error) {
159+
func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string, tablePrefix string) (*Adapter, error) {
158160
a := &Adapter{
159-
engine: engine,
160-
tableName: tableName,
161+
engine: engine,
162+
tableName: tableName,
163+
tablePrefix: tablePrefix,
161164
}
162165

163166
err := a.createTable()
@@ -168,6 +171,13 @@ func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string) (*Ad
168171
return a, nil
169172
}
170173

174+
func (a *Adapter) getFullTableName() string {
175+
if a.tablePrefix != "" {
176+
return a.tablePrefix + "_" + a.tableName
177+
}
178+
return a.tableName
179+
}
180+
171181
func (a *Adapter) createDatabase() error {
172182
var err error
173183
var engine *xorm.Engine
@@ -231,11 +241,11 @@ func (a *Adapter) open() error {
231241
}
232242

233243
func (a *Adapter) createTable() error {
234-
return a.engine.Sync2(&CasbinRule{tableName: a.tableName})
244+
return a.engine.Sync2(&CasbinRule{tableName: a.getFullTableName()})
235245
}
236246

237247
func (a *Adapter) dropTable() error {
238-
return a.engine.DropTables(&CasbinRule{tableName: a.tableName})
248+
return a.engine.DropTables(&CasbinRule{tableName: a.getFullTableName()})
239249
}
240250

241251
func loadPolicyLine(line *CasbinRule, model model.Model) {
@@ -263,7 +273,7 @@ func loadPolicyLine(line *CasbinRule, model model.Model) {
263273
func (a *Adapter) LoadPolicy(model model.Model) error {
264274
lines := make([]*CasbinRule, 0, 64)
265275

266-
if err := a.engine.Table(&CasbinRule{tableName: a.tableName}).Find(&lines); err != nil {
276+
if err := a.engine.Table(&CasbinRule{tableName: a.getFullTableName()}).Find(&lines); err != nil {
267277
return err
268278
}
269279

@@ -275,7 +285,7 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
275285
}
276286

277287
func (a *Adapter) genPolicyLine(ptype string, rule []string) *CasbinRule {
278-
line := CasbinRule{PType: ptype, tableName: a.tableName}
288+
line := CasbinRule{PType: ptype, tableName: a.getFullTableName()}
279289

280290
l := len(rule)
281291
if l > 0 {
@@ -383,7 +393,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err
383393

384394
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
385395
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
386-
line := CasbinRule{PType: ptype, tableName: a.tableName}
396+
line := CasbinRule{PType: ptype, tableName: a.getFullTableName()}
387397

388398
idx := fieldIndex + len(fieldValues)
389399
if fieldIndex <= 0 && idx > 0 {
@@ -417,7 +427,7 @@ func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) erro
417427
}
418428

419429
lines := make([]*CasbinRule, 0, 64)
420-
if err := a.filterQuery(a.engine.NewSession(), filterValue).Table(&CasbinRule{tableName: a.tableName}).Find(&lines); err != nil {
430+
if err := a.filterQuery(a.engine.NewSession(), filterValue).Table(&CasbinRule{tableName: a.getFullTableName()}).Find(&lines); err != nil {
421431
return err
422432
}
423433

@@ -516,7 +526,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
516526
for _, newRule := range newPolicies {
517527
newP = append(newP, *a.genPolicyLine(ptype, newRule))
518528
}
519-
tx := a.engine.NewSession()
529+
tx := a.engine.NewSession().Table(&CasbinRule{tableName: a.getFullTableName()})
520530
defer tx.Close()
521531

522532
if err := tx.Begin(); err != nil {
@@ -528,7 +538,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
528538
if err := tx.Where(str, args...).Find(&oldP); err != nil {
529539
return nil, tx.Rollback()
530540
}
531-
if _, err := tx.Where(str.(string), args...).Delete(CasbinRule{}); err != nil {
541+
if _, err := tx.Where(str.(string), args...).Delete(&CasbinRule{tableName: a.getFullTableName()}); err != nil {
532542
return nil, tx.Rollback()
533543
}
534544
if _, err := tx.Insert(&newP[i]); err != nil {

adapter_test.go

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
package xormadapter
1616

1717
import (
18-
"github.com/casbin/casbin/v2/util"
1918
"log"
2019
"strings"
2120
"testing"
2221

2322
"github.com/casbin/casbin/v2"
23+
"github.com/casbin/casbin/v2/util"
2424
_ "github.com/go-sql-driver/mysql"
2525
_ "github.com/lib/pq"
2626
)
@@ -45,20 +45,15 @@ func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
4545
}
4646
}
4747

48-
func initPolicy(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
48+
func initPolicy(t *testing.T, a *Adapter) {
4949
// Because the DB is empty at first,
5050
// so we need to load the policy from the file adapter (.CSV) first.
5151
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", "examples/rbac_policy.csv")
5252

53-
a, err := NewAdapter(driverName, dataSourceName, dbSpecified...)
54-
if err != nil {
55-
panic(err)
56-
}
57-
5853
// This is a trick to save the current policy to the DB.
5954
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
6055
// The current policy means the policy in the Casbin enforcer (aka in memory).
61-
err = a.SavePolicy(e.GetModel())
56+
err := a.SavePolicy(e.GetModel())
6257
if err != nil {
6358
panic(err)
6459
}
@@ -75,30 +70,28 @@ func initPolicy(t *testing.T, driverName string, dataSourceName string, dbSpecif
7570
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
7671
}
7772

78-
func testSaveLoad(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
73+
func testSaveLoad(t *testing.T, a *Adapter) {
7974
// Initialize some policy in DB.
80-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
75+
initPolicy(t, a)
8176
// Note: you don't need to look at the above code
8277
// if you already have a working DB with policy inside.
8378

8479
// Now the DB has policy, so we can provide a normal use case.
8580
// Create an adapter and an enforcer.
8681
// NewEnforcer() will load the policy automatically.
87-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
8882
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
8983
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
9084
}
9185

92-
func testAutoSave(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
86+
func testAutoSave(t *testing.T, a *Adapter) {
9387
// Initialize some policy in DB.
94-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
88+
initPolicy(t, a)
9589
// Note: you don't need to look at the above code
9690
// if you already have a working DB with policy inside.
9791

9892
// Now the DB has policy, so we can provide a normal use case.
9993
// Create an adapter and an enforcer.
10094
// NewEnforcer() will load the policy automatically.
101-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
10295
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
10396

10497
// AutoSave is enabled by default.
@@ -152,16 +145,15 @@ func testAutoSave(t *testing.T, driverName string, dataSourceName string, dbSpec
152145
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
153146
}
154147

155-
func testFilteredPolicy(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
148+
func testFilteredPolicy(t *testing.T, a *Adapter) {
156149
// Initialize some policy in DB.
157-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
150+
initPolicy(t, a)
158151
// Note: you don't need to look at the above code
159152
// if you already have a working DB with policy inside.
160153

161154
// Now the DB has policy, so we can provide a normal use case.
162155
// Create an adapter and an enforcer.
163156
// NewEnforcer() will load the policy automatically.
164-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
165157
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")
166158
// Now set the adapter
167159
e.SetAdapter(a)
@@ -194,16 +186,15 @@ func testFilteredPolicy(t *testing.T, driverName string, dataSourceName string,
194186
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
195187
}
196188

197-
func testRemovePolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
189+
func testRemovePolicies(t *testing.T, a *Adapter) {
198190
// Initialize some policy in DB.
199-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
191+
initPolicy(t, a)
200192
// Note: you don't need to look at the above code
201193
// if you already have a working DB with policy inside.
202194

203195
// Now the DB has policy, so we can provide a normal use case.
204196
// Create an adapter and an enforcer.
205197
// NewEnforcer() will load the policy automatically.
206-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
207198
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")
208199

209200
// Now set the adapter
@@ -236,16 +227,15 @@ func testRemovePolicies(t *testing.T, driverName string, dataSourceName string,
236227
testGetPolicy(t, e, [][]string{{"max", "data1", "delete"}})
237228
}
238229

239-
func testAddPolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
230+
func testAddPolicies(t *testing.T, a *Adapter) {
240231
// Initialize some policy in DB.
241-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
232+
initPolicy(t, a)
242233
// Note: you don't need to look at the above code
243234
// if you already have a working DB with policy inside.
244235

245236
// Now the DB has policy, so we can provide a normal use case.
246237
// Create an adapter and an enforcer.
247238
// NewEnforcer() will load the policy automatically.
248-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
249239
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")
250240

251241
// Now set the adapter
@@ -268,16 +258,15 @@ func testAddPolicies(t *testing.T, driverName string, dataSourceName string, dbS
268258
testGetPolicy(t, e, [][]string{{"max", "data2", "read"}, {"max", "data1", "write"}})
269259
}
270260

271-
func testUpdatePolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
261+
func testUpdatePolicies(t *testing.T, a *Adapter) {
272262
// Initialize some policy in DB.
273-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
263+
initPolicy(t, a)
274264
// Note: you don't need to look at the above code
275265
// if you already have a working DB with policy inside.
276266

277267
// Now the DB has policy, so we can provide a normal use case.
278268
// Create an adapter and an enforcer.
279269
// NewEnforcer() will load the policy automatically.
280-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
281270
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")
282271

283272
// Now set the adapter
@@ -301,16 +290,15 @@ func testUpdatePolicies(t *testing.T, driverName string, dataSourceName string,
301290
testGetPolicy(t, e, [][]string{{"bob", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
302291
}
303292

304-
func testUpdateFilteredPolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
293+
func testUpdateFilteredPolicies(t *testing.T, a *Adapter) {
305294
// Initialize some policy in DB.
306-
initPolicy(t, driverName, dataSourceName, dbSpecified...)
295+
initPolicy(t, a)
307296
// Note: you don't need to look at the above code
308297
// if you already have a working DB with policy inside.
309298

310299
// Now the DB has policy, so we can provide a normal use case.
311300
// Create an adapter and an enforcer.
312301
// NewEnforcer() will load the policy automatically.
313-
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
314302
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")
315303

316304
// Now set the adapter
@@ -370,23 +358,30 @@ func TestAdapters(t *testing.T) {
370358
// You can also use the following way to use an existing DB "abc":
371359
// testSaveLoad(t, "mysql", "root:@tcp(127.0.0.1:3306)/abc", true)
372360

373-
testSaveLoad(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
374-
testSaveLoad(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
375-
376-
testAutoSave(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
377-
testAutoSave(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
378-
379-
testFilteredPolicy(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
380-
381-
testAddPolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
382-
testAddPolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
383-
384-
testRemovePolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
385-
testRemovePolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
386-
387-
testUpdatePolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
388-
testUpdatePolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
389-
390-
testUpdateFilteredPolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
391-
testUpdateFilteredPolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
361+
a, _ := NewAdapter("mysql", "root:@tcp(127.0.0.1:3306)/")
362+
testSaveLoad(t, a)
363+
testAutoSave(t, a)
364+
testFilteredPolicy(t, a)
365+
testAddPolicies(t, a)
366+
testRemovePolicies(t, a)
367+
testUpdatePolicies(t, a)
368+
testUpdateFilteredPolicies(t, a)
369+
370+
a, _ = NewAdapter("postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
371+
testSaveLoad(t, a)
372+
testAutoSave(t, a)
373+
testFilteredPolicy(t, a)
374+
testAddPolicies(t, a)
375+
testRemovePolicies(t, a)
376+
testUpdatePolicies(t, a)
377+
testUpdateFilteredPolicies(t, a)
378+
379+
a, _ = NewAdapterWithTableName("mysql", "root:@tcp(127.0.0.1:3306)/", "test", "abc")
380+
testSaveLoad(t, a)
381+
testAutoSave(t, a)
382+
testFilteredPolicy(t, a)
383+
testAddPolicies(t, a)
384+
testRemovePolicies(t, a)
385+
testUpdatePolicies(t, a)
386+
testUpdateFilteredPolicies(t, a)
392387
}

0 commit comments

Comments
 (0)