Skip to content

Commit f950002

Browse files
committed
Propogate context parameter
1 parent 8c2fddb commit f950002

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+173
-189
lines changed

driver/rows.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (r *Rows) convert(col int, v driver.Value) interface{} {
114114
}
115115
}
116116

117-
sqlValue, _, err := r.cols[col].Type.Convert(ctx, v)
117+
sqlValue, _, err := r.cols[col].Type.Convert(r.ctx, v)
118118
if err != nil {
119119
break
120120
}

driver/value.go

-43
Original file line numberDiff line numberDiff line change
@@ -16,56 +16,13 @@ package driver
1616

1717
import (
1818
"database/sql/driver"
19-
"errors"
20-
"fmt"
2119
"strconv"
2220
"time"
2321

2422
"github.com/dolthub/vitess/go/sqltypes"
2523
"github.com/dolthub/vitess/go/vt/sqlparser"
26-
27-
"github.com/dolthub/go-mysql-server/sql"
28-
"github.com/dolthub/go-mysql-server/sql/expression"
29-
"github.com/dolthub/go-mysql-server/sql/types"
3024
)
3125

32-
// ErrUnsupportedType is returned when a query argument of an unsupported type is passed to a statement
33-
var ErrUnsupportedType = errors.New("unsupported type")
34-
35-
func valueToExpr(v driver.Value) (sql.Expression, error) {
36-
if v == nil {
37-
return expression.NewLiteral(nil, types.Null), nil
38-
}
39-
40-
var typ sql.Type
41-
var err error
42-
switch v := v.(type) {
43-
case int64:
44-
typ = types.Int64
45-
case float64:
46-
typ = types.Float64
47-
case bool:
48-
typ = types.Boolean
49-
case []byte:
50-
typ, err = types.CreateBinary(sqltypes.Blob, int64(len(v)))
51-
case string:
52-
typ, err = types.CreateStringWithDefaults(sqltypes.Text, int64(len(v)))
53-
case time.Time:
54-
typ = types.Datetime
55-
default:
56-
return nil, fmt.Errorf("%w: %T", ErrUnsupportedType, v)
57-
}
58-
if err != nil {
59-
return nil, err
60-
}
61-
62-
c, _, err := typ.Convert(ctx, v)
63-
if err != nil {
64-
return nil, err
65-
}
66-
return expression.NewLiteral(c, typ), nil
67-
}
68-
6926
func valuesToBindings(vals []driver.Value) (map[string]sqlparser.Expr, error) {
7027
if len(vals) == 0 {
7128
return nil, nil

engine.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ func clearWarnings(ctx *sql.Context, node sql.Node) {
273273
}
274274
}
275275

276-
func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.Expression, error) {
276+
func bindingsToExprs(ctx *sql.Context, bindings map[string]*querypb.BindVariable) (map[string]sql.Expression, error) {
277277
res := make(map[string]sql.Expression, len(bindings))
278278
for k, v := range bindings {
279279
v, err := sqltypes.NewValue(v.Type, v.Value)

engine_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ func TestBindingsToExprs(t *testing.T) {
139139
},
140140
}
141141

142+
ctx := sql.NewEmptyContext()
142143
for _, c := range cases {
143144
t.Run(c.Name, func(t *testing.T) {
144-
res, err := bindingsToExprs(c.Bindings)
145+
res, err := bindingsToExprs(ctx, c.Bindings)
145146
if !c.Err {
146147
require.NoError(t, err)
147148
require.Equal(t, c.Result, res)

enginetest/enginetests.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2082,7 +2082,7 @@ func TestUserPrivileges(t *testing.T, harness ClientHarness) {
20822082
// See the comment on QuickPrivilegeTest for a more in-depth explanation, but essentially we treat
20832083
// nil in script.Expected as matching "any" non-error result.
20842084
if script.Expected != nil && (rows != nil || len(script.Expected) != 0) {
2085-
CheckResults(t, harness, script.Expected, nil, sch, rows, lastQuery, engine)
2085+
CheckResults(ctx, t, harness, script.Expected, nil, sch, rows, lastQuery, engine)
20862086
}
20872087
})
20882088
}

enginetest/evaluation.go

+11-39
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ func testQueryWithContext(
412412
require.NoError(err, "Unexpected error for query %s: %s", q, err)
413413

414414
if expected != nil {
415-
checkResults(t, harness, expected, expectedCols, sch, rows, q, e, wrapBehavior)
415+
checkResults(t, ctx, harness, expected, expectedCols, sch, rows, q, e, wrapBehavior)
416416
}
417417

418418
require.Equal(
@@ -455,7 +455,7 @@ func TestQueryWithIndexCheck(t *testing.T, ctx *sql.Context, e QueryEngine, harn
455455
require.NoError(err, "Unexpected error for query %s: %s", q, err)
456456

457457
if expected != nil {
458-
CheckResults(t, harness, expected, expectedCols, sch, rows, q, e)
458+
CheckResults(ctx, t, harness, expected, expectedCols, sch, rows, q, e)
459459
}
460460

461461
require.Equal(
@@ -514,41 +514,22 @@ func TestPreparedQueryWithContext(t *testing.T, ctx *sql.Context, e QueryEngine,
514514

515515
if expected != nil {
516516
// TODO fix expected cols for prepared?
517-
CheckResults(t, h, expected, expectedCols, sch, rows, q, e)
517+
CheckResults(ctx, t, h, expected, expectedCols, sch, rows, q, e)
518518
}
519519

520520
require.Equal(0, ctx.Memory.NumCaches())
521521
validateEngine(t, ctx, h, e)
522522
}
523523

524-
func CheckResults(
525-
t *testing.T,
526-
h Harness,
527-
expected []sql.Row,
528-
expectedCols []*sql.Column,
529-
sch sql.Schema,
530-
rows []sql.Row,
531-
q string,
532-
e QueryEngine,
533-
) {
534-
checkResults(t, h, expected, expectedCols, sch, rows, q, e, queries.WrapBehavior_Unwrap)
524+
func CheckResults(ctx *sql.Context, t *testing.T, h Harness, expected []sql.Row, expectedCols []*sql.Column, sch sql.Schema, rows []sql.Row, q string, e QueryEngine) {
525+
checkResults(t, ctx, h, expected, expectedCols, sch, rows, q, e, queries.WrapBehavior_Unwrap)
535526
}
536527

537-
func checkResults(
538-
t *testing.T,
539-
h Harness,
540-
expected []sql.Row,
541-
expectedCols []*sql.Column,
542-
sch sql.Schema,
543-
rows []sql.Row,
544-
q string,
545-
e QueryEngine,
546-
wrapBehavior queries.WrapBehavior,
547-
) {
528+
func checkResults(t *testing.T, ctx *sql.Context, h Harness, expected []sql.Row, expectedCols []*sql.Column, sch sql.Schema, rows []sql.Row, q string, e QueryEngine, wrapBehavior queries.WrapBehavior) {
548529
if reh, ok := h.(ResultEvaluationHarness); ok {
549530
reh.EvaluateQueryResults(t, expected, expectedCols, sch, rows, q, wrapBehavior)
550531
} else {
551-
checkResultsDefault(t, expected, expectedCols, sch, rows, q, e, wrapBehavior)
532+
checkResultsDefault(t, ctx, expected, expectedCols, sch, rows, q, e, wrapBehavior)
552533
}
553534
}
554535

@@ -688,7 +669,7 @@ type CustomValueValidator interface {
688669
// toSQL converts the given expected value into appropriate type of given column.
689670
// |isZeroTime| is true if the query is any `SHOW` statement, except for `SHOW EVENTS`.
690671
// This is set earlier in `checkResult()` method.
691-
func toSQL(c *sql.Column, expected any, isZeroTime bool) (any, error) {
672+
func toSQL(ctx *sql.Context, c *sql.Column, expected any, isZeroTime bool) (any, error) {
692673
_, isTime := expected.(time.Time)
693674
_, isStr := expected.(string)
694675
// cases where we don't want the result value to be converted
@@ -705,16 +686,7 @@ func toSQL(c *sql.Column, expected any, isZeroTime bool) (any, error) {
705686
// don't implement ResultEvaluationHarness. All numerical values are widened to their widest type before comparison.
706687
// Based on the value of |unwrapValues|, this either normalized wrapped values by unwrapping them, or replaces them
707688
// with their hash so the test caller can assert that the values are wrapped and have a certain hash.
708-
func checkResultsDefault(
709-
t *testing.T,
710-
expected []sql.Row,
711-
expectedCols []*sql.Column,
712-
sch sql.Schema,
713-
rows []sql.Row,
714-
q string,
715-
e QueryEngine,
716-
wrapBehavior queries.WrapBehavior,
717-
) {
689+
func checkResultsDefault(t *testing.T, ctx *sql.Context, expected []sql.Row, expectedCols []*sql.Column, sch sql.Schema, rows []sql.Row, q string, e QueryEngine, wrapBehavior queries.WrapBehavior) {
718690
widenedRows := WidenRows(t, sch, rows)
719691
widenedExpected := WidenRows(t, sch, expected)
720692

@@ -804,7 +776,7 @@ func checkResultsDefault(
804776
} else {
805777
// this attempts to do what `rowToSQL()` method in `handler.go` on expected row
806778
// because over the wire values gets converted to SQL values depending on the column types.
807-
convertedExpected, err := toSQL(sch[j], widenedExpected[i][j], setZeroTime)
779+
convertedExpected, err := toSQL(ctx, sch[j], widenedExpected[i][j], setZeroTime)
808780
require.NoError(t, err)
809781
widenedExpected[i][j] = convertedExpected
810782
}
@@ -1120,7 +1092,7 @@ func AssertWarningAndTestQuery(
11201092
}
11211093

11221094
if !skipResultsCheck {
1123-
CheckResults(t, harness, expected, expectedCols, sch, rows, query, e)
1095+
CheckResults(ctx, t, harness, expected, expectedCols, sch, rows, query, e)
11241096
}
11251097
validateEngine(t, ctx, harness, e)
11261098
}

enginetest/join_planning_tests.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1900,7 +1900,7 @@ func evalJoinCorrectness(t *testing.T, harness Harness, e QueryEngine, name, q s
19001900
require.NoError(t, err, "Unexpected error for query %s: %s", q, err)
19011901

19021902
if exp != nil {
1903-
CheckResults(t, harness, exp, nil, sch, rows, q, e)
1903+
CheckResults(ctx, t, harness, exp, nil, sch, rows, q, e)
19041904
}
19051905

19061906
require.Equal(t, 0, ctx.Memory.NumCaches())

enginetest/queries/insert_queries.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424
"github.com/dolthub/go-mysql-server/sql/types"
2525
)
2626

27+
var sqlCtx = sql.NewEmptyContext()
28+
2729
var InsertQueries = []WriteQueryTest{
2830
{
2931
WriteQuery: "INSERT INTO keyless VALUES ();",
@@ -113,7 +115,7 @@ var InsertQueries = []WriteQueryTest{
113115
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
114116
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
115117
float32(math.MaxFloat32), float64(math.MaxFloat64),
116-
sql.MustConvert(types.Timestamp.Convert(ctx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(ctx, "2231-11-07")),
118+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(sqlCtx, "2231-11-07")),
117119
"random text", sql.True, types.MustJSON(`{"key":"value"}`), []byte("blobdata"), "v1", "v2",
118120
}},
119121
},
@@ -131,7 +133,7 @@ var InsertQueries = []WriteQueryTest{
131133
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
132134
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
133135
float32(math.MaxFloat32), float64(math.MaxFloat64),
134-
sql.MustConvert(types.Timestamp.Convert(ctx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(ctx, "2231-11-07")),
136+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(sqlCtx, "2231-11-07")),
135137
"random text", sql.True, types.MustJSON(`{"key":"value"}`), []byte("blobdata"), "v1", "v2",
136138
}},
137139
},
@@ -188,7 +190,7 @@ var InsertQueries = []WriteQueryTest{
188190
int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
189191
uint8(0), uint16(0), uint32(0), uint64(0),
190192
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
191-
sql.MustConvert(types.Timestamp.Convert(ctx, "2037-04-05 12:51:36")), types.Date.Zero(),
193+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2037-04-05 12:51:36")), types.Date.Zero(),
192194
"", sql.False, types.MustJSON(`""`), []byte(""), "v1", "v2",
193195
}},
194196
},
@@ -209,7 +211,7 @@ var InsertQueries = []WriteQueryTest{
209211
WriteQuery: `INSERT INTO typestable (id, ti, da) VALUES (999, '2021-09-1', '2021-9-01');`,
210212
ExpectedWriteResult: []sql.Row{{types.NewOkResult(1)}},
211213
SelectQuery: "SELECT id, ti, da FROM typestable WHERE id = 999;",
212-
ExpectedSelect: []sql.Row{{int64(999), sql.MustConvert(types.Timestamp.Convert(ctx, "2021-09-01")), sql.MustConvert(types.Date.Convert(ctx, "2021-09-01"))}},
214+
ExpectedSelect: []sql.Row{{int64(999), sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2021-09-01")), sql.MustConvert(types.Date.Convert(sqlCtx, "2021-09-01"))}},
213215
},
214216
{
215217
WriteQuery: `INSERT INTO typestable SET id=999, i8=null, i16=null, i32=null, i64=null, u8=null, u16=null, u32=null, u64=null,

enginetest/queries/replace_queries.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ var ReplaceQueries = []WriteQueryTest{
8686
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
8787
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
8888
float32(math.MaxFloat32), float64(math.MaxFloat64),
89-
sql.MustConvert(types.Timestamp.Convert(ctx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(ctx, "2231-11-07")),
89+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(sqlCtx, "2231-11-07")),
9090
"random text", sql.True, types.MustJSON(`{"key":"value"}`), []byte("blobdata"), "v1", "v2",
9191
}},
9292
},
@@ -104,7 +104,7 @@ var ReplaceQueries = []WriteQueryTest{
104104
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
105105
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
106106
float32(math.MaxFloat32), float64(math.MaxFloat64),
107-
sql.MustConvert(types.Timestamp.Convert(ctx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(ctx, "2231-11-07")),
107+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2037-04-05 12:51:36")), sql.MustConvert(types.Date.Convert(sqlCtx, "2231-11-07")),
108108
"random text", sql.True, types.MustJSON(`{"key":"value"}`), []byte("blobdata"), "v1", "v2",
109109
}},
110110
},

enginetest/queries/update_queries.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ var UpdateTests = []WriteQueryTest{
125125
uint64(9),
126126
float32(10),
127127
float64(11),
128-
sql.MustConvert(types.Timestamp.Convert(ctx, "2020-03-06 00:00:00")),
129-
sql.MustConvert(types.Date.Convert(ctx, "2019-12-31")),
128+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2020-03-06 00:00:00")),
129+
sql.MustConvert(types.Date.Convert(sqlCtx, "2019-12-31")),
130130
"fourteen",
131131
0,
132132
nil,
@@ -149,8 +149,8 @@ var UpdateTests = []WriteQueryTest{
149149
uint64(9),
150150
float32(10),
151151
float64(11),
152-
sql.MustConvert(types.Timestamp.Convert(ctx, "2020-03-06 00:00:00")),
153-
sql.MustConvert(types.Date.Convert(ctx, "2020-03-06")),
152+
sql.MustConvert(types.Timestamp.Convert(sqlCtx, "2020-03-06 00:00:00")),
153+
sql.MustConvert(types.Date.Convert(sqlCtx, "2020-03-06")),
154154
"fourteen",
155155
0,
156156
nil,

enginetest/server_engine.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
190190
// However, Dolt supports, but not go-sql-driver client
191191
switch parsed.(type) {
192192
case *sqlparser.Load, *sqlparser.Execute, *sqlparser.Prepare:
193-
return s.queryOrExec(nil, parsed, query, []any{})
193+
return s.queryOrExec(ctx, nil, parsed, query, []any{})
194194
}
195195

196196
stmt, err := s.conn.Prepare(query)
@@ -203,7 +203,7 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
203203
return nil, nil, nil, err
204204
}
205205

206-
return s.queryOrExec(stmt, parsed, query, args)
206+
return s.queryOrExec(ctx, stmt, parsed, query, args)
207207
}
208208

209209
// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
@@ -212,7 +212,7 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
212212
// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
213213
//
214214
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result.
215-
func (s *ServerQueryEngine) queryOrExec(stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
215+
func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
216216
var err error
217217
switch parsed.(type) {
218218
// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
@@ -226,7 +226,7 @@ func (s *ServerQueryEngine) queryOrExec(stmt *gosql.Stmt, parsed sqlparser.State
226226
if err != nil {
227227
return nil, nil, nil, trimMySQLErrCodePrefix(err)
228228
}
229-
return convertRowsResult(rows)
229+
return convertRowsResult(ctx, rows)
230230
default:
231231
var res gosql.Result
232232
if stmt != nil {
@@ -280,21 +280,21 @@ func convertExecResult(exec gosql.Result) (sql.Schema, sql.RowIter, *sql.QueryFl
280280
return types.OkResultSchema, sql.RowsToRowIter(sql.NewRow(okResult)), nil, nil
281281
}
282282

283-
func convertRowsResult(rows *gosql.Rows) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
283+
func convertRowsResult(ctx *sql.Context, rows *gosql.Rows) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
284284
sch, err := schemaForRows(rows)
285285
if err != nil {
286286
return nil, nil, nil, err
287287
}
288288

289-
rowIter, err := rowIterForGoSqlRows(sch, rows)
289+
rowIter, err := rowIterForGoSqlRows(ctx, sch, rows)
290290
if err != nil {
291291
return nil, nil, nil, err
292292
}
293293

294294
return sch, rowIter, nil, nil
295295
}
296296

297-
func rowIterForGoSqlRows(sch sql.Schema, rows *gosql.Rows) (sql.RowIter, error) {
297+
func rowIterForGoSqlRows(ctx *sql.Context, sch sql.Schema, rows *gosql.Rows) (sql.RowIter, error) {
298298
result := make([]sql.Row, 0)
299299
r, err := emptyRowForSchema(sch)
300300
if err != nil {
@@ -312,7 +312,7 @@ func rowIterForGoSqlRows(sch sql.Schema, rows *gosql.Rows) (sql.RowIter, error)
312312
return nil, err
313313
}
314314

315-
row = convertValue(sch, row)
315+
row = convertValue(ctx, sch, row)
316316

317317
result = append(result, row)
318318
}
@@ -322,7 +322,7 @@ func rowIterForGoSqlRows(sch sql.Schema, rows *gosql.Rows) (sql.RowIter, error)
322322

323323
// convertValue converts the row value scanned from go sql driver client to type that we expect.
324324
// This method helps with testing existing enginetests that expects specific type as returned value.
325-
func convertValue(sch sql.Schema, row sql.Row) sql.Row {
325+
func convertValue(ctx *sql.Context, sch sql.Schema, row sql.Row) sql.Row {
326326
for i, col := range sch {
327327
switch col.Type.Type() {
328328
case query.Type_GEOMETRY:

enginetest/spatial_index_tests.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ func evalSpatialIndexPlanCorrectness(t *testing.T, harness Harness, e QueryEngin
427427
require.NoError(t, err, "Unexpected error for q %s: %s", q, err)
428428

429429
if exp != nil {
430-
CheckResults(t, harness, exp, nil, sch, rows, q, e)
430+
CheckResults(ctx, t, harness, exp, nil, sch, rows, q, e)
431431
}
432432

433433
require.Equal(t, 0, ctx.Memory.NumCaches())

memory/exponential_dist_table.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (s ExponentialDistTable) UnderlyingTable() sql.Table {
2828
return s
2929
}
3030

31-
func (s ExponentialDistTable) NewInstance(_ *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
31+
func (s ExponentialDistTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
3232
if len(args) != 3 {
3333
return nil, fmt.Errorf("exponential_dist table expects 2 arguments: (cols, rows, lambda)")
3434
}

0 commit comments

Comments
 (0)