Skip to content

Commit 1b9c96c

Browse files
authored
Add NoTable reference mode (#21)
* Add NoTable reference mode * Add Refs for multiple references * Fix lint
1 parent 66a3996 commit 1b9c96c

File tree

2 files changed

+120
-27
lines changed

2 files changed

+120
-27
lines changed

referencer.go

+95-26
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ type Referencer struct {
6161
// Default QuoteNoop.
6262
IdentifierQuoter func(tableAndColumn ...string) string
6363

64-
refs map[interface{}]Quoted
65-
columnNames map[interface{}]string
66-
structColumns map[interface{}][]string
64+
refs map[interface{}]Quoted
65+
quotedCols map[interface{}]Quoted
66+
columnNames map[interface{}]string
67+
structRefs map[interface{}][]string
6768
}
6869

6970
// ColumnsOf makes a Mapper option to prefix columns with table alias.
@@ -93,6 +94,36 @@ func (r *Referencer) ColumnsOf(rowStructPtr interface{}) func(o *Options) {
9394
}
9495
}
9596

97+
// QuotedNoTable is a container of field pointer that should be referenced without table.
98+
type QuotedNoTable struct {
99+
ptr interface{}
100+
}
101+
102+
// NoTable enables references without table prefix.
103+
// So that `my_table`.`my_column` would be rendered as `my_column`.
104+
//
105+
// r.Ref(sqluct.NoTable(&row.MyColumn))
106+
// r.Fmt("%s = 1", sqluct.NoTable(&row.MyColumn))
107+
//
108+
// Such references may be useful for INSERT/UPDATE column expressions.
109+
func NoTable(ptr interface{}) QuotedNoTable {
110+
return QuotedNoTable{ptr: ptr}
111+
}
112+
113+
// NoTableAll enables references without table prefix for all field pointers.
114+
// It can be useful to prepare multiple variadic arguments.
115+
//
116+
// r.Fmt("ON CONFLICT(%s) DO UPDATE SET %s = excluded.%s, %s = excluded.%s",
117+
// sqluct.NoTableAll(&row.ID, &row.F1, &row.F1, &row.F2, &row.F3)...)
118+
func NoTableAll(ptrs ...interface{}) []interface{} {
119+
res := make([]interface{}, 0, len(ptrs))
120+
for _, ptr := range ptrs {
121+
res = append(res, NoTable(ptr))
122+
}
123+
124+
return res
125+
}
126+
96127
// AddTableAlias creates string references for row pointer and all suitable field pointers in it.
97128
//
98129
// Empty alias is not added to column reference.
@@ -106,37 +137,42 @@ func (r *Referencer) AddTableAlias(rowStructPtr interface{}, alias string) {
106137
r.refs = make(map[interface{}]Quoted, len(f)+1)
107138
}
108139

140+
if r.quotedCols == nil {
141+
r.quotedCols = make(map[interface{}]Quoted, len(f)+1)
142+
}
143+
109144
if r.columnNames == nil {
110145
r.columnNames = make(map[interface{}]string, len(f))
111146
}
112147

113-
if r.structColumns == nil {
114-
r.structColumns = make(map[interface{}][]string)
148+
if r.structRefs == nil {
149+
r.structRefs = make(map[interface{}][]string)
115150
}
116151

117152
if alias != "" {
118153
r.refs[rowStructPtr] = r.Q(alias)
119154
}
120155

121-
columns := make([]string, 0, len(f))
156+
refs := make([]string, 0, len(f))
122157

123158
for ptr, fieldName := range f {
124-
var col string
159+
var ref Quoted
125160

126161
if alias == "" {
127-
col = string(r.Q(fieldName))
162+
ref = r.Q(fieldName)
128163
} else {
129-
col = string(r.Q(alias, fieldName))
164+
ref = r.Q(alias, fieldName)
130165
}
131166

132-
columns = append(columns, col)
133-
r.refs[ptr] = Quoted(col)
167+
refs = append(refs, string(ref))
168+
r.refs[ptr] = ref
169+
r.quotedCols[ptr] = r.Q(fieldName)
134170
r.columnNames[ptr] = fieldName
135171
}
136172

137-
sort.Strings(columns)
173+
sort.Strings(refs)
138174

139-
r.structColumns[rowStructPtr] = columns
175+
r.structRefs[rowStructPtr] = refs
140176
}
141177

142178
// Quoted is a string that can be interpolated into an SQL statement as is.
@@ -155,11 +191,49 @@ func (r *Referencer) Q(tableAndColumn ...string) Quoted {
155191
//
156192
// It panics if pointer is unknown.
157193
func (r *Referencer) Ref(ptr interface{}) string {
158-
if ref, found := r.refs[ptr]; found {
159-
return string(ref)
194+
s, err := r.ref(ptr)
195+
if err != nil {
196+
panic(err)
160197
}
161198

162-
panic(errUnknownFieldOrRow)
199+
return s
200+
}
201+
202+
func (r *Referencer) ref(ptr interface{}) (string, error) {
203+
if q, ok := ptr.(Quoted); ok {
204+
return string(q), nil
205+
}
206+
207+
refs := r.refs
208+
209+
if nt, ok := ptr.(QuotedNoTable); ok {
210+
ptr = nt.ptr
211+
refs = r.quotedCols
212+
}
213+
214+
if ref, found := refs[ptr]; found {
215+
return string(ref), nil
216+
}
217+
218+
return "", errUnknownFieldOrRow
219+
}
220+
221+
// Refs returns reference strings for multiple field pointers.
222+
//
223+
// It panics if pointer is unknown.
224+
func (r *Referencer) Refs(ptrs ...interface{}) []string {
225+
args := make([]string, 0, len(ptrs))
226+
227+
for i, fieldPtr := range ptrs {
228+
ref, err := r.ref(fieldPtr)
229+
if err != nil {
230+
panic(fmt.Errorf("%w at position %d", err, i))
231+
}
232+
233+
args = append(args, ref)
234+
}
235+
236+
return args
163237
}
164238

165239
// Col returns unescaped column name for field pointer that was previously added with AddTableAlias.
@@ -181,25 +255,20 @@ func (r *Referencer) Fmt(format string, ptrs ...interface{}) string {
181255
args := make([]interface{}, 0, len(ptrs))
182256

183257
for i, fieldPtr := range ptrs {
184-
if q, ok := fieldPtr.(Quoted); ok {
185-
args = append(args, string(q))
186-
187-
continue
258+
ref, err := r.ref(fieldPtr)
259+
if err != nil {
260+
panic(fmt.Errorf("%w at position %d", err, i))
188261
}
189262

190-
if ref, found := r.refs[fieldPtr]; found {
191-
args = append(args, ref)
192-
} else {
193-
panic(fmt.Errorf("%w at position %d", errUnknownFieldOrRow, i))
194-
}
263+
args = append(args, ref)
195264
}
196265

197266
return fmt.Sprintf(format, args...)
198267
}
199268

200269
// Cols returns column references of a row structure.
201270
func (r *Referencer) Cols(ptr interface{}) []string {
202-
if cols, found := r.structColumns[ptr]; found {
271+
if cols, found := r.structRefs[ptr]; found {
203272
return cols
204273
}
205274

referencer_test.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func BenchmarkReferencer_Fmt_lite(b *testing.B) {
225225

226226
for i := 0; i < b.N; i++ {
227227
// Find direct reports that share same last name and manager is not named John.
228-
qb := squirrel.StatementBuilder.Select(rf.Fmt("%s, %s", &dr.ManagerID, &dr.EmployeeID)).
228+
qb := squirrel.StatementBuilder.Select(rf.Refs(&dr.ManagerID, &dr.EmployeeID)...).
229229
From(rf.Fmt("%s AS %s", rf.Q("users"), manager)).
230230
InnerJoin(rf.Fmt("%s AS %s ON %s = %s AND %s = %s",
231231
rf.Q("direct_reports"), dr,
@@ -258,3 +258,27 @@ func BenchmarkReferencer_Fmt_raw(b *testing.B) {
258258
}
259259
}
260260
}
261+
262+
func TestNoTable(t *testing.T) {
263+
ref := sqluct.Referencer{}
264+
ref.Mapper = &sqluct.Mapper{Dialect: sqluct.DialectSQLite3}
265+
ref.IdentifierQuoter = sqluct.QuoteBackticks
266+
267+
type User struct {
268+
ID int `db:"id"`
269+
FirstName string `db:"first_name"`
270+
LastName string `db:"last_name"`
271+
}
272+
273+
row := &User{}
274+
275+
ref.AddTableAlias(row, "users")
276+
277+
expr := ref.Fmt("ON CONFLICT(%s) DO UPDATE SET %s = excluded.%s, %s = excluded.%s",
278+
sqluct.NoTableAll(&row.ID, &row.FirstName, &row.FirstName, &row.LastName, &row.LastName)...)
279+
280+
assert.Equal(t, "ON CONFLICT(`id`) DO UPDATE SET `first_name` = excluded.`first_name`, `last_name` = excluded.`last_name`", expr)
281+
282+
assert.Equal(t, "`first_name`", ref.Ref(sqluct.NoTable(&row.FirstName)))
283+
assert.Equal(t, "`users`.`first_name`", ref.Ref(&row.FirstName))
284+
}

0 commit comments

Comments
 (0)