1 /* statement.c - the statement type
2 *
3 * Copyright (C) 2005-2010 Gerhard Häring <[email protected]>
4 *
5 * This file is part of pysqlite.
6 *
7 * This software is provided 'as-is', without any express or implied
8 * warranty. In no event will the authors be held liable for any damages
9 * arising from the use of this software.
10 *
11 * Permission is granted to anyone to use this software for any purpose,
12 * including commercial applications, and to alter it and redistribute it
13 * freely, subject to the following restrictions:
14 *
15 * 1. The origin of this software must not be misrepresented; you must not
16 * claim that you wrote the original software. If you use this software
17 * in a product, an acknowledgment in the product documentation would be
18 * appreciated but is not required.
19 * 2. Altered source versions must be plainly marked as such, and must not be
20 * misrepresented as being the original software.
21 * 3. This notice may not be removed or altered from any source distribution.
22 */
23
24 #include "connection.h"
25 #include "statement.h"
26 #include "util.h"
27
28 /* prototypes */
29 static const char *lstrip_sql(const char *sql);
30
31 pysqlite_Statement *
pysqlite_statement_create(pysqlite_Connection * connection,PyObject * sql)32 pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
33 {
34 pysqlite_state *state = connection->state;
35 assert(PyUnicode_Check(sql));
36 Py_ssize_t size;
37 const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size);
38 if (sql_cstr == NULL) {
39 return NULL;
40 }
41
42 sqlite3 *db = connection->db;
43 int max_length = sqlite3_limit(db, SQLITE_LIMIT_SQL_LENGTH, -1);
44 if (size > max_length) {
45 PyErr_SetString(connection->DataError,
46 "query string is too large");
47 return NULL;
48 }
49 if (strlen(sql_cstr) != (size_t)size) {
50 PyErr_SetString(connection->ProgrammingError,
51 "the query contains a null character");
52 return NULL;
53 }
54
55 sqlite3_stmt *stmt;
56 const char *tail;
57 int rc;
58 Py_BEGIN_ALLOW_THREADS
59 rc = sqlite3_prepare_v2(db, sql_cstr, (int)size + 1, &stmt, &tail);
60 Py_END_ALLOW_THREADS
61
62 if (rc != SQLITE_OK) {
63 _pysqlite_seterror(state, db);
64 return NULL;
65 }
66
67 if (lstrip_sql(tail) != NULL) {
68 PyErr_SetString(connection->ProgrammingError,
69 "You can only execute one statement at a time.");
70 goto error;
71 }
72
73 /* Determine if the statement is a DML statement.
74 SELECT is the only exception. See #9924. */
75 int is_dml = 0;
76 const char *p = lstrip_sql(sql_cstr);
77 if (p != NULL) {
78 is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
79 || (PyOS_strnicmp(p, "update", 6) == 0)
80 || (PyOS_strnicmp(p, "delete", 6) == 0)
81 || (PyOS_strnicmp(p, "replace", 7) == 0);
82 }
83
84 pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
85 state->StatementType);
86 if (self == NULL) {
87 goto error;
88 }
89
90 self->st = stmt;
91 self->in_use = 0;
92 self->is_dml = is_dml;
93
94 PyObject_GC_Track(self);
95 return self;
96
97 error:
98 (void)sqlite3_finalize(stmt);
99 return NULL;
100 }
101
102 static void
stmt_dealloc(pysqlite_Statement * self)103 stmt_dealloc(pysqlite_Statement *self)
104 {
105 PyTypeObject *tp = Py_TYPE(self);
106 PyObject_GC_UnTrack(self);
107 if (self->st) {
108 Py_BEGIN_ALLOW_THREADS
109 sqlite3_finalize(self->st);
110 Py_END_ALLOW_THREADS
111 self->st = 0;
112 }
113 tp->tp_free(self);
114 Py_DECREF(tp);
115 }
116
117 static int
stmt_traverse(pysqlite_Statement * self,visitproc visit,void * arg)118 stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
119 {
120 Py_VISIT(Py_TYPE(self));
121 return 0;
122 }
123
124 /*
125 * Strip leading whitespace and comments from incoming SQL (null terminated C
126 * string) and return a pointer to the first non-whitespace, non-comment
127 * character.
128 *
129 * This is used to check if somebody tries to execute more than one SQL query
130 * with one execute()/executemany() command, which the DB-API don't allow.
131 *
132 * It is also used to harden DML query detection.
133 */
134 static inline const char *
lstrip_sql(const char * sql)135 lstrip_sql(const char *sql)
136 {
137 // This loop is borrowed from the SQLite source code.
138 for (const char *pos = sql; *pos; pos++) {
139 switch (*pos) {
140 case ' ':
141 case '\t':
142 case '\f':
143 case '\n':
144 case '\r':
145 // Skip whitespace.
146 break;
147 case '-':
148 // Skip line comments.
149 if (pos[1] == '-') {
150 pos += 2;
151 while (pos[0] && pos[0] != '\n') {
152 pos++;
153 }
154 if (pos[0] == '\0') {
155 return NULL;
156 }
157 continue;
158 }
159 return pos;
160 case '/':
161 // Skip C style comments.
162 if (pos[1] == '*') {
163 pos += 2;
164 while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
165 pos++;
166 }
167 if (pos[0] == '\0') {
168 return NULL;
169 }
170 pos++;
171 continue;
172 }
173 return pos;
174 default:
175 return pos;
176 }
177 }
178
179 return NULL;
180 }
181
182 static PyType_Slot stmt_slots[] = {
183 {Py_tp_dealloc, stmt_dealloc},
184 {Py_tp_traverse, stmt_traverse},
185 {0, NULL},
186 };
187
188 static PyType_Spec stmt_spec = {
189 .name = MODULE_NAME ".Statement",
190 .basicsize = sizeof(pysqlite_Statement),
191 .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
192 Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_DISALLOW_INSTANTIATION),
193 .slots = stmt_slots,
194 };
195
196 int
pysqlite_statement_setup_types(PyObject * module)197 pysqlite_statement_setup_types(PyObject *module)
198 {
199 PyObject *type = PyType_FromModuleAndSpec(module, &stmt_spec, NULL);
200 if (type == NULL) {
201 return -1;
202 }
203 pysqlite_state *state = pysqlite_get_state(module);
204 state->StatementType = (PyTypeObject *)type;
205 return 0;
206 }
207