diff --git a/libs/sqlalchemy/__init__.py b/libs/sqlalchemy/__init__.py index ef5f385..03293b5 100644 --- a/libs/sqlalchemy/__init__.py +++ b/libs/sqlalchemy/__init__.py @@ -117,7 +117,7 @@ from sqlalchemy.engine import create_engine, engine_from_config __all__ = sorted(name for name, obj in locals().items() if not (name.startswith('_') or inspect.ismodule(obj))) -__version__ = '0.7.5' +__version__ = '0.7.6' del inspect, sys diff --git a/libs/sqlalchemy/cextension/processors.c b/libs/sqlalchemy/cextension/processors.c index b539f68..427db5d 100644 --- a/libs/sqlalchemy/cextension/processors.c +++ b/libs/sqlalchemy/cextension/processors.c @@ -342,23 +342,18 @@ DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value) if (value == Py_None) Py_RETURN_NONE; - if (PyFloat_CheckExact(value)) { - /* Decimal does not accept float values directly */ - args = PyTuple_Pack(1, value); - if (args == NULL) - return NULL; + args = PyTuple_Pack(1, value); + if (args == NULL) + return NULL; - str = PyString_Format(self->format, args); - Py_DECREF(args); - if (str == NULL) - return NULL; + str = PyString_Format(self->format, args); + Py_DECREF(args); + if (str == NULL) + return NULL; - result = PyObject_CallFunctionObjArgs(self->type, str, NULL); - Py_DECREF(str); - return result; - } else { - return PyObject_CallFunctionObjArgs(self->type, value, NULL); - } + result = PyObject_CallFunctionObjArgs(self->type, str, NULL); + Py_DECREF(str); + return result; } static void diff --git a/libs/sqlalchemy/cextension/resultproxy.c b/libs/sqlalchemy/cextension/resultproxy.c index 64b6855..3494cca 100644 --- a/libs/sqlalchemy/cextension/resultproxy.c +++ b/libs/sqlalchemy/cextension/resultproxy.c @@ -246,6 +246,7 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) PyObject *exc_module, *exception; char *cstr_key; long index; + int key_fallback = 0; if (PyInt_CheckExact(key)) { index = PyInt_AS_LONG(key); @@ -276,12 +277,17 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) "O", key); if (record == NULL) return NULL; + key_fallback = 1; } indexobject = PyTuple_GetItem(record, 2); if (indexobject == NULL) return NULL; + if (key_fallback) { + Py_DECREF(record); + } + if (indexobject == Py_None) { exc_module = PyImport_ImportModule("sqlalchemy.exc"); if (exc_module == NULL) @@ -347,7 +353,16 @@ BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name) else return tmp; - return BaseRowProxy_subscript(self, name); + tmp = BaseRowProxy_subscript(self, name); + if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) { + PyErr_Format( + PyExc_AttributeError, + "Could not locate column in row for column '%.200s'", + PyString_AsString(name) + ); + return NULL; + } + return tmp; } /*********************** diff --git a/libs/sqlalchemy/dialects/firebird/base.py b/libs/sqlalchemy/dialects/firebird/base.py index 8cf2ded..031c689 100644 --- a/libs/sqlalchemy/dialects/firebird/base.py +++ b/libs/sqlalchemy/dialects/firebird/base.py @@ -215,7 +215,7 @@ class FBCompiler(sql.compiler.SQLCompiler): # Override to not use the AS keyword which FB 1.5 does not like if asfrom: alias_name = isinstance(alias.name, - expression._generated_label) and \ + expression._truncated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name diff --git a/libs/sqlalchemy/dialects/mssql/base.py b/libs/sqlalchemy/dialects/mssql/base.py index f7c94aa..103b0a3 100644 --- a/libs/sqlalchemy/dialects/mssql/base.py +++ b/libs/sqlalchemy/dialects/mssql/base.py @@ -791,6 +791,9 @@ class MSSQLCompiler(compiler.SQLCompiler): def get_from_hint_text(self, table, text): return text + def get_crud_hint_text(self, table, text): + return text + def limit_clause(self, select): # Limit in mssql is after the select keyword return "" @@ -949,6 +952,13 @@ class MSSQLCompiler(compiler.SQLCompiler): ] return 'OUTPUT ' + ', '.join(columns) + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + def label_select_column(self, select, column, asfrom): if isinstance(column, expression.Function): return column.label(None) diff --git a/libs/sqlalchemy/dialects/mysql/base.py b/libs/sqlalchemy/dialects/mysql/base.py index 6aa250d..d9ab5a3 100644 --- a/libs/sqlalchemy/dialects/mysql/base.py +++ b/libs/sqlalchemy/dialects/mysql/base.py @@ -84,6 +84,23 @@ all lower case both within SQLAlchemy as well as on the MySQL database itself, especially if database reflection features are to be used. +Transaction Isolation Level +--------------------------- + +:func:`.create_engine` accepts an ``isolation_level`` +parameter which results in the command ``SET SESSION +TRANSACTION ISOLATION LEVEL `` being invoked for +every new connection. Valid values for this parameter are +``READ COMMITTED``, ``READ UNCOMMITTED``, +``REPEATABLE READ``, and ``SERIALIZABLE``:: + + engine = create_engine( + "mysql://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +(new in 0.7.6) + Keys ---- @@ -221,8 +238,29 @@ simply passed through to the underlying CREATE INDEX command, so it *must* be an integer. MySQL only allows a length for an index if it is for a CHAR, VARCHAR, TEXT, BINARY, VARBINARY and BLOB. +Index Types +~~~~~~~~~~~~~ + +Some MySQL storage engines permit you to specify an index type when creating +an index or primary key constraint. SQLAlchemy provides this feature via the +``mysql_using`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_using='hash') + +As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: + + PrimaryKeyConstraint("data", mysql_using='hash') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index +type for your MySQL storage engine. + More information can be found at: + http://dev.mysql.com/doc/refman/5.0/en/create-index.html + +http://dev.mysql.com/doc/refman/5.0/en/create-table.html + """ import datetime, inspect, re, sys @@ -1331,7 +1369,8 @@ class MySQLCompiler(compiler.SQLCompiler): return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in [from_table] + list(extra_froms)) - def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + def update_from_clause(self, update_stmt, from_table, + extra_froms, from_hints, **kw): return None @@ -1421,35 +1460,50 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts.append(joiner.join((opt, arg))) return ' '.join(table_opts) + def visit_create_index(self, create): index = create.element preparer = self.preparer + table = preparer.format_table(index.table) + columns = [preparer.quote(c.name, c.quote) for c in index.columns] + name = preparer.quote( + self._index_identifier(index.name), + index.quote) + text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s " \ - % (preparer.quote(self._index_identifier(index.name), - index.quote),preparer.format_table(index.table)) + text += "INDEX %s ON %s " % (name, table) + + columns = ', '.join(columns) if 'mysql_length' in index.kwargs: length = index.kwargs['mysql_length'] + text += "(%s(%d))" % (columns, length) else: - length = None - if length is not None: - text+= "(%s(%d))" \ - % (', '.join(preparer.quote(c.name, c.quote) - for c in index.columns), length) - else: - text+= "(%s)" \ - % (', '.join(preparer.quote(c.name, c.quote) - for c in index.columns)) + text += "(%s)" % (columns) + + if 'mysql_using' in index.kwargs: + using = index.kwargs['mysql_using'] + text += " USING %s" % (preparer.quote(using, index.quote)) + return text + def visit_primary_key_constraint(self, constraint): + text = super(MySQLDDLCompiler, self).\ + visit_primary_key_constraint(constraint) + if "mysql_using" in constraint.kwargs: + using = constraint.kwargs['mysql_using'] + text += " USING %s" % ( + self.preparer.quote(using, constraint.quote)) + return text def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX %s ON %s" % \ - (self.preparer.quote(self._index_identifier(index.name), index.quote), + (self.preparer.quote( + self._index_identifier(index.name), index.quote + ), self.preparer.format_table(index.table)) def visit_drop_constraint(self, drop): @@ -1768,8 +1822,40 @@ class MySQLDialect(default.DefaultDialect): _backslash_escapes = True _server_ansiquotes = False - def __init__(self, use_ansiquotes=None, **kwargs): + def __init__(self, use_ansiquotes=None, isolation_level=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + def on_connect(self): + if self.isolation_level is not None: + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + return connect + else: + return None + + _isolation_lookup = set(['SERIALIZABLE', + 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) + + def set_isolation_level(self, connection, level): + level = level.replace('_', ' ') + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute('SELECT @@tx_isolation') + val = cursor.fetchone()[0] + cursor.close() + return val.upper().replace("-", " ") def do_commit(self, connection): """Execute a COMMIT.""" diff --git a/libs/sqlalchemy/dialects/oracle/base.py b/libs/sqlalchemy/dialects/oracle/base.py index 88e5062..dd761ae 100644 --- a/libs/sqlalchemy/dialects/oracle/base.py +++ b/libs/sqlalchemy/dialects/oracle/base.py @@ -158,7 +158,7 @@ RESERVED_WORDS = \ 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\ 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\ 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\ - 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT'.split()) + 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) @@ -309,6 +309,9 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): "", ) + def visit_LONG(self, type_): + return "LONG" + def visit_TIMESTAMP(self, type_): if type_.timezone: return "TIMESTAMP WITH TIME ZONE" @@ -481,7 +484,7 @@ class OracleCompiler(compiler.SQLCompiler): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" if asfrom or ashint: - alias_name = isinstance(alias.name, expression._generated_label) and \ + alias_name = isinstance(alias.name, expression._truncated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name if ashint: diff --git a/libs/sqlalchemy/dialects/oracle/cx_oracle.py b/libs/sqlalchemy/dialects/oracle/cx_oracle.py index 64526d2..5001acc 100644 --- a/libs/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/libs/sqlalchemy/dialects/oracle/cx_oracle.py @@ -77,7 +77,7 @@ with this feature but it should be regarded as experimental. Precision Numerics ------------------ -The SQLAlchemy dialect goes thorugh a lot of steps to ensure +The SQLAlchemy dialect goes through a lot of steps to ensure that decimal numbers are sent and received with full accuracy. An "outputtypehandler" callable is associated with each cx_oracle connection object which detects numeric types and @@ -89,6 +89,21 @@ this behavior, and will coerce the ``Decimal`` to ``float`` if the ``asdecimal`` flag is ``False`` (default on :class:`.Float`, optional on :class:`.Numeric`). +Because the handler coerces to ``Decimal`` in all cases first, +the feature can detract significantly from performance. +If precision numerics aren't required, the decimal handling +can be disabled by passing the flag ``coerce_to_decimal=False`` +to :func:`.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", + coerce_to_decimal=False) + +The ``coerce_to_decimal`` flag is new in 0.7.6. + +Another alternative to performance is to use the +`cdecimal `_ library; +see :class:`.Numeric` for additional notes. + The handler attempts to use the "precision" and "scale" attributes of the result set column to best determine if subsequent incoming values should be received as ``Decimal`` as @@ -468,6 +483,7 @@ class OracleDialect_cx_oracle(OracleDialect): auto_convert_lobs=True, threaded=True, allow_twophase=True, + coerce_to_decimal=True, arraysize=50, **kwargs): OracleDialect.__init__(self, **kwargs) self.threaded = threaded @@ -491,7 +507,12 @@ class OracleDialect_cx_oracle(OracleDialect): self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0) - self.supports_native_decimal = self.cx_oracle_ver >= (5, 0) + + self.supports_native_decimal = ( + self.cx_oracle_ver >= (5, 0) and + coerce_to_decimal + ) + self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0) if self.cx_oracle_ver is None: @@ -603,7 +624,9 @@ class OracleDialect_cx_oracle(OracleDialect): size, precision, scale): # convert all NUMBER with precision + positive scale to Decimal # this almost allows "native decimal" mode. - if defaultType == cx_Oracle.NUMBER and precision and scale > 0: + if self.supports_native_decimal and \ + defaultType == cx_Oracle.NUMBER and \ + precision and scale > 0: return cursor.var( cx_Oracle.STRING, 255, @@ -614,7 +637,8 @@ class OracleDialect_cx_oracle(OracleDialect): # make a decision based on each value received - the type # may change from row to row (!). This kills # off "native decimal" mode, handlers still needed. - elif defaultType == cx_Oracle.NUMBER \ + elif self.supports_native_decimal and \ + defaultType == cx_Oracle.NUMBER \ and not precision and scale <= 0: return cursor.var( cx_Oracle.STRING, diff --git a/libs/sqlalchemy/dialects/postgresql/base.py b/libs/sqlalchemy/dialects/postgresql/base.py index 69c11d8..c4c2bbd 100644 --- a/libs/sqlalchemy/dialects/postgresql/base.py +++ b/libs/sqlalchemy/dialects/postgresql/base.py @@ -47,9 +47,18 @@ Transaction Isolation Level :func:`.create_engine` accepts an ``isolation_level`` parameter which results in the command ``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL `` being invoked for every new connection. Valid values for this -parameter are ``READ_COMMITTED``, ``READ_UNCOMMITTED``, ``REPEATABLE_READ``, -and ``SERIALIZABLE``. Note that the psycopg2 dialect does *not* use this -technique and uses psycopg2-specific APIs (see that dialect for details). +parameter are ``READ COMMITTED``, ``READ UNCOMMITTED``, ``REPEATABLE READ``, +and ``SERIALIZABLE``:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +When using the psycopg2 dialect, a psycopg2-specific method of setting +transaction isolation level is used, but the API of ``isolation_level`` +remains the same - see :ref:`psycopg2_isolation`. + Remote / Cross-Schema Table Introspection ----------------------------------------- diff --git a/libs/sqlalchemy/dialects/postgresql/psycopg2.py b/libs/sqlalchemy/dialects/postgresql/psycopg2.py index c66180f..5aa9397 100644 --- a/libs/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/libs/sqlalchemy/dialects/postgresql/psycopg2.py @@ -38,6 +38,26 @@ psycopg2-specific keyword arguments which are accepted by * *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True by default. +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + +See also: + +`PQconnectdbParams `_ + Per-Statement/Connection Execution Options ------------------------------------------- @@ -97,6 +117,8 @@ Transactions The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. +.. _psycopg2_isolation: + Transaction Isolation Level --------------------------- diff --git a/libs/sqlalchemy/dialects/sqlite/base.py b/libs/sqlalchemy/dialects/sqlite/base.py index f9520af..10a0d88 100644 --- a/libs/sqlalchemy/dialects/sqlite/base.py +++ b/libs/sqlalchemy/dialects/sqlite/base.py @@ -441,20 +441,6 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result return result -class SQLiteExecutionContext(default.DefaultExecutionContext): - def get_result_proxy(self): - rp = base.ResultProxy(self) - if rp._metadata: - # adjust for dotted column names. SQLite - # in the case of UNION may store col names as - # "tablename.colname" - # in cursor.description - for colname in rp._metadata.keys: - if "." in colname: - trunc_col = colname.split(".")[1] - rp._metadata._set_keymap_synonym(trunc_col, colname) - return rp - class SQLiteDialect(default.DefaultDialect): name = 'sqlite' supports_alter = False @@ -472,7 +458,6 @@ class SQLiteDialect(default.DefaultDialect): ischema_names = ischema_names colspecs = colspecs isolation_level = None - execution_ctx_cls = SQLiteExecutionContext supports_cast = True supports_default_values = True @@ -540,6 +525,16 @@ class SQLiteDialect(default.DefaultDialect): else: return None + def _translate_colname(self, colname): + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname" + # in cursor.description + if "." in colname: + return colname.split(".")[1], colname + else: + return colname, None + @reflection.cache def get_table_names(self, connection, schema=None, **kw): if schema is not None: diff --git a/libs/sqlalchemy/engine/__init__.py b/libs/sqlalchemy/engine/__init__.py index 4fac3e5..23b4b0b 100644 --- a/libs/sqlalchemy/engine/__init__.py +++ b/libs/sqlalchemy/engine/__init__.py @@ -306,6 +306,12 @@ def create_engine(*args, **kwargs): this is configurable with the MySQLDB connection itself and the server configuration as well). + :param pool_reset_on_return='rollback': set the "reset on return" + behavior of the pool, which is whether ``rollback()``, + ``commit()``, or nothing is called upon connections + being returned to the pool. See the docstring for + ``reset_on_return`` at :class:`.Pool`. (new as of 0.7.6) + :param pool_timeout=30: number of seconds to wait before giving up on getting a connection from the pool. This is only used with :class:`~sqlalchemy.pool.QueuePool`. diff --git a/libs/sqlalchemy/engine/base.py b/libs/sqlalchemy/engine/base.py index db19fe7..d16fc9c 100644 --- a/libs/sqlalchemy/engine/base.py +++ b/libs/sqlalchemy/engine/base.py @@ -491,14 +491,23 @@ class Dialect(object): raise NotImplementedError() def do_executemany(self, cursor, statement, parameters, context=None): - """Provide an implementation of *cursor.executemany(statement, - parameters)*.""" + """Provide an implementation of ``cursor.executemany(statement, + parameters)``.""" raise NotImplementedError() def do_execute(self, cursor, statement, parameters, context=None): - """Provide an implementation of *cursor.execute(statement, - parameters)*.""" + """Provide an implementation of ``cursor.execute(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute_no_params(self, cursor, statement, parameters, context=None): + """Provide an implementation of ``cursor.execute(statement)``. + + The parameter collection should not be sent. + + """ raise NotImplementedError() @@ -777,12 +786,12 @@ class Connectable(object): def connect(self, **kwargs): """Return a :class:`.Connection` object. - + Depending on context, this may be ``self`` if this object is already an instance of :class:`.Connection`, or a newly procured :class:`.Connection` if this object is an instance of :class:`.Engine`. - + """ def contextual_connect(self): @@ -793,7 +802,7 @@ class Connectable(object): is already an instance of :class:`.Connection`, or a newly procured :class:`.Connection` if this object is an instance of :class:`.Engine`. - + """ raise NotImplementedError() @@ -904,6 +913,12 @@ class Connection(Connectable): c.__dict__ = self.__dict__.copy() return c + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + def execution_options(self, **opt): """ Set non-SQL options for the connection which take effect during execution. @@ -940,7 +955,7 @@ class Connection(Connectable): :param compiled_cache: Available on: Connection. A dictionary where :class:`.Compiled` objects will be cached when the :class:`.Connection` compiles a clause - expression into a :class:`.Compiled` object. + expression into a :class:`.Compiled` object. It is the user's responsibility to manage the size of this dictionary, which will have keys corresponding to the dialect, clause element, the column @@ -953,7 +968,7 @@ class Connection(Connectable): some operations, including flush operations. The caching used by the ORM internally supersedes a cache dictionary specified here. - + :param isolation_level: Available on: Connection. Set the transaction isolation level for the lifespan of this connection. Valid values include @@ -962,7 +977,7 @@ class Connection(Connectable): database specific, including those for :ref:`sqlite_toplevel`, :ref:`postgresql_toplevel` - see those dialect's documentation for further info. - + Note that this option necessarily affects the underying DBAPI connection for the lifespan of the originating :class:`.Connection`, and is not per-execution. This @@ -970,6 +985,18 @@ class Connection(Connectable): is returned to the connection pool, i.e. the :meth:`.Connection.close` method is called. + :param no_parameters: When ``True``, if the final parameter + list or dictionary is totally empty, will invoke the + statement on the cursor as ``cursor.execute(statement)``, + not passing the parameter collection at all. + Some DBAPIs such as psycopg2 and mysql-python consider + percent signs as significant only when parameters are + present; this option allows code to generate SQL + containing percent signs (and possibly other characters) + that is neutral regarding whether it's executed by the DBAPI + or piped into a script that's later invoked by + command line tools. New in 0.7.6. + :param stream_results: Available on: Connection, statement. Indicate to the dialect that results should be "streamed" and not pre-buffered, if possible. This is a limitation @@ -1113,17 +1140,35 @@ class Connection(Connectable): def begin(self): """Begin a transaction and return a transaction handle. - + The returned object is an instance of :class:`.Transaction`. + This object represents the "scope" of the transaction, + which completes when either the :meth:`.Transaction.rollback` + or :meth:`.Transaction.commit` method is called. + + Nested calls to :meth:`.begin` on the same :class:`.Connection` + will return new :class:`.Transaction` objects that represent + an emulated transaction within the scope of the enclosing + transaction, that is:: + + trans = conn.begin() # outermost transaction + trans2 = conn.begin() # "nested" + trans2.commit() # does nothing + trans.commit() # actually commits + + Calls to :meth:`.Transaction.commit` only have an effect + when invoked via the outermost :class:`.Transaction` object, though the + :meth:`.Transaction.rollback` method of any of the + :class:`.Transaction` objects will roll back the + transaction. - Repeated calls to ``begin`` on the same Connection will create - a lightweight, emulated nested transaction. Only the - outermost transaction may ``commit``. Calls to ``commit`` on - inner transactions are ignored. Any transaction in the - hierarchy may ``rollback``, however. + See also: - See also :meth:`.Connection.begin_nested`, - :meth:`.Connection.begin_twophase`. + :meth:`.Connection.begin_nested` - use a SAVEPOINT + + :meth:`.Connection.begin_twophase` - use a two phase /XID transaction + + :meth:`.Engine.begin` - context manager available from :class:`.Engine`. """ @@ -1157,7 +1202,7 @@ class Connection(Connectable): def begin_twophase(self, xid=None): """Begin a two-phase or XA transaction and return a transaction handle. - + The returned object is an instance of :class:`.TwoPhaseTransaction`, which in addition to the methods provided by :class:`.Transaction`, also provides a :meth:`~.TwoPhaseTransaction.prepare` @@ -1302,7 +1347,7 @@ class Connection(Connectable): def close(self): """Close this :class:`.Connection`. - + This results in a release of the underlying database resources, that is, the DBAPI connection referenced internally. The DBAPI connection is typically restored @@ -1313,7 +1358,7 @@ class Connection(Connectable): the DBAPI connection's ``rollback()`` method, regardless of any :class:`.Transaction` object that may be outstanding with regards to this :class:`.Connection`. - + After :meth:`~.Connection.close` is called, the :class:`.Connection` is permanently in a closed state, and will allow no further operations. @@ -1354,24 +1399,24 @@ class Connection(Connectable): * a :class:`.DDLElement` object * a :class:`.DefaultGenerator` object * a :class:`.Compiled` object - + :param \*multiparams/\**params: represent bound parameter values to be used in the execution. Typically, the format is either a collection of one or more dictionaries passed to \*multiparams:: - + conn.execute( table.insert(), {"id":1, "value":"v1"}, {"id":2, "value":"v2"} ) - + ...or individual key/values interpreted by \**params:: - + conn.execute( table.insert(), id=1, value="v1" ) - + In the case that a plain SQL string is passed, and the underlying DBAPI accepts positional bind parameters, a collection of tuples or individual values in \*multiparams may be passed:: @@ -1380,21 +1425,21 @@ class Connection(Connectable): "INSERT INTO table (id, value) VALUES (?, ?)", (1, "v1"), (2, "v2") ) - + conn.execute( "INSERT INTO table (id, value) VALUES (?, ?)", 1, "v1" ) - + Note above, the usage of a question mark "?" or other symbol is contingent upon the "paramstyle" accepted by the DBAPI in use, which may be any of "qmark", "named", "pyformat", "format", "numeric". See `pep-249 `_ for details on paramstyle. - + To execute a textual SQL statement which uses bound parameters in a DBAPI-agnostic way, use the :func:`~.expression.text` construct. - + """ for c in type(object).__mro__: if c in Connection.executors: @@ -1623,7 +1668,8 @@ class Connection(Connectable): if self._echo: self.engine.logger.info(statement) - self.engine.logger.info("%r", sql_util._repr_params(parameters, batches=10)) + self.engine.logger.info("%r", + sql_util._repr_params(parameters, batches=10)) try: if context.executemany: self.dialect.do_executemany( @@ -1631,6 +1677,11 @@ class Connection(Connectable): statement, parameters, context) + elif not parameters and context.no_parameters: + self.dialect.do_execute_no_params( + cursor, + statement, + context) else: self.dialect.do_execute( cursor, @@ -1845,33 +1896,41 @@ class Connection(Connectable): """Execute the given function within a transaction boundary. The function is passed this :class:`.Connection` - as the first argument, followed by the given \*args and \**kwargs. - - This is a shortcut for explicitly invoking - :meth:`.Connection.begin`, calling :meth:`.Transaction.commit` - upon success or :meth:`.Transaction.rollback` upon an - exception raise:: + as the first argument, followed by the given \*args and \**kwargs, + e.g.:: def do_something(conn, x, y): conn.execute("some statement", {'x':x, 'y':y}) - + conn.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Connection.begin`:: + + with conn.begin(): + conn.execute("some statement", {'x':5, 'y':10}) + + As well as with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) - Note that context managers (i.e. the ``with`` statement) - present a more modern way of accomplishing the above, - using the :class:`.Transaction` object as a base:: + See also: - with conn.begin(): - conn.execute("some statement", {'x':5, 'y':10}) - - One advantage to the :meth:`.Connection.transaction` - method is that the same method is also available - on :class:`.Engine` as :meth:`.Engine.transaction` - - this method procures a :class:`.Connection` and then - performs the same operation, allowing equivalent - usage with either a :class:`.Connection` or :class:`.Engine` - without needing to know what kind of object - it is. + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Engine.transaction` - engine-level version of + :meth:`.Connection.transaction` """ @@ -1887,15 +1946,15 @@ class Connection(Connectable): def run_callable(self, callable_, *args, **kwargs): """Given a callable object or function, execute it, passing a :class:`.Connection` as the first argument. - + The given \*args and \**kwargs are passed subsequent to the :class:`.Connection` argument. - + This function, along with :meth:`.Engine.run_callable`, allows a function to be run with a :class:`.Connection` or :class:`.Engine` object without the need to know which one is being dealt with. - + """ return callable_(self, *args, **kwargs) @@ -1906,11 +1965,11 @@ class Connection(Connectable): class Transaction(object): """Represent a database transaction in progress. - + The :class:`.Transaction` object is procured by calling the :meth:`~.Connection.begin` method of :class:`.Connection`:: - + from sqlalchemy import create_engine engine = create_engine("postgresql://scott:tiger@localhost/test") connection = engine.connect() @@ -1923,7 +1982,7 @@ class Transaction(object): also implements a context manager interface so that the Python ``with`` statement can be used with the :meth:`.Connection.begin` method:: - + with connection.begin(): connection.execute("insert into x (a, b) values (1, 2)") @@ -1931,7 +1990,7 @@ class Transaction(object): See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`, :meth:`.Connection.begin_nested`. - + .. index:: single: thread safety; Transaction """ @@ -2012,9 +2071,9 @@ class NestedTransaction(Transaction): A new :class:`.NestedTransaction` object may be procured using the :meth:`.Connection.begin_nested` method. - + The interface is the same as that of :class:`.Transaction`. - + """ def __init__(self, connection, parent): super(NestedTransaction, self).__init__(connection, parent) @@ -2033,13 +2092,13 @@ class NestedTransaction(Transaction): class TwoPhaseTransaction(Transaction): """Represent a two-phase transaction. - + A new :class:`.TwoPhaseTransaction` object may be procured using the :meth:`.Connection.begin_twophase` method. - + The interface is the same as that of :class:`.Transaction` with the addition of the :meth:`prepare` method. - + """ def __init__(self, connection, xid): super(TwoPhaseTransaction, self).__init__(connection, None) @@ -2049,9 +2108,9 @@ class TwoPhaseTransaction(Transaction): def prepare(self): """Prepare this :class:`.TwoPhaseTransaction`. - + After a PREPARE, the transaction can be committed. - + """ if not self._parent.is_active: raise exc.InvalidRequestError("This transaction is inactive") @@ -2075,11 +2134,11 @@ class Engine(Connectable, log.Identified): :func:`~sqlalchemy.create_engine` function. See also: - + :ref:`engines_toplevel` :ref:`connections_toplevel` - + """ _execution_options = util.immutabledict() @@ -2115,13 +2174,13 @@ class Engine(Connectable, log.Identified): def update_execution_options(self, **opt): """Update the default execution_options dictionary of this :class:`.Engine`. - + The given keys/values in \**opt are added to the default execution options that will be used for all connections. The initial contents of this dictionary can be sent via the ``execution_options`` paramter to :func:`.create_engine`. - + See :meth:`.Connection.execution_options` for more details on execution options. @@ -2236,19 +2295,96 @@ class Engine(Connectable, log.Identified): if connection is None: conn.close() + class _trans_ctx(object): + def __init__(self, conn, transaction, close_with_result): + self.conn = conn + self.transaction = transaction + self.close_with_result = close_with_result + + def __enter__(self): + return self.conn + + def __exit__(self, type, value, traceback): + if type is not None: + self.transaction.rollback() + else: + self.transaction.commit() + if not self.close_with_result: + self.conn.close() + + def begin(self, close_with_result=False): + """Return a context manager delivering a :class:`.Connection` + with a :class:`.Transaction` established. + + E.g.:: + + with engine.begin() as conn: + conn.execute("insert into table (x, y, z) values (1, 2, 3)") + conn.execute("my_special_procedure(5)") + + Upon successful operation, the :class:`.Transaction` + is committed. If an error is raised, the :class:`.Transaction` + is rolled back. + + The ``close_with_result`` flag is normally ``False``, and indicates + that the :class:`.Connection` will be closed when the operation + is complete. When set to ``True``, it indicates the :class:`.Connection` + is in "single use" mode, where the :class:`.ResultProxy` + returned by the first call to :meth:`.Connection.execute` will + close the :class:`.Connection` when that :class:`.ResultProxy` + has exhausted all result rows. + + New in 0.7.6. + + See also: + + :meth:`.Engine.connect` - procure a :class:`.Connection` from + an :class:`.Engine`. + + :meth:`.Connection.begin` - start a :class:`.Transaction` + for a particular :class:`.Connection`. + + """ + conn = self.contextual_connect(close_with_result=close_with_result) + trans = conn.begin() + return Engine._trans_ctx(conn, trans, close_with_result) + def transaction(self, callable_, *args, **kwargs): """Execute the given function within a transaction boundary. - The function is passed a newly procured - :class:`.Connection` as the first argument, followed by - the given \*args and \**kwargs. The :class:`.Connection` - is then closed (returned to the pool) when the operation - is complete. + The function is passed a :class:`.Connection` newly procured + from :meth:`.Engine.contextual_connect` as the first argument, + followed by the given \*args and \**kwargs. + + e.g.:: + + def do_something(conn, x, y): + conn.execute("some statement", {'x':x, 'y':y}) + + engine.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) - This method can be used interchangeably with - :meth:`.Connection.transaction`. See that method for - more details on usage as well as a modern alternative - using context managers (i.e. the ``with`` statement). + See also: + + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Connection.transaction` - connection-level version of + :meth:`.Engine.transaction` """ @@ -2261,15 +2397,15 @@ class Engine(Connectable, log.Identified): def run_callable(self, callable_, *args, **kwargs): """Given a callable object or function, execute it, passing a :class:`.Connection` as the first argument. - + The given \*args and \**kwargs are passed subsequent to the :class:`.Connection` argument. - + This function, along with :meth:`.Connection.run_callable`, allows a function to be run with a :class:`.Connection` or :class:`.Engine` object without the need to know which one is being dealt with. - + """ conn = self.contextual_connect() try: @@ -2390,19 +2526,19 @@ class Engine(Connectable, log.Identified): def raw_connection(self): """Return a "raw" DBAPI connection from the connection pool. - + The returned object is a proxied version of the DBAPI connection object used by the underlying driver in use. The object will have all the same behavior as the real DBAPI connection, except that its ``close()`` method will result in the connection being returned to the pool, rather than being closed for real. - + This method provides direct DBAPI connection access for special situations. In most situations, the :class:`.Connection` object should be used, which is procured using the :meth:`.Engine.connect` method. - + """ return self.pool.unique_connection() @@ -2487,7 +2623,6 @@ except ImportError: def __getattr__(self, name): try: - # TODO: no test coverage here return self[name] except KeyError, e: raise AttributeError(e.args[0]) @@ -2575,6 +2710,10 @@ class ResultMetaData(object): context = parent.context dialect = context.dialect typemap = dialect.dbapi_type_map + translate_colname = dialect._translate_colname + + # high precedence key values. + primary_keymap = {} for i, rec in enumerate(metadata): colname = rec[0] @@ -2583,6 +2722,9 @@ class ResultMetaData(object): if dialect.description_encoding: colname = dialect._description_decoder(colname) + if translate_colname: + colname, untranslated = translate_colname(colname) + if context.result_map: try: name, obj, type_ = context.result_map[colname.lower()] @@ -2600,15 +2742,17 @@ class ResultMetaData(object): # indexes as keys. This is only needed for the Python version of # RowProxy (the C version uses a faster path for integer indexes). - keymap[i] = rec - - # Column names as keys - if keymap.setdefault(name.lower(), rec) is not rec: - # We do not raise an exception directly because several - # columns colliding by name is not a problem as long as the - # user does not try to access them (ie use an index directly, - # or the more precise ColumnElement) - keymap[name.lower()] = (processor, obj, None) + primary_keymap[i] = rec + + # populate primary keymap, looking for conflicts. + if primary_keymap.setdefault(name.lower(), rec) is not rec: + # place a record that doesn't have the "index" - this + # is interpreted later as an AmbiguousColumnError, + # but only when actually accessed. Columns + # colliding by name is not a problem if those names + # aren't used; integer and ColumnElement access is always + # unambiguous. + primary_keymap[name.lower()] = (processor, obj, None) if dialect.requires_name_normalize: colname = dialect.normalize_name(colname) @@ -2618,10 +2762,20 @@ class ResultMetaData(object): for o in obj: keymap[o] = rec + if translate_colname and \ + untranslated: + keymap[untranslated] = rec + + # overwrite keymap values with those of the + # high precedence keymap. + keymap.update(primary_keymap) + if parent._echo: context.engine.logger.debug( "Col %r", tuple(x[0] for x in metadata)) + @util.pending_deprecation("0.8", "sqlite dialect uses " + "_translate_colname() now") def _set_keymap_synonym(self, name, origname): """Set a synonym for the given name. @@ -2647,7 +2801,7 @@ class ResultMetaData(object): if key._label and key._label.lower() in map: result = map[key._label.lower()] elif hasattr(key, 'name') and key.name.lower() in map: - # match is only on name. + # match is only on name. result = map[key.name.lower()] # search extra hard to make sure this # isn't a column/label name overlap. @@ -2800,7 +2954,7 @@ class ResultProxy(object): @property def returns_rows(self): """True if this :class:`.ResultProxy` returns rows. - + I.e. if it is legal to call the methods :meth:`~.ResultProxy.fetchone`, :meth:`~.ResultProxy.fetchmany` @@ -2814,12 +2968,12 @@ class ResultProxy(object): """True if this :class:`.ResultProxy` is the result of a executing an expression language compiled :func:`.expression.insert` construct. - + When True, this implies that the :attr:`inserted_primary_key` attribute is accessible, assuming the statement did not include a user defined "returning" construct. - + """ return self.context.isinsert @@ -2867,7 +3021,7 @@ class ResultProxy(object): @util.memoized_property def inserted_primary_key(self): """Return the primary key for the row just inserted. - + The return value is a list of scalar values corresponding to the list of primary key columns in the target table. @@ -2875,7 +3029,7 @@ class ResultProxy(object): This only applies to single row :func:`.insert` constructs which did not explicitly specify :meth:`.Insert.returning`. - + Note that primary key columns which specify a server_default clause, or otherwise do not qualify as "autoincrement" diff --git a/libs/sqlalchemy/engine/default.py b/libs/sqlalchemy/engine/default.py index 73bd7fd..5c2d981 100644 --- a/libs/sqlalchemy/engine/default.py +++ b/libs/sqlalchemy/engine/default.py @@ -44,6 +44,7 @@ class DefaultDialect(base.Dialect): postfetch_lastrowid = True implicit_returning = False + supports_native_enum = False supports_native_boolean = False @@ -95,6 +96,10 @@ class DefaultDialect(base.Dialect): # and denormalize_name() must be provided. requires_name_normalize = False + # a hook for SQLite's translation of + # result column names + _translate_colname = None + reflection_options = () def __init__(self, convert_unicode=False, assert_unicode=False, @@ -329,6 +334,9 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, context=None): cursor.execute(statement, parameters) + def do_execute_no_params(self, cursor, statement, context=None): + cursor.execute(statement) + def is_disconnect(self, e, connection, cursor): return False @@ -533,6 +541,10 @@ class DefaultExecutionContext(base.ExecutionContext): return self @util.memoized_property + def no_parameters(self): + return self.execution_options.get("no_parameters", False) + + @util.memoized_property def is_crud(self): return self.isinsert or self.isupdate or self.isdelete diff --git a/libs/sqlalchemy/engine/reflection.py b/libs/sqlalchemy/engine/reflection.py index f5911f3..71d97e6 100644 --- a/libs/sqlalchemy/engine/reflection.py +++ b/libs/sqlalchemy/engine/reflection.py @@ -317,7 +317,7 @@ class Inspector(object): info_cache=self.info_cache, **kw) return indexes - def reflecttable(self, table, include_columns, exclude_columns=None): + def reflecttable(self, table, include_columns, exclude_columns=()): """Given a Table object, load its internal constructs based on introspection. This is the underlying method used by most dialects to produce @@ -414,9 +414,12 @@ class Inspector(object): # Primary keys pk_cons = self.get_pk_constraint(table_name, schema, **tblkw) if pk_cons: + pk_cols = [table.c[pk] + for pk in pk_cons['constrained_columns'] + if pk in table.c and pk not in exclude_columns + ] + [pk for pk in table.primary_key if pk.key in exclude_columns] primary_key_constraint = sa_schema.PrimaryKeyConstraint(name=pk_cons.get('name'), - *[table.c[pk] for pk in pk_cons['constrained_columns'] - if pk in table.c] + *pk_cols ) table.append_constraint(primary_key_constraint) diff --git a/libs/sqlalchemy/engine/strategies.py b/libs/sqlalchemy/engine/strategies.py index 7b2da68..4d5a4b3 100644 --- a/libs/sqlalchemy/engine/strategies.py +++ b/libs/sqlalchemy/engine/strategies.py @@ -108,7 +108,8 @@ class DefaultEngineStrategy(EngineStrategy): 'timeout': 'pool_timeout', 'recycle': 'pool_recycle', 'events':'pool_events', - 'use_threadlocal':'pool_threadlocal'} + 'use_threadlocal':'pool_threadlocal', + 'reset_on_return':'pool_reset_on_return'} for k in util.get_cls_kwargs(poolclass): tk = translate.get(k, k) if tk in kwargs: @@ -226,6 +227,9 @@ class MockEngineStrategy(EngineStrategy): def contextual_connect(self, **kwargs): return self + def execution_options(self, **kw): + return self + def compiler(self, statement, parameters, **kwargs): return self._dialect.compiler( statement, parameters, engine=self, **kwargs) diff --git a/libs/sqlalchemy/event.py b/libs/sqlalchemy/event.py index 9cc3139..cd70b3a 100644 --- a/libs/sqlalchemy/event.py +++ b/libs/sqlalchemy/event.py @@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL') def listen(target, identifier, fn, *args, **kw): """Register a listener function for the given target. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( table.name, @@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw): def listens_for(target, identifier, *args, **kw): """Decorate a function as a listener for the given target + identifier. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + @event.listens_for(UniqueConstraint, "after_parent_attach") def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( @@ -205,12 +205,14 @@ class _DispatchDescriptor(object): def insert(self, obj, target, propagate): assert isinstance(target, type), \ "Class-level Event targets must be classes." - stack = [target] while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].insert(0, obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].insert(0, obj) def append(self, obj, target, propagate): assert isinstance(target, type), \ @@ -220,7 +222,20 @@ class _DispatchDescriptor(object): while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].append(obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].append(obj) + + def update_subclass(self, target): + clslevel = self._clslevel[target] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend([ + fn for fn + in self._clslevel[cls] + if fn not in clslevel + ]) def remove(self, obj, target): stack = [target] @@ -252,6 +267,8 @@ class _ListenerCollection(object): _exec_once = False def __init__(self, parent, target_cls): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) self.parent_listeners = parent._clslevel[target_cls] self.name = parent.__name__ self.listeners = [] diff --git a/libs/sqlalchemy/exc.py b/libs/sqlalchemy/exc.py index 64f25a2..91ffc28 100644 --- a/libs/sqlalchemy/exc.py +++ b/libs/sqlalchemy/exc.py @@ -162,7 +162,7 @@ UnmappedColumnError = None class StatementError(SQLAlchemyError): """An error occurred during execution of a SQL statement. - :class:`.StatementError` wraps the exception raised + :class:`StatementError` wraps the exception raised during execution, and features :attr:`.statement` and :attr:`.params` attributes which supply context regarding the specifics of the statement which had an issue. @@ -172,6 +172,15 @@ class StatementError(SQLAlchemyError): """ + statement = None + """The string SQL statement being invoked when this exception occurred.""" + + params = None + """The parameter list being used when this exception occurred.""" + + orig = None + """The DBAPI exception object.""" + def __init__(self, message, statement, params, orig): SQLAlchemyError.__init__(self, message) self.statement = statement @@ -192,21 +201,21 @@ class StatementError(SQLAlchemyError): class DBAPIError(StatementError): """Raised when the execution of a database operation fails. - ``DBAPIError`` wraps exceptions raised by the DB-API underlying the + Wraps exceptions raised by the DB-API underlying the database operation. Driver-specific implementations of the standard DB-API exception types are wrapped by matching sub-types of SQLAlchemy's - ``DBAPIError`` when possible. DB-API's ``Error`` type maps to - ``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note + :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to + :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note that there is no guarantee that different DB-API implementations will raise the same exception type for any given error condition. - :class:`.DBAPIError` features :attr:`.statement` - and :attr:`.params` attributes which supply context regarding + :class:`DBAPIError` features :attr:`~.StatementError.statement` + and :attr:`~.StatementError.params` attributes which supply context regarding the specifics of the statement which had an issue, for the typical case when the error was raised within the context of emitting a SQL statement. - The wrapped exception object is available in the :attr:`.orig` attribute. + The wrapped exception object is available in the :attr:`~.StatementError.orig` attribute. Its type and properties are DB-API implementation specific. """ diff --git a/libs/sqlalchemy/ext/declarative.py b/libs/sqlalchemy/ext/declarative.py index 891130a..faf575d 100755 --- a/libs/sqlalchemy/ext/declarative.py +++ b/libs/sqlalchemy/ext/declarative.py @@ -1213,6 +1213,12 @@ def _as_declarative(cls, classname, dict_): del our_stuff[key] cols = sorted(cols, key=lambda c:c._creation_order) table = None + + if hasattr(cls, '__table_cls__'): + table_cls = util.unbound_method_to_callable(cls.__table_cls__) + else: + table_cls = Table + if '__table__' not in dict_: if tablename is not None: @@ -1230,7 +1236,7 @@ def _as_declarative(cls, classname, dict_): if autoload: table_kw['autoload'] = True - cls.__table__ = table = Table(tablename, cls.metadata, + cls.__table__ = table = table_cls(tablename, cls.metadata, *(tuple(cols) + tuple(args)), **table_kw) else: diff --git a/libs/sqlalchemy/ext/hybrid.py b/libs/sqlalchemy/ext/hybrid.py index 086ec90..8734181 100644 --- a/libs/sqlalchemy/ext/hybrid.py +++ b/libs/sqlalchemy/ext/hybrid.py @@ -11,30 +11,30 @@ class level and at the instance level. The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of method decorator, is around 50 lines of code and has almost no dependencies on the rest -of SQLAlchemy. It can in theory work with any class-level expression generator. +of SQLAlchemy. It can, in theory, work with any descriptor-based expression +system. -Consider a table ``interval`` as below:: - - from sqlalchemy import MetaData, Table, Column, Integer - - metadata = MetaData() - - interval_table = Table('interval', metadata, - Column('id', Integer, primary_key=True), - Column('start', Integer, nullable=False), - Column('end', Integer, nullable=False) - ) - -We can define higher level functions on mapped classes that produce SQL -expressions at the class level, and Python expression evaluation at the -instance level. Below, each function decorated with :func:`.hybrid_method` -or :func:`.hybrid_property` may receive ``self`` as an instance of the class, -or as the class itself:: +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce +SQL expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +as the class itself:: + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import Session, aliased from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method - from sqlalchemy.orm import mapper, Session, aliased + + Base = declarative_base() + + class Interval(Base): + __tablename__ = 'interval' + + id = Column(Integer, primary_key=True) + start = Column(Integer, nullable=False) + end = Column(Integer, nullable=False) - class Interval(object): def __init__(self, start, end): self.start = start self.end = end @@ -51,8 +51,6 @@ or as the class itself:: def intersects(self, other): return self.contains(other.start) | self.contains(other.end) - mapper(Interval, interval_table) - Above, the ``length`` property returns the difference between the ``end`` and ``start`` attributes. With an instance of ``Interval``, this subtraction occurs in Python, using normal Python descriptor mechanics:: @@ -60,10 +58,11 @@ in Python, using normal Python descriptor mechanics:: >>> i1 = Interval(5, 10) >>> i1.length 5 - -At the class level, the usual descriptor behavior of returning the descriptor -itself is modified by :class:`.hybrid_property`, to instead evaluate the function -body given the ``Interval`` class as the argument:: + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +returns a new SQL expression:: >>> print Interval.length interval."end" - interval.start @@ -83,9 +82,10 @@ locate attributes, so can also be used with hybrid attributes:: FROM interval WHERE interval."end" - interval.start = :param_1 -The ``contains()`` and ``intersects()`` methods are decorated with :class:`.hybrid_method`. -This decorator applies the same idea to methods which accept -zero or more arguments. The above methods return boolean values, and take advantage +The ``Interval`` class example also illustrates two methods, ``contains()`` and ``intersects()``, +decorated with :class:`.hybrid_method`. +This decorator applies the same idea to methods that :class:`.hybrid_property` applies +to attributes. The methods return boolean values, and take advantage of the Python ``|`` and ``&`` bitwise operators to produce equivalent instance-level and SQL expression-level boolean behavior:: @@ -368,7 +368,12 @@ SQL expression versus SQL expression:: >>> sw1 = aliased(SearchWord) >>> sw2 = aliased(SearchWord) - >>> print Session().query(sw1.word_insensitive, sw2.word_insensitive).filter(sw1.word_insensitive > sw2.word_insensitive) + >>> print Session().query( + ... sw1.word_insensitive, + ... sw2.word_insensitive).\\ + ... filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) SELECT lower(searchword_1.word) AS lower_1, lower(searchword_2.word) AS lower_2 FROM searchword AS searchword_1, searchword AS searchword_2 WHERE lower(searchword_1.word) > lower(searchword_2.word) diff --git a/libs/sqlalchemy/ext/orderinglist.py b/libs/sqlalchemy/ext/orderinglist.py index 9847861..3895725 100644 --- a/libs/sqlalchemy/ext/orderinglist.py +++ b/libs/sqlalchemy/ext/orderinglist.py @@ -184,12 +184,11 @@ class OrderingList(list): This implementation relies on the list starting in the proper order, so be **sure** to put an ``order_by`` on your relationship. - ordering_attr + :param ordering_attr: Name of the attribute that stores the object's order in the relationship. - ordering_func - Optional. A function that maps the position in the Python list to a + :param ordering_func: Optional. A function that maps the position in the Python list to a value to store in the ``ordering_attr``. Values returned are usually (but need not be!) integers. @@ -202,7 +201,7 @@ class OrderingList(list): like stepped numbering, alphabetical and Fibonacci numbering, see the unit tests. - reorder_on_append + :param reorder_on_append: Default False. When appending an object with an existing (non-None) ordering value, that value will be left untouched unless ``reorder_on_append`` is true. This is an optimization to avoid a diff --git a/libs/sqlalchemy/orm/collections.py b/libs/sqlalchemy/orm/collections.py index 7872715..160fac8 100644 --- a/libs/sqlalchemy/orm/collections.py +++ b/libs/sqlalchemy/orm/collections.py @@ -112,12 +112,32 @@ from sqlalchemy.sql import expression from sqlalchemy import schema, util, exc as sa_exc + __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] __instrumentation_mutex = util.threading.Lock() +class _SerializableColumnGetter(object): + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return _SerializableColumnGetter, (self.colkeys,) + + def __call__(self, value): + state = instance_state(value) + m = _state_mapper(state) + key = [m._get_state_attr_by_column( + state, state.dict, + m.mapped_table.columns[k]) + for k in self.colkeys] + if self.composite: + return tuple(key) + else: + return key[0] def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec): after a session flush. """ + global _state_mapper, instance_state from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._only_column_elements(q, "mapping_spec") - for q in util.to_list(mapping_spec)] - if len(cols) == 1: - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return m._get_state_attr_by_column(state, state.dict, cols[0]) - else: - mapping_spec = tuple(cols) - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return tuple(m._get_state_attr_by_column(state, state.dict, c) - for c in mapping_spec) + cols = [c.key for c in [ + expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec)]] + keyfunc = _SerializableColumnGetter(cols) return lambda: MappedCollection(keyfunc) +class _SerializableAttrGetter(object): + def __init__(self, name): + self.name = name + self.getter = operator.attrgetter(name) + + def __call__(self, target): + return self.getter(target) + + def __reduce__(self): + return _SerializableAttrGetter, (self.name, ) + def attribute_mapped_collection(attr_name): """A dictionary-based collection type with attribute-based keying. @@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name): after a session flush. """ - return lambda: MappedCollection(operator.attrgetter(attr_name)) + getter = _SerializableAttrGetter(attr_name) + return lambda: MappedCollection(getter) def mapped_collection(keyfunc): @@ -814,6 +837,7 @@ def _instrument_class(cls): methods[name] = None, None, after # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): fn = getattr(cls, method, None) if (fn and method not in methods and @@ -1465,3 +1489,13 @@ class MappedCollection(dict): incoming_key, value, new_key)) yield value _convert = collection.converter(_convert) + +# ensure instrumentation is associated with +# these built-in classes; if a user-defined class +# subclasses these and uses @internally_instrumented, +# the superclass is otherwise not instrumented. +# see [ticket:2406]. +_instrument_class(MappedCollection) +_instrument_class(InstrumentedList) +_instrument_class(InstrumentedSet) + diff --git a/libs/sqlalchemy/orm/mapper.py b/libs/sqlalchemy/orm/mapper.py index 4c952c1..e96b754 100644 --- a/libs/sqlalchemy/orm/mapper.py +++ b/libs/sqlalchemy/orm/mapper.py @@ -1452,12 +1452,19 @@ class Mapper(object): return result def _is_userland_descriptor(self, obj): - return not isinstance(obj, - (MapperProperty, attributes.QueryableAttribute)) and \ - hasattr(obj, '__get__') and not \ - isinstance(obj.__get__(None, obj), - attributes.QueryableAttribute) - + if isinstance(obj, (MapperProperty, + attributes.QueryableAttribute)): + return False + elif not hasattr(obj, '__get__'): + return False + else: + obj = util.unbound_method_to_callable(obj) + if isinstance( + obj.__get__(None, obj), + attributes.QueryableAttribute + ): + return False + return True def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly @@ -1875,501 +1882,6 @@ class Mapper(object): self._memoized_values[key] = value = callable_() return value - def _post_update(self, states, uowtransaction, post_update_cols): - """Issue UPDATE statements on behalf of a relationship() which - specifies post_update. - - """ - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - # if session has a connection callable, - # organize individual states with the connection - # to use for update - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - for state in _sort_states(states): - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - mapper = _state_mapper(state) - - tups.append((state, state.dict, mapper, conn)) - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper: - update = [] - - for state, state_dict, mapper, connection in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - params = {} - hasdata = False - - for col in mapper._cols_by_table[table]: - if col in pks: - params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col) - elif col in post_update_cols: - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - value = history.added[0] - params[col.key] = value - hasdata = True - if hasdata: - update.append((state, state_dict, params, mapper, - connection)) - - if update: - mapper = table_to_mapper[table] - - def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - return table.update(clause) - - statement = self._memo(('post_update', table), update_stmt) - - # execute each UPDATE in the order according to the original - # list of states to guarantee row access order, but - # also group them into common (connection, cols) sets - # to support executemany(). - for key, grouper in groupby( - update, lambda rec: (rec[4], rec[2].keys()) - ): - multiparams = [params for state, state_dict, - params, mapper, conn in grouper] - cached_connections[connection].\ - execute(statement, multiparams) - - def _save_obj(self, states, uowtransaction, single=False): - """Issue ``INSERT`` and/or ``UPDATE`` statements for a list - of objects. - - This is called within the context of a UOWTransaction during a - flush operation, given a list of states to be flushed. The - base mapper in an inheritance hierarchy handles the inserts/ - updates for all descendant mappers. - - """ - - # if batch=false, call _save_obj separately for each object - if not single and not self.batch: - for state in _sort_states(states): - self._save_obj([state], - uowtransaction, - single=True) - return - - # if session has a connection callable, - # organize individual states with the connection - # to use for insert/update - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - - for state in _sort_states(states): - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - has_identity = bool(state.key) - mapper = _state_mapper(state) - instance_key = state.key or mapper._identity_key_from_state(state) - - row_switch = None - - # call before_XXX extensions - if not has_identity: - mapper.dispatch.before_insert(mapper, conn, state) - else: - mapper.dispatch.before_update(mapper, conn, state) - - # detect if we have a "pending" instance (i.e. has - # no instance_key attached to it), and another instance - # with the same identity key already exists as persistent. - # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise orm_exc.FlushError( - "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) - - self._log_debug( - "detected row switch for identity %s. " - "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) - - # remove the "delete" flag from the existing element - uowtransaction.remove_state_actions(existing) - row_switch = existing - - tups.append( - (state, state.dict, mapper, conn, - has_identity, instance_key, row_switch) - ) - - # dictionary of connection->connection_with_cache_options. - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper: - insert = [] - update = [] - - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - - isinsert = not has_identity and not row_switch - - params = {} - value_params = {} - - if isinsert: - has_all_pks = True - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col.key] = \ - mapper.version_id_generator(None) - else: - # pull straight from the dict for - # pending objects - prop = mapper._columntoproperty[col] - value = state_dict.get(prop.key, None) - - if value is None: - if col in pks: - has_all_pks = False - elif col.default is None and \ - col.server_default is None: - params[col.key] = value - - elif isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value - - insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks)) - else: - hasdata = False - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col._label] = \ - mapper._get_committed_state_attr_by_column( - row_switch or state, - row_switch and row_switch.dict - or state_dict, - col) - - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE - ) - if history.added: - params[col.key] = history.added[0] - hasdata = True - else: - params[col.key] = \ - mapper.version_id_generator( - params[col._label]) - - # HACK: check for history, in case the - # history is only - # in a different table than the one - # where the version_id_col is. - for prop in mapper._columntoproperty.\ - itervalues(): - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - hasdata = True - else: - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - if isinstance(history.added[0], - sql.ClauseElement): - value_params[col] = history.added[0] - else: - value = history.added[0] - params[col.key] = value - - if col in pks: - if history.deleted and \ - not row_switch: - # if passive_updates and sync detected - # this was a pk->pk sync, use the new - # value to locate the row, since the - # DB would already have set this - if ("pk_cascaded", state, col) in \ - uowtransaction.\ - attributes: - value = history.added[0] - params[col._label] = value - else: - # use the old value to - # locate the row - value = history.deleted[0] - params[col._label] = value - hasdata = True - else: - # row switch logic can reach us here - # remove the pk from the update params - # so the update doesn't - # attempt to include the pk in the - # update statement - del params[col.key] - value = history.added[0] - params[col._label] = value - if value is None and hasdata: - raise sa_exc.FlushError( - "Can't update table " - "using NULL for primary key " - "value") - else: - hasdata = True - elif col in pks: - value = state.manager[prop.key].\ - impl.get(state, state_dict) - if value is None: - raise sa_exc.FlushError( - "Can't update table " - "using NULL for primary " - "key value") - params[col._label] = value - if hasdata: - update.append((state, state_dict, params, mapper, - connection, value_params)) - - if update: - mapper = table_to_mapper[table] - - needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) - - def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append(mapper.version_id_col ==\ - sql.bindparam(mapper.version_id_col._label, - type_=col.type)) - - return table.update(clause) - - statement = self._memo(('update', table), update_stmt) - - rows = 0 - for state, state_dict, params, mapper, \ - connection, value_params in update: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) - else: - c = cached_connections[connection].\ - execute(statement, params) - - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, - c.context.compiled_parameters[0], - value_params) - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to update %d row(s); " - "%d were matched." % - (table.description, len(update), rows)) - - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) - - if insert: - statement = self._memo(('insert', table), table.insert) - - for (connection, pkeys, hasvalue, has_all_pks), \ - records in groupby(insert, - lambda rec: (rec[4], - rec[2].keys(), - bool(rec[5]), - rec[6]) - ): - if has_all_pks and not hasvalue: - records = list(records) - multiparams = [rec[2] for rec in records] - c = cached_connections[connection].\ - execute(statement, multiparams) - - for (state, state_dict, params, mapper, - conn, value_params, has_all_pks), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, - last_inserted_params, - value_params) - - else: - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_pks in records: - - if value_params: - result = connection.execute( - statement.values(value_params), - params) - else: - result = cached_connections[connection].\ - execute(statement, params) - - primary_key = result.context.inserted_primary_key - - if primary_key is not None: - # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): - prop = mapper._columntoproperty[col] - if state_dict.get(prop.key) is None: - # TODO: would rather say: - #state_dict[prop.key] = pk - mapper._set_state_attr_by_column( - state, - state_dict, - col, pk) - - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - result.context.prefetch_cols, - result.context.postfetch_cols, - result.context.compiled_parameters[0], - value_params) - - - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: - - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [p.key for p in mapper._readonly_props - if p.expire_on_flush or p.key not in state.dict] - ) - if readonly: - state.expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, - # refresh whatever has been expired. - if self.eager_defaults and state.unloaded: - state.key = self._identity_key_from_state(state) - uowtransaction.session.query(self)._load_on_ident( - state.key, refresh_state=state, - only_load_props=state.unloaded) - - # call after_XXX extensions - if not has_identity: - mapper.dispatch.after_insert(mapper, connection, state) - else: - mapper.dispatch.after_update(mapper, connection, state) - - def _postfetch(self, uowtransaction, table, - state, dict_, prefetch_cols, postfetch_cols, - params, value_params): - """During a flush, expire attributes in need of newly - persisted database state.""" - - if self.version_id_col is not None: - prefetch_cols = list(prefetch_cols) + [self.version_id_col] - - for c in prefetch_cols: - if c.key in params and c in self._columntoproperty: - self._set_state_attr_by_column(state, dict_, c, params[c.key]) - - if postfetch_cols: - state.expire_attributes(state.dict, - [self._columntoproperty[c].key - for c in postfetch_cols if c in - self._columntoproperty] - ) - - # synchronize newly inserted ids from one table to the next - # TODO: this still goes a little too often. would be nice to - # have definitive list of "columns that changed" here - for m, equated_pairs in self._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - self.passive_updates) - @util.memoized_property def _table_to_equated(self): """memoized map of tables to collections of columns to be @@ -2387,128 +1899,6 @@ class Mapper(object): return result - def _delete_obj(self, states, uowtransaction): - """Issue ``DELETE`` statements for a list of objects. - - This is called within the context of a UOWTransaction during a - flush operation. - - """ - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - for state in _sort_states(states): - mapper = _state_mapper(state) - - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - mapper.dispatch.before_delete(mapper, conn, state) - - tups.append((state, - state.dict, - _state_mapper(state), - bool(state.key), - conn)) - - table_to_mapper = self._sorted_tables - - for table in reversed(table_to_mapper.keys()): - delete = util.defaultdict(list) - for state, state_dict, mapper, has_identity, connection in tups: - if not has_identity or table not in mapper._pks_by_table: - continue - - params = {} - delete[connection].append(params) - for col in mapper._pks_by_table[table]: - params[col.key] = \ - value = \ - mapper._get_state_attr_by_column( - state, state_dict, col) - if value is None: - raise sa_exc.FlushError( - "Can't delete from table " - "using NULL for primary " - "key value") - - if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, - mapper.version_id_col) - - mapper = table_to_mapper[table] - need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) - - def delete_stmt(): - clause = sql.and_() - for col in mapper._pks_by_table[table]: - clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) - - if need_version_id: - clause.clauses.append( - mapper.version_id_col == - sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type - ) - ) - - return table.delete(clause) - - for connection, del_objects in delete.iteritems(): - statement = self._memo(('delete', table), delete_stmt) - rows = -1 - - connection = cached_connections[connection] - - if need_version_id and \ - not connection.dialect.supports_sane_multi_rowcount: - # TODO: need test coverage for this [ticket:1761] - if connection.dialect.supports_sane_rowcount: - rows = 0 - # execute deletes individually so that versioned - # rows can be verified - for params in del_objects: - c = connection.execute(statement, params) - rows += c.rowcount - else: - util.warn( - "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." % - connection.dialect.dialect_description, - stacklevel=12) - connection.execute(statement, del_objects) - else: - c = connection.execute(statement, del_objects) - if connection.dialect.supports_sane_multi_rowcount: - rows = c.rowcount - - if rows != -1 and rows != len(del_objects): - raise orm_exc.StaleDataError( - "DELETE statement on table '%s' expected to delete %d row(s); " - "%d were matched." % - (table.description, len(del_objects), c.rowcount) - ) - - for state, state_dict, mapper, has_identity, connection in tups: - mapper.dispatch.after_delete(mapper, connection, state) def _instance_processor(self, context, path, reduced_path, adapter, polymorphic_from=None, @@ -2518,6 +1908,12 @@ class Mapper(object): """Produce a mapper level row processor callable which processes rows into mapped instances.""" + # note that this method, most of which exists in a closure + # called _instance(), resists being broken out, as + # attempts to do so tend to add significant function + # call overhead. _instance() is the most + # performance-critical section in the whole ORM. + pk_cols = self.primary_key if polymorphic_from or refresh_state: @@ -2961,13 +2357,6 @@ def _event_on_resurrect(state): state, state.dict, col, val) -def _sort_states(states): - pending = set(states) - persistent = set(s for s in pending if s.key is not None) - pending.difference_update(persistent) - return sorted(pending, key=operator.attrgetter("insert_order")) + \ - sorted(persistent, key=lambda q:q.key[1]) - class _ColumnMapping(util.py25_dict): """Error reporting helper for mapper._columntoproperty.""" diff --git a/libs/sqlalchemy/orm/persistence.py b/libs/sqlalchemy/orm/persistence.py new file mode 100644 index 0000000..55b9bf8 --- /dev/null +++ b/libs/sqlalchemy/orm/persistence.py @@ -0,0 +1,777 @@ +# orm/persistence.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" + +import operator +from itertools import groupby + +from sqlalchemy import sql, util, exc as sa_exc +from sqlalchemy.orm import attributes, sync, \ + exc as orm_exc + +from sqlalchemy.orm.util import _state_mapper, state_str + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_insert, states_to_update = _organize_states_for_save( + base_mapper, + states, + uowtransaction) + + cached_connections = _cached_connection_dict(base_mapper) + + for table, mapper in base_mapper._sorted_tables.iteritems(): + insert = _collect_insert_commands(base_mapper, uowtransaction, + table, states_to_insert) + + update = _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update) + + if update: + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + if insert: + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + table, insert) + + _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update) + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + cached_connections = _cached_connection_dict(base_mapper) + + states_to_update = _organize_states_for_post_update( + base_mapper, + states, uowtransaction) + + + for table, mapper in base_mapper._sorted_tables.iteritems(): + update = _collect_post_update_commands(base_mapper, uowtransaction, + table, states_to_update, + post_update_cols) + + if update: + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + cached_connections = _cached_connection_dict(base_mapper) + + states_to_delete = _organize_states_for_delete( + base_mapper, + states, + uowtransaction) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(table_to_mapper.keys()): + delete = _collect_delete_commands(base_mapper, uowtransaction, + table, states_to_delete) + + mapper = table_to_mapper[table] + + _emit_delete_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, delete) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + states_to_insert = [] + states_to_update = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + has_identity = bool(state.key) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: + instance = \ + uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, + state_str(existing))) + + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if not has_identity and not row_switch: + states_to_insert.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + else: + states_to_update.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + + return states_to_insert, states_to_update + +def _organize_states_for_post_update(base_mapper, states, + uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return list(_connections_for_states(base_mapper, uowtransaction, + states)) + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + states_to_delete = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + mapper.dispatch.before_delete(mapper, connection, state) + + states_to_delete.append((state, dict_, mapper, + bool(state.key), connection)) + return states_to_delete + +def _collect_insert_commands(base_mapper, uowtransaction, table, + states_to_insert): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + insert = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + has_all_pks = True + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col.key] = mapper.version_id_generator(None) + else: + # pull straight from the dict for + # pending objects + prop = mapper._columntoproperty[col] + value = state_dict.get(prop.key, None) + + if value is None: + if col in pks: + has_all_pks = False + elif col.default is None and \ + col.server_default is None: + params[col.key] = value + + elif isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + + insert.append((state, state_dict, params, mapper, + connection, value_params, has_all_pks)) + return insert + +def _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + update = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_update: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + hasdata = hasnull = False + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = \ + mapper._get_committed_state_attr_by_column( + row_switch or state, + row_switch and row_switch.dict + or state_dict, + col) + + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + params[col.key] = history.added[0] + hasdata = True + else: + params[col.key] = mapper.version_id_generator( + params[col._label]) + + # HACK: check for history, in case the + # history is only + # in a different table than the one + # where the version_id_col is. + for prop in mapper._columntoproperty.itervalues(): + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + hasdata = True + else: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + if isinstance(history.added[0], + sql.ClauseElement): + value_params[col] = history.added[0] + else: + value = history.added[0] + params[col.key] = value + + if col in pks: + if history.deleted and \ + not row_switch: + # if passive_updates and sync detected + # this was a pk->pk sync, use the new + # value to locate the row, since the + # DB would already have set this + if ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + value = history.added[0] + params[col._label] = value + else: + # use the old value to + # locate the row + value = history.deleted[0] + params[col._label] = value + hasdata = True + else: + # row switch logic can reach us here + # remove the pk from the update params + # so the update doesn't + # attempt to include the pk in the + # update statement + del params[col.key] + value = history.added[0] + params[col._label] = value + if value is None: + hasnull = True + else: + hasdata = True + elif col in pks: + value = state.manager[prop.key].impl.get( + state, state_dict) + if value is None: + hasnull = True + params[col._label] = value + if hasdata: + if hasnull: + raise sa_exc.FlushError( + "Can't update table " + "using NULL for primary " + "key value") + update.append((state, state_dict, params, mapper, + connection, value_params)) + return update + + +def _collect_post_update_commands(base_mapper, uowtransaction, table, + states_to_update, post_update_cols): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + update = [] + for state, state_dict, mapper, connection in states_to_update: + if table not in mapper._pks_by_table: + continue + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = \ + mapper._get_state_attr_by_column( + state, + state_dict, col) + elif col in post_update_cols: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + update.append((state, state_dict, params, mapper, + connection)) + return update + +def _collect_delete_commands(base_mapper, uowtransaction, table, + states_to_delete): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + delete = util.defaultdict(list) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + if not has_identity or table not in mapper._pks_by_table: + continue + + params = {} + delete[connection].append(params) + for col in mapper._pks_by_table[table]: + params[col.key] = \ + value = \ + mapper._get_state_attr_by_column( + state, state_dict, col) + if value is None: + raise sa_exc.FlushError( + "Can't delete from table " + "using NULL for primary " + "key value") + + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, + mapper.version_id_col) + return delete + + +def _emit_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append(mapper.version_id_col ==\ + sql.bindparam(mapper.version_id_col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('update', table), update_stmt) + + rows = 0 + for state, state_dict, params, mapper, \ + connection, value_params in update: + + if value_params: + c = connection.execute( + statement.values(value_params), + params) + else: + c = cached_connections[connection].\ + execute(statement, params) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(update): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(update), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + +def _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, table, insert): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + statement = base_mapper._memo(('insert', table), table.insert) + + for (connection, pkeys, hasvalue, has_all_pks), \ + records in groupby(insert, + lambda rec: (rec[4], + rec[2].keys(), + bool(rec[5]), + rec[6]) + ): + if has_all_pks and not hasvalue: + records = list(records) + multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ + execute(statement, multiparams) + + for (state, state_dict, params, mapper, + conn, value_params, has_all_pks), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + last_inserted_params, + value_params) + + else: + for state, state_dict, params, mapper, \ + connection, value_params, \ + has_all_pks in records: + + if value_params: + result = connection.execute( + statement.values(value_params), + params) + else: + result = cached_connections[connection].\ + execute(statement, params) + + primary_key = result.context.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for pk, col in zip(primary_key, + mapper._pks_by_table[table]): + prop = mapper._columntoproperty[col] + if state_dict.get(prop.key) is None: + # TODO: would rather say: + #state_dict[prop.key] = pk + mapper._set_state_attr_by_column( + state, + state_dict, + col, pk) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + result.context.prefetch_cols, + result.context.postfetch_cols, + result.context.compiled_parameters[0], + value_params) + + + +def _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('post_update', table), update_stmt) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, grouper in groupby( + update, lambda rec: (rec[4], rec[2].keys()) + ): + connection = key[0] + multiparams = [params for state, state_dict, + params, mapper, conn in grouper] + cached_connections[connection].\ + execute(statement, multiparams) + + +def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, + mapper, table, delete): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def delete_stmt(): + clause = sql.and_() + for col in mapper._pks_by_table[table]: + clause.clauses.append( + col == sql.bindparam(col.key, type_=col.type)) + + if need_version_id: + clause.clauses.append( + mapper.version_id_col == + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + + return table.delete(clause) + + for connection, del_objects in delete.iteritems(): + statement = base_mapper._memo(('delete', table), delete_stmt) + + connection = cached_connections[connection] + + if need_version_id: + # TODO: need test coverage for this [ticket:1761] + if connection.dialect.supports_sane_rowcount: + rows = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows += c.rowcount + if rows != len(del_objects): + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched." % + (table.description, len(del_objects), c.rowcount) + ) + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + connection.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + connection.execute(statement, del_objects) + + +def _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert + \ + states_to_update: + + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state.expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, + # refresh whatever has been expired. + if base_mapper.eager_defaults and state.unloaded: + state.key = base_mapper._identity_key_from_state(state) + uowtransaction.session.query(base_mapper)._load_on_ident( + state.key, refresh_state=state, + only_load_props=state.unloaded) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + +def _postfetch(mapper, uowtransaction, table, + state, dict_, prefetch_cols, postfetch_cols, + params, value_params): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + if mapper.version_id_col is not None: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) + + if postfetch_cols: + state.expire_attributes(state.dict, + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = \ + uowtransaction.session.connection_callable + else: + connection = uowtransaction.transaction.connection( + base_mapper) + connection_callable = None + + for state in _sort_states(states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + + mapper = _state_mapper(state) + + yield state, state.dict, mapper, connection + +def _cached_connection_dict(base_mapper): + # dictionary of connection->connection_with_cache_options. + return util.PopulateDict( + lambda conn:conn.execution_options( + compiled_cache=base_mapper._compiled_cache + )) + +def _sort_states(states): + pending = set(states) + persistent = set(s for s in pending if s.key is not None) + pending.difference_update(persistent) + return sorted(pending, key=operator.attrgetter("insert_order")) + \ + sorted(persistent, key=lambda q:q.key[1]) + + diff --git a/libs/sqlalchemy/orm/query.py b/libs/sqlalchemy/orm/query.py index 9508cb5..aa3dd01 100644 --- a/libs/sqlalchemy/orm/query.py +++ b/libs/sqlalchemy/orm/query.py @@ -133,7 +133,7 @@ class Query(object): with_polymorphic = mapper._with_polymorphic_mappers if mapper.mapped_table not in \ self._polymorphic_adapters: - self.__mapper_loads_polymorphically_with(mapper, + self._mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter( selectable, mapper._equivalent_columns)) @@ -150,7 +150,7 @@ class Query(object): is_aliased_class, with_polymorphic) ent.setup_entity(entity, *d[entity]) - def __mapper_loads_polymorphically_with(self, mapper, adapter): + def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: self._polymorphic_adapters[m2] = adapter for m in m2.iterate_to_root(): @@ -174,10 +174,6 @@ class Query(object): self._from_obj_alias = sql_util.ColumnAdapter( self._from_obj[0], equivs) - def _get_polymorphic_adapter(self, entity, selectable): - self.__mapper_loads_polymorphically_with(entity.mapper, - sql_util.ColumnAdapter(selectable, - entity.mapper._equivalent_columns)) def _reset_polymorphic_adapter(self, mapper): for m2 in mapper._with_polymorphic_mappers: @@ -276,6 +272,7 @@ class Query(object): return self._select_from_entity or \ self._entity_zero().entity_zero + @property def _mapper_entities(self): # TODO: this is wrong, its hardcoded to "primary entity" when @@ -324,13 +321,6 @@ class Query(object): ) return self._entity_zero() - def _generate_mapper_zero(self): - if not getattr(self._entities[0], 'primary_entity', False): - raise sa_exc.InvalidRequestError( - "No primary mapper set up for this Query.") - entity = self._entities[0]._clone() - self._entities = [entity] + self._entities[1:] - return entity def __all_equivs(self): equivs = {} @@ -460,6 +450,62 @@ class Query(object): """ return self.enable_eagerloads(False).statement.alias(name=name) + def cte(self, name=None, recursive=False): + """Return the full SELECT statement represented by this :class:`.Query` + represented as a common table expression (CTE). + + The :meth:`.Query.cte` method is new in 0.7.6. + + Parameters and usage are the same as those of the + :meth:`._SelectBase.cte` method; see that method for + further details. + + Here is the `Postgresql WITH + RECURSIVE example `_. + Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias + of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The ``parts_alias`` + object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped + attributes are available directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String, primary_key=True) + sub_part = Column(String, primary_key=True) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\\ + filter(Part.part=="our part").\\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union_all( + session.query( + parts_alias.part, + parts_alias.sub_part, + parts_alias.quantity).\\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity') + ).\\ + group_by(included_parts.c.sub_part) + + See also: + + :meth:`._SelectBase.cte` + + """ + return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive) + def label(self, name): """Return the full SELECT statement represented by this :class:`.Query`, converted to a scalar subquery with a label of the given name. @@ -601,7 +647,12 @@ class Query(object): such as concrete table mappers. """ - entity = self._generate_mapper_zero() + + if not getattr(self._entities[0], 'primary_entity', False): + raise sa_exc.InvalidRequestError( + "No primary mapper set up for this Query.") + entity = self._entities[0]._clone() + self._entities = [entity] + self._entities[1:] entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, @@ -1041,7 +1092,22 @@ class Query(object): @_generative() def with_lockmode(self, mode): - """Return a new Query object with the specified locking mode.""" + """Return a new Query object with the specified locking mode. + + :param mode: a string representing the desired locking mode. A + corresponding value is passed to the ``for_update`` parameter of + :meth:`~sqlalchemy.sql.expression.select` when the query is + executed. Valid values are: + + ``'update'`` - passes ``for_update=True``, which translates to + ``FOR UPDATE`` (standard SQL, supported by most dialects) + + ``'update_nowait'`` - passes ``for_update='nowait'``, which + translates to ``FOR UPDATE NOWAIT`` (supported by Oracle) + + ``'read'`` - passes ``for_update='read'``, which translates to + ``LOCK IN SHARE MODE`` (supported by MySQL). + """ self._lockmode = mode @@ -1583,7 +1649,6 @@ class Query(object): consistent format with which to form the actual JOIN constructs. """ - self._polymorphic_adapters = self._polymorphic_adapters.copy() if not from_joinpoint: self._reset_joinpoint() @@ -1683,6 +1748,8 @@ class Query(object): onclause, outerjoin, create_aliases, prop): """append a JOIN to the query's from clause.""" + self._polymorphic_adapters = self._polymorphic_adapters.copy() + if left is None: if self._from_obj: left = self._from_obj[0] @@ -1696,7 +1763,29 @@ class Query(object): "are the same entity" % (left, right)) - left_mapper, left_selectable, left_is_aliased = _entity_info(left) + right, right_is_aliased, onclause = self._prepare_right_side( + right, onclause, + outerjoin, create_aliases, + prop) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint({ + '_joinpoint_entity':right, + 'prev':((left, right, prop.key), self._joinpoint) + }) + else: + self._joinpoint = { + '_joinpoint_entity':right + } + + self._join_to_left(left, right, + right_is_aliased, + onclause, outerjoin) + + def _prepare_right_side(self, right, onclause, outerjoin, + create_aliases, prop): right_mapper, right_selectable, right_is_aliased = _entity_info(right) if right_mapper: @@ -1741,24 +1830,13 @@ class Query(object): right = aliased(right) need_adapter = True - # if joining on a MapperProperty path, - # track the path to prevent redundant joins - if not create_aliases and prop: - self._update_joinpoint({ - '_joinpoint_entity':right, - 'prev':((left, right, prop.key), self._joinpoint) - }) - else: - self._joinpoint = { - '_joinpoint_entity':right - } - # if an alias() of the right side was generated here, # apply an adapter to all subsequent filter() calls # until reset_joinpoint() is called. if need_adapter: self._filter_aliases = ORMAdapter(right, - equivalents=right_mapper and right_mapper._equivalent_columns or {}, + equivalents=right_mapper and + right_mapper._equivalent_columns or {}, chain_to=self._filter_aliases) # if the onclause is a ClauseElement, adapt it with any @@ -1771,7 +1849,7 @@ class Query(object): # ensure that columns retrieved from this target in the result # set are also adapted. if aliased_entity and not create_aliases: - self.__mapper_loads_polymorphically_with( + self._mapper_loads_polymorphically_with( right_mapper, ORMAdapter( right, @@ -1779,6 +1857,11 @@ class Query(object): ) ) + return right, right_is_aliased, onclause + + def _join_to_left(self, left, right, right_is_aliased, onclause, outerjoin): + left_mapper, left_selectable, left_is_aliased = _entity_info(left) + # this is an overly broad assumption here, but there's a # very wide variety of situations where we rely upon orm.join's # adaption to glue clauses together, with joined-table inheritance's @@ -2959,7 +3042,9 @@ class _MapperEntity(_QueryEntity): # with_polymorphic() can be applied to aliases if not self.is_aliased_class: self.selectable = from_obj - self.adapter = query._get_polymorphic_adapter(self, from_obj) + query._mapper_loads_polymorphically_with(self.mapper, + sql_util.ColumnAdapter(from_obj, + self.mapper._equivalent_columns)) filter_fn = id @@ -3086,8 +3171,9 @@ class _MapperEntity(_QueryEntity): class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" - def __init__(self, query, column): + def __init__(self, query, column, namespace=None): self.expr = column + self.namespace = namespace if isinstance(column, basestring): column = sql.literal_column(column) @@ -3106,7 +3192,7 @@ class _ColumnEntity(_QueryEntity): for c in column._select_iterable: if c is column: break - _ColumnEntity(query, c) + _ColumnEntity(query, c, namespace=column) if c is not column: return @@ -3147,12 +3233,14 @@ class _ColumnEntity(_QueryEntity): if self.entities: self.entity_zero = list(self.entities)[0] + elif self.namespace is not None: + self.entity_zero = self.namespace else: self.entity_zero = None @property def entity_zero_or_selectable(self): - if self.entity_zero: + if self.entity_zero is not None: return self.entity_zero elif self.actual_froms: return list(self.actual_froms)[0] diff --git a/libs/sqlalchemy/orm/scoping.py b/libs/sqlalchemy/orm/scoping.py index ffc8ef4..3c1cd7f 100644 --- a/libs/sqlalchemy/orm/scoping.py +++ b/libs/sqlalchemy/orm/scoping.py @@ -41,8 +41,9 @@ class ScopedSession(object): scope = kwargs.pop('scope', False) if scope is not None: if self.registry.has(): - raise sa_exc.InvalidRequestError("Scoped session is already present; " - "no new arguments may be specified.") + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified.") else: sess = self.session_factory(**kwargs) self.registry.set(sess) @@ -70,8 +71,8 @@ class ScopedSession(object): self.session_factory.configure(**kwargs) def query_property(self, query_cls=None): - """return a class property which produces a `Query` object against the - class when called. + """return a class property which produces a `Query` object + against the class when called. e.g.:: @@ -121,7 +122,8 @@ def makeprop(name): def get(self): return getattr(self.registry(), name) return property(get, set) -for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'): +for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', + 'is_active', 'autoflush', 'no_autoflush'): setattr(ScopedSession, prop, makeprop(prop)) def clslevel(name): diff --git a/libs/sqlalchemy/orm/session.py b/libs/sqlalchemy/orm/session.py index 4299290..1477870 100644 --- a/libs/sqlalchemy/orm/session.py +++ b/libs/sqlalchemy/orm/session.py @@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, kwargs.update(new_kwargs) - return type("Session", (Sess, class_), {}) + return type("SessionMaker", (Sess, class_), {}) class SessionTransaction(object): @@ -978,6 +978,34 @@ class Session(object): return self._query_cls(entities, self, **kwargs) + @property + @util.contextmanager + def no_autoflush(self): + """Return a context manager that disables autoflush. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + New in 0.7.6. + + """ + autoflush = self.autoflush + self.autoflush = False + yield self + self.autoflush = autoflush + def _autoflush(self): if self.autoflush and not self._flushing: self.flush() @@ -1772,6 +1800,19 @@ class Session(object): return self.transaction and self.transaction.is_active + identity_map = None + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + See also: + + :func:`.identity_key` - operations involving identity keys. + + """ + @property def _dirty_states(self): """The set of all persistent states considered dirty. diff --git a/libs/sqlalchemy/orm/sync.py b/libs/sqlalchemy/orm/sync.py index b016e81..a20e871 100644 --- a/libs/sqlalchemy/orm/sync.py +++ b/libs/sqlalchemy/orm/sync.py @@ -6,6 +6,7 @@ """private module containing functions used for copying data between instances based on join conditions. + """ from sqlalchemy.orm import exc, util as mapperutil, attributes diff --git a/libs/sqlalchemy/orm/unitofwork.py b/libs/sqlalchemy/orm/unitofwork.py index 3cd0f15..8fc5f13 100644 --- a/libs/sqlalchemy/orm/unitofwork.py +++ b/libs/sqlalchemy/orm/unitofwork.py @@ -14,7 +14,7 @@ organizes them in order of dependency, and executes. from sqlalchemy import util, event from sqlalchemy.util import topological -from sqlalchemy.orm import attributes, interfaces +from sqlalchemy.orm import attributes, interfaces, persistence from sqlalchemy.orm import util as mapperutil session = util.importlater("sqlalchemy.orm", "session") @@ -462,7 +462,7 @@ class IssuePostUpdate(PostSortRec): states, cols = uow.post_update_states[self.mapper] states = [s for s in states if uow.states[s][0] == self.isdelete] - self.mapper._post_update(states, uow, cols) + persistence.post_update(self.mapper, states, uow, cols) class SaveUpdateAll(PostSortRec): def __init__(self, uow, mapper): @@ -470,7 +470,7 @@ class SaveUpdateAll(PostSortRec): assert mapper is mapper.base_mapper def execute(self, uow): - self.mapper._save_obj( + persistence.save_obj(self.mapper, uow.states_for_mapper_hierarchy(self.mapper, False, False), uow ) @@ -493,7 +493,7 @@ class DeleteAll(PostSortRec): assert mapper is mapper.base_mapper def execute(self, uow): - self.mapper._delete_obj( + persistence.delete_obj(self.mapper, uow.states_for_mapper_hierarchy(self.mapper, True, False), uow ) @@ -551,7 +551,7 @@ class SaveUpdateState(PostSortRec): if r.__class__ is cls_ and r.mapper is mapper] recs.difference_update(our_recs) - mapper._save_obj( + persistence.save_obj(mapper, [self.state] + [r.state for r in our_recs], uow) @@ -575,7 +575,7 @@ class DeleteState(PostSortRec): r.mapper is mapper] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] - mapper._delete_obj( + persistence.delete_obj(mapper, [s for s in states if uow.states[s][0]], uow) diff --git a/libs/sqlalchemy/orm/util.py b/libs/sqlalchemy/orm/util.py index 0cd5b05..0c5f203 100644 --- a/libs/sqlalchemy/orm/util.py +++ b/libs/sqlalchemy/orm/util.py @@ -11,6 +11,7 @@ from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\ PropComparator, MapperProperty from sqlalchemy.orm import attributes, exc import operator +import re mapperlib = util.importlater("sqlalchemy.orm", "mapperlib") @@ -20,38 +21,52 @@ all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", _INSTRUMENTOR = ('mapper', 'instrumentor') -class CascadeOptions(dict): +class CascadeOptions(frozenset): """Keeps track of the options sent to relationship().cascade""" - def __init__(self, arg=""): - if not arg: - values = set() - else: - values = set(c.strip() for c in arg.split(',')) - - for name in ['save-update', 'delete', 'refresh-expire', - 'merge', 'expunge']: - boolean = name in values or 'all' in values - setattr(self, name.replace('-', '_'), boolean) - if boolean: - self[name] = True + _add_w_all_cascades = all_cascades.difference([ + 'all', 'none', 'delete-orphan']) + _allowed_cascades = all_cascades + + def __new__(cls, arg): + values = set([ + c for c + in re.split('\s*,\s*', arg or "") + if c + ]) + + if values.difference(cls._allowed_cascades): + raise sa_exc.ArgumentError( + "Invalid cascade option(s): %s" % + ", ".join([repr(x) for x in + sorted( + values.difference(cls._allowed_cascades) + )]) + ) + + if "all" in values: + values.update(cls._add_w_all_cascades) + if "none" in values: + values.clear() + values.discard('all') + + self = frozenset.__new__(CascadeOptions, values) + self.save_update = 'save-update' in values + self.delete = 'delete' in values + self.refresh_expire = 'refresh-expire' in values + self.merge = 'merge' in values + self.expunge = 'expunge' in values self.delete_orphan = "delete-orphan" in values - if self.delete_orphan: - self['delete-orphan'] = True if self.delete_orphan and not self.delete: - util.warn("The 'delete-orphan' cascade option requires " - "'delete'.") - - for x in values: - if x not in all_cascades: - raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) + util.warn("The 'delete-orphan' cascade " + "option requires 'delete'.") + return self def __repr__(self): - return "CascadeOptions(%s)" % repr(",".join( - [x for x in ['delete', 'save_update', 'merge', 'expunge', - 'delete_orphan', 'refresh-expire'] - if getattr(self, x, False) is True])) + return "CascadeOptions(%r)" % ( + ",".join([x for x in sorted(self)]) + ) def _validator_events(desc, key, validator): """Runs a validation method on an attribute value to be set or appended.""" @@ -557,15 +572,20 @@ def _entity_descriptor(entity, key): attribute. """ - if not isinstance(entity, (AliasedClass, type)): - entity = entity.class_ + if isinstance(entity, expression.FromClause): + description = entity + entity = entity.c + elif not isinstance(entity, (AliasedClass, type)): + description = entity = entity.class_ + else: + description = entity try: return getattr(entity, key) except AttributeError: raise sa_exc.InvalidRequestError( "Entity '%s' has no property '%s'" % - (entity, key) + (description, key) ) def _orm_columns(entity): diff --git a/libs/sqlalchemy/pool.py b/libs/sqlalchemy/pool.py index a615e8c..6254a4b 100644 --- a/libs/sqlalchemy/pool.py +++ b/libs/sqlalchemy/pool.py @@ -57,6 +57,10 @@ def clear_managers(): manager.close() proxies.clear() +reset_rollback = util.symbol('reset_rollback') +reset_commit = util.symbol('reset_commit') +reset_none = util.symbol('reset_none') + class Pool(log.Identified): """Abstract base class for connection pools.""" @@ -130,7 +134,17 @@ class Pool(log.Identified): self._creator = creator self._recycle = recycle self._use_threadlocal = use_threadlocal - self._reset_on_return = reset_on_return + if reset_on_return in ('rollback', True, reset_rollback): + self._reset_on_return = reset_rollback + elif reset_on_return in (None, False, reset_none): + self._reset_on_return = reset_none + elif reset_on_return in ('commit', reset_commit): + self._reset_on_return = reset_commit + else: + raise exc.ArgumentError( + "Invalid value for 'reset_on_return': %r" + % reset_on_return) + self.echo = echo if _dispatch: self.dispatch._update(_dispatch, only_propagate=False) @@ -330,8 +344,10 @@ def _finalize_fairy(connection, connection_record, pool, ref, echo): if connection is not None: try: - if pool._reset_on_return: + if pool._reset_on_return is reset_rollback: connection.rollback() + elif pool._reset_on_return is reset_commit: + connection.commit() # Immediately close detached instances if connection_record is None: connection.close() @@ -624,11 +640,37 @@ class QueuePool(Pool): :meth:`unique_connection` method is provided to bypass the threadlocal behavior installed into :meth:`connect`. - :param reset_on_return: If true, reset the database state of - connections returned to the pool. This is typically a - ROLLBACK to release locks and transaction resources. - Disable at your own peril. Defaults to True. - + :param reset_on_return: Determine steps to take on + connections as they are returned to the pool. + As of SQLAlchemy 0.7.6, reset_on_return can have any + of these values: + + * 'rollback' - call rollback() on the connection, + to release locks and transaction resources. + This is the default value. The vast majority + of use cases should leave this value set. + * True - same as 'rollback', this is here for + backwards compatibility. + * 'commit' - call commit() on the connection, + to release locks and transaction resources. + A commit here may be desirable for databases that + cache query plans if a commit is emitted, + such as Microsoft SQL Server. However, this + value is more dangerous than 'rollback' because + any data changes present on the transaction + are committed unconditionally. + * None - don't do anything on the connection. + This setting should only be made on a database + that has no transaction support at all, + namely MySQL MyISAM. By not doing anything, + performance can be improved. This + setting should **never be selected** for a + database that supports transactions, + as it will lead to deadlocks and stale + state. + * False - same as None, this is here for + backwards compatibility. + :param listeners: A list of :class:`~sqlalchemy.interfaces.PoolListener`-like objects or dictionaries of callables that receive events when DB-API diff --git a/libs/sqlalchemy/schema.py b/libs/sqlalchemy/schema.py index f0a9297..d295143 100644 --- a/libs/sqlalchemy/schema.py +++ b/libs/sqlalchemy/schema.py @@ -80,6 +80,17 @@ def _get_table_key(name, schema): else: return schema + "." + name +def _validate_dialect_kwargs(kwargs, name): + # validate remaining kwargs that they all specify DB prefixes + if len([k for k in kwargs + if not re.match( + r'^(?:%s)_' % + '|'.join(dialects.__all__), k + ) + ]): + raise TypeError( + "Invalid argument(s) for %s: %r" % (name, kwargs.keys())) + class Table(SchemaItem, expression.TableClause): """Represent a table in a database. @@ -369,9 +380,12 @@ class Table(SchemaItem, expression.TableClause): # allow user-overrides self._init_items(*args) - def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=None): + def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=()): if self.primary_key.columns: - PrimaryKeyConstraint()._set_parent_with_dispatch(self) + PrimaryKeyConstraint(*[ + c for c in self.primary_key.columns + if c.key in exclude_columns + ])._set_parent_with_dispatch(self) if autoload_with: autoload_with.run_callable( @@ -424,7 +438,7 @@ class Table(SchemaItem, expression.TableClause): if not autoload_replace: exclude_columns = [c.name for c in self.c] else: - exclude_columns = None + exclude_columns = () self._autoload(self.metadata, autoload_with, include_columns, exclude_columns) self._extra_kwargs(**kwargs) @@ -432,14 +446,7 @@ class Table(SchemaItem, expression.TableClause): def _extra_kwargs(self, **kwargs): # validate remaining kwargs that they all specify DB prefixes - if len([k for k in kwargs - if not re.match( - r'^(?:%s)_' % - '|'.join(dialects.__all__), k - ) - ]): - raise TypeError( - "Invalid argument(s) for Table: %r" % kwargs.keys()) + _validate_dialect_kwargs(kwargs, "Table") self.kwargs.update(kwargs) def _init_collections(self): @@ -1028,7 +1035,7 @@ class Column(SchemaItem, expression.ColumnClause): "The 'index' keyword argument on Column is boolean only. " "To create indexes with a specific name, create an " "explicit Index object external to the Table.") - Index(expression._generated_label('ix_%s' % self._label), self, unique=self.unique) + Index(expression._truncated_label('ix_%s' % self._label), self, unique=self.unique) elif self.unique: if isinstance(self.unique, basestring): raise exc.ArgumentError( @@ -1093,7 +1100,7 @@ class Column(SchemaItem, expression.ColumnClause): "been assigned.") try: c = self._constructor( - name or self.name, + expression._as_truncated(name or self.name), self.type, key = name or self.key, primary_key = self.primary_key, @@ -1119,6 +1126,8 @@ class Column(SchemaItem, expression.ColumnClause): c.table = selectable selectable._columns.add(c) + if selectable._is_clone_of is not None: + c._is_clone_of = selectable._is_clone_of.columns[c.name] if self.primary_key: selectable.primary_key.add(c) c.dispatch.after_parent_attach(c, selectable) @@ -1809,7 +1818,8 @@ class Constraint(SchemaItem): __visit_name__ = 'constraint' def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None): + _create_rule=None, + **kw): """Create a SQL constraint. :param name: @@ -1839,6 +1849,10 @@ class Constraint(SchemaItem): _create_rule is used by some types to create constraints. Currently, its call signature is subject to change at any time. + + :param \**kwargs: + Dialect-specific keyword parameters, see the documentation + for various dialects and constraints regarding options here. """ @@ -1847,6 +1861,8 @@ class Constraint(SchemaItem): self.initially = initially self._create_rule = _create_rule util.set_creation_order(self) + _validate_dialect_kwargs(kw, self.__class__.__name__) + self.kwargs = kw @property def table(self): @@ -2192,6 +2208,8 @@ class Index(ColumnCollectionMixin, SchemaItem): self.table = None # will call _set_parent() if table-bound column # objects are present + if not columns: + util.warn("No column names or expressions given for Index.") ColumnCollectionMixin.__init__(self, *columns) self.name = name self.unique = kw.pop('unique', False) @@ -3004,9 +3022,11 @@ def _to_schema_column(element): return element def _to_schema_column_or_string(element): - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, (basestring, expression.ColumnElement)): + raise exc.ArgumentError("Element %r is not a string name or column element" % element) + return element class _CreateDropBase(DDLElement): """Base class for DDL constucts that represent CREATE and DROP or diff --git a/libs/sqlalchemy/sql/compiler.py b/libs/sqlalchemy/sql/compiler.py index b0a55b8..c5c6f9e 100644 --- a/libs/sqlalchemy/sql/compiler.py +++ b/libs/sqlalchemy/sql/compiler.py @@ -154,9 +154,10 @@ class _CompileLabel(visitors.Visitable): __visit_name__ = 'label' __slots__ = 'element', 'name' - def __init__(self, col, name): + def __init__(self, col, name, alt_names=()): self.element = col self.name = name + self._alt_names = alt_names @property def proxy_set(self): @@ -251,6 +252,10 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + # true if the paramstyle is positional self.positional = dialect.positional if self.positional: @@ -354,14 +359,16 @@ class SQLCompiler(engine.Compiled): # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. if within_columns_clause and not within_label_clause: - if isinstance(label.name, sql._generated_label): + if isinstance(label.name, sql._truncated_label): labelname = self._truncated_identifier("colident", label.name) else: labelname = label.name if result_map is not None: - result_map[labelname.lower()] = \ - (label.name, (label, label.element, labelname),\ + result_map[labelname.lower()] = ( + label.name, + (label, label.element, labelname, ) + + label._alt_names, label.type) return label.element._compiler_dispatch(self, @@ -376,17 +383,19 @@ class SQLCompiler(engine.Compiled): **kw) def visit_column(self, column, result_map=None, **kwargs): - name = column.name + name = orig_name = column.name if name is None: raise exc.CompileError("Cannot compile Column object until " "it's 'name' is assigned.") is_literal = column.is_literal - if not is_literal and isinstance(name, sql._generated_label): + if not is_literal and isinstance(name, sql._truncated_label): name = self._truncated_identifier("colident", name) if result_map is not None: - result_map[name.lower()] = (name, (column, ), column.type) + result_map[name.lower()] = (orig_name, + (column, name, column.key), + column.type) if is_literal: name = self.escape_literal_column(name) @@ -404,7 +413,7 @@ class SQLCompiler(engine.Compiled): else: schema_prefix = '' tablename = table.name - if isinstance(tablename, sql._generated_label): + if isinstance(tablename, sql._truncated_label): tablename = self._truncated_identifier("alias", tablename) return schema_prefix + \ @@ -646,7 +655,8 @@ class SQLCompiler(engine.Compiled): if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if existing.unique or bindparam.unique: + if (existing.unique or bindparam.unique) and \ + not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " "unique bind parameter of the same name" % @@ -703,7 +713,7 @@ class SQLCompiler(engine.Compiled): return self.bind_names[bindparam] bind_name = bindparam.key - if isinstance(bind_name, sql._generated_label): + if isinstance(bind_name, sql._truncated_label): bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation @@ -715,7 +725,7 @@ class SQLCompiler(engine.Compiled): if (ident_class, name) in self.truncated_names: return self.truncated_names[(ident_class, name)] - anonname = name % self.anon_map + anonname = name.apply_map(self.anon_map) if len(anonname) > self.label_length: counter = self.truncated_names.get(ident_class, 1) @@ -744,10 +754,49 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, **kwargs): + if isinstance(cte.name, sql._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + if cte.cte_alias: + if isinstance(cte.cte_alias, sql._truncated_label): + cte_alias = self._truncated_identifier("alias", cte.cte_alias) + else: + cte_alias = cte.cte_alias + if not cte.cte_alias and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, sql.Select): + col_source = cte.original + elif isinstance(cte.original, sql.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c.key for c in util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join(recur_cols)) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.ctes[cte] = text + if asfrom: + if cte.cte_alias: + text = self.preparer.format_alias(cte, cte_alias) + text += " AS " + cte_name + else: + return self.preparer.format_alias(cte, cte_name) + return text + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: - if isinstance(alias.name, sql._generated_label): + if isinstance(alias.name, sql._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) else: alias_name = alias.name @@ -775,8 +824,14 @@ class SQLCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - elif select is not None and select.use_labels and column._label: - return _CompileLabel(column, column._label) + elif select is not None and \ + select.use_labels and \ + column._label: + return _CompileLabel( + column, + column._label, + alt_names=(column._key_label, ) + ) elif \ asfrom and \ @@ -784,7 +839,8 @@ class SQLCompiler(engine.Compiled): not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._generated_label(column.name)) + return _CompileLabel(column, sql._as_truncated(column.name), + alt_names=(column.key,)) elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) \ and (not hasattr(column, 'name') or \ @@ -799,6 +855,9 @@ class SQLCompiler(engine.Compiled): def get_from_hint_text(self, table, text): return None + def get_crud_hint_text(self, table, text): + return None + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=1, **kwargs): @@ -897,6 +956,15 @@ class SQLCompiler(engine.Compiled): if select.for_update: text += self.for_update_clause(select) + if self.ctes and \ + compound_index==1 and not entry: + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + text = cte_text + text + self.stack.pop(-1) if asfrom and parens: @@ -904,6 +972,12 @@ class SQLCompiler(engine.Compiled): else: return text + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list. @@ -977,12 +1051,26 @@ class SQLCompiler(engine.Compiled): text = "INSERT" + prefixes = [self.process(x) for x in insert_stmt._prefixes] if prefixes: text += " " + " ".join(prefixes) text += " INTO " + preparer.format_table(insert_stmt.table) + if insert_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + insert_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if insert_stmt.table in dialect_hints: + text += " " + self.get_crud_hint_text( + insert_stmt.table, + dialect_hints[insert_stmt.table] + ) + if colparams or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in colparams]) @@ -1014,21 +1102,25 @@ class SQLCompiler(engine.Compiled): extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. - + MySQL overrides this. """ return self.preparer.format_table(from_table) - def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): """Provide a hook to override the generation of an UPDATE..FROM clause. - + MySQL overrides this. """ return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, **kw) + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) for t in extra_froms) def visit_update(self, update_stmt, **kw): @@ -1045,6 +1137,21 @@ class SQLCompiler(engine.Compiled): update_stmt.table, extra_froms, **kw) + if update_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + update_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if update_stmt.table in dialect_hints: + text += " " + self.get_crud_hint_text( + update_stmt.table, + dialect_hints[update_stmt.table] + ) + else: + dialect_hints = None + text += ' SET ' if extra_froms and self.render_table_with_column_in_update_from: text += ', '.join( @@ -1067,7 +1174,8 @@ class SQLCompiler(engine.Compiled): extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, - extra_froms, **kw) + extra_froms, + dialect_hints, **kw) if extra_from_text: text += " " + extra_from_text @@ -1133,7 +1241,6 @@ class SQLCompiler(engine.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(sql._column_as_key(k), v) - # create a list of column assignment clauses as tuples values = [] @@ -1192,7 +1299,7 @@ class SQLCompiler(engine.Compiled): # "defaults", "primary key cols", etc. for c in stmt.table.columns: if c.key in parameters and c.key not in check_columns: - value = parameters[c.key] + value = parameters.pop(c.key) if sql._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is required) @@ -1288,6 +1395,17 @@ class SQLCompiler(engine.Compiled): self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) + + if parameters and stmt.parameters: + check = set(parameters).intersection( + sql._column_as_key(k) for k in stmt.parameters + ).difference(check_columns) + if check: + util.warn( + "Unconsumed column names: %s" % + (", ".join(check)) + ) + return values def visit_delete(self, delete_stmt): @@ -1296,6 +1414,21 @@ class SQLCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + if delete_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + delete_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if delete_stmt.table in dialect_hints: + text += " " + self.get_crud_hint_text( + delete_stmt.table, + dialect_hints[delete_stmt.table] + ) + else: + dialect_hints = None + if delete_stmt._returning: self.returning = delete_stmt._returning if self.returning_precedes_values: @@ -1445,7 +1578,7 @@ class DDLCompiler(engine.Compiled): return "\nDROP TABLE " + self.preparer.format_table(drop.element) def _index_identifier(self, ident): - if isinstance(ident, sql._generated_label): + if isinstance(ident, sql._truncated_label): max = self.dialect.max_index_name_length or \ self.dialect.max_identifier_length if len(ident) > max: diff --git a/libs/sqlalchemy/sql/expression.py b/libs/sqlalchemy/sql/expression.py index bff086e..aa67f44 100644 --- a/libs/sqlalchemy/sql/expression.py +++ b/libs/sqlalchemy/sql/expression.py @@ -832,6 +832,14 @@ def tuple_(*expr): [(1, 2), (5, 12), (10, 19)] ) + .. warning:: + + The composite IN construct is not supported by all backends, + and is currently known to work on Postgresql and MySQL, + but not SQLite. Unsupported backends will raise + a subclass of :class:`~sqlalchemy.exc.DBAPIError` when such + an expression is invoked. + """ return _Tuple(*expr) @@ -1275,14 +1283,48 @@ func = _FunctionGenerator() # TODO: use UnaryExpression for this instead ? modifier = _FunctionGenerator(group=False) -class _generated_label(unicode): - """A unicode subclass used to identify dynamically generated names.""" +class _truncated_label(unicode): + """A unicode subclass used to identify symbolic " + "names that may require truncation.""" + + def apply_map(self, map_): + return self + +# for backwards compatibility in case +# someone is re-implementing the +# _truncated_identifier() sequence in a custom +# compiler +_generated_label = _truncated_label + +class _anonymous_label(_truncated_label): + """A unicode subclass used to identify anonymously + generated names.""" + + def __add__(self, other): + return _anonymous_label( + unicode(self) + + unicode(other)) + + def __radd__(self, other): + return _anonymous_label( + unicode(other) + + unicode(self)) + + def apply_map(self, map_): + return self % map_ -def _escape_for_generated(x): - if isinstance(x, _generated_label): - return x +def _as_truncated(value): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(value, _truncated_label): + return value else: - return x.replace('%', '%%') + return _truncated_label(value) def _string_or_unprintable(element): if isinstance(element, basestring): @@ -1466,6 +1508,7 @@ class ClauseElement(Visitable): supports_execution = False _from_objects = [] bind = None + _is_clone_of = None def _clone(self): """Create a shallow copy of this ClauseElement. @@ -1514,7 +1557,7 @@ class ClauseElement(Visitable): f = self while f is not None: s.add(f) - f = getattr(f, '_is_clone_of', None) + f = f._is_clone_of return s def __getstate__(self): @@ -2063,6 +2106,8 @@ class ColumnElement(ClauseElement, _CompareMixin): foreign_keys = [] quote = None _label = None + _key_label = None + _alt_names = () @property def _select_iterable(self): @@ -2109,9 +2154,14 @@ class ColumnElement(ClauseElement, _CompareMixin): else: key = name - co = ColumnClause(name, selectable, type_=getattr(self, + co = ColumnClause(_as_truncated(name), + selectable, + type_=getattr(self, 'type', None)) co.proxies = [self] + if selectable._is_clone_of is not None: + co._is_clone_of = \ + selectable._is_clone_of.columns[key] selectable._columns[key] = co return co @@ -2157,7 +2207,7 @@ class ColumnElement(ClauseElement, _CompareMixin): expressions and function calls. """ - return _generated_label('%%(%d %s)s' % (id(self), getattr(self, + return _anonymous_label('%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))) class ColumnCollection(util.OrderedProperties): @@ -2420,6 +2470,13 @@ class FromClause(Selectable): """ + def embedded(expanded_proxy_set, target_set): + for t in target_set.difference(expanded_proxy_set): + if not set(_expand_cloned([t]) + ).intersection(expanded_proxy_set): + return False + return True + # dont dig around if the column is locally present if self.c.contains_column(column): return column @@ -2427,10 +2484,10 @@ class FromClause(Selectable): target_set = column.proxy_set cols = self.c for c in cols: - i = target_set.intersection(itertools.chain(*[p._cloned_set - for p in c.proxy_set])) + expanded_proxy_set = set(_expand_cloned(c.proxy_set)) + i = target_set.intersection(expanded_proxy_set) if i and (not require_embedded - or c.proxy_set.issuperset(target_set)): + or embedded(expanded_proxy_set, target_set)): if col is None: # no corresponding column yet, pick this one. @@ -2580,10 +2637,10 @@ class _BindParamClause(ColumnElement): """ if unique: - self.key = _generated_label('%%(%d %s)s' % (id(self), key + self.key = _anonymous_label('%%(%d %s)s' % (id(self), key or 'param')) else: - self.key = key or _generated_label('%%(%d param)s' + self.key = key or _anonymous_label('%%(%d param)s' % id(self)) # identifiying key that won't change across @@ -2631,14 +2688,14 @@ class _BindParamClause(ColumnElement): def _clone(self): c = ClauseElement._clone(self) if self.unique: - c.key = _generated_label('%%(%d %s)s' % (id(c), c._orig_key + c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key or 'param')) return c def _convert_to_unique(self): if not self.unique: self.unique = True - self.key = _generated_label('%%(%d %s)s' % (id(self), + self.key = _anonymous_label('%%(%d %s)s' % (id(self), self._orig_key or 'param')) def compare(self, other, **kw): @@ -3607,7 +3664,7 @@ class Alias(FromClause): if name is None: if self.original.named_with_column: name = getattr(self.original, 'name', None) - name = _generated_label('%%(%d %s)s' % (id(self), name + name = _anonymous_label('%%(%d %s)s' % (id(self), name or 'anon')) self.name = name @@ -3662,6 +3719,47 @@ class Alias(FromClause): def bind(self): return self.element.bind +class CTE(Alias): + """Represent a Common Table Expression. + + The :class:`.CTE` object is obtained using the + :meth:`._SelectBase.cte` method from any selectable. + See that method for complete examples. + + New in 0.7.6. + + """ + __visit_name__ = 'cte' + def __init__(self, selectable, + name=None, + recursive=False, + cte_alias=False): + self.recursive = recursive + self.cte_alias = cte_alias + super(CTE, self).__init__(selectable, name=name) + + def alias(self, name=None): + return CTE( + self.original, + name=name, + recursive=self.recursive, + cte_alias = self.name + ) + + def union(self, other): + return CTE( + self.original.union(other), + name=self.name, + recursive=self.recursive + ) + + def union_all(self, other): + return CTE( + self.original.union_all(other), + name=self.name, + recursive=self.recursive + ) + class _Grouping(ColumnElement): """Represent a grouping within a column expression""" @@ -3807,9 +3905,12 @@ class _Label(ColumnElement): def __init__(self, name, element, type_=None): while isinstance(element, _Label): element = element.element - self.name = self.key = self._label = name \ - or _generated_label('%%(%d %s)s' % (id(self), + if name: + self.name = name + else: + self.name = _anonymous_label('%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon'))) + self.key = self._label = self._key_label = self.name self._element = element self._type = type_ self.quote = element.quote @@ -3957,7 +4058,17 @@ class ColumnClause(_Immutable, ColumnElement): # end Py2K @_memoized_property + def _key_label(self): + if self.key != self.name: + return self._gen_label(self.key) + else: + return self._label + + @_memoized_property def _label(self): + return self._gen_label(self.name) + + def _gen_label(self, name): t = self.table if self.is_literal: return None @@ -3965,11 +4076,9 @@ class ColumnClause(_Immutable, ColumnElement): elif t is not None and t.named_with_column: if getattr(t, 'schema', None): label = t.schema.replace('.', '_') + "_" + \ - _escape_for_generated(t.name) + "_" + \ - _escape_for_generated(self.name) + t.name + "_" + name else: - label = _escape_for_generated(t.name) + "_" + \ - _escape_for_generated(self.name) + label = t.name + "_" + name # ensure the label name doesn't conflict with that # of an existing column @@ -3981,10 +4090,10 @@ class ColumnClause(_Immutable, ColumnElement): counter += 1 label = _label - return _generated_label(label) + return _as_truncated(label) else: - return self.name + return name def label(self, name): # currently, anonymous labels don't occur for @@ -4010,12 +4119,15 @@ class ColumnClause(_Immutable, ColumnElement): # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - name or self.name, + _as_truncated(name or self.name), selectable=selectable, type_=self.type, is_literal=is_literal ) c.proxies = [self] + if selectable._is_clone_of is not None: + c._is_clone_of = \ + selectable._is_clone_of.columns[c.name] if attach: selectable._columns[c.name] = c @@ -4218,6 +4330,125 @@ class _SelectBase(Executable, FromClause): """ return self.as_scalar().label(name) + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + The :meth:`._SelectBase.cte` method is new in 0.7.6. + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION ALL in order to derive rows + from those already selected. + + The following examples illustrate two examples from + Postgresql's documentation at + http://www.postgresql.org/docs/8.4/static/queries-with.html. + + Example 1, non recursive:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union_all( + select([ + parts_alias.c.part, + parts_alias.c.sub_part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity') + ]).\ + select_from(included_parts.join(parts, + included_parts.c.part==parts.c.part)).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + + See also: + + :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + @_generative @util.deprecated('0.6', message=":func:`.autocommit` is deprecated. Use " @@ -4602,7 +4833,7 @@ class Select(_SelectBase): The text of the hint is rendered in the appropriate location for the database backend in use, relative to the given :class:`.Table` or :class:`.Alias` passed as the - *selectable* argument. The dialect implementation + ``selectable`` argument. The dialect implementation typically uses Python string substitution syntax with the token ``%(name)s`` to render the name of the table or alias. E.g. when using Oracle, the @@ -4999,7 +5230,9 @@ class Select(_SelectBase): def _populate_column_collection(self): for c in self.inner_columns: if hasattr(c, '_make_proxy'): - c._make_proxy(self, name=self.use_labels and c._label or None) + c._make_proxy(self, + name=self.use_labels + and c._label or None) def self_group(self, against=None): """return a 'grouping' construct as per the ClauseElement @@ -5086,6 +5319,7 @@ class UpdateBase(Executable, ClauseElement): _execution_options = \ Executable._execution_options.union({'autocommit': True}) kwargs = util.immutabledict() + _hints = util.immutabledict() def _process_colparams(self, parameters): if isinstance(parameters, (list, tuple)): @@ -5166,6 +5400,45 @@ class UpdateBase(Executable, ClauseElement): """ self._returning = cols + @_generative + def with_hint(self, text, selectable=None, dialect_name="*"): + """Add a table hint for a single table to this + INSERT/UPDATE/DELETE statement. + + .. note:: + + :meth:`.UpdateBase.with_hint` currently applies only to + Microsoft SQL Server. For MySQL INSERT hints, use + :meth:`.Insert.prefix_with`. UPDATE/DELETE hints for + MySQL will be added in a future release. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the :class:`.Table` that is the subject of this + statement, or optionally to that of the given + :class:`.Table` passed as the ``selectable`` argument. + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add a hint + that only takes effect for SQL Server:: + + mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql") + + New in 0.7.6. + + :param text: Text of the hint. + :param selectable: optional :class:`.Table` that specifies + an element of the FROM clause within an UPDATE or DELETE + to be the subject of the hint - applies only to certain backends. + :param dialect_name: defaults to ``*``, if specified as the name + of a particular dialect, will apply these hints only when + that dialect is in use. + """ + if selectable is None: + selectable = self.table + + self._hints = self._hints.union({(selectable, dialect_name):text}) + class ValuesBase(UpdateBase): """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs.""" diff --git a/libs/sqlalchemy/sql/visitors.py b/libs/sqlalchemy/sql/visitors.py index cdcf40a..5354fbc 100644 --- a/libs/sqlalchemy/sql/visitors.py +++ b/libs/sqlalchemy/sql/visitors.py @@ -34,11 +34,19 @@ __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', 'cloned_traverse', 'replacement_traverse'] class VisitableType(type): - """Metaclass which checks for a `__visit_name__` attribute and - applies `_compiler_dispatch` method to classes. - + """Metaclass which assigns a `_compiler_dispatch` method to classes + having a `__visit_name__` attribute. + + The _compiler_dispatch attribute becomes an instance method which + looks approximately like the following:: + + def _compiler_dispatch (self, visitor, **kw): + '''Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params.''' + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + Classes having no __visit_name__ attribute will remain unaffected. """ - def __init__(cls, clsname, bases, clsdict): if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): super(VisitableType, cls).__init__(clsname, bases, clsdict) @@ -48,19 +56,31 @@ class VisitableType(type): super(VisitableType, cls).__init__(clsname, bases, clsdict) + def _generate_dispatch(cls): - # set up an optimized visit dispatch function - # for use by the compiler + """Return an optimized visit dispatch function for the cls + for use by the compiler. + """ if '__visit_name__' in cls.__dict__: visit_name = cls.__visit_name__ if isinstance(visit_name, str): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. getter = operator.attrgetter("visit_%s" % visit_name) def _compiler_dispatch(self, visitor, **kw): return getter(visitor)(self, **kw) else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. def _compiler_dispatch(self, visitor, **kw): return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + _compiler_dispatch.__doc__ = \ + """Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params. + """ cls._compiler_dispatch = _compiler_dispatch class Visitable(object): diff --git a/libs/sqlalchemy/types.py b/libs/sqlalchemy/types.py index 8c8e6eb..512ac62 100644 --- a/libs/sqlalchemy/types.py +++ b/libs/sqlalchemy/types.py @@ -397,7 +397,7 @@ class TypeDecorator(TypeEngine): def copy(self): return MyType(self.impl.length) - The class-level "impl" variable is required, and can reference any + The class-level "impl" attribute is required, and can reference any TypeEngine class. Alternatively, the load_dialect_impl() method can be used to provide different type classes based on the dialect given; in this case, the "impl" variable can reference @@ -457,15 +457,19 @@ class TypeDecorator(TypeEngine): Arguments sent here are passed to the constructor of the class assigned to the ``impl`` class level attribute, - where the ``self.impl`` attribute is assigned an instance - of the implementation type. If ``impl`` at the class level - is already an instance, then it's assigned to ``self.impl`` - as is. + assuming the ``impl`` is a callable, and the resulting + object is assigned to the ``self.impl`` instance attribute + (thus overriding the class attribute of the same name). + + If the class level ``impl`` is not a callable (the unusual case), + it will be assigned to the same instance attribute 'as-is', + ignoring those arguments passed to the constructor. Subclasses can override this to customize the generation - of ``self.impl``. + of ``self.impl`` entirely. """ + if not hasattr(self.__class__, 'impl'): raise AssertionError("TypeDecorator implementations " "require a class-level variable " @@ -475,6 +479,9 @@ class TypeDecorator(TypeEngine): def _gen_dialect_impl(self, dialect): + """ + #todo + """ adapted = dialect.type_descriptor(self) if adapted is not self: return adapted @@ -494,6 +501,9 @@ class TypeDecorator(TypeEngine): @property def _type_affinity(self): + """ + #todo + """ return self.impl._type_affinity def type_engine(self, dialect): @@ -531,7 +541,6 @@ class TypeDecorator(TypeEngine): def __getattr__(self, key): """Proxy all other undefined accessors to the underlying implementation.""" - return getattr(self.impl, key) def process_bind_param(self, value, dialect): @@ -542,29 +551,52 @@ class TypeDecorator(TypeEngine): :class:`.TypeEngine` object, and from there to the DBAPI ``execute()`` method. - :param value: the value. Can be None. + The operation could be anything desired to perform custom + behavior, such as transforming or serializing data. + This could also be used as a hook for validating logic. + + This operation should be designed with the reverse operation + in mind, which would be the process_result_value method of + this class. + + :param value: Data to operate upon, of any type expected by + this method in the subclass. Can be ``None``. :param dialect: the :class:`.Dialect` in use. """ + raise NotImplementedError() def process_result_value(self, value, dialect): """Receive a result-row column value to be converted. + Subclasses should implement this method to operate on data + fetched from the database. + Subclasses override this method to return the value that should be passed back to the application, given a value that is already processed by the underlying :class:`.TypeEngine` object, originally from the DBAPI cursor method ``fetchone()`` or similar. - :param value: the value. Can be None. + The operation could be anything desired to perform custom + behavior, such as transforming or serializing data. + This could also be used as a hook for validating logic. + + :param value: Data to operate upon, of any type expected by + this method in the subclass. Can be ``None``. :param dialect: the :class:`.Dialect` in use. + This operation should be designed to be reversible by + the "process_bind_param" method of this class. + """ + raise NotImplementedError() def bind_processor(self, dialect): - """Provide a bound value processing function for the given :class:`.Dialect`. + """Provide a bound value processing function for the + given :class:`.Dialect`. This is the method that fulfills the :class:`.TypeEngine` contract for bound value conversion. :class:`.TypeDecorator` @@ -575,6 +607,11 @@ class TypeDecorator(TypeEngine): though its likely best to use :meth:`process_bind_param` so that the processing provided by ``self.impl`` is maintained. + :param dialect: Dialect instance in use. + + This method is the reverse counterpart to the + :meth:`result_processor` method of this class. + """ if self.__class__.process_bind_param.func_code \ is not TypeDecorator.process_bind_param.func_code: @@ -604,6 +641,12 @@ class TypeDecorator(TypeEngine): though its likely best to use :meth:`process_result_value` so that the processing provided by ``self.impl`` is maintained. + :param dialect: Dialect instance in use. + :param coltype: An SQLAlchemy data type + + This method is the reverse counterpart to the + :meth:`bind_processor` method of this class. + """ if self.__class__.process_result_value.func_code \ is not TypeDecorator.process_result_value.func_code: @@ -654,6 +697,7 @@ class TypeDecorator(TypeEngine): has local state that should be deep-copied. """ + instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__) return instance @@ -724,6 +768,9 @@ class TypeDecorator(TypeEngine): return self.impl.is_mutable() def _adapt_expression(self, op, othertype): + """ + #todo + """ op, typ =self.impl._adapt_expression(op, othertype) if typ is self.impl: return op, self diff --git a/libs/sqlalchemy/util/__init__.py b/libs/sqlalchemy/util/__init__.py index 5712940..13914aa 100644 --- a/libs/sqlalchemy/util/__init__.py +++ b/libs/sqlalchemy/util/__init__.py @@ -7,7 +7,7 @@ from compat import callable, cmp, reduce, defaultdict, py25_dict, \ threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \ update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\ - parse_qsl, any + parse_qsl, any, contextmanager from _collections import NamedTuple, ImmutableContainer, immutabledict, \ Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ diff --git a/libs/sqlalchemy/util/compat.py b/libs/sqlalchemy/util/compat.py index 07652f3..99b92b1 100644 --- a/libs/sqlalchemy/util/compat.py +++ b/libs/sqlalchemy/util/compat.py @@ -57,6 +57,12 @@ buffer = buffer # end Py2K try: + from contextlib import contextmanager +except ImportError: + def contextmanager(fn): + return fn + +try: from functools import update_wrapper except ImportError: def update_wrapper(wrapper, wrapped,