xref: /aosp_15_r20/external/angle/src/common/system_utils_win32.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // system_utils_win32.cpp: Implementation of OS-specific functions for Windows.
7 
8 #include "common/FastVector.h"
9 #include "system_utils.h"
10 
11 #include <array>
12 
13 // Must be included in this order.
14 // clang-format off
15 #include <windows.h>
16 #include <psapi.h>
17 // clang-format on
18 
19 namespace angle
20 {
UnsetEnvironmentVar(const char * variableName)21 bool UnsetEnvironmentVar(const char *variableName)
22 {
23     return (SetEnvironmentVariableW(Widen(variableName).c_str(), nullptr) == TRUE);
24 }
25 
SetEnvironmentVar(const char * variableName,const char * value)26 bool SetEnvironmentVar(const char *variableName, const char *value)
27 {
28     return (SetEnvironmentVariableW(Widen(variableName).c_str(), Widen(value).c_str()) == TRUE);
29 }
30 
GetEnvironmentVar(const char * variableName)31 std::string GetEnvironmentVar(const char *variableName)
32 {
33     std::wstring variableNameUtf16 = Widen(variableName);
34     FastVector<wchar_t, MAX_PATH> value;
35 
36     DWORD result;
37 
38     // First get the length of the variable, including the null terminator
39     result = GetEnvironmentVariableW(variableNameUtf16.c_str(), nullptr, 0);
40 
41     // Zero means the variable was not found, so return now.
42     if (result == 0)
43     {
44         return std::string();
45     }
46 
47     // Now size the vector to fit the data, and read the environment variable.
48     value.resize(result, 0);
49     result = GetEnvironmentVariableW(variableNameUtf16.c_str(), value.data(), result);
50 
51     return Narrow(value.data());
52 }
53 
OpenSystemLibraryWithExtensionAndGetError(const char * libraryName,SearchType searchType,std::string * errorOut)54 void *OpenSystemLibraryWithExtensionAndGetError(const char *libraryName,
55                                                 SearchType searchType,
56                                                 std::string *errorOut)
57 {
58     char buffer[MAX_PATH];
59     int ret = snprintf(buffer, MAX_PATH, "%s.%s", libraryName, GetSharedLibraryExtension());
60     if (ret <= 0 || ret >= MAX_PATH)
61     {
62         fprintf(stderr, "Error generating library path: 0x%x", ret);
63         return nullptr;
64     }
65 
66     HMODULE libraryModule = nullptr;
67 
68     switch (searchType)
69     {
70         case SearchType::ModuleDir:
71         {
72             std::string moduleRelativePath = ConcatenatePath(GetModuleDirectory(), libraryName);
73             libraryModule                  = LoadLibraryW(Widen(moduleRelativePath).c_str());
74             if (libraryModule == nullptr && errorOut)
75             {
76                 *errorOut = std::string("failed to load library (SearchType::ModuleDir) ") +
77                             moduleRelativePath;
78             }
79             break;
80         }
81 
82         case SearchType::SystemDir:
83         {
84             libraryModule =
85                 LoadLibraryExW(Widen(libraryName).c_str(), nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32);
86             if (libraryModule == nullptr && errorOut)
87             {
88                 *errorOut =
89                     std::string("failed to load library (SearchType::SystemDir) ") + libraryName;
90             }
91             break;
92         }
93 
94         case SearchType::AlreadyLoaded:
95         {
96             libraryModule = GetModuleHandleW(Widen(libraryName).c_str());
97             if (libraryModule == nullptr && errorOut)
98             {
99                 *errorOut = std::string("failed to load library (SearchType::AlreadyLoaded) ") +
100                             libraryName;
101             }
102             break;
103         }
104     }
105 
106     return reinterpret_cast<void *>(libraryModule);
107 }
108 
109 namespace
110 {
111 class Win32PageFaultHandler : public PageFaultHandler
112 {
113   public:
Win32PageFaultHandler(PageFaultCallback callback)114     Win32PageFaultHandler(PageFaultCallback callback) : PageFaultHandler(callback) {}
~Win32PageFaultHandler()115     ~Win32PageFaultHandler() override {}
116 
117     bool enable() override;
118     bool disable() override;
119 
120     LONG handle(PEXCEPTION_POINTERS pExceptionInfo);
121 
122   private:
123     void *mVectoredExceptionHandler = nullptr;
124 };
125 
126 Win32PageFaultHandler *gWin32PageFaultHandler = nullptr;
VectoredExceptionHandler(PEXCEPTION_POINTERS info)127 static LONG CALLBACK VectoredExceptionHandler(PEXCEPTION_POINTERS info)
128 {
129     return gWin32PageFaultHandler->handle(info);
130 }
131 
SetMemoryProtection(uintptr_t start,size_t size,DWORD protections)132 bool SetMemoryProtection(uintptr_t start, size_t size, DWORD protections)
133 {
134     DWORD oldProtect;
135     BOOL res = VirtualProtect(reinterpret_cast<LPVOID>(start), size, protections, &oldProtect);
136     if (!res)
137     {
138         DWORD lastError = GetLastError();
139         fprintf(stderr, "VirtualProtect failed: 0x%lx\n", lastError);
140         return false;
141     }
142 
143     return true;
144 }
145 
handle(PEXCEPTION_POINTERS info)146 LONG Win32PageFaultHandler::handle(PEXCEPTION_POINTERS info)
147 {
148     bool found = false;
149 
150     if (info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION &&
151         info->ExceptionRecord->NumberParameters >= 2 &&
152         info->ExceptionRecord->ExceptionInformation[0] == 1)
153     {
154         found = mCallback(static_cast<uintptr_t>(info->ExceptionRecord->ExceptionInformation[1])) ==
155                 PageFaultHandlerRangeType::InRange;
156     }
157 
158     if (found)
159     {
160         return EXCEPTION_CONTINUE_EXECUTION;
161     }
162     else
163     {
164         return EXCEPTION_CONTINUE_SEARCH;
165     }
166 }
167 
disable()168 bool Win32PageFaultHandler::disable()
169 {
170     if (mVectoredExceptionHandler)
171     {
172         ULONG res                 = RemoveVectoredExceptionHandler(mVectoredExceptionHandler);
173         mVectoredExceptionHandler = nullptr;
174         if (res == 0)
175         {
176             DWORD lastError = GetLastError();
177             fprintf(stderr, "RemoveVectoredExceptionHandler failed: 0x%lx\n", lastError);
178             return false;
179         }
180     }
181     return true;
182 }
183 
enable()184 bool Win32PageFaultHandler::enable()
185 {
186     if (mVectoredExceptionHandler)
187     {
188         return true;
189     }
190 
191     PVECTORED_EXCEPTION_HANDLER handler =
192         reinterpret_cast<PVECTORED_EXCEPTION_HANDLER>(&VectoredExceptionHandler);
193 
194     mVectoredExceptionHandler = AddVectoredExceptionHandler(1, handler);
195 
196     if (!mVectoredExceptionHandler)
197     {
198         DWORD lastError = GetLastError();
199         fprintf(stderr, "AddVectoredExceptionHandler failed: 0x%lx\n", lastError);
200         return false;
201     }
202     return true;
203 }
204 }  // namespace
205 
206 // Set write protection
ProtectMemory(uintptr_t start,size_t size)207 bool ProtectMemory(uintptr_t start, size_t size)
208 {
209     return SetMemoryProtection(start, size, PAGE_READONLY);
210 }
211 
212 // Allow reading and writing
UnprotectMemory(uintptr_t start,size_t size)213 bool UnprotectMemory(uintptr_t start, size_t size)
214 {
215     return SetMemoryProtection(start, size, PAGE_READWRITE);
216 }
217 
GetPageSize()218 size_t GetPageSize()
219 {
220     SYSTEM_INFO info;
221     GetSystemInfo(&info);
222     return static_cast<size_t>(info.dwPageSize);
223 }
224 
CreatePageFaultHandler(PageFaultCallback callback)225 PageFaultHandler *CreatePageFaultHandler(PageFaultCallback callback)
226 {
227     gWin32PageFaultHandler = new Win32PageFaultHandler(callback);
228     return gWin32PageFaultHandler;
229 }
230 
GetProcessMemoryUsageKB()231 uint64_t GetProcessMemoryUsageKB()
232 {
233     PROCESS_MEMORY_COUNTERS_EX pmc;
234     ::GetProcessMemoryInfo(::GetCurrentProcess(), reinterpret_cast<PROCESS_MEMORY_COUNTERS *>(&pmc),
235                            sizeof(pmc));
236     return static_cast<uint64_t>(pmc.PrivateUsage) / 1024ull;
237 }
238 }  // namespace angle
239