Browse Source

Update sqlalchemy

tags/build/2.0.0.pre1
Ruud 13 years ago
parent
commit
710c2aa05f
  1. 2
      libs/sqlalchemy/__init__.py
  2. 25
      libs/sqlalchemy/cextension/processors.c
  3. 17
      libs/sqlalchemy/cextension/resultproxy.c
  4. 2
      libs/sqlalchemy/dialects/firebird/base.py
  5. 10
      libs/sqlalchemy/dialects/mssql/base.py
  6. 116
      libs/sqlalchemy/dialects/mysql/base.py
  7. 7
      libs/sqlalchemy/dialects/oracle/base.py
  8. 32
      libs/sqlalchemy/dialects/oracle/cx_oracle.py
  9. 15
      libs/sqlalchemy/dialects/postgresql/base.py
  10. 22
      libs/sqlalchemy/dialects/postgresql/psycopg2.py
  11. 25
      libs/sqlalchemy/dialects/sqlite/base.py
  12. 6
      libs/sqlalchemy/engine/__init__.py
  13. 356
      libs/sqlalchemy/engine/base.py
  14. 12
      libs/sqlalchemy/engine/default.py
  15. 9
      libs/sqlalchemy/engine/reflection.py
  16. 6
      libs/sqlalchemy/engine/strategies.py
  17. 35
      libs/sqlalchemy/event.py
  18. 23
      libs/sqlalchemy/exc.py
  19. 8
      libs/sqlalchemy/ext/declarative.py
  20. 65
      libs/sqlalchemy/ext/hybrid.py
  21. 7
      libs/sqlalchemy/ext/orderinglist.py
  22. 64
      libs/sqlalchemy/orm/collections.py
  23. 649
      libs/sqlalchemy/orm/mapper.py
  24. 777
      libs/sqlalchemy/orm/persistence.py
  25. 158
      libs/sqlalchemy/orm/query.py
  26. 12
      libs/sqlalchemy/orm/scoping.py
  27. 43
      libs/sqlalchemy/orm/session.py
  28. 1
      libs/sqlalchemy/orm/sync.py
  29. 12
      libs/sqlalchemy/orm/unitofwork.py
  30. 76
      libs/sqlalchemy/orm/util.py
  31. 56
      libs/sqlalchemy/pool.py
  32. 54
      libs/sqlalchemy/schema.py
  33. 179
      libs/sqlalchemy/sql/compiler.py
  34. 329
      libs/sqlalchemy/sql/expression.py
  35. 32
      libs/sqlalchemy/sql/visitors.py
  36. 67
      libs/sqlalchemy/types.py
  37. 2
      libs/sqlalchemy/util/__init__.py
  38. 6
      libs/sqlalchemy/util/compat.py

2
libs/sqlalchemy/__init__.py

@ -117,7 +117,7 @@ from sqlalchemy.engine import create_engine, engine_from_config
__all__ = sorted(name for name, obj in locals().items() __all__ = sorted(name for name, obj in locals().items()
if not (name.startswith('_') or inspect.ismodule(obj))) if not (name.startswith('_') or inspect.ismodule(obj)))
__version__ = '0.7.5' __version__ = '0.7.6'
del inspect, sys del inspect, sys

25
libs/sqlalchemy/cextension/processors.c

@ -342,23 +342,18 @@ DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
if (value == Py_None) if (value == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
if (PyFloat_CheckExact(value)) { args = PyTuple_Pack(1, value);
/* Decimal does not accept float values directly */ if (args == NULL)
args = PyTuple_Pack(1, value); return NULL;
if (args == NULL)
return NULL;
str = PyString_Format(self->format, args); str = PyString_Format(self->format, args);
Py_DECREF(args); Py_DECREF(args);
if (str == NULL) if (str == NULL)
return NULL; return NULL;
result = PyObject_CallFunctionObjArgs(self->type, str, NULL); result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
Py_DECREF(str); Py_DECREF(str);
return result; return result;
} else {
return PyObject_CallFunctionObjArgs(self->type, value, NULL);
}
} }
static void static void

17
libs/sqlalchemy/cextension/resultproxy.c

@ -246,6 +246,7 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
PyObject *exc_module, *exception; PyObject *exc_module, *exception;
char *cstr_key; char *cstr_key;
long index; long index;
int key_fallback = 0;
if (PyInt_CheckExact(key)) { if (PyInt_CheckExact(key)) {
index = PyInt_AS_LONG(key); index = PyInt_AS_LONG(key);
@ -276,12 +277,17 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
"O", key); "O", key);
if (record == NULL) if (record == NULL)
return NULL; return NULL;
key_fallback = 1;
} }
indexobject = PyTuple_GetItem(record, 2); indexobject = PyTuple_GetItem(record, 2);
if (indexobject == NULL) if (indexobject == NULL)
return NULL; return NULL;
if (key_fallback) {
Py_DECREF(record);
}
if (indexobject == Py_None) { if (indexobject == Py_None) {
exc_module = PyImport_ImportModule("sqlalchemy.exc"); exc_module = PyImport_ImportModule("sqlalchemy.exc");
if (exc_module == NULL) if (exc_module == NULL)
@ -347,7 +353,16 @@ BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
else else
return tmp; return tmp;
return BaseRowProxy_subscript(self, name); tmp = BaseRowProxy_subscript(self, name);
if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyString_AsString(name)
);
return NULL;
}
return tmp;
} }
/*********************** /***********************

2
libs/sqlalchemy/dialects/firebird/base.py

@ -215,7 +215,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
# Override to not use the AS keyword which FB 1.5 does not like # Override to not use the AS keyword which FB 1.5 does not like
if asfrom: if asfrom:
alias_name = isinstance(alias.name, alias_name = isinstance(alias.name,
expression._generated_label) and \ expression._truncated_label) and \
self._truncated_identifier("alias", self._truncated_identifier("alias",
alias.name) or alias.name alias.name) or alias.name

10
libs/sqlalchemy/dialects/mssql/base.py

@ -791,6 +791,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
def get_from_hint_text(self, table, text): def get_from_hint_text(self, table, text):
return text return text
def get_crud_hint_text(self, table, text):
return text
def limit_clause(self, select): def limit_clause(self, select):
# Limit in mssql is after the select keyword # Limit in mssql is after the select keyword
return "" return ""
@ -949,6 +952,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
] ]
return 'OUTPUT ' + ', '.join(columns) return 'OUTPUT ' + ', '.join(columns)
def get_cte_preamble(self, recursive):
# SQL Server finds it too inconvenient to accept
# an entirely optional, SQL standard specified,
# "RECURSIVE" word with their "WITH",
# so here we go
return "WITH"
def label_select_column(self, select, column, asfrom): def label_select_column(self, select, column, asfrom):
if isinstance(column, expression.Function): if isinstance(column, expression.Function):
return column.label(None) return column.label(None)

116
libs/sqlalchemy/dialects/mysql/base.py

@ -84,6 +84,23 @@ all lower case both within SQLAlchemy as well as on the MySQL
database itself, especially if database reflection features are database itself, especially if database reflection features are
to be used. to be used.
Transaction Isolation Level
---------------------------
:func:`.create_engine` accepts an ``isolation_level``
parameter which results in the command ``SET SESSION
TRANSACTION ISOLATION LEVEL <level>`` being invoked for
every new connection. Valid values for this parameter are
``READ COMMITTED``, ``READ UNCOMMITTED``,
``REPEATABLE READ``, and ``SERIALIZABLE``::
engine = create_engine(
"mysql://scott:tiger@localhost/test",
isolation_level="READ UNCOMMITTED"
)
(new in 0.7.6)
Keys Keys
---- ----
@ -221,8 +238,29 @@ simply passed through to the underlying CREATE INDEX command, so it *must* be
an integer. MySQL only allows a length for an index if it is for a CHAR, an integer. MySQL only allows a length for an index if it is for a CHAR,
VARCHAR, TEXT, BINARY, VARBINARY and BLOB. VARCHAR, TEXT, BINARY, VARBINARY and BLOB.
Index Types
~~~~~~~~~~~~~
Some MySQL storage engines permit you to specify an index type when creating
an index or primary key constraint. SQLAlchemy provides this feature via the
``mysql_using`` parameter on :class:`.Index`::
Index('my_index', my_table.c.data, mysql_using='hash')
As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`::
PrimaryKeyConstraint("data", mysql_using='hash')
The value passed to the keyword argument will be simply passed through to the
underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index
type for your MySQL storage engine.
More information can be found at: More information can be found at:
http://dev.mysql.com/doc/refman/5.0/en/create-index.html http://dev.mysql.com/doc/refman/5.0/en/create-index.html
http://dev.mysql.com/doc/refman/5.0/en/create-table.html
""" """
import datetime, inspect, re, sys import datetime, inspect, re, sys
@ -1331,7 +1369,8 @@ class MySQLCompiler(compiler.SQLCompiler):
return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw)
for t in [from_table] + list(extra_froms)) for t in [from_table] + list(extra_froms))
def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): def update_from_clause(self, update_stmt, from_table,
extra_froms, from_hints, **kw):
return None return None
@ -1421,35 +1460,50 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
table_opts.append(joiner.join((opt, arg))) table_opts.append(joiner.join((opt, arg)))
return ' '.join(table_opts) return ' '.join(table_opts)
def visit_create_index(self, create): def visit_create_index(self, create):
index = create.element index = create.element
preparer = self.preparer preparer = self.preparer
table = preparer.format_table(index.table)
columns = [preparer.quote(c.name, c.quote) for c in index.columns]
name = preparer.quote(
self._index_identifier(index.name),
index.quote)
text = "CREATE " text = "CREATE "
if index.unique: if index.unique:
text += "UNIQUE " text += "UNIQUE "
text += "INDEX %s ON %s " \ text += "INDEX %s ON %s " % (name, table)
% (preparer.quote(self._index_identifier(index.name),
index.quote),preparer.format_table(index.table)) columns = ', '.join(columns)
if 'mysql_length' in index.kwargs: if 'mysql_length' in index.kwargs:
length = index.kwargs['mysql_length'] length = index.kwargs['mysql_length']
text += "(%s(%d))" % (columns, length)
else: else:
length = None text += "(%s)" % (columns)
if length is not None:
text+= "(%s(%d))" \ if 'mysql_using' in index.kwargs:
% (', '.join(preparer.quote(c.name, c.quote) using = index.kwargs['mysql_using']
for c in index.columns), length) text += " USING %s" % (preparer.quote(using, index.quote))
else:
text+= "(%s)" \
% (', '.join(preparer.quote(c.name, c.quote)
for c in index.columns))
return text return text
def visit_primary_key_constraint(self, constraint):
text = super(MySQLDDLCompiler, self).\
visit_primary_key_constraint(constraint)
if "mysql_using" in constraint.kwargs:
using = constraint.kwargs['mysql_using']
text += " USING %s" % (
self.preparer.quote(using, constraint.quote))
return text
def visit_drop_index(self, drop): def visit_drop_index(self, drop):
index = drop.element index = drop.element
return "\nDROP INDEX %s ON %s" % \ return "\nDROP INDEX %s ON %s" % \
(self.preparer.quote(self._index_identifier(index.name), index.quote), (self.preparer.quote(
self._index_identifier(index.name), index.quote
),
self.preparer.format_table(index.table)) self.preparer.format_table(index.table))
def visit_drop_constraint(self, drop): def visit_drop_constraint(self, drop):
@ -1768,8 +1822,40 @@ class MySQLDialect(default.DefaultDialect):
_backslash_escapes = True _backslash_escapes = True
_server_ansiquotes = False _server_ansiquotes = False
def __init__(self, use_ansiquotes=None, **kwargs): def __init__(self, use_ansiquotes=None, isolation_level=None, **kwargs):
default.DefaultDialect.__init__(self, **kwargs) default.DefaultDialect.__init__(self, **kwargs)
self.isolation_level = isolation_level
def on_connect(self):
if self.isolation_level is not None:
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
return connect
else:
return None
_isolation_lookup = set(['SERIALIZABLE',
'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ'])
def set_isolation_level(self, connection, level):
level = level.replace('_', ' ')
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s" %
(level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level)
cursor.execute("COMMIT")
cursor.close()
def get_isolation_level(self, connection):
cursor = connection.cursor()
cursor.execute('SELECT @@tx_isolation')
val = cursor.fetchone()[0]
cursor.close()
return val.upper().replace("-", " ")
def do_commit(self, connection): def do_commit(self, connection):
"""Execute a COMMIT.""" """Execute a COMMIT."""

7
libs/sqlalchemy/dialects/oracle/base.py

@ -158,7 +158,7 @@ RESERVED_WORDS = \
'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\ 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\
'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\ 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\
'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\ 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\
'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT'.split()) 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split())
NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER '
'CURRENT_TIME CURRENT_TIMESTAMP'.split()) 'CURRENT_TIME CURRENT_TIMESTAMP'.split())
@ -309,6 +309,9 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
"", "",
) )
def visit_LONG(self, type_):
return "LONG"
def visit_TIMESTAMP(self, type_): def visit_TIMESTAMP(self, type_):
if type_.timezone: if type_.timezone:
return "TIMESTAMP WITH TIME ZONE" return "TIMESTAMP WITH TIME ZONE"
@ -481,7 +484,7 @@ class OracleCompiler(compiler.SQLCompiler):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
if asfrom or ashint: if asfrom or ashint:
alias_name = isinstance(alias.name, expression._generated_label) and \ alias_name = isinstance(alias.name, expression._truncated_label) and \
self._truncated_identifier("alias", alias.name) or alias.name self._truncated_identifier("alias", alias.name) or alias.name
if ashint: if ashint:

32
libs/sqlalchemy/dialects/oracle/cx_oracle.py

@ -77,7 +77,7 @@ with this feature but it should be regarded as experimental.
Precision Numerics Precision Numerics
------------------ ------------------
The SQLAlchemy dialect goes thorugh a lot of steps to ensure The SQLAlchemy dialect goes through a lot of steps to ensure
that decimal numbers are sent and received with full accuracy. that decimal numbers are sent and received with full accuracy.
An "outputtypehandler" callable is associated with each An "outputtypehandler" callable is associated with each
cx_oracle connection object which detects numeric types and cx_oracle connection object which detects numeric types and
@ -89,6 +89,21 @@ this behavior, and will coerce the ``Decimal`` to ``float`` if
the ``asdecimal`` flag is ``False`` (default on :class:`.Float`, the ``asdecimal`` flag is ``False`` (default on :class:`.Float`,
optional on :class:`.Numeric`). optional on :class:`.Numeric`).
Because the handler coerces to ``Decimal`` in all cases first,
the feature can detract significantly from performance.
If precision numerics aren't required, the decimal handling
can be disabled by passing the flag ``coerce_to_decimal=False``
to :func:`.create_engine`::
engine = create_engine("oracle+cx_oracle://dsn",
coerce_to_decimal=False)
The ``coerce_to_decimal`` flag is new in 0.7.6.
Another alternative to performance is to use the
`cdecimal <http://pypi.python.org/pypi/cdecimal/>`_ library;
see :class:`.Numeric` for additional notes.
The handler attempts to use the "precision" and "scale" The handler attempts to use the "precision" and "scale"
attributes of the result set column to best determine if attributes of the result set column to best determine if
subsequent incoming values should be received as ``Decimal`` as subsequent incoming values should be received as ``Decimal`` as
@ -468,6 +483,7 @@ class OracleDialect_cx_oracle(OracleDialect):
auto_convert_lobs=True, auto_convert_lobs=True,
threaded=True, threaded=True,
allow_twophase=True, allow_twophase=True,
coerce_to_decimal=True,
arraysize=50, **kwargs): arraysize=50, **kwargs):
OracleDialect.__init__(self, **kwargs) OracleDialect.__init__(self, **kwargs)
self.threaded = threaded self.threaded = threaded
@ -491,7 +507,12 @@ class OracleDialect_cx_oracle(OracleDialect):
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0) self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0)
self.supports_native_decimal = self.cx_oracle_ver >= (5, 0)
self.supports_native_decimal = (
self.cx_oracle_ver >= (5, 0) and
coerce_to_decimal
)
self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0) self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0)
if self.cx_oracle_ver is None: if self.cx_oracle_ver is None:
@ -603,7 +624,9 @@ class OracleDialect_cx_oracle(OracleDialect):
size, precision, scale): size, precision, scale):
# convert all NUMBER with precision + positive scale to Decimal # convert all NUMBER with precision + positive scale to Decimal
# this almost allows "native decimal" mode. # this almost allows "native decimal" mode.
if defaultType == cx_Oracle.NUMBER and precision and scale > 0: if self.supports_native_decimal and \
defaultType == cx_Oracle.NUMBER and \
precision and scale > 0:
return cursor.var( return cursor.var(
cx_Oracle.STRING, cx_Oracle.STRING,
255, 255,
@ -614,7 +637,8 @@ class OracleDialect_cx_oracle(OracleDialect):
# make a decision based on each value received - the type # make a decision based on each value received - the type
# may change from row to row (!). This kills # may change from row to row (!). This kills
# off "native decimal" mode, handlers still needed. # off "native decimal" mode, handlers still needed.
elif defaultType == cx_Oracle.NUMBER \ elif self.supports_native_decimal and \
defaultType == cx_Oracle.NUMBER \
and not precision and scale <= 0: and not precision and scale <= 0:
return cursor.var( return cursor.var(
cx_Oracle.STRING, cx_Oracle.STRING,

15
libs/sqlalchemy/dialects/postgresql/base.py

@ -47,9 +47,18 @@ Transaction Isolation Level
:func:`.create_engine` accepts an ``isolation_level`` parameter which results :func:`.create_engine` accepts an ``isolation_level`` parameter which results
in the command ``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL in the command ``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL
<level>`` being invoked for every new connection. Valid values for this <level>`` being invoked for every new connection. Valid values for this
parameter are ``READ_COMMITTED``, ``READ_UNCOMMITTED``, ``REPEATABLE_READ``, parameter are ``READ COMMITTED``, ``READ UNCOMMITTED``, ``REPEATABLE READ``,
and ``SERIALIZABLE``. Note that the psycopg2 dialect does *not* use this and ``SERIALIZABLE``::
technique and uses psycopg2-specific APIs (see that dialect for details).
engine = create_engine(
"postgresql+pg8000://scott:tiger@localhost/test",
isolation_level="READ UNCOMMITTED"
)
When using the psycopg2 dialect, a psycopg2-specific method of setting
transaction isolation level is used, but the API of ``isolation_level``
remains the same - see :ref:`psycopg2_isolation`.
Remote / Cross-Schema Table Introspection Remote / Cross-Schema Table Introspection
----------------------------------------- -----------------------------------------

22
libs/sqlalchemy/dialects/postgresql/psycopg2.py

@ -38,6 +38,26 @@ psycopg2-specific keyword arguments which are accepted by
* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode * *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode
per connection. True by default. per connection. True by default.
Unix Domain Connections
------------------------
psycopg2 supports connecting via Unix domain connections. When the ``host``
portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
which specifies Unix-domain communication rather than TCP/IP communication::
create_engine("postgresql+psycopg2://user:password@/dbname")
By default, the socket file used is to connect to a Unix-domain socket
in ``/tmp``, or whatever socket directory was specified when PostgreSQL
was built. This value can be overridden by passing a pathname to psycopg2,
using ``host`` as an additional keyword argument::
create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
See also:
`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
Per-Statement/Connection Execution Options Per-Statement/Connection Execution Options
------------------------------------------- -------------------------------------------
@ -97,6 +117,8 @@ Transactions
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
.. _psycopg2_isolation:
Transaction Isolation Level Transaction Isolation Level
--------------------------- ---------------------------

25
libs/sqlalchemy/dialects/sqlite/base.py

@ -441,20 +441,6 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result
return result return result
class SQLiteExecutionContext(default.DefaultExecutionContext):
def get_result_proxy(self):
rp = base.ResultProxy(self)
if rp._metadata:
# adjust for dotted column names. SQLite
# in the case of UNION may store col names as
# "tablename.colname"
# in cursor.description
for colname in rp._metadata.keys:
if "." in colname:
trunc_col = colname.split(".")[1]
rp._metadata._set_keymap_synonym(trunc_col, colname)
return rp
class SQLiteDialect(default.DefaultDialect): class SQLiteDialect(default.DefaultDialect):
name = 'sqlite' name = 'sqlite'
supports_alter = False supports_alter = False
@ -472,7 +458,6 @@ class SQLiteDialect(default.DefaultDialect):
ischema_names = ischema_names ischema_names = ischema_names
colspecs = colspecs colspecs = colspecs
isolation_level = None isolation_level = None
execution_ctx_cls = SQLiteExecutionContext
supports_cast = True supports_cast = True
supports_default_values = True supports_default_values = True
@ -540,6 +525,16 @@ class SQLiteDialect(default.DefaultDialect):
else: else:
return None return None
def _translate_colname(self, colname):
# adjust for dotted column names. SQLite
# in the case of UNION may store col names as
# "tablename.colname"
# in cursor.description
if "." in colname:
return colname.split(".")[1], colname
else:
return colname, None
@reflection.cache @reflection.cache
def get_table_names(self, connection, schema=None, **kw): def get_table_names(self, connection, schema=None, **kw):
if schema is not None: if schema is not None:

6
libs/sqlalchemy/engine/__init__.py

@ -306,6 +306,12 @@ def create_engine(*args, **kwargs):
this is configurable with the MySQLDB connection itself and the this is configurable with the MySQLDB connection itself and the
server configuration as well). server configuration as well).
:param pool_reset_on_return='rollback': set the "reset on return"
behavior of the pool, which is whether ``rollback()``,
``commit()``, or nothing is called upon connections
being returned to the pool. See the docstring for
``reset_on_return`` at :class:`.Pool`. (new as of 0.7.6)
:param pool_timeout=30: number of seconds to wait before giving :param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`. with :class:`~sqlalchemy.pool.QueuePool`.

356
libs/sqlalchemy/engine/base.py

@ -491,14 +491,23 @@ class Dialect(object):
raise NotImplementedError() raise NotImplementedError()
def do_executemany(self, cursor, statement, parameters, context=None): def do_executemany(self, cursor, statement, parameters, context=None):
"""Provide an implementation of *cursor.executemany(statement, """Provide an implementation of ``cursor.executemany(statement,
parameters)*.""" parameters)``."""
raise NotImplementedError() raise NotImplementedError()
def do_execute(self, cursor, statement, parameters, context=None): def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of *cursor.execute(statement, """Provide an implementation of ``cursor.execute(statement,
parameters)*.""" parameters)``."""
raise NotImplementedError()
def do_execute_no_params(self, cursor, statement, parameters, context=None):
"""Provide an implementation of ``cursor.execute(statement)``.
The parameter collection should not be sent.
"""
raise NotImplementedError() raise NotImplementedError()
@ -777,12 +786,12 @@ class Connectable(object):
def connect(self, **kwargs): def connect(self, **kwargs):
"""Return a :class:`.Connection` object. """Return a :class:`.Connection` object.
Depending on context, this may be ``self`` if this object Depending on context, this may be ``self`` if this object
is already an instance of :class:`.Connection`, or a newly is already an instance of :class:`.Connection`, or a newly
procured :class:`.Connection` if this object is an instance procured :class:`.Connection` if this object is an instance
of :class:`.Engine`. of :class:`.Engine`.
""" """
def contextual_connect(self): def contextual_connect(self):
@ -793,7 +802,7 @@ class Connectable(object):
is already an instance of :class:`.Connection`, or a newly is already an instance of :class:`.Connection`, or a newly
procured :class:`.Connection` if this object is an instance procured :class:`.Connection` if this object is an instance
of :class:`.Engine`. of :class:`.Engine`.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -904,6 +913,12 @@ class Connection(Connectable):
c.__dict__ = self.__dict__.copy() c.__dict__ = self.__dict__.copy()
return c return c
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def execution_options(self, **opt): def execution_options(self, **opt):
""" Set non-SQL options for the connection which take effect """ Set non-SQL options for the connection which take effect
during execution. during execution.
@ -940,7 +955,7 @@ class Connection(Connectable):
:param compiled_cache: Available on: Connection. :param compiled_cache: Available on: Connection.
A dictionary where :class:`.Compiled` objects A dictionary where :class:`.Compiled` objects
will be cached when the :class:`.Connection` compiles a clause will be cached when the :class:`.Connection` compiles a clause
expression into a :class:`.Compiled` object. expression into a :class:`.Compiled` object.
It is the user's responsibility to It is the user's responsibility to
manage the size of this dictionary, which will have keys manage the size of this dictionary, which will have keys
corresponding to the dialect, clause element, the column corresponding to the dialect, clause element, the column
@ -953,7 +968,7 @@ class Connection(Connectable):
some operations, including flush operations. The caching some operations, including flush operations. The caching
used by the ORM internally supersedes a cache dictionary used by the ORM internally supersedes a cache dictionary
specified here. specified here.
:param isolation_level: Available on: Connection. :param isolation_level: Available on: Connection.
Set the transaction isolation level for Set the transaction isolation level for
the lifespan of this connection. Valid values include the lifespan of this connection. Valid values include
@ -962,7 +977,7 @@ class Connection(Connectable):
database specific, including those for :ref:`sqlite_toplevel`, database specific, including those for :ref:`sqlite_toplevel`,
:ref:`postgresql_toplevel` - see those dialect's documentation :ref:`postgresql_toplevel` - see those dialect's documentation
for further info. for further info.
Note that this option necessarily affects the underying Note that this option necessarily affects the underying
DBAPI connection for the lifespan of the originating DBAPI connection for the lifespan of the originating
:class:`.Connection`, and is not per-execution. This :class:`.Connection`, and is not per-execution. This
@ -970,6 +985,18 @@ class Connection(Connectable):
is returned to the connection pool, i.e. is returned to the connection pool, i.e.
the :meth:`.Connection.close` method is called. the :meth:`.Connection.close` method is called.
:param no_parameters: When ``True``, if the final parameter
list or dictionary is totally empty, will invoke the
statement on the cursor as ``cursor.execute(statement)``,
not passing the parameter collection at all.
Some DBAPIs such as psycopg2 and mysql-python consider
percent signs as significant only when parameters are
present; this option allows code to generate SQL
containing percent signs (and possibly other characters)
that is neutral regarding whether it's executed by the DBAPI
or piped into a script that's later invoked by
command line tools. New in 0.7.6.
:param stream_results: Available on: Connection, statement. :param stream_results: Available on: Connection, statement.
Indicate to the dialect that results should be Indicate to the dialect that results should be
"streamed" and not pre-buffered, if possible. This is a limitation "streamed" and not pre-buffered, if possible. This is a limitation
@ -1113,17 +1140,35 @@ class Connection(Connectable):
def begin(self): def begin(self):
"""Begin a transaction and return a transaction handle. """Begin a transaction and return a transaction handle.
The returned object is an instance of :class:`.Transaction`. The returned object is an instance of :class:`.Transaction`.
This object represents the "scope" of the transaction,
which completes when either the :meth:`.Transaction.rollback`
or :meth:`.Transaction.commit` method is called.
Nested calls to :meth:`.begin` on the same :class:`.Connection`
will return new :class:`.Transaction` objects that represent
an emulated transaction within the scope of the enclosing
transaction, that is::
trans = conn.begin() # outermost transaction
trans2 = conn.begin() # "nested"
trans2.commit() # does nothing
trans.commit() # actually commits
Calls to :meth:`.Transaction.commit` only have an effect
when invoked via the outermost :class:`.Transaction` object, though the
:meth:`.Transaction.rollback` method of any of the
:class:`.Transaction` objects will roll back the
transaction.
Repeated calls to ``begin`` on the same Connection will create See also:
a lightweight, emulated nested transaction. Only the
outermost transaction may ``commit``. Calls to ``commit`` on
inner transactions are ignored. Any transaction in the
hierarchy may ``rollback``, however.
See also :meth:`.Connection.begin_nested`, :meth:`.Connection.begin_nested` - use a SAVEPOINT
:meth:`.Connection.begin_twophase`.
:meth:`.Connection.begin_twophase` - use a two phase /XID transaction
:meth:`.Engine.begin` - context manager available from :class:`.Engine`.
""" """
@ -1157,7 +1202,7 @@ class Connection(Connectable):
def begin_twophase(self, xid=None): def begin_twophase(self, xid=None):
"""Begin a two-phase or XA transaction and return a transaction """Begin a two-phase or XA transaction and return a transaction
handle. handle.
The returned object is an instance of :class:`.TwoPhaseTransaction`, The returned object is an instance of :class:`.TwoPhaseTransaction`,
which in addition to the methods provided by which in addition to the methods provided by
:class:`.Transaction`, also provides a :meth:`~.TwoPhaseTransaction.prepare` :class:`.Transaction`, also provides a :meth:`~.TwoPhaseTransaction.prepare`
@ -1302,7 +1347,7 @@ class Connection(Connectable):
def close(self): def close(self):
"""Close this :class:`.Connection`. """Close this :class:`.Connection`.
This results in a release of the underlying database This results in a release of the underlying database
resources, that is, the DBAPI connection referenced resources, that is, the DBAPI connection referenced
internally. The DBAPI connection is typically restored internally. The DBAPI connection is typically restored
@ -1313,7 +1358,7 @@ class Connection(Connectable):
the DBAPI connection's ``rollback()`` method, regardless the DBAPI connection's ``rollback()`` method, regardless
of any :class:`.Transaction` object that may be of any :class:`.Transaction` object that may be
outstanding with regards to this :class:`.Connection`. outstanding with regards to this :class:`.Connection`.
After :meth:`~.Connection.close` is called, the After :meth:`~.Connection.close` is called, the
:class:`.Connection` is permanently in a closed state, :class:`.Connection` is permanently in a closed state,
and will allow no further operations. and will allow no further operations.
@ -1354,24 +1399,24 @@ class Connection(Connectable):
* a :class:`.DDLElement` object * a :class:`.DDLElement` object
* a :class:`.DefaultGenerator` object * a :class:`.DefaultGenerator` object
* a :class:`.Compiled` object * a :class:`.Compiled` object
:param \*multiparams/\**params: represent bound parameter :param \*multiparams/\**params: represent bound parameter
values to be used in the execution. Typically, values to be used in the execution. Typically,
the format is either a collection of one or more the format is either a collection of one or more
dictionaries passed to \*multiparams:: dictionaries passed to \*multiparams::
conn.execute( conn.execute(
table.insert(), table.insert(),
{"id":1, "value":"v1"}, {"id":1, "value":"v1"},
{"id":2, "value":"v2"} {"id":2, "value":"v2"}
) )
...or individual key/values interpreted by \**params:: ...or individual key/values interpreted by \**params::
conn.execute( conn.execute(
table.insert(), id=1, value="v1" table.insert(), id=1, value="v1"
) )
In the case that a plain SQL string is passed, and the underlying In the case that a plain SQL string is passed, and the underlying
DBAPI accepts positional bind parameters, a collection of tuples DBAPI accepts positional bind parameters, a collection of tuples
or individual values in \*multiparams may be passed:: or individual values in \*multiparams may be passed::
@ -1380,21 +1425,21 @@ class Connection(Connectable):
"INSERT INTO table (id, value) VALUES (?, ?)", "INSERT INTO table (id, value) VALUES (?, ?)",
(1, "v1"), (2, "v2") (1, "v1"), (2, "v2")
) )
conn.execute( conn.execute(
"INSERT INTO table (id, value) VALUES (?, ?)", "INSERT INTO table (id, value) VALUES (?, ?)",
1, "v1" 1, "v1"
) )
Note above, the usage of a question mark "?" or other Note above, the usage of a question mark "?" or other
symbol is contingent upon the "paramstyle" accepted by the DBAPI symbol is contingent upon the "paramstyle" accepted by the DBAPI
in use, which may be any of "qmark", "named", "pyformat", "format", in use, which may be any of "qmark", "named", "pyformat", "format",
"numeric". See `pep-249 <http://www.python.org/dev/peps/pep-0249/>`_ "numeric". See `pep-249 <http://www.python.org/dev/peps/pep-0249/>`_
for details on paramstyle. for details on paramstyle.
To execute a textual SQL statement which uses bound parameters in a To execute a textual SQL statement which uses bound parameters in a
DBAPI-agnostic way, use the :func:`~.expression.text` construct. DBAPI-agnostic way, use the :func:`~.expression.text` construct.
""" """
for c in type(object).__mro__: for c in type(object).__mro__:
if c in Connection.executors: if c in Connection.executors:
@ -1623,7 +1668,8 @@ class Connection(Connectable):
if self._echo: if self._echo:
self.engine.logger.info(statement) self.engine.logger.info(statement)
self.engine.logger.info("%r", sql_util._repr_params(parameters, batches=10)) self.engine.logger.info("%r",
sql_util._repr_params(parameters, batches=10))
try: try:
if context.executemany: if context.executemany:
self.dialect.do_executemany( self.dialect.do_executemany(
@ -1631,6 +1677,11 @@ class Connection(Connectable):
statement, statement,
parameters, parameters,
context) context)
elif not parameters and context.no_parameters:
self.dialect.do_execute_no_params(
cursor,
statement,
context)
else: else:
self.dialect.do_execute( self.dialect.do_execute(
cursor, cursor,
@ -1845,33 +1896,41 @@ class Connection(Connectable):
"""Execute the given function within a transaction boundary. """Execute the given function within a transaction boundary.
The function is passed this :class:`.Connection` The function is passed this :class:`.Connection`
as the first argument, followed by the given \*args and \**kwargs. as the first argument, followed by the given \*args and \**kwargs,
e.g.::
This is a shortcut for explicitly invoking
:meth:`.Connection.begin`, calling :meth:`.Transaction.commit`
upon success or :meth:`.Transaction.rollback` upon an
exception raise::
def do_something(conn, x, y): def do_something(conn, x, y):
conn.execute("some statement", {'x':x, 'y':y}) conn.execute("some statement", {'x':x, 'y':y})
conn.transaction(do_something, 5, 10) conn.transaction(do_something, 5, 10)
The operations inside the function are all invoked within the
context of a single :class:`.Transaction`.
Upon success, the transaction is committed. If an
exception is raised, the transaction is rolled back
before propagating the exception.
.. note::
The :meth:`.transaction` method is superseded by
the usage of the Python ``with:`` statement, which can
be used with :meth:`.Connection.begin`::
with conn.begin():
conn.execute("some statement", {'x':5, 'y':10})
As well as with :meth:`.Engine.begin`::
with engine.begin() as conn:
conn.execute("some statement", {'x':5, 'y':10})
Note that context managers (i.e. the ``with`` statement) See also:
present a more modern way of accomplishing the above,
using the :class:`.Transaction` object as a base::
with conn.begin(): :meth:`.Engine.begin` - engine-level transactional
conn.execute("some statement", {'x':5, 'y':10}) context
One advantage to the :meth:`.Connection.transaction` :meth:`.Engine.transaction` - engine-level version of
method is that the same method is also available :meth:`.Connection.transaction`
on :class:`.Engine` as :meth:`.Engine.transaction` -
this method procures a :class:`.Connection` and then
performs the same operation, allowing equivalent
usage with either a :class:`.Connection` or :class:`.Engine`
without needing to know what kind of object
it is.
""" """
@ -1887,15 +1946,15 @@ class Connection(Connectable):
def run_callable(self, callable_, *args, **kwargs): def run_callable(self, callable_, *args, **kwargs):
"""Given a callable object or function, execute it, passing """Given a callable object or function, execute it, passing
a :class:`.Connection` as the first argument. a :class:`.Connection` as the first argument.
The given \*args and \**kwargs are passed subsequent The given \*args and \**kwargs are passed subsequent
to the :class:`.Connection` argument. to the :class:`.Connection` argument.
This function, along with :meth:`.Engine.run_callable`, This function, along with :meth:`.Engine.run_callable`,
allows a function to be run with a :class:`.Connection` allows a function to be run with a :class:`.Connection`
or :class:`.Engine` object without the need to know or :class:`.Engine` object without the need to know
which one is being dealt with. which one is being dealt with.
""" """
return callable_(self, *args, **kwargs) return callable_(self, *args, **kwargs)
@ -1906,11 +1965,11 @@ class Connection(Connectable):
class Transaction(object): class Transaction(object):
"""Represent a database transaction in progress. """Represent a database transaction in progress.
The :class:`.Transaction` object is procured by The :class:`.Transaction` object is procured by
calling the :meth:`~.Connection.begin` method of calling the :meth:`~.Connection.begin` method of
:class:`.Connection`:: :class:`.Connection`::
from sqlalchemy import create_engine from sqlalchemy import create_engine
engine = create_engine("postgresql://scott:tiger@localhost/test") engine = create_engine("postgresql://scott:tiger@localhost/test")
connection = engine.connect() connection = engine.connect()
@ -1923,7 +1982,7 @@ class Transaction(object):
also implements a context manager interface so that also implements a context manager interface so that
the Python ``with`` statement can be used with the the Python ``with`` statement can be used with the
:meth:`.Connection.begin` method:: :meth:`.Connection.begin` method::
with connection.begin(): with connection.begin():
connection.execute("insert into x (a, b) values (1, 2)") connection.execute("insert into x (a, b) values (1, 2)")
@ -1931,7 +1990,7 @@ class Transaction(object):
See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`, See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`,
:meth:`.Connection.begin_nested`. :meth:`.Connection.begin_nested`.
.. index:: .. index::
single: thread safety; Transaction single: thread safety; Transaction
""" """
@ -2012,9 +2071,9 @@ class NestedTransaction(Transaction):
A new :class:`.NestedTransaction` object may be procured A new :class:`.NestedTransaction` object may be procured
using the :meth:`.Connection.begin_nested` method. using the :meth:`.Connection.begin_nested` method.
The interface is the same as that of :class:`.Transaction`. The interface is the same as that of :class:`.Transaction`.
""" """
def __init__(self, connection, parent): def __init__(self, connection, parent):
super(NestedTransaction, self).__init__(connection, parent) super(NestedTransaction, self).__init__(connection, parent)
@ -2033,13 +2092,13 @@ class NestedTransaction(Transaction):
class TwoPhaseTransaction(Transaction): class TwoPhaseTransaction(Transaction):
"""Represent a two-phase transaction. """Represent a two-phase transaction.
A new :class:`.TwoPhaseTransaction` object may be procured A new :class:`.TwoPhaseTransaction` object may be procured
using the :meth:`.Connection.begin_twophase` method. using the :meth:`.Connection.begin_twophase` method.
The interface is the same as that of :class:`.Transaction` The interface is the same as that of :class:`.Transaction`
with the addition of the :meth:`prepare` method. with the addition of the :meth:`prepare` method.
""" """
def __init__(self, connection, xid): def __init__(self, connection, xid):
super(TwoPhaseTransaction, self).__init__(connection, None) super(TwoPhaseTransaction, self).__init__(connection, None)
@ -2049,9 +2108,9 @@ class TwoPhaseTransaction(Transaction):
def prepare(self): def prepare(self):
"""Prepare this :class:`.TwoPhaseTransaction`. """Prepare this :class:`.TwoPhaseTransaction`.
After a PREPARE, the transaction can be committed. After a PREPARE, the transaction can be committed.
""" """
if not self._parent.is_active: if not self._parent.is_active:
raise exc.InvalidRequestError("This transaction is inactive") raise exc.InvalidRequestError("This transaction is inactive")
@ -2075,11 +2134,11 @@ class Engine(Connectable, log.Identified):
:func:`~sqlalchemy.create_engine` function. :func:`~sqlalchemy.create_engine` function.
See also: See also:
:ref:`engines_toplevel` :ref:`engines_toplevel`
:ref:`connections_toplevel` :ref:`connections_toplevel`
""" """
_execution_options = util.immutabledict() _execution_options = util.immutabledict()
@ -2115,13 +2174,13 @@ class Engine(Connectable, log.Identified):
def update_execution_options(self, **opt): def update_execution_options(self, **opt):
"""Update the default execution_options dictionary """Update the default execution_options dictionary
of this :class:`.Engine`. of this :class:`.Engine`.
The given keys/values in \**opt are added to the The given keys/values in \**opt are added to the
default execution options that will be used for default execution options that will be used for
all connections. The initial contents of this dictionary all connections. The initial contents of this dictionary
can be sent via the ``execution_options`` paramter can be sent via the ``execution_options`` paramter
to :func:`.create_engine`. to :func:`.create_engine`.
See :meth:`.Connection.execution_options` for more See :meth:`.Connection.execution_options` for more
details on execution options. details on execution options.
@ -2236,19 +2295,96 @@ class Engine(Connectable, log.Identified):
if connection is None: if connection is None:
conn.close() conn.close()
class _trans_ctx(object):
def __init__(self, conn, transaction, close_with_result):
self.conn = conn
self.transaction = transaction
self.close_with_result = close_with_result
def __enter__(self):
return self.conn
def __exit__(self, type, value, traceback):
if type is not None:
self.transaction.rollback()
else:
self.transaction.commit()
if not self.close_with_result:
self.conn.close()
def begin(self, close_with_result=False):
"""Return a context manager delivering a :class:`.Connection`
with a :class:`.Transaction` established.
E.g.::
with engine.begin() as conn:
conn.execute("insert into table (x, y, z) values (1, 2, 3)")
conn.execute("my_special_procedure(5)")
Upon successful operation, the :class:`.Transaction`
is committed. If an error is raised, the :class:`.Transaction`
is rolled back.
The ``close_with_result`` flag is normally ``False``, and indicates
that the :class:`.Connection` will be closed when the operation
is complete. When set to ``True``, it indicates the :class:`.Connection`
is in "single use" mode, where the :class:`.ResultProxy`
returned by the first call to :meth:`.Connection.execute` will
close the :class:`.Connection` when that :class:`.ResultProxy`
has exhausted all result rows.
New in 0.7.6.
See also:
:meth:`.Engine.connect` - procure a :class:`.Connection` from
an :class:`.Engine`.
:meth:`.Connection.begin` - start a :class:`.Transaction`
for a particular :class:`.Connection`.
"""
conn = self.contextual_connect(close_with_result=close_with_result)
trans = conn.begin()
return Engine._trans_ctx(conn, trans, close_with_result)
def transaction(self, callable_, *args, **kwargs): def transaction(self, callable_, *args, **kwargs):
"""Execute the given function within a transaction boundary. """Execute the given function within a transaction boundary.
The function is passed a newly procured The function is passed a :class:`.Connection` newly procured
:class:`.Connection` as the first argument, followed by from :meth:`.Engine.contextual_connect` as the first argument,
the given \*args and \**kwargs. The :class:`.Connection` followed by the given \*args and \**kwargs.
is then closed (returned to the pool) when the operation
is complete. e.g.::
def do_something(conn, x, y):
conn.execute("some statement", {'x':x, 'y':y})
engine.transaction(do_something, 5, 10)
The operations inside the function are all invoked within the
context of a single :class:`.Transaction`.
Upon success, the transaction is committed. If an
exception is raised, the transaction is rolled back
before propagating the exception.
.. note::
The :meth:`.transaction` method is superseded by
the usage of the Python ``with:`` statement, which can
be used with :meth:`.Engine.begin`::
with engine.begin() as conn:
conn.execute("some statement", {'x':5, 'y':10})
This method can be used interchangeably with See also:
:meth:`.Connection.transaction`. See that method for
more details on usage as well as a modern alternative :meth:`.Engine.begin` - engine-level transactional
using context managers (i.e. the ``with`` statement). context
:meth:`.Connection.transaction` - connection-level version of
:meth:`.Engine.transaction`
""" """
@ -2261,15 +2397,15 @@ class Engine(Connectable, log.Identified):
def run_callable(self, callable_, *args, **kwargs): def run_callable(self, callable_, *args, **kwargs):
"""Given a callable object or function, execute it, passing """Given a callable object or function, execute it, passing
a :class:`.Connection` as the first argument. a :class:`.Connection` as the first argument.
The given \*args and \**kwargs are passed subsequent The given \*args and \**kwargs are passed subsequent
to the :class:`.Connection` argument. to the :class:`.Connection` argument.
This function, along with :meth:`.Connection.run_callable`, This function, along with :meth:`.Connection.run_callable`,
allows a function to be run with a :class:`.Connection` allows a function to be run with a :class:`.Connection`
or :class:`.Engine` object without the need to know or :class:`.Engine` object without the need to know
which one is being dealt with. which one is being dealt with.
""" """
conn = self.contextual_connect() conn = self.contextual_connect()
try: try:
@ -2390,19 +2526,19 @@ class Engine(Connectable, log.Identified):
def raw_connection(self): def raw_connection(self):
"""Return a "raw" DBAPI connection from the connection pool. """Return a "raw" DBAPI connection from the connection pool.
The returned object is a proxied version of the DBAPI The returned object is a proxied version of the DBAPI
connection object used by the underlying driver in use. connection object used by the underlying driver in use.
The object will have all the same behavior as the real DBAPI The object will have all the same behavior as the real DBAPI
connection, except that its ``close()`` method will result in the connection, except that its ``close()`` method will result in the
connection being returned to the pool, rather than being closed connection being returned to the pool, rather than being closed
for real. for real.
This method provides direct DBAPI connection access for This method provides direct DBAPI connection access for
special situations. In most situations, the :class:`.Connection` special situations. In most situations, the :class:`.Connection`
object should be used, which is procured using the object should be used, which is procured using the
:meth:`.Engine.connect` method. :meth:`.Engine.connect` method.
""" """
return self.pool.unique_connection() return self.pool.unique_connection()
@ -2487,7 +2623,6 @@ except ImportError:
def __getattr__(self, name): def __getattr__(self, name):
try: try:
# TODO: no test coverage here
return self[name] return self[name]
except KeyError, e: except KeyError, e:
raise AttributeError(e.args[0]) raise AttributeError(e.args[0])
@ -2575,6 +2710,10 @@ class ResultMetaData(object):
context = parent.context context = parent.context
dialect = context.dialect dialect = context.dialect
typemap = dialect.dbapi_type_map typemap = dialect.dbapi_type_map
translate_colname = dialect._translate_colname
# high precedence key values.
primary_keymap = {}
for i, rec in enumerate(metadata): for i, rec in enumerate(metadata):
colname = rec[0] colname = rec[0]
@ -2583,6 +2722,9 @@ class ResultMetaData(object):
if dialect.description_encoding: if dialect.description_encoding:
colname = dialect._description_decoder(colname) colname = dialect._description_decoder(colname)
if translate_colname:
colname, untranslated = translate_colname(colname)
if context.result_map: if context.result_map:
try: try:
name, obj, type_ = context.result_map[colname.lower()] name, obj, type_ = context.result_map[colname.lower()]
@ -2600,15 +2742,17 @@ class ResultMetaData(object):
# indexes as keys. This is only needed for the Python version of # indexes as keys. This is only needed for the Python version of
# RowProxy (the C version uses a faster path for integer indexes). # RowProxy (the C version uses a faster path for integer indexes).
keymap[i] = rec primary_keymap[i] = rec
# Column names as keys # populate primary keymap, looking for conflicts.
if keymap.setdefault(name.lower(), rec) is not rec: if primary_keymap.setdefault(name.lower(), rec) is not rec:
# We do not raise an exception directly because several # place a record that doesn't have the "index" - this
# columns colliding by name is not a problem as long as the # is interpreted later as an AmbiguousColumnError,
# user does not try to access them (ie use an index directly, # but only when actually accessed. Columns
# or the more precise ColumnElement) # colliding by name is not a problem if those names
keymap[name.lower()] = (processor, obj, None) # aren't used; integer and ColumnElement access is always
# unambiguous.
primary_keymap[name.lower()] = (processor, obj, None)
if dialect.requires_name_normalize: if dialect.requires_name_normalize:
colname = dialect.normalize_name(colname) colname = dialect.normalize_name(colname)
@ -2618,10 +2762,20 @@ class ResultMetaData(object):
for o in obj: for o in obj:
keymap[o] = rec keymap[o] = rec
if translate_colname and \
untranslated:
keymap[untranslated] = rec
# overwrite keymap values with those of the
# high precedence keymap.
keymap.update(primary_keymap)
if parent._echo: if parent._echo:
context.engine.logger.debug( context.engine.logger.debug(
"Col %r", tuple(x[0] for x in metadata)) "Col %r", tuple(x[0] for x in metadata))
@util.pending_deprecation("0.8", "sqlite dialect uses "
"_translate_colname() now")
def _set_keymap_synonym(self, name, origname): def _set_keymap_synonym(self, name, origname):
"""Set a synonym for the given name. """Set a synonym for the given name.
@ -2647,7 +2801,7 @@ class ResultMetaData(object):
if key._label and key._label.lower() in map: if key._label and key._label.lower() in map:
result = map[key._label.lower()] result = map[key._label.lower()]
elif hasattr(key, 'name') and key.name.lower() in map: elif hasattr(key, 'name') and key.name.lower() in map:
# match is only on name. # match is only on name.
result = map[key.name.lower()] result = map[key.name.lower()]
# search extra hard to make sure this # search extra hard to make sure this
# isn't a column/label name overlap. # isn't a column/label name overlap.
@ -2800,7 +2954,7 @@ class ResultProxy(object):
@property @property
def returns_rows(self): def returns_rows(self):
"""True if this :class:`.ResultProxy` returns rows. """True if this :class:`.ResultProxy` returns rows.
I.e. if it is legal to call the methods I.e. if it is legal to call the methods
:meth:`~.ResultProxy.fetchone`, :meth:`~.ResultProxy.fetchone`,
:meth:`~.ResultProxy.fetchmany` :meth:`~.ResultProxy.fetchmany`
@ -2814,12 +2968,12 @@ class ResultProxy(object):
"""True if this :class:`.ResultProxy` is the result """True if this :class:`.ResultProxy` is the result
of a executing an expression language compiled of a executing an expression language compiled
:func:`.expression.insert` construct. :func:`.expression.insert` construct.
When True, this implies that the When True, this implies that the
:attr:`inserted_primary_key` attribute is accessible, :attr:`inserted_primary_key` attribute is accessible,
assuming the statement did not include assuming the statement did not include
a user defined "returning" construct. a user defined "returning" construct.
""" """
return self.context.isinsert return self.context.isinsert
@ -2867,7 +3021,7 @@ class ResultProxy(object):
@util.memoized_property @util.memoized_property
def inserted_primary_key(self): def inserted_primary_key(self):
"""Return the primary key for the row just inserted. """Return the primary key for the row just inserted.
The return value is a list of scalar values The return value is a list of scalar values
corresponding to the list of primary key columns corresponding to the list of primary key columns
in the target table. in the target table.
@ -2875,7 +3029,7 @@ class ResultProxy(object):
This only applies to single row :func:`.insert` This only applies to single row :func:`.insert`
constructs which did not explicitly specify constructs which did not explicitly specify
:meth:`.Insert.returning`. :meth:`.Insert.returning`.
Note that primary key columns which specify a Note that primary key columns which specify a
server_default clause, server_default clause,
or otherwise do not qualify as "autoincrement" or otherwise do not qualify as "autoincrement"

12
libs/sqlalchemy/engine/default.py

@ -44,6 +44,7 @@ class DefaultDialect(base.Dialect):
postfetch_lastrowid = True postfetch_lastrowid = True
implicit_returning = False implicit_returning = False
supports_native_enum = False supports_native_enum = False
supports_native_boolean = False supports_native_boolean = False
@ -95,6 +96,10 @@ class DefaultDialect(base.Dialect):
# and denormalize_name() must be provided. # and denormalize_name() must be provided.
requires_name_normalize = False requires_name_normalize = False
# a hook for SQLite's translation of
# result column names
_translate_colname = None
reflection_options = () reflection_options = ()
def __init__(self, convert_unicode=False, assert_unicode=False, def __init__(self, convert_unicode=False, assert_unicode=False,
@ -329,6 +334,9 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, context=None): def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters) cursor.execute(statement, parameters)
def do_execute_no_params(self, cursor, statement, context=None):
cursor.execute(statement)
def is_disconnect(self, e, connection, cursor): def is_disconnect(self, e, connection, cursor):
return False return False
@ -533,6 +541,10 @@ class DefaultExecutionContext(base.ExecutionContext):
return self return self
@util.memoized_property @util.memoized_property
def no_parameters(self):
return self.execution_options.get("no_parameters", False)
@util.memoized_property
def is_crud(self): def is_crud(self):
return self.isinsert or self.isupdate or self.isdelete return self.isinsert or self.isupdate or self.isdelete

9
libs/sqlalchemy/engine/reflection.py

@ -317,7 +317,7 @@ class Inspector(object):
info_cache=self.info_cache, **kw) info_cache=self.info_cache, **kw)
return indexes return indexes
def reflecttable(self, table, include_columns, exclude_columns=None): def reflecttable(self, table, include_columns, exclude_columns=()):
"""Given a Table object, load its internal constructs based on introspection. """Given a Table object, load its internal constructs based on introspection.
This is the underlying method used by most dialects to produce This is the underlying method used by most dialects to produce
@ -414,9 +414,12 @@ class Inspector(object):
# Primary keys # Primary keys
pk_cons = self.get_pk_constraint(table_name, schema, **tblkw) pk_cons = self.get_pk_constraint(table_name, schema, **tblkw)
if pk_cons: if pk_cons:
pk_cols = [table.c[pk]
for pk in pk_cons['constrained_columns']
if pk in table.c and pk not in exclude_columns
] + [pk for pk in table.primary_key if pk.key in exclude_columns]
primary_key_constraint = sa_schema.PrimaryKeyConstraint(name=pk_cons.get('name'), primary_key_constraint = sa_schema.PrimaryKeyConstraint(name=pk_cons.get('name'),
*[table.c[pk] for pk in pk_cons['constrained_columns'] *pk_cols
if pk in table.c]
) )
table.append_constraint(primary_key_constraint) table.append_constraint(primary_key_constraint)

6
libs/sqlalchemy/engine/strategies.py

@ -108,7 +108,8 @@ class DefaultEngineStrategy(EngineStrategy):
'timeout': 'pool_timeout', 'timeout': 'pool_timeout',
'recycle': 'pool_recycle', 'recycle': 'pool_recycle',
'events':'pool_events', 'events':'pool_events',
'use_threadlocal':'pool_threadlocal'} 'use_threadlocal':'pool_threadlocal',
'reset_on_return':'pool_reset_on_return'}
for k in util.get_cls_kwargs(poolclass): for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k) tk = translate.get(k, k)
if tk in kwargs: if tk in kwargs:
@ -226,6 +227,9 @@ class MockEngineStrategy(EngineStrategy):
def contextual_connect(self, **kwargs): def contextual_connect(self, **kwargs):
return self return self
def execution_options(self, **kw):
return self
def compiler(self, statement, parameters, **kwargs): def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler( return self._dialect.compiler(
statement, parameters, engine=self, **kwargs) statement, parameters, engine=self, **kwargs)

35
libs/sqlalchemy/event.py

@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL')
def listen(target, identifier, fn, *args, **kw): def listen(target, identifier, fn, *args, **kw):
"""Register a listener function for the given target. """Register a listener function for the given target.
e.g.:: e.g.::
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint from sqlalchemy.schema import UniqueConstraint
def unique_constraint_name(const, table): def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % ( const.name = "uq_%s_%s" % (
table.name, table.name,
@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw):
def listens_for(target, identifier, *args, **kw): def listens_for(target, identifier, *args, **kw):
"""Decorate a function as a listener for the given target + identifier. """Decorate a function as a listener for the given target + identifier.
e.g.:: e.g.::
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint from sqlalchemy.schema import UniqueConstraint
@event.listens_for(UniqueConstraint, "after_parent_attach") @event.listens_for(UniqueConstraint, "after_parent_attach")
def unique_constraint_name(const, table): def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % ( const.name = "uq_%s_%s" % (
@ -205,12 +205,14 @@ class _DispatchDescriptor(object):
def insert(self, obj, target, propagate): def insert(self, obj, target, propagate):
assert isinstance(target, type), \ assert isinstance(target, type), \
"Class-level Event targets must be classes." "Class-level Event targets must be classes."
stack = [target] stack = [target]
while stack: while stack:
cls = stack.pop(0) cls = stack.pop(0)
stack.extend(cls.__subclasses__()) stack.extend(cls.__subclasses__())
self._clslevel[cls].insert(0, obj) if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
self._clslevel[cls].insert(0, obj)
def append(self, obj, target, propagate): def append(self, obj, target, propagate):
assert isinstance(target, type), \ assert isinstance(target, type), \
@ -220,7 +222,20 @@ class _DispatchDescriptor(object):
while stack: while stack:
cls = stack.pop(0) cls = stack.pop(0)
stack.extend(cls.__subclasses__()) stack.extend(cls.__subclasses__())
self._clslevel[cls].append(obj) if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
self._clslevel[cls].append(obj)
def update_subclass(self, target):
clslevel = self._clslevel[target]
for cls in target.__mro__[1:]:
if cls in self._clslevel:
clslevel.extend([
fn for fn
in self._clslevel[cls]
if fn not in clslevel
])
def remove(self, obj, target): def remove(self, obj, target):
stack = [target] stack = [target]
@ -252,6 +267,8 @@ class _ListenerCollection(object):
_exec_once = False _exec_once = False
def __init__(self, parent, target_cls): def __init__(self, parent, target_cls):
if target_cls not in parent._clslevel:
parent.update_subclass(target_cls)
self.parent_listeners = parent._clslevel[target_cls] self.parent_listeners = parent._clslevel[target_cls]
self.name = parent.__name__ self.name = parent.__name__
self.listeners = [] self.listeners = []

23
libs/sqlalchemy/exc.py

@ -162,7 +162,7 @@ UnmappedColumnError = None
class StatementError(SQLAlchemyError): class StatementError(SQLAlchemyError):
"""An error occurred during execution of a SQL statement. """An error occurred during execution of a SQL statement.
:class:`.StatementError` wraps the exception raised :class:`StatementError` wraps the exception raised
during execution, and features :attr:`.statement` during execution, and features :attr:`.statement`
and :attr:`.params` attributes which supply context regarding and :attr:`.params` attributes which supply context regarding
the specifics of the statement which had an issue. the specifics of the statement which had an issue.
@ -172,6 +172,15 @@ class StatementError(SQLAlchemyError):
""" """
statement = None
"""The string SQL statement being invoked when this exception occurred."""
params = None
"""The parameter list being used when this exception occurred."""
orig = None
"""The DBAPI exception object."""
def __init__(self, message, statement, params, orig): def __init__(self, message, statement, params, orig):
SQLAlchemyError.__init__(self, message) SQLAlchemyError.__init__(self, message)
self.statement = statement self.statement = statement
@ -192,21 +201,21 @@ class StatementError(SQLAlchemyError):
class DBAPIError(StatementError): class DBAPIError(StatementError):
"""Raised when the execution of a database operation fails. """Raised when the execution of a database operation fails.
``DBAPIError`` wraps exceptions raised by the DB-API underlying the Wraps exceptions raised by the DB-API underlying the
database operation. Driver-specific implementations of the standard database operation. Driver-specific implementations of the standard
DB-API exception types are wrapped by matching sub-types of SQLAlchemy's DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
``DBAPIError`` when possible. DB-API's ``Error`` type maps to :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to
``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note
that there is no guarantee that different DB-API implementations will that there is no guarantee that different DB-API implementations will
raise the same exception type for any given error condition. raise the same exception type for any given error condition.
:class:`.DBAPIError` features :attr:`.statement` :class:`DBAPIError` features :attr:`~.StatementError.statement`
and :attr:`.params` attributes which supply context regarding and :attr:`~.StatementError.params` attributes which supply context regarding
the specifics of the statement which had an issue, for the the specifics of the statement which had an issue, for the
typical case when the error was raised within the context of typical case when the error was raised within the context of
emitting a SQL statement. emitting a SQL statement.
The wrapped exception object is available in the :attr:`.orig` attribute. The wrapped exception object is available in the :attr:`~.StatementError.orig` attribute.
Its type and properties are DB-API implementation specific. Its type and properties are DB-API implementation specific.
""" """

8
libs/sqlalchemy/ext/declarative.py

@ -1213,6 +1213,12 @@ def _as_declarative(cls, classname, dict_):
del our_stuff[key] del our_stuff[key]
cols = sorted(cols, key=lambda c:c._creation_order) cols = sorted(cols, key=lambda c:c._creation_order)
table = None table = None
if hasattr(cls, '__table_cls__'):
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
else:
table_cls = Table
if '__table__' not in dict_: if '__table__' not in dict_:
if tablename is not None: if tablename is not None:
@ -1230,7 +1236,7 @@ def _as_declarative(cls, classname, dict_):
if autoload: if autoload:
table_kw['autoload'] = True table_kw['autoload'] = True
cls.__table__ = table = Table(tablename, cls.metadata, cls.__table__ = table = table_cls(tablename, cls.metadata,
*(tuple(cols) + tuple(args)), *(tuple(cols) + tuple(args)),
**table_kw) **table_kw)
else: else:

65
libs/sqlalchemy/ext/hybrid.py

@ -11,30 +11,30 @@ class level and at the instance level.
The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of method The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of method
decorator, is around 50 lines of code and has almost no dependencies on the rest decorator, is around 50 lines of code and has almost no dependencies on the rest
of SQLAlchemy. It can in theory work with any class-level expression generator. of SQLAlchemy. It can, in theory, work with any descriptor-based expression
system.
Consider a table ``interval`` as below:: Consider a mapping ``Interval``, representing integer ``start`` and ``end``
values. We can define higher level functions on mapped classes that produce
from sqlalchemy import MetaData, Table, Column, Integer SQL expressions at the class level, and Python expression evaluation at the
instance level. Below, each function decorated with :class:`.hybrid_method` or
metadata = MetaData() :class:`.hybrid_property` may receive ``self`` as an instance of the class, or
as the class itself::
interval_table = Table('interval', metadata,
Column('id', Integer, primary_key=True),
Column('start', Integer, nullable=False),
Column('end', Integer, nullable=False)
)
We can define higher level functions on mapped classes that produce SQL
expressions at the class level, and Python expression evaluation at the
instance level. Below, each function decorated with :func:`.hybrid_method`
or :func:`.hybrid_property` may receive ``self`` as an instance of the class,
or as the class itself::
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, aliased
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
from sqlalchemy.orm import mapper, Session, aliased
Base = declarative_base()
class Interval(Base):
__tablename__ = 'interval'
id = Column(Integer, primary_key=True)
start = Column(Integer, nullable=False)
end = Column(Integer, nullable=False)
class Interval(object):
def __init__(self, start, end): def __init__(self, start, end):
self.start = start self.start = start
self.end = end self.end = end
@ -51,8 +51,6 @@ or as the class itself::
def intersects(self, other): def intersects(self, other):
return self.contains(other.start) | self.contains(other.end) return self.contains(other.start) | self.contains(other.end)
mapper(Interval, interval_table)
Above, the ``length`` property returns the difference between the ``end`` and Above, the ``length`` property returns the difference between the ``end`` and
``start`` attributes. With an instance of ``Interval``, this subtraction occurs ``start`` attributes. With an instance of ``Interval``, this subtraction occurs
in Python, using normal Python descriptor mechanics:: in Python, using normal Python descriptor mechanics::
@ -60,10 +58,11 @@ in Python, using normal Python descriptor mechanics::
>>> i1 = Interval(5, 10) >>> i1 = Interval(5, 10)
>>> i1.length >>> i1.length
5 5
At the class level, the usual descriptor behavior of returning the descriptor When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
itself is modified by :class:`.hybrid_property`, to instead evaluate the function descriptor evaluates the function body given the ``Interval`` class as
body given the ``Interval`` class as the argument:: the argument, which when evaluated with SQLAlchemy expression mechanics
returns a new SQL expression::
>>> print Interval.length >>> print Interval.length
interval."end" - interval.start interval."end" - interval.start
@ -83,9 +82,10 @@ locate attributes, so can also be used with hybrid attributes::
FROM interval FROM interval
WHERE interval."end" - interval.start = :param_1 WHERE interval."end" - interval.start = :param_1
The ``contains()`` and ``intersects()`` methods are decorated with :class:`.hybrid_method`. The ``Interval`` class example also illustrates two methods, ``contains()`` and ``intersects()``,
This decorator applies the same idea to methods which accept decorated with :class:`.hybrid_method`.
zero or more arguments. The above methods return boolean values, and take advantage This decorator applies the same idea to methods that :class:`.hybrid_property` applies
to attributes. The methods return boolean values, and take advantage
of the Python ``|`` and ``&`` bitwise operators to produce equivalent instance-level and of the Python ``|`` and ``&`` bitwise operators to produce equivalent instance-level and
SQL expression-level boolean behavior:: SQL expression-level boolean behavior::
@ -368,7 +368,12 @@ SQL expression versus SQL expression::
>>> sw1 = aliased(SearchWord) >>> sw1 = aliased(SearchWord)
>>> sw2 = aliased(SearchWord) >>> sw2 = aliased(SearchWord)
>>> print Session().query(sw1.word_insensitive, sw2.word_insensitive).filter(sw1.word_insensitive > sw2.word_insensitive) >>> print Session().query(
... sw1.word_insensitive,
... sw2.word_insensitive).\\
... filter(
... sw1.word_insensitive > sw2.word_insensitive
... )
SELECT lower(searchword_1.word) AS lower_1, lower(searchword_2.word) AS lower_2 SELECT lower(searchword_1.word) AS lower_1, lower(searchword_2.word) AS lower_2
FROM searchword AS searchword_1, searchword AS searchword_2 FROM searchword AS searchword_1, searchword AS searchword_2
WHERE lower(searchword_1.word) > lower(searchword_2.word) WHERE lower(searchword_1.word) > lower(searchword_2.word)

7
libs/sqlalchemy/ext/orderinglist.py

@ -184,12 +184,11 @@ class OrderingList(list):
This implementation relies on the list starting in the proper order, This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship. so be **sure** to put an ``order_by`` on your relationship.
ordering_attr :param ordering_attr:
Name of the attribute that stores the object's order in the Name of the attribute that stores the object's order in the
relationship. relationship.
ordering_func :param ordering_func: Optional. A function that maps the position in the Python list to a
Optional. A function that maps the position in the Python list to a
value to store in the ``ordering_attr``. Values returned are value to store in the ``ordering_attr``. Values returned are
usually (but need not be!) integers. usually (but need not be!) integers.
@ -202,7 +201,7 @@ class OrderingList(list):
like stepped numbering, alphabetical and Fibonacci numbering, see like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests. the unit tests.
reorder_on_append :param reorder_on_append:
Default False. When appending an object with an existing (non-None) Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a ``reorder_on_append`` is true. This is an optimization to avoid a

64
libs/sqlalchemy/orm/collections.py

@ -112,12 +112,32 @@ from sqlalchemy.sql import expression
from sqlalchemy import schema, util, exc as sa_exc from sqlalchemy import schema, util, exc as sa_exc
__all__ = ['collection', 'collection_adapter', __all__ = ['collection', 'collection_adapter',
'mapped_collection', 'column_mapped_collection', 'mapped_collection', 'column_mapped_collection',
'attribute_mapped_collection'] 'attribute_mapped_collection']
__instrumentation_mutex = util.threading.Lock() __instrumentation_mutex = util.threading.Lock()
class _SerializableColumnGetter(object):
def __init__(self, colkeys):
self.colkeys = colkeys
self.composite = len(colkeys) > 1
def __reduce__(self):
return _SerializableColumnGetter, (self.colkeys,)
def __call__(self, value):
state = instance_state(value)
m = _state_mapper(state)
key = [m._get_state_attr_by_column(
state, state.dict,
m.mapped_table.columns[k])
for k in self.colkeys]
if self.composite:
return tuple(key)
else:
return key[0]
def column_mapped_collection(mapping_spec): def column_mapped_collection(mapping_spec):
"""A dictionary-based collection type with column-based keying. """A dictionary-based collection type with column-based keying.
@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec):
after a session flush. after a session flush.
""" """
global _state_mapper, instance_state
from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.util import _state_mapper
from sqlalchemy.orm.attributes import instance_state from sqlalchemy.orm.attributes import instance_state
cols = [expression._only_column_elements(q, "mapping_spec") cols = [c.key for c in [
for q in util.to_list(mapping_spec)] expression._only_column_elements(q, "mapping_spec")
if len(cols) == 1: for q in util.to_list(mapping_spec)]]
def keyfunc(value): keyfunc = _SerializableColumnGetter(cols)
state = instance_state(value)
m = _state_mapper(state)
return m._get_state_attr_by_column(state, state.dict, cols[0])
else:
mapping_spec = tuple(cols)
def keyfunc(value):
state = instance_state(value)
m = _state_mapper(state)
return tuple(m._get_state_attr_by_column(state, state.dict, c)
for c in mapping_spec)
return lambda: MappedCollection(keyfunc) return lambda: MappedCollection(keyfunc)
class _SerializableAttrGetter(object):
def __init__(self, name):
self.name = name
self.getter = operator.attrgetter(name)
def __call__(self, target):
return self.getter(target)
def __reduce__(self):
return _SerializableAttrGetter, (self.name, )
def attribute_mapped_collection(attr_name): def attribute_mapped_collection(attr_name):
"""A dictionary-based collection type with attribute-based keying. """A dictionary-based collection type with attribute-based keying.
@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name):
after a session flush. after a session flush.
""" """
return lambda: MappedCollection(operator.attrgetter(attr_name)) getter = _SerializableAttrGetter(attr_name)
return lambda: MappedCollection(getter)
def mapped_collection(keyfunc): def mapped_collection(keyfunc):
@ -814,6 +837,7 @@ def _instrument_class(cls):
methods[name] = None, None, after methods[name] = None, None, after
# apply ABC auto-decoration to methods that need it # apply ABC auto-decoration to methods that need it
for method, decorator in decorators.items(): for method, decorator in decorators.items():
fn = getattr(cls, method, None) fn = getattr(cls, method, None)
if (fn and method not in methods and if (fn and method not in methods and
@ -1465,3 +1489,13 @@ class MappedCollection(dict):
incoming_key, value, new_key)) incoming_key, value, new_key))
yield value yield value
_convert = collection.converter(_convert) _convert = collection.converter(_convert)
# ensure instrumentation is associated with
# these built-in classes; if a user-defined class
# subclasses these and uses @internally_instrumented,
# the superclass is otherwise not instrumented.
# see [ticket:2406].
_instrument_class(MappedCollection)
_instrument_class(InstrumentedList)
_instrument_class(InstrumentedSet)

649
libs/sqlalchemy/orm/mapper.py

@ -1452,12 +1452,19 @@ class Mapper(object):
return result return result
def _is_userland_descriptor(self, obj): def _is_userland_descriptor(self, obj):
return not isinstance(obj, if isinstance(obj, (MapperProperty,
(MapperProperty, attributes.QueryableAttribute)) and \ attributes.QueryableAttribute)):
hasattr(obj, '__get__') and not \ return False
isinstance(obj.__get__(None, obj), elif not hasattr(obj, '__get__'):
attributes.QueryableAttribute) return False
else:
obj = util.unbound_method_to_callable(obj)
if isinstance(
obj.__get__(None, obj),
attributes.QueryableAttribute
):
return False
return True
def _should_exclude(self, name, assigned_name, local, column): def _should_exclude(self, name, assigned_name, local, column):
"""determine whether a particular property should be implicitly """determine whether a particular property should be implicitly
@ -1875,501 +1882,6 @@ class Mapper(object):
self._memoized_values[key] = value = callable_() self._memoized_values[key] = value = callable_()
return value return value
def _post_update(self, states, uowtransaction, post_update_cols):
"""Issue UPDATE statements on behalf of a relationship() which
specifies post_update.
"""
cached_connections = util.PopulateDict(
lambda conn:conn.execution_options(
compiled_cache=self._compiled_cache
))
# if session has a connection callable,
# organize individual states with the connection
# to use for update
if uowtransaction.session.connection_callable:
connection_callable = \
uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
tups = []
for state in _sort_states(states):
if connection_callable:
conn = connection_callable(self, state.obj())
else:
conn = connection
mapper = _state_mapper(state)
tups.append((state, state.dict, mapper, conn))
table_to_mapper = self._sorted_tables
for table in table_to_mapper:
update = []
for state, state_dict, mapper, connection in tups:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
params = {}
hasdata = False
for col in mapper._cols_by_table[table]:
if col in pks:
params[col._label] = \
mapper._get_state_attr_by_column(
state,
state_dict, col)
elif col in post_update_cols:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
value = history.added[0]
params[col.key] = value
hasdata = True
if hasdata:
update.append((state, state_dict, params, mapper,
connection))
if update:
mapper = table_to_mapper[table]
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label,
type_=col.type))
return table.update(clause)
statement = self._memo(('post_update', table), update_stmt)
# execute each UPDATE in the order according to the original
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
# to support executemany().
for key, grouper in groupby(
update, lambda rec: (rec[4], rec[2].keys())
):
multiparams = [params for state, state_dict,
params, mapper, conn in grouper]
cached_connections[connection].\
execute(statement, multiparams)
def _save_obj(self, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
This is called within the context of a UOWTransaction during a
flush operation, given a list of states to be flushed. The
base mapper in an inheritance hierarchy handles the inserts/
updates for all descendant mappers.
"""
# if batch=false, call _save_obj separately for each object
if not single and not self.batch:
for state in _sort_states(states):
self._save_obj([state],
uowtransaction,
single=True)
return
# if session has a connection callable,
# organize individual states with the connection
# to use for insert/update
if uowtransaction.session.connection_callable:
connection_callable = \
uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
tups = []
for state in _sort_states(states):
if connection_callable:
conn = connection_callable(self, state.obj())
else:
conn = connection
has_identity = bool(state.key)
mapper = _state_mapper(state)
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
# call before_XXX extensions
if not has_identity:
mapper.dispatch.before_insert(mapper, conn, state)
else:
mapper.dispatch.before_update(mapper, conn, state)
# detect if we have a "pending" instance (i.e. has
# no instance_key attached to it), and another instance
# with the same identity key already exists as persistent.
# convert to an UPDATE if so.
if not has_identity and \
instance_key in uowtransaction.session.identity_map:
instance = \
uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
if not uowtransaction.is_deleted(existing):
raise orm_exc.FlushError(
"New instance %s with identity key %s conflicts "
"with persistent instance %s" %
(state_str(state), instance_key,
state_str(existing)))
self._log_debug(
"detected row switch for identity %s. "
"will update %s, remove %s from "
"transaction", instance_key,
state_str(state), state_str(existing))
# remove the "delete" flag from the existing element
uowtransaction.remove_state_actions(existing)
row_switch = existing
tups.append(
(state, state.dict, mapper, conn,
has_identity, instance_key, row_switch)
)
# dictionary of connection->connection_with_cache_options.
cached_connections = util.PopulateDict(
lambda conn:conn.execution_options(
compiled_cache=self._compiled_cache
))
table_to_mapper = self._sorted_tables
for table in table_to_mapper:
insert = []
update = []
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in tups:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
isinsert = not has_identity and not row_switch
params = {}
value_params = {}
if isinsert:
has_all_pks = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = \
mapper.version_id_generator(None)
else:
# pull straight from the dict for
# pending objects
prop = mapper._columntoproperty[col]
value = state_dict.get(prop.key, None)
if value is None:
if col in pks:
has_all_pks = False
elif col.default is None and \
col.server_default is None:
params[col.key] = value
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
else:
params[col.key] = value
insert.append((state, state_dict, params, mapper,
connection, value_params, has_all_pks))
else:
hasdata = False
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
mapper._get_committed_state_attr_by_column(
row_switch or state,
row_switch and row_switch.dict
or state_dict,
col)
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE
)
if history.added:
params[col.key] = history.added[0]
hasdata = True
else:
params[col.key] = \
mapper.version_id_generator(
params[col._label])
# HACK: check for history, in case the
# history is only
# in a different table than the one
# where the version_id_col is.
for prop in mapper._columntoproperty.\
itervalues():
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
hasdata = True
else:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
if isinstance(history.added[0],
sql.ClauseElement):
value_params[col] = history.added[0]
else:
value = history.added[0]
params[col.key] = value
if col in pks:
if history.deleted and \
not row_switch:
# if passive_updates and sync detected
# this was a pk->pk sync, use the new
# value to locate the row, since the
# DB would already have set this
if ("pk_cascaded", state, col) in \
uowtransaction.\
attributes:
value = history.added[0]
params[col._label] = value
else:
# use the old value to
# locate the row
value = history.deleted[0]
params[col._label] = value
hasdata = True
else:
# row switch logic can reach us here
# remove the pk from the update params
# so the update doesn't
# attempt to include the pk in the
# update statement
del params[col.key]
value = history.added[0]
params[col._label] = value
if value is None and hasdata:
raise sa_exc.FlushError(
"Can't update table "
"using NULL for primary key "
"value")
else:
hasdata = True
elif col in pks:
value = state.manager[prop.key].\
impl.get(state, state_dict)
if value is None:
raise sa_exc.FlushError(
"Can't update table "
"using NULL for primary "
"key value")
params[col._label] = value
if hasdata:
update.append((state, state_dict, params, mapper,
connection, value_params))
if update:
mapper = table_to_mapper[table]
needs_version_id = mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col)
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label,
type_=col.type))
if needs_version_id:
clause.clauses.append(mapper.version_id_col ==\
sql.bindparam(mapper.version_id_col._label,
type_=col.type))
return table.update(clause)
statement = self._memo(('update', table), update_stmt)
rows = 0
for state, state_dict, params, mapper, \
connection, value_params in update:
if value_params:
c = connection.execute(
statement.values(value_params),
params)
else:
c = cached_connections[connection].\
execute(statement, params)
mapper._postfetch(
uowtransaction,
table,
state,
state_dict,
c.context.prefetch_cols,
c.context.postfetch_cols,
c.context.compiled_parameters[0],
value_params)
rows += c.rowcount
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to update %d row(s); "
"%d were matched." %
(table.description, len(update), rows))
elif needs_version_id:
util.warn("Dialect %s does not support updated rowcount "
"- versioning cannot be verified." %
c.dialect.dialect_description,
stacklevel=12)
if insert:
statement = self._memo(('insert', table), table.insert)
for (connection, pkeys, hasvalue, has_all_pks), \
records in groupby(insert,
lambda rec: (rec[4],
rec[2].keys(),
bool(rec[5]),
rec[6])
):
if has_all_pks and not hasvalue:
records = list(records)
multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
execute(statement, multiparams)
for (state, state_dict, params, mapper,
conn, value_params, has_all_pks), \
last_inserted_params in \
zip(records, c.context.compiled_parameters):
mapper._postfetch(
uowtransaction,
table,
state,
state_dict,
c.context.prefetch_cols,
c.context.postfetch_cols,
last_inserted_params,
value_params)
else:
for state, state_dict, params, mapper, \
connection, value_params, \
has_all_pks in records:
if value_params:
result = connection.execute(
statement.values(value_params),
params)
else:
result = cached_connections[connection].\
execute(statement, params)
primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
for pk, col in zip(primary_key,
mapper._pks_by_table[table]):
prop = mapper._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
#state_dict[prop.key] = pk
mapper._set_state_attr_by_column(
state,
state_dict,
col, pk)
mapper._postfetch(
uowtransaction,
table,
state,
state_dict,
result.context.prefetch_cols,
result.context.postfetch_cols,
result.context.compiled_parameters[0],
value_params)
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in tups:
if mapper._readonly_props:
readonly = state.unmodified_intersection(
[p.key for p in mapper._readonly_props
if p.expire_on_flush or p.key not in state.dict]
)
if readonly:
state.expire_attributes(state.dict, readonly)
# if eager_defaults option is enabled,
# refresh whatever has been expired.
if self.eager_defaults and state.unloaded:
state.key = self._identity_key_from_state(state)
uowtransaction.session.query(self)._load_on_ident(
state.key, refresh_state=state,
only_load_props=state.unloaded)
# call after_XXX extensions
if not has_identity:
mapper.dispatch.after_insert(mapper, connection, state)
else:
mapper.dispatch.after_update(mapper, connection, state)
def _postfetch(self, uowtransaction, table,
state, dict_, prefetch_cols, postfetch_cols,
params, value_params):
"""During a flush, expire attributes in need of newly
persisted database state."""
if self.version_id_col is not None:
prefetch_cols = list(prefetch_cols) + [self.version_id_col]
for c in prefetch_cols:
if c.key in params and c in self._columntoproperty:
self._set_state_attr_by_column(state, dict_, c, params[c.key])
if postfetch_cols:
state.expire_attributes(state.dict,
[self._columntoproperty[c].key
for c in postfetch_cols if c in
self._columntoproperty]
)
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in self._table_to_equated[table]:
sync.populate(state, m, state, m,
equated_pairs,
uowtransaction,
self.passive_updates)
@util.memoized_property @util.memoized_property
def _table_to_equated(self): def _table_to_equated(self):
"""memoized map of tables to collections of columns to be """memoized map of tables to collections of columns to be
@ -2387,128 +1899,6 @@ class Mapper(object):
return result return result
def _delete_obj(self, states, uowtransaction):
"""Issue ``DELETE`` statements for a list of objects.
This is called within the context of a UOWTransaction during a
flush operation.
"""
if uowtransaction.session.connection_callable:
connection_callable = \
uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
tups = []
cached_connections = util.PopulateDict(
lambda conn:conn.execution_options(
compiled_cache=self._compiled_cache
))
for state in _sort_states(states):
mapper = _state_mapper(state)
if connection_callable:
conn = connection_callable(self, state.obj())
else:
conn = connection
mapper.dispatch.before_delete(mapper, conn, state)
tups.append((state,
state.dict,
_state_mapper(state),
bool(state.key),
conn))
table_to_mapper = self._sorted_tables
for table in reversed(table_to_mapper.keys()):
delete = util.defaultdict(list)
for state, state_dict, mapper, has_identity, connection in tups:
if not has_identity or table not in mapper._pks_by_table:
continue
params = {}
delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = \
value = \
mapper._get_state_attr_by_column(
state, state_dict, col)
if value is None:
raise sa_exc.FlushError(
"Can't delete from table "
"using NULL for primary "
"key value")
if mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col):
params[mapper.version_id_col.key] = \
mapper._get_committed_state_attr_by_column(
state, state_dict,
mapper.version_id_col)
mapper = table_to_mapper[table]
need_version_id = mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col)
def delete_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(
col == sql.bindparam(col.key, type_=col.type))
if need_version_id:
clause.clauses.append(
mapper.version_id_col ==
sql.bindparam(
mapper.version_id_col.key,
type_=mapper.version_id_col.type
)
)
return table.delete(clause)
for connection, del_objects in delete.iteritems():
statement = self._memo(('delete', table), delete_stmt)
rows = -1
connection = cached_connections[connection]
if need_version_id and \
not connection.dialect.supports_sane_multi_rowcount:
# TODO: need test coverage for this [ticket:1761]
if connection.dialect.supports_sane_rowcount:
rows = 0
# execute deletes individually so that versioned
# rows can be verified
for params in del_objects:
c = connection.execute(statement, params)
rows += c.rowcount
else:
util.warn(
"Dialect %s does not support deleted rowcount "
"- versioning cannot be verified." %
connection.dialect.dialect_description,
stacklevel=12)
connection.execute(statement, del_objects)
else:
c = connection.execute(statement, del_objects)
if connection.dialect.supports_sane_multi_rowcount:
rows = c.rowcount
if rows != -1 and rows != len(del_objects):
raise orm_exc.StaleDataError(
"DELETE statement on table '%s' expected to delete %d row(s); "
"%d were matched." %
(table.description, len(del_objects), c.rowcount)
)
for state, state_dict, mapper, has_identity, connection in tups:
mapper.dispatch.after_delete(mapper, connection, state)
def _instance_processor(self, context, path, reduced_path, adapter, def _instance_processor(self, context, path, reduced_path, adapter,
polymorphic_from=None, polymorphic_from=None,
@ -2518,6 +1908,12 @@ class Mapper(object):
"""Produce a mapper level row processor callable """Produce a mapper level row processor callable
which processes rows into mapped instances.""" which processes rows into mapped instances."""
# note that this method, most of which exists in a closure
# called _instance(), resists being broken out, as
# attempts to do so tend to add significant function
# call overhead. _instance() is the most
# performance-critical section in the whole ORM.
pk_cols = self.primary_key pk_cols = self.primary_key
if polymorphic_from or refresh_state: if polymorphic_from or refresh_state:
@ -2961,13 +2357,6 @@ def _event_on_resurrect(state):
state, state.dict, col, val) state, state.dict, col, val)
def _sort_states(states):
pending = set(states)
persistent = set(s for s in pending if s.key is not None)
pending.difference_update(persistent)
return sorted(pending, key=operator.attrgetter("insert_order")) + \
sorted(persistent, key=lambda q:q.key[1])
class _ColumnMapping(util.py25_dict): class _ColumnMapping(util.py25_dict):
"""Error reporting helper for mapper._columntoproperty.""" """Error reporting helper for mapper._columntoproperty."""

777
libs/sqlalchemy/orm/persistence.py

@ -0,0 +1,777 @@
# orm/persistence.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""private module containing functions used to emit INSERT, UPDATE
and DELETE statements on behalf of a :class:`.Mapper` and its descending
mappers.
The functions here are called only by the unit of work functions
in unitofwork.py.
"""
import operator
from itertools import groupby
from sqlalchemy import sql, util, exc as sa_exc
from sqlalchemy.orm import attributes, sync, \
exc as orm_exc
from sqlalchemy.orm.util import _state_mapper, state_str
def save_obj(base_mapper, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
This is called within the context of a UOWTransaction during a
flush operation, given a list of states to be flushed. The
base mapper in an inheritance hierarchy handles the inserts/
updates for all descendant mappers.
"""
# if batch=false, call _save_obj separately for each object
if not single and not base_mapper.batch:
for state in _sort_states(states):
save_obj(base_mapper, [state], uowtransaction, single=True)
return
states_to_insert, states_to_update = _organize_states_for_save(
base_mapper,
states,
uowtransaction)
cached_connections = _cached_connection_dict(base_mapper)
for table, mapper in base_mapper._sorted_tables.iteritems():
insert = _collect_insert_commands(base_mapper, uowtransaction,
table, states_to_insert)
update = _collect_update_commands(base_mapper, uowtransaction,
table, states_to_update)
if update:
_emit_update_statements(base_mapper, uowtransaction,
cached_connections,
mapper, table, update)
if insert:
_emit_insert_statements(base_mapper, uowtransaction,
cached_connections,
table, insert)
_finalize_insert_update_commands(base_mapper, uowtransaction,
states_to_insert, states_to_update)
def post_update(base_mapper, states, uowtransaction, post_update_cols):
"""Issue UPDATE statements on behalf of a relationship() which
specifies post_update.
"""
cached_connections = _cached_connection_dict(base_mapper)
states_to_update = _organize_states_for_post_update(
base_mapper,
states, uowtransaction)
for table, mapper in base_mapper._sorted_tables.iteritems():
update = _collect_post_update_commands(base_mapper, uowtransaction,
table, states_to_update,
post_update_cols)
if update:
_emit_post_update_statements(base_mapper, uowtransaction,
cached_connections,
mapper, table, update)
def delete_obj(base_mapper, states, uowtransaction):
"""Issue ``DELETE`` statements for a list of objects.
This is called within the context of a UOWTransaction during a
flush operation.
"""
cached_connections = _cached_connection_dict(base_mapper)
states_to_delete = _organize_states_for_delete(
base_mapper,
states,
uowtransaction)
table_to_mapper = base_mapper._sorted_tables
for table in reversed(table_to_mapper.keys()):
delete = _collect_delete_commands(base_mapper, uowtransaction,
table, states_to_delete)
mapper = table_to_mapper[table]
_emit_delete_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, delete)
for state, state_dict, mapper, has_identity, connection \
in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
def _organize_states_for_save(base_mapper, states, uowtransaction):
"""Make an initial pass across a set of states for INSERT or
UPDATE.
This includes splitting out into distinct lists for
each, calling before_insert/before_update, obtaining
key information for each state including its dictionary,
mapper, the connection to use for the execution per state,
and the identity flag.
"""
states_to_insert = []
states_to_update = []
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
has_identity = bool(state.key)
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
# call before_XXX extensions
if not has_identity:
mapper.dispatch.before_insert(mapper, connection, state)
else:
mapper.dispatch.before_update(mapper, connection, state)
# detect if we have a "pending" instance (i.e. has
# no instance_key attached to it), and another instance
# with the same identity key already exists as persistent.
# convert to an UPDATE if so.
if not has_identity and \
instance_key in uowtransaction.session.identity_map:
instance = \
uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
if not uowtransaction.is_deleted(existing):
raise orm_exc.FlushError(
"New instance %s with identity key %s conflicts "
"with persistent instance %s" %
(state_str(state), instance_key,
state_str(existing)))
base_mapper._log_debug(
"detected row switch for identity %s. "
"will update %s, remove %s from "
"transaction", instance_key,
state_str(state), state_str(existing))
# remove the "delete" flag from the existing element
uowtransaction.remove_state_actions(existing)
row_switch = existing
if not has_identity and not row_switch:
states_to_insert.append(
(state, dict_, mapper, connection,
has_identity, instance_key, row_switch)
)
else:
states_to_update.append(
(state, dict_, mapper, connection,
has_identity, instance_key, row_switch)
)
return states_to_insert, states_to_update
def _organize_states_for_post_update(base_mapper, states,
uowtransaction):
"""Make an initial pass across a set of states for UPDATE
corresponding to post_update.
This includes obtaining key information for each state
including its dictionary, mapper, the connection to use for
the execution per state.
"""
return list(_connections_for_states(base_mapper, uowtransaction,
states))
def _organize_states_for_delete(base_mapper, states, uowtransaction):
"""Make an initial pass across a set of states for DELETE.
This includes calling out before_delete and obtaining
key information for each state including its dictionary,
mapper, the connection to use for the execution per state.
"""
states_to_delete = []
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
mapper.dispatch.before_delete(mapper, connection, state)
states_to_delete.append((state, dict_, mapper,
bool(state.key), connection))
return states_to_delete
def _collect_insert_commands(base_mapper, uowtransaction, table,
states_to_insert):
"""Identify sets of values to use in INSERT statements for a
list of states.
"""
insert = []
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in states_to_insert:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
params = {}
value_params = {}
has_all_pks = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = mapper.version_id_generator(None)
else:
# pull straight from the dict for
# pending objects
prop = mapper._columntoproperty[col]
value = state_dict.get(prop.key, None)
if value is None:
if col in pks:
has_all_pks = False
elif col.default is None and \
col.server_default is None:
params[col.key] = value
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
else:
params[col.key] = value
insert.append((state, state_dict, params, mapper,
connection, value_params, has_all_pks))
return insert
def _collect_update_commands(base_mapper, uowtransaction,
table, states_to_update):
"""Identify sets of values to use in UPDATE statements for a
list of states.
This function works intricately with the history system
to determine exactly what values should be updated
as well as how the row should be matched within an UPDATE
statement. Includes some tricky scenarios where the primary
key of an object might have been changed.
"""
update = []
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in states_to_update:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
params = {}
value_params = {}
hasdata = hasnull = False
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
mapper._get_committed_state_attr_by_column(
row_switch or state,
row_switch and row_switch.dict
or state_dict,
col)
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE
)
if history.added:
params[col.key] = history.added[0]
hasdata = True
else:
params[col.key] = mapper.version_id_generator(
params[col._label])
# HACK: check for history, in case the
# history is only
# in a different table than the one
# where the version_id_col is.
for prop in mapper._columntoproperty.itervalues():
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
hasdata = True
else:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
if isinstance(history.added[0],
sql.ClauseElement):
value_params[col] = history.added[0]
else:
value = history.added[0]
params[col.key] = value
if col in pks:
if history.deleted and \
not row_switch:
# if passive_updates and sync detected
# this was a pk->pk sync, use the new
# value to locate the row, since the
# DB would already have set this
if ("pk_cascaded", state, col) in \
uowtransaction.attributes:
value = history.added[0]
params[col._label] = value
else:
# use the old value to
# locate the row
value = history.deleted[0]
params[col._label] = value
hasdata = True
else:
# row switch logic can reach us here
# remove the pk from the update params
# so the update doesn't
# attempt to include the pk in the
# update statement
del params[col.key]
value = history.added[0]
params[col._label] = value
if value is None:
hasnull = True
else:
hasdata = True
elif col in pks:
value = state.manager[prop.key].impl.get(
state, state_dict)
if value is None:
hasnull = True
params[col._label] = value
if hasdata:
if hasnull:
raise sa_exc.FlushError(
"Can't update table "
"using NULL for primary "
"key value")
update.append((state, state_dict, params, mapper,
connection, value_params))
return update
def _collect_post_update_commands(base_mapper, uowtransaction, table,
states_to_update, post_update_cols):
"""Identify sets of values to use in UPDATE statements for a
list of states within a post_update operation.
"""
update = []
for state, state_dict, mapper, connection in states_to_update:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
params = {}
hasdata = False
for col in mapper._cols_by_table[table]:
if col in pks:
params[col._label] = \
mapper._get_state_attr_by_column(
state,
state_dict, col)
elif col in post_update_cols:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key,
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
value = history.added[0]
params[col.key] = value
hasdata = True
if hasdata:
update.append((state, state_dict, params, mapper,
connection))
return update
def _collect_delete_commands(base_mapper, uowtransaction, table,
states_to_delete):
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
delete = util.defaultdict(list)
for state, state_dict, mapper, has_identity, connection \
in states_to_delete:
if not has_identity or table not in mapper._pks_by_table:
continue
params = {}
delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = \
value = \
mapper._get_state_attr_by_column(
state, state_dict, col)
if value is None:
raise sa_exc.FlushError(
"Can't delete from table "
"using NULL for primary "
"key value")
if mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col):
params[mapper.version_id_col.key] = \
mapper._get_committed_state_attr_by_column(
state, state_dict,
mapper.version_id_col)
return delete
def _emit_update_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, update):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
needs_version_id = mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col)
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label,
type_=col.type))
if needs_version_id:
clause.clauses.append(mapper.version_id_col ==\
sql.bindparam(mapper.version_id_col._label,
type_=col.type))
return table.update(clause)
statement = base_mapper._memo(('update', table), update_stmt)
rows = 0
for state, state_dict, params, mapper, \
connection, value_params in update:
if value_params:
c = connection.execute(
statement.values(value_params),
params)
else:
c = cached_connections[connection].\
execute(statement, params)
_postfetch(
mapper,
uowtransaction,
table,
state,
state_dict,
c.context.prefetch_cols,
c.context.postfetch_cols,
c.context.compiled_parameters[0],
value_params)
rows += c.rowcount
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
"update %d row(s); %d were matched." %
(table.description, len(update), rows))
elif needs_version_id:
util.warn("Dialect %s does not support updated rowcount "
"- versioning cannot be verified." %
c.dialect.dialect_description,
stacklevel=12)
def _emit_insert_statements(base_mapper, uowtransaction,
cached_connections, table, insert):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
statement = base_mapper._memo(('insert', table), table.insert)
for (connection, pkeys, hasvalue, has_all_pks), \
records in groupby(insert,
lambda rec: (rec[4],
rec[2].keys(),
bool(rec[5]),
rec[6])
):
if has_all_pks and not hasvalue:
records = list(records)
multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
execute(statement, multiparams)
for (state, state_dict, params, mapper,
conn, value_params, has_all_pks), \
last_inserted_params in \
zip(records, c.context.compiled_parameters):
_postfetch(
mapper,
uowtransaction,
table,
state,
state_dict,
c.context.prefetch_cols,
c.context.postfetch_cols,
last_inserted_params,
value_params)
else:
for state, state_dict, params, mapper, \
connection, value_params, \
has_all_pks in records:
if value_params:
result = connection.execute(
statement.values(value_params),
params)
else:
result = cached_connections[connection].\
execute(statement, params)
primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
for pk, col in zip(primary_key,
mapper._pks_by_table[table]):
prop = mapper._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
#state_dict[prop.key] = pk
mapper._set_state_attr_by_column(
state,
state_dict,
col, pk)
_postfetch(
mapper,
uowtransaction,
table,
state,
state_dict,
result.context.prefetch_cols,
result.context.postfetch_cols,
result.context.compiled_parameters[0],
value_params)
def _emit_post_update_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, update):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label,
type_=col.type))
return table.update(clause)
statement = base_mapper._memo(('post_update', table), update_stmt)
# execute each UPDATE in the order according to the original
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
# to support executemany().
for key, grouper in groupby(
update, lambda rec: (rec[4], rec[2].keys())
):
connection = key[0]
multiparams = [params for state, state_dict,
params, mapper, conn in grouper]
cached_connections[connection].\
execute(statement, multiparams)
def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
mapper, table, delete):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
need_version_id = mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col)
def delete_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(
col == sql.bindparam(col.key, type_=col.type))
if need_version_id:
clause.clauses.append(
mapper.version_id_col ==
sql.bindparam(
mapper.version_id_col.key,
type_=mapper.version_id_col.type
)
)
return table.delete(clause)
for connection, del_objects in delete.iteritems():
statement = base_mapper._memo(('delete', table), delete_stmt)
connection = cached_connections[connection]
if need_version_id:
# TODO: need test coverage for this [ticket:1761]
if connection.dialect.supports_sane_rowcount:
rows = 0
# execute deletes individually so that versioned
# rows can be verified
for params in del_objects:
c = connection.execute(statement, params)
rows += c.rowcount
if rows != len(del_objects):
raise orm_exc.StaleDataError(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched." %
(table.description, len(del_objects), c.rowcount)
)
else:
util.warn(
"Dialect %s does not support deleted rowcount "
"- versioning cannot be verified." %
connection.dialect.dialect_description,
stacklevel=12)
connection.execute(statement, del_objects)
else:
connection.execute(statement, del_objects)
def _finalize_insert_update_commands(base_mapper, uowtransaction,
states_to_insert, states_to_update):
"""finalize state on states that have been inserted or updated,
including calling after_insert/after_update events.
"""
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in states_to_insert + \
states_to_update:
if mapper._readonly_props:
readonly = state.unmodified_intersection(
[p.key for p in mapper._readonly_props
if p.expire_on_flush or p.key not in state.dict]
)
if readonly:
state.expire_attributes(state.dict, readonly)
# if eager_defaults option is enabled,
# refresh whatever has been expired.
if base_mapper.eager_defaults and state.unloaded:
state.key = base_mapper._identity_key_from_state(state)
uowtransaction.session.query(base_mapper)._load_on_ident(
state.key, refresh_state=state,
only_load_props=state.unloaded)
# call after_XXX extensions
if not has_identity:
mapper.dispatch.after_insert(mapper, connection, state)
else:
mapper.dispatch.after_update(mapper, connection, state)
def _postfetch(mapper, uowtransaction, table,
state, dict_, prefetch_cols, postfetch_cols,
params, value_params):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
if mapper.version_id_col is not None:
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
for c in prefetch_cols:
if c.key in params and c in mapper._columntoproperty:
mapper._set_state_attr_by_column(state, dict_, c, params[c.key])
if postfetch_cols:
state.expire_attributes(state.dict,
[mapper._columntoproperty[c].key
for c in postfetch_cols if c in
mapper._columntoproperty]
)
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in mapper._table_to_equated[table]:
sync.populate(state, m, state, m,
equated_pairs,
uowtransaction,
mapper.passive_updates)
def _connections_for_states(base_mapper, uowtransaction, states):
"""Return an iterator of (state, state.dict, mapper, connection).
The states are sorted according to _sort_states, then paired
with the connection they should be using for the given
unit of work transaction.
"""
# if session has a connection callable,
# organize individual states with the connection
# to use for update
if uowtransaction.session.connection_callable:
connection_callable = \
uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(
base_mapper)
connection_callable = None
for state in _sort_states(states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
mapper = _state_mapper(state)
yield state, state.dict, mapper, connection
def _cached_connection_dict(base_mapper):
# dictionary of connection->connection_with_cache_options.
return util.PopulateDict(
lambda conn:conn.execution_options(
compiled_cache=base_mapper._compiled_cache
))
def _sort_states(states):
pending = set(states)
persistent = set(s for s in pending if s.key is not None)
pending.difference_update(persistent)
return sorted(pending, key=operator.attrgetter("insert_order")) + \
sorted(persistent, key=lambda q:q.key[1])

158
libs/sqlalchemy/orm/query.py

@ -133,7 +133,7 @@ class Query(object):
with_polymorphic = mapper._with_polymorphic_mappers with_polymorphic = mapper._with_polymorphic_mappers
if mapper.mapped_table not in \ if mapper.mapped_table not in \
self._polymorphic_adapters: self._polymorphic_adapters:
self.__mapper_loads_polymorphically_with(mapper, self._mapper_loads_polymorphically_with(mapper,
sql_util.ColumnAdapter( sql_util.ColumnAdapter(
selectable, selectable,
mapper._equivalent_columns)) mapper._equivalent_columns))
@ -150,7 +150,7 @@ class Query(object):
is_aliased_class, with_polymorphic) is_aliased_class, with_polymorphic)
ent.setup_entity(entity, *d[entity]) ent.setup_entity(entity, *d[entity])
def __mapper_loads_polymorphically_with(self, mapper, adapter): def _mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers: for m2 in mapper._with_polymorphic_mappers:
self._polymorphic_adapters[m2] = adapter self._polymorphic_adapters[m2] = adapter
for m in m2.iterate_to_root(): for m in m2.iterate_to_root():
@ -174,10 +174,6 @@ class Query(object):
self._from_obj_alias = sql_util.ColumnAdapter( self._from_obj_alias = sql_util.ColumnAdapter(
self._from_obj[0], equivs) self._from_obj[0], equivs)
def _get_polymorphic_adapter(self, entity, selectable):
self.__mapper_loads_polymorphically_with(entity.mapper,
sql_util.ColumnAdapter(selectable,
entity.mapper._equivalent_columns))
def _reset_polymorphic_adapter(self, mapper): def _reset_polymorphic_adapter(self, mapper):
for m2 in mapper._with_polymorphic_mappers: for m2 in mapper._with_polymorphic_mappers:
@ -276,6 +272,7 @@ class Query(object):
return self._select_from_entity or \ return self._select_from_entity or \
self._entity_zero().entity_zero self._entity_zero().entity_zero
@property @property
def _mapper_entities(self): def _mapper_entities(self):
# TODO: this is wrong, its hardcoded to "primary entity" when # TODO: this is wrong, its hardcoded to "primary entity" when
@ -324,13 +321,6 @@ class Query(object):
) )
return self._entity_zero() return self._entity_zero()
def _generate_mapper_zero(self):
if not getattr(self._entities[0], 'primary_entity', False):
raise sa_exc.InvalidRequestError(
"No primary mapper set up for this Query.")
entity = self._entities[0]._clone()
self._entities = [entity] + self._entities[1:]
return entity
def __all_equivs(self): def __all_equivs(self):
equivs = {} equivs = {}
@ -460,6 +450,62 @@ class Query(object):
""" """
return self.enable_eagerloads(False).statement.alias(name=name) return self.enable_eagerloads(False).statement.alias(name=name)
def cte(self, name=None, recursive=False):
"""Return the full SELECT statement represented by this :class:`.Query`
represented as a common table expression (CTE).
The :meth:`.Query.cte` method is new in 0.7.6.
Parameters and usage are the same as those of the
:meth:`._SelectBase.cte` method; see that method for
further details.
Here is the `Postgresql WITH
RECURSIVE example <http://www.postgresql.org/docs/8.4/static/queries-with.html>`_.
Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias
of it are Core selectables, which
means the columns are accessed via the ``.c.`` attribute. The ``parts_alias``
object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped
attributes are available directly::
from sqlalchemy.orm import aliased
class Part(Base):
__tablename__ = 'part'
part = Column(String, primary_key=True)
sub_part = Column(String, primary_key=True)
quantity = Column(Integer)
included_parts = session.query(
Part.sub_part,
Part.part,
Part.quantity).\\
filter(Part.part=="our part").\\
cte(name="included_parts", recursive=True)
incl_alias = aliased(included_parts, name="pr")
parts_alias = aliased(Part, name="p")
included_parts = included_parts.union_all(
session.query(
parts_alias.part,
parts_alias.sub_part,
parts_alias.quantity).\\
filter(parts_alias.part==incl_alias.c.sub_part)
)
q = session.query(
included_parts.c.sub_part,
func.sum(included_parts.c.quantity).label('total_quantity')
).\\
group_by(included_parts.c.sub_part)
See also:
:meth:`._SelectBase.cte`
"""
return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive)
def label(self, name): def label(self, name):
"""Return the full SELECT statement represented by this :class:`.Query`, converted """Return the full SELECT statement represented by this :class:`.Query`, converted
to a scalar subquery with a label of the given name. to a scalar subquery with a label of the given name.
@ -601,7 +647,12 @@ class Query(object):
such as concrete table mappers. such as concrete table mappers.
""" """
entity = self._generate_mapper_zero()
if not getattr(self._entities[0], 'primary_entity', False):
raise sa_exc.InvalidRequestError(
"No primary mapper set up for this Query.")
entity = self._entities[0]._clone()
self._entities = [entity] + self._entities[1:]
entity.set_with_polymorphic(self, entity.set_with_polymorphic(self,
cls_or_mappers, cls_or_mappers,
selectable=selectable, selectable=selectable,
@ -1041,7 +1092,22 @@ class Query(object):
@_generative() @_generative()
def with_lockmode(self, mode): def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode.""" """Return a new Query object with the specified locking mode.
:param mode: a string representing the desired locking mode. A
corresponding value is passed to the ``for_update`` parameter of
:meth:`~sqlalchemy.sql.expression.select` when the query is
executed. Valid values are:
``'update'`` - passes ``for_update=True``, which translates to
``FOR UPDATE`` (standard SQL, supported by most dialects)
``'update_nowait'`` - passes ``for_update='nowait'``, which
translates to ``FOR UPDATE NOWAIT`` (supported by Oracle)
``'read'`` - passes ``for_update='read'``, which translates to
``LOCK IN SHARE MODE`` (supported by MySQL).
"""
self._lockmode = mode self._lockmode = mode
@ -1583,7 +1649,6 @@ class Query(object):
consistent format with which to form the actual JOIN constructs. consistent format with which to form the actual JOIN constructs.
""" """
self._polymorphic_adapters = self._polymorphic_adapters.copy()
if not from_joinpoint: if not from_joinpoint:
self._reset_joinpoint() self._reset_joinpoint()
@ -1683,6 +1748,8 @@ class Query(object):
onclause, outerjoin, create_aliases, prop): onclause, outerjoin, create_aliases, prop):
"""append a JOIN to the query's from clause.""" """append a JOIN to the query's from clause."""
self._polymorphic_adapters = self._polymorphic_adapters.copy()
if left is None: if left is None:
if self._from_obj: if self._from_obj:
left = self._from_obj[0] left = self._from_obj[0]
@ -1696,7 +1763,29 @@ class Query(object):
"are the same entity" % "are the same entity" %
(left, right)) (left, right))
left_mapper, left_selectable, left_is_aliased = _entity_info(left) right, right_is_aliased, onclause = self._prepare_right_side(
right, onclause,
outerjoin, create_aliases,
prop)
# if joining on a MapperProperty path,
# track the path to prevent redundant joins
if not create_aliases and prop:
self._update_joinpoint({
'_joinpoint_entity':right,
'prev':((left, right, prop.key), self._joinpoint)
})
else:
self._joinpoint = {
'_joinpoint_entity':right
}
self._join_to_left(left, right,
right_is_aliased,
onclause, outerjoin)
def _prepare_right_side(self, right, onclause, outerjoin,
create_aliases, prop):
right_mapper, right_selectable, right_is_aliased = _entity_info(right) right_mapper, right_selectable, right_is_aliased = _entity_info(right)
if right_mapper: if right_mapper:
@ -1741,24 +1830,13 @@ class Query(object):
right = aliased(right) right = aliased(right)
need_adapter = True need_adapter = True
# if joining on a MapperProperty path,
# track the path to prevent redundant joins
if not create_aliases and prop:
self._update_joinpoint({
'_joinpoint_entity':right,
'prev':((left, right, prop.key), self._joinpoint)
})
else:
self._joinpoint = {
'_joinpoint_entity':right
}
# if an alias() of the right side was generated here, # if an alias() of the right side was generated here,
# apply an adapter to all subsequent filter() calls # apply an adapter to all subsequent filter() calls
# until reset_joinpoint() is called. # until reset_joinpoint() is called.
if need_adapter: if need_adapter:
self._filter_aliases = ORMAdapter(right, self._filter_aliases = ORMAdapter(right,
equivalents=right_mapper and right_mapper._equivalent_columns or {}, equivalents=right_mapper and
right_mapper._equivalent_columns or {},
chain_to=self._filter_aliases) chain_to=self._filter_aliases)
# if the onclause is a ClauseElement, adapt it with any # if the onclause is a ClauseElement, adapt it with any
@ -1771,7 +1849,7 @@ class Query(object):
# ensure that columns retrieved from this target in the result # ensure that columns retrieved from this target in the result
# set are also adapted. # set are also adapted.
if aliased_entity and not create_aliases: if aliased_entity and not create_aliases:
self.__mapper_loads_polymorphically_with( self._mapper_loads_polymorphically_with(
right_mapper, right_mapper,
ORMAdapter( ORMAdapter(
right, right,
@ -1779,6 +1857,11 @@ class Query(object):
) )
) )
return right, right_is_aliased, onclause
def _join_to_left(self, left, right, right_is_aliased, onclause, outerjoin):
left_mapper, left_selectable, left_is_aliased = _entity_info(left)
# this is an overly broad assumption here, but there's a # this is an overly broad assumption here, but there's a
# very wide variety of situations where we rely upon orm.join's # very wide variety of situations where we rely upon orm.join's
# adaption to glue clauses together, with joined-table inheritance's # adaption to glue clauses together, with joined-table inheritance's
@ -2959,7 +3042,9 @@ class _MapperEntity(_QueryEntity):
# with_polymorphic() can be applied to aliases # with_polymorphic() can be applied to aliases
if not self.is_aliased_class: if not self.is_aliased_class:
self.selectable = from_obj self.selectable = from_obj
self.adapter = query._get_polymorphic_adapter(self, from_obj) query._mapper_loads_polymorphically_with(self.mapper,
sql_util.ColumnAdapter(from_obj,
self.mapper._equivalent_columns))
filter_fn = id filter_fn = id
@ -3086,8 +3171,9 @@ class _MapperEntity(_QueryEntity):
class _ColumnEntity(_QueryEntity): class _ColumnEntity(_QueryEntity):
"""Column/expression based entity.""" """Column/expression based entity."""
def __init__(self, query, column): def __init__(self, query, column, namespace=None):
self.expr = column self.expr = column
self.namespace = namespace
if isinstance(column, basestring): if isinstance(column, basestring):
column = sql.literal_column(column) column = sql.literal_column(column)
@ -3106,7 +3192,7 @@ class _ColumnEntity(_QueryEntity):
for c in column._select_iterable: for c in column._select_iterable:
if c is column: if c is column:
break break
_ColumnEntity(query, c) _ColumnEntity(query, c, namespace=column)
if c is not column: if c is not column:
return return
@ -3147,12 +3233,14 @@ class _ColumnEntity(_QueryEntity):
if self.entities: if self.entities:
self.entity_zero = list(self.entities)[0] self.entity_zero = list(self.entities)[0]
elif self.namespace is not None:
self.entity_zero = self.namespace
else: else:
self.entity_zero = None self.entity_zero = None
@property @property
def entity_zero_or_selectable(self): def entity_zero_or_selectable(self):
if self.entity_zero: if self.entity_zero is not None:
return self.entity_zero return self.entity_zero
elif self.actual_froms: elif self.actual_froms:
return list(self.actual_froms)[0] return list(self.actual_froms)[0]

12
libs/sqlalchemy/orm/scoping.py

@ -41,8 +41,9 @@ class ScopedSession(object):
scope = kwargs.pop('scope', False) scope = kwargs.pop('scope', False)
if scope is not None: if scope is not None:
if self.registry.has(): if self.registry.has():
raise sa_exc.InvalidRequestError("Scoped session is already present; " raise sa_exc.InvalidRequestError(
"no new arguments may be specified.") "Scoped session is already present; "
"no new arguments may be specified.")
else: else:
sess = self.session_factory(**kwargs) sess = self.session_factory(**kwargs)
self.registry.set(sess) self.registry.set(sess)
@ -70,8 +71,8 @@ class ScopedSession(object):
self.session_factory.configure(**kwargs) self.session_factory.configure(**kwargs)
def query_property(self, query_cls=None): def query_property(self, query_cls=None):
"""return a class property which produces a `Query` object against the """return a class property which produces a `Query` object
class when called. against the class when called.
e.g.:: e.g.::
@ -121,7 +122,8 @@ def makeprop(name):
def get(self): def get(self):
return getattr(self.registry(), name) return getattr(self.registry(), name)
return property(get, set) return property(get, set)
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'): for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map',
'is_active', 'autoflush', 'no_autoflush'):
setattr(ScopedSession, prop, makeprop(prop)) setattr(ScopedSession, prop, makeprop(prop))
def clslevel(name): def clslevel(name):

43
libs/sqlalchemy/orm/session.py

@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False,
kwargs.update(new_kwargs) kwargs.update(new_kwargs)
return type("Session", (Sess, class_), {}) return type("SessionMaker", (Sess, class_), {})
class SessionTransaction(object): class SessionTransaction(object):
@ -978,6 +978,34 @@ class Session(object):
return self._query_cls(entities, self, **kwargs) return self._query_cls(entities, self, **kwargs)
@property
@util.contextmanager
def no_autoflush(self):
"""Return a context manager that disables autoflush.
e.g.::
with session.no_autoflush:
some_object = SomeClass()
session.add(some_object)
# won't autoflush
some_object.related_thing = session.query(SomeRelated).first()
Operations that proceed within the ``with:`` block
will not be subject to flushes occurring upon query
access. This is useful when initializing a series
of objects which involve existing database queries,
where the uncompleted object should not yet be flushed.
New in 0.7.6.
"""
autoflush = self.autoflush
self.autoflush = False
yield self
self.autoflush = autoflush
def _autoflush(self): def _autoflush(self):
if self.autoflush and not self._flushing: if self.autoflush and not self._flushing:
self.flush() self.flush()
@ -1772,6 +1800,19 @@ class Session(object):
return self.transaction and self.transaction.is_active return self.transaction and self.transaction.is_active
identity_map = None
"""A mapping of object identities to objects themselves.
Iterating through ``Session.identity_map.values()`` provides
access to the full set of persistent objects (i.e., those
that have row identity) currently in the session.
See also:
:func:`.identity_key` - operations involving identity keys.
"""
@property @property
def _dirty_states(self): def _dirty_states(self):
"""The set of all persistent states considered dirty. """The set of all persistent states considered dirty.

1
libs/sqlalchemy/orm/sync.py

@ -6,6 +6,7 @@
"""private module containing functions used for copying data """private module containing functions used for copying data
between instances based on join conditions. between instances based on join conditions.
""" """
from sqlalchemy.orm import exc, util as mapperutil, attributes from sqlalchemy.orm import exc, util as mapperutil, attributes

12
libs/sqlalchemy/orm/unitofwork.py

@ -14,7 +14,7 @@ organizes them in order of dependency, and executes.
from sqlalchemy import util, event from sqlalchemy import util, event
from sqlalchemy.util import topological from sqlalchemy.util import topological
from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import attributes, interfaces, persistence
from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm import util as mapperutil
session = util.importlater("sqlalchemy.orm", "session") session = util.importlater("sqlalchemy.orm", "session")
@ -462,7 +462,7 @@ class IssuePostUpdate(PostSortRec):
states, cols = uow.post_update_states[self.mapper] states, cols = uow.post_update_states[self.mapper]
states = [s for s in states if uow.states[s][0] == self.isdelete] states = [s for s in states if uow.states[s][0] == self.isdelete]
self.mapper._post_update(states, uow, cols) persistence.post_update(self.mapper, states, uow, cols)
class SaveUpdateAll(PostSortRec): class SaveUpdateAll(PostSortRec):
def __init__(self, uow, mapper): def __init__(self, uow, mapper):
@ -470,7 +470,7 @@ class SaveUpdateAll(PostSortRec):
assert mapper is mapper.base_mapper assert mapper is mapper.base_mapper
def execute(self, uow): def execute(self, uow):
self.mapper._save_obj( persistence.save_obj(self.mapper,
uow.states_for_mapper_hierarchy(self.mapper, False, False), uow.states_for_mapper_hierarchy(self.mapper, False, False),
uow uow
) )
@ -493,7 +493,7 @@ class DeleteAll(PostSortRec):
assert mapper is mapper.base_mapper assert mapper is mapper.base_mapper
def execute(self, uow): def execute(self, uow):
self.mapper._delete_obj( persistence.delete_obj(self.mapper,
uow.states_for_mapper_hierarchy(self.mapper, True, False), uow.states_for_mapper_hierarchy(self.mapper, True, False),
uow uow
) )
@ -551,7 +551,7 @@ class SaveUpdateState(PostSortRec):
if r.__class__ is cls_ and if r.__class__ is cls_ and
r.mapper is mapper] r.mapper is mapper]
recs.difference_update(our_recs) recs.difference_update(our_recs)
mapper._save_obj( persistence.save_obj(mapper,
[self.state] + [self.state] +
[r.state for r in our_recs], [r.state for r in our_recs],
uow) uow)
@ -575,7 +575,7 @@ class DeleteState(PostSortRec):
r.mapper is mapper] r.mapper is mapper]
recs.difference_update(our_recs) recs.difference_update(our_recs)
states = [self.state] + [r.state for r in our_recs] states = [self.state] + [r.state for r in our_recs]
mapper._delete_obj( persistence.delete_obj(mapper,
[s for s in states if uow.states[s][0]], [s for s in states if uow.states[s][0]],
uow) uow)

76
libs/sqlalchemy/orm/util.py

@ -11,6 +11,7 @@ from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\
PropComparator, MapperProperty PropComparator, MapperProperty
from sqlalchemy.orm import attributes, exc from sqlalchemy.orm import attributes, exc
import operator import operator
import re
mapperlib = util.importlater("sqlalchemy.orm", "mapperlib") mapperlib = util.importlater("sqlalchemy.orm", "mapperlib")
@ -20,38 +21,52 @@ all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
_INSTRUMENTOR = ('mapper', 'instrumentor') _INSTRUMENTOR = ('mapper', 'instrumentor')
class CascadeOptions(dict): class CascadeOptions(frozenset):
"""Keeps track of the options sent to relationship().cascade""" """Keeps track of the options sent to relationship().cascade"""
def __init__(self, arg=""): _add_w_all_cascades = all_cascades.difference([
if not arg: 'all', 'none', 'delete-orphan'])
values = set() _allowed_cascades = all_cascades
else:
values = set(c.strip() for c in arg.split(',')) def __new__(cls, arg):
values = set([
for name in ['save-update', 'delete', 'refresh-expire', c for c
'merge', 'expunge']: in re.split('\s*,\s*', arg or "")
boolean = name in values or 'all' in values if c
setattr(self, name.replace('-', '_'), boolean) ])
if boolean:
self[name] = True if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
"Invalid cascade option(s): %s" %
", ".join([repr(x) for x in
sorted(
values.difference(cls._allowed_cascades)
)])
)
if "all" in values:
values.update(cls._add_w_all_cascades)
if "none" in values:
values.clear()
values.discard('all')
self = frozenset.__new__(CascadeOptions, values)
self.save_update = 'save-update' in values
self.delete = 'delete' in values
self.refresh_expire = 'refresh-expire' in values
self.merge = 'merge' in values
self.expunge = 'expunge' in values
self.delete_orphan = "delete-orphan" in values self.delete_orphan = "delete-orphan" in values
if self.delete_orphan:
self['delete-orphan'] = True
if self.delete_orphan and not self.delete: if self.delete_orphan and not self.delete:
util.warn("The 'delete-orphan' cascade option requires " util.warn("The 'delete-orphan' cascade "
"'delete'.") "option requires 'delete'.")
return self
for x in values:
if x not in all_cascades:
raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
def __repr__(self): def __repr__(self):
return "CascadeOptions(%s)" % repr(",".join( return "CascadeOptions(%r)" % (
[x for x in ['delete', 'save_update', 'merge', 'expunge', ",".join([x for x in sorted(self)])
'delete_orphan', 'refresh-expire'] )
if getattr(self, x, False) is True]))
def _validator_events(desc, key, validator): def _validator_events(desc, key, validator):
"""Runs a validation method on an attribute value to be set or appended.""" """Runs a validation method on an attribute value to be set or appended."""
@ -557,15 +572,20 @@ def _entity_descriptor(entity, key):
attribute. attribute.
""" """
if not isinstance(entity, (AliasedClass, type)): if isinstance(entity, expression.FromClause):
entity = entity.class_ description = entity
entity = entity.c
elif not isinstance(entity, (AliasedClass, type)):
description = entity = entity.class_
else:
description = entity
try: try:
return getattr(entity, key) return getattr(entity, key)
except AttributeError: except AttributeError:
raise sa_exc.InvalidRequestError( raise sa_exc.InvalidRequestError(
"Entity '%s' has no property '%s'" % "Entity '%s' has no property '%s'" %
(entity, key) (description, key)
) )
def _orm_columns(entity): def _orm_columns(entity):

56
libs/sqlalchemy/pool.py

@ -57,6 +57,10 @@ def clear_managers():
manager.close() manager.close()
proxies.clear() proxies.clear()
reset_rollback = util.symbol('reset_rollback')
reset_commit = util.symbol('reset_commit')
reset_none = util.symbol('reset_none')
class Pool(log.Identified): class Pool(log.Identified):
"""Abstract base class for connection pools.""" """Abstract base class for connection pools."""
@ -130,7 +134,17 @@ class Pool(log.Identified):
self._creator = creator self._creator = creator
self._recycle = recycle self._recycle = recycle
self._use_threadlocal = use_threadlocal self._use_threadlocal = use_threadlocal
self._reset_on_return = reset_on_return if reset_on_return in ('rollback', True, reset_rollback):
self._reset_on_return = reset_rollback
elif reset_on_return in (None, False, reset_none):
self._reset_on_return = reset_none
elif reset_on_return in ('commit', reset_commit):
self._reset_on_return = reset_commit
else:
raise exc.ArgumentError(
"Invalid value for 'reset_on_return': %r"
% reset_on_return)
self.echo = echo self.echo = echo
if _dispatch: if _dispatch:
self.dispatch._update(_dispatch, only_propagate=False) self.dispatch._update(_dispatch, only_propagate=False)
@ -330,8 +344,10 @@ def _finalize_fairy(connection, connection_record, pool, ref, echo):
if connection is not None: if connection is not None:
try: try:
if pool._reset_on_return: if pool._reset_on_return is reset_rollback:
connection.rollback() connection.rollback()
elif pool._reset_on_return is reset_commit:
connection.commit()
# Immediately close detached instances # Immediately close detached instances
if connection_record is None: if connection_record is None:
connection.close() connection.close()
@ -624,11 +640,37 @@ class QueuePool(Pool):
:meth:`unique_connection` method is provided to bypass the :meth:`unique_connection` method is provided to bypass the
threadlocal behavior installed into :meth:`connect`. threadlocal behavior installed into :meth:`connect`.
:param reset_on_return: If true, reset the database state of :param reset_on_return: Determine steps to take on
connections returned to the pool. This is typically a connections as they are returned to the pool.
ROLLBACK to release locks and transaction resources. As of SQLAlchemy 0.7.6, reset_on_return can have any
Disable at your own peril. Defaults to True. of these values:
* 'rollback' - call rollback() on the connection,
to release locks and transaction resources.
This is the default value. The vast majority
of use cases should leave this value set.
* True - same as 'rollback', this is here for
backwards compatibility.
* 'commit' - call commit() on the connection,
to release locks and transaction resources.
A commit here may be desirable for databases that
cache query plans if a commit is emitted,
such as Microsoft SQL Server. However, this
value is more dangerous than 'rollback' because
any data changes present on the transaction
are committed unconditionally.
* None - don't do anything on the connection.
This setting should only be made on a database
that has no transaction support at all,
namely MySQL MyISAM. By not doing anything,
performance can be improved. This
setting should **never be selected** for a
database that supports transactions,
as it will lead to deadlocks and stale
state.
* False - same as None, this is here for
backwards compatibility.
:param listeners: A list of :param listeners: A list of
:class:`~sqlalchemy.interfaces.PoolListener`-like objects or :class:`~sqlalchemy.interfaces.PoolListener`-like objects or
dictionaries of callables that receive events when DB-API dictionaries of callables that receive events when DB-API

54
libs/sqlalchemy/schema.py

@ -80,6 +80,17 @@ def _get_table_key(name, schema):
else: else:
return schema + "." + name return schema + "." + name
def _validate_dialect_kwargs(kwargs, name):
# validate remaining kwargs that they all specify DB prefixes
if len([k for k in kwargs
if not re.match(
r'^(?:%s)_' %
'|'.join(dialects.__all__), k
)
]):
raise TypeError(
"Invalid argument(s) for %s: %r" % (name, kwargs.keys()))
class Table(SchemaItem, expression.TableClause): class Table(SchemaItem, expression.TableClause):
"""Represent a table in a database. """Represent a table in a database.
@ -369,9 +380,12 @@ class Table(SchemaItem, expression.TableClause):
# allow user-overrides # allow user-overrides
self._init_items(*args) self._init_items(*args)
def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=None): def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=()):
if self.primary_key.columns: if self.primary_key.columns:
PrimaryKeyConstraint()._set_parent_with_dispatch(self) PrimaryKeyConstraint(*[
c for c in self.primary_key.columns
if c.key in exclude_columns
])._set_parent_with_dispatch(self)
if autoload_with: if autoload_with:
autoload_with.run_callable( autoload_with.run_callable(
@ -424,7 +438,7 @@ class Table(SchemaItem, expression.TableClause):
if not autoload_replace: if not autoload_replace:
exclude_columns = [c.name for c in self.c] exclude_columns = [c.name for c in self.c]
else: else:
exclude_columns = None exclude_columns = ()
self._autoload(self.metadata, autoload_with, include_columns, exclude_columns) self._autoload(self.metadata, autoload_with, include_columns, exclude_columns)
self._extra_kwargs(**kwargs) self._extra_kwargs(**kwargs)
@ -432,14 +446,7 @@ class Table(SchemaItem, expression.TableClause):
def _extra_kwargs(self, **kwargs): def _extra_kwargs(self, **kwargs):
# validate remaining kwargs that they all specify DB prefixes # validate remaining kwargs that they all specify DB prefixes
if len([k for k in kwargs _validate_dialect_kwargs(kwargs, "Table")
if not re.match(
r'^(?:%s)_' %
'|'.join(dialects.__all__), k
)
]):
raise TypeError(
"Invalid argument(s) for Table: %r" % kwargs.keys())
self.kwargs.update(kwargs) self.kwargs.update(kwargs)
def _init_collections(self): def _init_collections(self):
@ -1028,7 +1035,7 @@ class Column(SchemaItem, expression.ColumnClause):
"The 'index' keyword argument on Column is boolean only. " "The 'index' keyword argument on Column is boolean only. "
"To create indexes with a specific name, create an " "To create indexes with a specific name, create an "
"explicit Index object external to the Table.") "explicit Index object external to the Table.")
Index(expression._generated_label('ix_%s' % self._label), self, unique=self.unique) Index(expression._truncated_label('ix_%s' % self._label), self, unique=self.unique)
elif self.unique: elif self.unique:
if isinstance(self.unique, basestring): if isinstance(self.unique, basestring):
raise exc.ArgumentError( raise exc.ArgumentError(
@ -1093,7 +1100,7 @@ class Column(SchemaItem, expression.ColumnClause):
"been assigned.") "been assigned.")
try: try:
c = self._constructor( c = self._constructor(
name or self.name, expression._as_truncated(name or self.name),
self.type, self.type,
key = name or self.key, key = name or self.key,
primary_key = self.primary_key, primary_key = self.primary_key,
@ -1119,6 +1126,8 @@ class Column(SchemaItem, expression.ColumnClause):
c.table = selectable c.table = selectable
selectable._columns.add(c) selectable._columns.add(c)
if selectable._is_clone_of is not None:
c._is_clone_of = selectable._is_clone_of.columns[c.name]
if self.primary_key: if self.primary_key:
selectable.primary_key.add(c) selectable.primary_key.add(c)
c.dispatch.after_parent_attach(c, selectable) c.dispatch.after_parent_attach(c, selectable)
@ -1809,7 +1818,8 @@ class Constraint(SchemaItem):
__visit_name__ = 'constraint' __visit_name__ = 'constraint'
def __init__(self, name=None, deferrable=None, initially=None, def __init__(self, name=None, deferrable=None, initially=None,
_create_rule=None): _create_rule=None,
**kw):
"""Create a SQL constraint. """Create a SQL constraint.
:param name: :param name:
@ -1839,6 +1849,10 @@ class Constraint(SchemaItem):
_create_rule is used by some types to create constraints. _create_rule is used by some types to create constraints.
Currently, its call signature is subject to change at any time. Currently, its call signature is subject to change at any time.
:param \**kwargs:
Dialect-specific keyword parameters, see the documentation
for various dialects and constraints regarding options here.
""" """
@ -1847,6 +1861,8 @@ class Constraint(SchemaItem):
self.initially = initially self.initially = initially
self._create_rule = _create_rule self._create_rule = _create_rule
util.set_creation_order(self) util.set_creation_order(self)
_validate_dialect_kwargs(kw, self.__class__.__name__)
self.kwargs = kw
@property @property
def table(self): def table(self):
@ -2192,6 +2208,8 @@ class Index(ColumnCollectionMixin, SchemaItem):
self.table = None self.table = None
# will call _set_parent() if table-bound column # will call _set_parent() if table-bound column
# objects are present # objects are present
if not columns:
util.warn("No column names or expressions given for Index.")
ColumnCollectionMixin.__init__(self, *columns) ColumnCollectionMixin.__init__(self, *columns)
self.name = name self.name = name
self.unique = kw.pop('unique', False) self.unique = kw.pop('unique', False)
@ -3004,9 +3022,11 @@ def _to_schema_column(element):
return element return element
def _to_schema_column_or_string(element): def _to_schema_column_or_string(element):
if hasattr(element, '__clause_element__'): if hasattr(element, '__clause_element__'):
element = element.__clause_element__() element = element.__clause_element__()
return element if not isinstance(element, (basestring, expression.ColumnElement)):
raise exc.ArgumentError("Element %r is not a string name or column element" % element)
return element
class _CreateDropBase(DDLElement): class _CreateDropBase(DDLElement):
"""Base class for DDL constucts that represent CREATE and DROP or """Base class for DDL constucts that represent CREATE and DROP or

179
libs/sqlalchemy/sql/compiler.py

@ -154,9 +154,10 @@ class _CompileLabel(visitors.Visitable):
__visit_name__ = 'label' __visit_name__ = 'label'
__slots__ = 'element', 'name' __slots__ = 'element', 'name'
def __init__(self, col, name): def __init__(self, col, name, alt_names=()):
self.element = col self.element = col
self.name = name self.name = name
self._alt_names = alt_names
@property @property
def proxy_set(self): def proxy_set(self):
@ -251,6 +252,10 @@ class SQLCompiler(engine.Compiled):
# column targeting # column targeting
self.result_map = {} self.result_map = {}
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
self.ctes_recursive = False
# true if the paramstyle is positional # true if the paramstyle is positional
self.positional = dialect.positional self.positional = dialect.positional
if self.positional: if self.positional:
@ -354,14 +359,16 @@ class SQLCompiler(engine.Compiled):
# or ORDER BY clause of a select. dialect-specific compilers # or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior. # can modify this behavior.
if within_columns_clause and not within_label_clause: if within_columns_clause and not within_label_clause:
if isinstance(label.name, sql._generated_label): if isinstance(label.name, sql._truncated_label):
labelname = self._truncated_identifier("colident", label.name) labelname = self._truncated_identifier("colident", label.name)
else: else:
labelname = label.name labelname = label.name
if result_map is not None: if result_map is not None:
result_map[labelname.lower()] = \ result_map[labelname.lower()] = (
(label.name, (label, label.element, labelname),\ label.name,
(label, label.element, labelname, ) +
label._alt_names,
label.type) label.type)
return label.element._compiler_dispatch(self, return label.element._compiler_dispatch(self,
@ -376,17 +383,19 @@ class SQLCompiler(engine.Compiled):
**kw) **kw)
def visit_column(self, column, result_map=None, **kwargs): def visit_column(self, column, result_map=None, **kwargs):
name = column.name name = orig_name = column.name
if name is None: if name is None:
raise exc.CompileError("Cannot compile Column object until " raise exc.CompileError("Cannot compile Column object until "
"it's 'name' is assigned.") "it's 'name' is assigned.")
is_literal = column.is_literal is_literal = column.is_literal
if not is_literal and isinstance(name, sql._generated_label): if not is_literal and isinstance(name, sql._truncated_label):
name = self._truncated_identifier("colident", name) name = self._truncated_identifier("colident", name)
if result_map is not None: if result_map is not None:
result_map[name.lower()] = (name, (column, ), column.type) result_map[name.lower()] = (orig_name,
(column, name, column.key),
column.type)
if is_literal: if is_literal:
name = self.escape_literal_column(name) name = self.escape_literal_column(name)
@ -404,7 +413,7 @@ class SQLCompiler(engine.Compiled):
else: else:
schema_prefix = '' schema_prefix = ''
tablename = table.name tablename = table.name
if isinstance(tablename, sql._generated_label): if isinstance(tablename, sql._truncated_label):
tablename = self._truncated_identifier("alias", tablename) tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \ return schema_prefix + \
@ -646,7 +655,8 @@ class SQLCompiler(engine.Compiled):
if name in self.binds: if name in self.binds:
existing = self.binds[name] existing = self.binds[name]
if existing is not bindparam: if existing is not bindparam:
if existing.unique or bindparam.unique: if (existing.unique or bindparam.unique) and \
not existing.proxy_set.intersection(bindparam.proxy_set):
raise exc.CompileError( raise exc.CompileError(
"Bind parameter '%s' conflicts with " "Bind parameter '%s' conflicts with "
"unique bind parameter of the same name" % "unique bind parameter of the same name" %
@ -703,7 +713,7 @@ class SQLCompiler(engine.Compiled):
return self.bind_names[bindparam] return self.bind_names[bindparam]
bind_name = bindparam.key bind_name = bindparam.key
if isinstance(bind_name, sql._generated_label): if isinstance(bind_name, sql._truncated_label):
bind_name = self._truncated_identifier("bindparam", bind_name) bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation # add to bind_names for translation
@ -715,7 +725,7 @@ class SQLCompiler(engine.Compiled):
if (ident_class, name) in self.truncated_names: if (ident_class, name) in self.truncated_names:
return self.truncated_names[(ident_class, name)] return self.truncated_names[(ident_class, name)]
anonname = name % self.anon_map anonname = name.apply_map(self.anon_map)
if len(anonname) > self.label_length: if len(anonname) > self.label_length:
counter = self.truncated_names.get(ident_class, 1) counter = self.truncated_names.get(ident_class, 1)
@ -744,10 +754,49 @@ class SQLCompiler(engine.Compiled):
else: else:
return self.bindtemplate % {'name':name} return self.bindtemplate % {'name':name}
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if isinstance(cte.name, sql._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
if cte.cte_alias:
if isinstance(cte.cte_alias, sql._truncated_label):
cte_alias = self._truncated_identifier("alias", cte.cte_alias)
else:
cte_alias = cte.cte_alias
if not cte.cte_alias and cte not in self.ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
if isinstance(cte.original, sql.Select):
col_source = cte.original
elif isinstance(cte.original, sql.CompoundSelect):
col_source = cte.original.selects[0]
else:
assert False
recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
if c is not None]
text += "(%s)" % (", ".join(recur_cols))
text += " AS \n" + \
cte.original._compiler_dispatch(
self, asfrom=True, **kwargs
)
self.ctes[cte] = text
if asfrom:
if cte.cte_alias:
text = self.preparer.format_alias(cte, cte_alias)
text += " AS " + cte_name
else:
return self.preparer.format_alias(cte, cte_name)
return text
def visit_alias(self, alias, asfrom=False, ashint=False, def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs): fromhints=None, **kwargs):
if asfrom or ashint: if asfrom or ashint:
if isinstance(alias.name, sql._generated_label): if isinstance(alias.name, sql._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name) alias_name = self._truncated_identifier("alias", alias.name)
else: else:
alias_name = alias.name alias_name = alias.name
@ -775,8 +824,14 @@ class SQLCompiler(engine.Compiled):
if isinstance(column, sql._Label): if isinstance(column, sql._Label):
return column return column
elif select is not None and select.use_labels and column._label: elif select is not None and \
return _CompileLabel(column, column._label) select.use_labels and \
column._label:
return _CompileLabel(
column,
column._label,
alt_names=(column._key_label, )
)
elif \ elif \
asfrom and \ asfrom and \
@ -784,7 +839,8 @@ class SQLCompiler(engine.Compiled):
not column.is_literal and \ not column.is_literal and \
column.table is not None and \ column.table is not None and \
not isinstance(column.table, sql.Select): not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._generated_label(column.name)) return _CompileLabel(column, sql._as_truncated(column.name),
alt_names=(column.key,))
elif not isinstance(column, elif not isinstance(column,
(sql._UnaryExpression, sql._TextClause)) \ (sql._UnaryExpression, sql._TextClause)) \
and (not hasattr(column, 'name') or \ and (not hasattr(column, 'name') or \
@ -799,6 +855,9 @@ class SQLCompiler(engine.Compiled):
def get_from_hint_text(self, table, text): def get_from_hint_text(self, table, text):
return None return None
def get_crud_hint_text(self, table, text):
return None
def visit_select(self, select, asfrom=False, parens=True, def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None, iswrapper=False, fromhints=None,
compound_index=1, **kwargs): compound_index=1, **kwargs):
@ -897,6 +956,15 @@ class SQLCompiler(engine.Compiled):
if select.for_update: if select.for_update:
text += self.for_update_clause(select) text += self.for_update_clause(select)
if self.ctes and \
compound_index==1 and not entry:
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
)
cte_text += "\n "
text = cte_text + text
self.stack.pop(-1) self.stack.pop(-1)
if asfrom and parens: if asfrom and parens:
@ -904,6 +972,12 @@ class SQLCompiler(engine.Compiled):
else: else:
return text return text
def get_cte_preamble(self, recursive):
if recursive:
return "WITH RECURSIVE"
else:
return "WITH"
def get_select_precolumns(self, select): def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just """Called when building a ``SELECT`` statement, position is just
before column list. before column list.
@ -977,12 +1051,26 @@ class SQLCompiler(engine.Compiled):
text = "INSERT" text = "INSERT"
prefixes = [self.process(x) for x in insert_stmt._prefixes] prefixes = [self.process(x) for x in insert_stmt._prefixes]
if prefixes: if prefixes:
text += " " + " ".join(prefixes) text += " " + " ".join(prefixes)
text += " INTO " + preparer.format_table(insert_stmt.table) text += " INTO " + preparer.format_table(insert_stmt.table)
if insert_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
insert_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if insert_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
insert_stmt.table,
dialect_hints[insert_stmt.table]
)
if colparams or not supports_default_values: if colparams or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0]) text += " (%s)" % ', '.join([preparer.format_column(c[0])
for c in colparams]) for c in colparams])
@ -1014,21 +1102,25 @@ class SQLCompiler(engine.Compiled):
extra_froms, **kw): extra_froms, **kw):
"""Provide a hook to override the initial table clause """Provide a hook to override the initial table clause
in an UPDATE statement. in an UPDATE statement.
MySQL overrides this. MySQL overrides this.
""" """
return self.preparer.format_table(from_table) return self.preparer.format_table(from_table)
def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): def update_from_clause(self, update_stmt,
from_table, extra_froms,
from_hints,
**kw):
"""Provide a hook to override the generation of an """Provide a hook to override the generation of an
UPDATE..FROM clause. UPDATE..FROM clause.
MySQL overrides this. MySQL overrides this.
""" """
return "FROM " + ', '.join( return "FROM " + ', '.join(
t._compiler_dispatch(self, asfrom=True, **kw) t._compiler_dispatch(self, asfrom=True,
fromhints=from_hints, **kw)
for t in extra_froms) for t in extra_froms)
def visit_update(self, update_stmt, **kw): def visit_update(self, update_stmt, **kw):
@ -1045,6 +1137,21 @@ class SQLCompiler(engine.Compiled):
update_stmt.table, update_stmt.table,
extra_froms, **kw) extra_froms, **kw)
if update_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
update_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if update_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
update_stmt.table,
dialect_hints[update_stmt.table]
)
else:
dialect_hints = None
text += ' SET ' text += ' SET '
if extra_froms and self.render_table_with_column_in_update_from: if extra_froms and self.render_table_with_column_in_update_from:
text += ', '.join( text += ', '.join(
@ -1067,7 +1174,8 @@ class SQLCompiler(engine.Compiled):
extra_from_text = self.update_from_clause( extra_from_text = self.update_from_clause(
update_stmt, update_stmt,
update_stmt.table, update_stmt.table,
extra_froms, **kw) extra_froms,
dialect_hints, **kw)
if extra_from_text: if extra_from_text:
text += " " + extra_from_text text += " " + extra_from_text
@ -1133,7 +1241,6 @@ class SQLCompiler(engine.Compiled):
for k, v in stmt.parameters.iteritems(): for k, v in stmt.parameters.iteritems():
parameters.setdefault(sql._column_as_key(k), v) parameters.setdefault(sql._column_as_key(k), v)
# create a list of column assignment clauses as tuples # create a list of column assignment clauses as tuples
values = [] values = []
@ -1192,7 +1299,7 @@ class SQLCompiler(engine.Compiled):
# "defaults", "primary key cols", etc. # "defaults", "primary key cols", etc.
for c in stmt.table.columns: for c in stmt.table.columns:
if c.key in parameters and c.key not in check_columns: if c.key in parameters and c.key not in check_columns:
value = parameters[c.key] value = parameters.pop(c.key)
if sql._is_literal(value): if sql._is_literal(value):
value = self._create_crud_bind_param( value = self._create_crud_bind_param(
c, value, required=value is required) c, value, required=value is required)
@ -1288,6 +1395,17 @@ class SQLCompiler(engine.Compiled):
self.prefetch.append(c) self.prefetch.append(c)
elif c.server_onupdate is not None: elif c.server_onupdate is not None:
self.postfetch.append(c) self.postfetch.append(c)
if parameters and stmt.parameters:
check = set(parameters).intersection(
sql._column_as_key(k) for k in stmt.parameters
).difference(check_columns)
if check:
util.warn(
"Unconsumed column names: %s" %
(", ".join(check))
)
return values return values
def visit_delete(self, delete_stmt): def visit_delete(self, delete_stmt):
@ -1296,6 +1414,21 @@ class SQLCompiler(engine.Compiled):
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
delete_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if delete_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
delete_stmt.table,
dialect_hints[delete_stmt.table]
)
else:
dialect_hints = None
if delete_stmt._returning: if delete_stmt._returning:
self.returning = delete_stmt._returning self.returning = delete_stmt._returning
if self.returning_precedes_values: if self.returning_precedes_values:
@ -1445,7 +1578,7 @@ class DDLCompiler(engine.Compiled):
return "\nDROP TABLE " + self.preparer.format_table(drop.element) return "\nDROP TABLE " + self.preparer.format_table(drop.element)
def _index_identifier(self, ident): def _index_identifier(self, ident):
if isinstance(ident, sql._generated_label): if isinstance(ident, sql._truncated_label):
max = self.dialect.max_index_name_length or \ max = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length self.dialect.max_identifier_length
if len(ident) > max: if len(ident) > max:

329
libs/sqlalchemy/sql/expression.py

@ -832,6 +832,14 @@ def tuple_(*expr):
[(1, 2), (5, 12), (10, 19)] [(1, 2), (5, 12), (10, 19)]
) )
.. warning::
The composite IN construct is not supported by all backends,
and is currently known to work on Postgresql and MySQL,
but not SQLite. Unsupported backends will raise
a subclass of :class:`~sqlalchemy.exc.DBAPIError` when such
an expression is invoked.
""" """
return _Tuple(*expr) return _Tuple(*expr)
@ -1275,14 +1283,48 @@ func = _FunctionGenerator()
# TODO: use UnaryExpression for this instead ? # TODO: use UnaryExpression for this instead ?
modifier = _FunctionGenerator(group=False) modifier = _FunctionGenerator(group=False)
class _generated_label(unicode): class _truncated_label(unicode):
"""A unicode subclass used to identify dynamically generated names.""" """A unicode subclass used to identify symbolic "
"names that may require truncation."""
def apply_map(self, map_):
return self
# for backwards compatibility in case
# someone is re-implementing the
# _truncated_identifier() sequence in a custom
# compiler
_generated_label = _truncated_label
class _anonymous_label(_truncated_label):
"""A unicode subclass used to identify anonymously
generated names."""
def __add__(self, other):
return _anonymous_label(
unicode(self) +
unicode(other))
def __radd__(self, other):
return _anonymous_label(
unicode(other) +
unicode(self))
def apply_map(self, map_):
return self % map_
def _escape_for_generated(x): def _as_truncated(value):
if isinstance(x, _generated_label): """coerce the given value to :class:`._truncated_label`.
return x
Existing :class:`._truncated_label` and
:class:`._anonymous_label` objects are passed
unchanged.
"""
if isinstance(value, _truncated_label):
return value
else: else:
return x.replace('%', '%%') return _truncated_label(value)
def _string_or_unprintable(element): def _string_or_unprintable(element):
if isinstance(element, basestring): if isinstance(element, basestring):
@ -1466,6 +1508,7 @@ class ClauseElement(Visitable):
supports_execution = False supports_execution = False
_from_objects = [] _from_objects = []
bind = None bind = None
_is_clone_of = None
def _clone(self): def _clone(self):
"""Create a shallow copy of this ClauseElement. """Create a shallow copy of this ClauseElement.
@ -1514,7 +1557,7 @@ class ClauseElement(Visitable):
f = self f = self
while f is not None: while f is not None:
s.add(f) s.add(f)
f = getattr(f, '_is_clone_of', None) f = f._is_clone_of
return s return s
def __getstate__(self): def __getstate__(self):
@ -2063,6 +2106,8 @@ class ColumnElement(ClauseElement, _CompareMixin):
foreign_keys = [] foreign_keys = []
quote = None quote = None
_label = None _label = None
_key_label = None
_alt_names = ()
@property @property
def _select_iterable(self): def _select_iterable(self):
@ -2109,9 +2154,14 @@ class ColumnElement(ClauseElement, _CompareMixin):
else: else:
key = name key = name
co = ColumnClause(name, selectable, type_=getattr(self, co = ColumnClause(_as_truncated(name),
selectable,
type_=getattr(self,
'type', None)) 'type', None))
co.proxies = [self] co.proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = \
selectable._is_clone_of.columns[key]
selectable._columns[key] = co selectable._columns[key] = co
return co return co
@ -2157,7 +2207,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
expressions and function calls. expressions and function calls.
""" """
return _generated_label('%%(%d %s)s' % (id(self), getattr(self, return _anonymous_label('%%(%d %s)s' % (id(self), getattr(self,
'name', 'anon'))) 'name', 'anon')))
class ColumnCollection(util.OrderedProperties): class ColumnCollection(util.OrderedProperties):
@ -2420,6 +2470,13 @@ class FromClause(Selectable):
""" """
def embedded(expanded_proxy_set, target_set):
for t in target_set.difference(expanded_proxy_set):
if not set(_expand_cloned([t])
).intersection(expanded_proxy_set):
return False
return True
# dont dig around if the column is locally present # dont dig around if the column is locally present
if self.c.contains_column(column): if self.c.contains_column(column):
return column return column
@ -2427,10 +2484,10 @@ class FromClause(Selectable):
target_set = column.proxy_set target_set = column.proxy_set
cols = self.c cols = self.c
for c in cols: for c in cols:
i = target_set.intersection(itertools.chain(*[p._cloned_set expanded_proxy_set = set(_expand_cloned(c.proxy_set))
for p in c.proxy_set])) i = target_set.intersection(expanded_proxy_set)
if i and (not require_embedded if i and (not require_embedded
or c.proxy_set.issuperset(target_set)): or embedded(expanded_proxy_set, target_set)):
if col is None: if col is None:
# no corresponding column yet, pick this one. # no corresponding column yet, pick this one.
@ -2580,10 +2637,10 @@ class _BindParamClause(ColumnElement):
""" """
if unique: if unique:
self.key = _generated_label('%%(%d %s)s' % (id(self), key self.key = _anonymous_label('%%(%d %s)s' % (id(self), key
or 'param')) or 'param'))
else: else:
self.key = key or _generated_label('%%(%d param)s' self.key = key or _anonymous_label('%%(%d param)s'
% id(self)) % id(self))
# identifiying key that won't change across # identifiying key that won't change across
@ -2631,14 +2688,14 @@ class _BindParamClause(ColumnElement):
def _clone(self): def _clone(self):
c = ClauseElement._clone(self) c = ClauseElement._clone(self)
if self.unique: if self.unique:
c.key = _generated_label('%%(%d %s)s' % (id(c), c._orig_key c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key
or 'param')) or 'param'))
return c return c
def _convert_to_unique(self): def _convert_to_unique(self):
if not self.unique: if not self.unique:
self.unique = True self.unique = True
self.key = _generated_label('%%(%d %s)s' % (id(self), self.key = _anonymous_label('%%(%d %s)s' % (id(self),
self._orig_key or 'param')) self._orig_key or 'param'))
def compare(self, other, **kw): def compare(self, other, **kw):
@ -3607,7 +3664,7 @@ class Alias(FromClause):
if name is None: if name is None:
if self.original.named_with_column: if self.original.named_with_column:
name = getattr(self.original, 'name', None) name = getattr(self.original, 'name', None)
name = _generated_label('%%(%d %s)s' % (id(self), name name = _anonymous_label('%%(%d %s)s' % (id(self), name
or 'anon')) or 'anon'))
self.name = name self.name = name
@ -3662,6 +3719,47 @@ class Alias(FromClause):
def bind(self): def bind(self):
return self.element.bind return self.element.bind
class CTE(Alias):
"""Represent a Common Table Expression.
The :class:`.CTE` object is obtained using the
:meth:`._SelectBase.cte` method from any selectable.
See that method for complete examples.
New in 0.7.6.
"""
__visit_name__ = 'cte'
def __init__(self, selectable,
name=None,
recursive=False,
cte_alias=False):
self.recursive = recursive
self.cte_alias = cte_alias
super(CTE, self).__init__(selectable, name=name)
def alias(self, name=None):
return CTE(
self.original,
name=name,
recursive=self.recursive,
cte_alias = self.name
)
def union(self, other):
return CTE(
self.original.union(other),
name=self.name,
recursive=self.recursive
)
def union_all(self, other):
return CTE(
self.original.union_all(other),
name=self.name,
recursive=self.recursive
)
class _Grouping(ColumnElement): class _Grouping(ColumnElement):
"""Represent a grouping within a column expression""" """Represent a grouping within a column expression"""
@ -3807,9 +3905,12 @@ class _Label(ColumnElement):
def __init__(self, name, element, type_=None): def __init__(self, name, element, type_=None):
while isinstance(element, _Label): while isinstance(element, _Label):
element = element.element element = element.element
self.name = self.key = self._label = name \ if name:
or _generated_label('%%(%d %s)s' % (id(self), self.name = name
else:
self.name = _anonymous_label('%%(%d %s)s' % (id(self),
getattr(element, 'name', 'anon'))) getattr(element, 'name', 'anon')))
self.key = self._label = self._key_label = self.name
self._element = element self._element = element
self._type = type_ self._type = type_
self.quote = element.quote self.quote = element.quote
@ -3957,7 +4058,17 @@ class ColumnClause(_Immutable, ColumnElement):
# end Py2K # end Py2K
@_memoized_property @_memoized_property
def _key_label(self):
if self.key != self.name:
return self._gen_label(self.key)
else:
return self._label
@_memoized_property
def _label(self): def _label(self):
return self._gen_label(self.name)
def _gen_label(self, name):
t = self.table t = self.table
if self.is_literal: if self.is_literal:
return None return None
@ -3965,11 +4076,9 @@ class ColumnClause(_Immutable, ColumnElement):
elif t is not None and t.named_with_column: elif t is not None and t.named_with_column:
if getattr(t, 'schema', None): if getattr(t, 'schema', None):
label = t.schema.replace('.', '_') + "_" + \ label = t.schema.replace('.', '_') + "_" + \
_escape_for_generated(t.name) + "_" + \ t.name + "_" + name
_escape_for_generated(self.name)
else: else:
label = _escape_for_generated(t.name) + "_" + \ label = t.name + "_" + name
_escape_for_generated(self.name)
# ensure the label name doesn't conflict with that # ensure the label name doesn't conflict with that
# of an existing column # of an existing column
@ -3981,10 +4090,10 @@ class ColumnClause(_Immutable, ColumnElement):
counter += 1 counter += 1
label = _label label = _label
return _generated_label(label) return _as_truncated(label)
else: else:
return self.name return name
def label(self, name): def label(self, name):
# currently, anonymous labels don't occur for # currently, anonymous labels don't occur for
@ -4010,12 +4119,15 @@ class ColumnClause(_Immutable, ColumnElement):
# otherwise its considered to be a label # otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name) is_literal = self.is_literal and (name is None or name == self.name)
c = self._constructor( c = self._constructor(
name or self.name, _as_truncated(name or self.name),
selectable=selectable, selectable=selectable,
type_=self.type, type_=self.type,
is_literal=is_literal is_literal=is_literal
) )
c.proxies = [self] c.proxies = [self]
if selectable._is_clone_of is not None:
c._is_clone_of = \
selectable._is_clone_of.columns[c.name]
if attach: if attach:
selectable._columns[c.name] = c selectable._columns[c.name] = c
@ -4218,6 +4330,125 @@ class _SelectBase(Executable, FromClause):
""" """
return self.as_scalar().label(name) return self.as_scalar().label(name)
def cte(self, name=None, recursive=False):
"""Return a new :class:`.CTE`, or Common Table Expression instance.
Common table expressions are a SQL standard whereby SELECT
statements can draw upon secondary statements specified along
with the primary statement, using a clause called "WITH".
Special semantics regarding UNION can also be employed to
allow "recursive" queries, where a SELECT statement can draw
upon the set of rows that have previously been selected.
SQLAlchemy detects :class:`.CTE` objects, which are treated
similarly to :class:`.Alias` objects, as special elements
to be delivered to the FROM clause of the statement as well
as to a WITH clause at the top of the statement.
The :meth:`._SelectBase.cte` method is new in 0.7.6.
:param name: name given to the common table expression. Like
:meth:`._FromClause.alias`, the name can be left as ``None``
in which case an anonymous symbol will be used at query
compile time.
:param recursive: if ``True``, will render ``WITH RECURSIVE``.
A recursive common table expression is intended to be used in
conjunction with UNION ALL in order to derive rows
from those already selected.
The following examples illustrate two examples from
Postgresql's documentation at
http://www.postgresql.org/docs/8.4/static/queries-with.html.
Example 1, non recursive::
from sqlalchemy import Table, Column, String, Integer, MetaData, \\
select, func
metadata = MetaData()
orders = Table('orders', metadata,
Column('region', String),
Column('amount', Integer),
Column('product', String),
Column('quantity', Integer)
)
regional_sales = select([
orders.c.region,
func.sum(orders.c.amount).label('total_sales')
]).group_by(orders.c.region).cte("regional_sales")
top_regions = select([regional_sales.c.region]).\\
where(
regional_sales.c.total_sales >
select([
func.sum(regional_sales.c.total_sales)/10
])
).cte("top_regions")
statement = select([
orders.c.region,
orders.c.product,
func.sum(orders.c.quantity).label("product_units"),
func.sum(orders.c.amount).label("product_sales")
]).where(orders.c.region.in_(
select([top_regions.c.region])
)).group_by(orders.c.region, orders.c.product)
result = conn.execute(statement).fetchall()
Example 2, WITH RECURSIVE::
from sqlalchemy import Table, Column, String, Integer, MetaData, \\
select, func
metadata = MetaData()
parts = Table('parts', metadata,
Column('part', String),
Column('sub_part', String),
Column('quantity', Integer),
)
included_parts = select([
parts.c.sub_part,
parts.c.part,
parts.c.quantity]).\\
where(parts.c.part=='our part').\\
cte(recursive=True)
incl_alias = included_parts.alias()
parts_alias = parts.alias()
included_parts = included_parts.union_all(
select([
parts_alias.c.part,
parts_alias.c.sub_part,
parts_alias.c.quantity
]).
where(parts_alias.c.part==incl_alias.c.sub_part)
)
statement = select([
included_parts.c.sub_part,
func.sum(included_parts.c.quantity).label('total_quantity')
]).\
select_from(included_parts.join(parts,
included_parts.c.part==parts.c.part)).\\
group_by(included_parts.c.sub_part)
result = conn.execute(statement).fetchall()
See also:
:meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`.
"""
return CTE(self, name=name, recursive=recursive)
@_generative @_generative
@util.deprecated('0.6', @util.deprecated('0.6',
message=":func:`.autocommit` is deprecated. Use " message=":func:`.autocommit` is deprecated. Use "
@ -4602,7 +4833,7 @@ class Select(_SelectBase):
The text of the hint is rendered in the appropriate The text of the hint is rendered in the appropriate
location for the database backend in use, relative location for the database backend in use, relative
to the given :class:`.Table` or :class:`.Alias` passed as the to the given :class:`.Table` or :class:`.Alias` passed as the
*selectable* argument. The dialect implementation ``selectable`` argument. The dialect implementation
typically uses Python string substitution syntax typically uses Python string substitution syntax
with the token ``%(name)s`` to render the name of with the token ``%(name)s`` to render the name of
the table or alias. E.g. when using Oracle, the the table or alias. E.g. when using Oracle, the
@ -4999,7 +5230,9 @@ class Select(_SelectBase):
def _populate_column_collection(self): def _populate_column_collection(self):
for c in self.inner_columns: for c in self.inner_columns:
if hasattr(c, '_make_proxy'): if hasattr(c, '_make_proxy'):
c._make_proxy(self, name=self.use_labels and c._label or None) c._make_proxy(self,
name=self.use_labels
and c._label or None)
def self_group(self, against=None): def self_group(self, against=None):
"""return a 'grouping' construct as per the ClauseElement """return a 'grouping' construct as per the ClauseElement
@ -5086,6 +5319,7 @@ class UpdateBase(Executable, ClauseElement):
_execution_options = \ _execution_options = \
Executable._execution_options.union({'autocommit': True}) Executable._execution_options.union({'autocommit': True})
kwargs = util.immutabledict() kwargs = util.immutabledict()
_hints = util.immutabledict()
def _process_colparams(self, parameters): def _process_colparams(self, parameters):
if isinstance(parameters, (list, tuple)): if isinstance(parameters, (list, tuple)):
@ -5166,6 +5400,45 @@ class UpdateBase(Executable, ClauseElement):
""" """
self._returning = cols self._returning = cols
@_generative
def with_hint(self, text, selectable=None, dialect_name="*"):
"""Add a table hint for a single table to this
INSERT/UPDATE/DELETE statement.
.. note::
:meth:`.UpdateBase.with_hint` currently applies only to
Microsoft SQL Server. For MySQL INSERT hints, use
:meth:`.Insert.prefix_with`. UPDATE/DELETE hints for
MySQL will be added in a future release.
The text of the hint is rendered in the appropriate
location for the database backend in use, relative
to the :class:`.Table` that is the subject of this
statement, or optionally to that of the given
:class:`.Table` passed as the ``selectable`` argument.
The ``dialect_name`` option will limit the rendering of a particular
hint to a particular backend. Such as, to add a hint
that only takes effect for SQL Server::
mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
New in 0.7.6.
:param text: Text of the hint.
:param selectable: optional :class:`.Table` that specifies
an element of the FROM clause within an UPDATE or DELETE
to be the subject of the hint - applies only to certain backends.
:param dialect_name: defaults to ``*``, if specified as the name
of a particular dialect, will apply these hints only when
that dialect is in use.
"""
if selectable is None:
selectable = self.table
self._hints = self._hints.union({(selectable, dialect_name):text})
class ValuesBase(UpdateBase): class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs.""" """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs."""

32
libs/sqlalchemy/sql/visitors.py

@ -34,11 +34,19 @@ __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'cloned_traverse', 'replacement_traverse'] 'cloned_traverse', 'replacement_traverse']
class VisitableType(type): class VisitableType(type):
"""Metaclass which checks for a `__visit_name__` attribute and """Metaclass which assigns a `_compiler_dispatch` method to classes
applies `_compiler_dispatch` method to classes. having a `__visit_name__` attribute.
The _compiler_dispatch attribute becomes an instance method which
looks approximately like the following::
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
Classes having no __visit_name__ attribute will remain unaffected.
""" """
def __init__(cls, clsname, bases, clsdict): def __init__(cls, clsname, bases, clsdict):
if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
super(VisitableType, cls).__init__(clsname, bases, clsdict) super(VisitableType, cls).__init__(clsname, bases, clsdict)
@ -48,19 +56,31 @@ class VisitableType(type):
super(VisitableType, cls).__init__(clsname, bases, clsdict) super(VisitableType, cls).__init__(clsname, bases, clsdict)
def _generate_dispatch(cls): def _generate_dispatch(cls):
# set up an optimized visit dispatch function """Return an optimized visit dispatch function for the cls
# for use by the compiler for use by the compiler.
"""
if '__visit_name__' in cls.__dict__: if '__visit_name__' in cls.__dict__:
visit_name = cls.__visit_name__ visit_name = cls.__visit_name__
if isinstance(visit_name, str): if isinstance(visit_name, str):
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name) getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw): def _compiler_dispatch(self, visitor, **kw):
return getter(visitor)(self, **kw) return getter(visitor)(self, **kw)
else: else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw): def _compiler_dispatch(self, visitor, **kw):
return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
_compiler_dispatch.__doc__ = \
"""Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch cls._compiler_dispatch = _compiler_dispatch
class Visitable(object): class Visitable(object):

67
libs/sqlalchemy/types.py

@ -397,7 +397,7 @@ class TypeDecorator(TypeEngine):
def copy(self): def copy(self):
return MyType(self.impl.length) return MyType(self.impl.length)
The class-level "impl" variable is required, and can reference any The class-level "impl" attribute is required, and can reference any
TypeEngine class. Alternatively, the load_dialect_impl() method TypeEngine class. Alternatively, the load_dialect_impl() method
can be used to provide different type classes based on the dialect can be used to provide different type classes based on the dialect
given; in this case, the "impl" variable can reference given; in this case, the "impl" variable can reference
@ -457,15 +457,19 @@ class TypeDecorator(TypeEngine):
Arguments sent here are passed to the constructor Arguments sent here are passed to the constructor
of the class assigned to the ``impl`` class level attribute, of the class assigned to the ``impl`` class level attribute,
where the ``self.impl`` attribute is assigned an instance assuming the ``impl`` is a callable, and the resulting
of the implementation type. If ``impl`` at the class level object is assigned to the ``self.impl`` instance attribute
is already an instance, then it's assigned to ``self.impl`` (thus overriding the class attribute of the same name).
as is.
If the class level ``impl`` is not a callable (the unusual case),
it will be assigned to the same instance attribute 'as-is',
ignoring those arguments passed to the constructor.
Subclasses can override this to customize the generation Subclasses can override this to customize the generation
of ``self.impl``. of ``self.impl`` entirely.
""" """
if not hasattr(self.__class__, 'impl'): if not hasattr(self.__class__, 'impl'):
raise AssertionError("TypeDecorator implementations " raise AssertionError("TypeDecorator implementations "
"require a class-level variable " "require a class-level variable "
@ -475,6 +479,9 @@ class TypeDecorator(TypeEngine):
def _gen_dialect_impl(self, dialect): def _gen_dialect_impl(self, dialect):
"""
#todo
"""
adapted = dialect.type_descriptor(self) adapted = dialect.type_descriptor(self)
if adapted is not self: if adapted is not self:
return adapted return adapted
@ -494,6 +501,9 @@ class TypeDecorator(TypeEngine):
@property @property
def _type_affinity(self): def _type_affinity(self):
"""
#todo
"""
return self.impl._type_affinity return self.impl._type_affinity
def type_engine(self, dialect): def type_engine(self, dialect):
@ -531,7 +541,6 @@ class TypeDecorator(TypeEngine):
def __getattr__(self, key): def __getattr__(self, key):
"""Proxy all other undefined accessors to the underlying """Proxy all other undefined accessors to the underlying
implementation.""" implementation."""
return getattr(self.impl, key) return getattr(self.impl, key)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
@ -542,29 +551,52 @@ class TypeDecorator(TypeEngine):
:class:`.TypeEngine` object, and from there to the :class:`.TypeEngine` object, and from there to the
DBAPI ``execute()`` method. DBAPI ``execute()`` method.
:param value: the value. Can be None. The operation could be anything desired to perform custom
behavior, such as transforming or serializing data.
This could also be used as a hook for validating logic.
This operation should be designed with the reverse operation
in mind, which would be the process_result_value method of
this class.
:param value: Data to operate upon, of any type expected by
this method in the subclass. Can be ``None``.
:param dialect: the :class:`.Dialect` in use. :param dialect: the :class:`.Dialect` in use.
""" """
raise NotImplementedError() raise NotImplementedError()
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
"""Receive a result-row column value to be converted. """Receive a result-row column value to be converted.
Subclasses should implement this method to operate on data
fetched from the database.
Subclasses override this method to return the Subclasses override this method to return the
value that should be passed back to the application, value that should be passed back to the application,
given a value that is already processed by given a value that is already processed by
the underlying :class:`.TypeEngine` object, originally the underlying :class:`.TypeEngine` object, originally
from the DBAPI cursor method ``fetchone()`` or similar. from the DBAPI cursor method ``fetchone()`` or similar.
:param value: the value. Can be None. The operation could be anything desired to perform custom
behavior, such as transforming or serializing data.
This could also be used as a hook for validating logic.
:param value: Data to operate upon, of any type expected by
this method in the subclass. Can be ``None``.
:param dialect: the :class:`.Dialect` in use. :param dialect: the :class:`.Dialect` in use.
This operation should be designed to be reversible by
the "process_bind_param" method of this class.
""" """
raise NotImplementedError() raise NotImplementedError()
def bind_processor(self, dialect): def bind_processor(self, dialect):
"""Provide a bound value processing function for the given :class:`.Dialect`. """Provide a bound value processing function for the
given :class:`.Dialect`.
This is the method that fulfills the :class:`.TypeEngine` This is the method that fulfills the :class:`.TypeEngine`
contract for bound value conversion. :class:`.TypeDecorator` contract for bound value conversion. :class:`.TypeDecorator`
@ -575,6 +607,11 @@ class TypeDecorator(TypeEngine):
though its likely best to use :meth:`process_bind_param` so that though its likely best to use :meth:`process_bind_param` so that
the processing provided by ``self.impl`` is maintained. the processing provided by ``self.impl`` is maintained.
:param dialect: Dialect instance in use.
This method is the reverse counterpart to the
:meth:`result_processor` method of this class.
""" """
if self.__class__.process_bind_param.func_code \ if self.__class__.process_bind_param.func_code \
is not TypeDecorator.process_bind_param.func_code: is not TypeDecorator.process_bind_param.func_code:
@ -604,6 +641,12 @@ class TypeDecorator(TypeEngine):
though its likely best to use :meth:`process_result_value` so that though its likely best to use :meth:`process_result_value` so that
the processing provided by ``self.impl`` is maintained. the processing provided by ``self.impl`` is maintained.
:param dialect: Dialect instance in use.
:param coltype: An SQLAlchemy data type
This method is the reverse counterpart to the
:meth:`bind_processor` method of this class.
""" """
if self.__class__.process_result_value.func_code \ if self.__class__.process_result_value.func_code \
is not TypeDecorator.process_result_value.func_code: is not TypeDecorator.process_result_value.func_code:
@ -654,6 +697,7 @@ class TypeDecorator(TypeEngine):
has local state that should be deep-copied. has local state that should be deep-copied.
""" """
instance = self.__class__.__new__(self.__class__) instance = self.__class__.__new__(self.__class__)
instance.__dict__.update(self.__dict__) instance.__dict__.update(self.__dict__)
return instance return instance
@ -724,6 +768,9 @@ class TypeDecorator(TypeEngine):
return self.impl.is_mutable() return self.impl.is_mutable()
def _adapt_expression(self, op, othertype): def _adapt_expression(self, op, othertype):
"""
#todo
"""
op, typ =self.impl._adapt_expression(op, othertype) op, typ =self.impl._adapt_expression(op, othertype)
if typ is self.impl: if typ is self.impl:
return op, self return op, self

2
libs/sqlalchemy/util/__init__.py

@ -7,7 +7,7 @@
from compat import callable, cmp, reduce, defaultdict, py25_dict, \ from compat import callable, cmp, reduce, defaultdict, py25_dict, \
threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \ threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \
update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\ update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\
parse_qsl, any parse_qsl, any, contextmanager
from _collections import NamedTuple, ImmutableContainer, immutabledict, \ from _collections import NamedTuple, ImmutableContainer, immutabledict, \
Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ Properties, OrderedProperties, ImmutableProperties, OrderedDict, \

6
libs/sqlalchemy/util/compat.py

@ -57,6 +57,12 @@ buffer = buffer
# end Py2K # end Py2K
try: try:
from contextlib import contextmanager
except ImportError:
def contextmanager(fn):
return fn
try:
from functools import update_wrapper from functools import update_wrapper
except ImportError: except ImportError:
def update_wrapper(wrapper, wrapped, def update_wrapper(wrapper, wrapped,

Loading…
Cancel
Save