xref: /aosp_15_r20/external/cronet/base/win/com_init_check_hook.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/win/com_init_check_hook.h"
6 
7 #include <objbase.h>
8 
9 #include <windows.h>
10 
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include <ostream>
15 #include <string>
16 
17 #include "base/notreached.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/synchronization/lock.h"
20 #include "base/win/com_init_util.h"
21 #include "base/win/patch_util.h"
22 
23 namespace base {
24 namespace win {
25 
26 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
27 
28 namespace {
29 
30 // Hotpatchable Microsoft x86 32-bit functions take one of two forms:
31 // Newer format:
32 // RelAddr  Binary     Instruction                 Remarks
33 //      -5  cc         int 3
34 //      -4  cc         int 3
35 //      -3  cc         int 3
36 //      -2  cc         int 3
37 //      -1  cc         int 3
38 //       0  8bff       mov edi,edi                 Actual entry point and no-op.
39 //       2  ...                                    Actual body.
40 //
41 // Older format:
42 // RelAddr  Binary     Instruction                 Remarks
43 //      -5  90         nop
44 //      -4  90         nop
45 //      -3  90         nop
46 //      -2  90         nop
47 //      -1  90         nop
48 //       0  8bff       mov edi,edi                 Actual entry point and no-op.
49 //       2  ...                                    Actual body.
50 //
51 // The "int 3" or nop sled as well as entry point no-op are critical, as they
52 // are just enough to patch in a short backwards jump to -5 (2 bytes) then that
53 // can do a relative 32-bit jump about 2GB before or after the current address.
54 //
55 // To perform a hotpatch, we need to figure out where we want to go and where
56 // we are now as the final jump is relative. Let's say we want to jump to
57 // 0x12345678. Relative jumps are calculated from eip, which for our jump is the
58 // next instruction address. For the example above, that means we start at a 0
59 // base address.
60 //
61 // Our patch will then look as follows:
62 // RelAddr  Binary     Instruction                 Remarks
63 //      -5  e978563412 jmp 0x12345678-(-0x5+0x5)   Note little-endian format.
64 //       0  ebf9       jmp -0x5-(0x0+0x2)          Goes to RelAddr -0x5.
65 //       2  ...                                    Actual body.
66 // Note: The jmp instructions above are structured as
67 //       Address(Destination)-(Address(jmp Instruction)+sizeof(jmp Instruction))
68 
69 // The struct below is provided for convenience and must be packed together byte
70 // by byte with no word alignment padding. This comes at a very small
71 // performance cost because now there are shifts handling the fields, but
72 // it improves readability.
73 #pragma pack(push, 1)
74 struct StructuredHotpatch {
75   unsigned char jmp_32_relative = 0xe9;  // jmp relative 32-bit.
76   int32_t relative_address = 0;          // 32-bit signed operand.
77   unsigned char jmp_8_relative = 0xeb;   // jmp relative 8-bit.
78   unsigned char back_address = 0xf9;     // Operand of -7.
79 };
80 #pragma pack(pop)
81 
82 static_assert(sizeof(StructuredHotpatch) == 7,
83               "Needs to be exactly 7 bytes for the hotpatch to work.");
84 
85 // nop Function Padding with "mov edi,edi"
86 const unsigned char g_hotpatch_placeholder_nop[] = {0x90, 0x90, 0x90, 0x90,
87                                                     0x90, 0x8b, 0xff};
88 
89 // int 3 Function Padding with "mov edi,edi"
90 const unsigned char g_hotpatch_placeholder_int3[] = {0xcc, 0xcc, 0xcc, 0xcc,
91                                                      0xcc, 0x8b, 0xff};
92 
93 // http://crbug.com/1312659: Unusable apphelp placeholder missing one byte.
94 const unsigned char g_hotpatch_placeholder_apphelp[] = {0x00, 0xcc, 0xcc, 0xcc,
95                                                         0xcc, 0x8b, 0xff};
96 
97 class HookManager {
98  public:
GetInstance()99   static HookManager* GetInstance() {
100     static auto* hook_manager = new HookManager();
101     return hook_manager;
102   }
103 
104   HookManager(const HookManager&) = delete;
105   HookManager& operator=(const HookManager&) = delete;
106 
RegisterHook()107   void RegisterHook() {
108     AutoLock auto_lock(lock_);
109     ++init_count_;
110     if (disabled_)
111       return;
112     if (init_count_ == 1)
113       WriteHook();
114   }
115 
UnregisterHook()116   void UnregisterHook() {
117     AutoLock auto_lock(lock_);
118     DCHECK_NE(0U, init_count_);
119     --init_count_;
120     if (disabled_)
121       return;
122     if (init_count_ == 0)
123       RevertHook();
124   }
125 
DisableCOMChecksForProcess()126   void DisableCOMChecksForProcess() {
127     AutoLock auto_lock(lock_);
128     if (disabled_)
129       return;
130     disabled_ = true;
131     if (init_count_ > 0)
132       RevertHook();
133   }
134 
135  private:
136   enum class HotpatchPlaceholderFormat {
137     // The hotpatch placeholder is currently unknown
138     UNKNOWN,
139     // The hotpatch placeholder used int 3's in the sled.
140     INT3,
141     // The hotpatch placeholder used nop's in the sled.
142     NOP,
143     // The hotpatch placeholder is an unusable apphelp shim.
144     APPHELP_SHIM,
145     // This function has already been patched by a different component.
146     EXTERNALLY_PATCHED,
147   };
148 
149   HookManager() = default;
150   ~HookManager() = default;
151 
WriteHook()152   void WriteHook() {
153     lock_.AssertAcquired();
154     DCHECK(!ole32_library_);
155     ole32_library_ = ::LoadLibrary(L"ole32.dll");
156 
157     if (!ole32_library_)
158       return;
159 
160     // See banner comment above why this subtracts 5 bytes.
161     co_create_instance_padded_address_ =
162         reinterpret_cast<uint32_t>(
163             GetProcAddress(ole32_library_, "CoCreateInstance")) -
164         5;
165 
166     // See banner comment above why this adds 7 bytes.
167     original_co_create_instance_body_function_ =
168         reinterpret_cast<decltype(original_co_create_instance_body_function_)>(
169             co_create_instance_padded_address_ + 7);
170 
171     uint32_t dchecked_co_create_instance_address =
172         reinterpret_cast<uint32_t>(&HookManager::DCheckedCoCreateInstance);
173     uint32_t jmp_offset_base_address = co_create_instance_padded_address_ + 5;
174     structured_hotpatch_.relative_address = static_cast<int32_t>(
175         dchecked_co_create_instance_address - jmp_offset_base_address);
176 
177     HotpatchPlaceholderFormat format = GetHotpatchPlaceholderFormat(
178         reinterpret_cast<const void*>(co_create_instance_padded_address_));
179     if (format == HotpatchPlaceholderFormat::UNKNOWN) {
180       NOTREACHED() << "Unrecognized hotpatch function format: "
181                    << FirstSevenBytesToString(
182                           co_create_instance_padded_address_);
183       return;
184     } else if (format == HotpatchPlaceholderFormat::EXTERNALLY_PATCHED) {
185       hotpatch_placeholder_format_ = format;
186       NOTREACHED() << "CoCreateInstance appears to be previously patched. <"
187                    << FirstSevenBytesToString(
188                           co_create_instance_padded_address_)
189                    << "> Attempted to write <"
190                    << FirstSevenBytesToString(
191                           reinterpret_cast<uint32_t>(&structured_hotpatch_))
192                    << ">";
193       return;
194     } else if (format == HotpatchPlaceholderFormat::APPHELP_SHIM) {
195       // The apphelp shim placeholder does not allocate enough bytes for a
196       // trampolined jump. In this case, we skip patching.
197       hotpatch_placeholder_format_ = format;
198       return;
199     }
200 
201     DCHECK_EQ(hotpatch_placeholder_format_, HotpatchPlaceholderFormat::UNKNOWN);
202     DWORD patch_result = internal::ModifyCode(
203         reinterpret_cast<void*>(co_create_instance_padded_address_),
204         reinterpret_cast<void*>(&structured_hotpatch_),
205         sizeof(structured_hotpatch_));
206     if (patch_result == NO_ERROR)
207       hotpatch_placeholder_format_ = format;
208   }
209 
RevertHook()210   void RevertHook() {
211     lock_.AssertAcquired();
212 
213     DWORD revert_result = NO_ERROR;
214     switch (hotpatch_placeholder_format_) {
215       case HotpatchPlaceholderFormat::INT3:
216         if (WasHotpatchChanged())
217           return;
218         revert_result = internal::ModifyCode(
219             reinterpret_cast<void*>(co_create_instance_padded_address_),
220             reinterpret_cast<const void*>(&g_hotpatch_placeholder_int3),
221             sizeof(g_hotpatch_placeholder_int3));
222         break;
223       case HotpatchPlaceholderFormat::NOP:
224         if (WasHotpatchChanged())
225           return;
226         revert_result = internal::ModifyCode(
227             reinterpret_cast<void*>(co_create_instance_padded_address_),
228             reinterpret_cast<const void*>(&g_hotpatch_placeholder_nop),
229             sizeof(g_hotpatch_placeholder_nop));
230         break;
231       case HotpatchPlaceholderFormat::EXTERNALLY_PATCHED:
232       case HotpatchPlaceholderFormat::APPHELP_SHIM:
233       case HotpatchPlaceholderFormat::UNKNOWN:
234         break;
235     }
236     DCHECK_EQ(revert_result, static_cast<DWORD>(NO_ERROR))
237         << "Failed to revert CoCreateInstance hot-patch";
238 
239     hotpatch_placeholder_format_ = HotpatchPlaceholderFormat::UNKNOWN;
240 
241     if (ole32_library_) {
242       ::FreeLibrary(ole32_library_);
243       ole32_library_ = nullptr;
244     }
245 
246     co_create_instance_padded_address_ = 0;
247     original_co_create_instance_body_function_ = nullptr;
248   }
249 
GetHotpatchPlaceholderFormat(const void * address)250   HotpatchPlaceholderFormat GetHotpatchPlaceholderFormat(const void* address) {
251     if (::memcmp(reinterpret_cast<void*>(co_create_instance_padded_address_),
252                  reinterpret_cast<const void*>(&g_hotpatch_placeholder_int3),
253                  sizeof(g_hotpatch_placeholder_int3)) == 0) {
254       return HotpatchPlaceholderFormat::INT3;
255     }
256 
257     if (::memcmp(reinterpret_cast<void*>(co_create_instance_padded_address_),
258                  reinterpret_cast<const void*>(&g_hotpatch_placeholder_nop),
259                  sizeof(g_hotpatch_placeholder_nop)) == 0) {
260       return HotpatchPlaceholderFormat::NOP;
261     }
262 
263     if (::memcmp(reinterpret_cast<void*>(co_create_instance_padded_address_),
264                  reinterpret_cast<const void*>(&g_hotpatch_placeholder_apphelp),
265                  sizeof(g_hotpatch_placeholder_apphelp)) == 0) {
266       return HotpatchPlaceholderFormat::APPHELP_SHIM;
267     }
268 
269     const unsigned char* instruction_bytes =
270         reinterpret_cast<const unsigned char*>(
271             co_create_instance_padded_address_);
272     const unsigned char entry_point_byte = instruction_bytes[5];
273     // Check for all of the common jmp opcodes.
274     if (entry_point_byte == 0xeb || entry_point_byte == 0xe9 ||
275         entry_point_byte == 0xff || entry_point_byte == 0xea) {
276       return HotpatchPlaceholderFormat::EXTERNALLY_PATCHED;
277     }
278 
279     return HotpatchPlaceholderFormat::UNKNOWN;
280   }
281 
WasHotpatchChanged()282   bool WasHotpatchChanged() {
283     if (::memcmp(reinterpret_cast<void*>(co_create_instance_padded_address_),
284                  reinterpret_cast<const void*>(&structured_hotpatch_),
285                  sizeof(structured_hotpatch_)) == 0) {
286       return false;
287     }
288 
289     NOTREACHED() << "CoCreateInstance patch overwritten. Expected: <"
290                  << FirstSevenBytesToString(co_create_instance_padded_address_)
291                  << ">, Actual: <"
292                  << FirstSevenBytesToString(
293                         reinterpret_cast<uint32_t>(&structured_hotpatch_))
294                  << ">";
295     return true;
296   }
297 
298   // Indirect call to original_co_create_instance_body_function_ triggers CFI
299   // so this function must have CFI disabled.
DCheckedCoCreateInstance(const CLSID & rclsid,IUnknown * pUnkOuter,DWORD dwClsContext,REFIID riid,void ** ppv)300   static DISABLE_CFI_ICALL HRESULT __stdcall DCheckedCoCreateInstance(
301       const CLSID& rclsid,
302       IUnknown* pUnkOuter,
303       DWORD dwClsContext,
304       REFIID riid,
305       void** ppv) {
306     // Chromium COM callers need to make sure that their thread is configured to
307     // process COM objects to avoid creating an implicit MTA or silently failing
308     // STA object creation call due to the SUCCEEDED() pattern for COM calls.
309     //
310     // If you hit this assert as part of migrating to the Task Scheduler,
311     // evaluate your threading guarantees and dispatch your work with
312     // base::ThreadPool::CreateCOMSTATaskRunner().
313     //
314     // If you need MTA support, ping //base/task/thread_pool/OWNERS.
315     AssertComInitialized(
316         "CoCreateInstance calls in Chromium require explicit COM "
317         "initialization via base::ThreadPool::CreateCOMSTATaskRunner() or "
318         "ScopedCOMInitializer. See the comment in DCheckedCoCreateInstance for "
319         "more details.");
320     return original_co_create_instance_body_function_(rclsid, pUnkOuter,
321                                                       dwClsContext, riid, ppv);
322   }
323 
324   // Returns the first 7 bytes in hex as a string at |address|.
FirstSevenBytesToString(uint32_t address)325   static std::string FirstSevenBytesToString(uint32_t address) {
326     const unsigned char* bytes =
327         reinterpret_cast<const unsigned char*>(address);
328     return base::StringPrintf("%02x %02x %02x %02x %02x %02x %02x", bytes[0],
329                               bytes[1], bytes[2], bytes[3], bytes[4], bytes[5],
330                               bytes[6]);
331   }
332 
333   // Synchronizes everything in this class.
334   base::Lock lock_;
335   size_t init_count_ = 0;
336   bool disabled_ = false;
337   HMODULE ole32_library_ = nullptr;
338   uint32_t co_create_instance_padded_address_ = 0;
339   HotpatchPlaceholderFormat hotpatch_placeholder_format_ =
340       HotpatchPlaceholderFormat::UNKNOWN;
341   StructuredHotpatch structured_hotpatch_;
342   static decltype(
343       ::CoCreateInstance)* original_co_create_instance_body_function_;
344 };
345 
346 decltype(::CoCreateInstance)*
347     HookManager::original_co_create_instance_body_function_ = nullptr;
348 
349 }  // namespace
350 
351 #endif  // defined(COM_INIT_CHECK_HOOK_ENABLED)
352 
ComInitCheckHook()353 ComInitCheckHook::ComInitCheckHook() {
354 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
355   HookManager::GetInstance()->RegisterHook();
356 #endif  // defined(COM_INIT_CHECK_HOOK_ENABLED)
357 }
358 
~ComInitCheckHook()359 ComInitCheckHook::~ComInitCheckHook() {
360 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
361   HookManager::GetInstance()->UnregisterHook();
362 #endif  // defined(COM_INIT_CHECK_HOOK_ENABLED)
363 }
364 
DisableCOMChecksForProcess()365 void ComInitCheckHook::DisableCOMChecksForProcess() {
366 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
367   HookManager::GetInstance()->DisableCOMChecksForProcess();
368 #endif
369 }
370 
371 }  // namespace win
372 }  // namespace base
373