xref: /aosp_15_r20/external/cronet/base/win/com_init_util.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/win/com_init_util.h"
6 
7 #include <windows.h>
8 
9 #include <stdint.h>
10 #include <winternl.h>
11 
12 #include "base/logging.h"
13 #include "base/notreached.h"
14 
15 namespace base {
16 namespace win {
17 
18 namespace {
19 
20 #if DCHECK_IS_ON()
21 const char kComNotInitialized[] = "COM is not initialized on this thread.";
22 #endif  // DCHECK_IS_ON()
23 
24 // Derived from combase.dll.
25 struct OleTlsData {
26   enum ApartmentFlags {
27     LOGICAL_THREAD_REGISTERED = 0x2,
28     STA = 0x80,
29     MTA = 0x140,
30   };
31 
32   uintptr_t thread_base;
33   uintptr_t sm_allocator;
34   DWORD apartment_id;
35   DWORD apartment_flags;
36   // There are many more fields than this, but for our purposes, we only care
37   // about |apartment_flags|. Correctly declaring the previous types allows this
38   // to work between x86 and x64 builds.
39 };
40 
GetOleTlsData()41 OleTlsData* GetOleTlsData() {
42   TEB* teb = NtCurrentTeb();
43   return reinterpret_cast<OleTlsData*>(teb->ReservedForOle);
44 }
45 
46 }  // namespace
47 
GetComApartmentTypeForThread()48 ComApartmentType GetComApartmentTypeForThread() {
49   OleTlsData* ole_tls_data = GetOleTlsData();
50   if (!ole_tls_data)
51     return ComApartmentType::NONE;
52 
53   if (ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::STA)
54     return ComApartmentType::STA;
55 
56   if ((ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::MTA) ==
57       OleTlsData::ApartmentFlags::MTA) {
58     return ComApartmentType::MTA;
59   }
60 
61   return ComApartmentType::NONE;
62 }
63 
64 #if DCHECK_IS_ON()
65 
AssertComInitialized(const char * message)66 void AssertComInitialized(const char* message) {
67   if (GetComApartmentTypeForThread() != ComApartmentType::NONE)
68     return;
69 
70   // COM worker threads don't always set up the apartment, but they do perform
71   // some thread registration, so we allow those.
72   OleTlsData* ole_tls_data = GetOleTlsData();
73   if (ole_tls_data && (ole_tls_data->apartment_flags &
74                        OleTlsData::ApartmentFlags::LOGICAL_THREAD_REGISTERED)) {
75     return;
76   }
77 
78   NOTREACHED() << (message ? message : kComNotInitialized);
79 }
80 
AssertComApartmentType(ComApartmentType apartment_type)81 void AssertComApartmentType(ComApartmentType apartment_type) {
82   DCHECK_EQ(apartment_type, GetComApartmentTypeForThread());
83 }
84 
85 #endif  // DCHECK_IS_ON()
86 
87 }  // namespace win
88 }  // namespace base
89