1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"go/ast"
9	"go/token"
10	"reflect"
11	"strings"
12)
13
14func init() {
15	register(cftypeFix)
16}
17
18var cftypeFix = fix{
19	name:     "cftype",
20	date:     "2017-09-27",
21	f:        cftypefix,
22	desc:     `Fixes initializers and casts of C.*Ref and JNI types`,
23	disabled: false,
24}
25
26// Old state:
27//
28//	type CFTypeRef unsafe.Pointer
29//
30// New state:
31//
32//	type CFTypeRef uintptr
33//
34// and similar for other *Ref types.
35// This fix finds nils initializing these types and replaces the nils with 0s.
36func cftypefix(f *ast.File) bool {
37	return typefix(f, func(s string) bool {
38		return strings.HasPrefix(s, "C.") && strings.HasSuffix(s, "Ref") && s != "C.CFAllocatorRef"
39	})
40}
41
42// typefix replaces nil with 0 for all nils whose type, when passed to badType, returns true.
43func typefix(f *ast.File, badType func(string) bool) bool {
44	if !imports(f, "C") {
45		return false
46	}
47	typeof, _ := typecheck(&TypeConfig{}, f)
48	changed := false
49
50	// step 1: Find all the nils with the offending types.
51	// Compute their replacement.
52	badNils := map[any]ast.Expr{}
53	walk(f, func(n any) {
54		if i, ok := n.(*ast.Ident); ok && i.Name == "nil" && badType(typeof[n]) {
55			badNils[n] = &ast.BasicLit{ValuePos: i.NamePos, Kind: token.INT, Value: "0"}
56		}
57	})
58
59	// step 2: find all uses of the bad nils, replace them with 0.
60	// There's no easy way to map from an ast.Expr to all the places that use them, so
61	// we use reflect to find all such references.
62	if len(badNils) > 0 {
63		exprType := reflect.TypeFor[ast.Expr]()
64		exprSliceType := reflect.TypeFor[[]ast.Expr]()
65		walk(f, func(n any) {
66			if n == nil {
67				return
68			}
69			v := reflect.ValueOf(n)
70			if v.Type().Kind() != reflect.Pointer {
71				return
72			}
73			if v.IsNil() {
74				return
75			}
76			v = v.Elem()
77			if v.Type().Kind() != reflect.Struct {
78				return
79			}
80			for i := 0; i < v.NumField(); i++ {
81				f := v.Field(i)
82				if f.Type() == exprType {
83					if r := badNils[f.Interface()]; r != nil {
84						f.Set(reflect.ValueOf(r))
85						changed = true
86					}
87				}
88				if f.Type() == exprSliceType {
89					for j := 0; j < f.Len(); j++ {
90						e := f.Index(j)
91						if r := badNils[e.Interface()]; r != nil {
92							e.Set(reflect.ValueOf(r))
93							changed = true
94						}
95					}
96				}
97			}
98		})
99	}
100
101	// step 3: fix up invalid casts.
102	// It used to be ok to cast between *unsafe.Pointer and *C.CFTypeRef in a single step.
103	// Now we need unsafe.Pointer as an intermediate cast.
104	// (*unsafe.Pointer)(x) where x is type *bad -> (*unsafe.Pointer)(unsafe.Pointer(x))
105	// (*bad.type)(x) where x is type *unsafe.Pointer -> (*bad.type)(unsafe.Pointer(x))
106	walk(f, func(n any) {
107		if n == nil {
108			return
109		}
110		// Find pattern like (*a.b)(x)
111		c, ok := n.(*ast.CallExpr)
112		if !ok {
113			return
114		}
115		if len(c.Args) != 1 {
116			return
117		}
118		p, ok := c.Fun.(*ast.ParenExpr)
119		if !ok {
120			return
121		}
122		s, ok := p.X.(*ast.StarExpr)
123		if !ok {
124			return
125		}
126		t, ok := s.X.(*ast.SelectorExpr)
127		if !ok {
128			return
129		}
130		pkg, ok := t.X.(*ast.Ident)
131		if !ok {
132			return
133		}
134		dst := pkg.Name + "." + t.Sel.Name
135		src := typeof[c.Args[0]]
136		if badType(dst) && src == "*unsafe.Pointer" ||
137			dst == "unsafe.Pointer" && strings.HasPrefix(src, "*") && badType(src[1:]) {
138			c.Args[0] = &ast.CallExpr{
139				Fun:  &ast.SelectorExpr{X: &ast.Ident{Name: "unsafe"}, Sel: &ast.Ident{Name: "Pointer"}},
140				Args: []ast.Expr{c.Args[0]},
141			}
142			changed = true
143		}
144	})
145
146	return changed
147}
148