1# pysqlite2/test/userfunctions.py: tests for user-defined functions and 2# aggregates. 3# 4# Copyright (C) 2005-2007 Gerhard Häring <[email protected]> 5# 6# This file is part of pysqlite. 7# 8# This software is provided 'as-is', without any express or implied 9# warranty. In no event will the authors be held liable for any damages 10# arising from the use of this software. 11# 12# Permission is granted to anyone to use this software for any purpose, 13# including commercial applications, and to alter it and redistribute it 14# freely, subject to the following restrictions: 15# 16# 1. The origin of this software must not be misrepresented; you must not 17# claim that you wrote the original software. If you use this software 18# in a product, an acknowledgment in the product documentation would be 19# appreciated but is not required. 20# 2. Altered source versions must be plainly marked as such, and must not be 21# misrepresented as being the original software. 22# 3. This notice may not be removed or altered from any source distribution. 23 24import contextlib 25import functools 26import io 27import re 28import sys 29import unittest 30import sqlite3 as sqlite 31 32from unittest.mock import Mock, patch 33from test.support import bigmemtest, catch_unraisable_exception, gc_collect 34 35from test.test_sqlite3.test_dbapi import cx_limit 36 37 38def with_tracebacks(exc, regex="", name=""): 39 """Convenience decorator for testing callback tracebacks.""" 40 def decorator(func): 41 _regex = re.compile(regex) if regex else None 42 @functools.wraps(func) 43 def wrapper(self, *args, **kwargs): 44 with catch_unraisable_exception() as cm: 45 # First, run the test with traceback enabled. 46 with check_tracebacks(self, cm, exc, _regex, name): 47 func(self, *args, **kwargs) 48 49 # Then run the test with traceback disabled. 50 func(self, *args, **kwargs) 51 return wrapper 52 return decorator 53 54 55@contextlib.contextmanager 56def check_tracebacks(self, cm, exc, regex, obj_name): 57 """Convenience context manager for testing callback tracebacks.""" 58 sqlite.enable_callback_tracebacks(True) 59 try: 60 buf = io.StringIO() 61 with contextlib.redirect_stderr(buf): 62 yield 63 64 self.assertEqual(cm.unraisable.exc_type, exc) 65 if regex: 66 msg = str(cm.unraisable.exc_value) 67 self.assertIsNotNone(regex.search(msg)) 68 if obj_name: 69 self.assertEqual(cm.unraisable.object.__name__, obj_name) 70 finally: 71 sqlite.enable_callback_tracebacks(False) 72 73 74def func_returntext(): 75 return "foo" 76def func_returntextwithnull(): 77 return "1\x002" 78def func_returnunicode(): 79 return "bar" 80def func_returnint(): 81 return 42 82def func_returnfloat(): 83 return 3.14 84def func_returnnull(): 85 return None 86def func_returnblob(): 87 return b"blob" 88def func_returnlonglong(): 89 return 1<<31 90def func_raiseexception(): 91 5/0 92def func_memoryerror(): 93 raise MemoryError 94def func_overflowerror(): 95 raise OverflowError 96 97class AggrNoStep: 98 def __init__(self): 99 pass 100 101 def finalize(self): 102 return 1 103 104class AggrNoFinalize: 105 def __init__(self): 106 pass 107 108 def step(self, x): 109 pass 110 111class AggrExceptionInInit: 112 def __init__(self): 113 5/0 114 115 def step(self, x): 116 pass 117 118 def finalize(self): 119 pass 120 121class AggrExceptionInStep: 122 def __init__(self): 123 pass 124 125 def step(self, x): 126 5/0 127 128 def finalize(self): 129 return 42 130 131class AggrExceptionInFinalize: 132 def __init__(self): 133 pass 134 135 def step(self, x): 136 pass 137 138 def finalize(self): 139 5/0 140 141class AggrCheckType: 142 def __init__(self): 143 self.val = None 144 145 def step(self, whichType, val): 146 theType = {"str": str, "int": int, "float": float, "None": type(None), 147 "blob": bytes} 148 self.val = int(theType[whichType] is type(val)) 149 150 def finalize(self): 151 return self.val 152 153class AggrCheckTypes: 154 def __init__(self): 155 self.val = 0 156 157 def step(self, whichType, *vals): 158 theType = {"str": str, "int": int, "float": float, "None": type(None), 159 "blob": bytes} 160 for val in vals: 161 self.val += int(theType[whichType] is type(val)) 162 163 def finalize(self): 164 return self.val 165 166class AggrSum: 167 def __init__(self): 168 self.val = 0.0 169 170 def step(self, val): 171 self.val += val 172 173 def finalize(self): 174 return self.val 175 176class AggrText: 177 def __init__(self): 178 self.txt = "" 179 def step(self, txt): 180 self.txt = self.txt + txt 181 def finalize(self): 182 return self.txt 183 184 185class FunctionTests(unittest.TestCase): 186 def setUp(self): 187 self.con = sqlite.connect(":memory:") 188 189 self.con.create_function("returntext", 0, func_returntext) 190 self.con.create_function("returntextwithnull", 0, func_returntextwithnull) 191 self.con.create_function("returnunicode", 0, func_returnunicode) 192 self.con.create_function("returnint", 0, func_returnint) 193 self.con.create_function("returnfloat", 0, func_returnfloat) 194 self.con.create_function("returnnull", 0, func_returnnull) 195 self.con.create_function("returnblob", 0, func_returnblob) 196 self.con.create_function("returnlonglong", 0, func_returnlonglong) 197 self.con.create_function("returnnan", 0, lambda: float("nan")) 198 self.con.create_function("returntoolargeint", 0, lambda: 1 << 65) 199 self.con.create_function("return_noncont_blob", 0, 200 lambda: memoryview(b"blob")[::2]) 201 self.con.create_function("raiseexception", 0, func_raiseexception) 202 self.con.create_function("memoryerror", 0, func_memoryerror) 203 self.con.create_function("overflowerror", 0, func_overflowerror) 204 205 self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes)) 206 self.con.create_function("isnone", 1, lambda x: x is None) 207 self.con.create_function("spam", -1, lambda *x: len(x)) 208 self.con.execute("create table test(t text)") 209 210 def tearDown(self): 211 self.con.close() 212 213 def test_func_error_on_create(self): 214 with self.assertRaises(sqlite.OperationalError): 215 self.con.create_function("bla", -100, lambda x: 2*x) 216 217 def test_func_too_many_args(self): 218 category = sqlite.SQLITE_LIMIT_FUNCTION_ARG 219 msg = "too many arguments on function" 220 with cx_limit(self.con, category=category, limit=1): 221 self.con.execute("select abs(-1)"); 222 with self.assertRaisesRegex(sqlite.OperationalError, msg): 223 self.con.execute("select max(1, 2)"); 224 225 def test_func_ref_count(self): 226 def getfunc(): 227 def f(): 228 return 1 229 return f 230 f = getfunc() 231 globals()["foo"] = f 232 # self.con.create_function("reftest", 0, getfunc()) 233 self.con.create_function("reftest", 0, f) 234 cur = self.con.cursor() 235 cur.execute("select reftest()") 236 237 def test_func_return_text(self): 238 cur = self.con.cursor() 239 cur.execute("select returntext()") 240 val = cur.fetchone()[0] 241 self.assertEqual(type(val), str) 242 self.assertEqual(val, "foo") 243 244 def test_func_return_text_with_null_char(self): 245 cur = self.con.cursor() 246 res = cur.execute("select returntextwithnull()").fetchone()[0] 247 self.assertEqual(type(res), str) 248 self.assertEqual(res, "1\x002") 249 250 def test_func_return_unicode(self): 251 cur = self.con.cursor() 252 cur.execute("select returnunicode()") 253 val = cur.fetchone()[0] 254 self.assertEqual(type(val), str) 255 self.assertEqual(val, "bar") 256 257 def test_func_return_int(self): 258 cur = self.con.cursor() 259 cur.execute("select returnint()") 260 val = cur.fetchone()[0] 261 self.assertEqual(type(val), int) 262 self.assertEqual(val, 42) 263 264 def test_func_return_float(self): 265 cur = self.con.cursor() 266 cur.execute("select returnfloat()") 267 val = cur.fetchone()[0] 268 self.assertEqual(type(val), float) 269 if val < 3.139 or val > 3.141: 270 self.fail("wrong value") 271 272 def test_func_return_null(self): 273 cur = self.con.cursor() 274 cur.execute("select returnnull()") 275 val = cur.fetchone()[0] 276 self.assertEqual(type(val), type(None)) 277 self.assertEqual(val, None) 278 279 def test_func_return_blob(self): 280 cur = self.con.cursor() 281 cur.execute("select returnblob()") 282 val = cur.fetchone()[0] 283 self.assertEqual(type(val), bytes) 284 self.assertEqual(val, b"blob") 285 286 def test_func_return_long_long(self): 287 cur = self.con.cursor() 288 cur.execute("select returnlonglong()") 289 val = cur.fetchone()[0] 290 self.assertEqual(val, 1<<31) 291 292 def test_func_return_nan(self): 293 cur = self.con.cursor() 294 cur.execute("select returnnan()") 295 self.assertIsNone(cur.fetchone()[0]) 296 297 def test_func_return_too_large_int(self): 298 cur = self.con.cursor() 299 self.assertRaisesRegex(sqlite.DataError, "string or blob too big", 300 self.con.execute, "select returntoolargeint()") 301 302 @with_tracebacks(ZeroDivisionError, name="func_raiseexception") 303 def test_func_exception(self): 304 cur = self.con.cursor() 305 with self.assertRaises(sqlite.OperationalError) as cm: 306 cur.execute("select raiseexception()") 307 cur.fetchone() 308 self.assertEqual(str(cm.exception), 'user-defined function raised exception') 309 310 @with_tracebacks(MemoryError, name="func_memoryerror") 311 def test_func_memory_error(self): 312 cur = self.con.cursor() 313 with self.assertRaises(MemoryError): 314 cur.execute("select memoryerror()") 315 cur.fetchone() 316 317 @with_tracebacks(OverflowError, name="func_overflowerror") 318 def test_func_overflow_error(self): 319 cur = self.con.cursor() 320 with self.assertRaises(sqlite.DataError): 321 cur.execute("select overflowerror()") 322 cur.fetchone() 323 324 def test_any_arguments(self): 325 cur = self.con.cursor() 326 cur.execute("select spam(?, ?)", (1, 2)) 327 val = cur.fetchone()[0] 328 self.assertEqual(val, 2) 329 330 def test_empty_blob(self): 331 cur = self.con.execute("select isblob(x'')") 332 self.assertTrue(cur.fetchone()[0]) 333 334 def test_nan_float(self): 335 cur = self.con.execute("select isnone(?)", (float("nan"),)) 336 # SQLite has no concept of nan; it is converted to NULL 337 self.assertTrue(cur.fetchone()[0]) 338 339 def test_too_large_int(self): 340 err = "Python int too large to convert to SQLite INTEGER" 341 self.assertRaisesRegex(OverflowError, err, self.con.execute, 342 "select spam(?)", (1 << 65,)) 343 344 def test_non_contiguous_blob(self): 345 self.assertRaisesRegex(BufferError, 346 "underlying buffer is not C-contiguous", 347 self.con.execute, "select spam(?)", 348 (memoryview(b"blob")[::2],)) 349 350 @with_tracebacks(BufferError, regex="buffer.*contiguous") 351 def test_return_non_contiguous_blob(self): 352 with self.assertRaises(sqlite.OperationalError): 353 cur = self.con.execute("select return_noncont_blob()") 354 cur.fetchone() 355 356 def test_param_surrogates(self): 357 self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed", 358 self.con.execute, "select spam(?)", 359 ("\ud803\ude6d",)) 360 361 def test_func_params(self): 362 results = [] 363 def append_result(arg): 364 results.append((arg, type(arg))) 365 self.con.create_function("test_params", 1, append_result) 366 367 dataset = [ 368 (42, int), 369 (-1, int), 370 (1234567890123456789, int), 371 (4611686018427387905, int), # 63-bit int with non-zero low bits 372 (3.14, float), 373 (float('inf'), float), 374 ("text", str), 375 ("1\x002", str), 376 ("\u02e2q\u02e1\u2071\u1d57\u1d49", str), 377 (b"blob", bytes), 378 (bytearray(range(2)), bytes), 379 (memoryview(b"blob"), bytes), 380 (None, type(None)), 381 ] 382 for val, _ in dataset: 383 cur = self.con.execute("select test_params(?)", (val,)) 384 cur.fetchone() 385 self.assertEqual(dataset, results) 386 387 # Regarding deterministic functions: 388 # 389 # Between 3.8.3 and 3.15.0, deterministic functions were only used to 390 # optimize inner loops, so for those versions we can only test if the 391 # sqlite machinery has factored out a call or not. From 3.15.0 and onward, 392 # deterministic functions were permitted in WHERE clauses of partial 393 # indices, which allows testing based on syntax, iso. the query optimizer. 394 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") 395 def test_func_non_deterministic(self): 396 mock = Mock(return_value=None) 397 self.con.create_function("nondeterministic", 0, mock, deterministic=False) 398 if sqlite.sqlite_version_info < (3, 15, 0): 399 self.con.execute("select nondeterministic() = nondeterministic()") 400 self.assertEqual(mock.call_count, 2) 401 else: 402 with self.assertRaises(sqlite.OperationalError): 403 self.con.execute("create index t on test(t) where nondeterministic() is not null") 404 405 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") 406 def test_func_deterministic(self): 407 mock = Mock(return_value=None) 408 self.con.create_function("deterministic", 0, mock, deterministic=True) 409 if sqlite.sqlite_version_info < (3, 15, 0): 410 self.con.execute("select deterministic() = deterministic()") 411 self.assertEqual(mock.call_count, 1) 412 else: 413 try: 414 self.con.execute("create index t on test(t) where deterministic() is not null") 415 except sqlite.OperationalError: 416 self.fail("Unexpected failure while creating partial index") 417 418 @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed") 419 def test_func_deterministic_not_supported(self): 420 with self.assertRaises(sqlite.NotSupportedError): 421 self.con.create_function("deterministic", 0, int, deterministic=True) 422 423 def test_func_deterministic_keyword_only(self): 424 with self.assertRaises(TypeError): 425 self.con.create_function("deterministic", 0, int, True) 426 427 def test_function_destructor_via_gc(self): 428 # See bpo-44304: The destructor of the user function can 429 # crash if is called without the GIL from the gc functions 430 dest = sqlite.connect(':memory:') 431 def md5sum(t): 432 return 433 434 dest.create_function("md5", 1, md5sum) 435 x = dest("create table lang (name, first_appeared)") 436 del md5sum, dest 437 438 y = [x] 439 y.append(y) 440 441 del x,y 442 gc_collect() 443 444 @with_tracebacks(OverflowError) 445 def test_func_return_too_large_int(self): 446 cur = self.con.cursor() 447 for value in 2**63, -2**63-1, 2**64: 448 self.con.create_function("largeint", 0, lambda value=value: value) 449 with self.assertRaises(sqlite.DataError): 450 cur.execute("select largeint()") 451 452 @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr") 453 def test_func_return_text_with_surrogates(self): 454 cur = self.con.cursor() 455 self.con.create_function("pychr", 1, chr) 456 for value in 0xd8ff, 0xdcff: 457 with self.assertRaises(sqlite.OperationalError): 458 cur.execute("select pychr(?)", (value,)) 459 460 @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') 461 @bigmemtest(size=2**31, memuse=3, dry_run=False) 462 def test_func_return_too_large_text(self, size): 463 cur = self.con.cursor() 464 for size in 2**31-1, 2**31: 465 self.con.create_function("largetext", 0, lambda size=size: "b" * size) 466 with self.assertRaises(sqlite.DataError): 467 cur.execute("select largetext()") 468 469 @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') 470 @bigmemtest(size=2**31, memuse=2, dry_run=False) 471 def test_func_return_too_large_blob(self, size): 472 cur = self.con.cursor() 473 for size in 2**31-1, 2**31: 474 self.con.create_function("largeblob", 0, lambda size=size: b"b" * size) 475 with self.assertRaises(sqlite.DataError): 476 cur.execute("select largeblob()") 477 478 def test_func_return_illegal_value(self): 479 self.con.create_function("badreturn", 0, lambda: self) 480 msg = "user-defined function raised exception" 481 self.assertRaisesRegex(sqlite.OperationalError, msg, 482 self.con.execute, "select badreturn()") 483 484 485class WindowSumInt: 486 def __init__(self): 487 self.count = 0 488 489 def step(self, value): 490 self.count += value 491 492 def value(self): 493 return self.count 494 495 def inverse(self, value): 496 self.count -= value 497 498 def finalize(self): 499 return self.count 500 501class BadWindow(Exception): 502 pass 503 504 505@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0), 506 "Requires SQLite 3.25.0 or newer") 507class WindowFunctionTests(unittest.TestCase): 508 def setUp(self): 509 self.con = sqlite.connect(":memory:") 510 self.cur = self.con.cursor() 511 512 # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc 513 values = [ 514 ("a", 4), 515 ("b", 5), 516 ("c", 3), 517 ("d", 8), 518 ("e", 1), 519 ] 520 with self.con: 521 self.con.execute("create table test(x, y)") 522 self.con.executemany("insert into test values(?, ?)", values) 523 self.expected = [ 524 ("a", 9), 525 ("b", 12), 526 ("c", 16), 527 ("d", 12), 528 ("e", 9), 529 ] 530 self.query = """ 531 select x, %s(y) over ( 532 order by x rows between 1 preceding and 1 following 533 ) as sum_y 534 from test order by x 535 """ 536 self.con.create_window_function("sumint", 1, WindowSumInt) 537 538 def test_win_sum_int(self): 539 self.cur.execute(self.query % "sumint") 540 self.assertEqual(self.cur.fetchall(), self.expected) 541 542 def test_win_error_on_create(self): 543 self.assertRaises(sqlite.ProgrammingError, 544 self.con.create_window_function, 545 "shouldfail", -100, WindowSumInt) 546 547 @with_tracebacks(BadWindow) 548 def test_win_exception_in_method(self): 549 for meth in "__init__", "step", "value", "inverse": 550 with self.subTest(meth=meth): 551 with patch.object(WindowSumInt, meth, side_effect=BadWindow): 552 name = f"exc_{meth}" 553 self.con.create_window_function(name, 1, WindowSumInt) 554 msg = f"'{meth}' method raised error" 555 with self.assertRaisesRegex(sqlite.OperationalError, msg): 556 self.cur.execute(self.query % name) 557 self.cur.fetchall() 558 559 @with_tracebacks(BadWindow) 560 def test_win_exception_in_finalize(self): 561 # Note: SQLite does not (as of version 3.38.0) propagate finalize 562 # callback errors to sqlite3_step(); this implies that OperationalError 563 # is _not_ raised. 564 with patch.object(WindowSumInt, "finalize", side_effect=BadWindow): 565 name = f"exception_in_finalize" 566 self.con.create_window_function(name, 1, WindowSumInt) 567 self.cur.execute(self.query % name) 568 self.cur.fetchall() 569 570 @with_tracebacks(AttributeError) 571 def test_win_missing_method(self): 572 class MissingValue: 573 def step(self, x): pass 574 def inverse(self, x): pass 575 def finalize(self): return 42 576 577 class MissingInverse: 578 def step(self, x): pass 579 def value(self): return 42 580 def finalize(self): return 42 581 582 class MissingStep: 583 def value(self): return 42 584 def inverse(self, x): pass 585 def finalize(self): return 42 586 587 dataset = ( 588 ("step", MissingStep), 589 ("value", MissingValue), 590 ("inverse", MissingInverse), 591 ) 592 for meth, cls in dataset: 593 with self.subTest(meth=meth, cls=cls): 594 name = f"exc_{meth}" 595 self.con.create_window_function(name, 1, cls) 596 with self.assertRaisesRegex(sqlite.OperationalError, 597 f"'{meth}' method not defined"): 598 self.cur.execute(self.query % name) 599 self.cur.fetchall() 600 601 @with_tracebacks(AttributeError) 602 def test_win_missing_finalize(self): 603 # Note: SQLite does not (as of version 3.38.0) propagate finalize 604 # callback errors to sqlite3_step(); this implies that OperationalError 605 # is _not_ raised. 606 class MissingFinalize: 607 def step(self, x): pass 608 def value(self): return 42 609 def inverse(self, x): pass 610 611 name = "missing_finalize" 612 self.con.create_window_function(name, 1, MissingFinalize) 613 self.cur.execute(self.query % name) 614 self.cur.fetchall() 615 616 def test_win_clear_function(self): 617 self.con.create_window_function("sumint", 1, None) 618 self.assertRaises(sqlite.OperationalError, self.cur.execute, 619 self.query % "sumint") 620 621 def test_win_redefine_function(self): 622 # Redefine WindowSumInt; adjust the expected results accordingly. 623 class Redefined(WindowSumInt): 624 def step(self, value): self.count += value * 2 625 def inverse(self, value): self.count -= value * 2 626 expected = [(v[0], v[1]*2) for v in self.expected] 627 628 self.con.create_window_function("sumint", 1, Redefined) 629 self.cur.execute(self.query % "sumint") 630 self.assertEqual(self.cur.fetchall(), expected) 631 632 def test_win_error_value_return(self): 633 class ErrorValueReturn: 634 def __init__(self): pass 635 def step(self, x): pass 636 def value(self): return 1 << 65 637 638 self.con.create_window_function("err_val_ret", 1, ErrorValueReturn) 639 self.assertRaisesRegex(sqlite.DataError, "string or blob too big", 640 self.cur.execute, self.query % "err_val_ret") 641 642 643class AggregateTests(unittest.TestCase): 644 def setUp(self): 645 self.con = sqlite.connect(":memory:") 646 cur = self.con.cursor() 647 cur.execute(""" 648 create table test( 649 t text, 650 i integer, 651 f float, 652 n, 653 b blob 654 ) 655 """) 656 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 657 ("foo", 5, 3.14, None, memoryview(b"blob"),)) 658 659 self.con.create_aggregate("nostep", 1, AggrNoStep) 660 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 661 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 662 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 663 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 664 self.con.create_aggregate("checkType", 2, AggrCheckType) 665 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) 666 self.con.create_aggregate("mysum", 1, AggrSum) 667 self.con.create_aggregate("aggtxt", 1, AggrText) 668 669 def tearDown(self): 670 #self.cur.close() 671 #self.con.close() 672 pass 673 674 def test_aggr_error_on_create(self): 675 with self.assertRaises(sqlite.OperationalError): 676 self.con.create_function("bla", -100, AggrSum) 677 678 @with_tracebacks(AttributeError, name="AggrNoStep") 679 def test_aggr_no_step(self): 680 cur = self.con.cursor() 681 with self.assertRaises(sqlite.OperationalError) as cm: 682 cur.execute("select nostep(t) from test") 683 self.assertEqual(str(cm.exception), 684 "user-defined aggregate's 'step' method not defined") 685 686 def test_aggr_no_finalize(self): 687 cur = self.con.cursor() 688 msg = "user-defined aggregate's 'finalize' method not defined" 689 with self.assertRaisesRegex(sqlite.OperationalError, msg): 690 cur.execute("select nofinalize(t) from test") 691 val = cur.fetchone()[0] 692 693 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") 694 def test_aggr_exception_in_init(self): 695 cur = self.con.cursor() 696 with self.assertRaises(sqlite.OperationalError) as cm: 697 cur.execute("select excInit(t) from test") 698 val = cur.fetchone()[0] 699 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") 700 701 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep") 702 def test_aggr_exception_in_step(self): 703 cur = self.con.cursor() 704 with self.assertRaises(sqlite.OperationalError) as cm: 705 cur.execute("select excStep(t) from test") 706 val = cur.fetchone()[0] 707 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") 708 709 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize") 710 def test_aggr_exception_in_finalize(self): 711 cur = self.con.cursor() 712 with self.assertRaises(sqlite.OperationalError) as cm: 713 cur.execute("select excFinalize(t) from test") 714 val = cur.fetchone()[0] 715 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 716 717 def test_aggr_check_param_str(self): 718 cur = self.con.cursor() 719 cur.execute("select checkTypes('str', ?, ?)", ("foo", str())) 720 val = cur.fetchone()[0] 721 self.assertEqual(val, 2) 722 723 def test_aggr_check_param_int(self): 724 cur = self.con.cursor() 725 cur.execute("select checkType('int', ?)", (42,)) 726 val = cur.fetchone()[0] 727 self.assertEqual(val, 1) 728 729 def test_aggr_check_params_int(self): 730 cur = self.con.cursor() 731 cur.execute("select checkTypes('int', ?, ?)", (42, 24)) 732 val = cur.fetchone()[0] 733 self.assertEqual(val, 2) 734 735 def test_aggr_check_param_float(self): 736 cur = self.con.cursor() 737 cur.execute("select checkType('float', ?)", (3.14,)) 738 val = cur.fetchone()[0] 739 self.assertEqual(val, 1) 740 741 def test_aggr_check_param_none(self): 742 cur = self.con.cursor() 743 cur.execute("select checkType('None', ?)", (None,)) 744 val = cur.fetchone()[0] 745 self.assertEqual(val, 1) 746 747 def test_aggr_check_param_blob(self): 748 cur = self.con.cursor() 749 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) 750 val = cur.fetchone()[0] 751 self.assertEqual(val, 1) 752 753 def test_aggr_check_aggr_sum(self): 754 cur = self.con.cursor() 755 cur.execute("delete from test") 756 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 757 cur.execute("select mysum(i) from test") 758 val = cur.fetchone()[0] 759 self.assertEqual(val, 60) 760 761 def test_aggr_no_match(self): 762 cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0") 763 val = cur.fetchone()[0] 764 self.assertIsNone(val) 765 766 def test_aggr_text(self): 767 cur = self.con.cursor() 768 for txt in ["foo", "1\x002"]: 769 with self.subTest(txt=txt): 770 cur.execute("select aggtxt(?) from test", (txt,)) 771 val = cur.fetchone()[0] 772 self.assertEqual(val, txt) 773 774 775class AuthorizerTests(unittest.TestCase): 776 @staticmethod 777 def authorizer_cb(action, arg1, arg2, dbname, source): 778 if action != sqlite.SQLITE_SELECT: 779 return sqlite.SQLITE_DENY 780 if arg2 == 'c2' or arg1 == 't2': 781 return sqlite.SQLITE_DENY 782 return sqlite.SQLITE_OK 783 784 def setUp(self): 785 self.con = sqlite.connect(":memory:") 786 self.con.executescript(""" 787 create table t1 (c1, c2); 788 create table t2 (c1, c2); 789 insert into t1 (c1, c2) values (1, 2); 790 insert into t2 (c1, c2) values (4, 5); 791 """) 792 793 # For our security test: 794 self.con.execute("select c2 from t2") 795 796 self.con.set_authorizer(self.authorizer_cb) 797 798 def tearDown(self): 799 pass 800 801 def test_table_access(self): 802 with self.assertRaises(sqlite.DatabaseError) as cm: 803 self.con.execute("select * from t2") 804 self.assertIn('prohibited', str(cm.exception)) 805 806 def test_column_access(self): 807 with self.assertRaises(sqlite.DatabaseError) as cm: 808 self.con.execute("select c2 from t1") 809 self.assertIn('prohibited', str(cm.exception)) 810 811 def test_clear_authorizer(self): 812 self.con.set_authorizer(None) 813 self.con.execute("select * from t2") 814 self.con.execute("select c2 from t1") 815 816 817class AuthorizerRaiseExceptionTests(AuthorizerTests): 818 @staticmethod 819 def authorizer_cb(action, arg1, arg2, dbname, source): 820 if action != sqlite.SQLITE_SELECT: 821 raise ValueError 822 if arg2 == 'c2' or arg1 == 't2': 823 raise ValueError 824 return sqlite.SQLITE_OK 825 826 @with_tracebacks(ValueError, name="authorizer_cb") 827 def test_table_access(self): 828 super().test_table_access() 829 830 @with_tracebacks(ValueError, name="authorizer_cb") 831 def test_column_access(self): 832 super().test_table_access() 833 834class AuthorizerIllegalTypeTests(AuthorizerTests): 835 @staticmethod 836 def authorizer_cb(action, arg1, arg2, dbname, source): 837 if action != sqlite.SQLITE_SELECT: 838 return 0.0 839 if arg2 == 'c2' or arg1 == 't2': 840 return 0.0 841 return sqlite.SQLITE_OK 842 843class AuthorizerLargeIntegerTests(AuthorizerTests): 844 @staticmethod 845 def authorizer_cb(action, arg1, arg2, dbname, source): 846 if action != sqlite.SQLITE_SELECT: 847 return 2**32 848 if arg2 == 'c2' or arg1 == 't2': 849 return 2**32 850 return sqlite.SQLITE_OK 851 852 853if __name__ == "__main__": 854 unittest.main() 855