1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Type conversions for Scan.
6
7package sql
8
9import (
10	"bytes"
11	"database/sql/driver"
12	"errors"
13	"fmt"
14	"reflect"
15	"strconv"
16	"time"
17	"unicode"
18	"unicode/utf8"
19	_ "unsafe" // for linkname
20)
21
22var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
23
24func describeNamedValue(nv *driver.NamedValue) string {
25	if len(nv.Name) == 0 {
26		return fmt.Sprintf("$%d", nv.Ordinal)
27	}
28	return fmt.Sprintf("with name %q", nv.Name)
29}
30
31func validateNamedValueName(name string) error {
32	if len(name) == 0 {
33		return nil
34	}
35	r, _ := utf8.DecodeRuneInString(name)
36	if unicode.IsLetter(r) {
37		return nil
38	}
39	return fmt.Errorf("name %q does not begin with a letter", name)
40}
41
42// ccChecker wraps the driver.ColumnConverter and allows it to be used
43// as if it were a NamedValueChecker. If the driver ColumnConverter
44// is not present then the NamedValueChecker will return driver.ErrSkip.
45type ccChecker struct {
46	cci  driver.ColumnConverter
47	want int
48}
49
50func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
51	if c.cci == nil {
52		return driver.ErrSkip
53	}
54	// The column converter shouldn't be called on any index
55	// it isn't expecting. The final error will be thrown
56	// in the argument converter loop.
57	index := nv.Ordinal - 1
58	if c.want <= index {
59		return nil
60	}
61
62	// First, see if the value itself knows how to convert
63	// itself to a driver type. For example, a NullString
64	// struct changing into a string or nil.
65	if vr, ok := nv.Value.(driver.Valuer); ok {
66		sv, err := callValuerValue(vr)
67		if err != nil {
68			return err
69		}
70		if !driver.IsValue(sv) {
71			return fmt.Errorf("non-subset type %T returned from Value", sv)
72		}
73		nv.Value = sv
74	}
75
76	// Second, ask the column to sanity check itself. For
77	// example, drivers might use this to make sure that
78	// an int64 values being inserted into a 16-bit
79	// integer field is in range (before getting
80	// truncated), or that a nil can't go into a NOT NULL
81	// column before going across the network to get the
82	// same error.
83	var err error
84	arg := nv.Value
85	nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
86	if err != nil {
87		return err
88	}
89	if !driver.IsValue(nv.Value) {
90		return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
91	}
92	return nil
93}
94
95// defaultCheckNamedValue wraps the default ColumnConverter to have the same
96// function signature as the CheckNamedValue in the driver.NamedValueChecker
97// interface.
98func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
99	nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
100	return err
101}
102
103// driverArgsConnLocked converts arguments from callers of Stmt.Exec and
104// Stmt.Query into driver Values.
105//
106// The statement ds may be nil, if no statement is available.
107//
108// ci must be locked.
109func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
110	nvargs := make([]driver.NamedValue, len(args))
111
112	// -1 means the driver doesn't know how to count the number of
113	// placeholders, so we won't sanity check input here and instead let the
114	// driver deal with errors.
115	want := -1
116
117	var si driver.Stmt
118	var cc ccChecker
119	if ds != nil {
120		si = ds.si
121		want = ds.si.NumInput()
122		cc.want = want
123	}
124
125	// Check all types of interfaces from the start.
126	// Drivers may opt to use the NamedValueChecker for special
127	// argument types, then return driver.ErrSkip to pass it along
128	// to the column converter.
129	nvc, ok := si.(driver.NamedValueChecker)
130	if !ok {
131		nvc, _ = ci.(driver.NamedValueChecker)
132	}
133	cci, ok := si.(driver.ColumnConverter)
134	if ok {
135		cc.cci = cci
136	}
137
138	// Loop through all the arguments, checking each one.
139	// If no error is returned simply increment the index
140	// and continue. However, if driver.ErrRemoveArgument
141	// is returned the argument is not included in the query
142	// argument list.
143	var err error
144	var n int
145	for _, arg := range args {
146		nv := &nvargs[n]
147		if np, ok := arg.(NamedArg); ok {
148			if err = validateNamedValueName(np.Name); err != nil {
149				return nil, err
150			}
151			arg = np.Value
152			nv.Name = np.Name
153		}
154		nv.Ordinal = n + 1
155		nv.Value = arg
156
157		// Checking sequence has four routes:
158		// A: 1. Default
159		// B: 1. NamedValueChecker 2. Column Converter 3. Default
160		// C: 1. NamedValueChecker 3. Default
161		// D: 1. Column Converter 2. Default
162		//
163		// The only time a Column Converter is called is first
164		// or after NamedValueConverter. If first it is handled before
165		// the nextCheck label. Thus for repeats tries only when the
166		// NamedValueConverter is selected should the Column Converter
167		// be used in the retry.
168		checker := defaultCheckNamedValue
169		nextCC := false
170		switch {
171		case nvc != nil:
172			nextCC = cci != nil
173			checker = nvc.CheckNamedValue
174		case cci != nil:
175			checker = cc.CheckNamedValue
176		}
177
178	nextCheck:
179		err = checker(nv)
180		switch err {
181		case nil:
182			n++
183			continue
184		case driver.ErrRemoveArgument:
185			nvargs = nvargs[:len(nvargs)-1]
186			continue
187		case driver.ErrSkip:
188			if nextCC {
189				nextCC = false
190				checker = cc.CheckNamedValue
191			} else {
192				checker = defaultCheckNamedValue
193			}
194			goto nextCheck
195		default:
196			return nil, fmt.Errorf("sql: converting argument %s type: %w", describeNamedValue(nv), err)
197		}
198	}
199
200	// Check the length of arguments after conversion to allow for omitted
201	// arguments.
202	if want != -1 && len(nvargs) != want {
203		return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
204	}
205
206	return nvargs, nil
207}
208
209// convertAssign is the same as convertAssignRows, but without the optional
210// rows argument.
211//
212// convertAssign should be an internal detail,
213// but widely used packages access it using linkname.
214// Notable members of the hall of shame include:
215//   - ariga.io/entcache
216//
217// Do not remove or change the type signature.
218// See go.dev/issue/67401.
219//
220//go:linkname convertAssign
221func convertAssign(dest, src any) error {
222	return convertAssignRows(dest, src, nil)
223}
224
225// convertAssignRows copies to dest the value in src, converting it if possible.
226// An error is returned if the copy would result in loss of information.
227// dest should be a pointer type. If rows is passed in, the rows will
228// be used as the parent for any cursor values converted from a
229// driver.Rows to a *Rows.
230func convertAssignRows(dest, src any, rows *Rows) error {
231	// Common cases, without reflect.
232	switch s := src.(type) {
233	case string:
234		switch d := dest.(type) {
235		case *string:
236			if d == nil {
237				return errNilPtr
238			}
239			*d = s
240			return nil
241		case *[]byte:
242			if d == nil {
243				return errNilPtr
244			}
245			*d = []byte(s)
246			return nil
247		case *RawBytes:
248			if d == nil {
249				return errNilPtr
250			}
251			*d = rows.setrawbuf(append(rows.rawbuf(), s...))
252			return nil
253		}
254	case []byte:
255		switch d := dest.(type) {
256		case *string:
257			if d == nil {
258				return errNilPtr
259			}
260			*d = string(s)
261			return nil
262		case *any:
263			if d == nil {
264				return errNilPtr
265			}
266			*d = bytes.Clone(s)
267			return nil
268		case *[]byte:
269			if d == nil {
270				return errNilPtr
271			}
272			*d = bytes.Clone(s)
273			return nil
274		case *RawBytes:
275			if d == nil {
276				return errNilPtr
277			}
278			*d = s
279			return nil
280		}
281	case time.Time:
282		switch d := dest.(type) {
283		case *time.Time:
284			*d = s
285			return nil
286		case *string:
287			*d = s.Format(time.RFC3339Nano)
288			return nil
289		case *[]byte:
290			if d == nil {
291				return errNilPtr
292			}
293			*d = []byte(s.Format(time.RFC3339Nano))
294			return nil
295		case *RawBytes:
296			if d == nil {
297				return errNilPtr
298			}
299			*d = rows.setrawbuf(s.AppendFormat(rows.rawbuf(), time.RFC3339Nano))
300			return nil
301		}
302	case decimalDecompose:
303		switch d := dest.(type) {
304		case decimalCompose:
305			return d.Compose(s.Decompose(nil))
306		}
307	case nil:
308		switch d := dest.(type) {
309		case *any:
310			if d == nil {
311				return errNilPtr
312			}
313			*d = nil
314			return nil
315		case *[]byte:
316			if d == nil {
317				return errNilPtr
318			}
319			*d = nil
320			return nil
321		case *RawBytes:
322			if d == nil {
323				return errNilPtr
324			}
325			*d = nil
326			return nil
327		}
328	// The driver is returning a cursor the client may iterate over.
329	case driver.Rows:
330		switch d := dest.(type) {
331		case *Rows:
332			if d == nil {
333				return errNilPtr
334			}
335			if rows == nil {
336				return errors.New("invalid context to convert cursor rows, missing parent *Rows")
337			}
338			rows.closemu.Lock()
339			*d = Rows{
340				dc:          rows.dc,
341				releaseConn: func(error) {},
342				rowsi:       s,
343			}
344			// Chain the cancel function.
345			parentCancel := rows.cancel
346			rows.cancel = func() {
347				// When Rows.cancel is called, the closemu will be locked as well.
348				// So we can access rs.lasterr.
349				d.close(rows.lasterr)
350				if parentCancel != nil {
351					parentCancel()
352				}
353			}
354			rows.closemu.Unlock()
355			return nil
356		}
357	}
358
359	var sv reflect.Value
360
361	switch d := dest.(type) {
362	case *string:
363		sv = reflect.ValueOf(src)
364		switch sv.Kind() {
365		case reflect.Bool,
366			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
367			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
368			reflect.Float32, reflect.Float64:
369			*d = asString(src)
370			return nil
371		}
372	case *[]byte:
373		sv = reflect.ValueOf(src)
374		if b, ok := asBytes(nil, sv); ok {
375			*d = b
376			return nil
377		}
378	case *RawBytes:
379		sv = reflect.ValueOf(src)
380		if b, ok := asBytes(rows.rawbuf(), sv); ok {
381			*d = rows.setrawbuf(b)
382			return nil
383		}
384	case *bool:
385		bv, err := driver.Bool.ConvertValue(src)
386		if err == nil {
387			*d = bv.(bool)
388		}
389		return err
390	case *any:
391		*d = src
392		return nil
393	}
394
395	if scanner, ok := dest.(Scanner); ok {
396		return scanner.Scan(src)
397	}
398
399	dpv := reflect.ValueOf(dest)
400	if dpv.Kind() != reflect.Pointer {
401		return errors.New("destination not a pointer")
402	}
403	if dpv.IsNil() {
404		return errNilPtr
405	}
406
407	if !sv.IsValid() {
408		sv = reflect.ValueOf(src)
409	}
410
411	dv := reflect.Indirect(dpv)
412	if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
413		switch b := src.(type) {
414		case []byte:
415			dv.Set(reflect.ValueOf(bytes.Clone(b)))
416		default:
417			dv.Set(sv)
418		}
419		return nil
420	}
421
422	if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
423		dv.Set(sv.Convert(dv.Type()))
424		return nil
425	}
426
427	// The following conversions use a string value as an intermediate representation
428	// to convert between various numeric types.
429	//
430	// This also allows scanning into user defined types such as "type Int int64".
431	// For symmetry, also check for string destination types.
432	switch dv.Kind() {
433	case reflect.Pointer:
434		if src == nil {
435			dv.SetZero()
436			return nil
437		}
438		dv.Set(reflect.New(dv.Type().Elem()))
439		return convertAssignRows(dv.Interface(), src, rows)
440	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
441		if src == nil {
442			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
443		}
444		s := asString(src)
445		i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
446		if err != nil {
447			err = strconvErr(err)
448			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
449		}
450		dv.SetInt(i64)
451		return nil
452	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
453		if src == nil {
454			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
455		}
456		s := asString(src)
457		u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
458		if err != nil {
459			err = strconvErr(err)
460			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
461		}
462		dv.SetUint(u64)
463		return nil
464	case reflect.Float32, reflect.Float64:
465		if src == nil {
466			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
467		}
468		s := asString(src)
469		f64, err := strconv.ParseFloat(s, dv.Type().Bits())
470		if err != nil {
471			err = strconvErr(err)
472			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
473		}
474		dv.SetFloat(f64)
475		return nil
476	case reflect.String:
477		if src == nil {
478			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
479		}
480		switch v := src.(type) {
481		case string:
482			dv.SetString(v)
483			return nil
484		case []byte:
485			dv.SetString(string(v))
486			return nil
487		}
488	}
489
490	return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
491}
492
493func strconvErr(err error) error {
494	if ne, ok := err.(*strconv.NumError); ok {
495		return ne.Err
496	}
497	return err
498}
499
500func asString(src any) string {
501	switch v := src.(type) {
502	case string:
503		return v
504	case []byte:
505		return string(v)
506	}
507	rv := reflect.ValueOf(src)
508	switch rv.Kind() {
509	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
510		return strconv.FormatInt(rv.Int(), 10)
511	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
512		return strconv.FormatUint(rv.Uint(), 10)
513	case reflect.Float64:
514		return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
515	case reflect.Float32:
516		return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
517	case reflect.Bool:
518		return strconv.FormatBool(rv.Bool())
519	}
520	return fmt.Sprintf("%v", src)
521}
522
523func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
524	switch rv.Kind() {
525	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
526		return strconv.AppendInt(buf, rv.Int(), 10), true
527	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
528		return strconv.AppendUint(buf, rv.Uint(), 10), true
529	case reflect.Float32:
530		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
531	case reflect.Float64:
532		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
533	case reflect.Bool:
534		return strconv.AppendBool(buf, rv.Bool()), true
535	case reflect.String:
536		s := rv.String()
537		return append(buf, s...), true
538	}
539	return
540}
541
542var valuerReflectType = reflect.TypeFor[driver.Valuer]()
543
544// callValuerValue returns vr.Value(), with one exception:
545// If vr.Value is an auto-generated method on a pointer type and the
546// pointer is nil, it would panic at runtime in the panicwrap
547// method. Treat it like nil instead.
548// Issue 8415.
549//
550// This is so people can implement driver.Value on value types and
551// still use nil pointers to those types to mean nil/NULL, just like
552// string/*string.
553//
554// This function is mirrored in the database/sql/driver package.
555func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
556	if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
557		rv.IsNil() &&
558		rv.Type().Elem().Implements(valuerReflectType) {
559		return nil, nil
560	}
561	return vr.Value()
562}
563
564// decimal composes or decomposes a decimal value to and from individual parts.
565// There are four parts: a boolean negative flag, a form byte with three possible states
566// (finite=0, infinite=1, NaN=2), a base-2 big-endian integer
567// coefficient (also known as a significand) as a []byte, and an int32 exponent.
568// These are composed into a final value as "decimal = (neg) (form=finite) coefficient * 10 ^ exponent".
569// A zero length coefficient is a zero value.
570// The big-endian integer coefficient stores the most significant byte first (at coefficient[0]).
571// If the form is not finite the coefficient and exponent should be ignored.
572// The negative parameter may be set to true for any form, although implementations are not required
573// to respect the negative parameter in the non-finite form.
574//
575// Implementations may choose to set the negative parameter to true on a zero or NaN value,
576// but implementations that do not differentiate between negative and positive
577// zero or NaN values should ignore the negative parameter without error.
578// If an implementation does not support Infinity it may be converted into a NaN without error.
579// If a value is set that is larger than what is supported by an implementation,
580// an error must be returned.
581// Implementations must return an error if a NaN or Infinity is attempted to be set while neither
582// are supported.
583//
584// NOTE(kardianos): This is an experimental interface. See https://golang.org/issue/30870
585type decimal interface {
586	decimalDecompose
587	decimalCompose
588}
589
590type decimalDecompose interface {
591	// Decompose returns the internal decimal state in parts.
592	// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
593	// the value set and length set as appropriate.
594	Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
595}
596
597type decimalCompose interface {
598	// Compose sets the internal decimal value from parts. If the value cannot be
599	// represented then an error should be returned.
600	Compose(form byte, negative bool, coefficient []byte, exponent int32) error
601}
602