'
+ assert t.render() == escaped_out
+ assert unicode(t.module) == escaped_out
+ assert escape(t.module) == escaped_out
+ assert t.module.say_hello('') == escaped_out
+ assert escape(t.module.say_hello('')) == escaped_out
+
+ def test_attr_filter(self):
+ env = SandboxedEnvironment()
+ tmpl = env.from_string('{{ cls|attr("__subclasses__")() }}')
+ self.assert_raises(SecurityError, tmpl.render, cls=int)
+
+ def test_binary_operator_intercepting(self):
+ def disable_op(left, right):
+ raise TemplateRuntimeError('that operator so does not work')
+ for expr, ctx, rv in ('1 + 2', {}, '3'), ('a + 2', {'a': 2}, '4'):
+ env = SandboxedEnvironment()
+ env.binop_table['+'] = disable_op
+ t = env.from_string('{{ %s }}' % expr)
+ assert t.render(ctx) == rv
+ env.intercepted_binops = frozenset(['+'])
+ t = env.from_string('{{ %s }}' % expr)
+ try:
+ t.render(ctx)
+ except TemplateRuntimeError, e:
+ pass
+ else:
+ self.fail('expected runtime error')
+
+ def test_unary_operator_intercepting(self):
+ def disable_op(arg):
+ raise TemplateRuntimeError('that operator so does not work')
+ for expr, ctx, rv in ('-1', {}, '-1'), ('-a', {'a': 2}, '-2'):
+ env = SandboxedEnvironment()
+ env.unop_table['-'] = disable_op
+ t = env.from_string('{{ %s }}' % expr)
+ assert t.render(ctx) == rv
+ env.intercepted_unops = frozenset(['-'])
+ t = env.from_string('{{ %s }}' % expr)
+ try:
+ t.render(ctx)
+ except TemplateRuntimeError, e:
+ pass
+ else:
+ self.fail('expected runtime error')
+
+
+def suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(SandboxTestCase))
+ return suite
diff --git a/libs/jinja2/testsuite/tests.py b/libs/jinja2/testsuite/tests.py
new file mode 100755
index 0000000..3ece7a8
--- /dev/null
+++ b/libs/jinja2/testsuite/tests.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+"""
+ jinja2.testsuite.tests
+ ~~~~~~~~~~~~~~~~~~~~~~
+
+ Who tests the tests?
+
+ :copyright: (c) 2010 by the Jinja Team.
+ :license: BSD, see LICENSE for more details.
+"""
+import unittest
+from jinja2.testsuite import JinjaTestCase
+
+from jinja2 import Markup, Environment
+
+env = Environment()
+
+
+class TestsTestCase(JinjaTestCase):
+
+ def test_defined(self):
+ tmpl = env.from_string('{{ missing is defined }}|{{ true is defined }}')
+ assert tmpl.render() == 'False|True'
+
+ def test_even(self):
+ tmpl = env.from_string('''{{ 1 is even }}|{{ 2 is even }}''')
+ assert tmpl.render() == 'False|True'
+
+ def test_odd(self):
+ tmpl = env.from_string('''{{ 1 is odd }}|{{ 2 is odd }}''')
+ assert tmpl.render() == 'True|False'
+
+ def test_lower(self):
+ tmpl = env.from_string('''{{ "foo" is lower }}|{{ "FOO" is lower }}''')
+ assert tmpl.render() == 'True|False'
+
+ def test_typechecks(self):
+ tmpl = env.from_string('''
+ {{ 42 is undefined }}
+ {{ 42 is defined }}
+ {{ 42 is none }}
+ {{ none is none }}
+ {{ 42 is number }}
+ {{ 42 is string }}
+ {{ "foo" is string }}
+ {{ "foo" is sequence }}
+ {{ [1] is sequence }}
+ {{ range is callable }}
+ {{ 42 is callable }}
+ {{ range(5) is iterable }}
+ {{ {} is mapping }}
+ {{ mydict is mapping }}
+ {{ [] is mapping }}
+ ''')
+ class MyDict(dict):
+ pass
+ assert tmpl.render(mydict=MyDict()).split() == [
+ 'False', 'True', 'False', 'True', 'True', 'False',
+ 'True', 'True', 'True', 'True', 'False', 'True',
+ 'True', 'True', 'False'
+ ]
+
+ def test_sequence(self):
+ tmpl = env.from_string(
+ '{{ [1, 2, 3] is sequence }}|'
+ '{{ "foo" is sequence }}|'
+ '{{ 42 is sequence }}'
+ )
+ assert tmpl.render() == 'True|True|False'
+
+ def test_upper(self):
+ tmpl = env.from_string('{{ "FOO" is upper }}|{{ "foo" is upper }}')
+ assert tmpl.render() == 'True|False'
+
+ def test_sameas(self):
+ tmpl = env.from_string('{{ foo is sameas false }}|'
+ '{{ 0 is sameas false }}')
+ assert tmpl.render(foo=False) == 'True|False'
+
+ def test_no_paren_for_arg1(self):
+ tmpl = env.from_string('{{ foo is sameas none }}')
+ assert tmpl.render(foo=None) == 'True'
+
+ def test_escaped(self):
+ env = Environment(autoescape=True)
+ tmpl = env.from_string('{{ x is escaped }}|{{ y is escaped }}')
+ assert tmpl.render(x='foo', y=Markup('foo')) == 'False|True'
+
+
+def suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(TestsTestCase))
+ return suite
diff --git a/libs/jinja2/testsuite/utils.py b/libs/jinja2/testsuite/utils.py
new file mode 100755
index 0000000..be2e902
--- /dev/null
+++ b/libs/jinja2/testsuite/utils.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+"""
+ jinja2.testsuite.utils
+ ~~~~~~~~~~~~~~~~~~~~~~
+
+ Tests utilities jinja uses.
+
+ :copyright: (c) 2010 by the Jinja Team.
+ :license: BSD, see LICENSE for more details.
+"""
+import gc
+import unittest
+
+import pickle
+
+from jinja2.testsuite import JinjaTestCase
+
+from jinja2.utils import LRUCache, escape, object_type_repr
+
+
+class LRUCacheTestCase(JinjaTestCase):
+
+ def test_simple(self):
+ d = LRUCache(3)
+ d["a"] = 1
+ d["b"] = 2
+ d["c"] = 3
+ d["a"]
+ d["d"] = 4
+ assert len(d) == 3
+ assert 'a' in d and 'c' in d and 'd' in d and 'b' not in d
+
+ def test_pickleable(self):
+ cache = LRUCache(2)
+ cache["foo"] = 42
+ cache["bar"] = 23
+ cache["foo"]
+
+ for protocol in range(3):
+ copy = pickle.loads(pickle.dumps(cache, protocol))
+ assert copy.capacity == cache.capacity
+ assert copy._mapping == cache._mapping
+ assert copy._queue == cache._queue
+
+
+class HelpersTestCase(JinjaTestCase):
+
+ def test_object_type_repr(self):
+ class X(object):
+ pass
+ self.assert_equal(object_type_repr(42), 'int object')
+ self.assert_equal(object_type_repr([]), 'list object')
+ self.assert_equal(object_type_repr(X()),
+ 'jinja2.testsuite.utils.X object')
+ self.assert_equal(object_type_repr(None), 'None')
+ self.assert_equal(object_type_repr(Ellipsis), 'Ellipsis')
+
+
+class MarkupLeakTestCase(JinjaTestCase):
+
+ def test_markup_leaks(self):
+ counts = set()
+ for count in xrange(20):
+ for item in xrange(1000):
+ escape("foo")
+ escape("")
+ escape(u"foo")
+ escape(u"")
+ counts.add(len(gc.get_objects()))
+ assert len(counts) == 1, 'ouch, c extension seems to leak objects'
+
+
+def suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(LRUCacheTestCase))
+ suite.addTest(unittest.makeSuite(HelpersTestCase))
+
+ # this test only tests the c extension
+ if not hasattr(escape, 'func_code'):
+ suite.addTest(unittest.makeSuite(MarkupLeakTestCase))
+
+ return suite
diff --git a/libs/jinja2/utils.py b/libs/jinja2/utils.py
index 1e0bb81..568c63f 100755
--- a/libs/jinja2/utils.py
+++ b/libs/jinja2/utils.py
@@ -67,7 +67,7 @@ except TypeError, _error:
del _test_gen_bug, _error
-# for python 2.x we create outselves a next() function that does the
+# for python 2.x we create ourselves a next() function that does the
# basics without exception catching.
try:
next = next
@@ -132,7 +132,7 @@ def contextfunction(f):
def evalcontextfunction(f):
- """This decoraotr can be used to mark a function or method as an eval
+ """This decorator can be used to mark a function or method as an eval
context callable. This is similar to the :func:`contextfunction`
but instead of passing the context, an evaluation context object is
passed. For more information about the eval context, see
@@ -195,7 +195,7 @@ def clear_caches():
def import_string(import_name, silent=False):
- """Imports an object based on a string. This use useful if you want to
+ """Imports an object based on a string. This is useful if you want to
use import paths as endpoints or something similar. An import path can
be specified either in dotted notation (``xml.sax.saxutils.escape``)
or with a colon as object delimiter (``xml.sax.saxutils:escape``).
@@ -412,7 +412,7 @@ class LRUCache(object):
return (self.capacity,)
def copy(self):
- """Return an shallow copy of the instance."""
+ """Return a shallow copy of the instance."""
rv = self.__class__(self.capacity)
rv._mapping.update(self._mapping)
rv._queue = deque(self._queue)
@@ -462,7 +462,7 @@ class LRUCache(object):
"""Get an item from the cache. Moves the item up so that it has the
highest priority then.
- Raise an `KeyError` if it does not exist.
+ Raise a `KeyError` if it does not exist.
"""
rv = self._mapping[key]
if self._queue[-1] != key:
@@ -497,7 +497,7 @@ class LRUCache(object):
def __delitem__(self, key):
"""Remove an item from the cache dict.
- Raise an `KeyError` if it does not exist.
+ Raise a `KeyError` if it does not exist.
"""
self._wlock.acquire()
try:
@@ -598,7 +598,7 @@ class Joiner(object):
# try markupsafe first, if that fails go with Jinja2's bundled version
# of markupsafe. Markupsafe was previously Jinja2's implementation of
-# the Markup object but was moved into a separate package in a patchleve
+# the Markup object but was moved into a separate package in a patchlevel
# release
try:
from markupsafe import Markup, escape, soft_unicode
diff --git a/libs/sqlalchemy/__init__.py b/libs/sqlalchemy/__init__.py
index ef5f385..03293b5 100644
--- a/libs/sqlalchemy/__init__.py
+++ b/libs/sqlalchemy/__init__.py
@@ -117,7 +117,7 @@ from sqlalchemy.engine import create_engine, engine_from_config
__all__ = sorted(name for name, obj in locals().items()
if not (name.startswith('_') or inspect.ismodule(obj)))
-__version__ = '0.7.5'
+__version__ = '0.7.6'
del inspect, sys
diff --git a/libs/sqlalchemy/cextension/processors.c b/libs/sqlalchemy/cextension/processors.c
index b539f68..427db5d 100644
--- a/libs/sqlalchemy/cextension/processors.c
+++ b/libs/sqlalchemy/cextension/processors.c
@@ -342,23 +342,18 @@ DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
if (value == Py_None)
Py_RETURN_NONE;
- if (PyFloat_CheckExact(value)) {
- /* Decimal does not accept float values directly */
- args = PyTuple_Pack(1, value);
- if (args == NULL)
- return NULL;
+ args = PyTuple_Pack(1, value);
+ if (args == NULL)
+ return NULL;
- str = PyString_Format(self->format, args);
- Py_DECREF(args);
- if (str == NULL)
- return NULL;
+ str = PyString_Format(self->format, args);
+ Py_DECREF(args);
+ if (str == NULL)
+ return NULL;
- result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
- Py_DECREF(str);
- return result;
- } else {
- return PyObject_CallFunctionObjArgs(self->type, value, NULL);
- }
+ result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
+ Py_DECREF(str);
+ return result;
}
static void
diff --git a/libs/sqlalchemy/cextension/resultproxy.c b/libs/sqlalchemy/cextension/resultproxy.c
index 64b6855..3494cca 100644
--- a/libs/sqlalchemy/cextension/resultproxy.c
+++ b/libs/sqlalchemy/cextension/resultproxy.c
@@ -246,6 +246,7 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
PyObject *exc_module, *exception;
char *cstr_key;
long index;
+ int key_fallback = 0;
if (PyInt_CheckExact(key)) {
index = PyInt_AS_LONG(key);
@@ -276,12 +277,17 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
"O", key);
if (record == NULL)
return NULL;
+ key_fallback = 1;
}
indexobject = PyTuple_GetItem(record, 2);
if (indexobject == NULL)
return NULL;
+ if (key_fallback) {
+ Py_DECREF(record);
+ }
+
if (indexobject == Py_None) {
exc_module = PyImport_ImportModule("sqlalchemy.exc");
if (exc_module == NULL)
@@ -347,7 +353,16 @@ BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
else
return tmp;
- return BaseRowProxy_subscript(self, name);
+ tmp = BaseRowProxy_subscript(self, name);
+ if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
+ PyErr_Format(
+ PyExc_AttributeError,
+ "Could not locate column in row for column '%.200s'",
+ PyString_AsString(name)
+ );
+ return NULL;
+ }
+ return tmp;
}
/***********************
diff --git a/libs/sqlalchemy/dialects/firebird/base.py b/libs/sqlalchemy/dialects/firebird/base.py
index 8cf2ded..031c689 100644
--- a/libs/sqlalchemy/dialects/firebird/base.py
+++ b/libs/sqlalchemy/dialects/firebird/base.py
@@ -215,7 +215,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
alias_name = isinstance(alias.name,
- expression._generated_label) and \
+ expression._truncated_label) and \
self._truncated_identifier("alias",
alias.name) or alias.name
diff --git a/libs/sqlalchemy/dialects/mssql/base.py b/libs/sqlalchemy/dialects/mssql/base.py
index f7c94aa..103b0a3 100644
--- a/libs/sqlalchemy/dialects/mssql/base.py
+++ b/libs/sqlalchemy/dialects/mssql/base.py
@@ -791,6 +791,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
def get_from_hint_text(self, table, text):
return text
+ def get_crud_hint_text(self, table, text):
+ return text
+
def limit_clause(self, select):
# Limit in mssql is after the select keyword
return ""
@@ -949,6 +952,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
]
return 'OUTPUT ' + ', '.join(columns)
+ def get_cte_preamble(self, recursive):
+ # SQL Server finds it too inconvenient to accept
+ # an entirely optional, SQL standard specified,
+ # "RECURSIVE" word with their "WITH",
+ # so here we go
+ return "WITH"
+
def label_select_column(self, select, column, asfrom):
if isinstance(column, expression.Function):
return column.label(None)
diff --git a/libs/sqlalchemy/dialects/mysql/base.py b/libs/sqlalchemy/dialects/mysql/base.py
index 6aa250d..d9ab5a3 100644
--- a/libs/sqlalchemy/dialects/mysql/base.py
+++ b/libs/sqlalchemy/dialects/mysql/base.py
@@ -84,6 +84,23 @@ all lower case both within SQLAlchemy as well as on the MySQL
database itself, especially if database reflection features are
to be used.
+Transaction Isolation Level
+---------------------------
+
+:func:`.create_engine` accepts an ``isolation_level``
+parameter which results in the command ``SET SESSION
+TRANSACTION ISOLATION LEVEL `` being invoked for
+every new connection. Valid values for this parameter are
+``READ COMMITTED``, ``READ UNCOMMITTED``,
+``REPEATABLE READ``, and ``SERIALIZABLE``::
+
+ engine = create_engine(
+ "mysql://scott:tiger@localhost/test",
+ isolation_level="READ UNCOMMITTED"
+ )
+
+(new in 0.7.6)
+
Keys
----
@@ -221,8 +238,29 @@ simply passed through to the underlying CREATE INDEX command, so it *must* be
an integer. MySQL only allows a length for an index if it is for a CHAR,
VARCHAR, TEXT, BINARY, VARBINARY and BLOB.
+Index Types
+~~~~~~~~~~~~~
+
+Some MySQL storage engines permit you to specify an index type when creating
+an index or primary key constraint. SQLAlchemy provides this feature via the
+``mysql_using`` parameter on :class:`.Index`::
+
+ Index('my_index', my_table.c.data, mysql_using='hash')
+
+As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`::
+
+ PrimaryKeyConstraint("data", mysql_using='hash')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index
+type for your MySQL storage engine.
+
More information can be found at:
+
http://dev.mysql.com/doc/refman/5.0/en/create-index.html
+
+http://dev.mysql.com/doc/refman/5.0/en/create-table.html
+
"""
import datetime, inspect, re, sys
@@ -1331,7 +1369,8 @@ class MySQLCompiler(compiler.SQLCompiler):
return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw)
for t in [from_table] + list(extra_froms))
- def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+ def update_from_clause(self, update_stmt, from_table,
+ extra_froms, from_hints, **kw):
return None
@@ -1421,35 +1460,50 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
table_opts.append(joiner.join((opt, arg)))
return ' '.join(table_opts)
+
def visit_create_index(self, create):
index = create.element
preparer = self.preparer
+ table = preparer.format_table(index.table)
+ columns = [preparer.quote(c.name, c.quote) for c in index.columns]
+ name = preparer.quote(
+ self._index_identifier(index.name),
+ index.quote)
+
text = "CREATE "
if index.unique:
text += "UNIQUE "
- text += "INDEX %s ON %s " \
- % (preparer.quote(self._index_identifier(index.name),
- index.quote),preparer.format_table(index.table))
+ text += "INDEX %s ON %s " % (name, table)
+
+ columns = ', '.join(columns)
if 'mysql_length' in index.kwargs:
length = index.kwargs['mysql_length']
+ text += "(%s(%d))" % (columns, length)
else:
- length = None
- if length is not None:
- text+= "(%s(%d))" \
- % (', '.join(preparer.quote(c.name, c.quote)
- for c in index.columns), length)
- else:
- text+= "(%s)" \
- % (', '.join(preparer.quote(c.name, c.quote)
- for c in index.columns))
+ text += "(%s)" % (columns)
+
+ if 'mysql_using' in index.kwargs:
+ using = index.kwargs['mysql_using']
+ text += " USING %s" % (preparer.quote(using, index.quote))
+
return text
+ def visit_primary_key_constraint(self, constraint):
+ text = super(MySQLDDLCompiler, self).\
+ visit_primary_key_constraint(constraint)
+ if "mysql_using" in constraint.kwargs:
+ using = constraint.kwargs['mysql_using']
+ text += " USING %s" % (
+ self.preparer.quote(using, constraint.quote))
+ return text
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX %s ON %s" % \
- (self.preparer.quote(self._index_identifier(index.name), index.quote),
+ (self.preparer.quote(
+ self._index_identifier(index.name), index.quote
+ ),
self.preparer.format_table(index.table))
def visit_drop_constraint(self, drop):
@@ -1768,8 +1822,40 @@ class MySQLDialect(default.DefaultDialect):
_backslash_escapes = True
_server_ansiquotes = False
- def __init__(self, use_ansiquotes=None, **kwargs):
+ def __init__(self, use_ansiquotes=None, isolation_level=None, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)
+ self.isolation_level = isolation_level
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+ return connect
+ else:
+ return None
+
+ _isolation_lookup = set(['SERIALIZABLE',
+ 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ'])
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace('_', ' ')
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s" %
+ (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ cursor.execute('SELECT @@tx_isolation')
+ val = cursor.fetchone()[0]
+ cursor.close()
+ return val.upper().replace("-", " ")
def do_commit(self, connection):
"""Execute a COMMIT."""
diff --git a/libs/sqlalchemy/dialects/oracle/base.py b/libs/sqlalchemy/dialects/oracle/base.py
index 88e5062..dd761ae 100644
--- a/libs/sqlalchemy/dialects/oracle/base.py
+++ b/libs/sqlalchemy/dialects/oracle/base.py
@@ -158,7 +158,7 @@ RESERVED_WORDS = \
'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\
'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\
'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\
- 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT'.split())
+ 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split())
NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER '
'CURRENT_TIME CURRENT_TIMESTAMP'.split())
@@ -309,6 +309,9 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
"",
)
+ def visit_LONG(self, type_):
+ return "LONG"
+
def visit_TIMESTAMP(self, type_):
if type_.timezone:
return "TIMESTAMP WITH TIME ZONE"
@@ -481,7 +484,7 @@ class OracleCompiler(compiler.SQLCompiler):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
if asfrom or ashint:
- alias_name = isinstance(alias.name, expression._generated_label) and \
+ alias_name = isinstance(alias.name, expression._truncated_label) and \
self._truncated_identifier("alias", alias.name) or alias.name
if ashint:
diff --git a/libs/sqlalchemy/dialects/oracle/cx_oracle.py b/libs/sqlalchemy/dialects/oracle/cx_oracle.py
index 64526d2..5001acc 100644
--- a/libs/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/libs/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -77,7 +77,7 @@ with this feature but it should be regarded as experimental.
Precision Numerics
------------------
-The SQLAlchemy dialect goes thorugh a lot of steps to ensure
+The SQLAlchemy dialect goes through a lot of steps to ensure
that decimal numbers are sent and received with full accuracy.
An "outputtypehandler" callable is associated with each
cx_oracle connection object which detects numeric types and
@@ -89,6 +89,21 @@ this behavior, and will coerce the ``Decimal`` to ``float`` if
the ``asdecimal`` flag is ``False`` (default on :class:`.Float`,
optional on :class:`.Numeric`).
+Because the handler coerces to ``Decimal`` in all cases first,
+the feature can detract significantly from performance.
+If precision numerics aren't required, the decimal handling
+can be disabled by passing the flag ``coerce_to_decimal=False``
+to :func:`.create_engine`::
+
+ engine = create_engine("oracle+cx_oracle://dsn",
+ coerce_to_decimal=False)
+
+The ``coerce_to_decimal`` flag is new in 0.7.6.
+
+Another alternative to performance is to use the
+`cdecimal `_ library;
+see :class:`.Numeric` for additional notes.
+
The handler attempts to use the "precision" and "scale"
attributes of the result set column to best determine if
subsequent incoming values should be received as ``Decimal`` as
@@ -468,6 +483,7 @@ class OracleDialect_cx_oracle(OracleDialect):
auto_convert_lobs=True,
threaded=True,
allow_twophase=True,
+ coerce_to_decimal=True,
arraysize=50, **kwargs):
OracleDialect.__init__(self, **kwargs)
self.threaded = threaded
@@ -491,7 +507,12 @@ class OracleDialect_cx_oracle(OracleDialect):
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0)
- self.supports_native_decimal = self.cx_oracle_ver >= (5, 0)
+
+ self.supports_native_decimal = (
+ self.cx_oracle_ver >= (5, 0) and
+ coerce_to_decimal
+ )
+
self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0)
if self.cx_oracle_ver is None:
@@ -603,7 +624,9 @@ class OracleDialect_cx_oracle(OracleDialect):
size, precision, scale):
# convert all NUMBER with precision + positive scale to Decimal
# this almost allows "native decimal" mode.
- if defaultType == cx_Oracle.NUMBER and precision and scale > 0:
+ if self.supports_native_decimal and \
+ defaultType == cx_Oracle.NUMBER and \
+ precision and scale > 0:
return cursor.var(
cx_Oracle.STRING,
255,
@@ -614,7 +637,8 @@ class OracleDialect_cx_oracle(OracleDialect):
# make a decision based on each value received - the type
# may change from row to row (!). This kills
# off "native decimal" mode, handlers still needed.
- elif defaultType == cx_Oracle.NUMBER \
+ elif self.supports_native_decimal and \
+ defaultType == cx_Oracle.NUMBER \
and not precision and scale <= 0:
return cursor.var(
cx_Oracle.STRING,
diff --git a/libs/sqlalchemy/dialects/postgresql/base.py b/libs/sqlalchemy/dialects/postgresql/base.py
index 69c11d8..c4c2bbd 100644
--- a/libs/sqlalchemy/dialects/postgresql/base.py
+++ b/libs/sqlalchemy/dialects/postgresql/base.py
@@ -47,9 +47,18 @@ Transaction Isolation Level
:func:`.create_engine` accepts an ``isolation_level`` parameter which results
in the command ``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL
`` being invoked for every new connection. Valid values for this
-parameter are ``READ_COMMITTED``, ``READ_UNCOMMITTED``, ``REPEATABLE_READ``,
-and ``SERIALIZABLE``. Note that the psycopg2 dialect does *not* use this
-technique and uses psycopg2-specific APIs (see that dialect for details).
+parameter are ``READ COMMITTED``, ``READ UNCOMMITTED``, ``REPEATABLE READ``,
+and ``SERIALIZABLE``::
+
+ engine = create_engine(
+ "postgresql+pg8000://scott:tiger@localhost/test",
+ isolation_level="READ UNCOMMITTED"
+ )
+
+When using the psycopg2 dialect, a psycopg2-specific method of setting
+transaction isolation level is used, but the API of ``isolation_level``
+remains the same - see :ref:`psycopg2_isolation`.
+
Remote / Cross-Schema Table Introspection
-----------------------------------------
diff --git a/libs/sqlalchemy/dialects/postgresql/psycopg2.py b/libs/sqlalchemy/dialects/postgresql/psycopg2.py
index c66180f..5aa9397 100644
--- a/libs/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/libs/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -38,6 +38,26 @@ psycopg2-specific keyword arguments which are accepted by
* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode
per connection. True by default.
+Unix Domain Connections
+------------------------
+
+psycopg2 supports connecting via Unix domain connections. When the ``host``
+portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
+which specifies Unix-domain communication rather than TCP/IP communication::
+
+ create_engine("postgresql+psycopg2://user:password@/dbname")
+
+By default, the socket file used is to connect to a Unix-domain socket
+in ``/tmp``, or whatever socket directory was specified when PostgreSQL
+was built. This value can be overridden by passing a pathname to psycopg2,
+using ``host`` as an additional keyword argument::
+
+ create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
+
+See also:
+
+`PQconnectdbParams `_
+
Per-Statement/Connection Execution Options
-------------------------------------------
@@ -97,6 +117,8 @@ Transactions
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+.. _psycopg2_isolation:
+
Transaction Isolation Level
---------------------------
diff --git a/libs/sqlalchemy/dialects/sqlite/base.py b/libs/sqlalchemy/dialects/sqlite/base.py
index f9520af..10a0d88 100644
--- a/libs/sqlalchemy/dialects/sqlite/base.py
+++ b/libs/sqlalchemy/dialects/sqlite/base.py
@@ -441,20 +441,6 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result
return result
-class SQLiteExecutionContext(default.DefaultExecutionContext):
- def get_result_proxy(self):
- rp = base.ResultProxy(self)
- if rp._metadata:
- # adjust for dotted column names. SQLite
- # in the case of UNION may store col names as
- # "tablename.colname"
- # in cursor.description
- for colname in rp._metadata.keys:
- if "." in colname:
- trunc_col = colname.split(".")[1]
- rp._metadata._set_keymap_synonym(trunc_col, colname)
- return rp
-
class SQLiteDialect(default.DefaultDialect):
name = 'sqlite'
supports_alter = False
@@ -472,7 +458,6 @@ class SQLiteDialect(default.DefaultDialect):
ischema_names = ischema_names
colspecs = colspecs
isolation_level = None
- execution_ctx_cls = SQLiteExecutionContext
supports_cast = True
supports_default_values = True
@@ -540,6 +525,16 @@ class SQLiteDialect(default.DefaultDialect):
else:
return None
+ def _translate_colname(self, colname):
+ # adjust for dotted column names. SQLite
+ # in the case of UNION may store col names as
+ # "tablename.colname"
+ # in cursor.description
+ if "." in colname:
+ return colname.split(".")[1], colname
+ else:
+ return colname, None
+
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is not None:
diff --git a/libs/sqlalchemy/engine/__init__.py b/libs/sqlalchemy/engine/__init__.py
index 4fac3e5..23b4b0b 100644
--- a/libs/sqlalchemy/engine/__init__.py
+++ b/libs/sqlalchemy/engine/__init__.py
@@ -306,6 +306,12 @@ def create_engine(*args, **kwargs):
this is configurable with the MySQLDB connection itself and the
server configuration as well).
+ :param pool_reset_on_return='rollback': set the "reset on return"
+ behavior of the pool, which is whether ``rollback()``,
+ ``commit()``, or nothing is called upon connections
+ being returned to the pool. See the docstring for
+ ``reset_on_return`` at :class:`.Pool`. (new as of 0.7.6)
+
:param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`.
diff --git a/libs/sqlalchemy/engine/base.py b/libs/sqlalchemy/engine/base.py
index db19fe7..d16fc9c 100644
--- a/libs/sqlalchemy/engine/base.py
+++ b/libs/sqlalchemy/engine/base.py
@@ -491,14 +491,23 @@ class Dialect(object):
raise NotImplementedError()
def do_executemany(self, cursor, statement, parameters, context=None):
- """Provide an implementation of *cursor.executemany(statement,
- parameters)*."""
+ """Provide an implementation of ``cursor.executemany(statement,
+ parameters)``."""
raise NotImplementedError()
def do_execute(self, cursor, statement, parameters, context=None):
- """Provide an implementation of *cursor.execute(statement,
- parameters)*."""
+ """Provide an implementation of ``cursor.execute(statement,
+ parameters)``."""
+
+ raise NotImplementedError()
+
+ def do_execute_no_params(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of ``cursor.execute(statement)``.
+
+ The parameter collection should not be sent.
+
+ """
raise NotImplementedError()
@@ -777,12 +786,12 @@ class Connectable(object):
def connect(self, **kwargs):
"""Return a :class:`.Connection` object.
-
+
Depending on context, this may be ``self`` if this object
is already an instance of :class:`.Connection`, or a newly
procured :class:`.Connection` if this object is an instance
of :class:`.Engine`.
-
+
"""
def contextual_connect(self):
@@ -793,7 +802,7 @@ class Connectable(object):
is already an instance of :class:`.Connection`, or a newly
procured :class:`.Connection` if this object is an instance
of :class:`.Engine`.
-
+
"""
raise NotImplementedError()
@@ -904,6 +913,12 @@ class Connection(Connectable):
c.__dict__ = self.__dict__.copy()
return c
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
def execution_options(self, **opt):
""" Set non-SQL options for the connection which take effect
during execution.
@@ -940,7 +955,7 @@ class Connection(Connectable):
:param compiled_cache: Available on: Connection.
A dictionary where :class:`.Compiled` objects
will be cached when the :class:`.Connection` compiles a clause
- expression into a :class:`.Compiled` object.
+ expression into a :class:`.Compiled` object.
It is the user's responsibility to
manage the size of this dictionary, which will have keys
corresponding to the dialect, clause element, the column
@@ -953,7 +968,7 @@ class Connection(Connectable):
some operations, including flush operations. The caching
used by the ORM internally supersedes a cache dictionary
specified here.
-
+
:param isolation_level: Available on: Connection.
Set the transaction isolation level for
the lifespan of this connection. Valid values include
@@ -962,7 +977,7 @@ class Connection(Connectable):
database specific, including those for :ref:`sqlite_toplevel`,
:ref:`postgresql_toplevel` - see those dialect's documentation
for further info.
-
+
Note that this option necessarily affects the underying
DBAPI connection for the lifespan of the originating
:class:`.Connection`, and is not per-execution. This
@@ -970,6 +985,18 @@ class Connection(Connectable):
is returned to the connection pool, i.e.
the :meth:`.Connection.close` method is called.
+ :param no_parameters: When ``True``, if the final parameter
+ list or dictionary is totally empty, will invoke the
+ statement on the cursor as ``cursor.execute(statement)``,
+ not passing the parameter collection at all.
+ Some DBAPIs such as psycopg2 and mysql-python consider
+ percent signs as significant only when parameters are
+ present; this option allows code to generate SQL
+ containing percent signs (and possibly other characters)
+ that is neutral regarding whether it's executed by the DBAPI
+ or piped into a script that's later invoked by
+ command line tools. New in 0.7.6.
+
:param stream_results: Available on: Connection, statement.
Indicate to the dialect that results should be
"streamed" and not pre-buffered, if possible. This is a limitation
@@ -1113,17 +1140,35 @@ class Connection(Connectable):
def begin(self):
"""Begin a transaction and return a transaction handle.
-
+
The returned object is an instance of :class:`.Transaction`.
+ This object represents the "scope" of the transaction,
+ which completes when either the :meth:`.Transaction.rollback`
+ or :meth:`.Transaction.commit` method is called.
+
+ Nested calls to :meth:`.begin` on the same :class:`.Connection`
+ will return new :class:`.Transaction` objects that represent
+ an emulated transaction within the scope of the enclosing
+ transaction, that is::
+
+ trans = conn.begin() # outermost transaction
+ trans2 = conn.begin() # "nested"
+ trans2.commit() # does nothing
+ trans.commit() # actually commits
+
+ Calls to :meth:`.Transaction.commit` only have an effect
+ when invoked via the outermost :class:`.Transaction` object, though the
+ :meth:`.Transaction.rollback` method of any of the
+ :class:`.Transaction` objects will roll back the
+ transaction.
- Repeated calls to ``begin`` on the same Connection will create
- a lightweight, emulated nested transaction. Only the
- outermost transaction may ``commit``. Calls to ``commit`` on
- inner transactions are ignored. Any transaction in the
- hierarchy may ``rollback``, however.
+ See also:
- See also :meth:`.Connection.begin_nested`,
- :meth:`.Connection.begin_twophase`.
+ :meth:`.Connection.begin_nested` - use a SAVEPOINT
+
+ :meth:`.Connection.begin_twophase` - use a two phase /XID transaction
+
+ :meth:`.Engine.begin` - context manager available from :class:`.Engine`.
"""
@@ -1157,7 +1202,7 @@ class Connection(Connectable):
def begin_twophase(self, xid=None):
"""Begin a two-phase or XA transaction and return a transaction
handle.
-
+
The returned object is an instance of :class:`.TwoPhaseTransaction`,
which in addition to the methods provided by
:class:`.Transaction`, also provides a :meth:`~.TwoPhaseTransaction.prepare`
@@ -1302,7 +1347,7 @@ class Connection(Connectable):
def close(self):
"""Close this :class:`.Connection`.
-
+
This results in a release of the underlying database
resources, that is, the DBAPI connection referenced
internally. The DBAPI connection is typically restored
@@ -1313,7 +1358,7 @@ class Connection(Connectable):
the DBAPI connection's ``rollback()`` method, regardless
of any :class:`.Transaction` object that may be
outstanding with regards to this :class:`.Connection`.
-
+
After :meth:`~.Connection.close` is called, the
:class:`.Connection` is permanently in a closed state,
and will allow no further operations.
@@ -1354,24 +1399,24 @@ class Connection(Connectable):
* a :class:`.DDLElement` object
* a :class:`.DefaultGenerator` object
* a :class:`.Compiled` object
-
+
:param \*multiparams/\**params: represent bound parameter
values to be used in the execution. Typically,
the format is either a collection of one or more
dictionaries passed to \*multiparams::
-
+
conn.execute(
table.insert(),
{"id":1, "value":"v1"},
{"id":2, "value":"v2"}
)
-
+
...or individual key/values interpreted by \**params::
-
+
conn.execute(
table.insert(), id=1, value="v1"
)
-
+
In the case that a plain SQL string is passed, and the underlying
DBAPI accepts positional bind parameters, a collection of tuples
or individual values in \*multiparams may be passed::
@@ -1380,21 +1425,21 @@ class Connection(Connectable):
"INSERT INTO table (id, value) VALUES (?, ?)",
(1, "v1"), (2, "v2")
)
-
+
conn.execute(
"INSERT INTO table (id, value) VALUES (?, ?)",
1, "v1"
)
-
+
Note above, the usage of a question mark "?" or other
symbol is contingent upon the "paramstyle" accepted by the DBAPI
in use, which may be any of "qmark", "named", "pyformat", "format",
"numeric". See `pep-249 `_
for details on paramstyle.
-
+
To execute a textual SQL statement which uses bound parameters in a
DBAPI-agnostic way, use the :func:`~.expression.text` construct.
-
+
"""
for c in type(object).__mro__:
if c in Connection.executors:
@@ -1623,7 +1668,8 @@ class Connection(Connectable):
if self._echo:
self.engine.logger.info(statement)
- self.engine.logger.info("%r", sql_util._repr_params(parameters, batches=10))
+ self.engine.logger.info("%r",
+ sql_util._repr_params(parameters, batches=10))
try:
if context.executemany:
self.dialect.do_executemany(
@@ -1631,6 +1677,11 @@ class Connection(Connectable):
statement,
parameters,
context)
+ elif not parameters and context.no_parameters:
+ self.dialect.do_execute_no_params(
+ cursor,
+ statement,
+ context)
else:
self.dialect.do_execute(
cursor,
@@ -1845,33 +1896,41 @@ class Connection(Connectable):
"""Execute the given function within a transaction boundary.
The function is passed this :class:`.Connection`
- as the first argument, followed by the given \*args and \**kwargs.
-
- This is a shortcut for explicitly invoking
- :meth:`.Connection.begin`, calling :meth:`.Transaction.commit`
- upon success or :meth:`.Transaction.rollback` upon an
- exception raise::
+ as the first argument, followed by the given \*args and \**kwargs,
+ e.g.::
def do_something(conn, x, y):
conn.execute("some statement", {'x':x, 'y':y})
-
+
conn.transaction(do_something, 5, 10)
+
+ The operations inside the function are all invoked within the
+ context of a single :class:`.Transaction`.
+ Upon success, the transaction is committed. If an
+ exception is raised, the transaction is rolled back
+ before propagating the exception.
+
+ .. note::
+
+ The :meth:`.transaction` method is superseded by
+ the usage of the Python ``with:`` statement, which can
+ be used with :meth:`.Connection.begin`::
+
+ with conn.begin():
+ conn.execute("some statement", {'x':5, 'y':10})
+
+ As well as with :meth:`.Engine.begin`::
+
+ with engine.begin() as conn:
+ conn.execute("some statement", {'x':5, 'y':10})
- Note that context managers (i.e. the ``with`` statement)
- present a more modern way of accomplishing the above,
- using the :class:`.Transaction` object as a base::
+ See also:
- with conn.begin():
- conn.execute("some statement", {'x':5, 'y':10})
-
- One advantage to the :meth:`.Connection.transaction`
- method is that the same method is also available
- on :class:`.Engine` as :meth:`.Engine.transaction` -
- this method procures a :class:`.Connection` and then
- performs the same operation, allowing equivalent
- usage with either a :class:`.Connection` or :class:`.Engine`
- without needing to know what kind of object
- it is.
+ :meth:`.Engine.begin` - engine-level transactional
+ context
+
+ :meth:`.Engine.transaction` - engine-level version of
+ :meth:`.Connection.transaction`
"""
@@ -1887,15 +1946,15 @@ class Connection(Connectable):
def run_callable(self, callable_, *args, **kwargs):
"""Given a callable object or function, execute it, passing
a :class:`.Connection` as the first argument.
-
+
The given \*args and \**kwargs are passed subsequent
to the :class:`.Connection` argument.
-
+
This function, along with :meth:`.Engine.run_callable`,
allows a function to be run with a :class:`.Connection`
or :class:`.Engine` object without the need to know
which one is being dealt with.
-
+
"""
return callable_(self, *args, **kwargs)
@@ -1906,11 +1965,11 @@ class Connection(Connectable):
class Transaction(object):
"""Represent a database transaction in progress.
-
+
The :class:`.Transaction` object is procured by
calling the :meth:`~.Connection.begin` method of
:class:`.Connection`::
-
+
from sqlalchemy import create_engine
engine = create_engine("postgresql://scott:tiger@localhost/test")
connection = engine.connect()
@@ -1923,7 +1982,7 @@ class Transaction(object):
also implements a context manager interface so that
the Python ``with`` statement can be used with the
:meth:`.Connection.begin` method::
-
+
with connection.begin():
connection.execute("insert into x (a, b) values (1, 2)")
@@ -1931,7 +1990,7 @@ class Transaction(object):
See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`,
:meth:`.Connection.begin_nested`.
-
+
.. index::
single: thread safety; Transaction
"""
@@ -2012,9 +2071,9 @@ class NestedTransaction(Transaction):
A new :class:`.NestedTransaction` object may be procured
using the :meth:`.Connection.begin_nested` method.
-
+
The interface is the same as that of :class:`.Transaction`.
-
+
"""
def __init__(self, connection, parent):
super(NestedTransaction, self).__init__(connection, parent)
@@ -2033,13 +2092,13 @@ class NestedTransaction(Transaction):
class TwoPhaseTransaction(Transaction):
"""Represent a two-phase transaction.
-
+
A new :class:`.TwoPhaseTransaction` object may be procured
using the :meth:`.Connection.begin_twophase` method.
-
+
The interface is the same as that of :class:`.Transaction`
with the addition of the :meth:`prepare` method.
-
+
"""
def __init__(self, connection, xid):
super(TwoPhaseTransaction, self).__init__(connection, None)
@@ -2049,9 +2108,9 @@ class TwoPhaseTransaction(Transaction):
def prepare(self):
"""Prepare this :class:`.TwoPhaseTransaction`.
-
+
After a PREPARE, the transaction can be committed.
-
+
"""
if not self._parent.is_active:
raise exc.InvalidRequestError("This transaction is inactive")
@@ -2075,11 +2134,11 @@ class Engine(Connectable, log.Identified):
:func:`~sqlalchemy.create_engine` function.
See also:
-
+
:ref:`engines_toplevel`
:ref:`connections_toplevel`
-
+
"""
_execution_options = util.immutabledict()
@@ -2115,13 +2174,13 @@ class Engine(Connectable, log.Identified):
def update_execution_options(self, **opt):
"""Update the default execution_options dictionary
of this :class:`.Engine`.
-
+
The given keys/values in \**opt are added to the
default execution options that will be used for
all connections. The initial contents of this dictionary
can be sent via the ``execution_options`` paramter
to :func:`.create_engine`.
-
+
See :meth:`.Connection.execution_options` for more
details on execution options.
@@ -2236,19 +2295,96 @@ class Engine(Connectable, log.Identified):
if connection is None:
conn.close()
+ class _trans_ctx(object):
+ def __init__(self, conn, transaction, close_with_result):
+ self.conn = conn
+ self.transaction = transaction
+ self.close_with_result = close_with_result
+
+ def __enter__(self):
+ return self.conn
+
+ def __exit__(self, type, value, traceback):
+ if type is not None:
+ self.transaction.rollback()
+ else:
+ self.transaction.commit()
+ if not self.close_with_result:
+ self.conn.close()
+
+ def begin(self, close_with_result=False):
+ """Return a context manager delivering a :class:`.Connection`
+ with a :class:`.Transaction` established.
+
+ E.g.::
+
+ with engine.begin() as conn:
+ conn.execute("insert into table (x, y, z) values (1, 2, 3)")
+ conn.execute("my_special_procedure(5)")
+
+ Upon successful operation, the :class:`.Transaction`
+ is committed. If an error is raised, the :class:`.Transaction`
+ is rolled back.
+
+ The ``close_with_result`` flag is normally ``False``, and indicates
+ that the :class:`.Connection` will be closed when the operation
+ is complete. When set to ``True``, it indicates the :class:`.Connection`
+ is in "single use" mode, where the :class:`.ResultProxy`
+ returned by the first call to :meth:`.Connection.execute` will
+ close the :class:`.Connection` when that :class:`.ResultProxy`
+ has exhausted all result rows.
+
+ New in 0.7.6.
+
+ See also:
+
+ :meth:`.Engine.connect` - procure a :class:`.Connection` from
+ an :class:`.Engine`.
+
+ :meth:`.Connection.begin` - start a :class:`.Transaction`
+ for a particular :class:`.Connection`.
+
+ """
+ conn = self.contextual_connect(close_with_result=close_with_result)
+ trans = conn.begin()
+ return Engine._trans_ctx(conn, trans, close_with_result)
+
def transaction(self, callable_, *args, **kwargs):
"""Execute the given function within a transaction boundary.
- The function is passed a newly procured
- :class:`.Connection` as the first argument, followed by
- the given \*args and \**kwargs. The :class:`.Connection`
- is then closed (returned to the pool) when the operation
- is complete.
+ The function is passed a :class:`.Connection` newly procured
+ from :meth:`.Engine.contextual_connect` as the first argument,
+ followed by the given \*args and \**kwargs.
+
+ e.g.::
+
+ def do_something(conn, x, y):
+ conn.execute("some statement", {'x':x, 'y':y})
+
+ engine.transaction(do_something, 5, 10)
+
+ The operations inside the function are all invoked within the
+ context of a single :class:`.Transaction`.
+ Upon success, the transaction is committed. If an
+ exception is raised, the transaction is rolled back
+ before propagating the exception.
+
+ .. note::
+
+ The :meth:`.transaction` method is superseded by
+ the usage of the Python ``with:`` statement, which can
+ be used with :meth:`.Engine.begin`::
+
+ with engine.begin() as conn:
+ conn.execute("some statement", {'x':5, 'y':10})
- This method can be used interchangeably with
- :meth:`.Connection.transaction`. See that method for
- more details on usage as well as a modern alternative
- using context managers (i.e. the ``with`` statement).
+ See also:
+
+ :meth:`.Engine.begin` - engine-level transactional
+ context
+
+ :meth:`.Connection.transaction` - connection-level version of
+ :meth:`.Engine.transaction`
"""
@@ -2261,15 +2397,15 @@ class Engine(Connectable, log.Identified):
def run_callable(self, callable_, *args, **kwargs):
"""Given a callable object or function, execute it, passing
a :class:`.Connection` as the first argument.
-
+
The given \*args and \**kwargs are passed subsequent
to the :class:`.Connection` argument.
-
+
This function, along with :meth:`.Connection.run_callable`,
allows a function to be run with a :class:`.Connection`
or :class:`.Engine` object without the need to know
which one is being dealt with.
-
+
"""
conn = self.contextual_connect()
try:
@@ -2390,19 +2526,19 @@ class Engine(Connectable, log.Identified):
def raw_connection(self):
"""Return a "raw" DBAPI connection from the connection pool.
-
+
The returned object is a proxied version of the DBAPI
connection object used by the underlying driver in use.
The object will have all the same behavior as the real DBAPI
connection, except that its ``close()`` method will result in the
connection being returned to the pool, rather than being closed
for real.
-
+
This method provides direct DBAPI connection access for
special situations. In most situations, the :class:`.Connection`
object should be used, which is procured using the
:meth:`.Engine.connect` method.
-
+
"""
return self.pool.unique_connection()
@@ -2487,7 +2623,6 @@ except ImportError:
def __getattr__(self, name):
try:
- # TODO: no test coverage here
return self[name]
except KeyError, e:
raise AttributeError(e.args[0])
@@ -2575,6 +2710,10 @@ class ResultMetaData(object):
context = parent.context
dialect = context.dialect
typemap = dialect.dbapi_type_map
+ translate_colname = dialect._translate_colname
+
+ # high precedence key values.
+ primary_keymap = {}
for i, rec in enumerate(metadata):
colname = rec[0]
@@ -2583,6 +2722,9 @@ class ResultMetaData(object):
if dialect.description_encoding:
colname = dialect._description_decoder(colname)
+ if translate_colname:
+ colname, untranslated = translate_colname(colname)
+
if context.result_map:
try:
name, obj, type_ = context.result_map[colname.lower()]
@@ -2600,15 +2742,17 @@ class ResultMetaData(object):
# indexes as keys. This is only needed for the Python version of
# RowProxy (the C version uses a faster path for integer indexes).
- keymap[i] = rec
-
- # Column names as keys
- if keymap.setdefault(name.lower(), rec) is not rec:
- # We do not raise an exception directly because several
- # columns colliding by name is not a problem as long as the
- # user does not try to access them (ie use an index directly,
- # or the more precise ColumnElement)
- keymap[name.lower()] = (processor, obj, None)
+ primary_keymap[i] = rec
+
+ # populate primary keymap, looking for conflicts.
+ if primary_keymap.setdefault(name.lower(), rec) is not rec:
+ # place a record that doesn't have the "index" - this
+ # is interpreted later as an AmbiguousColumnError,
+ # but only when actually accessed. Columns
+ # colliding by name is not a problem if those names
+ # aren't used; integer and ColumnElement access is always
+ # unambiguous.
+ primary_keymap[name.lower()] = (processor, obj, None)
if dialect.requires_name_normalize:
colname = dialect.normalize_name(colname)
@@ -2618,10 +2762,20 @@ class ResultMetaData(object):
for o in obj:
keymap[o] = rec
+ if translate_colname and \
+ untranslated:
+ keymap[untranslated] = rec
+
+ # overwrite keymap values with those of the
+ # high precedence keymap.
+ keymap.update(primary_keymap)
+
if parent._echo:
context.engine.logger.debug(
"Col %r", tuple(x[0] for x in metadata))
+ @util.pending_deprecation("0.8", "sqlite dialect uses "
+ "_translate_colname() now")
def _set_keymap_synonym(self, name, origname):
"""Set a synonym for the given name.
@@ -2647,7 +2801,7 @@ class ResultMetaData(object):
if key._label and key._label.lower() in map:
result = map[key._label.lower()]
elif hasattr(key, 'name') and key.name.lower() in map:
- # match is only on name.
+ # match is only on name.
result = map[key.name.lower()]
# search extra hard to make sure this
# isn't a column/label name overlap.
@@ -2800,7 +2954,7 @@ class ResultProxy(object):
@property
def returns_rows(self):
"""True if this :class:`.ResultProxy` returns rows.
-
+
I.e. if it is legal to call the methods
:meth:`~.ResultProxy.fetchone`,
:meth:`~.ResultProxy.fetchmany`
@@ -2814,12 +2968,12 @@ class ResultProxy(object):
"""True if this :class:`.ResultProxy` is the result
of a executing an expression language compiled
:func:`.expression.insert` construct.
-
+
When True, this implies that the
:attr:`inserted_primary_key` attribute is accessible,
assuming the statement did not include
a user defined "returning" construct.
-
+
"""
return self.context.isinsert
@@ -2867,7 +3021,7 @@ class ResultProxy(object):
@util.memoized_property
def inserted_primary_key(self):
"""Return the primary key for the row just inserted.
-
+
The return value is a list of scalar values
corresponding to the list of primary key columns
in the target table.
@@ -2875,7 +3029,7 @@ class ResultProxy(object):
This only applies to single row :func:`.insert`
constructs which did not explicitly specify
:meth:`.Insert.returning`.
-
+
Note that primary key columns which specify a
server_default clause,
or otherwise do not qualify as "autoincrement"
diff --git a/libs/sqlalchemy/engine/default.py b/libs/sqlalchemy/engine/default.py
index 73bd7fd..5c2d981 100644
--- a/libs/sqlalchemy/engine/default.py
+++ b/libs/sqlalchemy/engine/default.py
@@ -44,6 +44,7 @@ class DefaultDialect(base.Dialect):
postfetch_lastrowid = True
implicit_returning = False
+
supports_native_enum = False
supports_native_boolean = False
@@ -95,6 +96,10 @@ class DefaultDialect(base.Dialect):
# and denormalize_name() must be provided.
requires_name_normalize = False
+ # a hook for SQLite's translation of
+ # result column names
+ _translate_colname = None
+
reflection_options = ()
def __init__(self, convert_unicode=False, assert_unicode=False,
@@ -329,6 +334,9 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters)
+ def do_execute_no_params(self, cursor, statement, context=None):
+ cursor.execute(statement)
+
def is_disconnect(self, e, connection, cursor):
return False
@@ -533,6 +541,10 @@ class DefaultExecutionContext(base.ExecutionContext):
return self
@util.memoized_property
+ def no_parameters(self):
+ return self.execution_options.get("no_parameters", False)
+
+ @util.memoized_property
def is_crud(self):
return self.isinsert or self.isupdate or self.isdelete
diff --git a/libs/sqlalchemy/engine/reflection.py b/libs/sqlalchemy/engine/reflection.py
index f5911f3..71d97e6 100644
--- a/libs/sqlalchemy/engine/reflection.py
+++ b/libs/sqlalchemy/engine/reflection.py
@@ -317,7 +317,7 @@ class Inspector(object):
info_cache=self.info_cache, **kw)
return indexes
- def reflecttable(self, table, include_columns, exclude_columns=None):
+ def reflecttable(self, table, include_columns, exclude_columns=()):
"""Given a Table object, load its internal constructs based on introspection.
This is the underlying method used by most dialects to produce
@@ -414,9 +414,12 @@ class Inspector(object):
# Primary keys
pk_cons = self.get_pk_constraint(table_name, schema, **tblkw)
if pk_cons:
+ pk_cols = [table.c[pk]
+ for pk in pk_cons['constrained_columns']
+ if pk in table.c and pk not in exclude_columns
+ ] + [pk for pk in table.primary_key if pk.key in exclude_columns]
primary_key_constraint = sa_schema.PrimaryKeyConstraint(name=pk_cons.get('name'),
- *[table.c[pk] for pk in pk_cons['constrained_columns']
- if pk in table.c]
+ *pk_cols
)
table.append_constraint(primary_key_constraint)
diff --git a/libs/sqlalchemy/engine/strategies.py b/libs/sqlalchemy/engine/strategies.py
index 7b2da68..4d5a4b3 100644
--- a/libs/sqlalchemy/engine/strategies.py
+++ b/libs/sqlalchemy/engine/strategies.py
@@ -108,7 +108,8 @@ class DefaultEngineStrategy(EngineStrategy):
'timeout': 'pool_timeout',
'recycle': 'pool_recycle',
'events':'pool_events',
- 'use_threadlocal':'pool_threadlocal'}
+ 'use_threadlocal':'pool_threadlocal',
+ 'reset_on_return':'pool_reset_on_return'}
for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k)
if tk in kwargs:
@@ -226,6 +227,9 @@ class MockEngineStrategy(EngineStrategy):
def contextual_connect(self, **kwargs):
return self
+ def execution_options(self, **kw):
+ return self
+
def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler(
statement, parameters, engine=self, **kwargs)
diff --git a/libs/sqlalchemy/event.py b/libs/sqlalchemy/event.py
index 9cc3139..cd70b3a 100644
--- a/libs/sqlalchemy/event.py
+++ b/libs/sqlalchemy/event.py
@@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL')
def listen(target, identifier, fn, *args, **kw):
"""Register a listener function for the given target.
-
+
e.g.::
-
+
from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint
-
+
def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % (
table.name,
@@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw):
def listens_for(target, identifier, *args, **kw):
"""Decorate a function as a listener for the given target + identifier.
-
+
e.g.::
-
+
from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint
-
+
@event.listens_for(UniqueConstraint, "after_parent_attach")
def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % (
@@ -205,12 +205,14 @@ class _DispatchDescriptor(object):
def insert(self, obj, target, propagate):
assert isinstance(target, type), \
"Class-level Event targets must be classes."
-
stack = [target]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
- self._clslevel[cls].insert(0, obj)
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ self._clslevel[cls].insert(0, obj)
def append(self, obj, target, propagate):
assert isinstance(target, type), \
@@ -220,7 +222,20 @@ class _DispatchDescriptor(object):
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
- self._clslevel[cls].append(obj)
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ self._clslevel[cls].append(obj)
+
+ def update_subclass(self, target):
+ clslevel = self._clslevel[target]
+ for cls in target.__mro__[1:]:
+ if cls in self._clslevel:
+ clslevel.extend([
+ fn for fn
+ in self._clslevel[cls]
+ if fn not in clslevel
+ ])
def remove(self, obj, target):
stack = [target]
@@ -252,6 +267,8 @@ class _ListenerCollection(object):
_exec_once = False
def __init__(self, parent, target_cls):
+ if target_cls not in parent._clslevel:
+ parent.update_subclass(target_cls)
self.parent_listeners = parent._clslevel[target_cls]
self.name = parent.__name__
self.listeners = []
diff --git a/libs/sqlalchemy/exc.py b/libs/sqlalchemy/exc.py
index 64f25a2..91ffc28 100644
--- a/libs/sqlalchemy/exc.py
+++ b/libs/sqlalchemy/exc.py
@@ -162,7 +162,7 @@ UnmappedColumnError = None
class StatementError(SQLAlchemyError):
"""An error occurred during execution of a SQL statement.
- :class:`.StatementError` wraps the exception raised
+ :class:`StatementError` wraps the exception raised
during execution, and features :attr:`.statement`
and :attr:`.params` attributes which supply context regarding
the specifics of the statement which had an issue.
@@ -172,6 +172,15 @@ class StatementError(SQLAlchemyError):
"""
+ statement = None
+ """The string SQL statement being invoked when this exception occurred."""
+
+ params = None
+ """The parameter list being used when this exception occurred."""
+
+ orig = None
+ """The DBAPI exception object."""
+
def __init__(self, message, statement, params, orig):
SQLAlchemyError.__init__(self, message)
self.statement = statement
@@ -192,21 +201,21 @@ class StatementError(SQLAlchemyError):
class DBAPIError(StatementError):
"""Raised when the execution of a database operation fails.
- ``DBAPIError`` wraps exceptions raised by the DB-API underlying the
+ Wraps exceptions raised by the DB-API underlying the
database operation. Driver-specific implementations of the standard
DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
- ``DBAPIError`` when possible. DB-API's ``Error`` type maps to
- ``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note
+ :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to
+ :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note
that there is no guarantee that different DB-API implementations will
raise the same exception type for any given error condition.
- :class:`.DBAPIError` features :attr:`.statement`
- and :attr:`.params` attributes which supply context regarding
+ :class:`DBAPIError` features :attr:`~.StatementError.statement`
+ and :attr:`~.StatementError.params` attributes which supply context regarding
the specifics of the statement which had an issue, for the
typical case when the error was raised within the context of
emitting a SQL statement.
- The wrapped exception object is available in the :attr:`.orig` attribute.
+ The wrapped exception object is available in the :attr:`~.StatementError.orig` attribute.
Its type and properties are DB-API implementation specific.
"""
diff --git a/libs/sqlalchemy/ext/declarative.py b/libs/sqlalchemy/ext/declarative.py
index 891130a..faf575d 100755
--- a/libs/sqlalchemy/ext/declarative.py
+++ b/libs/sqlalchemy/ext/declarative.py
@@ -1213,6 +1213,12 @@ def _as_declarative(cls, classname, dict_):
del our_stuff[key]
cols = sorted(cols, key=lambda c:c._creation_order)
table = None
+
+ if hasattr(cls, '__table_cls__'):
+ table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ else:
+ table_cls = Table
+
if '__table__' not in dict_:
if tablename is not None:
@@ -1230,7 +1236,7 @@ def _as_declarative(cls, classname, dict_):
if autoload:
table_kw['autoload'] = True
- cls.__table__ = table = Table(tablename, cls.metadata,
+ cls.__table__ = table = table_cls(tablename, cls.metadata,
*(tuple(cols) + tuple(args)),
**table_kw)
else:
diff --git a/libs/sqlalchemy/ext/hybrid.py b/libs/sqlalchemy/ext/hybrid.py
index 086ec90..8734181 100644
--- a/libs/sqlalchemy/ext/hybrid.py
+++ b/libs/sqlalchemy/ext/hybrid.py
@@ -11,30 +11,30 @@ class level and at the instance level.
The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of method
decorator, is around 50 lines of code and has almost no dependencies on the rest
-of SQLAlchemy. It can in theory work with any class-level expression generator.
+of SQLAlchemy. It can, in theory, work with any descriptor-based expression
+system.
-Consider a table ``interval`` as below::
-
- from sqlalchemy import MetaData, Table, Column, Integer
-
- metadata = MetaData()
-
- interval_table = Table('interval', metadata,
- Column('id', Integer, primary_key=True),
- Column('start', Integer, nullable=False),
- Column('end', Integer, nullable=False)
- )
-
-We can define higher level functions on mapped classes that produce SQL
-expressions at the class level, and Python expression evaluation at the
-instance level. Below, each function decorated with :func:`.hybrid_method`
-or :func:`.hybrid_property` may receive ``self`` as an instance of the class,
-or as the class itself::
+Consider a mapping ``Interval``, representing integer ``start`` and ``end``
+values. We can define higher level functions on mapped classes that produce
+SQL expressions at the class level, and Python expression evaluation at the
+instance level. Below, each function decorated with :class:`.hybrid_method` or
+:class:`.hybrid_property` may receive ``self`` as an instance of the class, or
+as the class itself::
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.orm import Session, aliased
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
- from sqlalchemy.orm import mapper, Session, aliased
+
+ Base = declarative_base()
+
+ class Interval(Base):
+ __tablename__ = 'interval'
+
+ id = Column(Integer, primary_key=True)
+ start = Column(Integer, nullable=False)
+ end = Column(Integer, nullable=False)
- class Interval(object):
def __init__(self, start, end):
self.start = start
self.end = end
@@ -51,8 +51,6 @@ or as the class itself::
def intersects(self, other):
return self.contains(other.start) | self.contains(other.end)
- mapper(Interval, interval_table)
-
Above, the ``length`` property returns the difference between the ``end`` and
``start`` attributes. With an instance of ``Interval``, this subtraction occurs
in Python, using normal Python descriptor mechanics::
@@ -60,10 +58,11 @@ in Python, using normal Python descriptor mechanics::
>>> i1 = Interval(5, 10)
>>> i1.length
5
-
-At the class level, the usual descriptor behavior of returning the descriptor
-itself is modified by :class:`.hybrid_property`, to instead evaluate the function
-body given the ``Interval`` class as the argument::
+
+When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
+descriptor evaluates the function body given the ``Interval`` class as
+the argument, which when evaluated with SQLAlchemy expression mechanics
+returns a new SQL expression::
>>> print Interval.length
interval."end" - interval.start
@@ -83,9 +82,10 @@ locate attributes, so can also be used with hybrid attributes::
FROM interval
WHERE interval."end" - interval.start = :param_1
-The ``contains()`` and ``intersects()`` methods are decorated with :class:`.hybrid_method`.
-This decorator applies the same idea to methods which accept
-zero or more arguments. The above methods return boolean values, and take advantage
+The ``Interval`` class example also illustrates two methods, ``contains()`` and ``intersects()``,
+decorated with :class:`.hybrid_method`.
+This decorator applies the same idea to methods that :class:`.hybrid_property` applies
+to attributes. The methods return boolean values, and take advantage
of the Python ``|`` and ``&`` bitwise operators to produce equivalent instance-level and
SQL expression-level boolean behavior::
@@ -368,7 +368,12 @@ SQL expression versus SQL expression::
>>> sw1 = aliased(SearchWord)
>>> sw2 = aliased(SearchWord)
- >>> print Session().query(sw1.word_insensitive, sw2.word_insensitive).filter(sw1.word_insensitive > sw2.word_insensitive)
+ >>> print Session().query(
+ ... sw1.word_insensitive,
+ ... sw2.word_insensitive).\\
+ ... filter(
+ ... sw1.word_insensitive > sw2.word_insensitive
+ ... )
SELECT lower(searchword_1.word) AS lower_1, lower(searchword_2.word) AS lower_2
FROM searchword AS searchword_1, searchword AS searchword_2
WHERE lower(searchword_1.word) > lower(searchword_2.word)
diff --git a/libs/sqlalchemy/ext/orderinglist.py b/libs/sqlalchemy/ext/orderinglist.py
index 9847861..3895725 100644
--- a/libs/sqlalchemy/ext/orderinglist.py
+++ b/libs/sqlalchemy/ext/orderinglist.py
@@ -184,12 +184,11 @@ class OrderingList(list):
This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship.
- ordering_attr
+ :param ordering_attr:
Name of the attribute that stores the object's order in the
relationship.
- ordering_func
- Optional. A function that maps the position in the Python list to a
+ :param ordering_func: Optional. A function that maps the position in the Python list to a
value to store in the ``ordering_attr``. Values returned are
usually (but need not be!) integers.
@@ -202,7 +201,7 @@ class OrderingList(list):
like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests.
- reorder_on_append
+ :param reorder_on_append:
Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a
diff --git a/libs/sqlalchemy/orm/collections.py b/libs/sqlalchemy/orm/collections.py
index 7872715..160fac8 100644
--- a/libs/sqlalchemy/orm/collections.py
+++ b/libs/sqlalchemy/orm/collections.py
@@ -112,12 +112,32 @@ from sqlalchemy.sql import expression
from sqlalchemy import schema, util, exc as sa_exc
+
__all__ = ['collection', 'collection_adapter',
'mapped_collection', 'column_mapped_collection',
'attribute_mapped_collection']
__instrumentation_mutex = util.threading.Lock()
+class _SerializableColumnGetter(object):
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetter, (self.colkeys,)
+
+ def __call__(self, value):
+ state = instance_state(value)
+ m = _state_mapper(state)
+ key = [m._get_state_attr_by_column(
+ state, state.dict,
+ m.mapped_table.columns[k])
+ for k in self.colkeys]
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
def column_mapped_collection(mapping_spec):
"""A dictionary-based collection type with column-based keying.
@@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec):
after a session flush.
"""
+ global _state_mapper, instance_state
from sqlalchemy.orm.util import _state_mapper
from sqlalchemy.orm.attributes import instance_state
- cols = [expression._only_column_elements(q, "mapping_spec")
- for q in util.to_list(mapping_spec)]
- if len(cols) == 1:
- def keyfunc(value):
- state = instance_state(value)
- m = _state_mapper(state)
- return m._get_state_attr_by_column(state, state.dict, cols[0])
- else:
- mapping_spec = tuple(cols)
- def keyfunc(value):
- state = instance_state(value)
- m = _state_mapper(state)
- return tuple(m._get_state_attr_by_column(state, state.dict, c)
- for c in mapping_spec)
+ cols = [c.key for c in [
+ expression._only_column_elements(q, "mapping_spec")
+ for q in util.to_list(mapping_spec)]]
+ keyfunc = _SerializableColumnGetter(cols)
return lambda: MappedCollection(keyfunc)
+class _SerializableAttrGetter(object):
+ def __init__(self, name):
+ self.name = name
+ self.getter = operator.attrgetter(name)
+
+ def __call__(self, target):
+ return self.getter(target)
+
+ def __reduce__(self):
+ return _SerializableAttrGetter, (self.name, )
+
def attribute_mapped_collection(attr_name):
"""A dictionary-based collection type with attribute-based keying.
@@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name):
after a session flush.
"""
- return lambda: MappedCollection(operator.attrgetter(attr_name))
+ getter = _SerializableAttrGetter(attr_name)
+ return lambda: MappedCollection(getter)
def mapped_collection(keyfunc):
@@ -814,6 +837,7 @@ def _instrument_class(cls):
methods[name] = None, None, after
# apply ABC auto-decoration to methods that need it
+
for method, decorator in decorators.items():
fn = getattr(cls, method, None)
if (fn and method not in methods and
@@ -1465,3 +1489,13 @@ class MappedCollection(dict):
incoming_key, value, new_key))
yield value
_convert = collection.converter(_convert)
+
+# ensure instrumentation is associated with
+# these built-in classes; if a user-defined class
+# subclasses these and uses @internally_instrumented,
+# the superclass is otherwise not instrumented.
+# see [ticket:2406].
+_instrument_class(MappedCollection)
+_instrument_class(InstrumentedList)
+_instrument_class(InstrumentedSet)
+
diff --git a/libs/sqlalchemy/orm/mapper.py b/libs/sqlalchemy/orm/mapper.py
index 4c952c1..e96b754 100644
--- a/libs/sqlalchemy/orm/mapper.py
+++ b/libs/sqlalchemy/orm/mapper.py
@@ -1452,12 +1452,19 @@ class Mapper(object):
return result
def _is_userland_descriptor(self, obj):
- return not isinstance(obj,
- (MapperProperty, attributes.QueryableAttribute)) and \
- hasattr(obj, '__get__') and not \
- isinstance(obj.__get__(None, obj),
- attributes.QueryableAttribute)
-
+ if isinstance(obj, (MapperProperty,
+ attributes.QueryableAttribute)):
+ return False
+ elif not hasattr(obj, '__get__'):
+ return False
+ else:
+ obj = util.unbound_method_to_callable(obj)
+ if isinstance(
+ obj.__get__(None, obj),
+ attributes.QueryableAttribute
+ ):
+ return False
+ return True
def _should_exclude(self, name, assigned_name, local, column):
"""determine whether a particular property should be implicitly
@@ -1875,501 +1882,6 @@ class Mapper(object):
self._memoized_values[key] = value = callable_()
return value
- def _post_update(self, states, uowtransaction, post_update_cols):
- """Issue UPDATE statements on behalf of a relationship() which
- specifies post_update.
-
- """
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
- # if session has a connection callable,
- # organize individual states with the connection
- # to use for update
- if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(self)
- connection_callable = None
-
- tups = []
- for state in _sort_states(states):
- if connection_callable:
- conn = connection_callable(self, state.obj())
- else:
- conn = connection
-
- mapper = _state_mapper(state)
-
- tups.append((state, state.dict, mapper, conn))
-
- table_to_mapper = self._sorted_tables
-
- for table in table_to_mapper:
- update = []
-
- for state, state_dict, mapper, connection in tups:
- if table not in mapper._pks_by_table:
- continue
-
- pks = mapper._pks_by_table[table]
- params = {}
- hasdata = False
-
- for col in mapper._cols_by_table[table]:
- if col in pks:
- params[col._label] = \
- mapper._get_state_attr_by_column(
- state,
- state_dict, col)
- elif col in post_update_cols:
- prop = mapper._columntoproperty[col]
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- value = history.added[0]
- params[col.key] = value
- hasdata = True
- if hasdata:
- update.append((state, state_dict, params, mapper,
- connection))
-
- if update:
- mapper = table_to_mapper[table]
-
- def update_stmt():
- clause = sql.and_()
-
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
-
- return table.update(clause)
-
- statement = self._memo(('post_update', table), update_stmt)
-
- # execute each UPDATE in the order according to the original
- # list of states to guarantee row access order, but
- # also group them into common (connection, cols) sets
- # to support executemany().
- for key, grouper in groupby(
- update, lambda rec: (rec[4], rec[2].keys())
- ):
- multiparams = [params for state, state_dict,
- params, mapper, conn in grouper]
- cached_connections[connection].\
- execute(statement, multiparams)
-
- def _save_obj(self, states, uowtransaction, single=False):
- """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
- of objects.
-
- This is called within the context of a UOWTransaction during a
- flush operation, given a list of states to be flushed. The
- base mapper in an inheritance hierarchy handles the inserts/
- updates for all descendant mappers.
-
- """
-
- # if batch=false, call _save_obj separately for each object
- if not single and not self.batch:
- for state in _sort_states(states):
- self._save_obj([state],
- uowtransaction,
- single=True)
- return
-
- # if session has a connection callable,
- # organize individual states with the connection
- # to use for insert/update
- if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(self)
- connection_callable = None
-
- tups = []
-
- for state in _sort_states(states):
- if connection_callable:
- conn = connection_callable(self, state.obj())
- else:
- conn = connection
-
- has_identity = bool(state.key)
- mapper = _state_mapper(state)
- instance_key = state.key or mapper._identity_key_from_state(state)
-
- row_switch = None
-
- # call before_XXX extensions
- if not has_identity:
- mapper.dispatch.before_insert(mapper, conn, state)
- else:
- mapper.dispatch.before_update(mapper, conn, state)
-
- # detect if we have a "pending" instance (i.e. has
- # no instance_key attached to it), and another instance
- # with the same identity key already exists as persistent.
- # convert to an UPDATE if so.
- if not has_identity and \
- instance_key in uowtransaction.session.identity_map:
- instance = \
- uowtransaction.session.identity_map[instance_key]
- existing = attributes.instance_state(instance)
- if not uowtransaction.is_deleted(existing):
- raise orm_exc.FlushError(
- "New instance %s with identity key %s conflicts "
- "with persistent instance %s" %
- (state_str(state), instance_key,
- state_str(existing)))
-
- self._log_debug(
- "detected row switch for identity %s. "
- "will update %s, remove %s from "
- "transaction", instance_key,
- state_str(state), state_str(existing))
-
- # remove the "delete" flag from the existing element
- uowtransaction.remove_state_actions(existing)
- row_switch = existing
-
- tups.append(
- (state, state.dict, mapper, conn,
- has_identity, instance_key, row_switch)
- )
-
- # dictionary of connection->connection_with_cache_options.
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
- table_to_mapper = self._sorted_tables
-
- for table in table_to_mapper:
- insert = []
- update = []
-
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in tups:
- if table not in mapper._pks_by_table:
- continue
-
- pks = mapper._pks_by_table[table]
-
- isinsert = not has_identity and not row_switch
-
- params = {}
- value_params = {}
-
- if isinsert:
- has_all_pks = True
- for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col:
- params[col.key] = \
- mapper.version_id_generator(None)
- else:
- # pull straight from the dict for
- # pending objects
- prop = mapper._columntoproperty[col]
- value = state_dict.get(prop.key, None)
-
- if value is None:
- if col in pks:
- has_all_pks = False
- elif col.default is None and \
- col.server_default is None:
- params[col.key] = value
-
- elif isinstance(value, sql.ClauseElement):
- value_params[col] = value
- else:
- params[col.key] = value
-
- insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks))
- else:
- hasdata = False
- for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col:
- params[col._label] = \
- mapper._get_committed_state_attr_by_column(
- row_switch or state,
- row_switch and row_switch.dict
- or state_dict,
- col)
-
- prop = mapper._columntoproperty[col]
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- params[col.key] = history.added[0]
- hasdata = True
- else:
- params[col.key] = \
- mapper.version_id_generator(
- params[col._label])
-
- # HACK: check for history, in case the
- # history is only
- # in a different table than the one
- # where the version_id_col is.
- for prop in mapper._columntoproperty.\
- itervalues():
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- hasdata = True
- else:
- prop = mapper._columntoproperty[col]
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- if isinstance(history.added[0],
- sql.ClauseElement):
- value_params[col] = history.added[0]
- else:
- value = history.added[0]
- params[col.key] = value
-
- if col in pks:
- if history.deleted and \
- not row_switch:
- # if passive_updates and sync detected
- # this was a pk->pk sync, use the new
- # value to locate the row, since the
- # DB would already have set this
- if ("pk_cascaded", state, col) in \
- uowtransaction.\
- attributes:
- value = history.added[0]
- params[col._label] = value
- else:
- # use the old value to
- # locate the row
- value = history.deleted[0]
- params[col._label] = value
- hasdata = True
- else:
- # row switch logic can reach us here
- # remove the pk from the update params
- # so the update doesn't
- # attempt to include the pk in the
- # update statement
- del params[col.key]
- value = history.added[0]
- params[col._label] = value
- if value is None and hasdata:
- raise sa_exc.FlushError(
- "Can't update table "
- "using NULL for primary key "
- "value")
- else:
- hasdata = True
- elif col in pks:
- value = state.manager[prop.key].\
- impl.get(state, state_dict)
- if value is None:
- raise sa_exc.FlushError(
- "Can't update table "
- "using NULL for primary "
- "key value")
- params[col._label] = value
- if hasdata:
- update.append((state, state_dict, params, mapper,
- connection, value_params))
-
- if update:
- mapper = table_to_mapper[table]
-
- needs_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
-
- def update_stmt():
- clause = sql.and_()
-
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
-
- if needs_version_id:
- clause.clauses.append(mapper.version_id_col ==\
- sql.bindparam(mapper.version_id_col._label,
- type_=col.type))
-
- return table.update(clause)
-
- statement = self._memo(('update', table), update_stmt)
-
- rows = 0
- for state, state_dict, params, mapper, \
- connection, value_params in update:
-
- if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
- else:
- c = cached_connections[connection].\
- execute(statement, params)
-
- mapper._postfetch(
- uowtransaction,
- table,
- state,
- state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
- c.context.compiled_parameters[0],
- value_params)
- rows += c.rowcount
-
- if connection.dialect.supports_sane_rowcount:
- if rows != len(update):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to update %d row(s); "
- "%d were matched." %
- (table.description, len(update), rows))
-
- elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
-
- if insert:
- statement = self._memo(('insert', table), table.insert)
-
- for (connection, pkeys, hasvalue, has_all_pks), \
- records in groupby(insert,
- lambda rec: (rec[4],
- rec[2].keys(),
- bool(rec[5]),
- rec[6])
- ):
- if has_all_pks and not hasvalue:
- records = list(records)
- multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].\
- execute(statement, multiparams)
-
- for (state, state_dict, params, mapper,
- conn, value_params, has_all_pks), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
- mapper._postfetch(
- uowtransaction,
- table,
- state,
- state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
- last_inserted_params,
- value_params)
-
- else:
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_pks in records:
-
- if value_params:
- result = connection.execute(
- statement.values(value_params),
- params)
- else:
- result = cached_connections[connection].\
- execute(statement, params)
-
- primary_key = result.context.inserted_primary_key
-
- if primary_key is not None:
- # set primary key attributes
- for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
- prop = mapper._columntoproperty[col]
- if state_dict.get(prop.key) is None:
- # TODO: would rather say:
- #state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(
- state,
- state_dict,
- col, pk)
-
- mapper._postfetch(
- uowtransaction,
- table,
- state,
- state_dict,
- result.context.prefetch_cols,
- result.context.postfetch_cols,
- result.context.compiled_parameters[0],
- value_params)
-
-
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in tups:
-
- if mapper._readonly_props:
- readonly = state.unmodified_intersection(
- [p.key for p in mapper._readonly_props
- if p.expire_on_flush or p.key not in state.dict]
- )
- if readonly:
- state.expire_attributes(state.dict, readonly)
-
- # if eager_defaults option is enabled,
- # refresh whatever has been expired.
- if self.eager_defaults and state.unloaded:
- state.key = self._identity_key_from_state(state)
- uowtransaction.session.query(self)._load_on_ident(
- state.key, refresh_state=state,
- only_load_props=state.unloaded)
-
- # call after_XXX extensions
- if not has_identity:
- mapper.dispatch.after_insert(mapper, connection, state)
- else:
- mapper.dispatch.after_update(mapper, connection, state)
-
- def _postfetch(self, uowtransaction, table,
- state, dict_, prefetch_cols, postfetch_cols,
- params, value_params):
- """During a flush, expire attributes in need of newly
- persisted database state."""
-
- if self.version_id_col is not None:
- prefetch_cols = list(prefetch_cols) + [self.version_id_col]
-
- for c in prefetch_cols:
- if c.key in params and c in self._columntoproperty:
- self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
- if postfetch_cols:
- state.expire_attributes(state.dict,
- [self._columntoproperty[c].key
- for c in postfetch_cols if c in
- self._columntoproperty]
- )
-
- # synchronize newly inserted ids from one table to the next
- # TODO: this still goes a little too often. would be nice to
- # have definitive list of "columns that changed" here
- for m, equated_pairs in self._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- self.passive_updates)
-
@util.memoized_property
def _table_to_equated(self):
"""memoized map of tables to collections of columns to be
@@ -2387,128 +1899,6 @@ class Mapper(object):
return result
- def _delete_obj(self, states, uowtransaction):
- """Issue ``DELETE`` statements for a list of objects.
-
- This is called within the context of a UOWTransaction during a
- flush operation.
-
- """
- if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(self)
- connection_callable = None
-
- tups = []
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
- for state in _sort_states(states):
- mapper = _state_mapper(state)
-
- if connection_callable:
- conn = connection_callable(self, state.obj())
- else:
- conn = connection
-
- mapper.dispatch.before_delete(mapper, conn, state)
-
- tups.append((state,
- state.dict,
- _state_mapper(state),
- bool(state.key),
- conn))
-
- table_to_mapper = self._sorted_tables
-
- for table in reversed(table_to_mapper.keys()):
- delete = util.defaultdict(list)
- for state, state_dict, mapper, has_identity, connection in tups:
- if not has_identity or table not in mapper._pks_by_table:
- continue
-
- params = {}
- delete[connection].append(params)
- for col in mapper._pks_by_table[table]:
- params[col.key] = \
- value = \
- mapper._get_state_attr_by_column(
- state, state_dict, col)
- if value is None:
- raise sa_exc.FlushError(
- "Can't delete from table "
- "using NULL for primary "
- "key value")
-
- if mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict,
- mapper.version_id_col)
-
- mapper = table_to_mapper[table]
- need_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
-
- def delete_stmt():
- clause = sql.and_()
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(
- col == sql.bindparam(col.key, type_=col.type))
-
- if need_version_id:
- clause.clauses.append(
- mapper.version_id_col ==
- sql.bindparam(
- mapper.version_id_col.key,
- type_=mapper.version_id_col.type
- )
- )
-
- return table.delete(clause)
-
- for connection, del_objects in delete.iteritems():
- statement = self._memo(('delete', table), delete_stmt)
- rows = -1
-
- connection = cached_connections[connection]
-
- if need_version_id and \
- not connection.dialect.supports_sane_multi_rowcount:
- # TODO: need test coverage for this [ticket:1761]
- if connection.dialect.supports_sane_rowcount:
- rows = 0
- # execute deletes individually so that versioned
- # rows can be verified
- for params in del_objects:
- c = connection.execute(statement, params)
- rows += c.rowcount
- else:
- util.warn(
- "Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified." %
- connection.dialect.dialect_description,
- stacklevel=12)
- connection.execute(statement, del_objects)
- else:
- c = connection.execute(statement, del_objects)
- if connection.dialect.supports_sane_multi_rowcount:
- rows = c.rowcount
-
- if rows != -1 and rows != len(del_objects):
- raise orm_exc.StaleDataError(
- "DELETE statement on table '%s' expected to delete %d row(s); "
- "%d were matched." %
- (table.description, len(del_objects), c.rowcount)
- )
-
- for state, state_dict, mapper, has_identity, connection in tups:
- mapper.dispatch.after_delete(mapper, connection, state)
def _instance_processor(self, context, path, reduced_path, adapter,
polymorphic_from=None,
@@ -2518,6 +1908,12 @@ class Mapper(object):
"""Produce a mapper level row processor callable
which processes rows into mapped instances."""
+ # note that this method, most of which exists in a closure
+ # called _instance(), resists being broken out, as
+ # attempts to do so tend to add significant function
+ # call overhead. _instance() is the most
+ # performance-critical section in the whole ORM.
+
pk_cols = self.primary_key
if polymorphic_from or refresh_state:
@@ -2961,13 +2357,6 @@ def _event_on_resurrect(state):
state, state.dict, col, val)
-def _sort_states(states):
- pending = set(states)
- persistent = set(s for s in pending if s.key is not None)
- pending.difference_update(persistent)
- return sorted(pending, key=operator.attrgetter("insert_order")) + \
- sorted(persistent, key=lambda q:q.key[1])
-
class _ColumnMapping(util.py25_dict):
"""Error reporting helper for mapper._columntoproperty."""
diff --git a/libs/sqlalchemy/orm/persistence.py b/libs/sqlalchemy/orm/persistence.py
new file mode 100644
index 0000000..55b9bf8
--- /dev/null
+++ b/libs/sqlalchemy/orm/persistence.py
@@ -0,0 +1,777 @@
+# orm/persistence.py
+# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used to emit INSERT, UPDATE
+and DELETE statements on behalf of a :class:`.Mapper` and its descending
+mappers.
+
+The functions here are called only by the unit of work functions
+in unitofwork.py.
+
+"""
+
+import operator
+from itertools import groupby
+
+from sqlalchemy import sql, util, exc as sa_exc
+from sqlalchemy.orm import attributes, sync, \
+ exc as orm_exc
+
+from sqlalchemy.orm.util import _state_mapper, state_str
+
+def save_obj(base_mapper, states, uowtransaction, single=False):
+ """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
+ of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation, given a list of states to be flushed. The
+ base mapper in an inheritance hierarchy handles the inserts/
+ updates for all descendant mappers.
+
+ """
+
+ # if batch=false, call _save_obj separately for each object
+ if not single and not base_mapper.batch:
+ for state in _sort_states(states):
+ save_obj(base_mapper, [state], uowtransaction, single=True)
+ return
+
+ states_to_insert, states_to_update = _organize_states_for_save(
+ base_mapper,
+ states,
+ uowtransaction)
+
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ for table, mapper in base_mapper._sorted_tables.iteritems():
+ insert = _collect_insert_commands(base_mapper, uowtransaction,
+ table, states_to_insert)
+
+ update = _collect_update_commands(base_mapper, uowtransaction,
+ table, states_to_update)
+
+ if update:
+ _emit_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, update)
+
+ if insert:
+ _emit_insert_statements(base_mapper, uowtransaction,
+ cached_connections,
+ table, insert)
+
+ _finalize_insert_update_commands(base_mapper, uowtransaction,
+ states_to_insert, states_to_update)
+
+def post_update(base_mapper, states, uowtransaction, post_update_cols):
+ """Issue UPDATE statements on behalf of a relationship() which
+ specifies post_update.
+
+ """
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ states_to_update = _organize_states_for_post_update(
+ base_mapper,
+ states, uowtransaction)
+
+
+ for table, mapper in base_mapper._sorted_tables.iteritems():
+ update = _collect_post_update_commands(base_mapper, uowtransaction,
+ table, states_to_update,
+ post_update_cols)
+
+ if update:
+ _emit_post_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, update)
+
+def delete_obj(base_mapper, states, uowtransaction):
+ """Issue ``DELETE`` statements for a list of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation.
+
+ """
+
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ states_to_delete = _organize_states_for_delete(
+ base_mapper,
+ states,
+ uowtransaction)
+
+ table_to_mapper = base_mapper._sorted_tables
+
+ for table in reversed(table_to_mapper.keys()):
+ delete = _collect_delete_commands(base_mapper, uowtransaction,
+ table, states_to_delete)
+
+ mapper = table_to_mapper[table]
+
+ _emit_delete_statements(base_mapper, uowtransaction,
+ cached_connections, mapper, table, delete)
+
+ for state, state_dict, mapper, has_identity, connection \
+ in states_to_delete:
+ mapper.dispatch.after_delete(mapper, connection, state)
+
+def _organize_states_for_save(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for INSERT or
+ UPDATE.
+
+ This includes splitting out into distinct lists for
+ each, calling before_insert/before_update, obtaining
+ key information for each state including its dictionary,
+ mapper, the connection to use for the execution per state,
+ and the identity flag.
+
+ """
+
+ states_to_insert = []
+ states_to_update = []
+
+ for state, dict_, mapper, connection in _connections_for_states(
+ base_mapper, uowtransaction,
+ states):
+
+ has_identity = bool(state.key)
+ instance_key = state.key or mapper._identity_key_from_state(state)
+
+ row_switch = None
+
+ # call before_XXX extensions
+ if not has_identity:
+ mapper.dispatch.before_insert(mapper, connection, state)
+ else:
+ mapper.dispatch.before_update(mapper, connection, state)
+
+ # detect if we have a "pending" instance (i.e. has
+ # no instance_key attached to it), and another instance
+ # with the same identity key already exists as persistent.
+ # convert to an UPDATE if so.
+ if not has_identity and \
+ instance_key in uowtransaction.session.identity_map:
+ instance = \
+ uowtransaction.session.identity_map[instance_key]
+ existing = attributes.instance_state(instance)
+ if not uowtransaction.is_deleted(existing):
+ raise orm_exc.FlushError(
+ "New instance %s with identity key %s conflicts "
+ "with persistent instance %s" %
+ (state_str(state), instance_key,
+ state_str(existing)))
+
+ base_mapper._log_debug(
+ "detected row switch for identity %s. "
+ "will update %s, remove %s from "
+ "transaction", instance_key,
+ state_str(state), state_str(existing))
+
+ # remove the "delete" flag from the existing element
+ uowtransaction.remove_state_actions(existing)
+ row_switch = existing
+
+ if not has_identity and not row_switch:
+ states_to_insert.append(
+ (state, dict_, mapper, connection,
+ has_identity, instance_key, row_switch)
+ )
+ else:
+ states_to_update.append(
+ (state, dict_, mapper, connection,
+ has_identity, instance_key, row_switch)
+ )
+
+ return states_to_insert, states_to_update
+
+def _organize_states_for_post_update(base_mapper, states,
+ uowtransaction):
+ """Make an initial pass across a set of states for UPDATE
+ corresponding to post_update.
+
+ This includes obtaining key information for each state
+ including its dictionary, mapper, the connection to use for
+ the execution per state.
+
+ """
+ return list(_connections_for_states(base_mapper, uowtransaction,
+ states))
+
+def _organize_states_for_delete(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for DELETE.
+
+ This includes calling out before_delete and obtaining
+ key information for each state including its dictionary,
+ mapper, the connection to use for the execution per state.
+
+ """
+ states_to_delete = []
+
+ for state, dict_, mapper, connection in _connections_for_states(
+ base_mapper, uowtransaction,
+ states):
+
+ mapper.dispatch.before_delete(mapper, connection, state)
+
+ states_to_delete.append((state, dict_, mapper,
+ bool(state.key), connection))
+ return states_to_delete
+
+def _collect_insert_commands(base_mapper, uowtransaction, table,
+ states_to_insert):
+ """Identify sets of values to use in INSERT statements for a
+ list of states.
+
+ """
+ insert = []
+ for state, state_dict, mapper, connection, has_identity, \
+ instance_key, row_switch in states_to_insert:
+ if table not in mapper._pks_by_table:
+ continue
+
+ pks = mapper._pks_by_table[table]
+
+ params = {}
+ value_params = {}
+
+ has_all_pks = True
+ for col in mapper._cols_by_table[table]:
+ if col is mapper.version_id_col:
+ params[col.key] = mapper.version_id_generator(None)
+ else:
+ # pull straight from the dict for
+ # pending objects
+ prop = mapper._columntoproperty[col]
+ value = state_dict.get(prop.key, None)
+
+ if value is None:
+ if col in pks:
+ has_all_pks = False
+ elif col.default is None and \
+ col.server_default is None:
+ params[col.key] = value
+
+ elif isinstance(value, sql.ClauseElement):
+ value_params[col] = value
+ else:
+ params[col.key] = value
+
+ insert.append((state, state_dict, params, mapper,
+ connection, value_params, has_all_pks))
+ return insert
+
+def _collect_update_commands(base_mapper, uowtransaction,
+ table, states_to_update):
+ """Identify sets of values to use in UPDATE statements for a
+ list of states.
+
+ This function works intricately with the history system
+ to determine exactly what values should be updated
+ as well as how the row should be matched within an UPDATE
+ statement. Includes some tricky scenarios where the primary
+ key of an object might have been changed.
+
+ """
+
+ update = []
+ for state, state_dict, mapper, connection, has_identity, \
+ instance_key, row_switch in states_to_update:
+ if table not in mapper._pks_by_table:
+ continue
+
+ pks = mapper._pks_by_table[table]
+
+ params = {}
+ value_params = {}
+
+ hasdata = hasnull = False
+ for col in mapper._cols_by_table[table]:
+ if col is mapper.version_id_col:
+ params[col._label] = \
+ mapper._get_committed_state_attr_by_column(
+ row_switch or state,
+ row_switch and row_switch.dict
+ or state_dict,
+ col)
+
+ prop = mapper._columntoproperty[col]
+ history = attributes.get_state_history(
+ state, prop.key,
+ attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history.added:
+ params[col.key] = history.added[0]
+ hasdata = True
+ else:
+ params[col.key] = mapper.version_id_generator(
+ params[col._label])
+
+ # HACK: check for history, in case the
+ # history is only
+ # in a different table than the one
+ # where the version_id_col is.
+ for prop in mapper._columntoproperty.itervalues():
+ history = attributes.get_state_history(
+ state, prop.key,
+ attributes.PASSIVE_NO_INITIALIZE)
+ if history.added:
+ hasdata = True
+ else:
+ prop = mapper._columntoproperty[col]
+ history = attributes.get_state_history(
+ state, prop.key,
+ attributes.PASSIVE_NO_INITIALIZE)
+ if history.added:
+ if isinstance(history.added[0],
+ sql.ClauseElement):
+ value_params[col] = history.added[0]
+ else:
+ value = history.added[0]
+ params[col.key] = value
+
+ if col in pks:
+ if history.deleted and \
+ not row_switch:
+ # if passive_updates and sync detected
+ # this was a pk->pk sync, use the new
+ # value to locate the row, since the
+ # DB would already have set this
+ if ("pk_cascaded", state, col) in \
+ uowtransaction.attributes:
+ value = history.added[0]
+ params[col._label] = value
+ else:
+ # use the old value to
+ # locate the row
+ value = history.deleted[0]
+ params[col._label] = value
+ hasdata = True
+ else:
+ # row switch logic can reach us here
+ # remove the pk from the update params
+ # so the update doesn't
+ # attempt to include the pk in the
+ # update statement
+ del params[col.key]
+ value = history.added[0]
+ params[col._label] = value
+ if value is None:
+ hasnull = True
+ else:
+ hasdata = True
+ elif col in pks:
+ value = state.manager[prop.key].impl.get(
+ state, state_dict)
+ if value is None:
+ hasnull = True
+ params[col._label] = value
+ if hasdata:
+ if hasnull:
+ raise sa_exc.FlushError(
+ "Can't update table "
+ "using NULL for primary "
+ "key value")
+ update.append((state, state_dict, params, mapper,
+ connection, value_params))
+ return update
+
+
+def _collect_post_update_commands(base_mapper, uowtransaction, table,
+ states_to_update, post_update_cols):
+ """Identify sets of values to use in UPDATE statements for a
+ list of states within a post_update operation.
+
+ """
+
+ update = []
+ for state, state_dict, mapper, connection in states_to_update:
+ if table not in mapper._pks_by_table:
+ continue
+ pks = mapper._pks_by_table[table]
+ params = {}
+ hasdata = False
+
+ for col in mapper._cols_by_table[table]:
+ if col in pks:
+ params[col._label] = \
+ mapper._get_state_attr_by_column(
+ state,
+ state_dict, col)
+ elif col in post_update_cols:
+ prop = mapper._columntoproperty[col]
+ history = attributes.get_state_history(
+ state, prop.key,
+ attributes.PASSIVE_NO_INITIALIZE)
+ if history.added:
+ value = history.added[0]
+ params[col.key] = value
+ hasdata = True
+ if hasdata:
+ update.append((state, state_dict, params, mapper,
+ connection))
+ return update
+
+def _collect_delete_commands(base_mapper, uowtransaction, table,
+ states_to_delete):
+ """Identify values to use in DELETE statements for a list of
+ states to be deleted."""
+
+ delete = util.defaultdict(list)
+
+ for state, state_dict, mapper, has_identity, connection \
+ in states_to_delete:
+ if not has_identity or table not in mapper._pks_by_table:
+ continue
+
+ params = {}
+ delete[connection].append(params)
+ for col in mapper._pks_by_table[table]:
+ params[col.key] = \
+ value = \
+ mapper._get_state_attr_by_column(
+ state, state_dict, col)
+ if value is None:
+ raise sa_exc.FlushError(
+ "Can't delete from table "
+ "using NULL for primary "
+ "key value")
+
+ if mapper.version_id_col is not None and \
+ table.c.contains_column(mapper.version_id_col):
+ params[mapper.version_id_col.key] = \
+ mapper._get_committed_state_attr_by_column(
+ state, state_dict,
+ mapper.version_id_col)
+ return delete
+
+
+def _emit_update_statements(base_mapper, uowtransaction,
+ cached_connections, mapper, table, update):
+ """Emit UPDATE statements corresponding to value lists collected
+ by _collect_update_commands()."""
+
+ needs_version_id = mapper.version_id_col is not None and \
+ table.c.contains_column(mapper.version_id_col)
+
+ def update_stmt():
+ clause = sql.and_()
+
+ for col in mapper._pks_by_table[table]:
+ clause.clauses.append(col == sql.bindparam(col._label,
+ type_=col.type))
+
+ if needs_version_id:
+ clause.clauses.append(mapper.version_id_col ==\
+ sql.bindparam(mapper.version_id_col._label,
+ type_=col.type))
+
+ return table.update(clause)
+
+ statement = base_mapper._memo(('update', table), update_stmt)
+
+ rows = 0
+ for state, state_dict, params, mapper, \
+ connection, value_params in update:
+
+ if value_params:
+ c = connection.execute(
+ statement.values(value_params),
+ params)
+ else:
+ c = cached_connections[connection].\
+ execute(statement, params)
+
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ c.context.compiled_parameters[0],
+ value_params)
+ rows += c.rowcount
+
+ if connection.dialect.supports_sane_rowcount:
+ if rows != len(update):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(update), rows))
+
+ elif needs_version_id:
+ util.warn("Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
+
+def _emit_insert_statements(base_mapper, uowtransaction,
+ cached_connections, table, insert):
+ """Emit INSERT statements corresponding to value lists collected
+ by _collect_insert_commands()."""
+
+ statement = base_mapper._memo(('insert', table), table.insert)
+
+ for (connection, pkeys, hasvalue, has_all_pks), \
+ records in groupby(insert,
+ lambda rec: (rec[4],
+ rec[2].keys(),
+ bool(rec[5]),
+ rec[6])
+ ):
+ if has_all_pks and not hasvalue:
+ records = list(records)
+ multiparams = [rec[2] for rec in records]
+ c = cached_connections[connection].\
+ execute(statement, multiparams)
+
+ for (state, state_dict, params, mapper,
+ conn, value_params, has_all_pks), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ last_inserted_params,
+ value_params)
+
+ else:
+ for state, state_dict, params, mapper, \
+ connection, value_params, \
+ has_all_pks in records:
+
+ if value_params:
+ result = connection.execute(
+ statement.values(value_params),
+ params)
+ else:
+ result = cached_connections[connection].\
+ execute(statement, params)
+
+ primary_key = result.context.inserted_primary_key
+
+ if primary_key is not None:
+ # set primary key attributes
+ for pk, col in zip(primary_key,
+ mapper._pks_by_table[table]):
+ prop = mapper._columntoproperty[col]
+ if state_dict.get(prop.key) is None:
+ # TODO: would rather say:
+ #state_dict[prop.key] = pk
+ mapper._set_state_attr_by_column(
+ state,
+ state_dict,
+ col, pk)
+
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result.context.prefetch_cols,
+ result.context.postfetch_cols,
+ result.context.compiled_parameters[0],
+ value_params)
+
+
+
+def _emit_post_update_statements(base_mapper, uowtransaction,
+ cached_connections, mapper, table, update):
+ """Emit UPDATE statements corresponding to value lists collected
+ by _collect_post_update_commands()."""
+
+ def update_stmt():
+ clause = sql.and_()
+
+ for col in mapper._pks_by_table[table]:
+ clause.clauses.append(col == sql.bindparam(col._label,
+ type_=col.type))
+
+ return table.update(clause)
+
+ statement = base_mapper._memo(('post_update', table), update_stmt)
+
+ # execute each UPDATE in the order according to the original
+ # list of states to guarantee row access order, but
+ # also group them into common (connection, cols) sets
+ # to support executemany().
+ for key, grouper in groupby(
+ update, lambda rec: (rec[4], rec[2].keys())
+ ):
+ connection = key[0]
+ multiparams = [params for state, state_dict,
+ params, mapper, conn in grouper]
+ cached_connections[connection].\
+ execute(statement, multiparams)
+
+
+def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
+ mapper, table, delete):
+ """Emit DELETE statements corresponding to value lists collected
+ by _collect_delete_commands()."""
+
+ need_version_id = mapper.version_id_col is not None and \
+ table.c.contains_column(mapper.version_id_col)
+
+ def delete_stmt():
+ clause = sql.and_()
+ for col in mapper._pks_by_table[table]:
+ clause.clauses.append(
+ col == sql.bindparam(col.key, type_=col.type))
+
+ if need_version_id:
+ clause.clauses.append(
+ mapper.version_id_col ==
+ sql.bindparam(
+ mapper.version_id_col.key,
+ type_=mapper.version_id_col.type
+ )
+ )
+
+ return table.delete(clause)
+
+ for connection, del_objects in delete.iteritems():
+ statement = base_mapper._memo(('delete', table), delete_stmt)
+
+ connection = cached_connections[connection]
+
+ if need_version_id:
+ # TODO: need test coverage for this [ticket:1761]
+ if connection.dialect.supports_sane_rowcount:
+ rows = 0
+ # execute deletes individually so that versioned
+ # rows can be verified
+ for params in del_objects:
+ c = connection.execute(statement, params)
+ rows += c.rowcount
+ if rows != len(del_objects):
+ raise orm_exc.StaleDataError(
+ "DELETE statement on table '%s' expected to "
+ "delete %d row(s); %d were matched." %
+ (table.description, len(del_objects), c.rowcount)
+ )
+ else:
+ util.warn(
+ "Dialect %s does not support deleted rowcount "
+ "- versioning cannot be verified." %
+ connection.dialect.dialect_description,
+ stacklevel=12)
+ connection.execute(statement, del_objects)
+ else:
+ connection.execute(statement, del_objects)
+
+
+def _finalize_insert_update_commands(base_mapper, uowtransaction,
+ states_to_insert, states_to_update):
+ """finalize state on states that have been inserted or updated,
+ including calling after_insert/after_update events.
+
+ """
+ for state, state_dict, mapper, connection, has_identity, \
+ instance_key, row_switch in states_to_insert + \
+ states_to_update:
+
+ if mapper._readonly_props:
+ readonly = state.unmodified_intersection(
+ [p.key for p in mapper._readonly_props
+ if p.expire_on_flush or p.key not in state.dict]
+ )
+ if readonly:
+ state.expire_attributes(state.dict, readonly)
+
+ # if eager_defaults option is enabled,
+ # refresh whatever has been expired.
+ if base_mapper.eager_defaults and state.unloaded:
+ state.key = base_mapper._identity_key_from_state(state)
+ uowtransaction.session.query(base_mapper)._load_on_ident(
+ state.key, refresh_state=state,
+ only_load_props=state.unloaded)
+
+ # call after_XXX extensions
+ if not has_identity:
+ mapper.dispatch.after_insert(mapper, connection, state)
+ else:
+ mapper.dispatch.after_update(mapper, connection, state)
+
+def _postfetch(mapper, uowtransaction, table,
+ state, dict_, prefetch_cols, postfetch_cols,
+ params, value_params):
+ """Expire attributes in need of newly persisted database state,
+ after an INSERT or UPDATE statement has proceeded for that
+ state."""
+
+ if mapper.version_id_col is not None:
+ prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
+ for c in prefetch_cols:
+ if c.key in params and c in mapper._columntoproperty:
+ mapper._set_state_attr_by_column(state, dict_, c, params[c.key])
+
+ if postfetch_cols:
+ state.expire_attributes(state.dict,
+ [mapper._columntoproperty[c].key
+ for c in postfetch_cols if c in
+ mapper._columntoproperty]
+ )
+
+ # synchronize newly inserted ids from one table to the next
+ # TODO: this still goes a little too often. would be nice to
+ # have definitive list of "columns that changed" here
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.populate(state, m, state, m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates)
+
+def _connections_for_states(base_mapper, uowtransaction, states):
+ """Return an iterator of (state, state.dict, mapper, connection).
+
+ The states are sorted according to _sort_states, then paired
+ with the connection they should be using for the given
+ unit of work transaction.
+
+ """
+ # if session has a connection callable,
+ # organize individual states with the connection
+ # to use for update
+ if uowtransaction.session.connection_callable:
+ connection_callable = \
+ uowtransaction.session.connection_callable
+ else:
+ connection = uowtransaction.transaction.connection(
+ base_mapper)
+ connection_callable = None
+
+ for state in _sort_states(states):
+ if connection_callable:
+ connection = connection_callable(base_mapper, state.obj())
+
+ mapper = _state_mapper(state)
+
+ yield state, state.dict, mapper, connection
+
+def _cached_connection_dict(base_mapper):
+ # dictionary of connection->connection_with_cache_options.
+ return util.PopulateDict(
+ lambda conn:conn.execution_options(
+ compiled_cache=base_mapper._compiled_cache
+ ))
+
+def _sort_states(states):
+ pending = set(states)
+ persistent = set(s for s in pending if s.key is not None)
+ pending.difference_update(persistent)
+ return sorted(pending, key=operator.attrgetter("insert_order")) + \
+ sorted(persistent, key=lambda q:q.key[1])
+
+
diff --git a/libs/sqlalchemy/orm/query.py b/libs/sqlalchemy/orm/query.py
index 9508cb5..aa3dd01 100644
--- a/libs/sqlalchemy/orm/query.py
+++ b/libs/sqlalchemy/orm/query.py
@@ -133,7 +133,7 @@ class Query(object):
with_polymorphic = mapper._with_polymorphic_mappers
if mapper.mapped_table not in \
self._polymorphic_adapters:
- self.__mapper_loads_polymorphically_with(mapper,
+ self._mapper_loads_polymorphically_with(mapper,
sql_util.ColumnAdapter(
selectable,
mapper._equivalent_columns))
@@ -150,7 +150,7 @@ class Query(object):
is_aliased_class, with_polymorphic)
ent.setup_entity(entity, *d[entity])
- def __mapper_loads_polymorphically_with(self, mapper, adapter):
+ def _mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers:
self._polymorphic_adapters[m2] = adapter
for m in m2.iterate_to_root():
@@ -174,10 +174,6 @@ class Query(object):
self._from_obj_alias = sql_util.ColumnAdapter(
self._from_obj[0], equivs)
- def _get_polymorphic_adapter(self, entity, selectable):
- self.__mapper_loads_polymorphically_with(entity.mapper,
- sql_util.ColumnAdapter(selectable,
- entity.mapper._equivalent_columns))
def _reset_polymorphic_adapter(self, mapper):
for m2 in mapper._with_polymorphic_mappers:
@@ -276,6 +272,7 @@ class Query(object):
return self._select_from_entity or \
self._entity_zero().entity_zero
+
@property
def _mapper_entities(self):
# TODO: this is wrong, its hardcoded to "primary entity" when
@@ -324,13 +321,6 @@ class Query(object):
)
return self._entity_zero()
- def _generate_mapper_zero(self):
- if not getattr(self._entities[0], 'primary_entity', False):
- raise sa_exc.InvalidRequestError(
- "No primary mapper set up for this Query.")
- entity = self._entities[0]._clone()
- self._entities = [entity] + self._entities[1:]
- return entity
def __all_equivs(self):
equivs = {}
@@ -460,6 +450,62 @@ class Query(object):
"""
return self.enable_eagerloads(False).statement.alias(name=name)
+ def cte(self, name=None, recursive=False):
+ """Return the full SELECT statement represented by this :class:`.Query`
+ represented as a common table expression (CTE).
+
+ The :meth:`.Query.cte` method is new in 0.7.6.
+
+ Parameters and usage are the same as those of the
+ :meth:`._SelectBase.cte` method; see that method for
+ further details.
+
+ Here is the `Postgresql WITH
+ RECURSIVE example `_.
+ Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias
+ of it are Core selectables, which
+ means the columns are accessed via the ``.c.`` attribute. The ``parts_alias``
+ object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped
+ attributes are available directly::
+
+ from sqlalchemy.orm import aliased
+
+ class Part(Base):
+ __tablename__ = 'part'
+ part = Column(String, primary_key=True)
+ sub_part = Column(String, primary_key=True)
+ quantity = Column(Integer)
+
+ included_parts = session.query(
+ Part.sub_part,
+ Part.part,
+ Part.quantity).\\
+ filter(Part.part=="our part").\\
+ cte(name="included_parts", recursive=True)
+
+ incl_alias = aliased(included_parts, name="pr")
+ parts_alias = aliased(Part, name="p")
+ included_parts = included_parts.union_all(
+ session.query(
+ parts_alias.part,
+ parts_alias.sub_part,
+ parts_alias.quantity).\\
+ filter(parts_alias.part==incl_alias.c.sub_part)
+ )
+
+ q = session.query(
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).label('total_quantity')
+ ).\\
+ group_by(included_parts.c.sub_part)
+
+ See also:
+
+ :meth:`._SelectBase.cte`
+
+ """
+ return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive)
+
def label(self, name):
"""Return the full SELECT statement represented by this :class:`.Query`, converted
to a scalar subquery with a label of the given name.
@@ -601,7 +647,12 @@ class Query(object):
such as concrete table mappers.
"""
- entity = self._generate_mapper_zero()
+
+ if not getattr(self._entities[0], 'primary_entity', False):
+ raise sa_exc.InvalidRequestError(
+ "No primary mapper set up for this Query.")
+ entity = self._entities[0]._clone()
+ self._entities = [entity] + self._entities[1:]
entity.set_with_polymorphic(self,
cls_or_mappers,
selectable=selectable,
@@ -1041,7 +1092,22 @@ class Query(object):
@_generative()
def with_lockmode(self, mode):
- """Return a new Query object with the specified locking mode."""
+ """Return a new Query object with the specified locking mode.
+
+ :param mode: a string representing the desired locking mode. A
+ corresponding value is passed to the ``for_update`` parameter of
+ :meth:`~sqlalchemy.sql.expression.select` when the query is
+ executed. Valid values are:
+
+ ``'update'`` - passes ``for_update=True``, which translates to
+ ``FOR UPDATE`` (standard SQL, supported by most dialects)
+
+ ``'update_nowait'`` - passes ``for_update='nowait'``, which
+ translates to ``FOR UPDATE NOWAIT`` (supported by Oracle)
+
+ ``'read'`` - passes ``for_update='read'``, which translates to
+ ``LOCK IN SHARE MODE`` (supported by MySQL).
+ """
self._lockmode = mode
@@ -1583,7 +1649,6 @@ class Query(object):
consistent format with which to form the actual JOIN constructs.
"""
- self._polymorphic_adapters = self._polymorphic_adapters.copy()
if not from_joinpoint:
self._reset_joinpoint()
@@ -1683,6 +1748,8 @@ class Query(object):
onclause, outerjoin, create_aliases, prop):
"""append a JOIN to the query's from clause."""
+ self._polymorphic_adapters = self._polymorphic_adapters.copy()
+
if left is None:
if self._from_obj:
left = self._from_obj[0]
@@ -1696,7 +1763,29 @@ class Query(object):
"are the same entity" %
(left, right))
- left_mapper, left_selectable, left_is_aliased = _entity_info(left)
+ right, right_is_aliased, onclause = self._prepare_right_side(
+ right, onclause,
+ outerjoin, create_aliases,
+ prop)
+
+ # if joining on a MapperProperty path,
+ # track the path to prevent redundant joins
+ if not create_aliases and prop:
+ self._update_joinpoint({
+ '_joinpoint_entity':right,
+ 'prev':((left, right, prop.key), self._joinpoint)
+ })
+ else:
+ self._joinpoint = {
+ '_joinpoint_entity':right
+ }
+
+ self._join_to_left(left, right,
+ right_is_aliased,
+ onclause, outerjoin)
+
+ def _prepare_right_side(self, right, onclause, outerjoin,
+ create_aliases, prop):
right_mapper, right_selectable, right_is_aliased = _entity_info(right)
if right_mapper:
@@ -1741,24 +1830,13 @@ class Query(object):
right = aliased(right)
need_adapter = True
- # if joining on a MapperProperty path,
- # track the path to prevent redundant joins
- if not create_aliases and prop:
- self._update_joinpoint({
- '_joinpoint_entity':right,
- 'prev':((left, right, prop.key), self._joinpoint)
- })
- else:
- self._joinpoint = {
- '_joinpoint_entity':right
- }
-
# if an alias() of the right side was generated here,
# apply an adapter to all subsequent filter() calls
# until reset_joinpoint() is called.
if need_adapter:
self._filter_aliases = ORMAdapter(right,
- equivalents=right_mapper and right_mapper._equivalent_columns or {},
+ equivalents=right_mapper and
+ right_mapper._equivalent_columns or {},
chain_to=self._filter_aliases)
# if the onclause is a ClauseElement, adapt it with any
@@ -1771,7 +1849,7 @@ class Query(object):
# ensure that columns retrieved from this target in the result
# set are also adapted.
if aliased_entity and not create_aliases:
- self.__mapper_loads_polymorphically_with(
+ self._mapper_loads_polymorphically_with(
right_mapper,
ORMAdapter(
right,
@@ -1779,6 +1857,11 @@ class Query(object):
)
)
+ return right, right_is_aliased, onclause
+
+ def _join_to_left(self, left, right, right_is_aliased, onclause, outerjoin):
+ left_mapper, left_selectable, left_is_aliased = _entity_info(left)
+
# this is an overly broad assumption here, but there's a
# very wide variety of situations where we rely upon orm.join's
# adaption to glue clauses together, with joined-table inheritance's
@@ -2959,7 +3042,9 @@ class _MapperEntity(_QueryEntity):
# with_polymorphic() can be applied to aliases
if not self.is_aliased_class:
self.selectable = from_obj
- self.adapter = query._get_polymorphic_adapter(self, from_obj)
+ query._mapper_loads_polymorphically_with(self.mapper,
+ sql_util.ColumnAdapter(from_obj,
+ self.mapper._equivalent_columns))
filter_fn = id
@@ -3086,8 +3171,9 @@ class _MapperEntity(_QueryEntity):
class _ColumnEntity(_QueryEntity):
"""Column/expression based entity."""
- def __init__(self, query, column):
+ def __init__(self, query, column, namespace=None):
self.expr = column
+ self.namespace = namespace
if isinstance(column, basestring):
column = sql.literal_column(column)
@@ -3106,7 +3192,7 @@ class _ColumnEntity(_QueryEntity):
for c in column._select_iterable:
if c is column:
break
- _ColumnEntity(query, c)
+ _ColumnEntity(query, c, namespace=column)
if c is not column:
return
@@ -3147,12 +3233,14 @@ class _ColumnEntity(_QueryEntity):
if self.entities:
self.entity_zero = list(self.entities)[0]
+ elif self.namespace is not None:
+ self.entity_zero = self.namespace
else:
self.entity_zero = None
@property
def entity_zero_or_selectable(self):
- if self.entity_zero:
+ if self.entity_zero is not None:
return self.entity_zero
elif self.actual_froms:
return list(self.actual_froms)[0]
diff --git a/libs/sqlalchemy/orm/scoping.py b/libs/sqlalchemy/orm/scoping.py
index ffc8ef4..3c1cd7f 100644
--- a/libs/sqlalchemy/orm/scoping.py
+++ b/libs/sqlalchemy/orm/scoping.py
@@ -41,8 +41,9 @@ class ScopedSession(object):
scope = kwargs.pop('scope', False)
if scope is not None:
if self.registry.has():
- raise sa_exc.InvalidRequestError("Scoped session is already present; "
- "no new arguments may be specified.")
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified.")
else:
sess = self.session_factory(**kwargs)
self.registry.set(sess)
@@ -70,8 +71,8 @@ class ScopedSession(object):
self.session_factory.configure(**kwargs)
def query_property(self, query_cls=None):
- """return a class property which produces a `Query` object against the
- class when called.
+ """return a class property which produces a `Query` object
+ against the class when called.
e.g.::
@@ -121,7 +122,8 @@ def makeprop(name):
def get(self):
return getattr(self.registry(), name)
return property(get, set)
-for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'):
+for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map',
+ 'is_active', 'autoflush', 'no_autoflush'):
setattr(ScopedSession, prop, makeprop(prop))
def clslevel(name):
diff --git a/libs/sqlalchemy/orm/session.py b/libs/sqlalchemy/orm/session.py
index 4299290..1477870 100644
--- a/libs/sqlalchemy/orm/session.py
+++ b/libs/sqlalchemy/orm/session.py
@@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False,
kwargs.update(new_kwargs)
- return type("Session", (Sess, class_), {})
+ return type("SessionMaker", (Sess, class_), {})
class SessionTransaction(object):
@@ -978,6 +978,34 @@ class Session(object):
return self._query_cls(entities, self, **kwargs)
+ @property
+ @util.contextmanager
+ def no_autoflush(self):
+ """Return a context manager that disables autoflush.
+
+ e.g.::
+
+ with session.no_autoflush:
+
+ some_object = SomeClass()
+ session.add(some_object)
+ # won't autoflush
+ some_object.related_thing = session.query(SomeRelated).first()
+
+ Operations that proceed within the ``with:`` block
+ will not be subject to flushes occurring upon query
+ access. This is useful when initializing a series
+ of objects which involve existing database queries,
+ where the uncompleted object should not yet be flushed.
+
+ New in 0.7.6.
+
+ """
+ autoflush = self.autoflush
+ self.autoflush = False
+ yield self
+ self.autoflush = autoflush
+
def _autoflush(self):
if self.autoflush and not self._flushing:
self.flush()
@@ -1772,6 +1800,19 @@ class Session(object):
return self.transaction and self.transaction.is_active
+ identity_map = None
+ """A mapping of object identities to objects themselves.
+
+ Iterating through ``Session.identity_map.values()`` provides
+ access to the full set of persistent objects (i.e., those
+ that have row identity) currently in the session.
+
+ See also:
+
+ :func:`.identity_key` - operations involving identity keys.
+
+ """
+
@property
def _dirty_states(self):
"""The set of all persistent states considered dirty.
diff --git a/libs/sqlalchemy/orm/sync.py b/libs/sqlalchemy/orm/sync.py
index b016e81..a20e871 100644
--- a/libs/sqlalchemy/orm/sync.py
+++ b/libs/sqlalchemy/orm/sync.py
@@ -6,6 +6,7 @@
"""private module containing functions used for copying data
between instances based on join conditions.
+
"""
from sqlalchemy.orm import exc, util as mapperutil, attributes
diff --git a/libs/sqlalchemy/orm/unitofwork.py b/libs/sqlalchemy/orm/unitofwork.py
index 3cd0f15..8fc5f13 100644
--- a/libs/sqlalchemy/orm/unitofwork.py
+++ b/libs/sqlalchemy/orm/unitofwork.py
@@ -14,7 +14,7 @@ organizes them in order of dependency, and executes.
from sqlalchemy import util, event
from sqlalchemy.util import topological
-from sqlalchemy.orm import attributes, interfaces
+from sqlalchemy.orm import attributes, interfaces, persistence
from sqlalchemy.orm import util as mapperutil
session = util.importlater("sqlalchemy.orm", "session")
@@ -462,7 +462,7 @@ class IssuePostUpdate(PostSortRec):
states, cols = uow.post_update_states[self.mapper]
states = [s for s in states if uow.states[s][0] == self.isdelete]
- self.mapper._post_update(states, uow, cols)
+ persistence.post_update(self.mapper, states, uow, cols)
class SaveUpdateAll(PostSortRec):
def __init__(self, uow, mapper):
@@ -470,7 +470,7 @@ class SaveUpdateAll(PostSortRec):
assert mapper is mapper.base_mapper
def execute(self, uow):
- self.mapper._save_obj(
+ persistence.save_obj(self.mapper,
uow.states_for_mapper_hierarchy(self.mapper, False, False),
uow
)
@@ -493,7 +493,7 @@ class DeleteAll(PostSortRec):
assert mapper is mapper.base_mapper
def execute(self, uow):
- self.mapper._delete_obj(
+ persistence.delete_obj(self.mapper,
uow.states_for_mapper_hierarchy(self.mapper, True, False),
uow
)
@@ -551,7 +551,7 @@ class SaveUpdateState(PostSortRec):
if r.__class__ is cls_ and
r.mapper is mapper]
recs.difference_update(our_recs)
- mapper._save_obj(
+ persistence.save_obj(mapper,
[self.state] +
[r.state for r in our_recs],
uow)
@@ -575,7 +575,7 @@ class DeleteState(PostSortRec):
r.mapper is mapper]
recs.difference_update(our_recs)
states = [self.state] + [r.state for r in our_recs]
- mapper._delete_obj(
+ persistence.delete_obj(mapper,
[s for s in states if uow.states[s][0]],
uow)
diff --git a/libs/sqlalchemy/orm/util.py b/libs/sqlalchemy/orm/util.py
index 0cd5b05..0c5f203 100644
--- a/libs/sqlalchemy/orm/util.py
+++ b/libs/sqlalchemy/orm/util.py
@@ -11,6 +11,7 @@ from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\
PropComparator, MapperProperty
from sqlalchemy.orm import attributes, exc
import operator
+import re
mapperlib = util.importlater("sqlalchemy.orm", "mapperlib")
@@ -20,38 +21,52 @@ all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
_INSTRUMENTOR = ('mapper', 'instrumentor')
-class CascadeOptions(dict):
+class CascadeOptions(frozenset):
"""Keeps track of the options sent to relationship().cascade"""
- def __init__(self, arg=""):
- if not arg:
- values = set()
- else:
- values = set(c.strip() for c in arg.split(','))
-
- for name in ['save-update', 'delete', 'refresh-expire',
- 'merge', 'expunge']:
- boolean = name in values or 'all' in values
- setattr(self, name.replace('-', '_'), boolean)
- if boolean:
- self[name] = True
+ _add_w_all_cascades = all_cascades.difference([
+ 'all', 'none', 'delete-orphan'])
+ _allowed_cascades = all_cascades
+
+ def __new__(cls, arg):
+ values = set([
+ c for c
+ in re.split('\s*,\s*', arg or "")
+ if c
+ ])
+
+ if values.difference(cls._allowed_cascades):
+ raise sa_exc.ArgumentError(
+ "Invalid cascade option(s): %s" %
+ ", ".join([repr(x) for x in
+ sorted(
+ values.difference(cls._allowed_cascades)
+ )])
+ )
+
+ if "all" in values:
+ values.update(cls._add_w_all_cascades)
+ if "none" in values:
+ values.clear()
+ values.discard('all')
+
+ self = frozenset.__new__(CascadeOptions, values)
+ self.save_update = 'save-update' in values
+ self.delete = 'delete' in values
+ self.refresh_expire = 'refresh-expire' in values
+ self.merge = 'merge' in values
+ self.expunge = 'expunge' in values
self.delete_orphan = "delete-orphan" in values
- if self.delete_orphan:
- self['delete-orphan'] = True
if self.delete_orphan and not self.delete:
- util.warn("The 'delete-orphan' cascade option requires "
- "'delete'.")
-
- for x in values:
- if x not in all_cascades:
- raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
+ util.warn("The 'delete-orphan' cascade "
+ "option requires 'delete'.")
+ return self
def __repr__(self):
- return "CascadeOptions(%s)" % repr(",".join(
- [x for x in ['delete', 'save_update', 'merge', 'expunge',
- 'delete_orphan', 'refresh-expire']
- if getattr(self, x, False) is True]))
+ return "CascadeOptions(%r)" % (
+ ",".join([x for x in sorted(self)])
+ )
def _validator_events(desc, key, validator):
"""Runs a validation method on an attribute value to be set or appended."""
@@ -557,15 +572,20 @@ def _entity_descriptor(entity, key):
attribute.
"""
- if not isinstance(entity, (AliasedClass, type)):
- entity = entity.class_
+ if isinstance(entity, expression.FromClause):
+ description = entity
+ entity = entity.c
+ elif not isinstance(entity, (AliasedClass, type)):
+ description = entity = entity.class_
+ else:
+ description = entity
try:
return getattr(entity, key)
except AttributeError:
raise sa_exc.InvalidRequestError(
"Entity '%s' has no property '%s'" %
- (entity, key)
+ (description, key)
)
def _orm_columns(entity):
diff --git a/libs/sqlalchemy/pool.py b/libs/sqlalchemy/pool.py
index a615e8c..6254a4b 100644
--- a/libs/sqlalchemy/pool.py
+++ b/libs/sqlalchemy/pool.py
@@ -57,6 +57,10 @@ def clear_managers():
manager.close()
proxies.clear()
+reset_rollback = util.symbol('reset_rollback')
+reset_commit = util.symbol('reset_commit')
+reset_none = util.symbol('reset_none')
+
class Pool(log.Identified):
"""Abstract base class for connection pools."""
@@ -130,7 +134,17 @@ class Pool(log.Identified):
self._creator = creator
self._recycle = recycle
self._use_threadlocal = use_threadlocal
- self._reset_on_return = reset_on_return
+ if reset_on_return in ('rollback', True, reset_rollback):
+ self._reset_on_return = reset_rollback
+ elif reset_on_return in (None, False, reset_none):
+ self._reset_on_return = reset_none
+ elif reset_on_return in ('commit', reset_commit):
+ self._reset_on_return = reset_commit
+ else:
+ raise exc.ArgumentError(
+ "Invalid value for 'reset_on_return': %r"
+ % reset_on_return)
+
self.echo = echo
if _dispatch:
self.dispatch._update(_dispatch, only_propagate=False)
@@ -330,8 +344,10 @@ def _finalize_fairy(connection, connection_record, pool, ref, echo):
if connection is not None:
try:
- if pool._reset_on_return:
+ if pool._reset_on_return is reset_rollback:
connection.rollback()
+ elif pool._reset_on_return is reset_commit:
+ connection.commit()
# Immediately close detached instances
if connection_record is None:
connection.close()
@@ -624,11 +640,37 @@ class QueuePool(Pool):
:meth:`unique_connection` method is provided to bypass the
threadlocal behavior installed into :meth:`connect`.
- :param reset_on_return: If true, reset the database state of
- connections returned to the pool. This is typically a
- ROLLBACK to release locks and transaction resources.
- Disable at your own peril. Defaults to True.
-
+ :param reset_on_return: Determine steps to take on
+ connections as they are returned to the pool.
+ As of SQLAlchemy 0.7.6, reset_on_return can have any
+ of these values:
+
+ * 'rollback' - call rollback() on the connection,
+ to release locks and transaction resources.
+ This is the default value. The vast majority
+ of use cases should leave this value set.
+ * True - same as 'rollback', this is here for
+ backwards compatibility.
+ * 'commit' - call commit() on the connection,
+ to release locks and transaction resources.
+ A commit here may be desirable for databases that
+ cache query plans if a commit is emitted,
+ such as Microsoft SQL Server. However, this
+ value is more dangerous than 'rollback' because
+ any data changes present on the transaction
+ are committed unconditionally.
+ * None - don't do anything on the connection.
+ This setting should only be made on a database
+ that has no transaction support at all,
+ namely MySQL MyISAM. By not doing anything,
+ performance can be improved. This
+ setting should **never be selected** for a
+ database that supports transactions,
+ as it will lead to deadlocks and stale
+ state.
+ * False - same as None, this is here for
+ backwards compatibility.
+
:param listeners: A list of
:class:`~sqlalchemy.interfaces.PoolListener`-like objects or
dictionaries of callables that receive events when DB-API
diff --git a/libs/sqlalchemy/schema.py b/libs/sqlalchemy/schema.py
index f0a9297..d295143 100644
--- a/libs/sqlalchemy/schema.py
+++ b/libs/sqlalchemy/schema.py
@@ -80,6 +80,17 @@ def _get_table_key(name, schema):
else:
return schema + "." + name
+def _validate_dialect_kwargs(kwargs, name):
+ # validate remaining kwargs that they all specify DB prefixes
+ if len([k for k in kwargs
+ if not re.match(
+ r'^(?:%s)_' %
+ '|'.join(dialects.__all__), k
+ )
+ ]):
+ raise TypeError(
+ "Invalid argument(s) for %s: %r" % (name, kwargs.keys()))
+
class Table(SchemaItem, expression.TableClause):
"""Represent a table in a database.
@@ -369,9 +380,12 @@ class Table(SchemaItem, expression.TableClause):
# allow user-overrides
self._init_items(*args)
- def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=None):
+ def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=()):
if self.primary_key.columns:
- PrimaryKeyConstraint()._set_parent_with_dispatch(self)
+ PrimaryKeyConstraint(*[
+ c for c in self.primary_key.columns
+ if c.key in exclude_columns
+ ])._set_parent_with_dispatch(self)
if autoload_with:
autoload_with.run_callable(
@@ -424,7 +438,7 @@ class Table(SchemaItem, expression.TableClause):
if not autoload_replace:
exclude_columns = [c.name for c in self.c]
else:
- exclude_columns = None
+ exclude_columns = ()
self._autoload(self.metadata, autoload_with, include_columns, exclude_columns)
self._extra_kwargs(**kwargs)
@@ -432,14 +446,7 @@ class Table(SchemaItem, expression.TableClause):
def _extra_kwargs(self, **kwargs):
# validate remaining kwargs that they all specify DB prefixes
- if len([k for k in kwargs
- if not re.match(
- r'^(?:%s)_' %
- '|'.join(dialects.__all__), k
- )
- ]):
- raise TypeError(
- "Invalid argument(s) for Table: %r" % kwargs.keys())
+ _validate_dialect_kwargs(kwargs, "Table")
self.kwargs.update(kwargs)
def _init_collections(self):
@@ -1028,7 +1035,7 @@ class Column(SchemaItem, expression.ColumnClause):
"The 'index' keyword argument on Column is boolean only. "
"To create indexes with a specific name, create an "
"explicit Index object external to the Table.")
- Index(expression._generated_label('ix_%s' % self._label), self, unique=self.unique)
+ Index(expression._truncated_label('ix_%s' % self._label), self, unique=self.unique)
elif self.unique:
if isinstance(self.unique, basestring):
raise exc.ArgumentError(
@@ -1093,7 +1100,7 @@ class Column(SchemaItem, expression.ColumnClause):
"been assigned.")
try:
c = self._constructor(
- name or self.name,
+ expression._as_truncated(name or self.name),
self.type,
key = name or self.key,
primary_key = self.primary_key,
@@ -1119,6 +1126,8 @@ class Column(SchemaItem, expression.ColumnClause):
c.table = selectable
selectable._columns.add(c)
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns[c.name]
if self.primary_key:
selectable.primary_key.add(c)
c.dispatch.after_parent_attach(c, selectable)
@@ -1809,7 +1818,8 @@ class Constraint(SchemaItem):
__visit_name__ = 'constraint'
def __init__(self, name=None, deferrable=None, initially=None,
- _create_rule=None):
+ _create_rule=None,
+ **kw):
"""Create a SQL constraint.
:param name:
@@ -1839,6 +1849,10 @@ class Constraint(SchemaItem):
_create_rule is used by some types to create constraints.
Currently, its call signature is subject to change at any time.
+
+ :param \**kwargs:
+ Dialect-specific keyword parameters, see the documentation
+ for various dialects and constraints regarding options here.
"""
@@ -1847,6 +1861,8 @@ class Constraint(SchemaItem):
self.initially = initially
self._create_rule = _create_rule
util.set_creation_order(self)
+ _validate_dialect_kwargs(kw, self.__class__.__name__)
+ self.kwargs = kw
@property
def table(self):
@@ -2192,6 +2208,8 @@ class Index(ColumnCollectionMixin, SchemaItem):
self.table = None
# will call _set_parent() if table-bound column
# objects are present
+ if not columns:
+ util.warn("No column names or expressions given for Index.")
ColumnCollectionMixin.__init__(self, *columns)
self.name = name
self.unique = kw.pop('unique', False)
@@ -3004,9 +3022,11 @@ def _to_schema_column(element):
return element
def _to_schema_column_or_string(element):
- if hasattr(element, '__clause_element__'):
- element = element.__clause_element__()
- return element
+ if hasattr(element, '__clause_element__'):
+ element = element.__clause_element__()
+ if not isinstance(element, (basestring, expression.ColumnElement)):
+ raise exc.ArgumentError("Element %r is not a string name or column element" % element)
+ return element
class _CreateDropBase(DDLElement):
"""Base class for DDL constucts that represent CREATE and DROP or
diff --git a/libs/sqlalchemy/sql/compiler.py b/libs/sqlalchemy/sql/compiler.py
index b0a55b8..c5c6f9e 100644
--- a/libs/sqlalchemy/sql/compiler.py
+++ b/libs/sqlalchemy/sql/compiler.py
@@ -154,9 +154,10 @@ class _CompileLabel(visitors.Visitable):
__visit_name__ = 'label'
__slots__ = 'element', 'name'
- def __init__(self, col, name):
+ def __init__(self, col, name, alt_names=()):
self.element = col
self.name = name
+ self._alt_names = alt_names
@property
def proxy_set(self):
@@ -251,6 +252,10 @@ class SQLCompiler(engine.Compiled):
# column targeting
self.result_map = {}
+ # collect CTEs to tack on top of a SELECT
+ self.ctes = util.OrderedDict()
+ self.ctes_recursive = False
+
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
@@ -354,14 +359,16 @@ class SQLCompiler(engine.Compiled):
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
if within_columns_clause and not within_label_clause:
- if isinstance(label.name, sql._generated_label):
+ if isinstance(label.name, sql._truncated_label):
labelname = self._truncated_identifier("colident", label.name)
else:
labelname = label.name
if result_map is not None:
- result_map[labelname.lower()] = \
- (label.name, (label, label.element, labelname),\
+ result_map[labelname.lower()] = (
+ label.name,
+ (label, label.element, labelname, ) +
+ label._alt_names,
label.type)
return label.element._compiler_dispatch(self,
@@ -376,17 +383,19 @@ class SQLCompiler(engine.Compiled):
**kw)
def visit_column(self, column, result_map=None, **kwargs):
- name = column.name
+ name = orig_name = column.name
if name is None:
raise exc.CompileError("Cannot compile Column object until "
"it's 'name' is assigned.")
is_literal = column.is_literal
- if not is_literal and isinstance(name, sql._generated_label):
+ if not is_literal and isinstance(name, sql._truncated_label):
name = self._truncated_identifier("colident", name)
if result_map is not None:
- result_map[name.lower()] = (name, (column, ), column.type)
+ result_map[name.lower()] = (orig_name,
+ (column, name, column.key),
+ column.type)
if is_literal:
name = self.escape_literal_column(name)
@@ -404,7 +413,7 @@ class SQLCompiler(engine.Compiled):
else:
schema_prefix = ''
tablename = table.name
- if isinstance(tablename, sql._generated_label):
+ if isinstance(tablename, sql._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \
@@ -646,7 +655,8 @@ class SQLCompiler(engine.Compiled):
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
- if existing.unique or bindparam.unique:
+ if (existing.unique or bindparam.unique) and \
+ not existing.proxy_set.intersection(bindparam.proxy_set):
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
"unique bind parameter of the same name" %
@@ -703,7 +713,7 @@ class SQLCompiler(engine.Compiled):
return self.bind_names[bindparam]
bind_name = bindparam.key
- if isinstance(bind_name, sql._generated_label):
+ if isinstance(bind_name, sql._truncated_label):
bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
@@ -715,7 +725,7 @@ class SQLCompiler(engine.Compiled):
if (ident_class, name) in self.truncated_names:
return self.truncated_names[(ident_class, name)]
- anonname = name % self.anon_map
+ anonname = name.apply_map(self.anon_map)
if len(anonname) > self.label_length:
counter = self.truncated_names.get(ident_class, 1)
@@ -744,10 +754,49 @@ class SQLCompiler(engine.Compiled):
else:
return self.bindtemplate % {'name':name}
+ def visit_cte(self, cte, asfrom=False, ashint=False,
+ fromhints=None, **kwargs):
+ if isinstance(cte.name, sql._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte.name)
+ else:
+ cte_name = cte.name
+ if cte.cte_alias:
+ if isinstance(cte.cte_alias, sql._truncated_label):
+ cte_alias = self._truncated_identifier("alias", cte.cte_alias)
+ else:
+ cte_alias = cte.cte_alias
+ if not cte.cte_alias and cte not in self.ctes:
+ if cte.recursive:
+ self.ctes_recursive = True
+ text = self.preparer.format_alias(cte, cte_name)
+ if cte.recursive:
+ if isinstance(cte.original, sql.Select):
+ col_source = cte.original
+ elif isinstance(cte.original, sql.CompoundSelect):
+ col_source = cte.original.selects[0]
+ else:
+ assert False
+ recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
+ if c is not None]
+
+ text += "(%s)" % (", ".join(recur_cols))
+ text += " AS \n" + \
+ cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
+ self.ctes[cte] = text
+ if asfrom:
+ if cte.cte_alias:
+ text = self.preparer.format_alias(cte, cte_alias)
+ text += " AS " + cte_name
+ else:
+ return self.preparer.format_alias(cte, cte_name)
+ return text
+
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
- if isinstance(alias.name, sql._generated_label):
+ if isinstance(alias.name, sql._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
else:
alias_name = alias.name
@@ -775,8 +824,14 @@ class SQLCompiler(engine.Compiled):
if isinstance(column, sql._Label):
return column
- elif select is not None and select.use_labels and column._label:
- return _CompileLabel(column, column._label)
+ elif select is not None and \
+ select.use_labels and \
+ column._label:
+ return _CompileLabel(
+ column,
+ column._label,
+ alt_names=(column._key_label, )
+ )
elif \
asfrom and \
@@ -784,7 +839,8 @@ class SQLCompiler(engine.Compiled):
not column.is_literal and \
column.table is not None and \
not isinstance(column.table, sql.Select):
- return _CompileLabel(column, sql._generated_label(column.name))
+ return _CompileLabel(column, sql._as_truncated(column.name),
+ alt_names=(column.key,))
elif not isinstance(column,
(sql._UnaryExpression, sql._TextClause)) \
and (not hasattr(column, 'name') or \
@@ -799,6 +855,9 @@ class SQLCompiler(engine.Compiled):
def get_from_hint_text(self, table, text):
return None
+ def get_crud_hint_text(self, table, text):
+ return None
+
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=1, **kwargs):
@@ -897,6 +956,15 @@ class SQLCompiler(engine.Compiled):
if select.for_update:
text += self.for_update_clause(select)
+ if self.ctes and \
+ compound_index==1 and not entry:
+ cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+ cte_text += ", \n".join(
+ [txt for txt in self.ctes.values()]
+ )
+ cte_text += "\n "
+ text = cte_text + text
+
self.stack.pop(-1)
if asfrom and parens:
@@ -904,6 +972,12 @@ class SQLCompiler(engine.Compiled):
else:
return text
+ def get_cte_preamble(self, recursive):
+ if recursive:
+ return "WITH RECURSIVE"
+ else:
+ return "WITH"
+
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list.
@@ -977,12 +1051,26 @@ class SQLCompiler(engine.Compiled):
text = "INSERT"
+
prefixes = [self.process(x) for x in insert_stmt._prefixes]
if prefixes:
text += " " + " ".join(prefixes)
text += " INTO " + preparer.format_table(insert_stmt.table)
+ if insert_stmt._hints:
+ dialect_hints = dict([
+ (table, hint_text)
+ for (table, dialect), hint_text in
+ insert_stmt._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
+ if insert_stmt.table in dialect_hints:
+ text += " " + self.get_crud_hint_text(
+ insert_stmt.table,
+ dialect_hints[insert_stmt.table]
+ )
+
if colparams or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0])
for c in colparams])
@@ -1014,21 +1102,25 @@ class SQLCompiler(engine.Compiled):
extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
-
+
MySQL overrides this.
"""
return self.preparer.format_table(from_table)
- def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+ def update_from_clause(self, update_stmt,
+ from_table, extra_froms,
+ from_hints,
+ **kw):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
-
+
MySQL overrides this.
"""
return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True, **kw)
+ t._compiler_dispatch(self, asfrom=True,
+ fromhints=from_hints, **kw)
for t in extra_froms)
def visit_update(self, update_stmt, **kw):
@@ -1045,6 +1137,21 @@ class SQLCompiler(engine.Compiled):
update_stmt.table,
extra_froms, **kw)
+ if update_stmt._hints:
+ dialect_hints = dict([
+ (table, hint_text)
+ for (table, dialect), hint_text in
+ update_stmt._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
+ if update_stmt.table in dialect_hints:
+ text += " " + self.get_crud_hint_text(
+ update_stmt.table,
+ dialect_hints[update_stmt.table]
+ )
+ else:
+ dialect_hints = None
+
text += ' SET '
if extra_froms and self.render_table_with_column_in_update_from:
text += ', '.join(
@@ -1067,7 +1174,8 @@ class SQLCompiler(engine.Compiled):
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
- extra_froms, **kw)
+ extra_froms,
+ dialect_hints, **kw)
if extra_from_text:
text += " " + extra_from_text
@@ -1133,7 +1241,6 @@ class SQLCompiler(engine.Compiled):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(sql._column_as_key(k), v)
-
# create a list of column assignment clauses as tuples
values = []
@@ -1192,7 +1299,7 @@ class SQLCompiler(engine.Compiled):
# "defaults", "primary key cols", etc.
for c in stmt.table.columns:
if c.key in parameters and c.key not in check_columns:
- value = parameters[c.key]
+ value = parameters.pop(c.key)
if sql._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is required)
@@ -1288,6 +1395,17 @@ class SQLCompiler(engine.Compiled):
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
+
+ if parameters and stmt.parameters:
+ check = set(parameters).intersection(
+ sql._column_as_key(k) for k in stmt.parameters
+ ).difference(check_columns)
+ if check:
+ util.warn(
+ "Unconsumed column names: %s" %
+ (", ".join(check))
+ )
+
return values
def visit_delete(self, delete_stmt):
@@ -1296,6 +1414,21 @@ class SQLCompiler(engine.Compiled):
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
+ if delete_stmt._hints:
+ dialect_hints = dict([
+ (table, hint_text)
+ for (table, dialect), hint_text in
+ delete_stmt._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
+ if delete_stmt.table in dialect_hints:
+ text += " " + self.get_crud_hint_text(
+ delete_stmt.table,
+ dialect_hints[delete_stmt.table]
+ )
+ else:
+ dialect_hints = None
+
if delete_stmt._returning:
self.returning = delete_stmt._returning
if self.returning_precedes_values:
@@ -1445,7 +1578,7 @@ class DDLCompiler(engine.Compiled):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
def _index_identifier(self, ident):
- if isinstance(ident, sql._generated_label):
+ if isinstance(ident, sql._truncated_label):
max = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length
if len(ident) > max:
diff --git a/libs/sqlalchemy/sql/expression.py b/libs/sqlalchemy/sql/expression.py
index bff086e..aa67f44 100644
--- a/libs/sqlalchemy/sql/expression.py
+++ b/libs/sqlalchemy/sql/expression.py
@@ -832,6 +832,14 @@ def tuple_(*expr):
[(1, 2), (5, 12), (10, 19)]
)
+ .. warning::
+
+ The composite IN construct is not supported by all backends,
+ and is currently known to work on Postgresql and MySQL,
+ but not SQLite. Unsupported backends will raise
+ a subclass of :class:`~sqlalchemy.exc.DBAPIError` when such
+ an expression is invoked.
+
"""
return _Tuple(*expr)
@@ -1275,14 +1283,48 @@ func = _FunctionGenerator()
# TODO: use UnaryExpression for this instead ?
modifier = _FunctionGenerator(group=False)
-class _generated_label(unicode):
- """A unicode subclass used to identify dynamically generated names."""
+class _truncated_label(unicode):
+ """A unicode subclass used to identify symbolic "
+ "names that may require truncation."""
+
+ def apply_map(self, map_):
+ return self
+
+# for backwards compatibility in case
+# someone is re-implementing the
+# _truncated_identifier() sequence in a custom
+# compiler
+_generated_label = _truncated_label
+
+class _anonymous_label(_truncated_label):
+ """A unicode subclass used to identify anonymously
+ generated names."""
+
+ def __add__(self, other):
+ return _anonymous_label(
+ unicode(self) +
+ unicode(other))
+
+ def __radd__(self, other):
+ return _anonymous_label(
+ unicode(other) +
+ unicode(self))
+
+ def apply_map(self, map_):
+ return self % map_
-def _escape_for_generated(x):
- if isinstance(x, _generated_label):
- return x
+def _as_truncated(value):
+ """coerce the given value to :class:`._truncated_label`.
+
+ Existing :class:`._truncated_label` and
+ :class:`._anonymous_label` objects are passed
+ unchanged.
+ """
+
+ if isinstance(value, _truncated_label):
+ return value
else:
- return x.replace('%', '%%')
+ return _truncated_label(value)
def _string_or_unprintable(element):
if isinstance(element, basestring):
@@ -1466,6 +1508,7 @@ class ClauseElement(Visitable):
supports_execution = False
_from_objects = []
bind = None
+ _is_clone_of = None
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@@ -1514,7 +1557,7 @@ class ClauseElement(Visitable):
f = self
while f is not None:
s.add(f)
- f = getattr(f, '_is_clone_of', None)
+ f = f._is_clone_of
return s
def __getstate__(self):
@@ -2063,6 +2106,8 @@ class ColumnElement(ClauseElement, _CompareMixin):
foreign_keys = []
quote = None
_label = None
+ _key_label = None
+ _alt_names = ()
@property
def _select_iterable(self):
@@ -2109,9 +2154,14 @@ class ColumnElement(ClauseElement, _CompareMixin):
else:
key = name
- co = ColumnClause(name, selectable, type_=getattr(self,
+ co = ColumnClause(_as_truncated(name),
+ selectable,
+ type_=getattr(self,
'type', None))
co.proxies = [self]
+ if selectable._is_clone_of is not None:
+ co._is_clone_of = \
+ selectable._is_clone_of.columns[key]
selectable._columns[key] = co
return co
@@ -2157,7 +2207,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
expressions and function calls.
"""
- return _generated_label('%%(%d %s)s' % (id(self), getattr(self,
+ return _anonymous_label('%%(%d %s)s' % (id(self), getattr(self,
'name', 'anon')))
class ColumnCollection(util.OrderedProperties):
@@ -2420,6 +2470,13 @@ class FromClause(Selectable):
"""
+ def embedded(expanded_proxy_set, target_set):
+ for t in target_set.difference(expanded_proxy_set):
+ if not set(_expand_cloned([t])
+ ).intersection(expanded_proxy_set):
+ return False
+ return True
+
# dont dig around if the column is locally present
if self.c.contains_column(column):
return column
@@ -2427,10 +2484,10 @@ class FromClause(Selectable):
target_set = column.proxy_set
cols = self.c
for c in cols:
- i = target_set.intersection(itertools.chain(*[p._cloned_set
- for p in c.proxy_set]))
+ expanded_proxy_set = set(_expand_cloned(c.proxy_set))
+ i = target_set.intersection(expanded_proxy_set)
if i and (not require_embedded
- or c.proxy_set.issuperset(target_set)):
+ or embedded(expanded_proxy_set, target_set)):
if col is None:
# no corresponding column yet, pick this one.
@@ -2580,10 +2637,10 @@ class _BindParamClause(ColumnElement):
"""
if unique:
- self.key = _generated_label('%%(%d %s)s' % (id(self), key
+ self.key = _anonymous_label('%%(%d %s)s' % (id(self), key
or 'param'))
else:
- self.key = key or _generated_label('%%(%d param)s'
+ self.key = key or _anonymous_label('%%(%d param)s'
% id(self))
# identifiying key that won't change across
@@ -2631,14 +2688,14 @@ class _BindParamClause(ColumnElement):
def _clone(self):
c = ClauseElement._clone(self)
if self.unique:
- c.key = _generated_label('%%(%d %s)s' % (id(c), c._orig_key
+ c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key
or 'param'))
return c
def _convert_to_unique(self):
if not self.unique:
self.unique = True
- self.key = _generated_label('%%(%d %s)s' % (id(self),
+ self.key = _anonymous_label('%%(%d %s)s' % (id(self),
self._orig_key or 'param'))
def compare(self, other, **kw):
@@ -3607,7 +3664,7 @@ class Alias(FromClause):
if name is None:
if self.original.named_with_column:
name = getattr(self.original, 'name', None)
- name = _generated_label('%%(%d %s)s' % (id(self), name
+ name = _anonymous_label('%%(%d %s)s' % (id(self), name
or 'anon'))
self.name = name
@@ -3662,6 +3719,47 @@ class Alias(FromClause):
def bind(self):
return self.element.bind
+class CTE(Alias):
+ """Represent a Common Table Expression.
+
+ The :class:`.CTE` object is obtained using the
+ :meth:`._SelectBase.cte` method from any selectable.
+ See that method for complete examples.
+
+ New in 0.7.6.
+
+ """
+ __visit_name__ = 'cte'
+ def __init__(self, selectable,
+ name=None,
+ recursive=False,
+ cte_alias=False):
+ self.recursive = recursive
+ self.cte_alias = cte_alias
+ super(CTE, self).__init__(selectable, name=name)
+
+ def alias(self, name=None):
+ return CTE(
+ self.original,
+ name=name,
+ recursive=self.recursive,
+ cte_alias = self.name
+ )
+
+ def union(self, other):
+ return CTE(
+ self.original.union(other),
+ name=self.name,
+ recursive=self.recursive
+ )
+
+ def union_all(self, other):
+ return CTE(
+ self.original.union_all(other),
+ name=self.name,
+ recursive=self.recursive
+ )
+
class _Grouping(ColumnElement):
"""Represent a grouping within a column expression"""
@@ -3807,9 +3905,12 @@ class _Label(ColumnElement):
def __init__(self, name, element, type_=None):
while isinstance(element, _Label):
element = element.element
- self.name = self.key = self._label = name \
- or _generated_label('%%(%d %s)s' % (id(self),
+ if name:
+ self.name = name
+ else:
+ self.name = _anonymous_label('%%(%d %s)s' % (id(self),
getattr(element, 'name', 'anon')))
+ self.key = self._label = self._key_label = self.name
self._element = element
self._type = type_
self.quote = element.quote
@@ -3957,7 +4058,17 @@ class ColumnClause(_Immutable, ColumnElement):
# end Py2K
@_memoized_property
+ def _key_label(self):
+ if self.key != self.name:
+ return self._gen_label(self.key)
+ else:
+ return self._label
+
+ @_memoized_property
def _label(self):
+ return self._gen_label(self.name)
+
+ def _gen_label(self, name):
t = self.table
if self.is_literal:
return None
@@ -3965,11 +4076,9 @@ class ColumnClause(_Immutable, ColumnElement):
elif t is not None and t.named_with_column:
if getattr(t, 'schema', None):
label = t.schema.replace('.', '_') + "_" + \
- _escape_for_generated(t.name) + "_" + \
- _escape_for_generated(self.name)
+ t.name + "_" + name
else:
- label = _escape_for_generated(t.name) + "_" + \
- _escape_for_generated(self.name)
+ label = t.name + "_" + name
# ensure the label name doesn't conflict with that
# of an existing column
@@ -3981,10 +4090,10 @@ class ColumnClause(_Immutable, ColumnElement):
counter += 1
label = _label
- return _generated_label(label)
+ return _as_truncated(label)
else:
- return self.name
+ return name
def label(self, name):
# currently, anonymous labels don't occur for
@@ -4010,12 +4119,15 @@ class ColumnClause(_Immutable, ColumnElement):
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = self._constructor(
- name or self.name,
+ _as_truncated(name or self.name),
selectable=selectable,
type_=self.type,
is_literal=is_literal
)
c.proxies = [self]
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = \
+ selectable._is_clone_of.columns[c.name]
if attach:
selectable._columns[c.name] = c
@@ -4218,6 +4330,125 @@ class _SelectBase(Executable, FromClause):
"""
return self.as_scalar().label(name)
+ def cte(self, name=None, recursive=False):
+ """Return a new :class:`.CTE`, or Common Table Expression instance.
+
+ Common table expressions are a SQL standard whereby SELECT
+ statements can draw upon secondary statements specified along
+ with the primary statement, using a clause called "WITH".
+ Special semantics regarding UNION can also be employed to
+ allow "recursive" queries, where a SELECT statement can draw
+ upon the set of rows that have previously been selected.
+
+ SQLAlchemy detects :class:`.CTE` objects, which are treated
+ similarly to :class:`.Alias` objects, as special elements
+ to be delivered to the FROM clause of the statement as well
+ as to a WITH clause at the top of the statement.
+
+ The :meth:`._SelectBase.cte` method is new in 0.7.6.
+
+ :param name: name given to the common table expression. Like
+ :meth:`._FromClause.alias`, the name can be left as ``None``
+ in which case an anonymous symbol will be used at query
+ compile time.
+ :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+ A recursive common table expression is intended to be used in
+ conjunction with UNION ALL in order to derive rows
+ from those already selected.
+
+ The following examples illustrate two examples from
+ Postgresql's documentation at
+ http://www.postgresql.org/docs/8.4/static/queries-with.html.
+
+ Example 1, non recursive::
+
+ from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+ select, func
+
+ metadata = MetaData()
+
+ orders = Table('orders', metadata,
+ Column('region', String),
+ Column('amount', Integer),
+ Column('product', String),
+ Column('quantity', Integer)
+ )
+
+ regional_sales = select([
+ orders.c.region,
+ func.sum(orders.c.amount).label('total_sales')
+ ]).group_by(orders.c.region).cte("regional_sales")
+
+
+ top_regions = select([regional_sales.c.region]).\\
+ where(
+ regional_sales.c.total_sales >
+ select([
+ func.sum(regional_sales.c.total_sales)/10
+ ])
+ ).cte("top_regions")
+
+ statement = select([
+ orders.c.region,
+ orders.c.product,
+ func.sum(orders.c.quantity).label("product_units"),
+ func.sum(orders.c.amount).label("product_sales")
+ ]).where(orders.c.region.in_(
+ select([top_regions.c.region])
+ )).group_by(orders.c.region, orders.c.product)
+
+ result = conn.execute(statement).fetchall()
+
+ Example 2, WITH RECURSIVE::
+
+ from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+ select, func
+
+ metadata = MetaData()
+
+ parts = Table('parts', metadata,
+ Column('part', String),
+ Column('sub_part', String),
+ Column('quantity', Integer),
+ )
+
+ included_parts = select([
+ parts.c.sub_part,
+ parts.c.part,
+ parts.c.quantity]).\\
+ where(parts.c.part=='our part').\\
+ cte(recursive=True)
+
+
+ incl_alias = included_parts.alias()
+ parts_alias = parts.alias()
+ included_parts = included_parts.union_all(
+ select([
+ parts_alias.c.part,
+ parts_alias.c.sub_part,
+ parts_alias.c.quantity
+ ]).
+ where(parts_alias.c.part==incl_alias.c.sub_part)
+ )
+
+ statement = select([
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).label('total_quantity')
+ ]).\
+ select_from(included_parts.join(parts,
+ included_parts.c.part==parts.c.part)).\\
+ group_by(included_parts.c.sub_part)
+
+ result = conn.execute(statement).fetchall()
+
+
+ See also:
+
+ :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`.
+
+ """
+ return CTE(self, name=name, recursive=recursive)
+
@_generative
@util.deprecated('0.6',
message=":func:`.autocommit` is deprecated. Use "
@@ -4602,7 +4833,7 @@ class Select(_SelectBase):
The text of the hint is rendered in the appropriate
location for the database backend in use, relative
to the given :class:`.Table` or :class:`.Alias` passed as the
- *selectable* argument. The dialect implementation
+ ``selectable`` argument. The dialect implementation
typically uses Python string substitution syntax
with the token ``%(name)s`` to render the name of
the table or alias. E.g. when using Oracle, the
@@ -4999,7 +5230,9 @@ class Select(_SelectBase):
def _populate_column_collection(self):
for c in self.inner_columns:
if hasattr(c, '_make_proxy'):
- c._make_proxy(self, name=self.use_labels and c._label or None)
+ c._make_proxy(self,
+ name=self.use_labels
+ and c._label or None)
def self_group(self, against=None):
"""return a 'grouping' construct as per the ClauseElement
@@ -5086,6 +5319,7 @@ class UpdateBase(Executable, ClauseElement):
_execution_options = \
Executable._execution_options.union({'autocommit': True})
kwargs = util.immutabledict()
+ _hints = util.immutabledict()
def _process_colparams(self, parameters):
if isinstance(parameters, (list, tuple)):
@@ -5166,6 +5400,45 @@ class UpdateBase(Executable, ClauseElement):
"""
self._returning = cols
+ @_generative
+ def with_hint(self, text, selectable=None, dialect_name="*"):
+ """Add a table hint for a single table to this
+ INSERT/UPDATE/DELETE statement.
+
+ .. note::
+
+ :meth:`.UpdateBase.with_hint` currently applies only to
+ Microsoft SQL Server. For MySQL INSERT hints, use
+ :meth:`.Insert.prefix_with`. UPDATE/DELETE hints for
+ MySQL will be added in a future release.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the :class:`.Table` that is the subject of this
+ statement, or optionally to that of the given
+ :class:`.Table` passed as the ``selectable`` argument.
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add a hint
+ that only takes effect for SQL Server::
+
+ mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
+
+ New in 0.7.6.
+
+ :param text: Text of the hint.
+ :param selectable: optional :class:`.Table` that specifies
+ an element of the FROM clause within an UPDATE or DELETE
+ to be the subject of the hint - applies only to certain backends.
+ :param dialect_name: defaults to ``*``, if specified as the name
+ of a particular dialect, will apply these hints only when
+ that dialect is in use.
+ """
+ if selectable is None:
+ selectable = self.table
+
+ self._hints = self._hints.union({(selectable, dialect_name):text})
+
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs."""
diff --git a/libs/sqlalchemy/sql/visitors.py b/libs/sqlalchemy/sql/visitors.py
index cdcf40a..5354fbc 100644
--- a/libs/sqlalchemy/sql/visitors.py
+++ b/libs/sqlalchemy/sql/visitors.py
@@ -34,11 +34,19 @@ __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'cloned_traverse', 'replacement_traverse']
class VisitableType(type):
- """Metaclass which checks for a `__visit_name__` attribute and
- applies `_compiler_dispatch` method to classes.
-
+ """Metaclass which assigns a `_compiler_dispatch` method to classes
+ having a `__visit_name__` attribute.
+
+ The _compiler_dispatch attribute becomes an instance method which
+ looks approximately like the following::
+
+ def _compiler_dispatch (self, visitor, **kw):
+ '''Look for an attribute named "visit_" + self.__visit_name__
+ on the visitor, and call it with the same kw params.'''
+ return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+
+ Classes having no __visit_name__ attribute will remain unaffected.
"""
-
def __init__(cls, clsname, bases, clsdict):
if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
super(VisitableType, cls).__init__(clsname, bases, clsdict)
@@ -48,19 +56,31 @@ class VisitableType(type):
super(VisitableType, cls).__init__(clsname, bases, clsdict)
+
def _generate_dispatch(cls):
- # set up an optimized visit dispatch function
- # for use by the compiler
+ """Return an optimized visit dispatch function for the cls
+ for use by the compiler.
+ """
if '__visit_name__' in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
+ # There is an optimization opportunity here because the
+ # the string name of the class's __visit_name__ is known at
+ # this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
return getter(visitor)(self, **kw)
else:
+ # The optimization opportunity is lost for this case because the
+ # __visit_name__ is not yet a string. As a result, the visit
+ # string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+ _compiler_dispatch.__doc__ = \
+ """Look for an attribute named "visit_" + self.__visit_name__
+ on the visitor, and call it with the same kw params.
+ """
cls._compiler_dispatch = _compiler_dispatch
class Visitable(object):
diff --git a/libs/sqlalchemy/types.py b/libs/sqlalchemy/types.py
index 8c8e6eb..512ac62 100644
--- a/libs/sqlalchemy/types.py
+++ b/libs/sqlalchemy/types.py
@@ -397,7 +397,7 @@ class TypeDecorator(TypeEngine):
def copy(self):
return MyType(self.impl.length)
- The class-level "impl" variable is required, and can reference any
+ The class-level "impl" attribute is required, and can reference any
TypeEngine class. Alternatively, the load_dialect_impl() method
can be used to provide different type classes based on the dialect
given; in this case, the "impl" variable can reference
@@ -457,15 +457,19 @@ class TypeDecorator(TypeEngine):
Arguments sent here are passed to the constructor
of the class assigned to the ``impl`` class level attribute,
- where the ``self.impl`` attribute is assigned an instance
- of the implementation type. If ``impl`` at the class level
- is already an instance, then it's assigned to ``self.impl``
- as is.
+ assuming the ``impl`` is a callable, and the resulting
+ object is assigned to the ``self.impl`` instance attribute
+ (thus overriding the class attribute of the same name).
+
+ If the class level ``impl`` is not a callable (the unusual case),
+ it will be assigned to the same instance attribute 'as-is',
+ ignoring those arguments passed to the constructor.
Subclasses can override this to customize the generation
- of ``self.impl``.
+ of ``self.impl`` entirely.
"""
+
if not hasattr(self.__class__, 'impl'):
raise AssertionError("TypeDecorator implementations "
"require a class-level variable "
@@ -475,6 +479,9 @@ class TypeDecorator(TypeEngine):
def _gen_dialect_impl(self, dialect):
+ """
+ #todo
+ """
adapted = dialect.type_descriptor(self)
if adapted is not self:
return adapted
@@ -494,6 +501,9 @@ class TypeDecorator(TypeEngine):
@property
def _type_affinity(self):
+ """
+ #todo
+ """
return self.impl._type_affinity
def type_engine(self, dialect):
@@ -531,7 +541,6 @@ class TypeDecorator(TypeEngine):
def __getattr__(self, key):
"""Proxy all other undefined accessors to the underlying
implementation."""
-
return getattr(self.impl, key)
def process_bind_param(self, value, dialect):
@@ -542,29 +551,52 @@ class TypeDecorator(TypeEngine):
:class:`.TypeEngine` object, and from there to the
DBAPI ``execute()`` method.
- :param value: the value. Can be None.
+ The operation could be anything desired to perform custom
+ behavior, such as transforming or serializing data.
+ This could also be used as a hook for validating logic.
+
+ This operation should be designed with the reverse operation
+ in mind, which would be the process_result_value method of
+ this class.
+
+ :param value: Data to operate upon, of any type expected by
+ this method in the subclass. Can be ``None``.
:param dialect: the :class:`.Dialect` in use.
"""
+
raise NotImplementedError()
def process_result_value(self, value, dialect):
"""Receive a result-row column value to be converted.
+ Subclasses should implement this method to operate on data
+ fetched from the database.
+
Subclasses override this method to return the
value that should be passed back to the application,
given a value that is already processed by
the underlying :class:`.TypeEngine` object, originally
from the DBAPI cursor method ``fetchone()`` or similar.
- :param value: the value. Can be None.
+ The operation could be anything desired to perform custom
+ behavior, such as transforming or serializing data.
+ This could also be used as a hook for validating logic.
+
+ :param value: Data to operate upon, of any type expected by
+ this method in the subclass. Can be ``None``.
:param dialect: the :class:`.Dialect` in use.
+ This operation should be designed to be reversible by
+ the "process_bind_param" method of this class.
+
"""
+
raise NotImplementedError()
def bind_processor(self, dialect):
- """Provide a bound value processing function for the given :class:`.Dialect`.
+ """Provide a bound value processing function for the
+ given :class:`.Dialect`.
This is the method that fulfills the :class:`.TypeEngine`
contract for bound value conversion. :class:`.TypeDecorator`
@@ -575,6 +607,11 @@ class TypeDecorator(TypeEngine):
though its likely best to use :meth:`process_bind_param` so that
the processing provided by ``self.impl`` is maintained.
+ :param dialect: Dialect instance in use.
+
+ This method is the reverse counterpart to the
+ :meth:`result_processor` method of this class.
+
"""
if self.__class__.process_bind_param.func_code \
is not TypeDecorator.process_bind_param.func_code:
@@ -604,6 +641,12 @@ class TypeDecorator(TypeEngine):
though its likely best to use :meth:`process_result_value` so that
the processing provided by ``self.impl`` is maintained.
+ :param dialect: Dialect instance in use.
+ :param coltype: An SQLAlchemy data type
+
+ This method is the reverse counterpart to the
+ :meth:`bind_processor` method of this class.
+
"""
if self.__class__.process_result_value.func_code \
is not TypeDecorator.process_result_value.func_code:
@@ -654,6 +697,7 @@ class TypeDecorator(TypeEngine):
has local state that should be deep-copied.
"""
+
instance = self.__class__.__new__(self.__class__)
instance.__dict__.update(self.__dict__)
return instance
@@ -724,6 +768,9 @@ class TypeDecorator(TypeEngine):
return self.impl.is_mutable()
def _adapt_expression(self, op, othertype):
+ """
+ #todo
+ """
op, typ =self.impl._adapt_expression(op, othertype)
if typ is self.impl:
return op, self
diff --git a/libs/sqlalchemy/util/__init__.py b/libs/sqlalchemy/util/__init__.py
index 5712940..13914aa 100644
--- a/libs/sqlalchemy/util/__init__.py
+++ b/libs/sqlalchemy/util/__init__.py
@@ -7,7 +7,7 @@
from compat import callable, cmp, reduce, defaultdict, py25_dict, \
threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \
update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\
- parse_qsl, any
+ parse_qsl, any, contextmanager
from _collections import NamedTuple, ImmutableContainer, immutabledict, \
Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
diff --git a/libs/sqlalchemy/util/compat.py b/libs/sqlalchemy/util/compat.py
index 07652f3..99b92b1 100644
--- a/libs/sqlalchemy/util/compat.py
+++ b/libs/sqlalchemy/util/compat.py
@@ -57,6 +57,12 @@ buffer = buffer
# end Py2K
try:
+ from contextlib import contextmanager
+except ImportError:
+ def contextmanager(fn):
+ return fn
+
+try:
from functools import update_wrapper
except ImportError:
def update_wrapper(wrapper, wrapped,