1# XXX TypeErrors on calling handlers, or on bad return values from a
2# handler, are obscure and unhelpful.
3
4from io import BytesIO
5import os
6import platform
7import sys
8import sysconfig
9import unittest
10import traceback
11
12from xml.parsers import expat
13from xml.parsers.expat import errors
14
15from test.support import sortdict, is_emscripten, is_wasi
16
17
18class SetAttributeTest(unittest.TestCase):
19    def setUp(self):
20        self.parser = expat.ParserCreate(namespace_separator='!')
21
22    def test_buffer_text(self):
23        self.assertIs(self.parser.buffer_text, False)
24        for x in 0, 1, 2, 0:
25            self.parser.buffer_text = x
26            self.assertIs(self.parser.buffer_text, bool(x))
27
28    def test_namespace_prefixes(self):
29        self.assertIs(self.parser.namespace_prefixes, False)
30        for x in 0, 1, 2, 0:
31            self.parser.namespace_prefixes = x
32            self.assertIs(self.parser.namespace_prefixes, bool(x))
33
34    def test_ordered_attributes(self):
35        self.assertIs(self.parser.ordered_attributes, False)
36        for x in 0, 1, 2, 0:
37            self.parser.ordered_attributes = x
38            self.assertIs(self.parser.ordered_attributes, bool(x))
39
40    def test_specified_attributes(self):
41        self.assertIs(self.parser.specified_attributes, False)
42        for x in 0, 1, 2, 0:
43            self.parser.specified_attributes = x
44            self.assertIs(self.parser.specified_attributes, bool(x))
45
46    def test_invalid_attributes(self):
47        with self.assertRaises(AttributeError):
48            self.parser.returns_unicode = 1
49        with self.assertRaises(AttributeError):
50            self.parser.returns_unicode
51
52        # Issue #25019
53        self.assertRaises(TypeError, setattr, self.parser, range(0xF), 0)
54        self.assertRaises(TypeError, self.parser.__setattr__, range(0xF), 0)
55        self.assertRaises(TypeError, getattr, self.parser, range(0xF))
56
57
58data = b'''\
59<?xml version="1.0" encoding="iso-8859-1" standalone="no"?>
60<?xml-stylesheet href="stylesheet.css"?>
61<!-- comment data -->
62<!DOCTYPE quotations SYSTEM "quotations.dtd" [
63<!ELEMENT root ANY>
64<!ATTLIST root attr1 CDATA #REQUIRED attr2 CDATA #IMPLIED>
65<!NOTATION notation SYSTEM "notation.jpeg">
66<!ENTITY acirc "&#226;">
67<!ENTITY external_entity SYSTEM "entity.file">
68<!ENTITY unparsed_entity SYSTEM "entity.file" NDATA notation>
69%unparsed_entity;
70]>
71
72<root attr1="value1" attr2="value2&#8000;">
73<myns:subelement xmlns:myns="http://www.python.org/namespace">
74     Contents of subelements
75</myns:subelement>
76<sub2><![CDATA[contents of CDATA section]]></sub2>
77&external_entity;
78&skipped_entity;
79\xb5
80</root>
81'''
82
83
84# Produce UTF-8 output
85class ParseTest(unittest.TestCase):
86    class Outputter:
87        def __init__(self):
88            self.out = []
89
90        def StartElementHandler(self, name, attrs):
91            self.out.append('Start element: ' + repr(name) + ' ' +
92                            sortdict(attrs))
93
94        def EndElementHandler(self, name):
95            self.out.append('End element: ' + repr(name))
96
97        def CharacterDataHandler(self, data):
98            data = data.strip()
99            if data:
100                self.out.append('Character data: ' + repr(data))
101
102        def ProcessingInstructionHandler(self, target, data):
103            self.out.append('PI: ' + repr(target) + ' ' + repr(data))
104
105        def StartNamespaceDeclHandler(self, prefix, uri):
106            self.out.append('NS decl: ' + repr(prefix) + ' ' + repr(uri))
107
108        def EndNamespaceDeclHandler(self, prefix):
109            self.out.append('End of NS decl: ' + repr(prefix))
110
111        def StartCdataSectionHandler(self):
112            self.out.append('Start of CDATA section')
113
114        def EndCdataSectionHandler(self):
115            self.out.append('End of CDATA section')
116
117        def CommentHandler(self, text):
118            self.out.append('Comment: ' + repr(text))
119
120        def NotationDeclHandler(self, *args):
121            name, base, sysid, pubid = args
122            self.out.append('Notation declared: %s' %(args,))
123
124        def UnparsedEntityDeclHandler(self, *args):
125            entityName, base, systemId, publicId, notationName = args
126            self.out.append('Unparsed entity decl: %s' %(args,))
127
128        def NotStandaloneHandler(self):
129            self.out.append('Not standalone')
130            return 1
131
132        def ExternalEntityRefHandler(self, *args):
133            context, base, sysId, pubId = args
134            self.out.append('External entity ref: %s' %(args[1:],))
135            return 1
136
137        def StartDoctypeDeclHandler(self, *args):
138            self.out.append(('Start doctype', args))
139            return 1
140
141        def EndDoctypeDeclHandler(self):
142            self.out.append("End doctype")
143            return 1
144
145        def EntityDeclHandler(self, *args):
146            self.out.append(('Entity declaration', args))
147            return 1
148
149        def XmlDeclHandler(self, *args):
150            self.out.append(('XML declaration', args))
151            return 1
152
153        def ElementDeclHandler(self, *args):
154            self.out.append(('Element declaration', args))
155            return 1
156
157        def AttlistDeclHandler(self, *args):
158            self.out.append(('Attribute list declaration', args))
159            return 1
160
161        def SkippedEntityHandler(self, *args):
162            self.out.append(("Skipped entity", args))
163            return 1
164
165        def DefaultHandler(self, userData):
166            pass
167
168        def DefaultHandlerExpand(self, userData):
169            pass
170
171    handler_names = [
172        'StartElementHandler', 'EndElementHandler', 'CharacterDataHandler',
173        'ProcessingInstructionHandler', 'UnparsedEntityDeclHandler',
174        'NotationDeclHandler', 'StartNamespaceDeclHandler',
175        'EndNamespaceDeclHandler', 'CommentHandler',
176        'StartCdataSectionHandler', 'EndCdataSectionHandler', 'DefaultHandler',
177        'DefaultHandlerExpand', 'NotStandaloneHandler',
178        'ExternalEntityRefHandler', 'StartDoctypeDeclHandler',
179        'EndDoctypeDeclHandler', 'EntityDeclHandler', 'XmlDeclHandler',
180        'ElementDeclHandler', 'AttlistDeclHandler', 'SkippedEntityHandler',
181        ]
182
183    def _hookup_callbacks(self, parser, handler):
184        """
185        Set each of the callbacks defined on handler and named in
186        self.handler_names on the given parser.
187        """
188        for name in self.handler_names:
189            setattr(parser, name, getattr(handler, name))
190
191    def _verify_parse_output(self, operations):
192        expected_operations = [
193            ('XML declaration', ('1.0', 'iso-8859-1', 0)),
194            'PI: \'xml-stylesheet\' \'href="stylesheet.css"\'',
195            "Comment: ' comment data '",
196            "Not standalone",
197            ("Start doctype", ('quotations', 'quotations.dtd', None, 1)),
198            ('Element declaration', ('root', (2, 0, None, ()))),
199            ('Attribute list declaration', ('root', 'attr1', 'CDATA', None,
200                1)),
201            ('Attribute list declaration', ('root', 'attr2', 'CDATA', None,
202                0)),
203            "Notation declared: ('notation', None, 'notation.jpeg', None)",
204            ('Entity declaration', ('acirc', 0, '\xe2', None, None, None, None)),
205            ('Entity declaration', ('external_entity', 0, None, None,
206                'entity.file', None, None)),
207            "Unparsed entity decl: ('unparsed_entity', None, 'entity.file', None, 'notation')",
208            "Not standalone",
209            "End doctype",
210            "Start element: 'root' {'attr1': 'value1', 'attr2': 'value2\u1f40'}",
211            "NS decl: 'myns' 'http://www.python.org/namespace'",
212            "Start element: 'http://www.python.org/namespace!subelement' {}",
213            "Character data: 'Contents of subelements'",
214            "End element: 'http://www.python.org/namespace!subelement'",
215            "End of NS decl: 'myns'",
216            "Start element: 'sub2' {}",
217            'Start of CDATA section',
218            "Character data: 'contents of CDATA section'",
219            'End of CDATA section',
220            "End element: 'sub2'",
221            "External entity ref: (None, 'entity.file', None)",
222            ('Skipped entity', ('skipped_entity', 0)),
223            "Character data: '\xb5'",
224            "End element: 'root'",
225        ]
226        for operation, expected_operation in zip(operations, expected_operations):
227            self.assertEqual(operation, expected_operation)
228
229    def test_parse_bytes(self):
230        out = self.Outputter()
231        parser = expat.ParserCreate(namespace_separator='!')
232        self._hookup_callbacks(parser, out)
233
234        parser.Parse(data, True)
235
236        operations = out.out
237        self._verify_parse_output(operations)
238        # Issue #6697.
239        self.assertRaises(AttributeError, getattr, parser, '\uD800')
240
241    def test_parse_str(self):
242        out = self.Outputter()
243        parser = expat.ParserCreate(namespace_separator='!')
244        self._hookup_callbacks(parser, out)
245
246        parser.Parse(data.decode('iso-8859-1'), True)
247
248        operations = out.out
249        self._verify_parse_output(operations)
250
251    def test_parse_file(self):
252        # Try parsing a file
253        out = self.Outputter()
254        parser = expat.ParserCreate(namespace_separator='!')
255        self._hookup_callbacks(parser, out)
256        file = BytesIO(data)
257
258        parser.ParseFile(file)
259
260        operations = out.out
261        self._verify_parse_output(operations)
262
263    def test_parse_again(self):
264        parser = expat.ParserCreate()
265        file = BytesIO(data)
266        parser.ParseFile(file)
267        # Issue 6676: ensure a meaningful exception is raised when attempting
268        # to parse more than one XML document per xmlparser instance,
269        # a limitation of the Expat library.
270        with self.assertRaises(expat.error) as cm:
271            parser.ParseFile(file)
272        self.assertEqual(expat.ErrorString(cm.exception.code),
273                          expat.errors.XML_ERROR_FINISHED)
274
275class NamespaceSeparatorTest(unittest.TestCase):
276    def test_legal(self):
277        # Tests that make sure we get errors when the namespace_separator value
278        # is illegal, and that we don't for good values:
279        expat.ParserCreate()
280        expat.ParserCreate(namespace_separator=None)
281        expat.ParserCreate(namespace_separator=' ')
282
283    def test_illegal(self):
284        try:
285            expat.ParserCreate(namespace_separator=42)
286            self.fail()
287        except TypeError as e:
288            self.assertEqual(str(e),
289                "ParserCreate() argument 'namespace_separator' must be str or None, not int")
290
291        try:
292            expat.ParserCreate(namespace_separator='too long')
293            self.fail()
294        except ValueError as e:
295            self.assertEqual(str(e),
296                'namespace_separator must be at most one character, omitted, or None')
297
298    def test_zero_length(self):
299        # ParserCreate() needs to accept a namespace_separator of zero length
300        # to satisfy the requirements of RDF applications that are required
301        # to simply glue together the namespace URI and the localname.  Though
302        # considered a wart of the RDF specifications, it needs to be supported.
303        #
304        # See XML-SIG mailing list thread starting with
305        # http://mail.python.org/pipermail/xml-sig/2001-April/005202.html
306        #
307        expat.ParserCreate(namespace_separator='') # too short
308
309
310class InterningTest(unittest.TestCase):
311    def test(self):
312        # Test the interning machinery.
313        p = expat.ParserCreate()
314        L = []
315        def collector(name, *args):
316            L.append(name)
317        p.StartElementHandler = collector
318        p.EndElementHandler = collector
319        p.Parse(b"<e> <e/> <e></e> </e>", True)
320        tag = L[0]
321        self.assertEqual(len(L), 6)
322        for entry in L:
323            # L should have the same string repeated over and over.
324            self.assertTrue(tag is entry)
325
326    def test_issue9402(self):
327        # create an ExternalEntityParserCreate with buffer text
328        class ExternalOutputter:
329            def __init__(self, parser):
330                self.parser = parser
331                self.parser_result = None
332
333            def ExternalEntityRefHandler(self, context, base, sysId, pubId):
334                external_parser = self.parser.ExternalEntityParserCreate("")
335                self.parser_result = external_parser.Parse(b"", True)
336                return 1
337
338        parser = expat.ParserCreate(namespace_separator='!')
339        parser.buffer_text = 1
340        out = ExternalOutputter(parser)
341        parser.ExternalEntityRefHandler = out.ExternalEntityRefHandler
342        parser.Parse(data, True)
343        self.assertEqual(out.parser_result, 1)
344
345
346class BufferTextTest(unittest.TestCase):
347    def setUp(self):
348        self.stuff = []
349        self.parser = expat.ParserCreate()
350        self.parser.buffer_text = 1
351        self.parser.CharacterDataHandler = self.CharacterDataHandler
352
353    def check(self, expected, label):
354        self.assertEqual(self.stuff, expected,
355                "%s\nstuff    = %r\nexpected = %r"
356                % (label, self.stuff, map(str, expected)))
357
358    def CharacterDataHandler(self, text):
359        self.stuff.append(text)
360
361    def StartElementHandler(self, name, attrs):
362        self.stuff.append("<%s>" % name)
363        bt = attrs.get("buffer-text")
364        if bt == "yes":
365            self.parser.buffer_text = 1
366        elif bt == "no":
367            self.parser.buffer_text = 0
368
369    def EndElementHandler(self, name):
370        self.stuff.append("</%s>" % name)
371
372    def CommentHandler(self, data):
373        self.stuff.append("<!--%s-->" % data)
374
375    def setHandlers(self, handlers=[]):
376        for name in handlers:
377            setattr(self.parser, name, getattr(self, name))
378
379    def test_default_to_disabled(self):
380        parser = expat.ParserCreate()
381        self.assertFalse(parser.buffer_text)
382
383    def test_buffering_enabled(self):
384        # Make sure buffering is turned on
385        self.assertTrue(self.parser.buffer_text)
386        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", True)
387        self.assertEqual(self.stuff, ['123'],
388                         "buffered text not properly collapsed")
389
390    def test1(self):
391        # XXX This test exposes more detail of Expat's text chunking than we
392        # XXX like, but it tests what we need to concisely.
393        self.setHandlers(["StartElementHandler"])
394        self.parser.Parse(b"<a>1<b buffer-text='no'/>2\n3<c buffer-text='yes'/>4\n5</a>", True)
395        self.assertEqual(self.stuff,
396                         ["<a>", "1", "<b>", "2", "\n", "3", "<c>", "4\n5"],
397                         "buffering control not reacting as expected")
398
399    def test2(self):
400        self.parser.Parse(b"<a>1<b/>&lt;2&gt;<c/>&#32;\n&#x20;3</a>", True)
401        self.assertEqual(self.stuff, ["1<2> \n 3"],
402                         "buffered text not properly collapsed")
403
404    def test3(self):
405        self.setHandlers(["StartElementHandler"])
406        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", True)
407        self.assertEqual(self.stuff, ["<a>", "1", "<b>", "2", "<c>", "3"],
408                         "buffered text not properly split")
409
410    def test4(self):
411        self.setHandlers(["StartElementHandler", "EndElementHandler"])
412        self.parser.CharacterDataHandler = None
413        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", True)
414        self.assertEqual(self.stuff,
415                         ["<a>", "<b>", "</b>", "<c>", "</c>", "</a>"])
416
417    def test5(self):
418        self.setHandlers(["StartElementHandler", "EndElementHandler"])
419        self.parser.Parse(b"<a>1<b></b>2<c/>3</a>", True)
420        self.assertEqual(self.stuff,
421            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3", "</a>"])
422
423    def test6(self):
424        self.setHandlers(["CommentHandler", "EndElementHandler",
425                    "StartElementHandler"])
426        self.parser.Parse(b"<a>1<b/>2<c></c>345</a> ", True)
427        self.assertEqual(self.stuff,
428            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "345", "</a>"],
429            "buffered text not properly split")
430
431    def test7(self):
432        self.setHandlers(["CommentHandler", "EndElementHandler",
433                    "StartElementHandler"])
434        self.parser.Parse(b"<a>1<b/>2<c></c>3<!--abc-->4<!--def-->5</a> ", True)
435        self.assertEqual(self.stuff,
436                         ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3",
437                          "<!--abc-->", "4", "<!--def-->", "5", "</a>"],
438                         "buffered text not properly split")
439
440
441# Test handling of exception from callback:
442class HandlerExceptionTest(unittest.TestCase):
443    def StartElementHandler(self, name, attrs):
444        raise RuntimeError(name)
445
446    def check_traceback_entry(self, entry, filename, funcname):
447        self.assertEqual(os.path.basename(entry[0]), filename)
448        self.assertEqual(entry[2], funcname)
449
450    def test_exception(self):
451        parser = expat.ParserCreate()
452        parser.StartElementHandler = self.StartElementHandler
453        try:
454            parser.Parse(b"<a><b><c/></b></a>", True)
455            self.fail()
456        except RuntimeError as e:
457            self.assertEqual(e.args[0], 'a',
458                             "Expected RuntimeError for element 'a', but" + \
459                             " found %r" % e.args[0])
460            # Check that the traceback contains the relevant line in pyexpat.c
461            entries = traceback.extract_tb(e.__traceback__)
462            self.assertEqual(len(entries), 3)
463            self.check_traceback_entry(entries[0],
464                                       "test_pyexpat.py", "test_exception")
465            self.check_traceback_entry(entries[1],
466                                       "pyexpat.c", "StartElement")
467            self.check_traceback_entry(entries[2],
468                                       "test_pyexpat.py", "StartElementHandler")
469            if (sysconfig.is_python_build()
470                and not (sys.platform == 'win32' and platform.machine() == 'ARM')
471                and not is_emscripten
472                and not is_wasi
473            ):
474                self.assertIn('call_with_frame("StartElement"', entries[1][3])
475
476
477# Test Current* members:
478class PositionTest(unittest.TestCase):
479    def StartElementHandler(self, name, attrs):
480        self.check_pos('s')
481
482    def EndElementHandler(self, name):
483        self.check_pos('e')
484
485    def check_pos(self, event):
486        pos = (event,
487               self.parser.CurrentByteIndex,
488               self.parser.CurrentLineNumber,
489               self.parser.CurrentColumnNumber)
490        self.assertTrue(self.upto < len(self.expected_list),
491                        'too many parser events')
492        expected = self.expected_list[self.upto]
493        self.assertEqual(pos, expected,
494                'Expected position %s, got position %s' %(pos, expected))
495        self.upto += 1
496
497    def test(self):
498        self.parser = expat.ParserCreate()
499        self.parser.StartElementHandler = self.StartElementHandler
500        self.parser.EndElementHandler = self.EndElementHandler
501        self.upto = 0
502        self.expected_list = [('s', 0, 1, 0), ('s', 5, 2, 1), ('s', 11, 3, 2),
503                              ('e', 15, 3, 6), ('e', 17, 4, 1), ('e', 22, 5, 0)]
504
505        xml = b'<a>\n <b>\n  <c/>\n </b>\n</a>'
506        self.parser.Parse(xml, True)
507
508
509class sf1296433Test(unittest.TestCase):
510    def test_parse_only_xml_data(self):
511        # http://python.org/sf/1296433
512        #
513        xml = "<?xml version='1.0' encoding='iso8859'?><s>%s</s>" % ('a' * 1025)
514        # this one doesn't crash
515        #xml = "<?xml version='1.0'?><s>%s</s>" % ('a' * 10000)
516
517        class SpecificException(Exception):
518            pass
519
520        def handler(text):
521            raise SpecificException
522
523        parser = expat.ParserCreate()
524        parser.CharacterDataHandler = handler
525
526        self.assertRaises(Exception, parser.Parse, xml.encode('iso8859'))
527
528class ChardataBufferTest(unittest.TestCase):
529    """
530    test setting of chardata buffer size
531    """
532
533    def test_1025_bytes(self):
534        self.assertEqual(self.small_buffer_test(1025), 2)
535
536    def test_1000_bytes(self):
537        self.assertEqual(self.small_buffer_test(1000), 1)
538
539    def test_wrong_size(self):
540        parser = expat.ParserCreate()
541        parser.buffer_text = 1
542        with self.assertRaises(ValueError):
543            parser.buffer_size = -1
544        with self.assertRaises(ValueError):
545            parser.buffer_size = 0
546        with self.assertRaises((ValueError, OverflowError)):
547            parser.buffer_size = sys.maxsize + 1
548        with self.assertRaises(TypeError):
549            parser.buffer_size = 512.0
550
551    def test_unchanged_size(self):
552        xml1 = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * 512
553        xml2 = b'a'*512 + b'</s>'
554        parser = expat.ParserCreate()
555        parser.CharacterDataHandler = self.counting_handler
556        parser.buffer_size = 512
557        parser.buffer_text = 1
558
559        # Feed 512 bytes of character data: the handler should be called
560        # once.
561        self.n = 0
562        parser.Parse(xml1)
563        self.assertEqual(self.n, 1)
564
565        # Reassign to buffer_size, but assign the same size.
566        parser.buffer_size = parser.buffer_size
567        self.assertEqual(self.n, 1)
568
569        # Try parsing rest of the document
570        parser.Parse(xml2)
571        self.assertEqual(self.n, 2)
572
573
574    def test_disabling_buffer(self):
575        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>" + b'a' * 512
576        xml2 = b'b' * 1024
577        xml3 = b'c' * 1024 + b'</a>';
578        parser = expat.ParserCreate()
579        parser.CharacterDataHandler = self.counting_handler
580        parser.buffer_text = 1
581        parser.buffer_size = 1024
582        self.assertEqual(parser.buffer_size, 1024)
583
584        # Parse one chunk of XML
585        self.n = 0
586        parser.Parse(xml1, False)
587        self.assertEqual(parser.buffer_size, 1024)
588        self.assertEqual(self.n, 1)
589
590        # Turn off buffering and parse the next chunk.
591        parser.buffer_text = 0
592        self.assertFalse(parser.buffer_text)
593        self.assertEqual(parser.buffer_size, 1024)
594        for i in range(10):
595            parser.Parse(xml2, False)
596        self.assertEqual(self.n, 11)
597
598        parser.buffer_text = 1
599        self.assertTrue(parser.buffer_text)
600        self.assertEqual(parser.buffer_size, 1024)
601        parser.Parse(xml3, True)
602        self.assertEqual(self.n, 12)
603
604    def counting_handler(self, text):
605        self.n += 1
606
607    def small_buffer_test(self, buffer_len):
608        xml = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * buffer_len + b'</s>'
609        parser = expat.ParserCreate()
610        parser.CharacterDataHandler = self.counting_handler
611        parser.buffer_size = 1024
612        parser.buffer_text = 1
613
614        self.n = 0
615        parser.Parse(xml)
616        return self.n
617
618    def test_change_size_1(self):
619        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a><s>" + b'a' * 1024
620        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
621        parser = expat.ParserCreate()
622        parser.CharacterDataHandler = self.counting_handler
623        parser.buffer_text = 1
624        parser.buffer_size = 1024
625        self.assertEqual(parser.buffer_size, 1024)
626
627        self.n = 0
628        parser.Parse(xml1, False)
629        parser.buffer_size *= 2
630        self.assertEqual(parser.buffer_size, 2048)
631        parser.Parse(xml2, True)
632        self.assertEqual(self.n, 2)
633
634    def test_change_size_2(self):
635        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>a<s>" + b'a' * 1023
636        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
637        parser = expat.ParserCreate()
638        parser.CharacterDataHandler = self.counting_handler
639        parser.buffer_text = 1
640        parser.buffer_size = 2048
641        self.assertEqual(parser.buffer_size, 2048)
642
643        self.n=0
644        parser.Parse(xml1, False)
645        parser.buffer_size = parser.buffer_size // 2
646        self.assertEqual(parser.buffer_size, 1024)
647        parser.Parse(xml2, True)
648        self.assertEqual(self.n, 4)
649
650class MalformedInputTest(unittest.TestCase):
651    def test1(self):
652        xml = b"\0\r\n"
653        parser = expat.ParserCreate()
654        try:
655            parser.Parse(xml, True)
656            self.fail()
657        except expat.ExpatError as e:
658            self.assertEqual(str(e), 'unclosed token: line 2, column 0')
659
660    def test2(self):
661        # \xc2\x85 is UTF-8 encoded U+0085 (NEXT LINE)
662        xml = b"<?xml version\xc2\x85='1.0'?>\r\n"
663        parser = expat.ParserCreate()
664        err_pattern = r'XML declaration not well-formed: line 1, column \d+'
665        with self.assertRaisesRegex(expat.ExpatError, err_pattern):
666            parser.Parse(xml, True)
667
668class ErrorMessageTest(unittest.TestCase):
669    def test_codes(self):
670        # verify mapping of errors.codes and errors.messages
671        self.assertEqual(errors.XML_ERROR_SYNTAX,
672                         errors.messages[errors.codes[errors.XML_ERROR_SYNTAX]])
673
674    def test_expaterror(self):
675        xml = b'<'
676        parser = expat.ParserCreate()
677        try:
678            parser.Parse(xml, True)
679            self.fail()
680        except expat.ExpatError as e:
681            self.assertEqual(e.code,
682                             errors.codes[errors.XML_ERROR_UNCLOSED_TOKEN])
683
684
685class ForeignDTDTests(unittest.TestCase):
686    """
687    Tests for the UseForeignDTD method of expat parser objects.
688    """
689    def test_use_foreign_dtd(self):
690        """
691        If UseForeignDTD is passed True and a document without an external
692        entity reference is parsed, ExternalEntityRefHandler is first called
693        with None for the public and system ids.
694        """
695        handler_call_args = []
696        def resolve_entity(context, base, system_id, public_id):
697            handler_call_args.append((public_id, system_id))
698            return 1
699
700        parser = expat.ParserCreate()
701        parser.UseForeignDTD(True)
702        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
703        parser.ExternalEntityRefHandler = resolve_entity
704        parser.Parse(b"<?xml version='1.0'?><element/>")
705        self.assertEqual(handler_call_args, [(None, None)])
706
707        # test UseForeignDTD() is equal to UseForeignDTD(True)
708        handler_call_args[:] = []
709
710        parser = expat.ParserCreate()
711        parser.UseForeignDTD()
712        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
713        parser.ExternalEntityRefHandler = resolve_entity
714        parser.Parse(b"<?xml version='1.0'?><element/>")
715        self.assertEqual(handler_call_args, [(None, None)])
716
717    def test_ignore_use_foreign_dtd(self):
718        """
719        If UseForeignDTD is passed True and a document with an external
720        entity reference is parsed, ExternalEntityRefHandler is called with
721        the public and system ids from the document.
722        """
723        handler_call_args = []
724        def resolve_entity(context, base, system_id, public_id):
725            handler_call_args.append((public_id, system_id))
726            return 1
727
728        parser = expat.ParserCreate()
729        parser.UseForeignDTD(True)
730        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
731        parser.ExternalEntityRefHandler = resolve_entity
732        parser.Parse(
733            b"<?xml version='1.0'?><!DOCTYPE foo PUBLIC 'bar' 'baz'><element/>")
734        self.assertEqual(handler_call_args, [("bar", "baz")])
735
736
737if __name__ == "__main__":
738    unittest.main()
739