1 /* statement.c - the statement type
2  *
3  * Copyright (C) 2005-2010 Gerhard Häring <[email protected]>
4  *
5  * This file is part of pysqlite.
6  *
7  * This software is provided 'as-is', without any express or implied
8  * warranty.  In no event will the authors be held liable for any damages
9  * arising from the use of this software.
10  *
11  * Permission is granted to anyone to use this software for any purpose,
12  * including commercial applications, and to alter it and redistribute it
13  * freely, subject to the following restrictions:
14  *
15  * 1. The origin of this software must not be misrepresented; you must not
16  *    claim that you wrote the original software. If you use this software
17  *    in a product, an acknowledgment in the product documentation would be
18  *    appreciated but is not required.
19  * 2. Altered source versions must be plainly marked as such, and must not be
20  *    misrepresented as being the original software.
21  * 3. This notice may not be removed or altered from any source distribution.
22  */
23 
24 #include "connection.h"
25 #include "statement.h"
26 #include "util.h"
27 
28 /* prototypes */
29 static const char *lstrip_sql(const char *sql);
30 
31 pysqlite_Statement *
pysqlite_statement_create(pysqlite_Connection * connection,PyObject * sql)32 pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
33 {
34     pysqlite_state *state = connection->state;
35     assert(PyUnicode_Check(sql));
36     Py_ssize_t size;
37     const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size);
38     if (sql_cstr == NULL) {
39         return NULL;
40     }
41 
42     sqlite3 *db = connection->db;
43     int max_length = sqlite3_limit(db, SQLITE_LIMIT_SQL_LENGTH, -1);
44     if (size > max_length) {
45         PyErr_SetString(connection->DataError,
46                         "query string is too large");
47         return NULL;
48     }
49     if (strlen(sql_cstr) != (size_t)size) {
50         PyErr_SetString(connection->ProgrammingError,
51                         "the query contains a null character");
52         return NULL;
53     }
54 
55     sqlite3_stmt *stmt;
56     const char *tail;
57     int rc;
58     Py_BEGIN_ALLOW_THREADS
59     rc = sqlite3_prepare_v2(db, sql_cstr, (int)size + 1, &stmt, &tail);
60     Py_END_ALLOW_THREADS
61 
62     if (rc != SQLITE_OK) {
63         _pysqlite_seterror(state, db);
64         return NULL;
65     }
66 
67     if (lstrip_sql(tail) != NULL) {
68         PyErr_SetString(connection->ProgrammingError,
69                         "You can only execute one statement at a time.");
70         goto error;
71     }
72 
73     /* Determine if the statement is a DML statement.
74        SELECT is the only exception. See #9924. */
75     int is_dml = 0;
76     const char *p = lstrip_sql(sql_cstr);
77     if (p != NULL) {
78         is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
79                   || (PyOS_strnicmp(p, "update", 6) == 0)
80                   || (PyOS_strnicmp(p, "delete", 6) == 0)
81                   || (PyOS_strnicmp(p, "replace", 7) == 0);
82     }
83 
84     pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
85                                                state->StatementType);
86     if (self == NULL) {
87         goto error;
88     }
89 
90     self->st = stmt;
91     self->in_use = 0;
92     self->is_dml = is_dml;
93 
94     PyObject_GC_Track(self);
95     return self;
96 
97 error:
98     (void)sqlite3_finalize(stmt);
99     return NULL;
100 }
101 
102 static void
stmt_dealloc(pysqlite_Statement * self)103 stmt_dealloc(pysqlite_Statement *self)
104 {
105     PyTypeObject *tp = Py_TYPE(self);
106     PyObject_GC_UnTrack(self);
107     if (self->st) {
108         Py_BEGIN_ALLOW_THREADS
109         sqlite3_finalize(self->st);
110         Py_END_ALLOW_THREADS
111         self->st = 0;
112     }
113     tp->tp_free(self);
114     Py_DECREF(tp);
115 }
116 
117 static int
stmt_traverse(pysqlite_Statement * self,visitproc visit,void * arg)118 stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
119 {
120     Py_VISIT(Py_TYPE(self));
121     return 0;
122 }
123 
124 /*
125  * Strip leading whitespace and comments from incoming SQL (null terminated C
126  * string) and return a pointer to the first non-whitespace, non-comment
127  * character.
128  *
129  * This is used to check if somebody tries to execute more than one SQL query
130  * with one execute()/executemany() command, which the DB-API don't allow.
131  *
132  * It is also used to harden DML query detection.
133  */
134 static inline const char *
lstrip_sql(const char * sql)135 lstrip_sql(const char *sql)
136 {
137     // This loop is borrowed from the SQLite source code.
138     for (const char *pos = sql; *pos; pos++) {
139         switch (*pos) {
140             case ' ':
141             case '\t':
142             case '\f':
143             case '\n':
144             case '\r':
145                 // Skip whitespace.
146                 break;
147             case '-':
148                 // Skip line comments.
149                 if (pos[1] == '-') {
150                     pos += 2;
151                     while (pos[0] && pos[0] != '\n') {
152                         pos++;
153                     }
154                     if (pos[0] == '\0') {
155                         return NULL;
156                     }
157                     continue;
158                 }
159                 return pos;
160             case '/':
161                 // Skip C style comments.
162                 if (pos[1] == '*') {
163                     pos += 2;
164                     while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
165                         pos++;
166                     }
167                     if (pos[0] == '\0') {
168                         return NULL;
169                     }
170                     pos++;
171                     continue;
172                 }
173                 return pos;
174             default:
175                 return pos;
176         }
177     }
178 
179     return NULL;
180 }
181 
182 static PyType_Slot stmt_slots[] = {
183     {Py_tp_dealloc, stmt_dealloc},
184     {Py_tp_traverse, stmt_traverse},
185     {0, NULL},
186 };
187 
188 static PyType_Spec stmt_spec = {
189     .name = MODULE_NAME ".Statement",
190     .basicsize = sizeof(pysqlite_Statement),
191     .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
192               Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_DISALLOW_INSTANTIATION),
193     .slots = stmt_slots,
194 };
195 
196 int
pysqlite_statement_setup_types(PyObject * module)197 pysqlite_statement_setup_types(PyObject *module)
198 {
199     PyObject *type = PyType_FromModuleAndSpec(module, &stmt_spec, NULL);
200     if (type == NULL) {
201         return -1;
202     }
203     pysqlite_state *state = pysqlite_get_state(module);
204     state->StatementType = (PyTypeObject *)type;
205     return 0;
206 }
207