1import io
2import os
3import pathlib
4import unittest
5import warnings
6from test.support import findfile, warnings_helper
7from test.support.os_helper import TESTFN, unlink
8
9imghdr = warnings_helper.import_deprecated("imghdr")
10
11
12TEST_FILES = (
13    ('python.png', 'png'),
14    ('python.gif', 'gif'),
15    ('python.bmp', 'bmp'),
16    ('python.ppm', 'ppm'),
17    ('python.pgm', 'pgm'),
18    ('python.pbm', 'pbm'),
19    ('python.jpg', 'jpeg'),
20    ('python-raw.jpg', 'jpeg'),  # raw JPEG without JFIF/EXIF markers
21    ('python.ras', 'rast'),
22    ('python.sgi', 'rgb'),
23    ('python.tiff', 'tiff'),
24    ('python.xbm', 'xbm'),
25    ('python.webp', 'webp'),
26    ('python.exr', 'exr'),
27)
28
29class UnseekableIO(io.FileIO):
30    def tell(self):
31        raise io.UnsupportedOperation
32
33    def seek(self, *args, **kwargs):
34        raise io.UnsupportedOperation
35
36class TestImghdr(unittest.TestCase):
37    @classmethod
38    def setUpClass(cls):
39        cls.testfile = findfile('python.png', subdir='imghdrdata')
40        with open(cls.testfile, 'rb') as stream:
41            cls.testdata = stream.read()
42
43    def tearDown(self):
44        unlink(TESTFN)
45
46    def test_data(self):
47        for filename, expected in TEST_FILES:
48            filename = findfile(filename, subdir='imghdrdata')
49            self.assertEqual(imghdr.what(filename), expected)
50            with open(filename, 'rb') as stream:
51                self.assertEqual(imghdr.what(stream), expected)
52            with open(filename, 'rb') as stream:
53                data = stream.read()
54            self.assertEqual(imghdr.what(None, data), expected)
55            self.assertEqual(imghdr.what(None, bytearray(data)), expected)
56
57    def test_pathlike_filename(self):
58        for filename, expected in TEST_FILES:
59            with self.subTest(filename=filename):
60                filename = findfile(filename, subdir='imghdrdata')
61                self.assertEqual(imghdr.what(pathlib.Path(filename)), expected)
62
63    def test_register_test(self):
64        def test_jumbo(h, file):
65            if h.startswith(b'eggs'):
66                return 'ham'
67        imghdr.tests.append(test_jumbo)
68        self.addCleanup(imghdr.tests.pop)
69        self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
70
71    def test_file_pos(self):
72        with open(TESTFN, 'wb') as stream:
73            stream.write(b'ababagalamaga')
74            pos = stream.tell()
75            stream.write(self.testdata)
76        with open(TESTFN, 'rb') as stream:
77            stream.seek(pos)
78            self.assertEqual(imghdr.what(stream), 'png')
79            self.assertEqual(stream.tell(), pos)
80
81    def test_bad_args(self):
82        with self.assertRaises(TypeError):
83            imghdr.what()
84        with self.assertRaises(AttributeError):
85            imghdr.what(None)
86        with self.assertRaises(TypeError):
87            imghdr.what(self.testfile, 1)
88        with self.assertRaises(AttributeError):
89            imghdr.what(os.fsencode(self.testfile))
90        with open(self.testfile, 'rb') as f:
91            with self.assertRaises(AttributeError):
92                imghdr.what(f.fileno())
93
94    def test_invalid_headers(self):
95        for header in (b'\211PN\r\n',
96                       b'\001\331',
97                       b'\x59\xA6',
98                       b'cutecat',
99                       b'000000JFI',
100                       b'GIF80'):
101            self.assertIsNone(imghdr.what(None, header))
102
103    def test_string_data(self):
104        with warnings.catch_warnings():
105            warnings.simplefilter("ignore", BytesWarning)
106            for filename, _ in TEST_FILES:
107                filename = findfile(filename, subdir='imghdrdata')
108                with open(filename, 'rb') as stream:
109                    data = stream.read().decode('latin1')
110                with self.assertRaises(TypeError):
111                    imghdr.what(io.StringIO(data))
112                with self.assertRaises(TypeError):
113                    imghdr.what(None, data)
114
115    def test_missing_file(self):
116        with self.assertRaises(FileNotFoundError):
117            imghdr.what('missing')
118
119    def test_closed_file(self):
120        stream = open(self.testfile, 'rb')
121        stream.close()
122        with self.assertRaises(ValueError) as cm:
123            imghdr.what(stream)
124        stream = io.BytesIO(self.testdata)
125        stream.close()
126        with self.assertRaises(ValueError) as cm:
127            imghdr.what(stream)
128
129    def test_unseekable(self):
130        with open(TESTFN, 'wb') as stream:
131            stream.write(self.testdata)
132        with UnseekableIO(TESTFN, 'rb') as stream:
133            with self.assertRaises(io.UnsupportedOperation):
134                imghdr.what(stream)
135
136    def test_output_stream(self):
137        with open(TESTFN, 'wb') as stream:
138            stream.write(self.testdata)
139            stream.seek(0)
140            with self.assertRaises(OSError) as cm:
141                imghdr.what(stream)
142
143if __name__ == '__main__':
144    unittest.main()
145