1 // Copyright 2011 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/win/iat_patch_function.h"
6
7 #include "base/check_op.h"
8 #include "base/memory/raw_ptr_exclusion.h"
9 #include "base/notreached.h"
10 #include "base/win/patch_util.h"
11 #include "base/win/pe_image.h"
12
13 namespace base {
14 namespace win {
15
16 namespace {
17
18 struct InterceptFunctionInformation {
19 bool finished_operation;
20 const char* imported_from_module;
21 const char* function_name;
22 // RAW_PTR_EXCLUSION: #reinterpret-cast-trivial-type
23 RAW_PTR_EXCLUSION void* new_function;
24 RAW_PTR_EXCLUSION void** old_function;
25 RAW_PTR_EXCLUSION IMAGE_THUNK_DATA** iat_thunk;
26 DWORD return_code;
27 };
28
GetIATFunction(IMAGE_THUNK_DATA * iat_thunk)29 void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
30 if (!iat_thunk) {
31 NOTREACHED();
32 return nullptr;
33 }
34
35 // Works around the 64 bit portability warning:
36 // The Function member inside IMAGE_THUNK_DATA is really a pointer
37 // to the IAT function. IMAGE_THUNK_DATA correctly maps to IMAGE_THUNK_DATA32
38 // or IMAGE_THUNK_DATA64 for correct pointer size.
39 union FunctionThunk {
40 IMAGE_THUNK_DATA thunk;
41 // This field is not a raw_ptr<> because it was filtered by the rewriter
42 // for: #union
43 RAW_PTR_EXCLUSION void* pointer;
44 } iat_function;
45
46 iat_function.thunk = *iat_thunk;
47 return iat_function.pointer;
48 }
49
InterceptEnumCallback(const base::win::PEImage & image,const char * module,DWORD ordinal,const char * name,DWORD hint,IMAGE_THUNK_DATA * iat,void * cookie)50 bool InterceptEnumCallback(const base::win::PEImage& image,
51 const char* module,
52 DWORD ordinal,
53 const char* name,
54 DWORD hint,
55 IMAGE_THUNK_DATA* iat,
56 void* cookie) {
57 InterceptFunctionInformation* intercept_information =
58 reinterpret_cast<InterceptFunctionInformation*>(cookie);
59
60 if (!intercept_information) {
61 NOTREACHED();
62 return false;
63 }
64
65 DCHECK(module);
66
67 if (name && (0 == lstrcmpiA(name, intercept_information->function_name))) {
68 // Save the old pointer.
69 if (intercept_information->old_function) {
70 *(intercept_information->old_function) = GetIATFunction(iat);
71 }
72
73 if (intercept_information->iat_thunk) {
74 *(intercept_information->iat_thunk) = iat;
75 }
76
77 // portability check
78 static_assert(
79 sizeof(iat->u1.Function) == sizeof(intercept_information->new_function),
80 "unknown IAT thunk format");
81
82 // Patch the function.
83 intercept_information->return_code = internal::ModifyCode(
84 &(iat->u1.Function), &(intercept_information->new_function),
85 sizeof(intercept_information->new_function));
86
87 // Terminate further enumeration.
88 intercept_information->finished_operation = true;
89 return false;
90 }
91
92 return true;
93 }
94
95 // Helper to intercept a function in an import table of a specific
96 // module.
97 //
98 // Arguments:
99 // module_handle Module to be intercepted
100 // imported_from_module Module that exports the symbol
101 // function_name Name of the API to be intercepted
102 // new_function Interceptor function
103 // old_function Receives the original function pointer
104 // iat_thunk Receives pointer to IAT_THUNK_DATA
105 // for the API from the import table.
106 //
107 // Returns: Returns NO_ERROR on success or Windows error code
108 // as defined in winerror.h
InterceptImportedFunction(HMODULE module_handle,const char * imported_from_module,const char * function_name,void * new_function,void ** old_function,IMAGE_THUNK_DATA ** iat_thunk)109 DWORD InterceptImportedFunction(HMODULE module_handle,
110 const char* imported_from_module,
111 const char* function_name,
112 void* new_function,
113 void** old_function,
114 IMAGE_THUNK_DATA** iat_thunk) {
115 if (!module_handle || !imported_from_module || !function_name ||
116 !new_function) {
117 NOTREACHED();
118 return ERROR_INVALID_PARAMETER;
119 }
120
121 base::win::PEImage target_image(module_handle);
122 if (!target_image.VerifyMagic()) {
123 NOTREACHED();
124 return ERROR_INVALID_PARAMETER;
125 }
126
127 InterceptFunctionInformation intercept_information = {false,
128 imported_from_module,
129 function_name,
130 new_function,
131 old_function,
132 iat_thunk,
133 ERROR_GEN_FAILURE};
134
135 // First go through the IAT. If we don't find the import we are looking
136 // for in IAT, search delay import table.
137 target_image.EnumAllImports(InterceptEnumCallback, &intercept_information,
138 imported_from_module);
139 if (!intercept_information.finished_operation) {
140 target_image.EnumAllDelayImports(
141 InterceptEnumCallback, &intercept_information, imported_from_module);
142 }
143
144 return intercept_information.return_code;
145 }
146
147 // Restore intercepted IAT entry with the original function.
148 //
149 // Arguments:
150 // intercept_function Interceptor function
151 // original_function Receives the original function pointer
152 //
153 // Returns: Returns NO_ERROR on success or Windows error code
154 // as defined in winerror.h
RestoreImportedFunction(void * intercept_function,void * original_function,IMAGE_THUNK_DATA * iat_thunk)155 DWORD RestoreImportedFunction(void* intercept_function,
156 void* original_function,
157 IMAGE_THUNK_DATA* iat_thunk) {
158 if (!intercept_function || !original_function || !iat_thunk) {
159 NOTREACHED();
160 return ERROR_INVALID_PARAMETER;
161 }
162
163 if (GetIATFunction(iat_thunk) != intercept_function) {
164 // Check if someone else has intercepted on top of us.
165 // We cannot unpatch in this case, just raise a red flag.
166 NOTREACHED();
167 return ERROR_INVALID_FUNCTION;
168 }
169
170 return internal::ModifyCode(&(iat_thunk->u1.Function), &original_function,
171 sizeof(original_function));
172 }
173
174 } // namespace
175
176 IATPatchFunction::IATPatchFunction() = default;
177
~IATPatchFunction()178 IATPatchFunction::~IATPatchFunction() {
179 if (intercept_function_) {
180 DWORD error = Unpatch();
181 DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
182 }
183 }
184
Patch(const wchar_t * module,const char * imported_from_module,const char * function_name,void * new_function)185 DWORD IATPatchFunction::Patch(const wchar_t* module,
186 const char* imported_from_module,
187 const char* function_name,
188 void* new_function) {
189 HMODULE module_handle = LoadLibraryW(module);
190 if (!module_handle) {
191 NOTREACHED();
192 return GetLastError();
193 }
194
195 DWORD error = PatchFromModule(module_handle, imported_from_module,
196 function_name, new_function);
197 if (NO_ERROR == error) {
198 module_handle_ = module_handle;
199 } else {
200 FreeLibrary(module_handle);
201 }
202
203 return error;
204 }
205
PatchFromModule(HMODULE module,const char * imported_from_module,const char * function_name,void * new_function)206 DWORD IATPatchFunction::PatchFromModule(HMODULE module,
207 const char* imported_from_module,
208 const char* function_name,
209 void* new_function) {
210 DCHECK_EQ(nullptr, original_function_);
211 DCHECK_EQ(nullptr, iat_thunk_);
212 DCHECK_EQ(nullptr, intercept_function_);
213 DCHECK(module);
214
215 DWORD error = InterceptImportedFunction(
216 module, imported_from_module, function_name, new_function,
217 &original_function_.AsEphemeralRawAddr(),
218 &iat_thunk_.AsEphemeralRawAddr());
219
220 if (NO_ERROR == error) {
221 DCHECK_NE(original_function_, intercept_function_);
222 intercept_function_ = new_function;
223 }
224
225 return error;
226 }
227
Unpatch()228 DWORD IATPatchFunction::Unpatch() {
229 DWORD error = RestoreImportedFunction(intercept_function_, original_function_,
230 iat_thunk_);
231 DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
232
233 // Hands off the intercept if we fail to unpatch.
234 // If IATPatchFunction::Unpatch fails during RestoreImportedFunction
235 // it means that we cannot safely unpatch the import address table
236 // patch. In this case its better to be hands off the intercept as
237 // trying to unpatch again in the destructor of IATPatchFunction is
238 // not going to be any safer
239 if (module_handle_)
240 FreeLibrary(module_handle_);
241 module_handle_ = nullptr;
242 intercept_function_ = nullptr;
243 original_function_ = nullptr;
244 iat_thunk_ = nullptr;
245
246 return error;
247 }
248
original_function() const249 void* IATPatchFunction::original_function() const {
250 DCHECK(is_patched());
251 return original_function_;
252 }
253
254 } // namespace win
255 } // namespace base
256