1# Licensed under the Apache License, Version 2.0 (the "License");
2# you may not use this file except in compliance with the License.
3# You may obtain a copy of the License at
4#
5#      http://www.apache.org/licenses/LICENSE-2.0
6#
7# Unless required by applicable law or agreed to in writing, software
8# distributed under the License is distributed on an "AS IS" BASIS,
9# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10# See the License for the specific language governing permissions and
11# limitations under the License.
12
13"""Helper classes use for fake file system implementation."""
14import io
15import locale
16import os
17import platform
18import stat
19import sys
20import time
21from copy import copy
22from stat import S_IFLNK
23from typing import Union, Optional, Any, AnyStr, overload, cast
24
25AnyString = Union[str, bytes]
26AnyPath = Union[AnyStr, os.PathLike]
27
28IS_PYPY = platform.python_implementation() == "PyPy"
29IS_WIN = sys.platform == "win32"
30IN_DOCKER = os.path.exists("/.dockerenv")
31
32PERM_READ = 0o400  # Read permission bit.
33PERM_WRITE = 0o200  # Write permission bit.
34PERM_EXE = 0o100  # Execute permission bit.
35PERM_DEF = 0o777  # Default permission bits.
36PERM_DEF_FILE = 0o666  # Default permission bits (regular file)
37PERM_ALL = 0o7777  # All permission bits.
38
39if sys.platform == "win32":
40    USER_ID = 1
41    GROUP_ID = 1
42else:
43    USER_ID = os.getuid()
44    GROUP_ID = os.getgid()
45
46
47def get_uid() -> int:
48    """Get the global user id. Same as ``os.getuid()``"""
49    return USER_ID
50
51
52def set_uid(uid: int) -> None:
53    """Set the global user id. This is used as st_uid for new files
54    and to differentiate between a normal user and the root user (uid 0).
55    For the root user, some permission restrictions are ignored.
56
57    Args:
58        uid: (int) the user ID of the user calling the file system functions.
59    """
60    global USER_ID
61    USER_ID = uid
62
63
64def get_gid() -> int:
65    """Get the global group id. Same as ``os.getgid()``"""
66    return GROUP_ID
67
68
69def set_gid(gid: int) -> None:
70    """Set the global group id. This is only used to set st_gid for new files,
71    no permission checks are performed.
72
73    Args:
74        gid: (int) the group ID of the user calling the file system functions.
75    """
76    global GROUP_ID
77    GROUP_ID = gid
78
79
80def reset_ids() -> None:
81    """Set the global user ID and group ID back to default values."""
82    if sys.platform == "win32":
83        set_uid(1)
84        set_gid(1)
85    else:
86        set_uid(os.getuid())
87        set_gid(os.getgid())
88
89
90def is_root() -> bool:
91    """Return True if the current user is the root user."""
92    return USER_ID == 0
93
94
95def is_int_type(val: Any) -> bool:
96    """Return True if `val` is of integer type."""
97    return isinstance(val, int)
98
99
100def is_byte_string(val: Any) -> bool:
101    """Return True if `val` is a bytes-like object, False for a unicode
102    string."""
103    return not hasattr(val, "encode")
104
105
106def is_unicode_string(val: Any) -> bool:
107    """Return True if `val` is a unicode string, False for a bytes-like
108    object."""
109    return hasattr(val, "encode")
110
111
112@overload
113def make_string_path(dir_name: AnyStr) -> AnyStr:
114    ...
115
116
117@overload
118def make_string_path(dir_name: os.PathLike) -> str:
119    ...
120
121
122def make_string_path(dir_name: AnyPath) -> AnyStr:
123    return cast(AnyStr, os.fspath(dir_name))  # pytype: disable=invalid-annotation
124
125
126def to_string(path: Union[AnyStr, Union[str, bytes]]) -> str:
127    """Return the string representation of a byte string using the preferred
128    encoding, or the string itself if path is a str."""
129    if isinstance(path, bytes):
130        return path.decode(locale.getpreferredencoding(False))
131    return path
132
133
134def to_bytes(path: Union[AnyStr, Union[str, bytes]]) -> bytes:
135    """Return the bytes representation of a string using the preferred
136    encoding, or the byte string itself if path is a byte string."""
137    if isinstance(path, str):
138        return bytes(path, locale.getpreferredencoding(False))
139    return path
140
141
142def join_strings(s1: AnyStr, s2: AnyStr) -> AnyStr:
143    """This is a bit of a hack to satisfy mypy - may be refactored."""
144    return s1 + s2
145
146
147def real_encoding(encoding: Optional[str]) -> Optional[str]:
148    """Since Python 3.10, the new function ``io.text_encoding`` returns
149    "locale" as the encoding if None is defined. This will be handled
150    as no encoding in pyfakefs."""
151    if sys.version_info >= (3, 10):
152        return encoding if encoding != "locale" else None
153    return encoding
154
155
156def now():
157    return time.time()
158
159
160@overload
161def matching_string(matched: bytes, string: AnyStr) -> bytes:
162    ...
163
164
165@overload
166def matching_string(matched: str, string: AnyStr) -> str:
167    ...
168
169
170@overload
171def matching_string(matched: AnyStr, string: None) -> None:
172    ...
173
174
175def matching_string(  # type: ignore[misc]
176    matched: AnyStr, string: Optional[AnyStr]
177) -> Optional[AnyString]:
178    """Return the string as byte or unicode depending
179    on the type of matched, assuming string is an ASCII string.
180    """
181    if string is None:
182        return string
183    if isinstance(matched, bytes) and isinstance(string, str):
184        return string.encode(locale.getpreferredencoding(False))
185    return string  # pytype: disable=bad-return-type
186
187
188class FakeStatResult:
189    """Mimics os.stat_result for use as return type of `stat()` and similar.
190    This is needed as `os.stat_result` has no possibility to set
191    nanosecond times directly.
192    """
193
194    def __init__(
195        self,
196        is_windows: bool,
197        user_id: int,
198        group_id: int,
199        initial_time: Optional[float] = None,
200    ):
201        self.st_mode: int = 0
202        self.st_ino: Optional[int] = None
203        self.st_dev: int = 0
204        self.st_nlink: int = 0
205        self.st_uid: int = user_id
206        self.st_gid: int = group_id
207        self._st_size: int = 0
208        self.is_windows: bool = is_windows
209        self._st_atime_ns: int = int((initial_time or 0) * 1e9)
210        self._st_mtime_ns: int = self._st_atime_ns
211        self._st_ctime_ns: int = self._st_atime_ns
212
213    def __eq__(self, other: Any) -> bool:
214        return (
215            isinstance(other, FakeStatResult)
216            and self._st_atime_ns == other._st_atime_ns
217            and self._st_ctime_ns == other._st_ctime_ns
218            and self._st_mtime_ns == other._st_mtime_ns
219            and self.st_size == other.st_size
220            and self.st_gid == other.st_gid
221            and self.st_uid == other.st_uid
222            and self.st_nlink == other.st_nlink
223            and self.st_dev == other.st_dev
224            and self.st_ino == other.st_ino
225            and self.st_mode == other.st_mode
226        )
227
228    def __ne__(self, other: Any) -> bool:
229        return not self == other
230
231    def copy(self) -> "FakeStatResult":
232        """Return a copy where the float usage is hard-coded to mimic the
233        behavior of the real os.stat_result.
234        """
235        stat_result = copy(self)
236        return stat_result
237
238    def set_from_stat_result(self, stat_result: os.stat_result) -> None:
239        """Set values from a real os.stat_result.
240        Note: values that are controlled by the fake filesystem are not set.
241        This includes st_ino, st_dev and st_nlink.
242        """
243        self.st_mode = stat_result.st_mode
244        self.st_uid = stat_result.st_uid
245        self.st_gid = stat_result.st_gid
246        self._st_size = stat_result.st_size
247        self._st_atime_ns = stat_result.st_atime_ns
248        self._st_mtime_ns = stat_result.st_mtime_ns
249        self._st_ctime_ns = stat_result.st_ctime_ns
250
251    @property
252    def st_ctime(self) -> Union[int, float]:
253        """Return the creation time in seconds."""
254        return self._st_ctime_ns / 1e9
255
256    @st_ctime.setter
257    def st_ctime(self, val: Union[int, float]) -> None:
258        """Set the creation time in seconds."""
259        self._st_ctime_ns = int(val * 1e9)
260
261    @property
262    def st_atime(self) -> Union[int, float]:
263        """Return the access time in seconds."""
264        return self._st_atime_ns / 1e9
265
266    @st_atime.setter
267    def st_atime(self, val: Union[int, float]) -> None:
268        """Set the access time in seconds."""
269        self._st_atime_ns = int(val * 1e9)
270
271    @property
272    def st_mtime(self) -> Union[int, float]:
273        """Return the modification time in seconds."""
274        return self._st_mtime_ns / 1e9
275
276    @st_mtime.setter
277    def st_mtime(self, val: Union[int, float]) -> None:
278        """Set the modification time in seconds."""
279        self._st_mtime_ns = int(val * 1e9)
280
281    @property
282    def st_size(self) -> int:
283        if self.st_mode & S_IFLNK == S_IFLNK and self.is_windows:
284            return 0
285        return self._st_size
286
287    @st_size.setter
288    def st_size(self, val: int) -> None:
289        self._st_size = val
290
291    @property
292    def st_blocks(self) -> int:
293        """Return the number of 512-byte blocks allocated for the file.
294        Assumes a page size of 4096 (matches most systems).
295        Ignores that this may not be available under some systems,
296        and that the result may differ if the file has holes.
297        """
298        if self.is_windows:
299            raise AttributeError("'os.stat_result' object has no attribute 'st_blocks'")
300        page_size = 4096
301        blocks_in_page = page_size // 512
302        pages = self._st_size // page_size
303        if self._st_size % page_size:
304            pages += 1
305        return pages * blocks_in_page
306
307    @property
308    def st_file_attributes(self) -> int:
309        if not self.is_windows:
310            raise AttributeError(
311                "module 'os.stat_result' " "has no attribute 'st_file_attributes'"
312            )
313        mode = 0
314        st_mode = self.st_mode
315        if st_mode & stat.S_IFDIR:
316            mode |= stat.FILE_ATTRIBUTE_DIRECTORY  # type:ignore[attr-defined]
317        if st_mode & stat.S_IFREG:
318            mode |= stat.FILE_ATTRIBUTE_NORMAL  # type:ignore[attr-defined]
319        if st_mode & (stat.S_IFCHR | stat.S_IFBLK):
320            mode |= stat.FILE_ATTRIBUTE_DEVICE  # type:ignore[attr-defined]
321        if st_mode & stat.S_IFLNK:
322            mode |= stat.FILE_ATTRIBUTE_REPARSE_POINT  # type:ignore
323        return mode
324
325    @property
326    def st_reparse_tag(self) -> int:
327        if not self.is_windows or sys.version_info < (3, 8):
328            raise AttributeError(
329                "module 'os.stat_result' " "has no attribute 'st_reparse_tag'"
330            )
331        if self.st_mode & stat.S_IFLNK:
332            return stat.IO_REPARSE_TAG_SYMLINK  # type: ignore[attr-defined]
333        return 0
334
335    def __getitem__(self, item: int) -> Optional[int]:
336        """Implement item access to mimic `os.stat_result` behavior."""
337        import stat
338
339        if item == stat.ST_MODE:
340            return self.st_mode
341        if item == stat.ST_INO:
342            return self.st_ino
343        if item == stat.ST_DEV:
344            return self.st_dev
345        if item == stat.ST_NLINK:
346            return self.st_nlink
347        if item == stat.ST_UID:
348            return self.st_uid
349        if item == stat.ST_GID:
350            return self.st_gid
351        if item == stat.ST_SIZE:
352            return self.st_size
353        if item == stat.ST_ATIME:
354            # item access always returns int for backward compatibility
355            return int(self.st_atime)
356        if item == stat.ST_MTIME:
357            return int(self.st_mtime)
358        if item == stat.ST_CTIME:
359            return int(self.st_ctime)
360        raise ValueError("Invalid item")
361
362    @property
363    def st_atime_ns(self) -> int:
364        """Return the access time in nanoseconds."""
365        return self._st_atime_ns
366
367    @st_atime_ns.setter
368    def st_atime_ns(self, val: int) -> None:
369        """Set the access time in nanoseconds."""
370        self._st_atime_ns = val
371
372    @property
373    def st_mtime_ns(self) -> int:
374        """Return the modification time in nanoseconds."""
375        return self._st_mtime_ns
376
377    @st_mtime_ns.setter
378    def st_mtime_ns(self, val: int) -> None:
379        """Set the modification time of the fake file in nanoseconds."""
380        self._st_mtime_ns = val
381
382    @property
383    def st_ctime_ns(self) -> int:
384        """Return the creation time in nanoseconds."""
385        return self._st_ctime_ns
386
387    @st_ctime_ns.setter
388    def st_ctime_ns(self, val: int) -> None:
389        """Set the creation time of the fake file in nanoseconds."""
390        self._st_ctime_ns = val
391
392
393class BinaryBufferIO(io.BytesIO):
394    """Stream class that handles byte contents for files."""
395
396    def __init__(self, contents: Optional[bytes]):
397        super().__init__(contents or b"")
398
399    def putvalue(self, value: bytes) -> None:
400        self.write(value)
401
402
403class TextBufferIO(io.TextIOWrapper):
404    """Stream class that handles Python string contents for files."""
405
406    def __init__(
407        self,
408        contents: Optional[bytes] = None,
409        newline: Optional[str] = None,
410        encoding: Optional[str] = None,
411        errors: str = "strict",
412    ):
413        self._bytestream = io.BytesIO(contents or b"")
414        super().__init__(self._bytestream, encoding, errors, newline)
415
416    def getvalue(self) -> bytes:
417        return self._bytestream.getvalue()
418
419    def putvalue(self, value: bytes) -> None:
420        self._bytestream.write(value)
421