xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/multiprocessing/shared_memory.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1"""Provides shared memory for direct access across processes.
2
3The API of this package is currently provisional. Refer to the
4documentation for details.
5"""
6
7
8__all__ = [ 'SharedMemory', 'ShareableList' ]
9
10
11from functools import partial
12import mmap
13import os
14import errno
15import struct
16import secrets
17import types
18
19if os.name == "nt":
20    import _winapi
21    _USE_POSIX = False
22else:
23    import _posixshmem
24    _USE_POSIX = True
25
26from . import resource_tracker
27
28_O_CREX = os.O_CREAT | os.O_EXCL
29
30# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
31_SHM_SAFE_NAME_LENGTH = 14
32
33# Shared memory block name prefix
34if _USE_POSIX:
35    _SHM_NAME_PREFIX = '/psm_'
36else:
37    _SHM_NAME_PREFIX = 'wnsm_'
38
39
40def _make_filename():
41    "Create a random filename for the shared memory object."
42    # number of random bytes to use for name
43    nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
44    assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
45    name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
46    assert len(name) <= _SHM_SAFE_NAME_LENGTH
47    return name
48
49
50class SharedMemory:
51    """Creates a new shared memory block or attaches to an existing
52    shared memory block.
53
54    Every shared memory block is assigned a unique name.  This enables
55    one process to create a shared memory block with a particular name
56    so that a different process can attach to that same shared memory
57    block using that same name.
58
59    As a resource for sharing data across processes, shared memory blocks
60    may outlive the original process that created them.  When one process
61    no longer needs access to a shared memory block that might still be
62    needed by other processes, the close() method should be called.
63    When a shared memory block is no longer needed by any process, the
64    unlink() method should be called to ensure proper cleanup."""
65
66    # Defaults; enables close() and unlink() to run without errors.
67    _name = None
68    _fd = -1
69    _mmap = None
70    _buf = None
71    _flags = os.O_RDWR
72    _mode = 0o600
73    _prepend_leading_slash = True if _USE_POSIX else False
74
75    def __init__(self, name=None, create=False, size=0):
76        if not size >= 0:
77            raise ValueError("'size' must be a positive integer")
78        if create:
79            self._flags = _O_CREX | os.O_RDWR
80            if size == 0:
81                raise ValueError("'size' must be a positive number different from zero")
82        if name is None and not self._flags & os.O_EXCL:
83            raise ValueError("'name' can only be None if create=True")
84
85        if _USE_POSIX:
86
87            # POSIX Shared Memory
88
89            if name is None:
90                while True:
91                    name = _make_filename()
92                    try:
93                        self._fd = _posixshmem.shm_open(
94                            name,
95                            self._flags,
96                            mode=self._mode
97                        )
98                    except FileExistsError:
99                        continue
100                    self._name = name
101                    break
102            else:
103                name = "/" + name if self._prepend_leading_slash else name
104                self._fd = _posixshmem.shm_open(
105                    name,
106                    self._flags,
107                    mode=self._mode
108                )
109                self._name = name
110            try:
111                if create and size:
112                    os.ftruncate(self._fd, size)
113                stats = os.fstat(self._fd)
114                size = stats.st_size
115                self._mmap = mmap.mmap(self._fd, size)
116            except OSError:
117                self.unlink()
118                raise
119
120            resource_tracker.register(self._name, "shared_memory")
121
122        else:
123
124            # Windows Named Shared Memory
125
126            if create:
127                while True:
128                    temp_name = _make_filename() if name is None else name
129                    # Create and reserve shared memory block with this name
130                    # until it can be attached to by mmap.
131                    h_map = _winapi.CreateFileMapping(
132                        _winapi.INVALID_HANDLE_VALUE,
133                        _winapi.NULL,
134                        _winapi.PAGE_READWRITE,
135                        (size >> 32) & 0xFFFFFFFF,
136                        size & 0xFFFFFFFF,
137                        temp_name
138                    )
139                    try:
140                        last_error_code = _winapi.GetLastError()
141                        if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
142                            if name is not None:
143                                raise FileExistsError(
144                                    errno.EEXIST,
145                                    os.strerror(errno.EEXIST),
146                                    name,
147                                    _winapi.ERROR_ALREADY_EXISTS
148                                )
149                            else:
150                                continue
151                        self._mmap = mmap.mmap(-1, size, tagname=temp_name)
152                    finally:
153                        _winapi.CloseHandle(h_map)
154                    self._name = temp_name
155                    break
156
157            else:
158                self._name = name
159                # Dynamically determine the existing named shared memory
160                # block's size which is likely a multiple of mmap.PAGESIZE.
161                h_map = _winapi.OpenFileMapping(
162                    _winapi.FILE_MAP_READ,
163                    False,
164                    name
165                )
166                try:
167                    p_buf = _winapi.MapViewOfFile(
168                        h_map,
169                        _winapi.FILE_MAP_READ,
170                        0,
171                        0,
172                        0
173                    )
174                finally:
175                    _winapi.CloseHandle(h_map)
176                try:
177                    size = _winapi.VirtualQuerySize(p_buf)
178                finally:
179                    _winapi.UnmapViewOfFile(p_buf)
180                self._mmap = mmap.mmap(-1, size, tagname=name)
181
182        self._size = size
183        self._buf = memoryview(self._mmap)
184
185    def __del__(self):
186        try:
187            self.close()
188        except OSError:
189            pass
190
191    def __reduce__(self):
192        return (
193            self.__class__,
194            (
195                self.name,
196                False,
197                self.size,
198            ),
199        )
200
201    def __repr__(self):
202        return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
203
204    @property
205    def buf(self):
206        "A memoryview of contents of the shared memory block."
207        return self._buf
208
209    @property
210    def name(self):
211        "Unique name that identifies the shared memory block."
212        reported_name = self._name
213        if _USE_POSIX and self._prepend_leading_slash:
214            if self._name.startswith("/"):
215                reported_name = self._name[1:]
216        return reported_name
217
218    @property
219    def size(self):
220        "Size in bytes."
221        return self._size
222
223    def close(self):
224        """Closes access to the shared memory from this instance but does
225        not destroy the shared memory block."""
226        if self._buf is not None:
227            self._buf.release()
228            self._buf = None
229        if self._mmap is not None:
230            self._mmap.close()
231            self._mmap = None
232        if _USE_POSIX and self._fd >= 0:
233            os.close(self._fd)
234            self._fd = -1
235
236    def unlink(self):
237        """Requests that the underlying shared memory block be destroyed.
238
239        In order to ensure proper cleanup of resources, unlink should be
240        called once (and only once) across all processes which have access
241        to the shared memory block."""
242        if _USE_POSIX and self._name:
243            _posixshmem.shm_unlink(self._name)
244            resource_tracker.unregister(self._name, "shared_memory")
245
246
247_encoding = "utf8"
248
249class ShareableList:
250    """Pattern for a mutable list-like object shareable via a shared
251    memory block.  It differs from the built-in list type in that these
252    lists can not change their overall length (i.e. no append, insert,
253    etc.)
254
255    Because values are packed into a memoryview as bytes, the struct
256    packing format for any storable value must require no more than 8
257    characters to describe its format."""
258
259    # The shared memory area is organized as follows:
260    # - 8 bytes: number of items (N) as a 64-bit integer
261    # - (N + 1) * 8 bytes: offsets of each element from the start of the
262    #                      data area
263    # - K bytes: the data area storing item values (with encoding and size
264    #            depending on their respective types)
265    # - N * 8 bytes: `struct` format string for each element
266    # - N bytes: index into _back_transforms_mapping for each element
267    #            (for reconstructing the corresponding Python value)
268    _types_mapping = {
269        int: "q",
270        float: "d",
271        bool: "xxxxxxx?",
272        str: "%ds",
273        bytes: "%ds",
274        None.__class__: "xxxxxx?x",
275    }
276    _alignment = 8
277    _back_transforms_mapping = {
278        0: lambda value: value,                   # int, float, bool
279        1: lambda value: value.rstrip(b'\x00').decode(_encoding),  # str
280        2: lambda value: value.rstrip(b'\x00'),   # bytes
281        3: lambda _value: None,                   # None
282    }
283
284    @staticmethod
285    def _extract_recreation_code(value):
286        """Used in concert with _back_transforms_mapping to convert values
287        into the appropriate Python objects when retrieving them from
288        the list as well as when storing them."""
289        if not isinstance(value, (str, bytes, None.__class__)):
290            return 0
291        elif isinstance(value, str):
292            return 1
293        elif isinstance(value, bytes):
294            return 2
295        else:
296            return 3  # NoneType
297
298    def __init__(self, sequence=None, *, name=None):
299        if name is None or sequence is not None:
300            sequence = sequence or ()
301            _formats = [
302                self._types_mapping[type(item)]
303                    if not isinstance(item, (str, bytes))
304                    else self._types_mapping[type(item)] % (
305                        self._alignment * (len(item) // self._alignment + 1),
306                    )
307                for item in sequence
308            ]
309            self._list_len = len(_formats)
310            assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
311            offset = 0
312            # The offsets of each list element into the shared memory's
313            # data area (0 meaning the start of the data area, not the start
314            # of the shared memory area).
315            self._allocated_offsets = [0]
316            for fmt in _formats:
317                offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
318                self._allocated_offsets.append(offset)
319            _recreation_codes = [
320                self._extract_recreation_code(item) for item in sequence
321            ]
322            requested_size = struct.calcsize(
323                "q" + self._format_size_metainfo +
324                "".join(_formats) +
325                self._format_packing_metainfo +
326                self._format_back_transform_codes
327            )
328
329            self.shm = SharedMemory(name, create=True, size=requested_size)
330        else:
331            self.shm = SharedMemory(name)
332
333        if sequence is not None:
334            _enc = _encoding
335            struct.pack_into(
336                "q" + self._format_size_metainfo,
337                self.shm.buf,
338                0,
339                self._list_len,
340                *(self._allocated_offsets)
341            )
342            struct.pack_into(
343                "".join(_formats),
344                self.shm.buf,
345                self._offset_data_start,
346                *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
347            )
348            struct.pack_into(
349                self._format_packing_metainfo,
350                self.shm.buf,
351                self._offset_packing_formats,
352                *(v.encode(_enc) for v in _formats)
353            )
354            struct.pack_into(
355                self._format_back_transform_codes,
356                self.shm.buf,
357                self._offset_back_transform_codes,
358                *(_recreation_codes)
359            )
360
361        else:
362            self._list_len = len(self)  # Obtains size from offset 0 in buffer.
363            self._allocated_offsets = list(
364                struct.unpack_from(
365                    self._format_size_metainfo,
366                    self.shm.buf,
367                    1 * 8
368                )
369            )
370
371    def _get_packing_format(self, position):
372        "Gets the packing format for a single value stored in the list."
373        position = position if position >= 0 else position + self._list_len
374        if (position >= self._list_len) or (self._list_len < 0):
375            raise IndexError("Requested position out of range.")
376
377        v = struct.unpack_from(
378            "8s",
379            self.shm.buf,
380            self._offset_packing_formats + position * 8
381        )[0]
382        fmt = v.rstrip(b'\x00')
383        fmt_as_str = fmt.decode(_encoding)
384
385        return fmt_as_str
386
387    def _get_back_transform(self, position):
388        "Gets the back transformation function for a single value."
389
390        if (position >= self._list_len) or (self._list_len < 0):
391            raise IndexError("Requested position out of range.")
392
393        transform_code = struct.unpack_from(
394            "b",
395            self.shm.buf,
396            self._offset_back_transform_codes + position
397        )[0]
398        transform_function = self._back_transforms_mapping[transform_code]
399
400        return transform_function
401
402    def _set_packing_format_and_transform(self, position, fmt_as_str, value):
403        """Sets the packing format and back transformation code for a
404        single value in the list at the specified position."""
405
406        if (position >= self._list_len) or (self._list_len < 0):
407            raise IndexError("Requested position out of range.")
408
409        struct.pack_into(
410            "8s",
411            self.shm.buf,
412            self._offset_packing_formats + position * 8,
413            fmt_as_str.encode(_encoding)
414        )
415
416        transform_code = self._extract_recreation_code(value)
417        struct.pack_into(
418            "b",
419            self.shm.buf,
420            self._offset_back_transform_codes + position,
421            transform_code
422        )
423
424    def __getitem__(self, position):
425        position = position if position >= 0 else position + self._list_len
426        try:
427            offset = self._offset_data_start + self._allocated_offsets[position]
428            (v,) = struct.unpack_from(
429                self._get_packing_format(position),
430                self.shm.buf,
431                offset
432            )
433        except IndexError:
434            raise IndexError("index out of range")
435
436        back_transform = self._get_back_transform(position)
437        v = back_transform(v)
438
439        return v
440
441    def __setitem__(self, position, value):
442        position = position if position >= 0 else position + self._list_len
443        try:
444            item_offset = self._allocated_offsets[position]
445            offset = self._offset_data_start + item_offset
446            current_format = self._get_packing_format(position)
447        except IndexError:
448            raise IndexError("assignment index out of range")
449
450        if not isinstance(value, (str, bytes)):
451            new_format = self._types_mapping[type(value)]
452            encoded_value = value
453        else:
454            allocated_length = self._allocated_offsets[position + 1] - item_offset
455
456            encoded_value = (value.encode(_encoding)
457                             if isinstance(value, str) else value)
458            if len(encoded_value) > allocated_length:
459                raise ValueError("bytes/str item exceeds available storage")
460            if current_format[-1] == "s":
461                new_format = current_format
462            else:
463                new_format = self._types_mapping[str] % (
464                    allocated_length,
465                )
466
467        self._set_packing_format_and_transform(
468            position,
469            new_format,
470            value
471        )
472        struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
473
474    def __reduce__(self):
475        return partial(self.__class__, name=self.shm.name), ()
476
477    def __len__(self):
478        return struct.unpack_from("q", self.shm.buf, 0)[0]
479
480    def __repr__(self):
481        return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
482
483    @property
484    def format(self):
485        "The struct packing format used by all currently stored items."
486        return "".join(
487            self._get_packing_format(i) for i in range(self._list_len)
488        )
489
490    @property
491    def _format_size_metainfo(self):
492        "The struct packing format used for the items' storage offsets."
493        return "q" * (self._list_len + 1)
494
495    @property
496    def _format_packing_metainfo(self):
497        "The struct packing format used for the items' packing formats."
498        return "8s" * self._list_len
499
500    @property
501    def _format_back_transform_codes(self):
502        "The struct packing format used for the items' back transforms."
503        return "b" * self._list_len
504
505    @property
506    def _offset_data_start(self):
507        # - 8 bytes for the list length
508        # - (N + 1) * 8 bytes for the element offsets
509        return (self._list_len + 2) * 8
510
511    @property
512    def _offset_packing_formats(self):
513        return self._offset_data_start + self._allocated_offsets[-1]
514
515    @property
516    def _offset_back_transform_codes(self):
517        return self._offset_packing_formats + self._list_len * 8
518
519    def count(self, value):
520        "L.count(value) -> integer -- return number of occurrences of value."
521
522        return sum(value == entry for entry in self)
523
524    def index(self, value):
525        """L.index(value) -> integer -- return first index of value.
526        Raises ValueError if the value is not present."""
527
528        for position, entry in enumerate(self):
529            if value == entry:
530                return position
531        else:
532            raise ValueError(f"{value!r} not in this container")
533
534    __class_getitem__ = classmethod(types.GenericAlias)
535