1 // Copyright 2022 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/win/access_control_list.h"
6
7 #include <windows.h>
8
9 #include <aclapi.h>
10
11 #include <utility>
12 #include <vector>
13
14 #include "base/check.h"
15 #include "base/logging.h"
16 #include "base/notreached.h"
17 #include "base/numerics/checked_math.h"
18 #include "base/win/scoped_localalloc.h"
19
20 namespace base::win {
21
22 namespace {
23
AclToBuffer(const ACL * acl)24 std::unique_ptr<uint8_t[]> AclToBuffer(const ACL* acl) {
25 if (!acl) {
26 return nullptr;
27 }
28 size_t size = acl->AclSize;
29 DCHECK(size >= sizeof(*acl));
30 std::unique_ptr<uint8_t[]> ptr = std::make_unique<uint8_t[]>(size);
31 memcpy(ptr.get(), acl, size);
32 return ptr;
33 }
34
EmptyAclToBuffer()35 std::unique_ptr<uint8_t[]> EmptyAclToBuffer() {
36 ACL acl = {};
37 acl.AclRevision = ACL_REVISION;
38 acl.AclSize = static_cast<WORD>(sizeof(acl));
39 return AclToBuffer(&acl);
40 }
41
ConvertAccessMode(SecurityAccessMode access_mode)42 ACCESS_MODE ConvertAccessMode(SecurityAccessMode access_mode) {
43 switch (access_mode) {
44 case SecurityAccessMode::kGrant:
45 return GRANT_ACCESS;
46 case SecurityAccessMode::kSet:
47 return SET_ACCESS;
48 case SecurityAccessMode::kDeny:
49 return DENY_ACCESS;
50 case SecurityAccessMode::kRevoke:
51 return REVOKE_ACCESS;
52 }
53 }
54
AddACEToAcl(ACL * old_acl,const std::vector<ExplicitAccessEntry> & entries)55 std::unique_ptr<uint8_t[]> AddACEToAcl(
56 ACL* old_acl,
57 const std::vector<ExplicitAccessEntry>& entries) {
58 std::vector<EXPLICIT_ACCESS> access_entries(entries.size());
59 auto entries_interator = access_entries.begin();
60 for (const ExplicitAccessEntry& entry : entries) {
61 EXPLICIT_ACCESS& new_access = *entries_interator++;
62 new_access.grfAccessMode = ConvertAccessMode(entry.mode());
63 new_access.grfAccessPermissions = entry.access_mask();
64 new_access.grfInheritance = entry.inheritance();
65 ::BuildTrusteeWithSid(&new_access.Trustee, entry.sid().GetPSID());
66 }
67
68 PACL new_acl = nullptr;
69 DWORD error = ::SetEntriesInAcl(checked_cast<ULONG>(access_entries.size()),
70 access_entries.data(), old_acl, &new_acl);
71 if (error != ERROR_SUCCESS) {
72 ::SetLastError(error);
73 DPLOG(ERROR) << "Failed adding ACEs to ACL";
74 return nullptr;
75 }
76 auto new_acl_ptr = TakeLocalAlloc(new_acl);
77 return AclToBuffer(new_acl_ptr.get());
78 }
79
80 } // namespace
81
Clone() const82 ExplicitAccessEntry ExplicitAccessEntry::Clone() const {
83 return ExplicitAccessEntry{sid_, mode_, access_mask_, inheritance_};
84 }
85
ExplicitAccessEntry(const Sid & sid,SecurityAccessMode mode,DWORD access_mask,DWORD inheritance)86 ExplicitAccessEntry::ExplicitAccessEntry(const Sid& sid,
87 SecurityAccessMode mode,
88 DWORD access_mask,
89 DWORD inheritance)
90 : sid_(sid.Clone()),
91 mode_(mode),
92 access_mask_(access_mask),
93 inheritance_(inheritance) {}
94
ExplicitAccessEntry(WellKnownSid known_sid,SecurityAccessMode mode,DWORD access_mask,DWORD inheritance)95 ExplicitAccessEntry::ExplicitAccessEntry(WellKnownSid known_sid,
96 SecurityAccessMode mode,
97 DWORD access_mask,
98 DWORD inheritance)
99 : ExplicitAccessEntry(Sid(known_sid), mode, access_mask, inheritance) {}
100
101 ExplicitAccessEntry::ExplicitAccessEntry(ExplicitAccessEntry&&) = default;
102 ExplicitAccessEntry& ExplicitAccessEntry::operator=(ExplicitAccessEntry&&) =
103 default;
104 ExplicitAccessEntry::~ExplicitAccessEntry() = default;
105
FromPACL(ACL * acl)106 std::optional<AccessControlList> AccessControlList::FromPACL(ACL* acl) {
107 if (acl && !::IsValidAcl(acl)) {
108 ::SetLastError(ERROR_INVALID_ACL);
109 return std::nullopt;
110 }
111 return AccessControlList{acl};
112 }
113
FromMandatoryLabel(DWORD integrity_level,DWORD inheritance,DWORD mandatory_policy)114 std::optional<AccessControlList> AccessControlList::FromMandatoryLabel(
115 DWORD integrity_level,
116 DWORD inheritance,
117 DWORD mandatory_policy) {
118 Sid sid = Sid::FromIntegrityLevel(integrity_level);
119 // Get total ACL length. SYSTEM_MANDATORY_LABEL_ACE contains the first DWORD
120 // of the SID so remove it from total.
121 DWORD length = sizeof(ACL) + sizeof(SYSTEM_MANDATORY_LABEL_ACE) +
122 ::GetLengthSid(sid.GetPSID()) - sizeof(DWORD);
123 std::unique_ptr<uint8_t[]> sacl_ptr = std::make_unique<uint8_t[]>(length);
124 PACL sacl = reinterpret_cast<PACL>(sacl_ptr.get());
125
126 if (!::InitializeAcl(sacl, length, ACL_REVISION)) {
127 return std::nullopt;
128 }
129
130 if (!::AddMandatoryAce(sacl, ACL_REVISION, inheritance, mandatory_policy,
131 sid.GetPSID())) {
132 return std::nullopt;
133 }
134
135 DCHECK(::IsValidAcl(sacl));
136 AccessControlList ret;
137 ret.acl_ = std::move(sacl_ptr);
138 return ret;
139 }
140
AccessControlList()141 AccessControlList::AccessControlList() : acl_(EmptyAclToBuffer()) {}
142 AccessControlList::AccessControlList(AccessControlList&&) = default;
143 AccessControlList& AccessControlList::operator=(AccessControlList&&) = default;
144 AccessControlList::~AccessControlList() = default;
145
SetEntries(const std::vector<ExplicitAccessEntry> & entries)146 bool AccessControlList::SetEntries(
147 const std::vector<ExplicitAccessEntry>& entries) {
148 if (entries.empty())
149 return true;
150
151 std::unique_ptr<uint8_t[]> acl = AddACEToAcl(get(), entries);
152 if (!acl)
153 return false;
154
155 acl_ = std::move(acl);
156 return true;
157 }
158
SetEntry(const Sid & sid,SecurityAccessMode mode,DWORD access_mask,DWORD inheritance)159 bool AccessControlList::SetEntry(const Sid& sid,
160 SecurityAccessMode mode,
161 DWORD access_mask,
162 DWORD inheritance) {
163 std::vector<ExplicitAccessEntry> ace_list;
164 ace_list.emplace_back(sid, mode, access_mask, inheritance);
165 return SetEntries(ace_list);
166 }
167
Clone() const168 AccessControlList AccessControlList::Clone() const {
169 return AccessControlList{get()};
170 }
171
Clear()172 void AccessControlList::Clear() {
173 acl_ = EmptyAclToBuffer();
174 }
175
AccessControlList(const ACL * acl)176 AccessControlList::AccessControlList(const ACL* acl) : acl_(AclToBuffer(acl)) {}
177
178 } // namespace base::win
179