1 // Copyright 2017 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "base/win/com_init_util.h" 6 7 #include <windows.h> 8 9 #include <stdint.h> 10 #include <winternl.h> 11 12 #include "base/logging.h" 13 #include "base/notreached.h" 14 15 namespace base { 16 namespace win { 17 18 namespace { 19 20 #if DCHECK_IS_ON() 21 const char kComNotInitialized[] = "COM is not initialized on this thread."; 22 #endif // DCHECK_IS_ON() 23 24 // Derived from combase.dll. 25 struct OleTlsData { 26 enum ApartmentFlags { 27 LOGICAL_THREAD_REGISTERED = 0x2, 28 STA = 0x80, 29 MTA = 0x140, 30 }; 31 32 uintptr_t thread_base; 33 uintptr_t sm_allocator; 34 DWORD apartment_id; 35 DWORD apartment_flags; 36 // There are many more fields than this, but for our purposes, we only care 37 // about |apartment_flags|. Correctly declaring the previous types allows this 38 // to work between x86 and x64 builds. 39 }; 40 GetOleTlsData()41OleTlsData* GetOleTlsData() { 42 TEB* teb = NtCurrentTeb(); 43 return reinterpret_cast<OleTlsData*>(teb->ReservedForOle); 44 } 45 46 } // namespace 47 GetComApartmentTypeForThread()48ComApartmentType GetComApartmentTypeForThread() { 49 OleTlsData* ole_tls_data = GetOleTlsData(); 50 if (!ole_tls_data) 51 return ComApartmentType::NONE; 52 53 if (ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::STA) 54 return ComApartmentType::STA; 55 56 if ((ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::MTA) == 57 OleTlsData::ApartmentFlags::MTA) { 58 return ComApartmentType::MTA; 59 } 60 61 return ComApartmentType::NONE; 62 } 63 64 #if DCHECK_IS_ON() 65 AssertComInitialized(const char * message)66void AssertComInitialized(const char* message) { 67 if (GetComApartmentTypeForThread() != ComApartmentType::NONE) 68 return; 69 70 // COM worker threads don't always set up the apartment, but they do perform 71 // some thread registration, so we allow those. 72 OleTlsData* ole_tls_data = GetOleTlsData(); 73 if (ole_tls_data && (ole_tls_data->apartment_flags & 74 OleTlsData::ApartmentFlags::LOGICAL_THREAD_REGISTERED)) { 75 return; 76 } 77 78 NOTREACHED() << (message ? message : kComNotInitialized); 79 } 80 AssertComApartmentType(ComApartmentType apartment_type)81void AssertComApartmentType(ComApartmentType apartment_type) { 82 DCHECK_EQ(apartment_type, GetComApartmentTypeForThread()); 83 } 84 85 #endif // DCHECK_IS_ON() 86 87 } // namespace win 88 } // namespace base 89