1# pysqlite2/test/userfunctions.py: tests for user-defined functions and
2#                                  aggregates.
3#
4# Copyright (C) 2005-2007 Gerhard Häring <[email protected]>
5#
6# This file is part of pysqlite.
7#
8# This software is provided 'as-is', without any express or implied
9# warranty.  In no event will the authors be held liable for any damages
10# arising from the use of this software.
11#
12# Permission is granted to anyone to use this software for any purpose,
13# including commercial applications, and to alter it and redistribute it
14# freely, subject to the following restrictions:
15#
16# 1. The origin of this software must not be misrepresented; you must not
17#    claim that you wrote the original software. If you use this software
18#    in a product, an acknowledgment in the product documentation would be
19#    appreciated but is not required.
20# 2. Altered source versions must be plainly marked as such, and must not be
21#    misrepresented as being the original software.
22# 3. This notice may not be removed or altered from any source distribution.
23
24import contextlib
25import functools
26import io
27import re
28import sys
29import unittest
30import sqlite3 as sqlite
31
32from unittest.mock import Mock, patch
33from test.support import bigmemtest, catch_unraisable_exception, gc_collect
34
35from test.test_sqlite3.test_dbapi import cx_limit
36
37
38def with_tracebacks(exc, regex="", name=""):
39    """Convenience decorator for testing callback tracebacks."""
40    def decorator(func):
41        _regex = re.compile(regex) if regex else None
42        @functools.wraps(func)
43        def wrapper(self, *args, **kwargs):
44            with catch_unraisable_exception() as cm:
45                # First, run the test with traceback enabled.
46                with check_tracebacks(self, cm, exc, _regex, name):
47                    func(self, *args, **kwargs)
48
49            # Then run the test with traceback disabled.
50            func(self, *args, **kwargs)
51        return wrapper
52    return decorator
53
54
55@contextlib.contextmanager
56def check_tracebacks(self, cm, exc, regex, obj_name):
57    """Convenience context manager for testing callback tracebacks."""
58    sqlite.enable_callback_tracebacks(True)
59    try:
60        buf = io.StringIO()
61        with contextlib.redirect_stderr(buf):
62            yield
63
64        self.assertEqual(cm.unraisable.exc_type, exc)
65        if regex:
66            msg = str(cm.unraisable.exc_value)
67            self.assertIsNotNone(regex.search(msg))
68        if obj_name:
69            self.assertEqual(cm.unraisable.object.__name__, obj_name)
70    finally:
71        sqlite.enable_callback_tracebacks(False)
72
73
74def func_returntext():
75    return "foo"
76def func_returntextwithnull():
77    return "1\x002"
78def func_returnunicode():
79    return "bar"
80def func_returnint():
81    return 42
82def func_returnfloat():
83    return 3.14
84def func_returnnull():
85    return None
86def func_returnblob():
87    return b"blob"
88def func_returnlonglong():
89    return 1<<31
90def func_raiseexception():
91    5/0
92def func_memoryerror():
93    raise MemoryError
94def func_overflowerror():
95    raise OverflowError
96
97class AggrNoStep:
98    def __init__(self):
99        pass
100
101    def finalize(self):
102        return 1
103
104class AggrNoFinalize:
105    def __init__(self):
106        pass
107
108    def step(self, x):
109        pass
110
111class AggrExceptionInInit:
112    def __init__(self):
113        5/0
114
115    def step(self, x):
116        pass
117
118    def finalize(self):
119        pass
120
121class AggrExceptionInStep:
122    def __init__(self):
123        pass
124
125    def step(self, x):
126        5/0
127
128    def finalize(self):
129        return 42
130
131class AggrExceptionInFinalize:
132    def __init__(self):
133        pass
134
135    def step(self, x):
136        pass
137
138    def finalize(self):
139        5/0
140
141class AggrCheckType:
142    def __init__(self):
143        self.val = None
144
145    def step(self, whichType, val):
146        theType = {"str": str, "int": int, "float": float, "None": type(None),
147                   "blob": bytes}
148        self.val = int(theType[whichType] is type(val))
149
150    def finalize(self):
151        return self.val
152
153class AggrCheckTypes:
154    def __init__(self):
155        self.val = 0
156
157    def step(self, whichType, *vals):
158        theType = {"str": str, "int": int, "float": float, "None": type(None),
159                   "blob": bytes}
160        for val in vals:
161            self.val += int(theType[whichType] is type(val))
162
163    def finalize(self):
164        return self.val
165
166class AggrSum:
167    def __init__(self):
168        self.val = 0.0
169
170    def step(self, val):
171        self.val += val
172
173    def finalize(self):
174        return self.val
175
176class AggrText:
177    def __init__(self):
178        self.txt = ""
179    def step(self, txt):
180        self.txt = self.txt + txt
181    def finalize(self):
182        return self.txt
183
184
185class FunctionTests(unittest.TestCase):
186    def setUp(self):
187        self.con = sqlite.connect(":memory:")
188
189        self.con.create_function("returntext", 0, func_returntext)
190        self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
191        self.con.create_function("returnunicode", 0, func_returnunicode)
192        self.con.create_function("returnint", 0, func_returnint)
193        self.con.create_function("returnfloat", 0, func_returnfloat)
194        self.con.create_function("returnnull", 0, func_returnnull)
195        self.con.create_function("returnblob", 0, func_returnblob)
196        self.con.create_function("returnlonglong", 0, func_returnlonglong)
197        self.con.create_function("returnnan", 0, lambda: float("nan"))
198        self.con.create_function("returntoolargeint", 0, lambda: 1 << 65)
199        self.con.create_function("return_noncont_blob", 0,
200                                 lambda: memoryview(b"blob")[::2])
201        self.con.create_function("raiseexception", 0, func_raiseexception)
202        self.con.create_function("memoryerror", 0, func_memoryerror)
203        self.con.create_function("overflowerror", 0, func_overflowerror)
204
205        self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes))
206        self.con.create_function("isnone", 1, lambda x: x is None)
207        self.con.create_function("spam", -1, lambda *x: len(x))
208        self.con.execute("create table test(t text)")
209
210    def tearDown(self):
211        self.con.close()
212
213    def test_func_error_on_create(self):
214        with self.assertRaises(sqlite.OperationalError):
215            self.con.create_function("bla", -100, lambda x: 2*x)
216
217    def test_func_too_many_args(self):
218        category = sqlite.SQLITE_LIMIT_FUNCTION_ARG
219        msg = "too many arguments on function"
220        with cx_limit(self.con, category=category, limit=1):
221            self.con.execute("select abs(-1)");
222            with self.assertRaisesRegex(sqlite.OperationalError, msg):
223                self.con.execute("select max(1, 2)");
224
225    def test_func_ref_count(self):
226        def getfunc():
227            def f():
228                return 1
229            return f
230        f = getfunc()
231        globals()["foo"] = f
232        # self.con.create_function("reftest", 0, getfunc())
233        self.con.create_function("reftest", 0, f)
234        cur = self.con.cursor()
235        cur.execute("select reftest()")
236
237    def test_func_return_text(self):
238        cur = self.con.cursor()
239        cur.execute("select returntext()")
240        val = cur.fetchone()[0]
241        self.assertEqual(type(val), str)
242        self.assertEqual(val, "foo")
243
244    def test_func_return_text_with_null_char(self):
245        cur = self.con.cursor()
246        res = cur.execute("select returntextwithnull()").fetchone()[0]
247        self.assertEqual(type(res), str)
248        self.assertEqual(res, "1\x002")
249
250    def test_func_return_unicode(self):
251        cur = self.con.cursor()
252        cur.execute("select returnunicode()")
253        val = cur.fetchone()[0]
254        self.assertEqual(type(val), str)
255        self.assertEqual(val, "bar")
256
257    def test_func_return_int(self):
258        cur = self.con.cursor()
259        cur.execute("select returnint()")
260        val = cur.fetchone()[0]
261        self.assertEqual(type(val), int)
262        self.assertEqual(val, 42)
263
264    def test_func_return_float(self):
265        cur = self.con.cursor()
266        cur.execute("select returnfloat()")
267        val = cur.fetchone()[0]
268        self.assertEqual(type(val), float)
269        if val < 3.139 or val > 3.141:
270            self.fail("wrong value")
271
272    def test_func_return_null(self):
273        cur = self.con.cursor()
274        cur.execute("select returnnull()")
275        val = cur.fetchone()[0]
276        self.assertEqual(type(val), type(None))
277        self.assertEqual(val, None)
278
279    def test_func_return_blob(self):
280        cur = self.con.cursor()
281        cur.execute("select returnblob()")
282        val = cur.fetchone()[0]
283        self.assertEqual(type(val), bytes)
284        self.assertEqual(val, b"blob")
285
286    def test_func_return_long_long(self):
287        cur = self.con.cursor()
288        cur.execute("select returnlonglong()")
289        val = cur.fetchone()[0]
290        self.assertEqual(val, 1<<31)
291
292    def test_func_return_nan(self):
293        cur = self.con.cursor()
294        cur.execute("select returnnan()")
295        self.assertIsNone(cur.fetchone()[0])
296
297    def test_func_return_too_large_int(self):
298        cur = self.con.cursor()
299        self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
300                               self.con.execute, "select returntoolargeint()")
301
302    @with_tracebacks(ZeroDivisionError, name="func_raiseexception")
303    def test_func_exception(self):
304        cur = self.con.cursor()
305        with self.assertRaises(sqlite.OperationalError) as cm:
306            cur.execute("select raiseexception()")
307            cur.fetchone()
308        self.assertEqual(str(cm.exception), 'user-defined function raised exception')
309
310    @with_tracebacks(MemoryError, name="func_memoryerror")
311    def test_func_memory_error(self):
312        cur = self.con.cursor()
313        with self.assertRaises(MemoryError):
314            cur.execute("select memoryerror()")
315            cur.fetchone()
316
317    @with_tracebacks(OverflowError, name="func_overflowerror")
318    def test_func_overflow_error(self):
319        cur = self.con.cursor()
320        with self.assertRaises(sqlite.DataError):
321            cur.execute("select overflowerror()")
322            cur.fetchone()
323
324    def test_any_arguments(self):
325        cur = self.con.cursor()
326        cur.execute("select spam(?, ?)", (1, 2))
327        val = cur.fetchone()[0]
328        self.assertEqual(val, 2)
329
330    def test_empty_blob(self):
331        cur = self.con.execute("select isblob(x'')")
332        self.assertTrue(cur.fetchone()[0])
333
334    def test_nan_float(self):
335        cur = self.con.execute("select isnone(?)", (float("nan"),))
336        # SQLite has no concept of nan; it is converted to NULL
337        self.assertTrue(cur.fetchone()[0])
338
339    def test_too_large_int(self):
340        err = "Python int too large to convert to SQLite INTEGER"
341        self.assertRaisesRegex(OverflowError, err, self.con.execute,
342                               "select spam(?)", (1 << 65,))
343
344    def test_non_contiguous_blob(self):
345        self.assertRaisesRegex(BufferError,
346                               "underlying buffer is not C-contiguous",
347                               self.con.execute, "select spam(?)",
348                               (memoryview(b"blob")[::2],))
349
350    @with_tracebacks(BufferError, regex="buffer.*contiguous")
351    def test_return_non_contiguous_blob(self):
352        with self.assertRaises(sqlite.OperationalError):
353            cur = self.con.execute("select return_noncont_blob()")
354            cur.fetchone()
355
356    def test_param_surrogates(self):
357        self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed",
358                               self.con.execute, "select spam(?)",
359                               ("\ud803\ude6d",))
360
361    def test_func_params(self):
362        results = []
363        def append_result(arg):
364            results.append((arg, type(arg)))
365        self.con.create_function("test_params", 1, append_result)
366
367        dataset = [
368            (42, int),
369            (-1, int),
370            (1234567890123456789, int),
371            (4611686018427387905, int),  # 63-bit int with non-zero low bits
372            (3.14, float),
373            (float('inf'), float),
374            ("text", str),
375            ("1\x002", str),
376            ("\u02e2q\u02e1\u2071\u1d57\u1d49", str),
377            (b"blob", bytes),
378            (bytearray(range(2)), bytes),
379            (memoryview(b"blob"), bytes),
380            (None, type(None)),
381        ]
382        for val, _ in dataset:
383            cur = self.con.execute("select test_params(?)", (val,))
384            cur.fetchone()
385        self.assertEqual(dataset, results)
386
387    # Regarding deterministic functions:
388    #
389    # Between 3.8.3 and 3.15.0, deterministic functions were only used to
390    # optimize inner loops, so for those versions we can only test if the
391    # sqlite machinery has factored out a call or not. From 3.15.0 and onward,
392    # deterministic functions were permitted in WHERE clauses of partial
393    # indices, which allows testing based on syntax, iso. the query optimizer.
394    @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
395    def test_func_non_deterministic(self):
396        mock = Mock(return_value=None)
397        self.con.create_function("nondeterministic", 0, mock, deterministic=False)
398        if sqlite.sqlite_version_info < (3, 15, 0):
399            self.con.execute("select nondeterministic() = nondeterministic()")
400            self.assertEqual(mock.call_count, 2)
401        else:
402            with self.assertRaises(sqlite.OperationalError):
403                self.con.execute("create index t on test(t) where nondeterministic() is not null")
404
405    @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
406    def test_func_deterministic(self):
407        mock = Mock(return_value=None)
408        self.con.create_function("deterministic", 0, mock, deterministic=True)
409        if sqlite.sqlite_version_info < (3, 15, 0):
410            self.con.execute("select deterministic() = deterministic()")
411            self.assertEqual(mock.call_count, 1)
412        else:
413            try:
414                self.con.execute("create index t on test(t) where deterministic() is not null")
415            except sqlite.OperationalError:
416                self.fail("Unexpected failure while creating partial index")
417
418    @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed")
419    def test_func_deterministic_not_supported(self):
420        with self.assertRaises(sqlite.NotSupportedError):
421            self.con.create_function("deterministic", 0, int, deterministic=True)
422
423    def test_func_deterministic_keyword_only(self):
424        with self.assertRaises(TypeError):
425            self.con.create_function("deterministic", 0, int, True)
426
427    def test_function_destructor_via_gc(self):
428        # See bpo-44304: The destructor of the user function can
429        # crash if is called without the GIL from the gc functions
430        dest = sqlite.connect(':memory:')
431        def md5sum(t):
432            return
433
434        dest.create_function("md5", 1, md5sum)
435        x = dest("create table lang (name, first_appeared)")
436        del md5sum, dest
437
438        y = [x]
439        y.append(y)
440
441        del x,y
442        gc_collect()
443
444    @with_tracebacks(OverflowError)
445    def test_func_return_too_large_int(self):
446        cur = self.con.cursor()
447        for value in 2**63, -2**63-1, 2**64:
448            self.con.create_function("largeint", 0, lambda value=value: value)
449            with self.assertRaises(sqlite.DataError):
450                cur.execute("select largeint()")
451
452    @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
453    def test_func_return_text_with_surrogates(self):
454        cur = self.con.cursor()
455        self.con.create_function("pychr", 1, chr)
456        for value in 0xd8ff, 0xdcff:
457            with self.assertRaises(sqlite.OperationalError):
458                cur.execute("select pychr(?)", (value,))
459
460    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
461    @bigmemtest(size=2**31, memuse=3, dry_run=False)
462    def test_func_return_too_large_text(self, size):
463        cur = self.con.cursor()
464        for size in 2**31-1, 2**31:
465            self.con.create_function("largetext", 0, lambda size=size: "b" * size)
466            with self.assertRaises(sqlite.DataError):
467                cur.execute("select largetext()")
468
469    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
470    @bigmemtest(size=2**31, memuse=2, dry_run=False)
471    def test_func_return_too_large_blob(self, size):
472        cur = self.con.cursor()
473        for size in 2**31-1, 2**31:
474            self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
475            with self.assertRaises(sqlite.DataError):
476                cur.execute("select largeblob()")
477
478    def test_func_return_illegal_value(self):
479        self.con.create_function("badreturn", 0, lambda: self)
480        msg = "user-defined function raised exception"
481        self.assertRaisesRegex(sqlite.OperationalError, msg,
482                               self.con.execute, "select badreturn()")
483
484
485class WindowSumInt:
486    def __init__(self):
487        self.count = 0
488
489    def step(self, value):
490        self.count += value
491
492    def value(self):
493        return self.count
494
495    def inverse(self, value):
496        self.count -= value
497
498    def finalize(self):
499        return self.count
500
501class BadWindow(Exception):
502    pass
503
504
505@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
506                 "Requires SQLite 3.25.0 or newer")
507class WindowFunctionTests(unittest.TestCase):
508    def setUp(self):
509        self.con = sqlite.connect(":memory:")
510        self.cur = self.con.cursor()
511
512        # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
513        values = [
514            ("a", 4),
515            ("b", 5),
516            ("c", 3),
517            ("d", 8),
518            ("e", 1),
519        ]
520        with self.con:
521            self.con.execute("create table test(x, y)")
522            self.con.executemany("insert into test values(?, ?)", values)
523        self.expected = [
524            ("a", 9),
525            ("b", 12),
526            ("c", 16),
527            ("d", 12),
528            ("e", 9),
529        ]
530        self.query = """
531            select x, %s(y) over (
532                order by x rows between 1 preceding and 1 following
533            ) as sum_y
534            from test order by x
535        """
536        self.con.create_window_function("sumint", 1, WindowSumInt)
537
538    def test_win_sum_int(self):
539        self.cur.execute(self.query % "sumint")
540        self.assertEqual(self.cur.fetchall(), self.expected)
541
542    def test_win_error_on_create(self):
543        self.assertRaises(sqlite.ProgrammingError,
544                          self.con.create_window_function,
545                          "shouldfail", -100, WindowSumInt)
546
547    @with_tracebacks(BadWindow)
548    def test_win_exception_in_method(self):
549        for meth in "__init__", "step", "value", "inverse":
550            with self.subTest(meth=meth):
551                with patch.object(WindowSumInt, meth, side_effect=BadWindow):
552                    name = f"exc_{meth}"
553                    self.con.create_window_function(name, 1, WindowSumInt)
554                    msg = f"'{meth}' method raised error"
555                    with self.assertRaisesRegex(sqlite.OperationalError, msg):
556                        self.cur.execute(self.query % name)
557                        self.cur.fetchall()
558
559    @with_tracebacks(BadWindow)
560    def test_win_exception_in_finalize(self):
561        # Note: SQLite does not (as of version 3.38.0) propagate finalize
562        # callback errors to sqlite3_step(); this implies that OperationalError
563        # is _not_ raised.
564        with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
565            name = f"exception_in_finalize"
566            self.con.create_window_function(name, 1, WindowSumInt)
567            self.cur.execute(self.query % name)
568            self.cur.fetchall()
569
570    @with_tracebacks(AttributeError)
571    def test_win_missing_method(self):
572        class MissingValue:
573            def step(self, x): pass
574            def inverse(self, x): pass
575            def finalize(self): return 42
576
577        class MissingInverse:
578            def step(self, x): pass
579            def value(self): return 42
580            def finalize(self): return 42
581
582        class MissingStep:
583            def value(self): return 42
584            def inverse(self, x): pass
585            def finalize(self): return 42
586
587        dataset = (
588            ("step", MissingStep),
589            ("value", MissingValue),
590            ("inverse", MissingInverse),
591        )
592        for meth, cls in dataset:
593            with self.subTest(meth=meth, cls=cls):
594                name = f"exc_{meth}"
595                self.con.create_window_function(name, 1, cls)
596                with self.assertRaisesRegex(sqlite.OperationalError,
597                                            f"'{meth}' method not defined"):
598                    self.cur.execute(self.query % name)
599                    self.cur.fetchall()
600
601    @with_tracebacks(AttributeError)
602    def test_win_missing_finalize(self):
603        # Note: SQLite does not (as of version 3.38.0) propagate finalize
604        # callback errors to sqlite3_step(); this implies that OperationalError
605        # is _not_ raised.
606        class MissingFinalize:
607            def step(self, x): pass
608            def value(self): return 42
609            def inverse(self, x): pass
610
611        name = "missing_finalize"
612        self.con.create_window_function(name, 1, MissingFinalize)
613        self.cur.execute(self.query % name)
614        self.cur.fetchall()
615
616    def test_win_clear_function(self):
617        self.con.create_window_function("sumint", 1, None)
618        self.assertRaises(sqlite.OperationalError, self.cur.execute,
619                          self.query % "sumint")
620
621    def test_win_redefine_function(self):
622        # Redefine WindowSumInt; adjust the expected results accordingly.
623        class Redefined(WindowSumInt):
624            def step(self, value): self.count += value * 2
625            def inverse(self, value): self.count -= value * 2
626        expected = [(v[0], v[1]*2) for v in self.expected]
627
628        self.con.create_window_function("sumint", 1, Redefined)
629        self.cur.execute(self.query % "sumint")
630        self.assertEqual(self.cur.fetchall(), expected)
631
632    def test_win_error_value_return(self):
633        class ErrorValueReturn:
634            def __init__(self): pass
635            def step(self, x): pass
636            def value(self): return 1 << 65
637
638        self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
639        self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
640                               self.cur.execute, self.query % "err_val_ret")
641
642
643class AggregateTests(unittest.TestCase):
644    def setUp(self):
645        self.con = sqlite.connect(":memory:")
646        cur = self.con.cursor()
647        cur.execute("""
648            create table test(
649                t text,
650                i integer,
651                f float,
652                n,
653                b blob
654                )
655            """)
656        cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
657            ("foo", 5, 3.14, None, memoryview(b"blob"),))
658
659        self.con.create_aggregate("nostep", 1, AggrNoStep)
660        self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
661        self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
662        self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
663        self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
664        self.con.create_aggregate("checkType", 2, AggrCheckType)
665        self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
666        self.con.create_aggregate("mysum", 1, AggrSum)
667        self.con.create_aggregate("aggtxt", 1, AggrText)
668
669    def tearDown(self):
670        #self.cur.close()
671        #self.con.close()
672        pass
673
674    def test_aggr_error_on_create(self):
675        with self.assertRaises(sqlite.OperationalError):
676            self.con.create_function("bla", -100, AggrSum)
677
678    @with_tracebacks(AttributeError, name="AggrNoStep")
679    def test_aggr_no_step(self):
680        cur = self.con.cursor()
681        with self.assertRaises(sqlite.OperationalError) as cm:
682            cur.execute("select nostep(t) from test")
683        self.assertEqual(str(cm.exception),
684                         "user-defined aggregate's 'step' method not defined")
685
686    def test_aggr_no_finalize(self):
687        cur = self.con.cursor()
688        msg = "user-defined aggregate's 'finalize' method not defined"
689        with self.assertRaisesRegex(sqlite.OperationalError, msg):
690            cur.execute("select nofinalize(t) from test")
691            val = cur.fetchone()[0]
692
693    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
694    def test_aggr_exception_in_init(self):
695        cur = self.con.cursor()
696        with self.assertRaises(sqlite.OperationalError) as cm:
697            cur.execute("select excInit(t) from test")
698            val = cur.fetchone()[0]
699        self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
700
701    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
702    def test_aggr_exception_in_step(self):
703        cur = self.con.cursor()
704        with self.assertRaises(sqlite.OperationalError) as cm:
705            cur.execute("select excStep(t) from test")
706            val = cur.fetchone()[0]
707        self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
708
709    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
710    def test_aggr_exception_in_finalize(self):
711        cur = self.con.cursor()
712        with self.assertRaises(sqlite.OperationalError) as cm:
713            cur.execute("select excFinalize(t) from test")
714            val = cur.fetchone()[0]
715        self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
716
717    def test_aggr_check_param_str(self):
718        cur = self.con.cursor()
719        cur.execute("select checkTypes('str', ?, ?)", ("foo", str()))
720        val = cur.fetchone()[0]
721        self.assertEqual(val, 2)
722
723    def test_aggr_check_param_int(self):
724        cur = self.con.cursor()
725        cur.execute("select checkType('int', ?)", (42,))
726        val = cur.fetchone()[0]
727        self.assertEqual(val, 1)
728
729    def test_aggr_check_params_int(self):
730        cur = self.con.cursor()
731        cur.execute("select checkTypes('int', ?, ?)", (42, 24))
732        val = cur.fetchone()[0]
733        self.assertEqual(val, 2)
734
735    def test_aggr_check_param_float(self):
736        cur = self.con.cursor()
737        cur.execute("select checkType('float', ?)", (3.14,))
738        val = cur.fetchone()[0]
739        self.assertEqual(val, 1)
740
741    def test_aggr_check_param_none(self):
742        cur = self.con.cursor()
743        cur.execute("select checkType('None', ?)", (None,))
744        val = cur.fetchone()[0]
745        self.assertEqual(val, 1)
746
747    def test_aggr_check_param_blob(self):
748        cur = self.con.cursor()
749        cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
750        val = cur.fetchone()[0]
751        self.assertEqual(val, 1)
752
753    def test_aggr_check_aggr_sum(self):
754        cur = self.con.cursor()
755        cur.execute("delete from test")
756        cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
757        cur.execute("select mysum(i) from test")
758        val = cur.fetchone()[0]
759        self.assertEqual(val, 60)
760
761    def test_aggr_no_match(self):
762        cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0")
763        val = cur.fetchone()[0]
764        self.assertIsNone(val)
765
766    def test_aggr_text(self):
767        cur = self.con.cursor()
768        for txt in ["foo", "1\x002"]:
769            with self.subTest(txt=txt):
770                cur.execute("select aggtxt(?) from test", (txt,))
771                val = cur.fetchone()[0]
772                self.assertEqual(val, txt)
773
774
775class AuthorizerTests(unittest.TestCase):
776    @staticmethod
777    def authorizer_cb(action, arg1, arg2, dbname, source):
778        if action != sqlite.SQLITE_SELECT:
779            return sqlite.SQLITE_DENY
780        if arg2 == 'c2' or arg1 == 't2':
781            return sqlite.SQLITE_DENY
782        return sqlite.SQLITE_OK
783
784    def setUp(self):
785        self.con = sqlite.connect(":memory:")
786        self.con.executescript("""
787            create table t1 (c1, c2);
788            create table t2 (c1, c2);
789            insert into t1 (c1, c2) values (1, 2);
790            insert into t2 (c1, c2) values (4, 5);
791            """)
792
793        # For our security test:
794        self.con.execute("select c2 from t2")
795
796        self.con.set_authorizer(self.authorizer_cb)
797
798    def tearDown(self):
799        pass
800
801    def test_table_access(self):
802        with self.assertRaises(sqlite.DatabaseError) as cm:
803            self.con.execute("select * from t2")
804        self.assertIn('prohibited', str(cm.exception))
805
806    def test_column_access(self):
807        with self.assertRaises(sqlite.DatabaseError) as cm:
808            self.con.execute("select c2 from t1")
809        self.assertIn('prohibited', str(cm.exception))
810
811    def test_clear_authorizer(self):
812        self.con.set_authorizer(None)
813        self.con.execute("select * from t2")
814        self.con.execute("select c2 from t1")
815
816
817class AuthorizerRaiseExceptionTests(AuthorizerTests):
818    @staticmethod
819    def authorizer_cb(action, arg1, arg2, dbname, source):
820        if action != sqlite.SQLITE_SELECT:
821            raise ValueError
822        if arg2 == 'c2' or arg1 == 't2':
823            raise ValueError
824        return sqlite.SQLITE_OK
825
826    @with_tracebacks(ValueError, name="authorizer_cb")
827    def test_table_access(self):
828        super().test_table_access()
829
830    @with_tracebacks(ValueError, name="authorizer_cb")
831    def test_column_access(self):
832        super().test_table_access()
833
834class AuthorizerIllegalTypeTests(AuthorizerTests):
835    @staticmethod
836    def authorizer_cb(action, arg1, arg2, dbname, source):
837        if action != sqlite.SQLITE_SELECT:
838            return 0.0
839        if arg2 == 'c2' or arg1 == 't2':
840            return 0.0
841        return sqlite.SQLITE_OK
842
843class AuthorizerLargeIntegerTests(AuthorizerTests):
844    @staticmethod
845    def authorizer_cb(action, arg1, arg2, dbname, source):
846        if action != sqlite.SQLITE_SELECT:
847            return 2**32
848        if arg2 == 'c2' or arg1 == 't2':
849            return 2**32
850        return sqlite.SQLITE_OK
851
852
853if __name__ == "__main__":
854    unittest.main()
855