1# Author: Paul Kippes <[email protected]>
2
3import unittest
4import sqlite3 as sqlite
5from .test_dbapi import memory_database
6
7
8class DumpTests(unittest.TestCase):
9    def setUp(self):
10        self.cx = sqlite.connect(":memory:")
11        self.cu = self.cx.cursor()
12
13    def tearDown(self):
14        self.cx.close()
15
16    def test_table_dump(self):
17        expected_sqls = [
18                """CREATE TABLE "index"("index" blob);"""
19                ,
20                """INSERT INTO "index" VALUES(X'01');"""
21                ,
22                """CREATE TABLE "quoted""table"("quoted""field" text);"""
23                ,
24                """INSERT INTO "quoted""table" VALUES('quoted''value');"""
25                ,
26                "CREATE TABLE t1(id integer primary key, s1 text, " \
27                "t1_i1 integer not null, i2 integer, unique (s1), " \
28                "constraint t1_idx1 unique (i2));"
29                ,
30                "INSERT INTO \"t1\" VALUES(1,'foo',10,20);"
31                ,
32                "INSERT INTO \"t1\" VALUES(2,'foo2',30,30);"
33                ,
34                "CREATE TABLE t2(id integer, t2_i1 integer, " \
35                "t2_i2 integer, primary key (id)," \
36                "foreign key(t2_i1) references t1(t1_i1));"
37                ,
38                "CREATE TRIGGER trigger_1 update of t1_i1 on t1 " \
39                "begin " \
40                "update t2 set t2_i1 = new.t1_i1 where t2_i1 = old.t1_i1; " \
41                "end;"
42                ,
43                "CREATE VIEW v1 as select * from t1 left join t2 " \
44                "using (id);"
45                ]
46        [self.cu.execute(s) for s in expected_sqls]
47        i = self.cx.iterdump()
48        actual_sqls = [s for s in i]
49        expected_sqls = ['BEGIN TRANSACTION;'] + expected_sqls + \
50            ['COMMIT;']
51        [self.assertEqual(expected_sqls[i], actual_sqls[i])
52            for i in range(len(expected_sqls))]
53
54    def test_dump_autoincrement(self):
55        expected = [
56            'CREATE TABLE "t1" (id integer primary key autoincrement);',
57            'INSERT INTO "t1" VALUES(NULL);',
58            'CREATE TABLE "t2" (id integer primary key autoincrement);',
59        ]
60        self.cu.executescript("".join(expected))
61
62        # the NULL value should now be automatically be set to 1
63        expected[1] = expected[1].replace("NULL", "1")
64        expected.insert(0, "BEGIN TRANSACTION;")
65        expected.extend([
66            'DELETE FROM "sqlite_sequence";',
67            'INSERT INTO "sqlite_sequence" VALUES(\'t1\',1);',
68            'COMMIT;',
69        ])
70
71        actual = [stmt for stmt in self.cx.iterdump()]
72        self.assertEqual(expected, actual)
73
74    def test_dump_autoincrement_create_new_db(self):
75        self.cu.execute("BEGIN TRANSACTION")
76        self.cu.execute("CREATE TABLE t1 (id integer primary key autoincrement)")
77        self.cu.execute("CREATE TABLE t2 (id integer primary key autoincrement)")
78        self.cu.executemany("INSERT INTO t1 VALUES(?)", ((None,) for _ in range(9)))
79        self.cu.executemany("INSERT INTO t2 VALUES(?)", ((None,) for _ in range(4)))
80        self.cx.commit()
81
82        with memory_database() as cx2:
83            query = "".join(self.cx.iterdump())
84            cx2.executescript(query)
85            cu2 = cx2.cursor()
86
87            dataset = (
88                ("t1", 9),
89                ("t2", 4),
90            )
91            for table, seq in dataset:
92                with self.subTest(table=table, seq=seq):
93                    res = cu2.execute("""
94                        SELECT "seq" FROM "sqlite_sequence" WHERE "name" == ?
95                    """, (table,))
96                    rows = res.fetchall()
97                    self.assertEqual(rows[0][0], seq)
98
99    def test_unorderable_row(self):
100        # iterdump() should be able to cope with unorderable row types (issue #15545)
101        class UnorderableRow:
102            def __init__(self, cursor, row):
103                self.row = row
104            def __getitem__(self, index):
105                return self.row[index]
106        self.cx.row_factory = UnorderableRow
107        CREATE_ALPHA = """CREATE TABLE "alpha" ("one");"""
108        CREATE_BETA = """CREATE TABLE "beta" ("two");"""
109        expected = [
110            "BEGIN TRANSACTION;",
111            CREATE_ALPHA,
112            CREATE_BETA,
113            "COMMIT;"
114            ]
115        self.cu.execute(CREATE_BETA)
116        self.cu.execute(CREATE_ALPHA)
117        got = list(self.cx.iterdump())
118        self.assertEqual(expected, got)
119
120
121if __name__ == "__main__":
122    unittest.main()
123