@@ -22,6 +22,7 @@ import (
22
22
"testing"
23
23
24
24
"github.com/stretchr/testify/assert"
25
+ "github.com/stretchr/testify/require"
25
26
26
27
"vitess.io/vitess/go/sqltypes"
27
28
"vitess.io/vitess/go/vt/vttablet/endtoend/framework"
@@ -97,9 +98,7 @@ func TestTableACL(t *testing.T) {
97
98
for _ , tcase := range execCases {
98
99
_ , err := client .Execute (tcase .query , nil )
99
100
if tcase .err == "" {
100
- if err != nil {
101
- t .Error (err )
102
- }
101
+ assert .NoError (t , err )
103
102
continue
104
103
}
105
104
assert .ErrorContains (t , err , tcase .err )
@@ -125,9 +124,7 @@ func TestTableACL(t *testing.T) {
125
124
for _ , tcase := range streamCases {
126
125
_ , err := client .StreamExecute (tcase .query , nil )
127
126
if tcase .err == "" {
128
- if err != nil {
129
- t .Error (err )
130
- }
127
+ assert .NoError (t , err )
131
128
continue
132
129
}
133
130
assert .ErrorContains (t , err , tcase .err )
@@ -145,28 +142,19 @@ var rulesJSON = []byte(`[{
145
142
}]` )
146
143
147
144
func TestQueryRules (t * testing.T ) {
148
- rules := rules .New ()
149
- err := rules .UnmarshalJSON (rulesJSON )
150
- if err != nil {
151
- t .Error (err )
152
- return
153
- }
154
- err = framework .Server .SetQueryRules ("endtoend" , rules )
155
- want := "Rule source identifier endtoend is not valid"
156
- if err == nil || err .Error () != want {
157
- t .Errorf ("Error: %v, want %s" , err , want )
158
- }
145
+ r := rules .New ()
146
+ err := r .UnmarshalJSON (rulesJSON )
147
+ require .NoError (t , err )
148
+ err = framework .Server .SetQueryRules ("endtoend" , r )
149
+ assert .EqualError (t , err , "Rule source identifier endtoend is not valid" )
159
150
160
151
framework .Server .RegisterQueryRuleSource ("endtoend" )
161
152
defer framework .Server .UnRegisterQueryRuleSource ("endtoend" )
162
- err = framework .Server .SetQueryRules ("endtoend" , rules )
163
- if err != nil {
164
- t .Error (err )
165
- return
166
- }
153
+ err = framework .Server .SetQueryRules ("endtoend" , r )
154
+ require .NoError (t , err )
167
155
168
156
rulesJSON := compacted (framework .FetchURL ("/debug/query_rules" ))
169
- want = compacted (`{
157
+ expectJson : = compacted (`{
170
158
"endtoend":[{
171
159
"Description": "disallow bindvar 'asdfg'",
172
160
"Name": "r1",
@@ -178,34 +166,20 @@ func TestQueryRules(t *testing.T) {
178
166
"Action": "FAIL"
179
167
}]
180
168
}` )
181
- if rulesJSON != want {
182
- t .Errorf ("/debug/query_rules:\n %v, want\n %s" , rulesJSON , want )
183
- }
184
-
169
+ assert .Equal (t , expectJson , rulesJSON , "/debug/query_rules" )
185
170
client := framework .NewClient ()
186
171
query := "select * from vitess_test where intval=:asdfg"
187
172
bv := map [string ]* querypb.BindVariable {"asdfg" : sqltypes .Int64BindVariable (1 )}
188
173
_ , err = client .Execute (query , bv )
189
- want = "disallowed due to rule: disallow bindvar 'asdfg' (CallerID: dev)"
190
- if err == nil || err .Error () != want {
191
- t .Errorf ("Error: %v, want %s" , err , want )
192
- }
174
+ errString := "disallowed due to rule: disallow bindvar 'asdfg' (CallerID: dev)"
175
+ assert .EqualError (t , err , errString )
193
176
_ , err = client .StreamExecute (query , bv )
194
- want = "disallowed due to rule: disallow bindvar 'asdfg' (CallerID: dev)"
195
- if err == nil || err .Error () != want {
196
- t .Errorf ("Error: %v, want %s" , err , want )
197
- }
177
+ assert .EqualError (t , err , errString )
198
178
199
179
err = framework .Server .SetQueryRules ("endtoend" , nil )
200
- if err != nil {
201
- t .Error (err )
202
- return
203
- }
180
+ require .NoError (t , err )
204
181
_ , err = client .Execute (query , bv )
205
- if err != nil {
206
- t .Error (err )
207
- return
208
- }
182
+ require .NoError (t , err )
209
183
}
210
184
211
185
func compacted (in string ) string {
0 commit comments