@@ -6,7 +6,7 @@ use crate::errors::ApiError;
6
6
use postgres:: types:: ToSql ;
7
7
use regex:: Regex ;
8
8
9
- #[ derive( Debug ) ]
9
+ #[ derive( Debug , PartialEq ) ]
10
10
enum PreparedStatementValue {
11
11
String ( String ) ,
12
12
Int8 ( i64 ) ,
@@ -63,7 +63,7 @@ fn build_select_statement(query: Query) -> Result<(String, Vec<PreparedStatement
63
63
validate_sql_name ( column) ?;
64
64
}
65
65
66
- statement. push_str ( & format ! ( " DISTINCT ON ({}) " , distinct_columns. join( ", " ) ) ) ;
66
+ statement. push_str ( & format ! ( " DISTINCT ON ({})" , distinct_columns. join( ", " ) ) ) ;
67
67
}
68
68
69
69
// building prepared statement
@@ -127,8 +127,25 @@ fn build_select_statement(query: Query) -> Result<(String, Vec<PreparedStatement
127
127
. map ( |column_str_raw| String :: from ( column_str_raw. trim ( ) ) )
128
128
. collect ( ) ;
129
129
130
+ lazy_static ! {
131
+ static ref ORDER_DIRECTION_RE : Regex = Regex :: new( r" ASC| DESC" ) . unwrap( ) ;
132
+ }
133
+
130
134
for column in & columns {
131
- validate_sql_name ( column) ?;
135
+ if ORDER_DIRECTION_RE . is_match ( column) {
136
+ // we need to account for ASC and DESC directions
137
+ match ORDER_DIRECTION_RE . find ( column) {
138
+ Some ( order_direction_match) => {
139
+ let order_by_column = & column[ ..order_direction_match. start ( ) ] ;
140
+ validate_sql_name ( order_by_column) ?;
141
+ }
142
+ None => {
143
+ validate_sql_name ( column) ?;
144
+ }
145
+ }
146
+ } else {
147
+ validate_sql_name ( column) ?;
148
+ }
132
149
}
133
150
134
151
statement. push_str ( & format ! ( " ORDER BY {}" , columns. join( ", " ) ) ) ;
@@ -146,3 +163,254 @@ fn build_select_statement(query: Query) -> Result<(String, Vec<PreparedStatement
146
163
147
164
Ok ( ( statement, prepared_values) )
148
165
}
166
+
167
+ #[ cfg( test) ]
168
+ mod build_select_statement_tests {
169
+ use super :: super :: query_types:: { QueryParams , QueryTasks } ;
170
+ use super :: * ;
171
+ use pretty_assertions:: assert_eq;
172
+
173
+ #[ test]
174
+ fn basic_query ( ) {
175
+ let query = Query {
176
+ params : QueryParams {
177
+ columns : vec ! [ "id" . to_string( ) ] ,
178
+ conditions : None ,
179
+ distinct : None ,
180
+ limit : 100 ,
181
+ offset : 0 ,
182
+ order_by : None ,
183
+ prepared_values : None ,
184
+ table : "a_table" . to_string ( ) ,
185
+ } ,
186
+ task : QueryTasks :: GetAllTables ,
187
+ } ;
188
+
189
+ match build_select_statement ( query) {
190
+ Ok ( ( sql, _) ) => {
191
+ assert_eq ! ( & sql, "SELECT id FROM a_table LIMIT 100;" ) ;
192
+ }
193
+ Err ( e) => {
194
+ assert ! ( false , e) ;
195
+ }
196
+ } ;
197
+ }
198
+
199
+ #[ test]
200
+ fn multiple_columns ( ) {
201
+ let query = Query {
202
+ params : QueryParams {
203
+ columns : vec ! [ "id" . to_string( ) , "name" . to_string( ) ] ,
204
+ conditions : None ,
205
+ distinct : None ,
206
+ limit : 100 ,
207
+ offset : 0 ,
208
+ order_by : None ,
209
+ prepared_values : None ,
210
+ table : "a_table" . to_string ( ) ,
211
+ } ,
212
+ task : QueryTasks :: GetAllTables ,
213
+ } ;
214
+
215
+ match build_select_statement ( query) {
216
+ Ok ( ( sql, _) ) => {
217
+ assert_eq ! ( & sql, "SELECT id, name FROM a_table LIMIT 100;" ) ;
218
+ }
219
+ Err ( e) => {
220
+ assert ! ( false , e) ;
221
+ }
222
+ } ;
223
+ }
224
+
225
+ #[ test]
226
+ fn distinct ( ) {
227
+ let query = Query {
228
+ params : QueryParams {
229
+ columns : vec ! [ "id" . to_string( ) ] ,
230
+ conditions : None ,
231
+ distinct : Some ( "name, blah" . to_string ( ) ) ,
232
+ limit : 100 ,
233
+ offset : 0 ,
234
+ order_by : None ,
235
+ prepared_values : None ,
236
+ table : "a_table" . to_string ( ) ,
237
+ } ,
238
+ task : QueryTasks :: GetAllTables ,
239
+ } ;
240
+
241
+ match build_select_statement ( query) {
242
+ Ok ( ( sql, _) ) => {
243
+ assert_eq ! (
244
+ & sql,
245
+ "SELECT DISTINCT ON (name, blah) id FROM a_table LIMIT 100;"
246
+ ) ;
247
+ }
248
+ Err ( e) => {
249
+ assert ! ( false , e) ;
250
+ }
251
+ } ;
252
+ }
253
+
254
+ #[ test]
255
+ fn offset ( ) {
256
+ let query = Query {
257
+ params : QueryParams {
258
+ columns : vec ! [ "id" . to_string( ) ] ,
259
+ conditions : None ,
260
+ distinct : None ,
261
+ limit : 1000 ,
262
+ offset : 100 ,
263
+ order_by : None ,
264
+ prepared_values : None ,
265
+ table : "a_table" . to_string ( ) ,
266
+ } ,
267
+ task : QueryTasks :: GetAllTables ,
268
+ } ;
269
+
270
+ match build_select_statement ( query) {
271
+ Ok ( ( sql, _) ) => {
272
+ assert_eq ! ( & sql, "SELECT id FROM a_table LIMIT 1000 OFFSET 100;" ) ;
273
+ }
274
+ Err ( e) => {
275
+ assert ! ( false , e) ;
276
+ }
277
+ } ;
278
+ }
279
+
280
+ #[ test]
281
+ fn order_by ( ) {
282
+ let query = Query {
283
+ params : QueryParams {
284
+ columns : vec ! [ "id" . to_string( ) ] ,
285
+ conditions : None ,
286
+ distinct : None ,
287
+ limit : 1000 ,
288
+ offset : 0 ,
289
+ order_by : Some ( "name,test" . to_string ( ) ) ,
290
+ prepared_values : None ,
291
+ table : "a_table" . to_string ( ) ,
292
+ } ,
293
+ task : QueryTasks :: GetAllTables ,
294
+ } ;
295
+
296
+ match build_select_statement ( query) {
297
+ Ok ( ( sql, _) ) => {
298
+ assert_eq ! (
299
+ & sql,
300
+ "SELECT id FROM a_table ORDER BY name, test LIMIT 1000;"
301
+ ) ;
302
+ }
303
+ Err ( e) => {
304
+ assert ! ( false , e) ;
305
+ }
306
+ } ;
307
+ }
308
+
309
+ #[ test]
310
+ fn conditions ( ) {
311
+ let query = Query {
312
+ params : QueryParams {
313
+ columns : vec ! [ "id" . to_string( ) ] ,
314
+ conditions : Some ( "(id > 10 OR id < 20) AND name = 'test'" . to_string ( ) ) ,
315
+ distinct : None ,
316
+ limit : 10 ,
317
+ offset : 0 ,
318
+ order_by : None ,
319
+ prepared_values : None ,
320
+ table : "a_table" . to_string ( ) ,
321
+ } ,
322
+ task : QueryTasks :: GetAllTables ,
323
+ } ;
324
+
325
+ match build_select_statement ( query) {
326
+ Ok ( ( sql, _) ) => {
327
+ assert_eq ! (
328
+ & sql,
329
+ "SELECT id FROM a_table WHERE ((id > 10 OR id < 20) AND name = 'test') LIMIT 10;"
330
+ ) ;
331
+ }
332
+ Err ( e) => {
333
+ assert ! ( false , e) ;
334
+ }
335
+ } ;
336
+ }
337
+
338
+ #[ test]
339
+ fn prepared_values ( ) {
340
+ let query = Query {
341
+ params : QueryParams {
342
+ columns : vec ! [ "id" . to_string( ) ] ,
343
+ conditions : Some ( "(id > $1 OR id < $2) AND name = $3" . to_string ( ) ) ,
344
+ distinct : None ,
345
+ limit : 10 ,
346
+ offset : 0 ,
347
+ order_by : None ,
348
+ prepared_values : Some ( "10,20,'test'" . to_string ( ) ) ,
349
+ table : "a_table" . to_string ( ) ,
350
+ } ,
351
+ task : QueryTasks :: GetAllTables ,
352
+ } ;
353
+
354
+ match build_select_statement ( query) {
355
+ Ok ( ( sql, prepared_values) ) => {
356
+ assert_eq ! (
357
+ & sql,
358
+ "SELECT id FROM a_table WHERE ((id > $1 OR id < $2) AND name = $3) LIMIT 10;"
359
+ ) ;
360
+
361
+ assert_eq ! (
362
+ prepared_values,
363
+ vec![
364
+ PreparedStatementValue :: Int4 ( 10 ) ,
365
+ PreparedStatementValue :: Int4 ( 20 ) ,
366
+ PreparedStatementValue :: String ( "test" . to_string( ) ) ,
367
+ ]
368
+ ) ;
369
+ }
370
+ Err ( e) => {
371
+ assert ! ( false , e) ;
372
+ }
373
+ } ;
374
+ }
375
+
376
+ #[ test]
377
+ fn complex_query ( ) {
378
+ let query = Query {
379
+ params : QueryParams {
380
+ columns : vec ! [
381
+ "id" . to_string( ) ,
382
+ "test_bigint" . to_string( ) ,
383
+ "test_bigserial" . to_string( ) ,
384
+ ] ,
385
+ conditions : Some ( "id = $1 AND test_name = $2" . to_string ( ) ) ,
386
+ distinct : Some ( "test_date,test_timestamptz" . to_string ( ) ) ,
387
+ limit : 10000 ,
388
+ offset : 2000 ,
389
+ order_by : Some ( "due_date DESC" . to_string ( ) ) ,
390
+ prepared_values : Some ( "46327143679919107,'a name'" . to_string ( ) ) ,
391
+ table : "a_table" . to_string ( ) ,
392
+ } ,
393
+ task : QueryTasks :: GetAllTables ,
394
+ } ;
395
+
396
+ match build_select_statement ( query) {
397
+ Ok ( ( sql, prepared_values) ) => {
398
+ assert_eq ! (
399
+ & sql,
400
+ "SELECT DISTINCT ON (test_date, test_timestamptz) id, test_bigint, test_bigserial FROM a_table WHERE (id = $1 AND test_name = $2) ORDER BY due_date DESC LIMIT 10000 OFFSET 2000;"
401
+ ) ;
402
+
403
+ assert_eq ! (
404
+ prepared_values,
405
+ vec![
406
+ PreparedStatementValue :: Int8 ( 46_327_143_679_919_107i64 ) ,
407
+ PreparedStatementValue :: String ( "a name" . to_string( ) ) ,
408
+ ]
409
+ ) ;
410
+ }
411
+ Err ( e) => {
412
+ assert ! ( false , format!( "{}" , e) ) ;
413
+ }
414
+ } ;
415
+ }
416
+ }
0 commit comments