diff --git a/CouchPotato.py b/CouchPotato.py index 6c3a8de..dd1e654 100755 --- a/CouchPotato.py +++ b/CouchPotato.py @@ -3,6 +3,7 @@ from logging import handlers from os.path import dirname import logging import os +import select import signal import socket import subprocess @@ -121,6 +122,8 @@ if __name__ == '__main__': l.run() except KeyboardInterrupt: pass + except select.error: + pass except SystemExit: raise except socket.error as (nr, msg): diff --git a/README.md b/README.md index ea03f1e..2eb0f2f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Windows: * Install [PyWin32 2.7](http://sourceforge.net/projects/pywin32/files/pywin32/Build%20217/) and [GIT](http://git-scm.com/) * If you come and ask on the forums 'why directory selection no work?', I will kill a kitten, also this is because you need PyWin32 * Open up `Git Bash` (or CMD) and go to the folder you want to install CP. Something like Program Files. -* Run `git clone https://RuudBurger@github.com/RuudBurger/CouchPotatoServer.git`. +* Run `git clone https://github.com/RuudBurger/CouchPotatoServer.git`. * You can now start CP via `CouchPotatoServer\CouchPotato.py` to start OSx: @@ -23,14 +23,14 @@ OSx: * Install [GIT](http://git-scm.com/) * Open up `Terminal` * Go to your App folder `cd /Applications` -* Run `git clone https://RuudBurger@github.com/RuudBurger/CouchPotatoServer.git` +* Run `git clone https://github.com/RuudBurger/CouchPotatoServer.git` * Then do `python CouchPotatoServer/CouchPotato.py` Linux (ubuntu / debian): * Install [GIT](http://git-scm.com/) with `apt-get install git-core` * 'cd' to the folder of your choosing. -* Run `git clone https://RuudBurger@github.com/RuudBurger/CouchPotatoServer.git` +* Run `git clone https://github.com/RuudBurger/CouchPotatoServer.git` * Then do `python CouchPotatoServer/CouchPotato.py` to start * To run on boot copy the init script. `cp CouchPotatoServer/init/ubuntu /etc/init.d/couchpotato` * Change the paths inside the init script. `nano /etc/init.d/couchpotato` diff --git a/couchpotato/__init__.py b/couchpotato/__init__.py index 57d7d33..3e28363 100644 --- a/couchpotato/__init__.py +++ b/couchpotato/__init__.py @@ -24,11 +24,7 @@ web = Blueprint('web', __name__) def get_session(engine = None): - engine = engine if engine else get_engine() - return scoped_session(sessionmaker(bind = engine)) - -def get_engine(): - return create_engine(Env.get('db_path') + '?check_same_thread=False', echo = False) + return Env.getSession(engine) def addView(route, func, static = False): web.add_url_rule(route + ('' if static else '/'), endpoint = route if route else 'index', view_func = func) diff --git a/couchpotato/api.py b/couchpotato/api.py index 934fa94..b1dee1b 100644 --- a/couchpotato/api.py +++ b/couchpotato/api.py @@ -6,8 +6,8 @@ api = Blueprint('api', __name__) api_docs = {} api_docs_missing = [] -def addApiView(route, func, static = False, docs = None): - api.add_url_rule(route + ('' if static else '/'), endpoint = route.replace('.', '::') if route else 'index', view_func = func) +def addApiView(route, func, static = False, docs = None, **kwargs): + api.add_url_rule(route + ('' if static else '/'), endpoint = route.replace('.', '::') if route else 'index', view_func = func, **kwargs) if docs: api_docs[route[4:] if route[0:4] == 'api.' else route] = docs else: diff --git a/couchpotato/core/_base/_core/main.py b/couchpotato/core/_base/_core/main.py index 81ccac2..a496df6 100644 --- a/couchpotato/core/_base/_core/main.py +++ b/couchpotato/core/_base/_core/main.py @@ -47,7 +47,6 @@ class Core(Plugin): addEvent('setting.save.core.password', self.md5Password) addEvent('setting.save.core.api_key', self.checkApikey) - self.removeRestartFile() def md5Password(self, value): return md5(value) if value else '' @@ -119,9 +118,6 @@ class Core(Plugin): time.sleep(1) - if restart: - self.createFile(self.restartFilePath(), 'This is the most suckiest way to register if CP is restarted. Ever...') - log.debug('Save to shutdown/restart') try: @@ -133,15 +129,6 @@ class Core(Plugin): fireEvent('app.after_shutdown', restart = restart) - def removeRestartFile(self): - try: - os.remove(self.restartFilePath()) - except: - pass - - def restartFilePath(self): - return os.path.join(Env.get('data_dir'), 'restart') - def launchBrowser(self): if Env.setting('launch_browser'): diff --git a/couchpotato/core/_base/scheduler/main.py b/couchpotato/core/_base/scheduler/main.py index fb3b01e..d09efed 100644 --- a/couchpotato/core/_base/scheduler/main.py +++ b/couchpotato/core/_base/scheduler/main.py @@ -15,8 +15,6 @@ class Scheduler(Plugin): def __init__(self): - logging.getLogger('apscheduler').setLevel(logging.ERROR) - addEvent('schedule.cron', self.cron) addEvent('schedule.interval', self.interval) addEvent('schedule.start', self.start) diff --git a/couchpotato/core/_base/updater/main.py b/couchpotato/core/_base/updater/main.py index 1750de5..7bcfeff 100644 --- a/couchpotato/core/_base/updater/main.py +++ b/couchpotato/core/_base/updater/main.py @@ -232,6 +232,7 @@ class SourceUpdater(BaseUpdater): # Extract tar = tarfile.open(destination) tar.extractall(path = extracted_path) + tar.close() os.remove(destination) self.replaceWith(os.path.join(extracted_path, os.listdir(extracted_path)[0])) diff --git a/couchpotato/core/helpers/variable.py b/couchpotato/core/helpers/variable.py index 1f5f76a..b398d30 100644 --- a/couchpotato/core/helpers/variable.py +++ b/couchpotato/core/helpers/variable.py @@ -59,6 +59,9 @@ def flattenList(l): def md5(text): return hashlib.md5(text).hexdigest() +def sha1(text): + return hashlib.sha1(text).hexdigest() + def getExt(filename): return os.path.splitext(filename)[1][1:] diff --git a/couchpotato/core/notifications/base.py b/couchpotato/core/notifications/base.py index 254059e..2663c33 100644 --- a/couchpotato/core/notifications/base.py +++ b/couchpotato/core/notifications/base.py @@ -3,13 +3,14 @@ from couchpotato.core.event import addEvent from couchpotato.core.helpers.request import jsonified from couchpotato.core.logger import CPLog from couchpotato.core.plugins.base import Plugin +from couchpotato.environment import Env log = CPLog(__name__) class Notification(Plugin): - default_title = 'CouchPotato' + default_title = Env.get('appname') test_message = 'ZOMG Lazors Pewpewpew!' listen_to = ['movie.downloaded', 'movie.snatched', 'updater.available'] @@ -29,11 +30,11 @@ class Notification(Plugin): def notify(message, data): if not self.conf('on_snatch', default = True) and listener == 'movie.snatched': return - return self.notify(message = message, data = data) + return self.notify(message = message, data = data, listener = listener) return notify - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): pass def test(self): @@ -44,7 +45,8 @@ class Notification(Plugin): success = self.notify( message = self.test_message, - data = {} + data = {}, + listener = 'test' ) return jsonified({'success': success}) diff --git a/couchpotato/core/notifications/boxcar/main.py b/couchpotato/core/notifications/boxcar/main.py index cbf907f..3135614 100644 --- a/couchpotato/core/notifications/boxcar/main.py +++ b/couchpotato/core/notifications/boxcar/main.py @@ -10,7 +10,7 @@ class Boxcar(Notification): url = 'https://boxcar.io/devices/providers/7MNNXY3UIzVBwvzkKwkC/notifications' - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return try: diff --git a/couchpotato/core/notifications/core/main.py b/couchpotato/core/notifications/core/main.py index 7081740..243f621 100644 --- a/couchpotato/core/notifications/core/main.py +++ b/couchpotato/core/notifications/core/main.py @@ -65,6 +65,7 @@ class CoreNotifier(Notification): q.update({Notif.read: True}) db.commit() + db.close() return jsonified({ 'success': True @@ -90,16 +91,19 @@ class CoreNotifier(Notification): ndict['type'] = 'notification' notifications.append(ndict) + db.close() return jsonified({ 'success': True, 'empty': len(notifications) == 0, 'notifications': notifications }) - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): db = get_session() + data['notification_type'] = listener if listener else 'unknown' + n = Notif( message = toUnicode(message), data = data @@ -112,7 +116,8 @@ class CoreNotifier(Notification): ndict['time'] = time.time() self.messages.append(ndict) - db.remove() + db.close() + return True def frontend(self, type = 'notification', data = {}): self.messages.append({ @@ -141,6 +146,8 @@ class CoreNotifier(Notification): ndict['type'] = 'notification' messages.append(ndict) + db.close() + self.messages = [] return jsonified({ 'success': True, diff --git a/couchpotato/core/notifications/growl/main.py b/couchpotato/core/notifications/growl/main.py index 06accc8..b98888e 100644 --- a/couchpotato/core/notifications/growl/main.py +++ b/couchpotato/core/notifications/growl/main.py @@ -1,6 +1,7 @@ from couchpotato.core.event import fireEvent from couchpotato.core.logger import CPLog from couchpotato.core.notifications.base import Notification +from couchpotato.environment import Env from gntp import notifier import logging import traceback @@ -15,8 +16,6 @@ class Growl(Notification): def __init__(self): super(Growl, self).__init__() - logging.getLogger('gntp').setLevel(logging.WARNING) - if self.isEnabled(): self.register() @@ -29,7 +28,7 @@ class Growl(Notification): port = self.conf('port') self.growl = notifier.GrowlNotifier( - applicationName = 'CouchPotato', + applicationName = Env.get('appname'), notifications = ["Updates"], defaultNotifications = ["Updates"], applicationIcon = '%s/static/images/couch.png' % fireEvent('app.api_url', single = True), @@ -42,7 +41,7 @@ class Growl(Notification): except: log.error('Failed register of growl: %s' % traceback.format_exc()) - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return self.register() diff --git a/couchpotato/core/notifications/history/main.py b/couchpotato/core/notifications/history/main.py index f235fd0..a4ad974 100644 --- a/couchpotato/core/notifications/history/main.py +++ b/couchpotato/core/notifications/history/main.py @@ -12,7 +12,7 @@ class History(Notification): listen_to = ['movie.downloaded', 'movie.snatched', 'renamer.canceled'] - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): db = get_session() history = Hist( @@ -22,3 +22,6 @@ class History(Notification): ) db.add(history) db.commit() + db.close() + + return True diff --git a/couchpotato/core/notifications/notifo/main.py b/couchpotato/core/notifications/notifo/main.py index d4baaf6..e372f28 100644 --- a/couchpotato/core/notifications/notifo/main.py +++ b/couchpotato/core/notifications/notifo/main.py @@ -12,7 +12,7 @@ class Notifo(Notification): url = 'https://api.notifo.com/v1/send_notification' - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return try: diff --git a/couchpotato/core/notifications/notifymyandroid/main.py b/couchpotato/core/notifications/notifymyandroid/main.py index 2f84317..195278e 100644 --- a/couchpotato/core/notifications/notifymyandroid/main.py +++ b/couchpotato/core/notifications/notifymyandroid/main.py @@ -7,7 +7,7 @@ log = CPLog(__name__) class NotifyMyAndroid(Notification): - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return nma = pynma.PyNMA() diff --git a/couchpotato/core/notifications/notifymywp/main.py b/couchpotato/core/notifications/notifymywp/main.py index 203fdad..17252c1 100644 --- a/couchpotato/core/notifications/notifymywp/main.py +++ b/couchpotato/core/notifications/notifymywp/main.py @@ -7,7 +7,7 @@ log = CPLog(__name__) class NotifyMyWP(Notification): - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return keys = [x.strip() for x in self.conf('api_key').split(',')] diff --git a/couchpotato/core/notifications/prowl/main.py b/couchpotato/core/notifications/prowl/main.py index 44aaa9c..5f24d4e 100644 --- a/couchpotato/core/notifications/prowl/main.py +++ b/couchpotato/core/notifications/prowl/main.py @@ -8,7 +8,7 @@ log = CPLog(__name__) class Prowl(Notification): - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return http_handler = HTTPSConnection('api.prowlapp.com') diff --git a/couchpotato/core/notifications/pushover/main.py b/couchpotato/core/notifications/pushover/main.py index d531f2b..be99df1 100644 --- a/couchpotato/core/notifications/pushover/main.py +++ b/couchpotato/core/notifications/pushover/main.py @@ -10,7 +10,7 @@ class Pushover(Notification): app_token = 'YkxHMYDZp285L265L3IwH3LmzkTaCy' - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return http_handler = HTTPSConnection("api.pushover.net:443") diff --git a/couchpotato/core/notifications/twitter/main.py b/couchpotato/core/notifications/twitter/main.py index 700f7fa..b956dd0 100644 --- a/couchpotato/core/notifications/twitter/main.py +++ b/couchpotato/core/notifications/twitter/main.py @@ -31,7 +31,7 @@ class Twitter(Notification): addApiView('notify.%s.auth_url' % self.getName().lower(), self.getAuthorizationUrl) addApiView('notify.%s.credentials' % self.getName().lower(), self.getCredentials) - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return api = Api(self.consumer_key, self.consumer_secret, self.conf('access_token_key'), self.conf('access_token_secret')) diff --git a/couchpotato/core/notifications/xbmc/main.py b/couchpotato/core/notifications/xbmc/main.py index cbc81df..eff591a 100644 --- a/couchpotato/core/notifications/xbmc/main.py +++ b/couchpotato/core/notifications/xbmc/main.py @@ -10,7 +10,7 @@ class XBMC(Notification): listen_to = ['movie.downloaded'] - def notify(self, message = '', data = {}): + def notify(self, message = '', data = {}, listener = None): if self.isDisabled(): return for host in [x.strip() for x in self.conf('host').split(",")]: diff --git a/couchpotato/core/plugins/file/main.py b/couchpotato/core/plugins/file/main.py index e0f0240..994da1e 100644 --- a/couchpotato/core/plugins/file/main.py +++ b/couchpotato/core/plugins/file/main.py @@ -87,7 +87,7 @@ class FileManager(Plugin): db.commit() type_dict = ft.to_dict() - db.remove() + db.close() return type_dict def getTypes(self): @@ -100,4 +100,5 @@ class FileManager(Plugin): for type_object in results: types.append(type_object.to_dict()) + db.close() return types diff --git a/couchpotato/core/plugins/library/main.py b/couchpotato/core/plugins/library/main.py index 1a4880c..bc38857 100644 --- a/couchpotato/core/plugins/library/main.py +++ b/couchpotato/core/plugins/library/main.py @@ -51,7 +51,10 @@ class LibraryPlugin(Plugin): handle = fireEventAsync if update_after is 'async' else fireEvent handle('library.update', identifier = l.identifier, default_title = toUnicode(attrs.get('title', ''))) - return l.to_dict(self.default_dict) + library_dict = l.to_dict(self.default_dict) + + db.close() + return library_dict def update(self, identifier, default_title = '', force = False): @@ -68,7 +71,13 @@ class LibraryPlugin(Plugin): do_update = False else: info = fireEvent('movie.info', merge = True, identifier = identifier) - del info['in_wanted'], info['in_library'] # Don't need those here + + # Don't need those here + try: del info['in_wanted'] + except: pass + try: del info['in_library'] + except: pass + if not info or len(info) == 0: log.error('Could not update, no movie info to work with: %s' % identifier) return False @@ -121,6 +130,7 @@ class LibraryPlugin(Plugin): fireEvent('library.update_finish', data = library_dict) + db.close() return library_dict def updateReleaseDate(self, identifier): @@ -134,7 +144,7 @@ class LibraryPlugin(Plugin): db.commit() dates = library.info.get('release_date', {}) - db.remove() + db.close() return dates diff --git a/couchpotato/core/plugins/movie/main.py b/couchpotato/core/plugins/movie/main.py index fbad471..52e6edc 100644 --- a/couchpotato/core/plugins/movie/main.py +++ b/couchpotato/core/plugins/movie/main.py @@ -6,7 +6,7 @@ from couchpotato.core.helpers.request import getParams, jsonified, getParam from couchpotato.core.helpers.variable import getImdb from couchpotato.core.logger import CPLog from couchpotato.core.plugins.base import Plugin -from couchpotato.core.settings.model import Movie, Library, LibraryTitle +from couchpotato.core.settings.model import Library, LibraryTitle, Movie from couchpotato.environment import Env from sqlalchemy.orm import joinedload_all from sqlalchemy.sql.expression import or_, asc, not_ @@ -60,7 +60,7 @@ class MoviePlugin(Plugin): addApiView('movie.refresh', self.refresh, docs = { 'desc': 'Refresh a movie by id', 'params': { - 'id': {'desc': 'The id of the movie that needs to be refreshed'}, + 'id': {'desc': 'Movie ID(s) you want to refresh.', 'type': 'int (comma separated)'}, } }) addApiView('movie.available_chars', self.charView) @@ -109,10 +109,12 @@ class MoviePlugin(Plugin): db = get_session() m = db.query(Movie).filter_by(id = movie_id).first() + results = None if m: - return m.to_dict(self.default_dict) + results = m.to_dict(self.default_dict) - return None + db.close() + return results def list(self, status = ['active'], limit_offset = None, starts_with = None, search = None): @@ -175,6 +177,7 @@ class MoviePlugin(Plugin): }) movies.append(temp) + db.close() return movies def availableChars(self, status = ['active']): @@ -200,6 +203,7 @@ class MoviePlugin(Plugin): if char not in chars: chars += char + db.close() return chars def listView(self): @@ -232,20 +236,21 @@ class MoviePlugin(Plugin): def refresh(self): - params = getParams() db = get_session() - movie = db.query(Movie).filter_by(id = params.get('id')).first() + for id in getParam('id').split(','): + movie = db.query(Movie).filter_by(id = id).first() - # Get current selected title - default_title = '' - for title in movie.library.titles: - if title.default: default_title = title.title + # Get current selected title + default_title = '' + for title in movie.library.titles: + if title.default: default_title = title.title - if movie: - fireEventAsync('library.update', identifier = movie.library.identifier, default_title = default_title, force = True) - fireEventAsync('searcher.single', movie.to_dict(self.default_dict)) + if movie: + fireEventAsync('library.update', identifier = movie.library.identifier, default_title = default_title, force = True) + fireEventAsync('searcher.single', movie.to_dict(self.default_dict)) + db.close() return jsonified({ 'success': True, }) @@ -270,7 +275,7 @@ class MoviePlugin(Plugin): 'movies': movies, }) - def add(self, params = {}, force_readd = True): + def add(self, params = {}, force_readd = True, search_after = True): library = fireEvent('library.add', single = True, attrs = params, update_after = False) @@ -316,9 +321,10 @@ class MoviePlugin(Plugin): movie_dict = m.to_dict(self.default_dict) - if force_readd or do_search: + if (force_readd or do_search) and search_after: fireEventAsync('searcher.single', movie_dict) + db.close() return movie_dict @@ -365,6 +371,7 @@ class MoviePlugin(Plugin): movie_dict = m.to_dict(self.default_dict) fireEventAsync('searcher.single', movie_dict) + db.close() return jsonified({ 'success': True, }) @@ -419,6 +426,7 @@ class MoviePlugin(Plugin): else: fireEvent('movie.restatus', movie.id, single = True) + db.close() return True def restatus(self, movie_id): @@ -429,6 +437,9 @@ class MoviePlugin(Plugin): db = get_session() m = db.query(Movie).filter_by(id = movie_id).first() + if not m: + log.debug('Can\'t restatus movie, doesn\'t seem to exist.') + return False log.debug('Changing status for %s' % (m.library.titles[0].title)) if not m.profile: @@ -444,3 +455,6 @@ class MoviePlugin(Plugin): m.status_id = active_status.get('id') if move_to_wanted else done_status.get('id') db.commit() + db.close() + + return True diff --git a/couchpotato/core/plugins/movie/static/list.js b/couchpotato/core/plugins/movie/static/list.js index becdc14..23bdcf8 100644 --- a/couchpotato/core/plugins/movie/static/list.js +++ b/couchpotato/core/plugins/movie/static/list.js @@ -144,6 +144,15 @@ var MovieList = new Class({ 'click': self.deleteSelected.bind(self) } }) + ), + new Element('div.refresh').adopt( + new Element('span[text=or]'), + new Element('a.button.green', { + 'text': 'Refresh', + 'events': { + 'click': self.refreshSelected.bind(self) + } + }) ) ) ).inject(self.el, 'top'); @@ -245,8 +254,8 @@ var MovieList = new Class({ var self = this; var ids = self.getSelectedMovies() - var qObj = new Question('Are you sure you want to delete the selected movies?', 'Items using this profile, will be set to the default quality.', [{ - 'text': 'Yes, delete them', + var qObj = new Question('Are you sure you want to delete '+ids.length+' movie'+ (ids.length != 1 ? 's' : '') +'?', 'If you do, you won\'t be able to watch them, as they won\'t get downloaded!', [{ + 'text': 'Yes, delete '+(ids.length != 1 ? 'them' : 'it'), 'class': 'delete', 'events': { 'click': function(e){ @@ -292,6 +301,17 @@ var MovieList = new Class({ }); }, + refreshSelected: function(){ + var self = this; + var ids = self.getSelectedMovies() + + Api.request('movie.refresh', { + 'data': { + 'id': ids.join(','), + } + }); + }, + getSelectedMovies: function(){ var self = this; diff --git a/couchpotato/core/plugins/movie/static/movie.css b/couchpotato/core/plugins/movie/static/movie.css index a6f7e4c..6365f79 100644 --- a/couchpotato/core/plugins/movie/static/movie.css +++ b/couchpotato/core/plugins/movie/static/movie.css @@ -456,11 +456,13 @@ padding: 3px 7px; } + .movies .alph_nav .mass_edit_form .refresh, .movies .alph_nav .mass_edit_form .delete { float: left; padding: 8px 0 0 8px; } + .movies .alph_nav .mass_edit_form .refresh span, .movies .alph_nav .mass_edit_form .delete span { margin: 0 10px 0 0; } diff --git a/couchpotato/core/plugins/movie/static/movie.js b/couchpotato/core/plugins/movie/static/movie.js index 4ce2eab..39b12ee 100644 --- a/couchpotato/core/plugins/movie/static/movie.js +++ b/couchpotato/core/plugins/movie/static/movie.js @@ -69,7 +69,7 @@ var Movie = new Class({ self.profile.getTypes().each(function(type){ var q = self.addQuality(type.quality_id || type.get('quality_id')); - if(type.finish || type.get('finish')) + if(type.finish == true || type.get('finish')) q.addClass('finish'); }); @@ -82,7 +82,7 @@ var Movie = new Class({ if(!q && (status.identifier == 'snatched' || status.identifier == 'done')) var q = self.addQuality(release.quality_id) - if (q) + if (status && q) q.addClass(status.identifier); }); diff --git a/couchpotato/core/plugins/profile/main.py b/couchpotato/core/plugins/profile/main.py index 0ba0637..77ad02a 100644 --- a/couchpotato/core/plugins/profile/main.py +++ b/couchpotato/core/plugins/profile/main.py @@ -47,6 +47,7 @@ class ProfilePlugin(Plugin): for profile in profiles: temp.append(profile.to_dict(self.to_dict)) + db.close() return temp def save(self): @@ -83,6 +84,7 @@ class ProfilePlugin(Plugin): profile_dict = p.to_dict(self.to_dict) + db.close() return jsonified({ 'success': True, 'profile': profile_dict @@ -92,8 +94,10 @@ class ProfilePlugin(Plugin): db = get_session() default = db.query(Profile).first() + default_dict = default.to_dict(self.to_dict) + db.close() - return default.to_dict(self.to_dict) + return default_dict def saveOrder(self): @@ -109,6 +113,7 @@ class ProfilePlugin(Plugin): order += 1 db.commit() + db.close() return jsonified({ 'success': True @@ -133,6 +138,8 @@ class ProfilePlugin(Plugin): message = 'Failed deleting Profile: %s' % e log.error(message) + db.close() + return jsonified({ 'success': success, 'message': message @@ -180,4 +187,5 @@ class ProfilePlugin(Plugin): order += 1 + db.close() return True diff --git a/couchpotato/core/plugins/quality/main.py b/couchpotato/core/plugins/quality/main.py index b962dcd..682866f 100644 --- a/couchpotato/core/plugins/quality/main.py +++ b/couchpotato/core/plugins/quality/main.py @@ -21,7 +21,7 @@ class QualityPlugin(Plugin): {'identifier': '720p', 'hd': True, 'size': (3500, 10000), 'label': '720P', 'width': 1280, 'alternative': [], 'allow': [], 'ext':['mkv', 'm2ts', 'ts']}, {'identifier': 'brrip', 'hd': True, 'size': (700, 7000), 'label': 'BR-Rip', 'alternative': ['bdrip'], 'allow': ['720p'], 'ext':['avi']}, {'identifier': 'dvdr', 'size': (3000, 10000), 'label': 'DVD-R', 'alternative': [], 'allow': [], 'ext':['iso', 'img'], 'tags': ['pal', 'ntsc', 'video_ts', 'audio_ts']}, - {'identifier': 'dvdrip', 'size': (600, 2400), 'label': 'DVD-Rip', 'alternative': ['dvdrip'], 'allow': [], 'ext':['avi', 'mpg', 'mpeg']}, + {'identifier': 'dvdrip', 'size': (600, 2400), 'label': 'DVD-Rip', 'width': 720, 'alternative': ['dvdrip'], 'allow': [], 'ext':['avi', 'mpg', 'mpeg']}, {'identifier': 'scr', 'size': (600, 1600), 'label': 'Screener', 'alternative': ['screener', 'dvdscr', 'ppvrip'], 'allow': ['dvdr', 'dvd'], 'ext':['avi', 'mpg', 'mpeg']}, {'identifier': 'r5', 'size': (600, 1000), 'label': 'R5', 'alternative': [], 'allow': ['dvdr'], 'ext':['avi', 'mpg', 'mpeg']}, {'identifier': 'tc', 'size': (600, 1000), 'label': 'TeleCine', 'alternative': ['telecine'], 'allow': [], 'ext':['avi', 'mpg', 'mpeg']}, @@ -68,6 +68,7 @@ class QualityPlugin(Plugin): q = mergeDicts(self.getQuality(quality.identifier), quality.to_dict()) temp.append(q) + db.close() return temp def single(self, identifier = ''): @@ -79,6 +80,7 @@ class QualityPlugin(Plugin): if quality: quality_dict = dict(self.getQuality(quality.identifier), **quality.to_dict()) + db.close() return quality_dict def getQuality(self, identifier): @@ -98,6 +100,7 @@ class QualityPlugin(Plugin): setattr(quality, params.get('value_type'), params.get('value')) db.commit() + db.close() return jsonified({ 'success': True }) @@ -149,9 +152,10 @@ class QualityPlugin(Plugin): order += 1 db.commit() + db.close() return True - def guess(self, files, extra = {}, loose = False): + def guess(self, files, extra = {}): # Create hash for cache hash = md5(str([f.replace('.' + getExt(f), '') for f in files])) @@ -182,25 +186,25 @@ class QualityPlugin(Plugin): log.debug('Found %s via tag %s in %s' % (quality['identifier'], quality.get('tags'), cur_file)) return self.setCache(hash, quality) - # Check on unreliable stuff - if loose: + # Try again with loose testing + quality = self.guessLoose(hash, extra = extra) + if quality: + return self.setCache(hash, quality) - # Last check on resolution only - if quality.get('width', 480) == extra.get('resolution_width', 0): - log.debug('Found %s via resolution_width: %s == %s' % (quality['identifier'], quality.get('width', 480), extra.get('resolution_width', 0))) - return self.setCache(hash, quality) + log.debug('Could not identify quality for: %s' % files) + return None - # Check extension + filesize - if list(set(quality.get('ext', [])) & set(words)) and size >= quality['size_min'] and size <= quality['size_max']: - log.debug('Found %s via ext and filesize %s in %s' % (quality['identifier'], quality.get('ext'), words)) - return self.setCache(hash, quality) + def guessLoose(self, hash, extra): + for quality in self.all(): - # Try again with loose testing - if not loose: - quality = self.guess(files, extra = extra, loose = True) - if quality: + # Last check on resolution only + if quality.get('width', 480) == extra.get('resolution_width', 0): + log.debug('Found %s via resolution_width: %s == %s' % (quality['identifier'], quality.get('width', 480), extra.get('resolution_width', 0))) return self.setCache(hash, quality) - log.debug('Could not identify quality for: %s' % files) + if 480 <= extra.get('resolution_width', 0) <= 720: + log.debug('Found as dvdrip') + return self.setCache(hash, self.single('dvdrip')) + return None diff --git a/couchpotato/core/plugins/release/main.py b/couchpotato/core/plugins/release/main.py index 3a0f204..9364a0e 100644 --- a/couchpotato/core/plugins/release/main.py +++ b/couchpotato/core/plugins/release/main.py @@ -83,7 +83,9 @@ class Release(Plugin): fireEvent('movie.restatus', movie.id) - db.remove() + db.close() + + return True def saveFile(self, filepath, type = 'unknown', include_media_info = False): @@ -107,6 +109,7 @@ class Release(Plugin): rel.delete() db.commit() + db.close() return jsonified({ 'success': True }) @@ -123,6 +126,7 @@ class Release(Plugin): rel.status_id = available_status.get('id') if rel.status_id is ignored_status.get('id') else ignored_status.get('id') db.commit() + db.close() return jsonified({ 'success': True }) @@ -149,12 +153,14 @@ class Release(Plugin): 'files': {} }), manual = True) + db.close() return jsonified({ 'success': True }) else: log.error('Couldn\'t find release with id: %s' % id) + db.close() return jsonified({ 'success': False }) diff --git a/couchpotato/core/plugins/renamer/main.py b/couchpotato/core/plugins/renamer/main.py index 672ad8e..ac6b88b 100644 --- a/couchpotato/core/plugins/renamer/main.py +++ b/couchpotato/core/plugins/renamer/main.py @@ -6,7 +6,7 @@ from couchpotato.core.helpers.request import jsonified from couchpotato.core.helpers.variable import getExt, mergeDicts from couchpotato.core.logger import CPLog from couchpotato.core.plugins.base import Plugin -from couchpotato.core.settings.model import Library, File +from couchpotato.core.settings.model import Library, File, Profile from couchpotato.environment import Env import os import re @@ -67,6 +67,12 @@ class Renamer(Plugin): nfo_name = self.conf('nfo_name') separator = self.conf('separator') + # Statusses + done_status = fireEvent('status.get', 'done', single = True) + active_status = fireEvent('status.get', 'active', single = True) + downloaded_status = fireEvent('status.get', 'downloaded', single = True) + snatched_status = fireEvent('status.get', 'snatched', single = True) + db = get_session() for group_identifier in groups: @@ -185,7 +191,7 @@ class Renamer(Plugin): break if not found: - log.error('Could not determin dvd structure for: %s' % current_file) + log.error('Could not determine dvd structure for: %s' % current_file) # Do rename others else: @@ -240,10 +246,15 @@ class Renamer(Plugin): cd += 1 # Before renaming, remove the lower quality files - library = db.query(Library).filter_by(identifier = group['library']['identifier']).first() - done_status = fireEvent('status.get', 'done', single = True) - active_status = fireEvent('status.get', 'active', single = True) + remove_leftovers = True + + # Add it to the wanted list before we continue + if len(library.movies) == 0: + profile = db.query(Profile).filter_by(core = True, label = group['meta_data']['quality']['label']).first() + fireEvent('movie.add', params = {'identifier': group['library']['identifier'], 'profile_id': profile.id}, search_after = False) + db.expire_all() + library = db.query(Library).filter_by(identifier = group['library']['identifier']).first() for movie in library.movies: @@ -293,14 +304,25 @@ class Renamer(Plugin): # Notify on rename fail download_message = 'Renaming of %s (%s) canceled, exists in %s already.' % (movie.library.titles[0].title, group['meta_data']['quality']['label'], release.quality.label) fireEvent('movie.renaming.canceled', message = download_message, data = group) + remove_leftovers = False break + elif release.status_id is snatched_status.get('id'): + print release.quality.label, group['meta_data']['quality']['label'] + if release.quality.id is group['meta_data']['quality']['id']: + log.debug('Marking release as downloaded') + release.status_id = downloaded_status.get('id') + db.commit() # Remove leftover files - if self.conf('cleanup') and not self.conf('move_leftover'): + if self.conf('cleanup') and not self.conf('move_leftover') and remove_leftovers: log.debug('Removing leftover files') for current_file in group['files']['leftover']: remove_files.append(current_file) + elif not remove_leftovers: # Don't remove anything + remove_files = [] + + continue # Rename all files marked group['renamed_files'] = [] @@ -356,6 +378,7 @@ class Renamer(Plugin): if self.shuttingDown(): break + db.close() self.renaming_started = False def getRenameExtras(self, extra_type = '', replacements = {}, folder_name = '', file_name = '', destination = '', group = {}, current_file = ''): diff --git a/couchpotato/core/plugins/scanner/main.py b/couchpotato/core/plugins/scanner/main.py index 358d992..46eccea 100644 --- a/couchpotato/core/plugins/scanner/main.py +++ b/couchpotato/core/plugins/scanner/main.py @@ -8,17 +8,13 @@ from couchpotato.core.settings.model import File from couchpotato.environment import Env from enzyme.exceptions import NoParserError, ParseError from guessit import guess_movie_info -from subliminal.videos import scan +from subliminal.videos import scan, Video import enzyme -import logging import os import re import time import traceback -enzyme_logger = logging.getLogger('enzyme') -enzyme_logger.setLevel(logging.INFO) - log = CPLog(__name__) @@ -97,10 +93,6 @@ class Scanner(Plugin): addEvent('rename.after', after_rename) - # Disable lib logging - logging.getLogger('guessit').setLevel(logging.ERROR) - logging.getLogger('subliminal').setLevel(logging.ERROR) - def scanFilesToLibrary(self, folder = None, files = None): groups = self.scan(folder = folder, files = files) @@ -109,12 +101,12 @@ class Scanner(Plugin): if group['library']: fireEvent('release.add', group = group) - def scanFolderToLibrary(self, folder = None, newer_than = None): + def scanFolderToLibrary(self, folder = None, newer_than = None, simple = True): if not os.path.isdir(folder): return - groups = self.scan(folder = folder) + groups = self.scan(folder = folder, simple = simple) added_identifier = [] while True and not self.shuttingDown(): @@ -135,7 +127,7 @@ class Scanner(Plugin): return added_identifier - def scan(self, folder = None, files = []): + def scan(self, folder = None, files = [], simple = False): if not folder or not os.path.isdir(folder): log.error('Folder doesn\'t exists: %s' % folder) @@ -299,7 +291,7 @@ class Scanner(Plugin): group['meta_data'] = self.getMetaData(group) # Subtitle meta - group['subtitle_language'] = self.getSubtitleLanguage(group) + group['subtitle_language'] = self.getSubtitleLanguage(group) if not simple else {} # Get parent dir from movie files for movie_file in group['files']['movie']: @@ -328,7 +320,7 @@ class Scanner(Plugin): # Determine movie group['library'] = self.determineMovie(group) if not group['library']: - log.error('Unable to determin movie: %s' % group['identifiers']) + log.error('Unable to determine movie: %s' % group['identifiers']) processed_movies[identifier] = group @@ -400,7 +392,9 @@ class Scanner(Plugin): scan_result = [] for p in paths: if not group['is_dvd']: - scan_result.extend(scan(p)) + video = Video.from_path(p) + video_result = [(video, video.scan())] + scan_result.extend(video_result) for video, detected_subtitles in scan_result: for s in detected_subtitles: @@ -461,7 +455,7 @@ class Scanner(Plugin): break except: pass - db.remove() + db.close() # Search based on OpenSubtitleHash if not imdb_id and not group['is_dvd']: @@ -482,7 +476,7 @@ class Scanner(Plugin): try: filename = list(group['files'].get('movie'))[0] except: filename = None - name_year = self.getReleaseNameYear(identifier, file_name = filename) + name_year = self.getReleaseNameYear(identifier, file_name = filename if not group['is_dvd'] else None) if name_year.get('name') and name_year.get('year'): movie = fireEvent('movie.search', q = '%(name)s %(year)s' % name_year, merge = True, limit = 1) diff --git a/couchpotato/core/plugins/searcher/main.py b/couchpotato/core/plugins/searcher/main.py index e2642dc..c27e101 100644 --- a/couchpotato/core/plugins/searcher/main.py +++ b/couchpotato/core/plugins/searcher/main.py @@ -61,10 +61,17 @@ class Searcher(Plugin): if self.shuttingDown(): break + db.close() self.in_progress = False def single(self, movie): + done_status = fireEvent('status.get', 'done', single = True) + + if not movie['profile'] or movie['status_id'] == done_status.get('id'): + log.debug('Movie doesn\'t have a profile or already done, assuming in manage tab.') + return + db = get_session() pre_releases = fireEvent('quality.pre_releases', single = True) @@ -141,7 +148,7 @@ class Searcher(Plugin): if self.shuttingDown(): break - db.remove() + db.close() return False def download(self, data, movie, manual = False): @@ -184,6 +191,7 @@ class Searcher(Plugin): except Exception, e: log.error('Failed marking movie finished: %s %s' % (e, traceback.format_exc())) + db.close() return True log.info('Tried to download, but none of the downloaders are enabled') diff --git a/couchpotato/core/plugins/status/main.py b/couchpotato/core/plugins/status/main.py index 9912fee..1683972 100644 --- a/couchpotato/core/plugins/status/main.py +++ b/couchpotato/core/plugins/status/main.py @@ -48,7 +48,10 @@ class StatusPlugin(Plugin): def getById(self, id): db = get_session() status = db.query(Status).filter_by(id = id).first() - return status.to_dict() + status_dict = status.to_dict() + db.close() + + return status_dict def all(self): @@ -61,6 +64,7 @@ class StatusPlugin(Plugin): s = status.to_dict() temp.append(s) + db.close() return temp def add(self, identifier): @@ -78,6 +82,7 @@ class StatusPlugin(Plugin): status_dict = s.to_dict() + db.close() return status_dict def fill(self): @@ -97,3 +102,5 @@ class StatusPlugin(Plugin): s.label = toUnicode(label) db.commit() + db.close() + diff --git a/couchpotato/core/plugins/subtitle/main.py b/couchpotato/core/plugins/subtitle/main.py index da86d49..56ff39d 100644 --- a/couchpotato/core/plugins/subtitle/main.py +++ b/couchpotato/core/plugins/subtitle/main.py @@ -38,6 +38,8 @@ class Subtitle(Plugin): # get subtitles for those files subliminal.list_subtitles(files, cache_dir = Env.get('cache_dir'), multi = True, languages = self.getLanguages(), services = self.services) + db.close() + def searchSingle(self, group): if self.isDisabled(): return diff --git a/couchpotato/core/plugins/userscript/iframe.html b/couchpotato/core/plugins/userscript/iframe.html deleted file mode 100644 index 322407f..0000000 --- a/couchpotato/core/plugins/userscript/iframe.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - \ No newline at end of file diff --git a/couchpotato/core/plugins/v1importer/__init__.py b/couchpotato/core/plugins/v1importer/__init__.py new file mode 100644 index 0000000..40c1434 --- /dev/null +++ b/couchpotato/core/plugins/v1importer/__init__.py @@ -0,0 +1,6 @@ +from .main import V1Importer + +def start(): + return V1Importer() + +config = [] diff --git a/couchpotato/core/plugins/v1importer/form.html b/couchpotato/core/plugins/v1importer/form.html new file mode 100644 index 0000000..e27d1c7 --- /dev/null +++ b/couchpotato/core/plugins/v1importer/form.html @@ -0,0 +1,30 @@ + + + + + + + + + + + + + {% if message: %} + {{ message }} + {% else: %} +
+ +
+ {% endif %} + + \ No newline at end of file diff --git a/couchpotato/core/plugins/v1importer/main.py b/couchpotato/core/plugins/v1importer/main.py new file mode 100644 index 0000000..08f8ba9 --- /dev/null +++ b/couchpotato/core/plugins/v1importer/main.py @@ -0,0 +1,56 @@ +from couchpotato.api import addApiView +from couchpotato.core.event import fireEventAsync +from couchpotato.core.helpers.variable import getImdb +from couchpotato.core.logger import CPLog +from couchpotato.core.plugins.base import Plugin +from couchpotato.environment import Env +from flask.globals import request +from flask.helpers import url_for +import os + +log = CPLog(__name__) + + +class V1Importer(Plugin): + + def __init__(self): + addApiView('v1.import', self.fromOld, methods = ['GET', 'POST']) + + def fromOld(self): + + if request.method != 'POST': + return self.renderTemplate(__file__, 'form.html', url_for = url_for) + + file = request.files['old_db'] + + uploaded_file = os.path.join(Env.get('cache_dir'), 'v1_database.db') + + if os.path.isfile(uploaded_file): + os.remove(uploaded_file) + + file.save(uploaded_file) + + try: + import sqlite3 + conn = sqlite3.connect(uploaded_file) + + wanted = [] + + t = ('want',) + cur = conn.execute('SELECT status, imdb FROM Movie WHERE status=?', t) + for row in cur: + status, imdb = row + if getImdb(imdb): + wanted.append(imdb) + conn.close() + + wanted = set(wanted) + for imdb in wanted: + fireEventAsync('movie.add', {'identifier': imdb}, search_after = False) + + message = 'Successfully imported %s movie(s)' % len(wanted) + except Exception, e: + message = 'Failed: %s' % e + + return self.renderTemplate(__file__, 'form.html', url_for = url_for, message = message) + diff --git a/couchpotato/core/plugins/wizard/static/wizard.js b/couchpotato/core/plugins/wizard/static/wizard.js index e1e07a8..a4438cb 100644 --- a/couchpotato/core/plugins/wizard/static/wizard.js +++ b/couchpotato/core/plugins/wizard/static/wizard.js @@ -8,8 +8,28 @@ Page.Wizard = new Class({ headers: { 'welcome': { - 'title': 'Welcome to CouchPotato', - 'description': 'To get started, fill in each of the following settings as much as your can.' + 'title': 'Welcome to the new CouchPotato', + 'description': 'To get started, fill in each of the following settings as much as your can.
Maybe first start with importing your movies from the previous CouchPotato', + 'content': new Element('div', { + 'styles': { + 'margin': '0 0 0 30px' + } + }).adopt( + new Element('div', { + 'html': 'Select the data.db. It should be in your CouchPotato root directory.' + }), + self.import_iframe = new Element('iframe', { + 'styles': { + 'height': 40, + 'width': 300, + 'border': 0, + 'overflow': 'hidden' + } + }) + ), + 'event': function(){ + self.import_iframe.set('src', Api.createUrl('v1.import')) + } }, 'general': { 'title': 'General', @@ -105,7 +125,7 @@ Page.Wizard = new Class({ 'text': self.headers[group].title }), self.headers[group].description ? new Element('span.description', { - 'text': self.headers[group].description + 'html': self.headers[group].description }) : null, self.headers[group].content ? self.headers[group].content : null ).inject(form); @@ -132,6 +152,9 @@ Page.Wizard = new Class({ }) ).inject(tabs); } + + if(self.headers[group] && self.headers[group].event) + self.headers[group].event.call() }); // Remove toggle diff --git a/couchpotato/core/providers/automation/trakt/__init__.py b/couchpotato/core/providers/automation/trakt/__init__.py index 88ab454..fca7af3 100644 --- a/couchpotato/core/providers/automation/trakt/__init__.py +++ b/couchpotato/core/providers/automation/trakt/__init__.py @@ -25,6 +25,12 @@ config = [{ 'name': 'automation_username', 'label': 'Username', }, + { + 'name': 'automation_password', + 'label': 'Password', + 'type': 'password', + 'description': 'When you have "Protect my data" checked on trakt.', + }, ], }, ], diff --git a/couchpotato/core/providers/automation/trakt/main.py b/couchpotato/core/providers/automation/trakt/main.py index 06a4872..e623b8f 100644 --- a/couchpotato/core/providers/automation/trakt/main.py +++ b/couchpotato/core/providers/automation/trakt/main.py @@ -1,6 +1,8 @@ -from couchpotato.core.helpers.variable import md5 +from couchpotato.core.event import addEvent +from couchpotato.core.helpers.variable import md5, sha1 from couchpotato.core.logger import CPLog from couchpotato.core.providers.automation.base import Automation +import base64 import json log = CPLog(__name__) @@ -13,6 +15,14 @@ class Trakt(Automation): 'watchlist': 'user/watchlist/movies.json/%s/', } + def __init__(self): + super(Trakt, self).__init__() + + addEvent('setting.save.trakt.automation_password', self.sha1Password) + + def sha1Password(self, value): + return sha1(value) if value else '' + def getIMDBids(self): if self.isDisabled(): @@ -31,6 +41,13 @@ class Trakt(Automation): def call(self, method_url): + if self.conf('automation_password'): + headers = { + 'Authorization': "Basic %s" % base64.encodestring('%s:%s' % (self.conf('automation_username'), self.conf('automation_password')))[:-1] + } + else: + headers = {} + cache_key = 'trakt.%s' % md5(method_url) - json_string = self.getCache(cache_key, self.urls['base'] + method_url) + json_string = self.getCache(cache_key, self.urls['base'] + method_url, headers = headers) return json.loads(json_string) diff --git a/couchpotato/core/providers/movie/_modifier/main.py b/couchpotato/core/providers/movie/_modifier/main.py index 527c7b9..c19af75 100644 --- a/couchpotato/core/providers/movie/_modifier/main.py +++ b/couchpotato/core/providers/movie/_modifier/main.py @@ -44,8 +44,8 @@ class MovieResultModifier(Plugin): } # Add release info from current library + db = get_session() try: - db = get_session() l = db.query(Library).filter_by(identifier = imdb).first() if l: @@ -63,6 +63,7 @@ class MovieResultModifier(Plugin): except: log.error('Tried getting more info on searched movies: %s' % traceback.format_exc()) + db.close() return temp def checkLibrary(self, result): diff --git a/couchpotato/core/providers/movie/couchpotatoapi/main.py b/couchpotato/core/providers/movie/couchpotatoapi/main.py index 26e42df..3e438f2 100644 --- a/couchpotato/core/providers/movie/couchpotatoapi/main.py +++ b/couchpotato/core/providers/movie/couchpotatoapi/main.py @@ -59,6 +59,7 @@ class CouchPotatoApi(MovieProvider): db = get_session() active_movies = db.query(Movie).filter(Movie.status.has(identifier = 'active')).all() movies = [x.library.identifier for x in active_movies] + db.close() suggestions = self.suggest(movies, ignore) diff --git a/couchpotato/core/providers/nzb/newznab/main.py b/couchpotato/core/providers/nzb/newznab/main.py index 57691b7..0d648a9 100644 --- a/couchpotato/core/providers/nzb/newznab/main.py +++ b/couchpotato/core/providers/nzb/newznab/main.py @@ -4,6 +4,7 @@ from couchpotato.core.helpers.rss import RSS from couchpotato.core.helpers.variable import cleanHost from couchpotato.core.logger import CPLog from couchpotato.core.providers.nzb.base import NZBProvider +from couchpotato.environment import Env from dateutil.parser import parse import time import xml.etree.ElementTree as XMLTree @@ -99,7 +100,7 @@ class Newznab(NZBProvider, RSS): def createItems(self, url, cache_key, host, single_cat = False, movie = None, quality = None, for_feed = False): results = [] - data = self.getCache(cache_key, url) + data = self.getCache(cache_key, url, cache_timeout = 1800, headers = {'User-Agent': Env.getIdentifier()}) if data: try: try: diff --git a/couchpotato/core/providers/nzb/nzbmatrix/main.py b/couchpotato/core/providers/nzb/nzbmatrix/main.py index e2b426b..2a6e5a8 100644 --- a/couchpotato/core/providers/nzb/nzbmatrix/main.py +++ b/couchpotato/core/providers/nzb/nzbmatrix/main.py @@ -51,7 +51,7 @@ class NZBMatrix(NZBProvider, RSS): cache_key = 'nzbmatrix.%s.%s' % (movie['library'].get('identifier'), cat_ids) single_cat = True - data = self.getCache(cache_key, url, cache_timeout = 1800, headers = {'User-Agent': 'CouchPotato'}) + data = self.getCache(cache_key, url, cache_timeout = 1800, headers = {'User-Agent': Env.getIdentifier()}) if data: try: try: diff --git a/couchpotato/core/providers/trailer/hdtrailers/main.py b/couchpotato/core/providers/trailer/hdtrailers/main.py index 2698070..03fa910 100644 --- a/couchpotato/core/providers/trailer/hdtrailers/main.py +++ b/couchpotato/core/providers/trailer/hdtrailers/main.py @@ -23,8 +23,10 @@ class HDTrailers(TrailerProvider): url = self.urls['api'] % self.movieUrlName(movie_name) data = self.getCache('hdtrailers.%s' % group['library']['identifier'], url) + result_data = {'480p':[], '720p':[], '1080p':[]} - result_data = {} + if not data: + return result_data did_alternative = False for provider in self.providers: diff --git a/couchpotato/core/settings/__init__.py b/couchpotato/core/settings/__init__.py index 48c7c32..985515a 100644 --- a/couchpotato/core/settings/__init__.py +++ b/couchpotato/core/settings/__init__.py @@ -197,11 +197,15 @@ class Settings(object): from couchpotato import get_session db = get_session() + prop = None try: - prop = db.query(Properties).filter_by(identifier = identifier).first() - return prop.value if prop else None + propert = db.query(Properties).filter_by(identifier = identifier).first() + prop = propert.value except: - return None + pass + + db.close() + return prop def setProperty(self, identifier, value = ''): from couchpotato import get_session @@ -217,3 +221,4 @@ class Settings(object): p.value = toUnicode(value) db.commit() + db.close() diff --git a/couchpotato/core/settings/model.py b/couchpotato/core/settings/model.py index 9a909be..950a67a 100644 --- a/couchpotato/core/settings/model.py +++ b/couchpotato/core/settings/model.py @@ -238,7 +238,15 @@ class Properties(Entity): def setup(): """Setup the database and create the tables that don't exists yet""" from elixir import setup_all, create_all - from couchpotato import get_engine + from couchpotato.environment import Env + + engine = Env.getEngine() setup_all() - create_all(get_engine()) + create_all(engine) + + try: + engine.execute("PRAGMA journal_mode = WAL") + engine.execute("PRAGMA temp_store = MEMORY") + except: + pass diff --git a/couchpotato/environment.py b/couchpotato/environment.py index 1ac184f..e42cbad 100644 --- a/couchpotato/environment.py +++ b/couchpotato/environment.py @@ -1,10 +1,15 @@ from couchpotato.core.event import fireEvent, addEvent from couchpotato.core.loader import Loader from couchpotato.core.settings import Settings +from sqlalchemy.engine import create_engine +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.session import sessionmaker import os class Env(object): + _appname = 'CouchPotato' + ''' Environment variables ''' _encoding = '' _uses_git = False @@ -18,6 +23,7 @@ class Env(object): _quiet = False _deamonize = False _desktop = None + _session = None ''' Data paths and directories ''' _app_dir = "" @@ -47,6 +53,22 @@ class Env(object): return setattr(Env, '_' + attr, value) @staticmethod + def getSession(engine = None): + existing_session = Env.get('session') + if existing_session: + return existing_session + + engine = Env.getEngine() + session = scoped_session(sessionmaker(bind = engine)) + Env.set('session', session) + + return session + + @staticmethod + def getEngine(): + return create_engine(Env.get('db_path'), echo = False) + + @staticmethod def setting(attr, section = 'core', value = None, default = '', type = None): s = Env.get('settings') @@ -96,3 +118,7 @@ class Env(object): return '%d %s' % (os.getpid(), '(%d)' % parent if parent and parent > 1 else '') except: return 0 + + @staticmethod + def getIdentifier(): + return '%s %s' % (Env.get('appname'), fireEvent('app.version', single = True)) diff --git a/couchpotato/runner.py b/couchpotato/runner.py index 0be3bdb..e8f2fb0 100644 --- a/couchpotato/runner.py +++ b/couchpotato/runner.py @@ -9,6 +9,7 @@ import atexit import locale import logging import os.path +import shutil import sys import time import warnings @@ -58,13 +59,45 @@ def runCouchPotato(options, base_path, args, data_dir = None, log_dir = None, En if not encoding or encoding in ('ANSI_X3.4-1968', 'US-ASCII', 'ASCII'): encoding = 'UTF-8' + # Do db stuff + db_path = os.path.join(data_dir, 'couchpotato.db') + + # Backup before start and cleanup old databases + new_backup = os.path.join(data_dir, 'db_backup', str(int(time.time()))) + + # Create path and copy + if not os.path.isdir(new_backup): os.makedirs(new_backup) + src_files = [options.config_file, db_path, db_path + '-shm', db_path + '-wal'] + for src_file in src_files: + if os.path.isfile(src_file): + shutil.copy2(src_file, os.path.join(new_backup, os.path.basename(src_file))) + + # Remove older backups, keep backups 3 days or at least 3 + backups = [] + for directory in os.listdir(os.path.dirname(new_backup)): + backup = os.path.join(os.path.dirname(new_backup), directory) + if os.path.isdir(backup): + backups.append(backup) + + total_backups = len(backups) + for backup in backups: + if total_backups > 3: + if int(os.path.basename(backup)) < time.time() - 259200: + for src_file in src_files: + b_file = os.path.join(backup, os.path.basename(src_file)) + if os.path.isfile(b_file): + os.remove(b_file) + os.rmdir(backup) + total_backups -= 1 + + # Register environment settings Env.set('encoding', encoding) Env.set('uses_git', not options.nogit) Env.set('app_dir', base_path) Env.set('data_dir', data_dir) Env.set('log_path', os.path.join(log_dir, 'CouchPotato.log')) - Env.set('db_path', 'sqlite:///' + os.path.join(data_dir, 'couchpotato.db')) + Env.set('db_path', 'sqlite:///' + db_path) Env.set('cache_dir', os.path.join(data_dir, 'cache')) Env.set('cache', FileSystemCache(os.path.join(Env.get('cache_dir'), 'python'))) Env.set('console_log', options.console_log) @@ -83,12 +116,16 @@ def runCouchPotato(options, base_path, args, data_dir = None, log_dir = None, En if not development: atexit.register(cleanup) + # Disable logging for some modules + for logger_name in ['enzyme', 'guessit', 'subliminal', 'apscheduler']: + logging.getLogger(logger_name).setLevel(logging.ERROR) + + for logger_name in ['gntp', 'werkzeug', 'migrate']: + logging.getLogger(logger_name).setLevel(logging.WARNING) + # Use reloader reloader = debug is True and development and not Env.get('desktop') and not options.daemon - # Disable server access log - logging.getLogger('werkzeug').setLevel(logging.WARNING) - # Only run once when debugging fire_load = False if os.environ.get('WERKZEUG_RUN_MAIN') or not reloader: @@ -130,12 +167,11 @@ def runCouchPotato(options, base_path, args, data_dir = None, log_dir = None, En # Load migrations initialize = True db = Env.get('db_path') - if os.path.isfile(db.replace('sqlite:///', '')): + if os.path.isfile(db_path): initialize = False from migrate.versioning.api import version_control, db_version, version, upgrade repo = os.path.join(base_path, 'couchpotato', 'core', 'migration') - logging.getLogger('migrate').setLevel(logging.WARNING) # Disable logging for migration latest_db_version = version(repo) try: diff --git a/couchpotato/static/scripts/couchpotato.js b/couchpotato/static/scripts/couchpotato.js index 451e434..e2c23b8 100644 --- a/couchpotato/static/scripts/couchpotato.js +++ b/couchpotato/static/scripts/couchpotato.js @@ -84,6 +84,10 @@ var CouchPotato = new Class({ 'events': { 'click': self.checkForUpdate.bind(self) } + }), + new Element('a', { + 'text': 'Run install wizard', + 'href': App.createUrl('wizard') })].each(function(a){ self.block.more.addLink(a) }) diff --git a/couchpotato/static/scripts/page/settings.js b/couchpotato/static/scripts/page/settings.js index a33d774..20ca457 100644 --- a/couchpotato/static/scripts/page/settings.js +++ b/couchpotato/static/scripts/page/settings.js @@ -70,10 +70,16 @@ Page.Settings = new Class({ t.tab[a](c); t.subtabs[subtab].tab[a](c); t.subtabs[subtab].content[a](c); + + if(!hide) + t.subtabs[subtab].content.fireEvent('activate'); } else { t.tab[a](c); t.content[a](c); + + if(!hide) + t.content.fireEvent('activate'); } return t @@ -869,6 +875,7 @@ Option.Choice = new Class({ afterInject: function(){ var self = this; + self.tags = []; self.replaceInput(); self.select = new Element('select').adopt( @@ -941,6 +948,10 @@ Option.Choice = new Class({ self.reset(); } }); + + // Calc width on show + var input_group = self.tag_input.getParent('.tab_content'); + input_group.addEvent('activate', self.setAllWidth.bind(self)); }, addLastTag: function(){ @@ -952,10 +963,8 @@ Option.Choice = new Class({ var self = this; tag = new Option.Choice.Tag(tag, { 'onChange': self.setOrder.bind(self), - 'onFocus': self.activate.bind(self), 'onBlur': function(){ self.addLastTag(); - self.deactivate(); } }); $(tag).inject(self.tag_input); @@ -965,6 +974,8 @@ Option.Choice = new Class({ else (function(){ tag.setWidth(); }).delay(10, self); + self.tags.include(tag); + return tag; }, @@ -979,6 +990,7 @@ Option.Choice = new Class({ self.input.set('value', value); self.input.fireEvent('change'); + self.setAllWidth(); }, addSelection: function(){ @@ -987,6 +999,7 @@ Option.Choice = new Class({ var tag = self.addTag(self.el.getElement('.selection input').get('value')); self.sortable.addItems($(tag)); self.setOrder(); + self.setAllWidth(); }, reset: function(){ @@ -996,14 +1009,14 @@ Option.Choice = new Class({ self.sortable.detach(); self.replaceInput(); + self.setAllWidth(); }, - activate: function(){ - - }, - - deactivate: function(){ - + setAllWidth: function(){ + var self = this; + self.tags.each(function(tag){ + tag.setWidth.delay(10, tag); + }); } }); @@ -1032,6 +1045,9 @@ Option.Choice.Tag = new Class({ self.el = new Element('li', { 'class': self.is_choice ? 'choice' : '', + 'styles': { + 'border': 0 + }, 'events': { 'mouseover': !self.is_choice ? self.fireEvent.bind(self, 'focus') : function(){} } @@ -1039,6 +1055,9 @@ Option.Choice.Tag = new Class({ self.input = new Element(self.is_choice ? 'span' : 'input', { 'text': self.tag, 'value': self.tag, + 'styles': { + 'width': 0 + }, 'events': { 'keyup': self.is_choice ? null : function(){ self.setWidth(); diff --git a/couchpotato/static/scripts/page/wanted.js b/couchpotato/static/scripts/page/wanted.js index b5a82cf..1f811c6 100644 --- a/couchpotato/static/scripts/page/wanted.js +++ b/couchpotato/static/scripts/page/wanted.js @@ -211,6 +211,7 @@ window.addEvent('domready', function(){ }, 'onComplete': function(){ movie.set('tween', { + 'duration': 300, 'onComplete': function(){ movie.destroy(); } diff --git a/couchpotato/static/style/page/settings.css b/couchpotato/static/style/page/settings.css index 4d157f5..44390f9 100644 --- a/couchpotato/static/style/page/settings.css +++ b/couchpotato/static/style/page/settings.css @@ -362,28 +362,28 @@ border-radius: 2px; } .page .tag_input > ul:hover > li.choice { - background: url('../images/sprite.png') no-repeat 94% -53px, -webkit-gradient( + background: url('../../images/sprite.png') no-repeat 94% -53px, -webkit-gradient( linear, left bottom, left top, color-stop(0, rgba(255,255,255,0.1)), color-stop(1, rgba(255,255,255,0.3)) ); - background: url('../images/sprite.png') no-repeat 94% -53px, -moz-linear-gradient( + background: url('../../images/sprite.png') no-repeat 94% -53px, -moz-linear-gradient( center top, rgba(255,255,255,0.3) 0%, rgba(255,255,255,0.1) 100% ); } .page .tag_input > ul > li.choice:hover { - background: url('../images/sprite.png') no-repeat 94% -53px, -webkit-gradient( + background: url('../../images/sprite.png') no-repeat 94% -53px, -webkit-gradient( linear, left bottom, left top, color-stop(0, #406db8), color-stop(1, #5b9bd1) ); - background: url('../images/sprite.png') no-repeat 94% -53px, -moz-linear-gradient( + background: url('../../images/sprite.png') no-repeat 94% -53px, -moz-linear-gradient( center top, #5b9bd1 0%, #406db8 100% diff --git a/libs/axl/axel.py b/libs/axl/axel.py index 86921a9..8e9b607 100644 --- a/libs/axl/axel.py +++ b/libs/axl/axel.py @@ -164,7 +164,12 @@ class Event(object): if not self.asynchronous: self.queue.join() - return self.result or None + res = self.result or None + + # Cleanup + self.result = {} + + return res def count(self): """ Returns the count of registered handlers """ diff --git a/libs/guessit/ISO-3166-1_utf8.txt b/libs/guessit/ISO-3166-1_utf8.txt new file mode 100644 index 0000000..7022040 --- /dev/null +++ b/libs/guessit/ISO-3166-1_utf8.txt @@ -0,0 +1,249 @@ +Afghanistan|AF|AFG|004|ISO 3166-2:AF +Åland Islands|AX|ALA|248|ISO 3166-2:AX +Albania|AL|ALB|008|ISO 3166-2:AL +Algeria|DZ|DZA|012|ISO 3166-2:DZ +American Samoa|AS|ASM|016|ISO 3166-2:AS +Andorra|AD|AND|020|ISO 3166-2:AD +Angola|AO|AGO|024|ISO 3166-2:AO +Anguilla|AI|AIA|660|ISO 3166-2:AI +Antarctica|AQ|ATA|010|ISO 3166-2:AQ +Antigua and Barbuda|AG|ATG|028|ISO 3166-2:AG +Argentina|AR|ARG|032|ISO 3166-2:AR +Armenia|AM|ARM|051|ISO 3166-2:AM +Aruba|AW|ABW|533|ISO 3166-2:AW +Australia|AU|AUS|036|ISO 3166-2:AU +Austria|AT|AUT|040|ISO 3166-2:AT +Azerbaijan|AZ|AZE|031|ISO 3166-2:AZ +Bahamas|BS|BHS|044|ISO 3166-2:BS +Bahrain|BH|BHR|048|ISO 3166-2:BH +Bangladesh|BD|BGD|050|ISO 3166-2:BD +Barbados|BB|BRB|052|ISO 3166-2:BB +Belarus|BY|BLR|112|ISO 3166-2:BY +Belgium|BE|BEL|056|ISO 3166-2:BE +Belize|BZ|BLZ|084|ISO 3166-2:BZ +Benin|BJ|BEN|204|ISO 3166-2:BJ +Bermuda|BM|BMU|060|ISO 3166-2:BM +Bhutan|BT|BTN|064|ISO 3166-2:BT +Bolivia, Plurinational State of|BO|BOL|068|ISO 3166-2:BO +Bonaire, Sint Eustatius and Saba|BQ|BES|535|ISO 3166-2:BQ +Bosnia and Herzegovina|BA|BIH|070|ISO 3166-2:BA +Botswana|BW|BWA|072|ISO 3166-2:BW +Bouvet Island|BV|BVT|074|ISO 3166-2:BV +Brazil|BR|BRA|076|ISO 3166-2:BR +British Indian Ocean Territory|IO|IOT|086|ISO 3166-2:IO +Brunei Darussalam|BN|BRN|096|ISO 3166-2:BN +Bulgaria|BG|BGR|100|ISO 3166-2:BG +Burkina Faso|BF|BFA|854|ISO 3166-2:BF +Burundi|BI|BDI|108|ISO 3166-2:BI +Cambodia|KH|KHM|116|ISO 3166-2:KH +Cameroon|CM|CMR|120|ISO 3166-2:CM +Canada|CA|CAN|124|ISO 3166-2:CA +Cape Verde|CV|CPV|132|ISO 3166-2:CV +Cayman Islands|KY|CYM|136|ISO 3166-2:KY +Central African Republic|CF|CAF|140|ISO 3166-2:CF +Chad|TD|TCD|148|ISO 3166-2:TD +Chile|CL|CHL|152|ISO 3166-2:CL +China|CN|CHN|156|ISO 3166-2:CN +Christmas Island|CX|CXR|162|ISO 3166-2:CX +Cocos (Keeling) Islands|CC|CCK|166|ISO 3166-2:CC +Colombia|CO|COL|170|ISO 3166-2:CO +Comoros|KM|COM|174|ISO 3166-2:KM +Congo|CG|COG|178|ISO 3166-2:CG +Congo, the Democratic Republic of the|CD|COD|180|ISO 3166-2:CD +Cook Islands|CK|COK|184|ISO 3166-2:CK +Costa Rica|CR|CRI|188|ISO 3166-2:CR +Côte d'Ivoire|CI|CIV|384|ISO 3166-2:CI +Croatia|HR|HRV|191|ISO 3166-2:HR +Cuba|CU|CUB|192|ISO 3166-2:CU +Curaçao|CW|CUW|531|ISO 3166-2:CW +Cyprus|CY|CYP|196|ISO 3166-2:CY +Czech Republic|CZ|CZE|203|ISO 3166-2:CZ +Denmark|DK|DNK|208|ISO 3166-2:DK +Djibouti|DJ|DJI|262|ISO 3166-2:DJ +Dominica|DM|DMA|212|ISO 3166-2:DM +Dominican Republic|DO|DOM|214|ISO 3166-2:DO +Ecuador|EC|ECU|218|ISO 3166-2:EC +Egypt|EG|EGY|818|ISO 3166-2:EG +El Salvador|SV|SLV|222|ISO 3166-2:SV +Equatorial Guinea|GQ|GNQ|226|ISO 3166-2:GQ +Eritrea|ER|ERI|232|ISO 3166-2:ER +Estonia|EE|EST|233|ISO 3166-2:EE +Ethiopia|ET|ETH|231|ISO 3166-2:ET +Falkland Islands (Malvinas|FK|FLK|238|ISO 3166-2:FK +Faroe Islands|FO|FRO|234|ISO 3166-2:FO +Fiji|FJ|FJI|242|ISO 3166-2:FJ +Finland|FI|FIN|246|ISO 3166-2:FI +France|FR|FRA|250|ISO 3166-2:FR +French Guiana|GF|GUF|254|ISO 3166-2:GF +French Polynesia|PF|PYF|258|ISO 3166-2:PF +French Southern Territories|TF|ATF|260|ISO 3166-2:TF +Gabon|GA|GAB|266|ISO 3166-2:GA +Gambia|GM|GMB|270|ISO 3166-2:GM +Georgia|GE|GEO|268|ISO 3166-2:GE +Germany|DE|DEU|276|ISO 3166-2:DE +Ghana|GH|GHA|288|ISO 3166-2:GH +Gibraltar|GI|GIB|292|ISO 3166-2:GI +Greece|GR|GRC|300|ISO 3166-2:GR +Greenland|GL|GRL|304|ISO 3166-2:GL +Grenada|GD|GRD|308|ISO 3166-2:GD +Guadeloupe|GP|GLP|312|ISO 3166-2:GP +Guam|GU|GUM|316|ISO 3166-2:GU +Guatemala|GT|GTM|320|ISO 3166-2:GT +Guernsey|GG|GGY|831|ISO 3166-2:GG +Guinea|GN|GIN|324|ISO 3166-2:GN +Guinea-Bissau|GW|GNB|624|ISO 3166-2:GW +Guyana|GY|GUY|328|ISO 3166-2:GY +Haiti|HT|HTI|332|ISO 3166-2:HT +Heard Island and McDonald Islands|HM|HMD|334|ISO 3166-2:HM +Holy See (Vatican City State|VA|VAT|336|ISO 3166-2:VA +Honduras|HN|HND|340|ISO 3166-2:HN +Hong Kong|HK|HKG|344|ISO 3166-2:HK +Hungary|HU|HUN|348|ISO 3166-2:HU +Iceland|IS|ISL|352|ISO 3166-2:IS +India|IN|IND|356|ISO 3166-2:IN +Indonesia|ID|IDN|360|ISO 3166-2:ID +Iran, Islamic Republic of|IR|IRN|364|ISO 3166-2:IR +Iraq|IQ|IRQ|368|ISO 3166-2:IQ +Ireland|IE|IRL|372|ISO 3166-2:IE +Isle of Man|IM|IMN|833|ISO 3166-2:IM +Israel|IL|ISR|376|ISO 3166-2:IL +Italy|IT|ITA|380|ISO 3166-2:IT +Jamaica|JM|JAM|388|ISO 3166-2:JM +Japan|JP|JPN|392|ISO 3166-2:JP +Jersey|JE|JEY|832|ISO 3166-2:JE +Jordan|JO|JOR|400|ISO 3166-2:JO +Kazakhstan|KZ|KAZ|398|ISO 3166-2:KZ +Kenya|KE|KEN|404|ISO 3166-2:KE +Kiribati|KI|KIR|296|ISO 3166-2:KI +Korea, Democratic People's Republic of|KP|PRK|408|ISO 3166-2:KP +Korea, Republic of|KR|KOR|410|ISO 3166-2:KR +Kuwait|KW|KWT|414|ISO 3166-2:KW +Kyrgyzstan|KG|KGZ|417|ISO 3166-2:KG +Lao People's Democratic Republic|LA|LAO|418|ISO 3166-2:LA +Latvia|LV|LVA|428|ISO 3166-2:LV +Lebanon|LB|LBN|422|ISO 3166-2:LB +Lesotho|LS|LSO|426|ISO 3166-2:LS +Liberia|LR|LBR|430|ISO 3166-2:LR +Libya|LY|LBY|434|ISO 3166-2:LY +Liechtenstein|LI|LIE|438|ISO 3166-2:LI +Lithuania|LT|LTU|440|ISO 3166-2:LT +Luxembourg|LU|LUX|442|ISO 3166-2:LU +Macao|MO|MAC|446|ISO 3166-2:MO +Macedonia, the former Yugoslav Republic of|MK|MKD|807|ISO 3166-2:MK +Madagascar|MG|MDG|450|ISO 3166-2:MG +Malawi|MW|MWI|454|ISO 3166-2:MW +Malaysia|MY|MYS|458|ISO 3166-2:MY +Maldives|MV|MDV|462|ISO 3166-2:MV +Mali|ML|MLI|466|ISO 3166-2:ML +Malta|MT|MLT|470|ISO 3166-2:MT +Marshall Islands|MH|MHL|584|ISO 3166-2:MH +Martinique|MQ|MTQ|474|ISO 3166-2:MQ +Mauritania|MR|MRT|478|ISO 3166-2:MR +Mauritius|MU|MUS|480|ISO 3166-2:MU +Mayotte|YT|MYT|175|ISO 3166-2:YT +Mexico|MX|MEX|484|ISO 3166-2:MX +Micronesia, Federated States of|FM|FSM|583|ISO 3166-2:FM +Moldova, Republic of|MD|MDA|498|ISO 3166-2:MD +Monaco|MC|MCO|492|ISO 3166-2:MC +Mongolia|MN|MNG|496|ISO 3166-2:MN +Montenegro|ME|MNE|499|ISO 3166-2:ME +Montserrat|MS|MSR|500|ISO 3166-2:MS +Morocco|MA|MAR|504|ISO 3166-2:MA +Mozambique|MZ|MOZ|508|ISO 3166-2:MZ +Myanmar|MM|MMR|104|ISO 3166-2:MM +Namibia|NA|NAM|516|ISO 3166-2:NA +Nauru|NR|NRU|520|ISO 3166-2:NR +Nepal|NP|NPL|524|ISO 3166-2:NP +Netherlands|NL|NLD|528|ISO 3166-2:NL +New Caledonia|NC|NCL|540|ISO 3166-2:NC +New Zealand|NZ|NZL|554|ISO 3166-2:NZ +Nicaragua|NI|NIC|558|ISO 3166-2:NI +Niger|NE|NER|562|ISO 3166-2:NE +Nigeria|NG|NGA|566|ISO 3166-2:NG +Niue|NU|NIU|570|ISO 3166-2:NU +Norfolk Island|NF|NFK|574|ISO 3166-2:NF +Northern Mariana Islands|MP|MNP|580|ISO 3166-2:MP +Norway|NO|NOR|578|ISO 3166-2:NO +Oman|OM|OMN|512|ISO 3166-2:OM +Pakistan|PK|PAK|586|ISO 3166-2:PK +Palau|PW|PLW|585|ISO 3166-2:PW +Palestinian Territory, Occupied|PS|PSE|275|ISO 3166-2:PS +Panama|PA|PAN|591|ISO 3166-2:PA +Papua New Guinea|PG|PNG|598|ISO 3166-2:PG +Paraguay|PY|PRY|600|ISO 3166-2:PY +Peru|PE|PER|604|ISO 3166-2:PE +Philippines|PH|PHL|608|ISO 3166-2:PH +Pitcairn|PN|PCN|612|ISO 3166-2:PN +Poland|PL|POL|616|ISO 3166-2:PL +Portugal|PT|PRT|620|ISO 3166-2:PT +Puerto Rico|PR|PRI|630|ISO 3166-2:PR +Qatar|QA|QAT|634|ISO 3166-2:QA +Réunion|RE|REU|638|ISO 3166-2:RE +Romania|RO|ROU|642|ISO 3166-2:RO +Russian Federation|RU|RUS|643|ISO 3166-2:RU +Rwanda|RW|RWA|646|ISO 3166-2:RW +Saint Barthélemy|BL|BLM|652|ISO 3166-2:BL +Saint Helena, Ascension and Tristan da Cunha|SH|SHN|654|ISO 3166-2:SH +Saint Kitts and Nevis|KN|KNA|659|ISO 3166-2:KN +Saint Lucia|LC|LCA|662|ISO 3166-2:LC +Saint Martin (French part|MF|MAF|663|ISO 3166-2:MF +Saint Pierre and Miquelon|PM|SPM|666|ISO 3166-2:PM +Saint Vincent and the Grenadines|VC|VCT|670|ISO 3166-2:VC +Samoa|WS|WSM|882|ISO 3166-2:WS +San Marino|SM|SMR|674|ISO 3166-2:SM +Sao Tome and Principe|ST|STP|678|ISO 3166-2:ST +Saudi Arabia|SA|SAU|682|ISO 3166-2:SA +Senegal|SN|SEN|686|ISO 3166-2:SN +Serbia|RS|SRB|688|ISO 3166-2:RS +Seychelles|SC|SYC|690|ISO 3166-2:SC +Sierra Leone|SL|SLE|694|ISO 3166-2:SL +Singapore|SG|SGP|702|ISO 3166-2:SG +Sint Maarten (Dutch part|SX|SXM|534|ISO 3166-2:SX +Slovakia|SK|SVK|703|ISO 3166-2:SK +Slovenia|SI|SVN|705|ISO 3166-2:SI +Solomon Islands|SB|SLB|090|ISO 3166-2:SB +Somalia|SO|SOM|706|ISO 3166-2:SO +South Africa|ZA|ZAF|710|ISO 3166-2:ZA +South Georgia and the South Sandwich Islands|GS|SGS|239|ISO 3166-2:GS +South Sudan|SS|SSD|728|ISO 3166-2:SS +Spain|ES|ESP|724|ISO 3166-2:ES +Sri Lanka|LK|LKA|144|ISO 3166-2:LK +Sudan|SD|SDN|729|ISO 3166-2:SD +Suriname|SR|SUR|740|ISO 3166-2:SR +Svalbard and Jan Mayen|SJ|SJM|744|ISO 3166-2:SJ +Swaziland|SZ|SWZ|748|ISO 3166-2:SZ +Sweden|SE|SWE|752|ISO 3166-2:SE +Switzerland|CH|CHE|756|ISO 3166-2:CH +Syrian Arab Republic|SY|SYR|760|ISO 3166-2:SY +Taiwan, Province of China|TW|TWN|158|ISO 3166-2:TW +Tajikistan|TJ|TJK|762|ISO 3166-2:TJ +Tanzania, United Republic of|TZ|TZA|834|ISO 3166-2:TZ +Thailand|TH|THA|764|ISO 3166-2:TH +Timor-Leste|TL|TLS|626|ISO 3166-2:TL +Togo|TG|TGO|768|ISO 3166-2:TG +Tokelau|TK|TKL|772|ISO 3166-2:TK +Tonga|TO|TON|776|ISO 3166-2:TO +Trinidad and Tobago|TT|TTO|780|ISO 3166-2:TT +Tunisia|TN|TUN|788|ISO 3166-2:TN +Turkey|TR|TUR|792|ISO 3166-2:TR +Turkmenistan|TM|TKM|795|ISO 3166-2:TM +Turks and Caicos Islands|TC|TCA|796|ISO 3166-2:TC +Tuvalu|TV|TUV|798|ISO 3166-2:TV +Uganda|UG|UGA|800|ISO 3166-2:UG +Ukraine|UA|UKR|804|ISO 3166-2:UA +United Arab Emirates|AE|ARE|784|ISO 3166-2:AE +United Kingdom|GB|GBR|826|ISO 3166-2:GB +United States|US|USA|840|ISO 3166-2:US +United States Minor Outlying Islands|UM|UMI|581|ISO 3166-2:UM +Uruguay|UY|URY|858|ISO 3166-2:UY +Uzbekistan|UZ|UZB|860|ISO 3166-2:UZ +Vanuatu|VU|VUT|548|ISO 3166-2:VU +Venezuela, Bolivarian Republic of|VE|VEN|862|ISO 3166-2:VE +Viet Nam|VN|VNM|704|ISO 3166-2:VN +Virgin Islands, British|VG|VGB|092|ISO 3166-2:VG +Virgin Islands, U.S|VI|VIR|850|ISO 3166-2:VI +Wallis and Futuna|WF|WLF|876|ISO 3166-2:WF +Western Sahara|EH|ESH|732|ISO 3166-2:EH +Yemen|YE|YEM|887|ISO 3166-2:YE +Zambia|ZM|ZMB|894|ISO 3166-2:ZM +Zimbabwe|ZW|ZWE|716|ISO 3166-2:ZW diff --git a/libs/guessit/__init__.py b/libs/guessit/__init__.py index a86f71b..9c7c9d0 100644 --- a/libs/guessit/__init__.py +++ b/libs/guessit/__init__.py @@ -18,7 +18,7 @@ # along with this program. If not, see . # -__version__ = '0.3.1' +__version__ = '0.4' __all__ = ['Guess', 'Language', 'guess_file_info', 'guess_video_info', 'guess_movie_info', 'guess_episode_info'] @@ -29,7 +29,7 @@ from guessit.language import Language from guessit.matcher import IterativeMatcher import logging -log = logging.getLogger("guessit") +log = logging.getLogger(__name__) class NullHandler(logging.Handler): diff --git a/libs/guessit/country.py b/libs/guessit/country.py new file mode 100644 index 0000000..f529728 --- /dev/null +++ b/libs/guessit/country.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# GuessIt - A library for guessing information from filenames +# Copyright (c) 2012 Nicolas Wack +# +# GuessIt is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# GuessIt is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see . +# + +from __future__ import unicode_literals +from guessit import fileutils +import logging + +log = logging.getLogger(__name__) + + +# parsed from http://en.wikipedia.org/wiki/ISO_3166-1 +# +# Description of the fields: +# "An English name, an alpha-2 code (when given), +# an alpha-3 code (when given), a numeric code, and an ISO 31666-2 code +# are all separated by pipe (|) characters." +_iso3166_contents = fileutils.load_file_in_same_dir(__file__, + 'ISO-3166-1_utf8.txt').decode('utf-8') + +country_matrix = [ l.strip().split('|') + for l in _iso3166_contents.strip().split('\n') ] + +country_matrix += [ [ 'Unknown', 'un', 'unk', '', '' ], + [ 'Latin America', '', 'lat', '', '' ] + ] + +country_to_alpha3 = dict((c[0].lower(), c[2].lower()) for c in country_matrix) +country_to_alpha3.update(dict((c[1].lower(), c[2].lower()) for c in country_matrix)) +country_to_alpha3.update(dict((c[2].lower(), c[2].lower()) for c in country_matrix)) + +# add here exceptions / non ISO representations +# Note: remember to put those exceptions in lower-case, they won't work otherwise +country_to_alpha3.update({ 'latinoamérica': 'lat', + 'brazilian': 'bra', + 'españa': 'esp', + 'uk': 'gbr' + }) + +country_alpha3_to_en_name = dict((c[2].lower(), c[0]) for c in country_matrix) +country_alpha3_to_alpha2 = dict((c[2].lower(), c[1].lower()) for c in country_matrix) + + + +class Country(object): + """This class represents a country. + + You can initialize it with pretty much anything, as it knows conversion + from ISO-3166 2-letter and 3-letter codes, and an English name. + """ + + def __init__(self, country, strict=False): + self.alpha3 = country_to_alpha3.get(country.lower()) + + if self.alpha3 is None and strict: + msg = 'The given string "%s" could not be identified as a country' + raise ValueError(msg % country) + + if self.alpha3 is None: + self.alpha3 = 'unk' + + + @property + def alpha2(self): + return country_alpha3_to_alpha2[self.alpha3] + + @property + def english_name(self): + return country_alpha3_to_en_name[self.alpha3] + + def __hash__(self): + return hash(self.alpha3) + + def __eq__(self, other): + if isinstance(other, Country): + return self.alpha3 == other.alpha3 + + if isinstance(other, basestring): + try: + return self == Country(other) + except ValueError: + return False + + return False + + def __ne__(self, other): + return not self == other + + def __unicode__(self): + return self.english_name + + def __str__(self): + return unicode(self).encode('utf-8') + + def __repr__(self): + return 'Country(%s)' % self.english_name + diff --git a/libs/guessit/fileutils.py b/libs/guessit/fileutils.py index e146390..bd315bc 100644 --- a/libs/guessit/fileutils.py +++ b/libs/guessit/fileutils.py @@ -51,8 +51,8 @@ def split_path(path): if head == '/' and tail == '': return ['/'] + result - # on Windows, the root folder is a drive letter (eg: 'C:\') - if len(head) == 3 and head[1:] == ':\\' and tail == '': + # on Windows, the root folder is a drive letter (eg: 'C:\') or for shares \\ + if ((len(head) == 3 and head[1:] == ':\\') or (len(head) == 2 and head == '\\\\')) and tail == '': return [head] + result if head == '' and tail == '': diff --git a/libs/guessit/guess.py b/libs/guessit/guess.py index 9950a12..e25ca1f 100644 --- a/libs/guessit/guess.py +++ b/libs/guessit/guess.py @@ -22,7 +22,7 @@ import json import datetime import logging -log = logging.getLogger("guessit.guess") +log = logging.getLogger(__name__) class Guess(dict): diff --git a/libs/guessit/language.py b/libs/guessit/language.py index 777d0e2..b043346 100644 --- a/libs/guessit/language.py +++ b/libs/guessit/language.py @@ -18,10 +18,18 @@ # along with this program. If not, see . # +from __future__ import unicode_literals from guessit import fileutils +from guessit.country import Country +import re import logging -log = logging.getLogger('guessit.language') +__all__ = [ 'is_iso_language', 'is_language', 'lang_set', 'Language', + 'ALL_LANGUAGES', 'ALL_LANGUAGES_NAMES', 'search_language' ] + + +log = logging.getLogger(__name__) + # downloaded from http://www.loc.gov/standards/iso639-2/ISO-639-2_utf-8.txt # @@ -30,9 +38,23 @@ log = logging.getLogger('guessit.language') # an alpha-2 code (when given), an English name, and a French name of a language # are all separated by pipe (|) characters." _iso639_contents = fileutils.load_file_in_same_dir(__file__, - 'ISO-639-2_utf-8.txt') -language_matrix = [ l.strip().decode('utf-8').split('|') - for l in _iso639_contents.split('\n') ] + 'ISO-639-2_utf-8.txt').decode('utf-8') + +# drop the BOM from the beginning of the file +_iso639_contents = _iso639_contents[1:] + +language_matrix = [ l.strip().split('|') + for l in _iso639_contents.strip().split('\n') ] + +language_matrix += [ [ 'unk', '', 'un', 'Unknown', 'inconnu' ] ] + + +# remove unused languages that shadow other common ones with a non-official form +for lang in language_matrix: + if (lang[2] == 'se' or # Northern Sami shadows Swedish + lang[2] == 'br'): # Breton shadows Brazilian + language_matrix.remove(lang) + lng3 = frozenset(l[0] for l in language_matrix if l[0]) lng3term = frozenset(l[1] for l in language_matrix if l[1]) @@ -63,54 +85,126 @@ lng_fr_name_to_lng3 = dict((fr_name.lower(), l[0]) for l in language_matrix if l[4] for fr_name in l[4].split('; ')) +# contains a list of exceptions: strings that should be parsed as a language +# but which are not in an ISO form +lng_exceptions = { 'gr': ('gre', None), + 'greek': ('gre', None), + 'esp': ('spa', None), + 'español': ('spa', None), + 'se': ('swe', None), + 'po': ('pt', 'br'), + 'pob': ('pt', 'br'), + 'br': ('pt', 'br'), + 'brazilian': ('pt', 'br'), + 'català': ('cat', None), + 'cz': ('cze', None), + 'ua': ('ukr', None), + 'cn': ('chi', None), + 'chs': ('chi', None), + 'jp': ('jpn', None) + } + + +def is_iso_language(language): + return language.lower() in lng_all_names def is_language(language): - return language.lower() in lng_all_names + return is_iso_language(language) or language in lng_exceptions + +def lang_set(languages, strict=False): + """Return a set of guessit.Language created from their given string + representation. + + if strict is True, then this will raise an exception if any language + could not be identified. + """ + return set(Language(l, strict=strict) for l in languages) class Language(object): """This class represents a human language. - You can initialize it with pretty much everything, as it knows conversion + You can initialize it with pretty much anything, as it knows conversion from ISO-639 2-letter and 3-letter codes, English and French names. + You can also distinguish languages for specific countries, such as + Portuguese and Brazilian Portuguese. + >>> Language('fr') Language(French) - >>> Language('eng').french_name() + >>> Language('eng').french_name u'anglais' + + >>> Language('pt(br)').country.english_name + u'Brazil' + + >>> Language('Español (Latinoamérica)').country.english_name + u'Latin America' + + >>> Language('Spanish (Latin America)') == Language('Español (Latinoamérica)') + True + + >>> Language('zz', strict=False).english_name + u'Unknown' """ - def __init__(self, language): - lang = None - language = language.lower() + + _with_country_regexp = re.compile('(.*)\((.*)\)') + + def __init__(self, language, country=None, strict=False): + language = language.strip().lower() + if isinstance(language, str): + language = language.decode('utf-8') + with_country = Language._with_country_regexp.match(language) + if with_country: + self.lang = Language(with_country.group(1)).lang + self.country = Country(with_country.group(2)) + return + + self.lang = None + self.country = Country(country) if country else None + if len(language) == 2: - lang = lng2_to_lng3.get(language) + self.lang = lng2_to_lng3.get(language) elif len(language) == 3: - lang = (language - if language in lng3 - else lng3term_to_lng3.get(language)) + self.lang = (language + if language in lng3 + else lng3term_to_lng3.get(language)) else: - lang = (lng_en_name_to_lng3.get(language) or - lng_fr_name_to_lng3.get(language)) + self.lang = (lng_en_name_to_lng3.get(language) or + lng_fr_name_to_lng3.get(language)) - if lang is None: - msg = 'The given string "%s" could not be identified as a language' - raise ValueError(msg % language) + if self.lang is None and language in lng_exceptions: + lang, country = lng_exceptions[language] + self.lang = Language(lang).alpha3 + self.country = Country(country) if country else None - self.lang = lang + msg = 'The given string "%s" could not be identified as a language' % language - def lng2(self): + if self.lang is None and strict: + raise ValueError(msg) + + if self.lang is None: + log.debug(msg) + self.lang = 'unk' + + @property + def alpha2(self): return lng3_to_lng2[self.lang] - def lng3(self): + @property + def alpha3(self): return self.lang - def lng3term(self): + @property + def alpha3term(self): return lng3_to_lng3term[self.lang] + @property def english_name(self): return lng3_to_lng_en_name[self.lang] + @property def french_name(self): return lng3_to_lng_fr_name[self.lang] @@ -132,15 +226,27 @@ class Language(object): def __ne__(self, other): return not self == other + def __nonzero__(self): + return self.lang != 'unk' + def __unicode__(self): - return lng3_to_lng_en_name[self.lang] + if self.country: + return '%s(%s)' % (self.english_name, self.country.alpha2) + else: + return self.english_name def __str__(self): return unicode(self).encode('utf-8') def __repr__(self): - return 'Language(%s)' % self + if self.country: + return 'Language(%s, country=%s)' % (self.english_name, self.country) + else: + return 'Language(%s)' % self.english_name + +ALL_LANGUAGES = frozenset(Language(lng) for lng in lng_all_names) - frozenset([Language('unk')]) +ALL_LANGUAGES_NAMES = lng_all_names def search_language(string, lang_filter=None): """Looks for language patterns, and if found return the language object, @@ -177,7 +283,7 @@ def search_language(string, lang_filter=None): sep = r'[](){} \._-+' if lang_filter: - lang_filter = set(Language(l) for l in lang_filter) + lang_filter = lang_set(lang_filter) slow = ' %s ' % string.lower() confidence = 1.0 # for all of them diff --git a/libs/guessit/matcher.py b/libs/guessit/matcher.py index cac172d..b0a5040 100644 --- a/libs/guessit/matcher.py +++ b/libs/guessit/matcher.py @@ -25,7 +25,7 @@ from guessit.guess import (merge_similar_guesses, merge_all, import copy import logging -log = logging.getLogger("guessit.matcher") +log = logging.getLogger(__name__) class IterativeMatcher(object): @@ -105,7 +105,7 @@ class IterativeMatcher(object): 'guess_release_group', 'guess_properties', 'guess_weak_episodes_rexps', 'guess_language'] else: - strategy = ['guess_date', 'guess_year', 'guess_video_rexps', + strategy = ['guess_date', 'guess_video_rexps', 'guess_website', 'guess_release_group', 'guess_properties', 'guess_language'] @@ -125,6 +125,7 @@ class IterativeMatcher(object): if mtree.guess['type'] in ('episode', 'episodesubtitle'): apply_transfo('guess_episode_info_from_position') else: + apply_transfo('guess_year') apply_transfo('guess_movie_title_from_position') # 6- perform some post-processing steps diff --git a/libs/guessit/matchtree.py b/libs/guessit/matchtree.py index 634cbf7..466e0bb 100644 --- a/libs/guessit/matchtree.py +++ b/libs/guessit/matchtree.py @@ -23,7 +23,7 @@ from guessit.textutils import clean_string, str_fill, to_utf8 from guessit.patterns import group_delimiters import logging -log = logging.getLogger("guessit.matchtree") +log = logging.getLogger(__name__) class BaseMatchTree(object): diff --git a/libs/guessit/patterns.py b/libs/guessit/patterns.py index 4125fb7..4223585 100755 --- a/libs/guessit/patterns.py +++ b/libs/guessit/patterns.py @@ -22,8 +22,9 @@ subtitle_exts = [ 'srt', 'idx', 'sub', 'ssa', 'txt' ] -video_exts = [ 'avi', 'mkv', 'mpg', 'mp4', 'm4v', 'mov', 'ogg', 'ogm', 'ogv', - 'wmv', 'divx' ] +video_exts = ['3g2', '3gp', '3gp2', 'asf', 'avi', 'divx', 'flv', 'm4v', 'mk2', + 'mka', 'mkv', 'mov', 'mp4', 'mp4a', 'mpeg', 'mpg', 'ogg', 'ogm', + 'ogv', 'qt', 'ra', 'ram', 'rm', 'ts', 'wav', 'webm', 'wma', 'wmv'] group_delimiters = [ '()', '[]', '{}' ] @@ -62,6 +63,8 @@ weak_episode_rexps = [ # ... 213 or 0106 ... # ... 2x13 ... (sep + r'[^0-9](?P[0-9]{1,2})\.(?P[0-9]{2})[^0-9]' + sep, (1, -1)), + # ... e13 ... for a mini-series without a season number + (r'e(?P[0-9]{1,4})[^0-9]', (0, -1)), ] non_episode_title = [ 'extras', 'rip' ] diff --git a/libs/guessit/transfo/__init__.py b/libs/guessit/transfo/__init__.py index eb72beb..1bdd09b 100644 --- a/libs/guessit/transfo/__init__.py +++ b/libs/guessit/transfo/__init__.py @@ -23,7 +23,7 @@ from guessit.patterns import canonical_form from guessit.textutils import clean_string import logging -log = logging.getLogger('guessit.transfo') +log = logging.getLogger(__name__) def found_property(node, name, confidence): diff --git a/libs/guessit/transfo/guess_bonus_features.py b/libs/guessit/transfo/guess_bonus_features.py index dcb90b3..73fc7b4 100644 --- a/libs/guessit/transfo/guess_bonus_features.py +++ b/libs/guessit/transfo/guess_bonus_features.py @@ -21,7 +21,7 @@ from guessit.transfo import found_property import logging -log = logging.getLogger("guessit.transfo.guess_bonus_features") +log = logging.getLogger(__name__) def process(mtree): diff --git a/libs/guessit/transfo/guess_date.py b/libs/guessit/transfo/guess_date.py index c72d66a..ded8094 100644 --- a/libs/guessit/transfo/guess_date.py +++ b/libs/guessit/transfo/guess_date.py @@ -22,7 +22,7 @@ from guessit.transfo import SingleNodeGuesser from guessit.date import search_date import logging -log = logging.getLogger("guessit.transfo.guess_date") +log = logging.getLogger(__name__) def guess_date(string): diff --git a/libs/guessit/transfo/guess_episode_info_from_position.py b/libs/guessit/transfo/guess_episode_info_from_position.py index fe1a752..7b4f43f 100644 --- a/libs/guessit/transfo/guess_episode_info_from_position.py +++ b/libs/guessit/transfo/guess_episode_info_from_position.py @@ -22,7 +22,7 @@ from guessit.transfo import found_property from guessit.patterns import non_episode_title, unlikely_series import logging -log = logging.getLogger("guessit.transfo.guess_episode_info_from_position") +log = logging.getLogger(__name__) def match_from_epnum_position(mtree, node): @@ -112,6 +112,9 @@ def process(mtree): if len(title_candidates) >= 2: found_property(title_candidates[0], 'series', 0.4) found_property(title_candidates[1], 'title', 0.4) + elif len(title_candidates) == 1: + # but if there's only one candidate, it's probably the series name + found_property(title_candidates[0], 'series', 0.4) # if we only have 1 remaining valid group in the folder containing the # file, then it's likely that it is the series name diff --git a/libs/guessit/transfo/guess_episodes_rexps.py b/libs/guessit/transfo/guess_episodes_rexps.py index 46dbc59..dfaa944 100644 --- a/libs/guessit/transfo/guess_episodes_rexps.py +++ b/libs/guessit/transfo/guess_episodes_rexps.py @@ -24,7 +24,7 @@ from guessit.patterns import episode_rexps import re import logging -log = logging.getLogger("guessit.transfo.guess_episodes_rexps") +log = logging.getLogger(__name__) def guess_episodes_rexps(string): diff --git a/libs/guessit/transfo/guess_filetype.py b/libs/guessit/transfo/guess_filetype.py index 32bdc13..bf0a80a 100644 --- a/libs/guessit/transfo/guess_filetype.py +++ b/libs/guessit/transfo/guess_filetype.py @@ -26,7 +26,7 @@ import re import mimetypes import logging -log = logging.getLogger("guessit.transfo.guess_filetype") +log = logging.getLogger(__name__) def guess_filetype(filename, filetype): diff --git a/libs/guessit/transfo/guess_language.py b/libs/guessit/transfo/guess_language.py index 62f47d8..aa1431b 100644 --- a/libs/guessit/transfo/guess_language.py +++ b/libs/guessit/transfo/guess_language.py @@ -24,7 +24,7 @@ from guessit.language import search_language from guessit.textutils import clean_string import logging -log = logging.getLogger("guessit.transfo.guess_language") +log = logging.getLogger(__name__) def guess_language(string): diff --git a/libs/guessit/transfo/guess_movie_title_from_position.py b/libs/guessit/transfo/guess_movie_title_from_position.py index dea56d6..55289c8 100644 --- a/libs/guessit/transfo/guess_movie_title_from_position.py +++ b/libs/guessit/transfo/guess_movie_title_from_position.py @@ -21,7 +21,7 @@ from guessit import Guess import logging -log = logging.getLogger("guessit.transfo.guess_movie_title_from_position") +log = logging.getLogger(__name__) def process(mtree): diff --git a/libs/guessit/transfo/guess_properties.py b/libs/guessit/transfo/guess_properties.py index 3822d22..02d0cad 100644 --- a/libs/guessit/transfo/guess_properties.py +++ b/libs/guessit/transfo/guess_properties.py @@ -22,7 +22,7 @@ from guessit.transfo import SingleNodeGuesser from guessit.patterns import find_properties import logging -log = logging.getLogger("guessit.transfo.guess_properties") +log = logging.getLogger(__name__) def guess_properties(string): diff --git a/libs/guessit/transfo/guess_release_group.py b/libs/guessit/transfo/guess_release_group.py index 9ec609d..54a7148 100644 --- a/libs/guessit/transfo/guess_release_group.py +++ b/libs/guessit/transfo/guess_release_group.py @@ -22,7 +22,7 @@ from guessit.transfo import SingleNodeGuesser import re import logging -log = logging.getLogger("guessit.transfo.guess_release_group") +log = logging.getLogger(__name__) def guess_release_group(string): diff --git a/libs/guessit/transfo/guess_video_rexps.py b/libs/guessit/transfo/guess_video_rexps.py index 36723c8..697a6af 100644 --- a/libs/guessit/transfo/guess_video_rexps.py +++ b/libs/guessit/transfo/guess_video_rexps.py @@ -24,7 +24,7 @@ from guessit.patterns import video_rexps, sep import re import logging -log = logging.getLogger("guessit.transfo.guess_video_rexps") +log = logging.getLogger(__name__) def guess_video_rexps(string): diff --git a/libs/guessit/transfo/guess_weak_episodes_rexps.py b/libs/guessit/transfo/guess_weak_episodes_rexps.py index 8fffe17..57c9f44 100644 --- a/libs/guessit/transfo/guess_weak_episodes_rexps.py +++ b/libs/guessit/transfo/guess_weak_episodes_rexps.py @@ -24,7 +24,7 @@ from guessit.patterns import weak_episode_rexps import re import logging -log = logging.getLogger("guessit.transfo.guess_weak_episodes_rexps") +log = logging.getLogger(__name__) def guess_weak_episodes_rexps(string, node): diff --git a/libs/guessit/transfo/guess_website.py b/libs/guessit/transfo/guess_website.py index a169f97..638f7d2 100644 --- a/libs/guessit/transfo/guess_website.py +++ b/libs/guessit/transfo/guess_website.py @@ -22,7 +22,7 @@ from guessit.transfo import SingleNodeGuesser from guessit.patterns import websites import logging -log = logging.getLogger("guessit.transfo.guess_website") +log = logging.getLogger(__name__) def guess_website(string): diff --git a/libs/guessit/transfo/guess_year.py b/libs/guessit/transfo/guess_year.py index 7a47ecf..7a90111 100644 --- a/libs/guessit/transfo/guess_year.py +++ b/libs/guessit/transfo/guess_year.py @@ -22,7 +22,7 @@ from guessit.transfo import SingleNodeGuesser from guessit.date import search_year import logging -log = logging.getLogger("guessit.transfo.guess_year") +log = logging.getLogger(__name__) def guess_year(string): diff --git a/libs/guessit/transfo/post_process.py b/libs/guessit/transfo/post_process.py index 0b5a4df..f08bbb2 100644 --- a/libs/guessit/transfo/post_process.py +++ b/libs/guessit/transfo/post_process.py @@ -21,7 +21,7 @@ from guessit.patterns import subtitle_exts import logging -log = logging.getLogger("guessit.transfo.post_process") +log = logging.getLogger(__name__) def process(mtree): diff --git a/libs/guessit/transfo/split_explicit_groups.py b/libs/guessit/transfo/split_explicit_groups.py index 797a886..f99ff19 100644 --- a/libs/guessit/transfo/split_explicit_groups.py +++ b/libs/guessit/transfo/split_explicit_groups.py @@ -22,7 +22,7 @@ from guessit.textutils import find_first_level_groups from guessit.patterns import group_delimiters import logging -log = logging.getLogger("guessit.transfo.split_explicit_groups") +log = logging.getLogger(__name__) def process(mtree): diff --git a/libs/guessit/transfo/split_on_dash.py b/libs/guessit/transfo/split_on_dash.py index fc10c49..0f2c34b 100644 --- a/libs/guessit/transfo/split_on_dash.py +++ b/libs/guessit/transfo/split_on_dash.py @@ -22,7 +22,7 @@ from guessit.patterns import sep import re import logging -log = logging.getLogger("guessit.transfo.split_on_dash") +log = logging.getLogger(__name__) def process(mtree): diff --git a/libs/guessit/transfo/split_path_components.py b/libs/guessit/transfo/split_path_components.py index 0f8d1a5..9f7ec9b 100644 --- a/libs/guessit/transfo/split_path_components.py +++ b/libs/guessit/transfo/split_path_components.py @@ -22,7 +22,7 @@ from guessit import fileutils import os.path import logging -log = logging.getLogger("guessit.transfo.split_path_components") +log = logging.getLogger(__name__) def process(mtree): diff --git a/libs/jinja2/exceptions.py b/libs/jinja2/exceptions.py index 771f6a8..841aabb 100755 --- a/libs/jinja2/exceptions.py +++ b/libs/jinja2/exceptions.py @@ -62,7 +62,7 @@ class TemplatesNotFound(TemplateNotFound): def __init__(self, names=(), message=None): if message is None: - message = u'non of the templates given were found: ' + \ + message = u'none of the templates given were found: ' + \ u', '.join(map(unicode, names)) TemplateNotFound.__init__(self, names and names[-1] or None, message) self.templates = list(names) diff --git a/libs/jinja2/filters.py b/libs/jinja2/filters.py index 8dd6ff0..8fef6ea 100755 --- a/libs/jinja2/filters.py +++ b/libs/jinja2/filters.py @@ -176,7 +176,12 @@ def do_title(s): """Return a titlecased version of the value. I.e. words will start with uppercase letters, all remaining characters are lowercase. """ - return soft_unicode(s).title() + rv = [] + for item in re.compile(r'([-\s]+)(?u)').split(s): + if not item: + continue + rv.append(item[0].upper() + item[1:]) + return ''.join(rv) def do_dictsort(value, case_sensitive=False, by='key'): @@ -578,7 +583,7 @@ def do_batch(value, linecount, fill_with=None): A filter that batches items. It works pretty much like `slice` just the other way round. It returns a list of lists with the given number of items. If you provide a second parameter this - is used to fill missing items. See this example: + is used to fill up missing items. See this example: .. sourcecode:: html+jinja diff --git a/libs/jinja2/testsuite/__init__.py b/libs/jinja2/testsuite/__init__.py new file mode 100755 index 0000000..1f10ef6 --- /dev/null +++ b/libs/jinja2/testsuite/__init__.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite + ~~~~~~~~~~~~~~~~ + + All the unittests of Jinja2. These tests can be executed by + either running run-tests.py using multiple Python versions at + the same time. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import os +import re +import sys +import unittest +from traceback import format_exception +from jinja2 import loaders + + +here = os.path.dirname(os.path.abspath(__file__)) + +dict_loader = loaders.DictLoader({ + 'justdict.html': 'FOO' +}) +package_loader = loaders.PackageLoader('jinja2.testsuite.res', 'templates') +filesystem_loader = loaders.FileSystemLoader(here + '/res/templates') +function_loader = loaders.FunctionLoader({'justfunction.html': 'FOO'}.get) +choice_loader = loaders.ChoiceLoader([dict_loader, package_loader]) +prefix_loader = loaders.PrefixLoader({ + 'a': filesystem_loader, + 'b': dict_loader +}) + + +class JinjaTestCase(unittest.TestCase): + + ### use only these methods for testing. If you need standard + ### unittest method, wrap them! + + def setup(self): + pass + + def teardown(self): + pass + + def setUp(self): + self.setup() + + def tearDown(self): + self.teardown() + + def assert_equal(self, a, b): + return self.assertEqual(a, b) + + def assert_raises(self, *args, **kwargs): + return self.assertRaises(*args, **kwargs) + + def assert_traceback_matches(self, callback, expected_tb): + try: + callback() + except Exception, e: + tb = format_exception(*sys.exc_info()) + if re.search(expected_tb.strip(), ''.join(tb)) is None: + raise self.fail('Traceback did not match:\n\n%s\nexpected:\n%s' + % (''.join(tb), expected_tb)) + else: + self.fail('Expected exception') + + +def suite(): + from jinja2.testsuite import ext, filters, tests, core_tags, \ + loader, inheritance, imports, lexnparse, security, api, \ + regression, debug, utils, doctests + suite = unittest.TestSuite() + suite.addTest(ext.suite()) + suite.addTest(filters.suite()) + suite.addTest(tests.suite()) + suite.addTest(core_tags.suite()) + suite.addTest(loader.suite()) + suite.addTest(inheritance.suite()) + suite.addTest(imports.suite()) + suite.addTest(lexnparse.suite()) + suite.addTest(security.suite()) + suite.addTest(api.suite()) + suite.addTest(regression.suite()) + suite.addTest(debug.suite()) + suite.addTest(utils.suite()) + + # doctests will not run on python 3 currently. Too many issues + # with that, do not test that on that platform. + if sys.version_info < (3, 0): + suite.addTest(doctests.suite()) + + return suite diff --git a/libs/jinja2/testsuite/api.py b/libs/jinja2/testsuite/api.py new file mode 100755 index 0000000..c8f9634 --- /dev/null +++ b/libs/jinja2/testsuite/api.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.api + ~~~~~~~~~~~~~~~~~~~~ + + Tests the public API and related stuff. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, Undefined, DebugUndefined, \ + StrictUndefined, UndefinedError, meta, \ + is_undefined, Template, DictLoader +from jinja2.utils import Cycler + +env = Environment() + + +class ExtendedAPITestCase(JinjaTestCase): + + def test_item_and_attribute(self): + from jinja2.sandbox import SandboxedEnvironment + + for env in Environment(), SandboxedEnvironment(): + # the |list is necessary for python3 + tmpl = env.from_string('{{ foo.items()|list }}') + assert tmpl.render(foo={'items': 42}) == "[('items', 42)]" + tmpl = env.from_string('{{ foo|attr("items")()|list }}') + assert tmpl.render(foo={'items': 42}) == "[('items', 42)]" + tmpl = env.from_string('{{ foo["items"] }}') + assert tmpl.render(foo={'items': 42}) == '42' + + def test_finalizer(self): + def finalize_none_empty(value): + if value is None: + value = u'' + return value + env = Environment(finalize=finalize_none_empty) + tmpl = env.from_string('{% for item in seq %}|{{ item }}{% endfor %}') + assert tmpl.render(seq=(None, 1, "foo")) == '||1|foo' + tmpl = env.from_string('<{{ none }}>') + assert tmpl.render() == '<>' + + def test_cycler(self): + items = 1, 2, 3 + c = Cycler(*items) + for item in items + items: + assert c.current == item + assert c.next() == item + c.next() + assert c.current == 2 + c.reset() + assert c.current == 1 + + def test_expressions(self): + expr = env.compile_expression("foo") + assert expr() is None + assert expr(foo=42) == 42 + expr2 = env.compile_expression("foo", undefined_to_none=False) + assert is_undefined(expr2()) + + expr = env.compile_expression("42 + foo") + assert expr(foo=42) == 84 + + def test_template_passthrough(self): + t = Template('Content') + assert env.get_template(t) is t + assert env.select_template([t]) is t + assert env.get_or_select_template([t]) is t + assert env.get_or_select_template(t) is t + + def test_autoescape_autoselect(self): + def select_autoescape(name): + if name is None or '.' not in name: + return False + return name.endswith('.html') + env = Environment(autoescape=select_autoescape, + loader=DictLoader({ + 'test.txt': '{{ foo }}', + 'test.html': '{{ foo }}' + })) + t = env.get_template('test.txt') + assert t.render(foo='') == '' + t = env.get_template('test.html') + assert t.render(foo='') == '<foo>' + t = env.from_string('{{ foo }}') + assert t.render(foo='') == '' + + +class MetaTestCase(JinjaTestCase): + + def test_find_undeclared_variables(self): + ast = env.parse('{% set foo = 42 %}{{ bar + foo }}') + x = meta.find_undeclared_variables(ast) + assert x == set(['bar']) + + ast = env.parse('{% set foo = 42 %}{{ bar + foo }}' + '{% macro meh(x) %}{{ x }}{% endmacro %}' + '{% for item in seq %}{{ muh(item) + meh(seq) }}{% endfor %}') + x = meta.find_undeclared_variables(ast) + assert x == set(['bar', 'seq', 'muh']) + + def test_find_refererenced_templates(self): + ast = env.parse('{% extends "layout.html" %}{% include helper %}') + i = meta.find_referenced_templates(ast) + assert i.next() == 'layout.html' + assert i.next() is None + assert list(i) == [] + + ast = env.parse('{% extends "layout.html" %}' + '{% from "test.html" import a, b as c %}' + '{% import "meh.html" as meh %}' + '{% include "muh.html" %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['layout.html', 'test.html', 'meh.html', 'muh.html'] + + def test_find_included_templates(self): + ast = env.parse('{% include ["foo.html", "bar.html"] %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html'] + + ast = env.parse('{% include ("foo.html", "bar.html") %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html'] + + ast = env.parse('{% include ["foo.html", "bar.html", foo] %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html', None] + + ast = env.parse('{% include ("foo.html", "bar.html", foo) %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html', None] + + +class StreamingTestCase(JinjaTestCase): + + def test_basic_streaming(self): + tmpl = env.from_string("
    {% for item in seq %}
  • {{ loop.index " + "}} - {{ item }}
  • {%- endfor %}
") + stream = tmpl.stream(seq=range(4)) + self.assert_equal(stream.next(), '
    ') + self.assert_equal(stream.next(), '
  • 1 - 0
  • ') + self.assert_equal(stream.next(), '
  • 2 - 1
  • ') + self.assert_equal(stream.next(), '
  • 3 - 2
  • ') + self.assert_equal(stream.next(), '
  • 4 - 3
  • ') + self.assert_equal(stream.next(), '
') + + def test_buffered_streaming(self): + tmpl = env.from_string("
    {% for item in seq %}
  • {{ loop.index " + "}} - {{ item }}
  • {%- endfor %}
") + stream = tmpl.stream(seq=range(4)) + stream.enable_buffering(size=3) + self.assert_equal(stream.next(), u'
  • 1 - 0
  • 2 - 1
  • ') + self.assert_equal(stream.next(), u'
  • 3 - 2
  • 4 - 3
') + + def test_streaming_behavior(self): + tmpl = env.from_string("") + stream = tmpl.stream() + assert not stream.buffered + stream.enable_buffering(20) + assert stream.buffered + stream.disable_buffering() + assert not stream.buffered + + +class UndefinedTestCase(JinjaTestCase): + + def test_stopiteration_is_undefined(self): + def test(): + raise StopIteration() + t = Template('A{{ test() }}B') + assert t.render(test=test) == 'AB' + t = Template('A{{ test().missingattribute }}B') + self.assert_raises(UndefinedError, t.render, test=test) + + def test_undefined_and_special_attributes(self): + try: + Undefined('Foo').__dict__ + except AttributeError: + pass + else: + assert False, "Expected actual attribute error" + + def test_default_undefined(self): + env = Environment(undefined=Undefined) + self.assert_equal(env.from_string('{{ missing }}').render(), u'') + self.assert_raises(UndefinedError, + env.from_string('{{ missing.attribute }}').render) + self.assert_equal(env.from_string('{{ missing|list }}').render(), '[]') + self.assert_equal(env.from_string('{{ missing is not defined }}').render(), 'True') + self.assert_equal(env.from_string('{{ foo.missing }}').render(foo=42), '') + self.assert_equal(env.from_string('{{ not missing }}').render(), 'True') + + def test_debug_undefined(self): + env = Environment(undefined=DebugUndefined) + self.assert_equal(env.from_string('{{ missing }}').render(), '{{ missing }}') + self.assert_raises(UndefinedError, + env.from_string('{{ missing.attribute }}').render) + self.assert_equal(env.from_string('{{ missing|list }}').render(), '[]') + self.assert_equal(env.from_string('{{ missing is not defined }}').render(), 'True') + self.assert_equal(env.from_string('{{ foo.missing }}').render(foo=42), + u"{{ no such element: int object['missing'] }}") + self.assert_equal(env.from_string('{{ not missing }}').render(), 'True') + + def test_strict_undefined(self): + env = Environment(undefined=StrictUndefined) + self.assert_raises(UndefinedError, env.from_string('{{ missing }}').render) + self.assert_raises(UndefinedError, env.from_string('{{ missing.attribute }}').render) + self.assert_raises(UndefinedError, env.from_string('{{ missing|list }}').render) + self.assert_equal(env.from_string('{{ missing is not defined }}').render(), 'True') + self.assert_raises(UndefinedError, env.from_string('{{ foo.missing }}').render, foo=42) + self.assert_raises(UndefinedError, env.from_string('{{ not missing }}').render) + + def test_indexing_gives_undefined(self): + t = Template("{{ var[42].foo }}") + self.assert_raises(UndefinedError, t.render, var=0) + + def test_none_gives_proper_error(self): + try: + Environment().getattr(None, 'split')() + except UndefinedError, e: + assert e.message == "'None' has no attribute 'split'" + else: + assert False, 'expected exception' + + def test_object_repr(self): + try: + Undefined(obj=42, name='upper')() + except UndefinedError, e: + assert e.message == "'int object' has no attribute 'upper'" + else: + assert False, 'expected exception' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ExtendedAPITestCase)) + suite.addTest(unittest.makeSuite(MetaTestCase)) + suite.addTest(unittest.makeSuite(StreamingTestCase)) + suite.addTest(unittest.makeSuite(UndefinedTestCase)) + return suite diff --git a/libs/jinja2/testsuite/core_tags.py b/libs/jinja2/testsuite/core_tags.py new file mode 100755 index 0000000..2b5f580 --- /dev/null +++ b/libs/jinja2/testsuite/core_tags.py @@ -0,0 +1,285 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.core_tags + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Test the core tags like for and if. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, TemplateSyntaxError, UndefinedError, \ + DictLoader + +env = Environment() + + +class ForLoopTestCase(JinjaTestCase): + + def test_simple(self): + tmpl = env.from_string('{% for item in seq %}{{ item }}{% endfor %}') + assert tmpl.render(seq=range(10)) == '0123456789' + + def test_else(self): + tmpl = env.from_string('{% for item in seq %}XXX{% else %}...{% endfor %}') + assert tmpl.render() == '...' + + def test_empty_blocks(self): + tmpl = env.from_string('<{% for item in seq %}{% else %}{% endfor %}>') + assert tmpl.render() == '<>' + + def test_context_vars(self): + tmpl = env.from_string('''{% for item in seq -%} + {{ loop.index }}|{{ loop.index0 }}|{{ loop.revindex }}|{{ + loop.revindex0 }}|{{ loop.first }}|{{ loop.last }}|{{ + loop.length }}###{% endfor %}''') + one, two, _ = tmpl.render(seq=[0, 1]).split('###') + (one_index, one_index0, one_revindex, one_revindex0, one_first, + one_last, one_length) = one.split('|') + (two_index, two_index0, two_revindex, two_revindex0, two_first, + two_last, two_length) = two.split('|') + + assert int(one_index) == 1 and int(two_index) == 2 + assert int(one_index0) == 0 and int(two_index0) == 1 + assert int(one_revindex) == 2 and int(two_revindex) == 1 + assert int(one_revindex0) == 1 and int(two_revindex0) == 0 + assert one_first == 'True' and two_first == 'False' + assert one_last == 'False' and two_last == 'True' + assert one_length == two_length == '2' + + def test_cycling(self): + tmpl = env.from_string('''{% for item in seq %}{{ + loop.cycle('<1>', '<2>') }}{% endfor %}{% + for item in seq %}{{ loop.cycle(*through) }}{% endfor %}''') + output = tmpl.render(seq=range(4), through=('<1>', '<2>')) + assert output == '<1><2>' * 4 + + def test_scope(self): + tmpl = env.from_string('{% for item in seq %}{% endfor %}{{ item }}') + output = tmpl.render(seq=range(10)) + assert not output + + def test_varlen(self): + def inner(): + for item in range(5): + yield item + tmpl = env.from_string('{% for item in iter %}{{ item }}{% endfor %}') + output = tmpl.render(iter=inner()) + assert output == '01234' + + def test_noniter(self): + tmpl = env.from_string('{% for item in none %}...{% endfor %}') + self.assert_raises(TypeError, tmpl.render) + + def test_recursive(self): + tmpl = env.from_string('''{% for item in seq recursive -%} + [{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}] + {%- endfor %}''') + assert tmpl.render(seq=[ + dict(a=1, b=[dict(a=1), dict(a=2)]), + dict(a=2, b=[dict(a=1), dict(a=2)]), + dict(a=3, b=[dict(a='a')]) + ]) == '[1<[1][2]>][2<[1][2]>][3<[a]>]' + + def test_looploop(self): + tmpl = env.from_string('''{% for row in table %} + {%- set rowloop = loop -%} + {% for cell in row -%} + [{{ rowloop.index }}|{{ loop.index }}] + {%- endfor %} + {%- endfor %}''') + assert tmpl.render(table=['ab', 'cd']) == '[1|1][1|2][2|1][2|2]' + + def test_reversed_bug(self): + tmpl = env.from_string('{% for i in items %}{{ i }}' + '{% if not loop.last %}' + ',{% endif %}{% endfor %}') + assert tmpl.render(items=reversed([3, 2, 1])) == '1,2,3' + + def test_loop_errors(self): + tmpl = env.from_string('''{% for item in [1] if loop.index + == 0 %}...{% endfor %}''') + self.assert_raises(UndefinedError, tmpl.render) + tmpl = env.from_string('''{% for item in [] %}...{% else + %}{{ loop }}{% endfor %}''') + assert tmpl.render() == '' + + def test_loop_filter(self): + tmpl = env.from_string('{% for item in range(10) if item ' + 'is even %}[{{ item }}]{% endfor %}') + assert tmpl.render() == '[0][2][4][6][8]' + tmpl = env.from_string(''' + {%- for item in range(10) if item is even %}[{{ + loop.index }}:{{ item }}]{% endfor %}''') + assert tmpl.render() == '[1:0][2:2][3:4][4:6][5:8]' + + def test_loop_unassignable(self): + self.assert_raises(TemplateSyntaxError, env.from_string, + '{% for loop in seq %}...{% endfor %}') + + def test_scoped_special_var(self): + t = env.from_string('{% for s in seq %}[{{ loop.first }}{% for c in s %}' + '|{{ loop.first }}{% endfor %}]{% endfor %}') + assert t.render(seq=('ab', 'cd')) == '[True|True|False][False|True|False]' + + def test_scoped_loop_var(self): + t = env.from_string('{% for x in seq %}{{ loop.first }}' + '{% for y in seq %}{% endfor %}{% endfor %}') + assert t.render(seq='ab') == 'TrueFalse' + t = env.from_string('{% for x in seq %}{% for y in seq %}' + '{{ loop.first }}{% endfor %}{% endfor %}') + assert t.render(seq='ab') == 'TrueFalseTrueFalse' + + def test_recursive_empty_loop_iter(self): + t = env.from_string(''' + {%- for item in foo recursive -%}{%- endfor -%} + ''') + assert t.render(dict(foo=[])) == '' + + def test_call_in_loop(self): + t = env.from_string(''' + {%- macro do_something() -%} + [{{ caller() }}] + {%- endmacro %} + + {%- for i in [1, 2, 3] %} + {%- call do_something() -%} + {{ i }} + {%- endcall %} + {%- endfor -%} + ''') + assert t.render() == '[1][2][3]' + + def test_scoping_bug(self): + t = env.from_string(''' + {%- for item in foo %}...{{ item }}...{% endfor %} + {%- macro item(a) %}...{{ a }}...{% endmacro %} + {{- item(2) -}} + ''') + assert t.render(foo=(1,)) == '...1......2...' + + def test_unpacking(self): + tmpl = env.from_string('{% for a, b, c in [[1, 2, 3]] %}' + '{{ a }}|{{ b }}|{{ c }}{% endfor %}') + assert tmpl.render() == '1|2|3' + + +class IfConditionTestCase(JinjaTestCase): + + def test_simple(self): + tmpl = env.from_string('''{% if true %}...{% endif %}''') + assert tmpl.render() == '...' + + def test_elif(self): + tmpl = env.from_string('''{% if false %}XXX{% elif true + %}...{% else %}XXX{% endif %}''') + assert tmpl.render() == '...' + + def test_else(self): + tmpl = env.from_string('{% if false %}XXX{% else %}...{% endif %}') + assert tmpl.render() == '...' + + def test_empty(self): + tmpl = env.from_string('[{% if true %}{% else %}{% endif %}]') + assert tmpl.render() == '[]' + + def test_complete(self): + tmpl = env.from_string('{% if a %}A{% elif b %}B{% elif c == d %}' + 'C{% else %}D{% endif %}') + assert tmpl.render(a=0, b=False, c=42, d=42.0) == 'C' + + def test_no_scope(self): + tmpl = env.from_string('{% if a %}{% set foo = 1 %}{% endif %}{{ foo }}') + assert tmpl.render(a=True) == '1' + tmpl = env.from_string('{% if true %}{% set foo = 1 %}{% endif %}{{ foo }}') + assert tmpl.render() == '1' + + +class MacrosTestCase(JinjaTestCase): + env = Environment(trim_blocks=True) + + def test_simple(self): + tmpl = self.env.from_string('''\ +{% macro say_hello(name) %}Hello {{ name }}!{% endmacro %} +{{ say_hello('Peter') }}''') + assert tmpl.render() == 'Hello Peter!' + + def test_scoping(self): + tmpl = self.env.from_string('''\ +{% macro level1(data1) %} +{% macro level2(data2) %}{{ data1 }}|{{ data2 }}{% endmacro %} +{{ level2('bar') }}{% endmacro %} +{{ level1('foo') }}''') + assert tmpl.render() == 'foo|bar' + + def test_arguments(self): + tmpl = self.env.from_string('''\ +{% macro m(a, b, c='c', d='d') %}{{ a }}|{{ b }}|{{ c }}|{{ d }}{% endmacro %} +{{ m() }}|{{ m('a') }}|{{ m('a', 'b') }}|{{ m(1, 2, 3) }}''') + assert tmpl.render() == '||c|d|a||c|d|a|b|c|d|1|2|3|d' + + def test_varargs(self): + tmpl = self.env.from_string('''\ +{% macro test() %}{{ varargs|join('|') }}{% endmacro %}\ +{{ test(1, 2, 3) }}''') + assert tmpl.render() == '1|2|3' + + def test_simple_call(self): + tmpl = self.env.from_string('''\ +{% macro test() %}[[{{ caller() }}]]{% endmacro %}\ +{% call test() %}data{% endcall %}''') + assert tmpl.render() == '[[data]]' + + def test_complex_call(self): + tmpl = self.env.from_string('''\ +{% macro test() %}[[{{ caller('data') }}]]{% endmacro %}\ +{% call(data) test() %}{{ data }}{% endcall %}''') + assert tmpl.render() == '[[data]]' + + def test_caller_undefined(self): + tmpl = self.env.from_string('''\ +{% set caller = 42 %}\ +{% macro test() %}{{ caller is not defined }}{% endmacro %}\ +{{ test() }}''') + assert tmpl.render() == 'True' + + def test_include(self): + self.env = Environment(loader=DictLoader({'include': + '{% macro test(foo) %}[{{ foo }}]{% endmacro %}'})) + tmpl = self.env.from_string('{% from "include" import test %}{{ test("foo") }}') + assert tmpl.render() == '[foo]' + + def test_macro_api(self): + tmpl = self.env.from_string('{% macro foo(a, b) %}{% endmacro %}' + '{% macro bar() %}{{ varargs }}{{ kwargs }}{% endmacro %}' + '{% macro baz() %}{{ caller() }}{% endmacro %}') + assert tmpl.module.foo.arguments == ('a', 'b') + assert tmpl.module.foo.defaults == () + assert tmpl.module.foo.name == 'foo' + assert not tmpl.module.foo.caller + assert not tmpl.module.foo.catch_kwargs + assert not tmpl.module.foo.catch_varargs + assert tmpl.module.bar.arguments == () + assert tmpl.module.bar.defaults == () + assert not tmpl.module.bar.caller + assert tmpl.module.bar.catch_kwargs + assert tmpl.module.bar.catch_varargs + assert tmpl.module.baz.caller + + def test_callself(self): + tmpl = self.env.from_string('{% macro foo(x) %}{{ x }}{% if x > 1 %}|' + '{{ foo(x - 1) }}{% endif %}{% endmacro %}' + '{{ foo(5) }}') + assert tmpl.render() == '5|4|3|2|1' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ForLoopTestCase)) + suite.addTest(unittest.makeSuite(IfConditionTestCase)) + suite.addTest(unittest.makeSuite(MacrosTestCase)) + return suite diff --git a/libs/jinja2/testsuite/debug.py b/libs/jinja2/testsuite/debug.py new file mode 100755 index 0000000..7552dec --- /dev/null +++ b/libs/jinja2/testsuite/debug.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.debug + ~~~~~~~~~~~~~~~~~~~~~~ + + Tests the debug system. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import sys +import unittest + +from jinja2.testsuite import JinjaTestCase, filesystem_loader + +from jinja2 import Environment, TemplateSyntaxError + +env = Environment(loader=filesystem_loader) + + +class DebugTestCase(JinjaTestCase): + + if sys.version_info[:2] != (2, 4): + def test_runtime_error(self): + def test(): + tmpl.render(fail=lambda: 1 / 0) + tmpl = env.get_template('broken.html') + self.assert_traceback_matches(test, r''' + File ".*?broken.html", line 2, in (top-level template code|) + \{\{ fail\(\) \}\} + File ".*?debug.pyc?", line \d+, in + tmpl\.render\(fail=lambda: 1 / 0\) +ZeroDivisionError: (int(eger)? )?division (or modulo )?by zero +''') + + def test_syntax_error(self): + # XXX: the .*? is necessary for python3 which does not hide + # some of the stack frames we don't want to show. Not sure + # what's up with that, but that is not that critical. Should + # be fixed though. + self.assert_traceback_matches(lambda: env.get_template('syntaxerror.html'), r'''(?sm) + File ".*?syntaxerror.html", line 4, in (template|) + \{% endif %\}.*? +(jinja2\.exceptions\.)?TemplateSyntaxError: Encountered unknown tag 'endif'. Jinja was looking for the following tags: 'endfor' or 'else'. The innermost block that needs to be closed is 'for'. + ''') + + def test_regular_syntax_error(self): + def test(): + raise TemplateSyntaxError('wtf', 42) + self.assert_traceback_matches(test, r''' + File ".*debug.pyc?", line \d+, in test + raise TemplateSyntaxError\('wtf', 42\) +(jinja2\.exceptions\.)?TemplateSyntaxError: wtf + line 42''') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(DebugTestCase)) + return suite diff --git a/libs/jinja2/testsuite/doctests.py b/libs/jinja2/testsuite/doctests.py new file mode 100755 index 0000000..616d3b6 --- /dev/null +++ b/libs/jinja2/testsuite/doctests.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.doctests + ~~~~~~~~~~~~~~~~~~~~~~~~~ + + The doctests. Collects all tests we want to test from + the Jinja modules. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +import doctest + + +def suite(): + from jinja2 import utils, sandbox, runtime, meta, loaders, \ + ext, environment, bccache, nodes + suite = unittest.TestSuite() + suite.addTest(doctest.DocTestSuite(utils)) + suite.addTest(doctest.DocTestSuite(sandbox)) + suite.addTest(doctest.DocTestSuite(runtime)) + suite.addTest(doctest.DocTestSuite(meta)) + suite.addTest(doctest.DocTestSuite(loaders)) + suite.addTest(doctest.DocTestSuite(ext)) + suite.addTest(doctest.DocTestSuite(environment)) + suite.addTest(doctest.DocTestSuite(bccache)) + suite.addTest(doctest.DocTestSuite(nodes)) + return suite diff --git a/libs/jinja2/testsuite/ext.py b/libs/jinja2/testsuite/ext.py new file mode 100755 index 0000000..6ca6c22 --- /dev/null +++ b/libs/jinja2/testsuite/ext.py @@ -0,0 +1,455 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.ext + ~~~~~~~~~~~~~~~~~~~~ + + Tests for the extensions. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import re +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, DictLoader, contextfunction, nodes +from jinja2.exceptions import TemplateAssertionError +from jinja2.ext import Extension +from jinja2.lexer import Token, count_newlines +from jinja2.utils import next + +# 2.x / 3.x +try: + from io import BytesIO +except ImportError: + from StringIO import StringIO as BytesIO + + +importable_object = 23 + +_gettext_re = re.compile(r'_\((.*?)\)(?s)') + + +i18n_templates = { + 'master.html': '{{ page_title|default(_("missing")) }}' + '{% block body %}{% endblock %}', + 'child.html': '{% extends "master.html" %}{% block body %}' + '{% trans %}watch out{% endtrans %}{% endblock %}', + 'plural.html': '{% trans user_count %}One user online{% pluralize %}' + '{{ user_count }} users online{% endtrans %}', + 'stringformat.html': '{{ _("User: %(num)s")|format(num=user_count) }}' +} + +newstyle_i18n_templates = { + 'master.html': '{{ page_title|default(_("missing")) }}' + '{% block body %}{% endblock %}', + 'child.html': '{% extends "master.html" %}{% block body %}' + '{% trans %}watch out{% endtrans %}{% endblock %}', + 'plural.html': '{% trans user_count %}One user online{% pluralize %}' + '{{ user_count }} users online{% endtrans %}', + 'stringformat.html': '{{ _("User: %(num)s", num=user_count) }}', + 'ngettext.html': '{{ ngettext("%(num)s apple", "%(num)s apples", apples) }}', + 'ngettext_long.html': '{% trans num=apples %}{{ num }} apple{% pluralize %}' + '{{ num }} apples{% endtrans %}', + 'transvars1.html': '{% trans %}User: {{ num }}{% endtrans %}', + 'transvars2.html': '{% trans num=count %}User: {{ num }}{% endtrans %}', + 'transvars3.html': '{% trans count=num %}User: {{ count }}{% endtrans %}', + 'novars.html': '{% trans %}%(hello)s{% endtrans %}', + 'vars.html': '{% trans %}{{ foo }}%(foo)s{% endtrans %}', + 'explicitvars.html': '{% trans foo="42" %}%(foo)s{% endtrans %}' +} + + +languages = { + 'de': { + 'missing': u'fehlend', + 'watch out': u'pass auf', + 'One user online': u'Ein Benutzer online', + '%(user_count)s users online': u'%(user_count)s Benutzer online', + 'User: %(num)s': u'Benutzer: %(num)s', + 'User: %(count)s': u'Benutzer: %(count)s', + '%(num)s apple': u'%(num)s Apfel', + '%(num)s apples': u'%(num)s Äpfel' + } +} + + +@contextfunction +def gettext(context, string): + language = context.get('LANGUAGE', 'en') + return languages.get(language, {}).get(string, string) + + +@contextfunction +def ngettext(context, s, p, n): + language = context.get('LANGUAGE', 'en') + if n != 1: + return languages.get(language, {}).get(p, p) + return languages.get(language, {}).get(s, s) + + +i18n_env = Environment( + loader=DictLoader(i18n_templates), + extensions=['jinja2.ext.i18n'] +) +i18n_env.globals.update({ + '_': gettext, + 'gettext': gettext, + 'ngettext': ngettext +}) + +newstyle_i18n_env = Environment( + loader=DictLoader(newstyle_i18n_templates), + extensions=['jinja2.ext.i18n'] +) +newstyle_i18n_env.install_gettext_callables(gettext, ngettext, newstyle=True) + +class TestExtension(Extension): + tags = set(['test']) + ext_attr = 42 + + def parse(self, parser): + return nodes.Output([self.call_method('_dump', [ + nodes.EnvironmentAttribute('sandboxed'), + self.attr('ext_attr'), + nodes.ImportedName(__name__ + '.importable_object'), + nodes.ContextReference() + ])]).set_lineno(next(parser.stream).lineno) + + def _dump(self, sandboxed, ext_attr, imported_object, context): + return '%s|%s|%s|%s' % ( + sandboxed, + ext_attr, + imported_object, + context.blocks + ) + + +class PreprocessorExtension(Extension): + + def preprocess(self, source, name, filename=None): + return source.replace('[[TEST]]', '({{ foo }})') + + +class StreamFilterExtension(Extension): + + def filter_stream(self, stream): + for token in stream: + if token.type == 'data': + for t in self.interpolate(token): + yield t + else: + yield token + + def interpolate(self, token): + pos = 0 + end = len(token.value) + lineno = token.lineno + while 1: + match = _gettext_re.search(token.value, pos) + if match is None: + break + value = token.value[pos:match.start()] + if value: + yield Token(lineno, 'data', value) + lineno += count_newlines(token.value) + yield Token(lineno, 'variable_begin', None) + yield Token(lineno, 'name', 'gettext') + yield Token(lineno, 'lparen', None) + yield Token(lineno, 'string', match.group(1)) + yield Token(lineno, 'rparen', None) + yield Token(lineno, 'variable_end', None) + pos = match.end() + if pos < end: + yield Token(lineno, 'data', token.value[pos:]) + + +class ExtensionsTestCase(JinjaTestCase): + + def test_extend_late(self): + env = Environment() + env.add_extension('jinja2.ext.autoescape') + t = env.from_string('{% autoescape true %}{{ "" }}{% endautoescape %}') + assert t.render() == '<test>' + + def test_loop_controls(self): + env = Environment(extensions=['jinja2.ext.loopcontrols']) + + tmpl = env.from_string(''' + {%- for item in [1, 2, 3, 4] %} + {%- if item % 2 == 0 %}{% continue %}{% endif -%} + {{ item }} + {%- endfor %}''') + assert tmpl.render() == '13' + + tmpl = env.from_string(''' + {%- for item in [1, 2, 3, 4] %} + {%- if item > 2 %}{% break %}{% endif -%} + {{ item }} + {%- endfor %}''') + assert tmpl.render() == '12' + + def test_do(self): + env = Environment(extensions=['jinja2.ext.do']) + tmpl = env.from_string(''' + {%- set items = [] %} + {%- for char in "foo" %} + {%- do items.append(loop.index0 ~ char) %} + {%- endfor %}{{ items|join(', ') }}''') + assert tmpl.render() == '0f, 1o, 2o' + + def test_with(self): + env = Environment(extensions=['jinja2.ext.with_']) + tmpl = env.from_string('''\ + {% with a=42, b=23 -%} + {{ a }} = {{ b }} + {% endwith -%} + {{ a }} = {{ b }}\ + ''') + assert [x.strip() for x in tmpl.render(a=1, b=2).splitlines()] \ + == ['42 = 23', '1 = 2'] + + def test_extension_nodes(self): + env = Environment(extensions=[TestExtension]) + tmpl = env.from_string('{% test %}') + assert tmpl.render() == 'False|42|23|{}' + + def test_identifier(self): + assert TestExtension.identifier == __name__ + '.TestExtension' + + def test_rebinding(self): + original = Environment(extensions=[TestExtension]) + overlay = original.overlay() + for env in original, overlay: + for ext in env.extensions.itervalues(): + assert ext.environment is env + + def test_preprocessor_extension(self): + env = Environment(extensions=[PreprocessorExtension]) + tmpl = env.from_string('{[[TEST]]}') + assert tmpl.render(foo=42) == '{(42)}' + + def test_streamfilter_extension(self): + env = Environment(extensions=[StreamFilterExtension]) + env.globals['gettext'] = lambda x: x.upper() + tmpl = env.from_string('Foo _(bar) Baz') + out = tmpl.render() + assert out == 'Foo BAR Baz' + + def test_extension_ordering(self): + class T1(Extension): + priority = 1 + class T2(Extension): + priority = 2 + env = Environment(extensions=[T1, T2]) + ext = list(env.iter_extensions()) + assert ext[0].__class__ is T1 + assert ext[1].__class__ is T2 + + +class InternationalizationTestCase(JinjaTestCase): + + def test_trans(self): + tmpl = i18n_env.get_template('child.html') + assert tmpl.render(LANGUAGE='de') == 'fehlendpass auf' + + def test_trans_plural(self): + tmpl = i18n_env.get_template('plural.html') + assert tmpl.render(LANGUAGE='de', user_count=1) == 'Ein Benutzer online' + assert tmpl.render(LANGUAGE='de', user_count=2) == '2 Benutzer online' + + def test_complex_plural(self): + tmpl = i18n_env.from_string('{% trans foo=42, count=2 %}{{ count }} item{% ' + 'pluralize count %}{{ count }} items{% endtrans %}') + assert tmpl.render() == '2 items' + self.assert_raises(TemplateAssertionError, i18n_env.from_string, + '{% trans foo %}...{% pluralize bar %}...{% endtrans %}') + + def test_trans_stringformatting(self): + tmpl = i18n_env.get_template('stringformat.html') + assert tmpl.render(LANGUAGE='de', user_count=5) == 'Benutzer: 5' + + def test_extract(self): + from jinja2.ext import babel_extract + source = BytesIO(''' + {{ gettext('Hello World') }} + {% trans %}Hello World{% endtrans %} + {% trans %}{{ users }} user{% pluralize %}{{ users }} users{% endtrans %} + '''.encode('ascii')) # make python 3 happy + assert list(babel_extract(source, ('gettext', 'ngettext', '_'), [], {})) == [ + (2, 'gettext', u'Hello World', []), + (3, 'gettext', u'Hello World', []), + (4, 'ngettext', (u'%(users)s user', u'%(users)s users', None), []) + ] + + def test_comment_extract(self): + from jinja2.ext import babel_extract + source = BytesIO(''' + {# trans first #} + {{ gettext('Hello World') }} + {% trans %}Hello World{% endtrans %}{# trans second #} + {#: third #} + {% trans %}{{ users }} user{% pluralize %}{{ users }} users{% endtrans %} + '''.encode('utf-8')) # make python 3 happy + assert list(babel_extract(source, ('gettext', 'ngettext', '_'), ['trans', ':'], {})) == [ + (3, 'gettext', u'Hello World', ['first']), + (4, 'gettext', u'Hello World', ['second']), + (6, 'ngettext', (u'%(users)s user', u'%(users)s users', None), ['third']) + ] + + +class NewstyleInternationalizationTestCase(JinjaTestCase): + + def test_trans(self): + tmpl = newstyle_i18n_env.get_template('child.html') + assert tmpl.render(LANGUAGE='de') == 'fehlendpass auf' + + def test_trans_plural(self): + tmpl = newstyle_i18n_env.get_template('plural.html') + assert tmpl.render(LANGUAGE='de', user_count=1) == 'Ein Benutzer online' + assert tmpl.render(LANGUAGE='de', user_count=2) == '2 Benutzer online' + + def test_complex_plural(self): + tmpl = newstyle_i18n_env.from_string('{% trans foo=42, count=2 %}{{ count }} item{% ' + 'pluralize count %}{{ count }} items{% endtrans %}') + assert tmpl.render() == '2 items' + self.assert_raises(TemplateAssertionError, i18n_env.from_string, + '{% trans foo %}...{% pluralize bar %}...{% endtrans %}') + + def test_trans_stringformatting(self): + tmpl = newstyle_i18n_env.get_template('stringformat.html') + assert tmpl.render(LANGUAGE='de', user_count=5) == 'Benutzer: 5' + + def test_newstyle_plural(self): + tmpl = newstyle_i18n_env.get_template('ngettext.html') + assert tmpl.render(LANGUAGE='de', apples=1) == '1 Apfel' + assert tmpl.render(LANGUAGE='de', apples=5) == u'5 Äpfel' + + def test_autoescape_support(self): + env = Environment(extensions=['jinja2.ext.autoescape', + 'jinja2.ext.i18n']) + env.install_gettext_callables(lambda x: u'Wert: %(name)s', + lambda s, p, n: s, newstyle=True) + t = env.from_string('{% autoescape ae %}{{ gettext("foo", name=' + '"") }}{% endautoescape %}') + assert t.render(ae=True) == 'Wert: <test>' + assert t.render(ae=False) == 'Wert: ' + + def test_num_used_twice(self): + tmpl = newstyle_i18n_env.get_template('ngettext_long.html') + assert tmpl.render(apples=5, LANGUAGE='de') == u'5 Äpfel' + + def test_num_called_num(self): + source = newstyle_i18n_env.compile(''' + {% trans num=3 %}{{ num }} apple{% pluralize + %}{{ num }} apples{% endtrans %} + ''', raw=True) + # quite hacky, but the only way to properly test that. The idea is + # that the generated code does not pass num twice (although that + # would work) for better performance. This only works on the + # newstyle gettext of course + assert re.search(r"l_ngettext, u?'\%\(num\)s apple', u?'\%\(num\)s " + r"apples', 3", source) is not None + + def test_trans_vars(self): + t1 = newstyle_i18n_env.get_template('transvars1.html') + t2 = newstyle_i18n_env.get_template('transvars2.html') + t3 = newstyle_i18n_env.get_template('transvars3.html') + assert t1.render(num=1, LANGUAGE='de') == 'Benutzer: 1' + assert t2.render(count=23, LANGUAGE='de') == 'Benutzer: 23' + assert t3.render(num=42, LANGUAGE='de') == 'Benutzer: 42' + + def test_novars_vars_escaping(self): + t = newstyle_i18n_env.get_template('novars.html') + assert t.render() == '%(hello)s' + t = newstyle_i18n_env.get_template('vars.html') + assert t.render(foo='42') == '42%(foo)s' + t = newstyle_i18n_env.get_template('explicitvars.html') + assert t.render() == '%(foo)s' + + +class AutoEscapeTestCase(JinjaTestCase): + + def test_scoped_setting(self): + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + tmpl = env.from_string(''' + {{ "" }} + {% autoescape false %} + {{ "" }} + {% endautoescape %} + {{ "" }} + ''') + assert tmpl.render().split() == \ + [u'<HelloWorld>', u'', u'<HelloWorld>'] + + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=False) + tmpl = env.from_string(''' + {{ "" }} + {% autoescape true %} + {{ "" }} + {% endautoescape %} + {{ "" }} + ''') + assert tmpl.render().split() == \ + [u'', u'<HelloWorld>', u''] + + def test_nonvolatile(self): + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + tmpl = env.from_string('{{ {"foo": ""}|xmlattr|escape }}') + assert tmpl.render() == ' foo="<test>"' + tmpl = env.from_string('{% autoescape false %}{{ {"foo": ""}' + '|xmlattr|escape }}{% endautoescape %}') + assert tmpl.render() == ' foo="&lt;test&gt;"' + + def test_volatile(self): + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + tmpl = env.from_string('{% autoescape foo %}{{ {"foo": ""}' + '|xmlattr|escape }}{% endautoescape %}') + assert tmpl.render(foo=False) == ' foo="&lt;test&gt;"' + assert tmpl.render(foo=True) == ' foo="<test>"' + + def test_scoping(self): + env = Environment(extensions=['jinja2.ext.autoescape']) + tmpl = env.from_string('{% autoescape true %}{% set x = "" %}{{ x }}' + '{% endautoescape %}{{ x }}{{ "" }}') + assert tmpl.render(x=1) == '<x>1' + + def test_volatile_scoping(self): + env = Environment(extensions=['jinja2.ext.autoescape']) + tmplsource = ''' + {% autoescape val %} + {% macro foo(x) %} + [{{ x }}] + {% endmacro %} + {{ foo().__class__.__name__ }} + {% endautoescape %} + {{ '' }} + ''' + tmpl = env.from_string(tmplsource) + assert tmpl.render(val=True).split()[0] == 'Markup' + assert tmpl.render(val=False).split()[0] == unicode.__name__ + + # looking at the source we should see there in raw + # (and then escaped as well) + env = Environment(extensions=['jinja2.ext.autoescape']) + pysource = env.compile(tmplsource, raw=True) + assert '\\n' in pysource + + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + pysource = env.compile(tmplsource, raw=True) + assert '<testing>\\n' in pysource + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ExtensionsTestCase)) + suite.addTest(unittest.makeSuite(InternationalizationTestCase)) + suite.addTest(unittest.makeSuite(NewstyleInternationalizationTestCase)) + suite.addTest(unittest.makeSuite(AutoEscapeTestCase)) + return suite diff --git a/libs/jinja2/testsuite/filters.py b/libs/jinja2/testsuite/filters.py new file mode 100755 index 0000000..b037e24 --- /dev/null +++ b/libs/jinja2/testsuite/filters.py @@ -0,0 +1,396 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.filters + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests for the jinja filters. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Markup, Environment + +env = Environment() + + +class FilterTestCase(JinjaTestCase): + + def test_capitalize(self): + tmpl = env.from_string('{{ "foo bar"|capitalize }}') + assert tmpl.render() == 'Foo bar' + + def test_center(self): + tmpl = env.from_string('{{ "foo"|center(9) }}') + assert tmpl.render() == ' foo ' + + def test_default(self): + tmpl = env.from_string( + "{{ missing|default('no') }}|{{ false|default('no') }}|" + "{{ false|default('no', true) }}|{{ given|default('no') }}" + ) + assert tmpl.render(given='yes') == 'no|False|no|yes' + + def test_dictsort(self): + tmpl = env.from_string( + '{{ foo|dictsort }}|' + '{{ foo|dictsort(true) }}|' + '{{ foo|dictsort(false, "value") }}' + ) + out = tmpl.render(foo={"aa": 0, "b": 1, "c": 2, "AB": 3}) + assert out == ("[('aa', 0), ('AB', 3), ('b', 1), ('c', 2)]|" + "[('AB', 3), ('aa', 0), ('b', 1), ('c', 2)]|" + "[('aa', 0), ('b', 1), ('c', 2), ('AB', 3)]") + + def test_batch(self): + tmpl = env.from_string("{{ foo|batch(3)|list }}|" + "{{ foo|batch(3, 'X')|list }}") + out = tmpl.render(foo=range(10)) + assert out == ("[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]|" + "[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 'X', 'X']]") + + def test_slice(self): + tmpl = env.from_string('{{ foo|slice(3)|list }}|' + '{{ foo|slice(3, "X")|list }}') + out = tmpl.render(foo=range(10)) + assert out == ("[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]|" + "[[0, 1, 2, 3], [4, 5, 6, 'X'], [7, 8, 9, 'X']]") + + def test_escape(self): + tmpl = env.from_string('''{{ '<">&'|escape }}''') + out = tmpl.render() + assert out == '<">&' + + def test_striptags(self): + tmpl = env.from_string('''{{ foo|striptags }}''') + out = tmpl.render(foo='

just a small \n ' + 'example link

\n

to a webpage

' + '') + assert out == 'just a small example link to a webpage' + + def test_filesizeformat(self): + tmpl = env.from_string( + '{{ 100|filesizeformat }}|' + '{{ 1000|filesizeformat }}|' + '{{ 1000000|filesizeformat }}|' + '{{ 1000000000|filesizeformat }}|' + '{{ 1000000000000|filesizeformat }}|' + '{{ 100|filesizeformat(true) }}|' + '{{ 1000|filesizeformat(true) }}|' + '{{ 1000000|filesizeformat(true) }}|' + '{{ 1000000000|filesizeformat(true) }}|' + '{{ 1000000000000|filesizeformat(true) }}' + ) + out = tmpl.render() + self.assert_equal(out, ( + '100 Bytes|1.0 kB|1.0 MB|1.0 GB|1.0 TB|100 Bytes|' + '1000 Bytes|976.6 KiB|953.7 MiB|931.3 GiB' + )) + + def test_filesizeformat_issue59(self): + tmpl = env.from_string( + '{{ 300|filesizeformat }}|' + '{{ 3000|filesizeformat }}|' + '{{ 3000000|filesizeformat }}|' + '{{ 3000000000|filesizeformat }}|' + '{{ 3000000000000|filesizeformat }}|' + '{{ 300|filesizeformat(true) }}|' + '{{ 3000|filesizeformat(true) }}|' + '{{ 3000000|filesizeformat(true) }}' + ) + out = tmpl.render() + self.assert_equal(out, ( + '300 Bytes|3.0 kB|3.0 MB|3.0 GB|3.0 TB|300 Bytes|' + '2.9 KiB|2.9 MiB' + )) + + + def test_first(self): + tmpl = env.from_string('{{ foo|first }}') + out = tmpl.render(foo=range(10)) + assert out == '0' + + def test_float(self): + tmpl = env.from_string('{{ "42"|float }}|' + '{{ "ajsghasjgd"|float }}|' + '{{ "32.32"|float }}') + out = tmpl.render() + assert out == '42.0|0.0|32.32' + + def test_format(self): + tmpl = env.from_string('''{{ "%s|%s"|format("a", "b") }}''') + out = tmpl.render() + assert out == 'a|b' + + def test_indent(self): + tmpl = env.from_string('{{ foo|indent(2) }}|{{ foo|indent(2, true) }}') + text = '\n'.join([' '.join(['foo', 'bar'] * 2)] * 2) + out = tmpl.render(foo=text) + assert out == ('foo bar foo bar\n foo bar foo bar| ' + 'foo bar foo bar\n foo bar foo bar') + + def test_int(self): + tmpl = env.from_string('{{ "42"|int }}|{{ "ajsghasjgd"|int }}|' + '{{ "32.32"|int }}') + out = tmpl.render() + assert out == '42|0|32' + + def test_join(self): + tmpl = env.from_string('{{ [1, 2, 3]|join("|") }}') + out = tmpl.render() + assert out == '1|2|3' + + env2 = Environment(autoescape=True) + tmpl = env2.from_string('{{ ["", "foo"|safe]|join }}') + assert tmpl.render() == '<foo>foo' + + def test_join_attribute(self): + class User(object): + def __init__(self, username): + self.username = username + tmpl = env.from_string('''{{ users|join(', ', 'username') }}''') + assert tmpl.render(users=map(User, ['foo', 'bar'])) == 'foo, bar' + + def test_last(self): + tmpl = env.from_string('''{{ foo|last }}''') + out = tmpl.render(foo=range(10)) + assert out == '9' + + def test_length(self): + tmpl = env.from_string('''{{ "hello world"|length }}''') + out = tmpl.render() + assert out == '11' + + def test_lower(self): + tmpl = env.from_string('''{{ "FOO"|lower }}''') + out = tmpl.render() + assert out == 'foo' + + def test_pprint(self): + from pprint import pformat + tmpl = env.from_string('''{{ data|pprint }}''') + data = range(1000) + assert tmpl.render(data=data) == pformat(data) + + def test_random(self): + tmpl = env.from_string('''{{ seq|random }}''') + seq = range(100) + for _ in range(10): + assert int(tmpl.render(seq=seq)) in seq + + def test_reverse(self): + tmpl = env.from_string('{{ "foobar"|reverse|join }}|' + '{{ [1, 2, 3]|reverse|list }}') + assert tmpl.render() == 'raboof|[3, 2, 1]' + + def test_string(self): + x = [1, 2, 3, 4, 5] + tmpl = env.from_string('''{{ obj|string }}''') + assert tmpl.render(obj=x) == unicode(x) + + def test_title(self): + tmpl = env.from_string('''{{ "foo bar"|title }}''') + assert tmpl.render() == "Foo Bar" + tmpl = env.from_string('''{{ "foo's bar"|title }}''') + assert tmpl.render() == "Foo's Bar" + tmpl = env.from_string('''{{ "foo bar"|title }}''') + assert tmpl.render() == "Foo Bar" + tmpl = env.from_string('''{{ "f bar f"|title }}''') + assert tmpl.render() == "F Bar F" + tmpl = env.from_string('''{{ "foo-bar"|title }}''') + assert tmpl.render() == "Foo-Bar" + tmpl = env.from_string('''{{ "foo\tbar"|title }}''') + assert tmpl.render() == "Foo\tBar" + + def test_truncate(self): + tmpl = env.from_string( + '{{ data|truncate(15, true, ">>>") }}|' + '{{ data|truncate(15, false, ">>>") }}|' + '{{ smalldata|truncate(15) }}' + ) + out = tmpl.render(data='foobar baz bar' * 1000, + smalldata='foobar baz bar') + assert out == 'foobar baz barf>>>|foobar baz >>>|foobar baz bar' + + def test_upper(self): + tmpl = env.from_string('{{ "foo"|upper }}') + assert tmpl.render() == 'FOO' + + def test_urlize(self): + tmpl = env.from_string('{{ "foo http://www.example.com/ bar"|urlize }}') + assert tmpl.render() == 'foo '\ + 'http://www.example.com/ bar' + + def test_wordcount(self): + tmpl = env.from_string('{{ "foo bar baz"|wordcount }}') + assert tmpl.render() == '3' + + def test_block(self): + tmpl = env.from_string('{% filter lower|escape %}{% endfilter %}') + assert tmpl.render() == '<hehe>' + + def test_chaining(self): + tmpl = env.from_string('''{{ ['', '']|first|upper|escape }}''') + assert tmpl.render() == '<FOO>' + + def test_sum(self): + tmpl = env.from_string('''{{ [1, 2, 3, 4, 5, 6]|sum }}''') + assert tmpl.render() == '21' + + def test_sum_attributes(self): + tmpl = env.from_string('''{{ values|sum('value') }}''') + assert tmpl.render(values=[ + {'value': 23}, + {'value': 1}, + {'value': 18}, + ]) == '42' + + def test_sum_attributes_nested(self): + tmpl = env.from_string('''{{ values|sum('real.value') }}''') + assert tmpl.render(values=[ + {'real': {'value': 23}}, + {'real': {'value': 1}}, + {'real': {'value': 18}}, + ]) == '42' + + def test_abs(self): + tmpl = env.from_string('''{{ -1|abs }}|{{ 1|abs }}''') + assert tmpl.render() == '1|1', tmpl.render() + + def test_round_positive(self): + tmpl = env.from_string('{{ 2.7|round }}|{{ 2.1|round }}|' + "{{ 2.1234|round(3, 'floor') }}|" + "{{ 2.1|round(0, 'ceil') }}") + assert tmpl.render() == '3.0|2.0|2.123|3.0', tmpl.render() + + def test_round_negative(self): + tmpl = env.from_string('{{ 21.3|round(-1)}}|' + "{{ 21.3|round(-1, 'ceil')}}|" + "{{ 21.3|round(-1, 'floor')}}") + assert tmpl.render() == '20.0|30.0|20.0',tmpl.render() + + def test_xmlattr(self): + tmpl = env.from_string("{{ {'foo': 42, 'bar': 23, 'fish': none, " + "'spam': missing, 'blub:blub': ''}|xmlattr }}") + out = tmpl.render().split() + assert len(out) == 3 + assert 'foo="42"' in out + assert 'bar="23"' in out + assert 'blub:blub="<?>"' in out + + def test_sort1(self): + tmpl = env.from_string('{{ [2, 3, 1]|sort }}|{{ [2, 3, 1]|sort(true) }}') + assert tmpl.render() == '[1, 2, 3]|[3, 2, 1]' + + def test_sort2(self): + tmpl = env.from_string('{{ "".join(["c", "A", "b", "D"]|sort) }}') + assert tmpl.render() == 'AbcD' + + def test_sort3(self): + tmpl = env.from_string('''{{ ['foo', 'Bar', 'blah']|sort }}''') + assert tmpl.render() == "['Bar', 'blah', 'foo']" + + def test_sort4(self): + class Magic(object): + def __init__(self, value): + self.value = value + def __unicode__(self): + return unicode(self.value) + tmpl = env.from_string('''{{ items|sort(attribute='value')|join }}''') + assert tmpl.render(items=map(Magic, [3, 2, 4, 1])) == '1234' + + def test_groupby(self): + tmpl = env.from_string(''' + {%- for grouper, list in [{'foo': 1, 'bar': 2}, + {'foo': 2, 'bar': 3}, + {'foo': 1, 'bar': 1}, + {'foo': 3, 'bar': 4}]|groupby('foo') -%} + {{ grouper }}{% for x in list %}: {{ x.foo }}, {{ x.bar }}{% endfor %}| + {%- endfor %}''') + assert tmpl.render().split('|') == [ + "1: 1, 2: 1, 1", + "2: 2, 3", + "3: 3, 4", + "" + ] + + def test_groupby_tuple_index(self): + tmpl = env.from_string(''' + {%- for grouper, list in [('a', 1), ('a', 2), ('b', 1)]|groupby(0) -%} + {{ grouper }}{% for x in list %}:{{ x.1 }}{% endfor %}| + {%- endfor %}''') + assert tmpl.render() == 'a:1:2|b:1|' + + def test_groupby_multidot(self): + class Date(object): + def __init__(self, day, month, year): + self.day = day + self.month = month + self.year = year + class Article(object): + def __init__(self, title, *date): + self.date = Date(*date) + self.title = title + articles = [ + Article('aha', 1, 1, 1970), + Article('interesting', 2, 1, 1970), + Article('really?', 3, 1, 1970), + Article('totally not', 1, 1, 1971) + ] + tmpl = env.from_string(''' + {%- for year, list in articles|groupby('date.year') -%} + {{ year }}{% for x in list %}[{{ x.title }}]{% endfor %}| + {%- endfor %}''') + assert tmpl.render(articles=articles).split('|') == [ + '1970[aha][interesting][really?]', + '1971[totally not]', + '' + ] + + def test_filtertag(self): + tmpl = env.from_string("{% filter upper|replace('FOO', 'foo') %}" + "foobar{% endfilter %}") + assert tmpl.render() == 'fooBAR' + + def test_replace(self): + env = Environment() + tmpl = env.from_string('{{ string|replace("o", 42) }}') + assert tmpl.render(string='') == '' + env = Environment(autoescape=True) + tmpl = env.from_string('{{ string|replace("o", 42) }}') + assert tmpl.render(string='') == '<f4242>' + tmpl = env.from_string('{{ string|replace("<", 42) }}') + assert tmpl.render(string='') == '42foo>' + tmpl = env.from_string('{{ string|replace("o", ">x<") }}') + assert tmpl.render(string=Markup('foo')) == 'f>x<>x<' + + def test_forceescape(self): + tmpl = env.from_string('{{ x|forceescape }}') + assert tmpl.render(x=Markup('
')) == u'<div />' + + def test_safe(self): + env = Environment(autoescape=True) + tmpl = env.from_string('{{ "
foo
"|safe }}') + assert tmpl.render() == '
foo
' + tmpl = env.from_string('{{ "
foo
" }}') + assert tmpl.render() == '<div>foo</div>' + + def test_urlencode(self): + env = Environment(autoescape=True) + tmpl = env.from_string('{{ "Hello, world!"|urlencode }}') + assert tmpl.render() == 'Hello%2C%20world%21' + tmpl = env.from_string('{{ o|urlencode }}') + assert tmpl.render(o=u"Hello, world\u203d") == "Hello%2C%20world%E2%80%BD" + assert tmpl.render(o=(("f", 1),)) == "f=1" + assert tmpl.render(o=(('f', 1), ("z", 2))) == "f=1&z=2" + assert tmpl.render(o=((u"\u203d", 1),)) == "%E2%80%BD=1" + assert tmpl.render(o={u"\u203d": 1}) == "%E2%80%BD=1" + assert tmpl.render(o={0: 1}) == "0=1" + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(FilterTestCase)) + return suite diff --git a/libs/jinja2/testsuite/imports.py b/libs/jinja2/testsuite/imports.py new file mode 100755 index 0000000..1cb12cb --- /dev/null +++ b/libs/jinja2/testsuite/imports.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.imports + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests the import features (with includes). + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, DictLoader +from jinja2.exceptions import TemplateNotFound, TemplatesNotFound + + +test_env = Environment(loader=DictLoader(dict( + module='{% macro test() %}[{{ foo }}|{{ bar }}]{% endmacro %}', + header='[{{ foo }}|{{ 23 }}]', + o_printer='({{ o }})' +))) +test_env.globals['bar'] = 23 + + +class ImportsTestCase(JinjaTestCase): + + def test_context_imports(self): + t = test_env.from_string('{% import "module" as m %}{{ m.test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% import "module" as m without context %}{{ m.test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% import "module" as m with context %}{{ m.test() }}') + assert t.render(foo=42) == '[42|23]' + t = test_env.from_string('{% from "module" import test %}{{ test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% from "module" import test without context %}{{ test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% from "module" import test with context %}{{ test() }}') + assert t.render(foo=42) == '[42|23]' + + def test_trailing_comma(self): + test_env.from_string('{% from "foo" import bar, baz with context %}') + test_env.from_string('{% from "foo" import bar, baz, with context %}') + test_env.from_string('{% from "foo" import bar, with context %}') + test_env.from_string('{% from "foo" import bar, with, context %}') + test_env.from_string('{% from "foo" import bar, with with context %}') + + def test_exports(self): + m = test_env.from_string(''' + {% macro toplevel() %}...{% endmacro %} + {% macro __private() %}...{% endmacro %} + {% set variable = 42 %} + {% for item in [1] %} + {% macro notthere() %}{% endmacro %} + {% endfor %} + ''').module + assert m.toplevel() == '...' + assert not hasattr(m, '__missing') + assert m.variable == 42 + assert not hasattr(m, 'notthere') + + +class IncludesTestCase(JinjaTestCase): + + def test_context_include(self): + t = test_env.from_string('{% include "header" %}') + assert t.render(foo=42) == '[42|23]' + t = test_env.from_string('{% include "header" with context %}') + assert t.render(foo=42) == '[42|23]' + t = test_env.from_string('{% include "header" without context %}') + assert t.render(foo=42) == '[|23]' + + def test_choice_includes(self): + t = test_env.from_string('{% include ["missing", "header"] %}') + assert t.render(foo=42) == '[42|23]' + + t = test_env.from_string('{% include ["missing", "missing2"] ignore missing %}') + assert t.render(foo=42) == '' + + t = test_env.from_string('{% include ["missing", "missing2"] %}') + self.assert_raises(TemplateNotFound, t.render) + try: + t.render() + except TemplatesNotFound, e: + assert e.templates == ['missing', 'missing2'] + assert e.name == 'missing2' + else: + assert False, 'thou shalt raise' + + def test_includes(t, **ctx): + ctx['foo'] = 42 + assert t.render(ctx) == '[42|23]' + + t = test_env.from_string('{% include ["missing", "header"] %}') + test_includes(t) + t = test_env.from_string('{% include x %}') + test_includes(t, x=['missing', 'header']) + t = test_env.from_string('{% include [x, "header"] %}') + test_includes(t, x='missing') + t = test_env.from_string('{% include x %}') + test_includes(t, x='header') + t = test_env.from_string('{% include x %}') + test_includes(t, x='header') + t = test_env.from_string('{% include [x] %}') + test_includes(t, x='header') + + def test_include_ignoring_missing(self): + t = test_env.from_string('{% include "missing" %}') + self.assert_raises(TemplateNotFound, t.render) + for extra in '', 'with context', 'without context': + t = test_env.from_string('{% include "missing" ignore missing ' + + extra + ' %}') + assert t.render() == '' + + def test_context_include_with_overrides(self): + env = Environment(loader=DictLoader(dict( + main="{% for item in [1, 2, 3] %}{% include 'item' %}{% endfor %}", + item="{{ item }}" + ))) + assert env.get_template("main").render() == "123" + + def test_unoptimized_scopes(self): + t = test_env.from_string(""" + {% macro outer(o) %} + {% macro inner() %} + {% include "o_printer" %} + {% endmacro %} + {{ inner() }} + {% endmacro %} + {{ outer("FOO") }} + """) + assert t.render().strip() == '(FOO)' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ImportsTestCase)) + suite.addTest(unittest.makeSuite(IncludesTestCase)) + return suite diff --git a/libs/jinja2/testsuite/inheritance.py b/libs/jinja2/testsuite/inheritance.py new file mode 100755 index 0000000..355aa0c --- /dev/null +++ b/libs/jinja2/testsuite/inheritance.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.inheritance + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests the template inheritance feature. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, DictLoader + + +LAYOUTTEMPLATE = '''\ +|{% block block1 %}block 1 from layout{% endblock %} +|{% block block2 %}block 2 from layout{% endblock %} +|{% block block3 %} +{% block block4 %}nested block 4 from layout{% endblock %} +{% endblock %}|''' + +LEVEL1TEMPLATE = '''\ +{% extends "layout" %} +{% block block1 %}block 1 from level1{% endblock %}''' + +LEVEL2TEMPLATE = '''\ +{% extends "level1" %} +{% block block2 %}{% block block5 %}nested block 5 from level2{% +endblock %}{% endblock %}''' + +LEVEL3TEMPLATE = '''\ +{% extends "level2" %} +{% block block5 %}block 5 from level3{% endblock %} +{% block block4 %}block 4 from level3{% endblock %} +''' + +LEVEL4TEMPLATE = '''\ +{% extends "level3" %} +{% block block3 %}block 3 from level4{% endblock %} +''' + +WORKINGTEMPLATE = '''\ +{% extends "layout" %} +{% block block1 %} + {% if false %} + {% block block2 %} + this should workd + {% endblock %} + {% endif %} +{% endblock %} +''' + +env = Environment(loader=DictLoader({ + 'layout': LAYOUTTEMPLATE, + 'level1': LEVEL1TEMPLATE, + 'level2': LEVEL2TEMPLATE, + 'level3': LEVEL3TEMPLATE, + 'level4': LEVEL4TEMPLATE, + 'working': WORKINGTEMPLATE +}), trim_blocks=True) + + +class InheritanceTestCase(JinjaTestCase): + + def test_layout(self): + tmpl = env.get_template('layout') + assert tmpl.render() == ('|block 1 from layout|block 2 from ' + 'layout|nested block 4 from layout|') + + def test_level1(self): + tmpl = env.get_template('level1') + assert tmpl.render() == ('|block 1 from level1|block 2 from ' + 'layout|nested block 4 from layout|') + + def test_level2(self): + tmpl = env.get_template('level2') + assert tmpl.render() == ('|block 1 from level1|nested block 5 from ' + 'level2|nested block 4 from layout|') + + def test_level3(self): + tmpl = env.get_template('level3') + assert tmpl.render() == ('|block 1 from level1|block 5 from level3|' + 'block 4 from level3|') + + def test_level4(sel): + tmpl = env.get_template('level4') + assert tmpl.render() == ('|block 1 from level1|block 5 from ' + 'level3|block 3 from level4|') + + def test_super(self): + env = Environment(loader=DictLoader({ + 'a': '{% block intro %}INTRO{% endblock %}|' + 'BEFORE|{% block data %}INNER{% endblock %}|AFTER', + 'b': '{% extends "a" %}{% block data %}({{ ' + 'super() }}){% endblock %}', + 'c': '{% extends "b" %}{% block intro %}--{{ ' + 'super() }}--{% endblock %}\n{% block data ' + '%}[{{ super() }}]{% endblock %}' + })) + tmpl = env.get_template('c') + assert tmpl.render() == '--INTRO--|BEFORE|[(INNER)]|AFTER' + + def test_working(self): + tmpl = env.get_template('working') + + def test_reuse_blocks(self): + tmpl = env.from_string('{{ self.foo() }}|{% block foo %}42' + '{% endblock %}|{{ self.foo() }}') + assert tmpl.render() == '42|42|42' + + def test_preserve_blocks(self): + env = Environment(loader=DictLoader({ + 'a': '{% if false %}{% block x %}A{% endblock %}{% endif %}{{ self.x() }}', + 'b': '{% extends "a" %}{% block x %}B{{ super() }}{% endblock %}' + })) + tmpl = env.get_template('b') + assert tmpl.render() == 'BA' + + def test_dynamic_inheritance(self): + env = Environment(loader=DictLoader({ + 'master1': 'MASTER1{% block x %}{% endblock %}', + 'master2': 'MASTER2{% block x %}{% endblock %}', + 'child': '{% extends master %}{% block x %}CHILD{% endblock %}' + })) + tmpl = env.get_template('child') + for m in range(1, 3): + assert tmpl.render(master='master%d' % m) == 'MASTER%dCHILD' % m + + def test_multi_inheritance(self): + env = Environment(loader=DictLoader({ + 'master1': 'MASTER1{% block x %}{% endblock %}', + 'master2': 'MASTER2{% block x %}{% endblock %}', + 'child': '''{% if master %}{% extends master %}{% else %}{% extends + 'master1' %}{% endif %}{% block x %}CHILD{% endblock %}''' + })) + tmpl = env.get_template('child') + assert tmpl.render(master='master2') == 'MASTER2CHILD' + assert tmpl.render(master='master1') == 'MASTER1CHILD' + assert tmpl.render() == 'MASTER1CHILD' + + def test_scoped_block(self): + env = Environment(loader=DictLoader({ + 'master.html': '{% for item in seq %}[{% block item scoped %}' + '{% endblock %}]{% endfor %}' + })) + t = env.from_string('{% extends "master.html" %}{% block item %}' + '{{ item }}{% endblock %}') + assert t.render(seq=range(5)) == '[0][1][2][3][4]' + + def test_super_in_scoped_block(self): + env = Environment(loader=DictLoader({ + 'master.html': '{% for item in seq %}[{% block item scoped %}' + '{{ item }}{% endblock %}]{% endfor %}' + })) + t = env.from_string('{% extends "master.html" %}{% block item %}' + '{{ super() }}|{{ item * 2 }}{% endblock %}') + assert t.render(seq=range(5)) == '[0|0][1|2][2|4][3|6][4|8]' + + def test_scoped_block_after_inheritance(self): + env = Environment(loader=DictLoader({ + 'layout.html': ''' + {% block useless %}{% endblock %} + ''', + 'index.html': ''' + {%- extends 'layout.html' %} + {% from 'helpers.html' import foo with context %} + {% block useless %} + {% for x in [1, 2, 3] %} + {% block testing scoped %} + {{ foo(x) }} + {% endblock %} + {% endfor %} + {% endblock %} + ''', + 'helpers.html': ''' + {% macro foo(x) %}{{ the_foo + x }}{% endmacro %} + ''' + })) + rv = env.get_template('index.html').render(the_foo=42).split() + assert rv == ['43', '44', '45'] + + +class BugFixTestCase(JinjaTestCase): + + def test_fixed_macro_scoping_bug(self): + assert Environment(loader=DictLoader({ + 'test.html': '''\ + {% extends 'details.html' %} + + {% macro my_macro() %} + my_macro + {% endmacro %} + + {% block inner_box %} + {{ my_macro() }} + {% endblock %} + ''', + 'details.html': '''\ + {% extends 'standard.html' %} + + {% macro my_macro() %} + my_macro + {% endmacro %} + + {% block content %} + {% block outer_box %} + outer_box + {% block inner_box %} + inner_box + {% endblock %} + {% endblock %} + {% endblock %} + ''', + 'standard.html': ''' + {% block content %} {% endblock %} + ''' + })).get_template("test.html").render().split() == [u'outer_box', u'my_macro'] + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(InheritanceTestCase)) + suite.addTest(unittest.makeSuite(BugFixTestCase)) + return suite diff --git a/libs/jinja2/testsuite/lexnparse.py b/libs/jinja2/testsuite/lexnparse.py new file mode 100755 index 0000000..77b76ec --- /dev/null +++ b/libs/jinja2/testsuite/lexnparse.py @@ -0,0 +1,387 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.lexnparse + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + All the unittests regarding lexing, parsing and syntax. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import sys +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, Template, TemplateSyntaxError, \ + UndefinedError, nodes + +env = Environment() + + +# how does a string look like in jinja syntax? +if sys.version_info < (3, 0): + def jinja_string_repr(string): + return repr(string)[1:] +else: + jinja_string_repr = repr + + +class LexerTestCase(JinjaTestCase): + + def test_raw1(self): + tmpl = env.from_string('{% raw %}foo{% endraw %}|' + '{%raw%}{{ bar }}|{% baz %}{% endraw %}') + assert tmpl.render() == 'foo|{{ bar }}|{% baz %}' + + def test_raw2(self): + tmpl = env.from_string('1 {%- raw -%} 2 {%- endraw -%} 3') + assert tmpl.render() == '123' + + def test_balancing(self): + env = Environment('{%', '%}', '${', '}') + tmpl = env.from_string('''{% for item in seq + %}${{'foo': item}|upper}{% endfor %}''') + assert tmpl.render(seq=range(3)) == "{'FOO': 0}{'FOO': 1}{'FOO': 2}" + + def test_comments(self): + env = Environment('', '{', '}') + tmpl = env.from_string('''\ +
    + +
  • {item}
  • + +
''') + assert tmpl.render(seq=range(3)) == ("
    \n
  • 0
  • \n " + "
  • 1
  • \n
  • 2
  • \n
") + + def test_string_escapes(self): + for char in u'\0', u'\u2668', u'\xe4', u'\t', u'\r', u'\n': + tmpl = env.from_string('{{ %s }}' % jinja_string_repr(char)) + assert tmpl.render() == char + assert env.from_string('{{ "\N{HOT SPRINGS}" }}').render() == u'\u2668' + + def test_bytefallback(self): + from pprint import pformat + tmpl = env.from_string(u'''{{ 'foo'|pprint }}|{{ 'bär'|pprint }}''') + assert tmpl.render() == pformat('foo') + '|' + pformat(u'bär') + + def test_operators(self): + from jinja2.lexer import operators + for test, expect in operators.iteritems(): + if test in '([{}])': + continue + stream = env.lexer.tokenize('{{ %s }}' % test) + stream.next() + assert stream.current.type == expect + + def test_normalizing(self): + for seq in '\r', '\r\n', '\n': + env = Environment(newline_sequence=seq) + tmpl = env.from_string('1\n2\r\n3\n4\n') + result = tmpl.render() + assert result.replace(seq, 'X') == '1X2X3X4' + + +class ParserTestCase(JinjaTestCase): + + def test_php_syntax(self): + env = Environment('', '', '') + tmpl = env.from_string('''\ +\ + + +''') + assert tmpl.render(seq=range(5)) == '01234' + + def test_erb_syntax(self): + env = Environment('<%', '%>', '<%=', '%>', '<%#', '%>') + tmpl = env.from_string('''\ +<%# I'm a comment, I'm not interesting %>\ +<% for item in seq -%> + <%= item %> +<%- endfor %>''') + assert tmpl.render(seq=range(5)) == '01234' + + def test_comment_syntax(self): + env = Environment('', '${', '}', '') + tmpl = env.from_string('''\ +\ + + ${item} +''') + assert tmpl.render(seq=range(5)) == '01234' + + def test_balancing(self): + tmpl = env.from_string('''{{{'foo':'bar'}.foo}}''') + assert tmpl.render() == 'bar' + + def test_start_comment(self): + tmpl = env.from_string('''{# foo comment +and bar comment #} +{% macro blub() %}foo{% endmacro %} +{{ blub() }}''') + assert tmpl.render().strip() == 'foo' + + def test_line_syntax(self): + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%') + tmpl = env.from_string('''\ +<%# regular comment %> +% for item in seq: + ${item} +% endfor''') + assert [int(x.strip()) for x in tmpl.render(seq=range(5)).split()] == \ + range(5) + + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%', '##') + tmpl = env.from_string('''\ +<%# regular comment %> +% for item in seq: + ${item} ## the rest of the stuff +% endfor''') + assert [int(x.strip()) for x in tmpl.render(seq=range(5)).split()] == \ + range(5) + + def test_line_syntax_priority(self): + # XXX: why is the whitespace there in front of the newline? + env = Environment('{%', '%}', '${', '}', '/*', '*/', '##', '#') + tmpl = env.from_string('''\ +/* ignore me. + I'm a multiline comment */ +## for item in seq: +* ${item} # this is just extra stuff +## endfor''') + assert tmpl.render(seq=[1, 2]).strip() == '* 1\n* 2' + env = Environment('{%', '%}', '${', '}', '/*', '*/', '#', '##') + tmpl = env.from_string('''\ +/* ignore me. + I'm a multiline comment */ +# for item in seq: +* ${item} ## this is just extra stuff + ## extra stuff i just want to ignore +# endfor''') + assert tmpl.render(seq=[1, 2]).strip() == '* 1\n\n* 2' + + def test_error_messages(self): + def assert_error(code, expected): + try: + Template(code) + except TemplateSyntaxError, e: + assert str(e) == expected, 'unexpected error message' + else: + assert False, 'that was supposed to be an error' + + assert_error('{% for item in seq %}...{% endif %}', + "Encountered unknown tag 'endif'. Jinja was looking " + "for the following tags: 'endfor' or 'else'. The " + "innermost block that needs to be closed is 'for'.") + assert_error('{% if foo %}{% for item in seq %}...{% endfor %}{% endfor %}', + "Encountered unknown tag 'endfor'. Jinja was looking for " + "the following tags: 'elif' or 'else' or 'endif'. The " + "innermost block that needs to be closed is 'if'.") + assert_error('{% if foo %}', + "Unexpected end of template. Jinja was looking for the " + "following tags: 'elif' or 'else' or 'endif'. The " + "innermost block that needs to be closed is 'if'.") + assert_error('{% for item in seq %}', + "Unexpected end of template. Jinja was looking for the " + "following tags: 'endfor' or 'else'. The innermost block " + "that needs to be closed is 'for'.") + assert_error('{% block foo-bar-baz %}', + "Block names in Jinja have to be valid Python identifiers " + "and may not contain hyphens, use an underscore instead.") + assert_error('{% unknown_tag %}', + "Encountered unknown tag 'unknown_tag'.") + + +class SyntaxTestCase(JinjaTestCase): + + def test_call(self): + env = Environment() + env.globals['foo'] = lambda a, b, c, e, g: a + b + c + e + g + tmpl = env.from_string("{{ foo('a', c='d', e='f', *['b'], **{'g': 'h'}) }}") + assert tmpl.render() == 'abdfh' + + def test_slicing(self): + tmpl = env.from_string('{{ [1, 2, 3][:] }}|{{ [1, 2, 3][::-1] }}') + assert tmpl.render() == '[1, 2, 3]|[3, 2, 1]' + + def test_attr(self): + tmpl = env.from_string("{{ foo.bar }}|{{ foo['bar'] }}") + assert tmpl.render(foo={'bar': 42}) == '42|42' + + def test_subscript(self): + tmpl = env.from_string("{{ foo[0] }}|{{ foo[-1] }}") + assert tmpl.render(foo=[0, 1, 2]) == '0|2' + + def test_tuple(self): + tmpl = env.from_string('{{ () }}|{{ (1,) }}|{{ (1, 2) }}') + assert tmpl.render() == '()|(1,)|(1, 2)' + + def test_math(self): + tmpl = env.from_string('{{ (1 + 1 * 2) - 3 / 2 }}|{{ 2**3 }}') + assert tmpl.render() == '1.5|8' + + def test_div(self): + tmpl = env.from_string('{{ 3 // 2 }}|{{ 3 / 2 }}|{{ 3 % 2 }}') + assert tmpl.render() == '1|1.5|1' + + def test_unary(self): + tmpl = env.from_string('{{ +3 }}|{{ -3 }}') + assert tmpl.render() == '3|-3' + + def test_concat(self): + tmpl = env.from_string("{{ [1, 2] ~ 'foo' }}") + assert tmpl.render() == '[1, 2]foo' + + def test_compare(self): + tmpl = env.from_string('{{ 1 > 0 }}|{{ 1 >= 1 }}|{{ 2 < 3 }}|' + '{{ 2 == 2 }}|{{ 1 <= 1 }}') + assert tmpl.render() == 'True|True|True|True|True' + + def test_inop(self): + tmpl = env.from_string('{{ 1 in [1, 2, 3] }}|{{ 1 not in [1, 2, 3] }}') + assert tmpl.render() == 'True|False' + + def test_literals(self): + tmpl = env.from_string('{{ [] }}|{{ {} }}|{{ () }}') + assert tmpl.render().lower() == '[]|{}|()' + + def test_bool(self): + tmpl = env.from_string('{{ true and false }}|{{ false ' + 'or true }}|{{ not false }}') + assert tmpl.render() == 'False|True|True' + + def test_grouping(self): + tmpl = env.from_string('{{ (true and false) or (false and true) and not false }}') + assert tmpl.render() == 'False' + + def test_django_attr(self): + tmpl = env.from_string('{{ [1, 2, 3].0 }}|{{ [[1]].0.0 }}') + assert tmpl.render() == '1|1' + + def test_conditional_expression(self): + tmpl = env.from_string('''{{ 0 if true else 1 }}''') + assert tmpl.render() == '0' + + def test_short_conditional_expression(self): + tmpl = env.from_string('<{{ 1 if false }}>') + assert tmpl.render() == '<>' + + tmpl = env.from_string('<{{ (1 if false).bar }}>') + self.assert_raises(UndefinedError, tmpl.render) + + def test_filter_priority(self): + tmpl = env.from_string('{{ "foo"|upper + "bar"|upper }}') + assert tmpl.render() == 'FOOBAR' + + def test_function_calls(self): + tests = [ + (True, '*foo, bar'), + (True, '*foo, *bar'), + (True, '*foo, bar=42'), + (True, '**foo, *bar'), + (True, '**foo, bar'), + (False, 'foo, bar'), + (False, 'foo, bar=42'), + (False, 'foo, bar=23, *args'), + (False, 'a, b=c, *d, **e'), + (False, '*foo, **bar') + ] + for should_fail, sig in tests: + if should_fail: + self.assert_raises(TemplateSyntaxError, + env.from_string, '{{ foo(%s) }}' % sig) + else: + env.from_string('foo(%s)' % sig) + + def test_tuple_expr(self): + for tmpl in [ + '{{ () }}', + '{{ (1, 2) }}', + '{{ (1, 2,) }}', + '{{ 1, }}', + '{{ 1, 2 }}', + '{% for foo, bar in seq %}...{% endfor %}', + '{% for x in foo, bar %}...{% endfor %}', + '{% for x in foo, %}...{% endfor %}' + ]: + assert env.from_string(tmpl) + + def test_trailing_comma(self): + tmpl = env.from_string('{{ (1, 2,) }}|{{ [1, 2,] }}|{{ {1: 2,} }}') + assert tmpl.render().lower() == '(1, 2)|[1, 2]|{1: 2}' + + def test_block_end_name(self): + env.from_string('{% block foo %}...{% endblock foo %}') + self.assert_raises(TemplateSyntaxError, env.from_string, + '{% block x %}{% endblock y %}') + + def test_constant_casing(self): + for const in True, False, None: + tmpl = env.from_string('{{ %s }}|{{ %s }}|{{ %s }}' % ( + str(const), str(const).lower(), str(const).upper() + )) + assert tmpl.render() == '%s|%s|' % (const, const) + + def test_test_chaining(self): + self.assert_raises(TemplateSyntaxError, env.from_string, + '{{ foo is string is sequence }}') + assert env.from_string('{{ 42 is string or 42 is number }}' + ).render() == 'True' + + def test_string_concatenation(self): + tmpl = env.from_string('{{ "foo" "bar" "baz" }}') + assert tmpl.render() == 'foobarbaz' + + def test_notin(self): + bar = xrange(100) + tmpl = env.from_string('''{{ not 42 in bar }}''') + assert tmpl.render(bar=bar) == unicode(not 42 in bar) + + def test_implicit_subscribed_tuple(self): + class Foo(object): + def __getitem__(self, x): + return x + t = env.from_string('{{ foo[1, 2] }}') + assert t.render(foo=Foo()) == u'(1, 2)' + + def test_raw2(self): + tmpl = env.from_string('{% raw %}{{ FOO }} and {% BAR %}{% endraw %}') + assert tmpl.render() == '{{ FOO }} and {% BAR %}' + + def test_const(self): + tmpl = env.from_string('{{ true }}|{{ false }}|{{ none }}|' + '{{ none is defined }}|{{ missing is defined }}') + assert tmpl.render() == 'True|False|None|True|False' + + def test_neg_filter_priority(self): + node = env.parse('{{ -1|foo }}') + assert isinstance(node.body[0].nodes[0], nodes.Filter) + assert isinstance(node.body[0].nodes[0].node, nodes.Neg) + + def test_const_assign(self): + constass1 = '''{% set true = 42 %}''' + constass2 = '''{% for none in seq %}{% endfor %}''' + for tmpl in constass1, constass2: + self.assert_raises(TemplateSyntaxError, env.from_string, tmpl) + + def test_localset(self): + tmpl = env.from_string('''{% set foo = 0 %}\ +{% for item in [1, 2] %}{% set foo = 1 %}{% endfor %}\ +{{ foo }}''') + assert tmpl.render() == '0' + + def test_parse_unary(self): + tmpl = env.from_string('{{ -foo["bar"] }}') + assert tmpl.render(foo={'bar': 42}) == '-42' + tmpl = env.from_string('{{ -foo["bar"]|abs }}') + assert tmpl.render(foo={'bar': 42}) == '42' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(LexerTestCase)) + suite.addTest(unittest.makeSuite(ParserTestCase)) + suite.addTest(unittest.makeSuite(SyntaxTestCase)) + return suite diff --git a/libs/jinja2/testsuite/loader.py b/libs/jinja2/testsuite/loader.py new file mode 100755 index 0000000..f62ec92 --- /dev/null +++ b/libs/jinja2/testsuite/loader.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.loader + ~~~~~~~~~~~~~~~~~~~~~~~ + + Test the loaders. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import os +import sys +import tempfile +import shutil +import unittest + +from jinja2.testsuite import JinjaTestCase, dict_loader, \ + package_loader, filesystem_loader, function_loader, \ + choice_loader, prefix_loader + +from jinja2 import Environment, loaders +from jinja2.loaders import split_template_path +from jinja2.exceptions import TemplateNotFound + + +class LoaderTestCase(JinjaTestCase): + + def test_dict_loader(self): + env = Environment(loader=dict_loader) + tmpl = env.get_template('justdict.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_package_loader(self): + env = Environment(loader=package_loader) + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_filesystem_loader(self): + env = Environment(loader=filesystem_loader) + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + tmpl = env.get_template('foo/test.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_choice_loader(self): + env = Environment(loader=choice_loader) + tmpl = env.get_template('justdict.html') + assert tmpl.render().strip() == 'FOO' + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_function_loader(self): + env = Environment(loader=function_loader) + tmpl = env.get_template('justfunction.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_prefix_loader(self): + env = Environment(loader=prefix_loader) + tmpl = env.get_template('a/test.html') + assert tmpl.render().strip() == 'BAR' + tmpl = env.get_template('b/justdict.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing') + + def test_caching(self): + changed = False + class TestLoader(loaders.BaseLoader): + def get_source(self, environment, template): + return u'foo', None, lambda: not changed + env = Environment(loader=TestLoader(), cache_size=-1) + tmpl = env.get_template('template') + assert tmpl is env.get_template('template') + changed = True + assert tmpl is not env.get_template('template') + changed = False + + env = Environment(loader=TestLoader(), cache_size=0) + assert env.get_template('template') \ + is not env.get_template('template') + + env = Environment(loader=TestLoader(), cache_size=2) + t1 = env.get_template('one') + t2 = env.get_template('two') + assert t2 is env.get_template('two') + assert t1 is env.get_template('one') + t3 = env.get_template('three') + assert 'one' in env.cache + assert 'two' not in env.cache + assert 'three' in env.cache + + def test_split_template_path(self): + assert split_template_path('foo/bar') == ['foo', 'bar'] + assert split_template_path('./foo/bar') == ['foo', 'bar'] + self.assert_raises(TemplateNotFound, split_template_path, '../foo') + + +class ModuleLoaderTestCase(JinjaTestCase): + archive = None + + def compile_down(self, zip='deflated', py_compile=False): + super(ModuleLoaderTestCase, self).setup() + log = [] + self.reg_env = Environment(loader=prefix_loader) + if zip is not None: + self.archive = tempfile.mkstemp(suffix='.zip')[1] + else: + self.archive = tempfile.mkdtemp() + self.reg_env.compile_templates(self.archive, zip=zip, + log_function=log.append, + py_compile=py_compile) + self.mod_env = Environment(loader=loaders.ModuleLoader(self.archive)) + return ''.join(log) + + def teardown(self): + super(ModuleLoaderTestCase, self).teardown() + if hasattr(self, 'mod_env'): + if os.path.isfile(self.archive): + os.remove(self.archive) + else: + shutil.rmtree(self.archive) + self.archive = None + + def test_log(self): + log = self.compile_down() + assert 'Compiled "a/foo/test.html" as ' \ + 'tmpl_a790caf9d669e39ea4d280d597ec891c4ef0404a' in log + assert 'Finished compiling templates' in log + assert 'Could not compile "a/syntaxerror.html": ' \ + 'Encountered unknown tag \'endif\'' in log + + def _test_common(self): + tmpl1 = self.reg_env.get_template('a/test.html') + tmpl2 = self.mod_env.get_template('a/test.html') + assert tmpl1.render() == tmpl2.render() + + tmpl1 = self.reg_env.get_template('b/justdict.html') + tmpl2 = self.mod_env.get_template('b/justdict.html') + assert tmpl1.render() == tmpl2.render() + + def test_deflated_zip_compile(self): + self.compile_down(zip='deflated') + self._test_common() + + def test_stored_zip_compile(self): + self.compile_down(zip='stored') + self._test_common() + + def test_filesystem_compile(self): + self.compile_down(zip=None) + self._test_common() + + def test_weak_references(self): + self.compile_down() + tmpl = self.mod_env.get_template('a/test.html') + key = loaders.ModuleLoader.get_template_key('a/test.html') + name = self.mod_env.loader.module.__name__ + + assert hasattr(self.mod_env.loader.module, key) + assert name in sys.modules + + # unset all, ensure the module is gone from sys.modules + self.mod_env = tmpl = None + + try: + import gc + gc.collect() + except: + pass + + assert name not in sys.modules + + def test_byte_compilation(self): + log = self.compile_down(py_compile=True) + assert 'Byte-compiled "a/test.html"' in log + tmpl1 = self.mod_env.get_template('a/test.html') + mod = self.mod_env.loader.module. \ + tmpl_3c4ddf650c1a73df961a6d3d2ce2752f1b8fd490 + assert mod.__file__.endswith('.pyc') + + def test_choice_loader(self): + log = self.compile_down(py_compile=True) + assert 'Byte-compiled "a/test.html"' in log + + self.mod_env.loader = loaders.ChoiceLoader([ + self.mod_env.loader, + loaders.DictLoader({'DICT_SOURCE': 'DICT_TEMPLATE'}) + ]) + + tmpl1 = self.mod_env.get_template('a/test.html') + self.assert_equal(tmpl1.render(), 'BAR') + tmpl2 = self.mod_env.get_template('DICT_SOURCE') + self.assert_equal(tmpl2.render(), 'DICT_TEMPLATE') + + def test_prefix_loader(self): + log = self.compile_down(py_compile=True) + assert 'Byte-compiled "a/test.html"' in log + + self.mod_env.loader = loaders.PrefixLoader({ + 'MOD': self.mod_env.loader, + 'DICT': loaders.DictLoader({'test.html': 'DICT_TEMPLATE'}) + }) + + tmpl1 = self.mod_env.get_template('MOD/a/test.html') + self.assert_equal(tmpl1.render(), 'BAR') + tmpl2 = self.mod_env.get_template('DICT/test.html') + self.assert_equal(tmpl2.render(), 'DICT_TEMPLATE') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(LoaderTestCase)) + suite.addTest(unittest.makeSuite(ModuleLoaderTestCase)) + return suite diff --git a/libs/jinja2/testsuite/regression.py b/libs/jinja2/testsuite/regression.py new file mode 100755 index 0000000..4db9076 --- /dev/null +++ b/libs/jinja2/testsuite/regression.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.regression + ~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests corner cases and bugs. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Template, Environment, DictLoader, TemplateSyntaxError, \ + TemplateNotFound, PrefixLoader + +env = Environment() + + +class CornerTestCase(JinjaTestCase): + + def test_assigned_scoping(self): + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) -%} + [{{ item }}] + {%- endfor %} + {{- item -}} + ''') + assert t.render(item=42) == '[1][2][3][4]42' + + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) -%} + [{{ item }}] + {%- endfor %} + {%- set item = 42 %} + {{- item -}} + ''') + assert t.render() == '[1][2][3][4]42' + + t = env.from_string(''' + {%- set item = 42 %} + {%- for item in (1, 2, 3, 4) -%} + [{{ item }}] + {%- endfor %} + {{- item -}} + ''') + assert t.render() == '[1][2][3][4]42' + + def test_closure_scoping(self): + t = env.from_string(''' + {%- set wrapper = "" %} + {%- for item in (1, 2, 3, 4) %} + {%- macro wrapper() %}[{{ item }}]{% endmacro %} + {{- wrapper() }} + {%- endfor %} + {{- wrapper -}} + ''') + assert t.render() == '[1][2][3][4]' + + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) %} + {%- macro wrapper() %}[{{ item }}]{% endmacro %} + {{- wrapper() }} + {%- endfor %} + {%- set wrapper = "" %} + {{- wrapper -}} + ''') + assert t.render() == '[1][2][3][4]' + + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) %} + {%- macro wrapper() %}[{{ item }}]{% endmacro %} + {{- wrapper() }} + {%- endfor %} + {{- wrapper -}} + ''') + assert t.render(wrapper=23) == '[1][2][3][4]23' + + +class BugTestCase(JinjaTestCase): + + def test_keyword_folding(self): + env = Environment() + env.filters['testing'] = lambda value, some: value + some + assert env.from_string("{{ 'test'|testing(some='stuff') }}") \ + .render() == 'teststuff' + + def test_extends_output_bugs(self): + env = Environment(loader=DictLoader({ + 'parent.html': '(({% block title %}{% endblock %}))' + })) + + t = env.from_string('{% if expr %}{% extends "parent.html" %}{% endif %}' + '[[{% block title %}title{% endblock %}]]' + '{% for item in [1, 2, 3] %}({{ item }}){% endfor %}') + assert t.render(expr=False) == '[[title]](1)(2)(3)' + assert t.render(expr=True) == '((title))' + + def test_urlize_filter_escaping(self): + tmpl = env.from_string('{{ "http://www.example.org/http://www.example.org/<foo' + + def test_loop_call_loop(self): + tmpl = env.from_string(''' + + {% macro test() %} + {{ caller() }} + {% endmacro %} + + {% for num1 in range(5) %} + {% call test() %} + {% for num2 in range(10) %} + {{ loop.index }} + {% endfor %} + {% endcall %} + {% endfor %} + + ''') + + assert tmpl.render().split() == map(unicode, range(1, 11)) * 5 + + def test_weird_inline_comment(self): + env = Environment(line_statement_prefix='%') + self.assert_raises(TemplateSyntaxError, env.from_string, + '% for item in seq {# missing #}\n...% endfor') + + def test_old_macro_loop_scoping_bug(self): + tmpl = env.from_string('{% for i in (1, 2) %}{{ i }}{% endfor %}' + '{% macro i() %}3{% endmacro %}{{ i() }}') + assert tmpl.render() == '123' + + def test_partial_conditional_assignments(self): + tmpl = env.from_string('{% if b %}{% set a = 42 %}{% endif %}{{ a }}') + assert tmpl.render(a=23) == '23' + assert tmpl.render(b=True) == '42' + + def test_stacked_locals_scoping_bug(self): + env = Environment(line_statement_prefix='#') + t = env.from_string('''\ +# for j in [1, 2]: +# set x = 1 +# for i in [1, 2]: +# print x +# if i % 2 == 0: +# set x = x + 1 +# endif +# endfor +# endfor +# if a +# print 'A' +# elif b +# print 'B' +# elif c == d +# print 'C' +# else +# print 'D' +# endif + ''') + assert t.render(a=0, b=False, c=42, d=42.0) == '1111C' + + def test_stacked_locals_scoping_bug_twoframe(self): + t = Template(''' + {% set x = 1 %} + {% for item in foo %} + {% if item == 1 %} + {% set x = 2 %} + {% endif %} + {% endfor %} + {{ x }} + ''') + rv = t.render(foo=[1]).strip() + assert rv == u'1' + + def test_call_with_args(self): + t = Template("""{% macro dump_users(users) -%} +
    + {%- for user in users -%} +
  • {{ user.username|e }}

    {{ caller(user) }}
  • + {%- endfor -%} +
+ {%- endmacro -%} + + {% call(user) dump_users(list_of_user) -%} +
+
Realname
+
{{ user.realname|e }}
+
Description
+
{{ user.description }}
+
+ {% endcall %}""") + + assert [x.strip() for x in t.render(list_of_user=[{ + 'username':'apo', + 'realname':'something else', + 'description':'test' + }]).splitlines()] == [ + u'
  • apo

    ', + u'
    Realname
    ', + u'
    something else
    ', + u'
    Description
    ', + u'
    test
    ', + u'
    ', + u'
' + ] + + def test_empty_if_condition_fails(self): + self.assert_raises(TemplateSyntaxError, Template, '{% if %}....{% endif %}') + self.assert_raises(TemplateSyntaxError, Template, '{% if foo %}...{% elif %}...{% endif %}') + self.assert_raises(TemplateSyntaxError, Template, '{% for x in %}..{% endfor %}') + + def test_recursive_loop_bug(self): + tpl1 = Template(""" + {% for p in foo recursive%} + {{p.bar}} + {% for f in p.fields recursive%} + {{f.baz}} + {{p.bar}} + {% if f.rec %} + {{ loop(f.sub) }} + {% endif %} + {% endfor %} + {% endfor %} + """) + + tpl2 = Template(""" + {% for p in foo%} + {{p.bar}} + {% for f in p.fields recursive%} + {{f.baz}} + {{p.bar}} + {% if f.rec %} + {{ loop(f.sub) }} + {% endif %} + {% endfor %} + {% endfor %} + """) + + def test_correct_prefix_loader_name(self): + env = Environment(loader=PrefixLoader({ + 'foo': DictLoader({}) + })) + try: + env.get_template('foo/bar.html') + except TemplateNotFound, e: + assert e.name == 'foo/bar.html' + else: + assert False, 'expected error here' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(CornerTestCase)) + suite.addTest(unittest.makeSuite(BugTestCase)) + return suite diff --git a/libs/jinja2/testsuite/res/__init__.py b/libs/jinja2/testsuite/res/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/libs/jinja2/testsuite/res/templates/broken.html b/libs/jinja2/testsuite/res/templates/broken.html new file mode 100755 index 0000000..77669fa --- /dev/null +++ b/libs/jinja2/testsuite/res/templates/broken.html @@ -0,0 +1,3 @@ +Before +{{ fail() }} +After diff --git a/libs/jinja2/testsuite/res/templates/foo/test.html b/libs/jinja2/testsuite/res/templates/foo/test.html new file mode 100755 index 0000000..b7d6715 --- /dev/null +++ b/libs/jinja2/testsuite/res/templates/foo/test.html @@ -0,0 +1 @@ +FOO diff --git a/libs/jinja2/testsuite/res/templates/syntaxerror.html b/libs/jinja2/testsuite/res/templates/syntaxerror.html new file mode 100755 index 0000000..f21b817 --- /dev/null +++ b/libs/jinja2/testsuite/res/templates/syntaxerror.html @@ -0,0 +1,4 @@ +Foo +{% for item in broken %} + ... +{% endif %} diff --git a/libs/jinja2/testsuite/res/templates/test.html b/libs/jinja2/testsuite/res/templates/test.html new file mode 100755 index 0000000..ba578e4 --- /dev/null +++ b/libs/jinja2/testsuite/res/templates/test.html @@ -0,0 +1 @@ +BAR diff --git a/libs/jinja2/testsuite/security.py b/libs/jinja2/testsuite/security.py new file mode 100755 index 0000000..4518eac --- /dev/null +++ b/libs/jinja2/testsuite/security.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.security + ~~~~~~~~~~~~~~~~~~~~~~~~~ + + Checks the sandbox and other security features. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment +from jinja2.sandbox import SandboxedEnvironment, \ + ImmutableSandboxedEnvironment, unsafe +from jinja2 import Markup, escape +from jinja2.exceptions import SecurityError, TemplateSyntaxError, \ + TemplateRuntimeError + + +class PrivateStuff(object): + + def bar(self): + return 23 + + @unsafe + def foo(self): + return 42 + + def __repr__(self): + return 'PrivateStuff' + + +class PublicStuff(object): + bar = lambda self: 23 + _foo = lambda self: 42 + + def __repr__(self): + return 'PublicStuff' + + +class SandboxTestCase(JinjaTestCase): + + def test_unsafe(self): + env = SandboxedEnvironment() + self.assert_raises(SecurityError, env.from_string("{{ foo.foo() }}").render, + foo=PrivateStuff()) + self.assert_equal(env.from_string("{{ foo.bar() }}").render(foo=PrivateStuff()), '23') + + self.assert_raises(SecurityError, env.from_string("{{ foo._foo() }}").render, + foo=PublicStuff()) + self.assert_equal(env.from_string("{{ foo.bar() }}").render(foo=PublicStuff()), '23') + self.assert_equal(env.from_string("{{ foo.__class__ }}").render(foo=42), '') + self.assert_equal(env.from_string("{{ foo.func_code }}").render(foo=lambda:None), '') + # security error comes from __class__ already. + self.assert_raises(SecurityError, env.from_string( + "{{ foo.__class__.__subclasses__() }}").render, foo=42) + + def test_immutable_environment(self): + env = ImmutableSandboxedEnvironment() + self.assert_raises(SecurityError, env.from_string( + '{{ [].append(23) }}').render) + self.assert_raises(SecurityError, env.from_string( + '{{ {1:2}.clear() }}').render) + + def test_restricted(self): + env = SandboxedEnvironment() + self.assert_raises(TemplateSyntaxError, env.from_string, + "{% for item.attribute in seq %}...{% endfor %}") + self.assert_raises(TemplateSyntaxError, env.from_string, + "{% for foo, bar.baz in seq %}...{% endfor %}") + + def test_markup_operations(self): + # adding two strings should escape the unsafe one + unsafe = '' + safe = Markup('username') + assert unsafe + safe == unicode(escape(unsafe)) + unicode(safe) + + # string interpolations are safe to use too + assert Markup('%s') % '' == \ + '<bad user>' + assert Markup('%(username)s') % { + 'username': '' + } == '<bad user>' + + # an escaped object is markup too + assert type(Markup('foo') + 'bar') is Markup + + # and it implements __html__ by returning itself + x = Markup("foo") + assert x.__html__() is x + + # it also knows how to treat __html__ objects + class Foo(object): + def __html__(self): + return 'awesome' + def __unicode__(self): + return 'awesome' + assert Markup(Foo()) == 'awesome' + assert Markup('%s') % Foo() == \ + 'awesome' + + # escaping and unescaping + assert escape('"<>&\'') == '"<>&'' + assert Markup("Foo & Bar").striptags() == "Foo & Bar" + assert Markup("<test>").unescape() == "" + + def test_template_data(self): + env = Environment(autoescape=True) + t = env.from_string('{% macro say_hello(name) %}' + '

Hello {{ name }}!

{% endmacro %}' + '{{ say_hello("foo") }}') + escaped_out = '

Hello <blink>foo</blink>!

' + assert t.render() == escaped_out + assert unicode(t.module) == escaped_out + assert escape(t.module) == escaped_out + assert t.module.say_hello('foo') == escaped_out + assert escape(t.module.say_hello('foo')) == escaped_out + + def test_attr_filter(self): + env = SandboxedEnvironment() + tmpl = env.from_string('{{ cls|attr("__subclasses__")() }}') + self.assert_raises(SecurityError, tmpl.render, cls=int) + + def test_binary_operator_intercepting(self): + def disable_op(left, right): + raise TemplateRuntimeError('that operator so does not work') + for expr, ctx, rv in ('1 + 2', {}, '3'), ('a + 2', {'a': 2}, '4'): + env = SandboxedEnvironment() + env.binop_table['+'] = disable_op + t = env.from_string('{{ %s }}' % expr) + assert t.render(ctx) == rv + env.intercepted_binops = frozenset(['+']) + t = env.from_string('{{ %s }}' % expr) + try: + t.render(ctx) + except TemplateRuntimeError, e: + pass + else: + self.fail('expected runtime error') + + def test_unary_operator_intercepting(self): + def disable_op(arg): + raise TemplateRuntimeError('that operator so does not work') + for expr, ctx, rv in ('-1', {}, '-1'), ('-a', {'a': 2}, '-2'): + env = SandboxedEnvironment() + env.unop_table['-'] = disable_op + t = env.from_string('{{ %s }}' % expr) + assert t.render(ctx) == rv + env.intercepted_unops = frozenset(['-']) + t = env.from_string('{{ %s }}' % expr) + try: + t.render(ctx) + except TemplateRuntimeError, e: + pass + else: + self.fail('expected runtime error') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(SandboxTestCase)) + return suite diff --git a/libs/jinja2/testsuite/tests.py b/libs/jinja2/testsuite/tests.py new file mode 100755 index 0000000..3ece7a8 --- /dev/null +++ b/libs/jinja2/testsuite/tests.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.tests + ~~~~~~~~~~~~~~~~~~~~~~ + + Who tests the tests? + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Markup, Environment + +env = Environment() + + +class TestsTestCase(JinjaTestCase): + + def test_defined(self): + tmpl = env.from_string('{{ missing is defined }}|{{ true is defined }}') + assert tmpl.render() == 'False|True' + + def test_even(self): + tmpl = env.from_string('''{{ 1 is even }}|{{ 2 is even }}''') + assert tmpl.render() == 'False|True' + + def test_odd(self): + tmpl = env.from_string('''{{ 1 is odd }}|{{ 2 is odd }}''') + assert tmpl.render() == 'True|False' + + def test_lower(self): + tmpl = env.from_string('''{{ "foo" is lower }}|{{ "FOO" is lower }}''') + assert tmpl.render() == 'True|False' + + def test_typechecks(self): + tmpl = env.from_string(''' + {{ 42 is undefined }} + {{ 42 is defined }} + {{ 42 is none }} + {{ none is none }} + {{ 42 is number }} + {{ 42 is string }} + {{ "foo" is string }} + {{ "foo" is sequence }} + {{ [1] is sequence }} + {{ range is callable }} + {{ 42 is callable }} + {{ range(5) is iterable }} + {{ {} is mapping }} + {{ mydict is mapping }} + {{ [] is mapping }} + ''') + class MyDict(dict): + pass + assert tmpl.render(mydict=MyDict()).split() == [ + 'False', 'True', 'False', 'True', 'True', 'False', + 'True', 'True', 'True', 'True', 'False', 'True', + 'True', 'True', 'False' + ] + + def test_sequence(self): + tmpl = env.from_string( + '{{ [1, 2, 3] is sequence }}|' + '{{ "foo" is sequence }}|' + '{{ 42 is sequence }}' + ) + assert tmpl.render() == 'True|True|False' + + def test_upper(self): + tmpl = env.from_string('{{ "FOO" is upper }}|{{ "foo" is upper }}') + assert tmpl.render() == 'True|False' + + def test_sameas(self): + tmpl = env.from_string('{{ foo is sameas false }}|' + '{{ 0 is sameas false }}') + assert tmpl.render(foo=False) == 'True|False' + + def test_no_paren_for_arg1(self): + tmpl = env.from_string('{{ foo is sameas none }}') + assert tmpl.render(foo=None) == 'True' + + def test_escaped(self): + env = Environment(autoescape=True) + tmpl = env.from_string('{{ x is escaped }}|{{ y is escaped }}') + assert tmpl.render(x='foo', y=Markup('foo')) == 'False|True' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestsTestCase)) + return suite diff --git a/libs/jinja2/testsuite/utils.py b/libs/jinja2/testsuite/utils.py new file mode 100755 index 0000000..be2e902 --- /dev/null +++ b/libs/jinja2/testsuite/utils.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.utils + ~~~~~~~~~~~~~~~~~~~~~~ + + Tests utilities jinja uses. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import gc +import unittest + +import pickle + +from jinja2.testsuite import JinjaTestCase + +from jinja2.utils import LRUCache, escape, object_type_repr + + +class LRUCacheTestCase(JinjaTestCase): + + def test_simple(self): + d = LRUCache(3) + d["a"] = 1 + d["b"] = 2 + d["c"] = 3 + d["a"] + d["d"] = 4 + assert len(d) == 3 + assert 'a' in d and 'c' in d and 'd' in d and 'b' not in d + + def test_pickleable(self): + cache = LRUCache(2) + cache["foo"] = 42 + cache["bar"] = 23 + cache["foo"] + + for protocol in range(3): + copy = pickle.loads(pickle.dumps(cache, protocol)) + assert copy.capacity == cache.capacity + assert copy._mapping == cache._mapping + assert copy._queue == cache._queue + + +class HelpersTestCase(JinjaTestCase): + + def test_object_type_repr(self): + class X(object): + pass + self.assert_equal(object_type_repr(42), 'int object') + self.assert_equal(object_type_repr([]), 'list object') + self.assert_equal(object_type_repr(X()), + 'jinja2.testsuite.utils.X object') + self.assert_equal(object_type_repr(None), 'None') + self.assert_equal(object_type_repr(Ellipsis), 'Ellipsis') + + +class MarkupLeakTestCase(JinjaTestCase): + + def test_markup_leaks(self): + counts = set() + for count in xrange(20): + for item in xrange(1000): + escape("foo") + escape("") + escape(u"foo") + escape(u"") + counts.add(len(gc.get_objects())) + assert len(counts) == 1, 'ouch, c extension seems to leak objects' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(LRUCacheTestCase)) + suite.addTest(unittest.makeSuite(HelpersTestCase)) + + # this test only tests the c extension + if not hasattr(escape, 'func_code'): + suite.addTest(unittest.makeSuite(MarkupLeakTestCase)) + + return suite diff --git a/libs/jinja2/utils.py b/libs/jinja2/utils.py index 1e0bb81..568c63f 100755 --- a/libs/jinja2/utils.py +++ b/libs/jinja2/utils.py @@ -67,7 +67,7 @@ except TypeError, _error: del _test_gen_bug, _error -# for python 2.x we create outselves a next() function that does the +# for python 2.x we create ourselves a next() function that does the # basics without exception catching. try: next = next @@ -132,7 +132,7 @@ def contextfunction(f): def evalcontextfunction(f): - """This decoraotr can be used to mark a function or method as an eval + """This decorator can be used to mark a function or method as an eval context callable. This is similar to the :func:`contextfunction` but instead of passing the context, an evaluation context object is passed. For more information about the eval context, see @@ -195,7 +195,7 @@ def clear_caches(): def import_string(import_name, silent=False): - """Imports an object based on a string. This use useful if you want to + """Imports an object based on a string. This is useful if you want to use import paths as endpoints or something similar. An import path can be specified either in dotted notation (``xml.sax.saxutils.escape``) or with a colon as object delimiter (``xml.sax.saxutils:escape``). @@ -412,7 +412,7 @@ class LRUCache(object): return (self.capacity,) def copy(self): - """Return an shallow copy of the instance.""" + """Return a shallow copy of the instance.""" rv = self.__class__(self.capacity) rv._mapping.update(self._mapping) rv._queue = deque(self._queue) @@ -462,7 +462,7 @@ class LRUCache(object): """Get an item from the cache. Moves the item up so that it has the highest priority then. - Raise an `KeyError` if it does not exist. + Raise a `KeyError` if it does not exist. """ rv = self._mapping[key] if self._queue[-1] != key: @@ -497,7 +497,7 @@ class LRUCache(object): def __delitem__(self, key): """Remove an item from the cache dict. - Raise an `KeyError` if it does not exist. + Raise a `KeyError` if it does not exist. """ self._wlock.acquire() try: @@ -598,7 +598,7 @@ class Joiner(object): # try markupsafe first, if that fails go with Jinja2's bundled version # of markupsafe. Markupsafe was previously Jinja2's implementation of -# the Markup object but was moved into a separate package in a patchleve +# the Markup object but was moved into a separate package in a patchlevel # release try: from markupsafe import Markup, escape, soft_unicode diff --git a/libs/sqlalchemy/__init__.py b/libs/sqlalchemy/__init__.py index ef5f385..03293b5 100644 --- a/libs/sqlalchemy/__init__.py +++ b/libs/sqlalchemy/__init__.py @@ -117,7 +117,7 @@ from sqlalchemy.engine import create_engine, engine_from_config __all__ = sorted(name for name, obj in locals().items() if not (name.startswith('_') or inspect.ismodule(obj))) -__version__ = '0.7.5' +__version__ = '0.7.6' del inspect, sys diff --git a/libs/sqlalchemy/cextension/processors.c b/libs/sqlalchemy/cextension/processors.c index b539f68..427db5d 100644 --- a/libs/sqlalchemy/cextension/processors.c +++ b/libs/sqlalchemy/cextension/processors.c @@ -342,23 +342,18 @@ DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value) if (value == Py_None) Py_RETURN_NONE; - if (PyFloat_CheckExact(value)) { - /* Decimal does not accept float values directly */ - args = PyTuple_Pack(1, value); - if (args == NULL) - return NULL; + args = PyTuple_Pack(1, value); + if (args == NULL) + return NULL; - str = PyString_Format(self->format, args); - Py_DECREF(args); - if (str == NULL) - return NULL; + str = PyString_Format(self->format, args); + Py_DECREF(args); + if (str == NULL) + return NULL; - result = PyObject_CallFunctionObjArgs(self->type, str, NULL); - Py_DECREF(str); - return result; - } else { - return PyObject_CallFunctionObjArgs(self->type, value, NULL); - } + result = PyObject_CallFunctionObjArgs(self->type, str, NULL); + Py_DECREF(str); + return result; } static void diff --git a/libs/sqlalchemy/cextension/resultproxy.c b/libs/sqlalchemy/cextension/resultproxy.c index 64b6855..3494cca 100644 --- a/libs/sqlalchemy/cextension/resultproxy.c +++ b/libs/sqlalchemy/cextension/resultproxy.c @@ -246,6 +246,7 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) PyObject *exc_module, *exception; char *cstr_key; long index; + int key_fallback = 0; if (PyInt_CheckExact(key)) { index = PyInt_AS_LONG(key); @@ -276,12 +277,17 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) "O", key); if (record == NULL) return NULL; + key_fallback = 1; } indexobject = PyTuple_GetItem(record, 2); if (indexobject == NULL) return NULL; + if (key_fallback) { + Py_DECREF(record); + } + if (indexobject == Py_None) { exc_module = PyImport_ImportModule("sqlalchemy.exc"); if (exc_module == NULL) @@ -347,7 +353,16 @@ BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name) else return tmp; - return BaseRowProxy_subscript(self, name); + tmp = BaseRowProxy_subscript(self, name); + if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) { + PyErr_Format( + PyExc_AttributeError, + "Could not locate column in row for column '%.200s'", + PyString_AsString(name) + ); + return NULL; + } + return tmp; } /*********************** diff --git a/libs/sqlalchemy/dialects/firebird/base.py b/libs/sqlalchemy/dialects/firebird/base.py index 8cf2ded..031c689 100644 --- a/libs/sqlalchemy/dialects/firebird/base.py +++ b/libs/sqlalchemy/dialects/firebird/base.py @@ -215,7 +215,7 @@ class FBCompiler(sql.compiler.SQLCompiler): # Override to not use the AS keyword which FB 1.5 does not like if asfrom: alias_name = isinstance(alias.name, - expression._generated_label) and \ + expression._truncated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name diff --git a/libs/sqlalchemy/dialects/mssql/base.py b/libs/sqlalchemy/dialects/mssql/base.py index f7c94aa..103b0a3 100644 --- a/libs/sqlalchemy/dialects/mssql/base.py +++ b/libs/sqlalchemy/dialects/mssql/base.py @@ -791,6 +791,9 @@ class MSSQLCompiler(compiler.SQLCompiler): def get_from_hint_text(self, table, text): return text + def get_crud_hint_text(self, table, text): + return text + def limit_clause(self, select): # Limit in mssql is after the select keyword return "" @@ -949,6 +952,13 @@ class MSSQLCompiler(compiler.SQLCompiler): ] return 'OUTPUT ' + ', '.join(columns) + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + def label_select_column(self, select, column, asfrom): if isinstance(column, expression.Function): return column.label(None) diff --git a/libs/sqlalchemy/dialects/mysql/base.py b/libs/sqlalchemy/dialects/mysql/base.py index 6aa250d..d9ab5a3 100644 --- a/libs/sqlalchemy/dialects/mysql/base.py +++ b/libs/sqlalchemy/dialects/mysql/base.py @@ -84,6 +84,23 @@ all lower case both within SQLAlchemy as well as on the MySQL database itself, especially if database reflection features are to be used. +Transaction Isolation Level +--------------------------- + +:func:`.create_engine` accepts an ``isolation_level`` +parameter which results in the command ``SET SESSION +TRANSACTION ISOLATION LEVEL `` being invoked for +every new connection. Valid values for this parameter are +``READ COMMITTED``, ``READ UNCOMMITTED``, +``REPEATABLE READ``, and ``SERIALIZABLE``:: + + engine = create_engine( + "mysql://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +(new in 0.7.6) + Keys ---- @@ -221,8 +238,29 @@ simply passed through to the underlying CREATE INDEX command, so it *must* be an integer. MySQL only allows a length for an index if it is for a CHAR, VARCHAR, TEXT, BINARY, VARBINARY and BLOB. +Index Types +~~~~~~~~~~~~~ + +Some MySQL storage engines permit you to specify an index type when creating +an index or primary key constraint. SQLAlchemy provides this feature via the +``mysql_using`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_using='hash') + +As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: + + PrimaryKeyConstraint("data", mysql_using='hash') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index +type for your MySQL storage engine. + More information can be found at: + http://dev.mysql.com/doc/refman/5.0/en/create-index.html + +http://dev.mysql.com/doc/refman/5.0/en/create-table.html + """ import datetime, inspect, re, sys @@ -1331,7 +1369,8 @@ class MySQLCompiler(compiler.SQLCompiler): return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in [from_table] + list(extra_froms)) - def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + def update_from_clause(self, update_stmt, from_table, + extra_froms, from_hints, **kw): return None @@ -1421,35 +1460,50 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts.append(joiner.join((opt, arg))) return ' '.join(table_opts) + def visit_create_index(self, create): index = create.element preparer = self.preparer + table = preparer.format_table(index.table) + columns = [preparer.quote(c.name, c.quote) for c in index.columns] + name = preparer.quote( + self._index_identifier(index.name), + index.quote) + text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s " \ - % (preparer.quote(self._index_identifier(index.name), - index.quote),preparer.format_table(index.table)) + text += "INDEX %s ON %s " % (name, table) + + columns = ', '.join(columns) if 'mysql_length' in index.kwargs: length = index.kwargs['mysql_length'] + text += "(%s(%d))" % (columns, length) else: - length = None - if length is not None: - text+= "(%s(%d))" \ - % (', '.join(preparer.quote(c.name, c.quote) - for c in index.columns), length) - else: - text+= "(%s)" \ - % (', '.join(preparer.quote(c.name, c.quote) - for c in index.columns)) + text += "(%s)" % (columns) + + if 'mysql_using' in index.kwargs: + using = index.kwargs['mysql_using'] + text += " USING %s" % (preparer.quote(using, index.quote)) + return text + def visit_primary_key_constraint(self, constraint): + text = super(MySQLDDLCompiler, self).\ + visit_primary_key_constraint(constraint) + if "mysql_using" in constraint.kwargs: + using = constraint.kwargs['mysql_using'] + text += " USING %s" % ( + self.preparer.quote(using, constraint.quote)) + return text def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX %s ON %s" % \ - (self.preparer.quote(self._index_identifier(index.name), index.quote), + (self.preparer.quote( + self._index_identifier(index.name), index.quote + ), self.preparer.format_table(index.table)) def visit_drop_constraint(self, drop): @@ -1768,8 +1822,40 @@ class MySQLDialect(default.DefaultDialect): _backslash_escapes = True _server_ansiquotes = False - def __init__(self, use_ansiquotes=None, **kwargs): + def __init__(self, use_ansiquotes=None, isolation_level=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + def on_connect(self): + if self.isolation_level is not None: + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + return connect + else: + return None + + _isolation_lookup = set(['SERIALIZABLE', + 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) + + def set_isolation_level(self, connection, level): + level = level.replace('_', ' ') + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute('SELECT @@tx_isolation') + val = cursor.fetchone()[0] + cursor.close() + return val.upper().replace("-", " ") def do_commit(self, connection): """Execute a COMMIT.""" diff --git a/libs/sqlalchemy/dialects/oracle/base.py b/libs/sqlalchemy/dialects/oracle/base.py index 88e5062..dd761ae 100644 --- a/libs/sqlalchemy/dialects/oracle/base.py +++ b/libs/sqlalchemy/dialects/oracle/base.py @@ -158,7 +158,7 @@ RESERVED_WORDS = \ 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\ 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\ 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\ - 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT'.split()) + 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) @@ -309,6 +309,9 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): "", ) + def visit_LONG(self, type_): + return "LONG" + def visit_TIMESTAMP(self, type_): if type_.timezone: return "TIMESTAMP WITH TIME ZONE" @@ -481,7 +484,7 @@ class OracleCompiler(compiler.SQLCompiler): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" if asfrom or ashint: - alias_name = isinstance(alias.name, expression._generated_label) and \ + alias_name = isinstance(alias.name, expression._truncated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name if ashint: diff --git a/libs/sqlalchemy/dialects/oracle/cx_oracle.py b/libs/sqlalchemy/dialects/oracle/cx_oracle.py index 64526d2..5001acc 100644 --- a/libs/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/libs/sqlalchemy/dialects/oracle/cx_oracle.py @@ -77,7 +77,7 @@ with this feature but it should be regarded as experimental. Precision Numerics ------------------ -The SQLAlchemy dialect goes thorugh a lot of steps to ensure +The SQLAlchemy dialect goes through a lot of steps to ensure that decimal numbers are sent and received with full accuracy. An "outputtypehandler" callable is associated with each cx_oracle connection object which detects numeric types and @@ -89,6 +89,21 @@ this behavior, and will coerce the ``Decimal`` to ``float`` if the ``asdecimal`` flag is ``False`` (default on :class:`.Float`, optional on :class:`.Numeric`). +Because the handler coerces to ``Decimal`` in all cases first, +the feature can detract significantly from performance. +If precision numerics aren't required, the decimal handling +can be disabled by passing the flag ``coerce_to_decimal=False`` +to :func:`.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", + coerce_to_decimal=False) + +The ``coerce_to_decimal`` flag is new in 0.7.6. + +Another alternative to performance is to use the +`cdecimal `_ library; +see :class:`.Numeric` for additional notes. + The handler attempts to use the "precision" and "scale" attributes of the result set column to best determine if subsequent incoming values should be received as ``Decimal`` as @@ -468,6 +483,7 @@ class OracleDialect_cx_oracle(OracleDialect): auto_convert_lobs=True, threaded=True, allow_twophase=True, + coerce_to_decimal=True, arraysize=50, **kwargs): OracleDialect.__init__(self, **kwargs) self.threaded = threaded @@ -491,7 +507,12 @@ class OracleDialect_cx_oracle(OracleDialect): self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0) - self.supports_native_decimal = self.cx_oracle_ver >= (5, 0) + + self.supports_native_decimal = ( + self.cx_oracle_ver >= (5, 0) and + coerce_to_decimal + ) + self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0) if self.cx_oracle_ver is None: @@ -603,7 +624,9 @@ class OracleDialect_cx_oracle(OracleDialect): size, precision, scale): # convert all NUMBER with precision + positive scale to Decimal # this almost allows "native decimal" mode. - if defaultType == cx_Oracle.NUMBER and precision and scale > 0: + if self.supports_native_decimal and \ + defaultType == cx_Oracle.NUMBER and \ + precision and scale > 0: return cursor.var( cx_Oracle.STRING, 255, @@ -614,7 +637,8 @@ class OracleDialect_cx_oracle(OracleDialect): # make a decision based on each value received - the type # may change from row to row (!). This kills # off "native decimal" mode, handlers still needed. - elif defaultType == cx_Oracle.NUMBER \ + elif self.supports_native_decimal and \ + defaultType == cx_Oracle.NUMBER \ and not precision and scale <= 0: return cursor.var( cx_Oracle.STRING, diff --git a/libs/sqlalchemy/dialects/postgresql/base.py b/libs/sqlalchemy/dialects/postgresql/base.py index 69c11d8..c4c2bbd 100644 --- a/libs/sqlalchemy/dialects/postgresql/base.py +++ b/libs/sqlalchemy/dialects/postgresql/base.py @@ -47,9 +47,18 @@ Transaction Isolation Level :func:`.create_engine` accepts an ``isolation_level`` parameter which results in the command ``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL `` being invoked for every new connection. Valid values for this -parameter are ``READ_COMMITTED``, ``READ_UNCOMMITTED``, ``REPEATABLE_READ``, -and ``SERIALIZABLE``. Note that the psycopg2 dialect does *not* use this -technique and uses psycopg2-specific APIs (see that dialect for details). +parameter are ``READ COMMITTED``, ``READ UNCOMMITTED``, ``REPEATABLE READ``, +and ``SERIALIZABLE``:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +When using the psycopg2 dialect, a psycopg2-specific method of setting +transaction isolation level is used, but the API of ``isolation_level`` +remains the same - see :ref:`psycopg2_isolation`. + Remote / Cross-Schema Table Introspection ----------------------------------------- diff --git a/libs/sqlalchemy/dialects/postgresql/psycopg2.py b/libs/sqlalchemy/dialects/postgresql/psycopg2.py index c66180f..5aa9397 100644 --- a/libs/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/libs/sqlalchemy/dialects/postgresql/psycopg2.py @@ -38,6 +38,26 @@ psycopg2-specific keyword arguments which are accepted by * *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True by default. +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + +See also: + +`PQconnectdbParams `_ + Per-Statement/Connection Execution Options ------------------------------------------- @@ -97,6 +117,8 @@ Transactions The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. +.. _psycopg2_isolation: + Transaction Isolation Level --------------------------- diff --git a/libs/sqlalchemy/dialects/sqlite/base.py b/libs/sqlalchemy/dialects/sqlite/base.py index f9520af..10a0d88 100644 --- a/libs/sqlalchemy/dialects/sqlite/base.py +++ b/libs/sqlalchemy/dialects/sqlite/base.py @@ -441,20 +441,6 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result return result -class SQLiteExecutionContext(default.DefaultExecutionContext): - def get_result_proxy(self): - rp = base.ResultProxy(self) - if rp._metadata: - # adjust for dotted column names. SQLite - # in the case of UNION may store col names as - # "tablename.colname" - # in cursor.description - for colname in rp._metadata.keys: - if "." in colname: - trunc_col = colname.split(".")[1] - rp._metadata._set_keymap_synonym(trunc_col, colname) - return rp - class SQLiteDialect(default.DefaultDialect): name = 'sqlite' supports_alter = False @@ -472,7 +458,6 @@ class SQLiteDialect(default.DefaultDialect): ischema_names = ischema_names colspecs = colspecs isolation_level = None - execution_ctx_cls = SQLiteExecutionContext supports_cast = True supports_default_values = True @@ -540,6 +525,16 @@ class SQLiteDialect(default.DefaultDialect): else: return None + def _translate_colname(self, colname): + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname" + # in cursor.description + if "." in colname: + return colname.split(".")[1], colname + else: + return colname, None + @reflection.cache def get_table_names(self, connection, schema=None, **kw): if schema is not None: diff --git a/libs/sqlalchemy/engine/__init__.py b/libs/sqlalchemy/engine/__init__.py index 4fac3e5..23b4b0b 100644 --- a/libs/sqlalchemy/engine/__init__.py +++ b/libs/sqlalchemy/engine/__init__.py @@ -306,6 +306,12 @@ def create_engine(*args, **kwargs): this is configurable with the MySQLDB connection itself and the server configuration as well). + :param pool_reset_on_return='rollback': set the "reset on return" + behavior of the pool, which is whether ``rollback()``, + ``commit()``, or nothing is called upon connections + being returned to the pool. See the docstring for + ``reset_on_return`` at :class:`.Pool`. (new as of 0.7.6) + :param pool_timeout=30: number of seconds to wait before giving up on getting a connection from the pool. This is only used with :class:`~sqlalchemy.pool.QueuePool`. diff --git a/libs/sqlalchemy/engine/base.py b/libs/sqlalchemy/engine/base.py index db19fe7..d16fc9c 100644 --- a/libs/sqlalchemy/engine/base.py +++ b/libs/sqlalchemy/engine/base.py @@ -491,14 +491,23 @@ class Dialect(object): raise NotImplementedError() def do_executemany(self, cursor, statement, parameters, context=None): - """Provide an implementation of *cursor.executemany(statement, - parameters)*.""" + """Provide an implementation of ``cursor.executemany(statement, + parameters)``.""" raise NotImplementedError() def do_execute(self, cursor, statement, parameters, context=None): - """Provide an implementation of *cursor.execute(statement, - parameters)*.""" + """Provide an implementation of ``cursor.execute(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute_no_params(self, cursor, statement, parameters, context=None): + """Provide an implementation of ``cursor.execute(statement)``. + + The parameter collection should not be sent. + + """ raise NotImplementedError() @@ -777,12 +786,12 @@ class Connectable(object): def connect(self, **kwargs): """Return a :class:`.Connection` object. - + Depending on context, this may be ``self`` if this object is already an instance of :class:`.Connection`, or a newly procured :class:`.Connection` if this object is an instance of :class:`.Engine`. - + """ def contextual_connect(self): @@ -793,7 +802,7 @@ class Connectable(object): is already an instance of :class:`.Connection`, or a newly procured :class:`.Connection` if this object is an instance of :class:`.Engine`. - + """ raise NotImplementedError() @@ -904,6 +913,12 @@ class Connection(Connectable): c.__dict__ = self.__dict__.copy() return c + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + def execution_options(self, **opt): """ Set non-SQL options for the connection which take effect during execution. @@ -940,7 +955,7 @@ class Connection(Connectable): :param compiled_cache: Available on: Connection. A dictionary where :class:`.Compiled` objects will be cached when the :class:`.Connection` compiles a clause - expression into a :class:`.Compiled` object. + expression into a :class:`.Compiled` object. It is the user's responsibility to manage the size of this dictionary, which will have keys corresponding to the dialect, clause element, the column @@ -953,7 +968,7 @@ class Connection(Connectable): some operations, including flush operations. The caching used by the ORM internally supersedes a cache dictionary specified here. - + :param isolation_level: Available on: Connection. Set the transaction isolation level for the lifespan of this connection. Valid values include @@ -962,7 +977,7 @@ class Connection(Connectable): database specific, including those for :ref:`sqlite_toplevel`, :ref:`postgresql_toplevel` - see those dialect's documentation for further info. - + Note that this option necessarily affects the underying DBAPI connection for the lifespan of the originating :class:`.Connection`, and is not per-execution. This @@ -970,6 +985,18 @@ class Connection(Connectable): is returned to the connection pool, i.e. the :meth:`.Connection.close` method is called. + :param no_parameters: When ``True``, if the final parameter + list or dictionary is totally empty, will invoke the + statement on the cursor as ``cursor.execute(statement)``, + not passing the parameter collection at all. + Some DBAPIs such as psycopg2 and mysql-python consider + percent signs as significant only when parameters are + present; this option allows code to generate SQL + containing percent signs (and possibly other characters) + that is neutral regarding whether it's executed by the DBAPI + or piped into a script that's later invoked by + command line tools. New in 0.7.6. + :param stream_results: Available on: Connection, statement. Indicate to the dialect that results should be "streamed" and not pre-buffered, if possible. This is a limitation @@ -1113,17 +1140,35 @@ class Connection(Connectable): def begin(self): """Begin a transaction and return a transaction handle. - + The returned object is an instance of :class:`.Transaction`. + This object represents the "scope" of the transaction, + which completes when either the :meth:`.Transaction.rollback` + or :meth:`.Transaction.commit` method is called. + + Nested calls to :meth:`.begin` on the same :class:`.Connection` + will return new :class:`.Transaction` objects that represent + an emulated transaction within the scope of the enclosing + transaction, that is:: + + trans = conn.begin() # outermost transaction + trans2 = conn.begin() # "nested" + trans2.commit() # does nothing + trans.commit() # actually commits + + Calls to :meth:`.Transaction.commit` only have an effect + when invoked via the outermost :class:`.Transaction` object, though the + :meth:`.Transaction.rollback` method of any of the + :class:`.Transaction` objects will roll back the + transaction. - Repeated calls to ``begin`` on the same Connection will create - a lightweight, emulated nested transaction. Only the - outermost transaction may ``commit``. Calls to ``commit`` on - inner transactions are ignored. Any transaction in the - hierarchy may ``rollback``, however. + See also: - See also :meth:`.Connection.begin_nested`, - :meth:`.Connection.begin_twophase`. + :meth:`.Connection.begin_nested` - use a SAVEPOINT + + :meth:`.Connection.begin_twophase` - use a two phase /XID transaction + + :meth:`.Engine.begin` - context manager available from :class:`.Engine`. """ @@ -1157,7 +1202,7 @@ class Connection(Connectable): def begin_twophase(self, xid=None): """Begin a two-phase or XA transaction and return a transaction handle. - + The returned object is an instance of :class:`.TwoPhaseTransaction`, which in addition to the methods provided by :class:`.Transaction`, also provides a :meth:`~.TwoPhaseTransaction.prepare` @@ -1302,7 +1347,7 @@ class Connection(Connectable): def close(self): """Close this :class:`.Connection`. - + This results in a release of the underlying database resources, that is, the DBAPI connection referenced internally. The DBAPI connection is typically restored @@ -1313,7 +1358,7 @@ class Connection(Connectable): the DBAPI connection's ``rollback()`` method, regardless of any :class:`.Transaction` object that may be outstanding with regards to this :class:`.Connection`. - + After :meth:`~.Connection.close` is called, the :class:`.Connection` is permanently in a closed state, and will allow no further operations. @@ -1354,24 +1399,24 @@ class Connection(Connectable): * a :class:`.DDLElement` object * a :class:`.DefaultGenerator` object * a :class:`.Compiled` object - + :param \*multiparams/\**params: represent bound parameter values to be used in the execution. Typically, the format is either a collection of one or more dictionaries passed to \*multiparams:: - + conn.execute( table.insert(), {"id":1, "value":"v1"}, {"id":2, "value":"v2"} ) - + ...or individual key/values interpreted by \**params:: - + conn.execute( table.insert(), id=1, value="v1" ) - + In the case that a plain SQL string is passed, and the underlying DBAPI accepts positional bind parameters, a collection of tuples or individual values in \*multiparams may be passed:: @@ -1380,21 +1425,21 @@ class Connection(Connectable): "INSERT INTO table (id, value) VALUES (?, ?)", (1, "v1"), (2, "v2") ) - + conn.execute( "INSERT INTO table (id, value) VALUES (?, ?)", 1, "v1" ) - + Note above, the usage of a question mark "?" or other symbol is contingent upon the "paramstyle" accepted by the DBAPI in use, which may be any of "qmark", "named", "pyformat", "format", "numeric". See `pep-249 `_ for details on paramstyle. - + To execute a textual SQL statement which uses bound parameters in a DBAPI-agnostic way, use the :func:`~.expression.text` construct. - + """ for c in type(object).__mro__: if c in Connection.executors: @@ -1623,7 +1668,8 @@ class Connection(Connectable): if self._echo: self.engine.logger.info(statement) - self.engine.logger.info("%r", sql_util._repr_params(parameters, batches=10)) + self.engine.logger.info("%r", + sql_util._repr_params(parameters, batches=10)) try: if context.executemany: self.dialect.do_executemany( @@ -1631,6 +1677,11 @@ class Connection(Connectable): statement, parameters, context) + elif not parameters and context.no_parameters: + self.dialect.do_execute_no_params( + cursor, + statement, + context) else: self.dialect.do_execute( cursor, @@ -1845,33 +1896,41 @@ class Connection(Connectable): """Execute the given function within a transaction boundary. The function is passed this :class:`.Connection` - as the first argument, followed by the given \*args and \**kwargs. - - This is a shortcut for explicitly invoking - :meth:`.Connection.begin`, calling :meth:`.Transaction.commit` - upon success or :meth:`.Transaction.rollback` upon an - exception raise:: + as the first argument, followed by the given \*args and \**kwargs, + e.g.:: def do_something(conn, x, y): conn.execute("some statement", {'x':x, 'y':y}) - + conn.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Connection.begin`:: + + with conn.begin(): + conn.execute("some statement", {'x':5, 'y':10}) + + As well as with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) - Note that context managers (i.e. the ``with`` statement) - present a more modern way of accomplishing the above, - using the :class:`.Transaction` object as a base:: + See also: - with conn.begin(): - conn.execute("some statement", {'x':5, 'y':10}) - - One advantage to the :meth:`.Connection.transaction` - method is that the same method is also available - on :class:`.Engine` as :meth:`.Engine.transaction` - - this method procures a :class:`.Connection` and then - performs the same operation, allowing equivalent - usage with either a :class:`.Connection` or :class:`.Engine` - without needing to know what kind of object - it is. + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Engine.transaction` - engine-level version of + :meth:`.Connection.transaction` """ @@ -1887,15 +1946,15 @@ class Connection(Connectable): def run_callable(self, callable_, *args, **kwargs): """Given a callable object or function, execute it, passing a :class:`.Connection` as the first argument. - + The given \*args and \**kwargs are passed subsequent to the :class:`.Connection` argument. - + This function, along with :meth:`.Engine.run_callable`, allows a function to be run with a :class:`.Connection` or :class:`.Engine` object without the need to know which one is being dealt with. - + """ return callable_(self, *args, **kwargs) @@ -1906,11 +1965,11 @@ class Connection(Connectable): class Transaction(object): """Represent a database transaction in progress. - + The :class:`.Transaction` object is procured by calling the :meth:`~.Connection.begin` method of :class:`.Connection`:: - + from sqlalchemy import create_engine engine = create_engine("postgresql://scott:tiger@localhost/test") connection = engine.connect() @@ -1923,7 +1982,7 @@ class Transaction(object): also implements a context manager interface so that the Python ``with`` statement can be used with the :meth:`.Connection.begin` method:: - + with connection.begin(): connection.execute("insert into x (a, b) values (1, 2)") @@ -1931,7 +1990,7 @@ class Transaction(object): See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`, :meth:`.Connection.begin_nested`. - + .. index:: single: thread safety; Transaction """ @@ -2012,9 +2071,9 @@ class NestedTransaction(Transaction): A new :class:`.NestedTransaction` object may be procured using the :meth:`.Connection.begin_nested` method. - + The interface is the same as that of :class:`.Transaction`. - + """ def __init__(self, connection, parent): super(NestedTransaction, self).__init__(connection, parent) @@ -2033,13 +2092,13 @@ class NestedTransaction(Transaction): class TwoPhaseTransaction(Transaction): """Represent a two-phase transaction. - + A new :class:`.TwoPhaseTransaction` object may be procured using the :meth:`.Connection.begin_twophase` method. - + The interface is the same as that of :class:`.Transaction` with the addition of the :meth:`prepare` method. - + """ def __init__(self, connection, xid): super(TwoPhaseTransaction, self).__init__(connection, None) @@ -2049,9 +2108,9 @@ class TwoPhaseTransaction(Transaction): def prepare(self): """Prepare this :class:`.TwoPhaseTransaction`. - + After a PREPARE, the transaction can be committed. - + """ if not self._parent.is_active: raise exc.InvalidRequestError("This transaction is inactive") @@ -2075,11 +2134,11 @@ class Engine(Connectable, log.Identified): :func:`~sqlalchemy.create_engine` function. See also: - + :ref:`engines_toplevel` :ref:`connections_toplevel` - + """ _execution_options = util.immutabledict() @@ -2115,13 +2174,13 @@ class Engine(Connectable, log.Identified): def update_execution_options(self, **opt): """Update the default execution_options dictionary of this :class:`.Engine`. - + The given keys/values in \**opt are added to the default execution options that will be used for all connections. The initial contents of this dictionary can be sent via the ``execution_options`` paramter to :func:`.create_engine`. - + See :meth:`.Connection.execution_options` for more details on execution options. @@ -2236,19 +2295,96 @@ class Engine(Connectable, log.Identified): if connection is None: conn.close() + class _trans_ctx(object): + def __init__(self, conn, transaction, close_with_result): + self.conn = conn + self.transaction = transaction + self.close_with_result = close_with_result + + def __enter__(self): + return self.conn + + def __exit__(self, type, value, traceback): + if type is not None: + self.transaction.rollback() + else: + self.transaction.commit() + if not self.close_with_result: + self.conn.close() + + def begin(self, close_with_result=False): + """Return a context manager delivering a :class:`.Connection` + with a :class:`.Transaction` established. + + E.g.:: + + with engine.begin() as conn: + conn.execute("insert into table (x, y, z) values (1, 2, 3)") + conn.execute("my_special_procedure(5)") + + Upon successful operation, the :class:`.Transaction` + is committed. If an error is raised, the :class:`.Transaction` + is rolled back. + + The ``close_with_result`` flag is normally ``False``, and indicates + that the :class:`.Connection` will be closed when the operation + is complete. When set to ``True``, it indicates the :class:`.Connection` + is in "single use" mode, where the :class:`.ResultProxy` + returned by the first call to :meth:`.Connection.execute` will + close the :class:`.Connection` when that :class:`.ResultProxy` + has exhausted all result rows. + + New in 0.7.6. + + See also: + + :meth:`.Engine.connect` - procure a :class:`.Connection` from + an :class:`.Engine`. + + :meth:`.Connection.begin` - start a :class:`.Transaction` + for a particular :class:`.Connection`. + + """ + conn = self.contextual_connect(close_with_result=close_with_result) + trans = conn.begin() + return Engine._trans_ctx(conn, trans, close_with_result) + def transaction(self, callable_, *args, **kwargs): """Execute the given function within a transaction boundary. - The function is passed a newly procured - :class:`.Connection` as the first argument, followed by - the given \*args and \**kwargs. The :class:`.Connection` - is then closed (returned to the pool) when the operation - is complete. + The function is passed a :class:`.Connection` newly procured + from :meth:`.Engine.contextual_connect` as the first argument, + followed by the given \*args and \**kwargs. + + e.g.:: + + def do_something(conn, x, y): + conn.execute("some statement", {'x':x, 'y':y}) + + engine.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) - This method can be used interchangeably with - :meth:`.Connection.transaction`. See that method for - more details on usage as well as a modern alternative - using context managers (i.e. the ``with`` statement). + See also: + + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Connection.transaction` - connection-level version of + :meth:`.Engine.transaction` """ @@ -2261,15 +2397,15 @@ class Engine(Connectable, log.Identified): def run_callable(self, callable_, *args, **kwargs): """Given a callable object or function, execute it, passing a :class:`.Connection` as the first argument. - + The given \*args and \**kwargs are passed subsequent to the :class:`.Connection` argument. - + This function, along with :meth:`.Connection.run_callable`, allows a function to be run with a :class:`.Connection` or :class:`.Engine` object without the need to know which one is being dealt with. - + """ conn = self.contextual_connect() try: @@ -2390,19 +2526,19 @@ class Engine(Connectable, log.Identified): def raw_connection(self): """Return a "raw" DBAPI connection from the connection pool. - + The returned object is a proxied version of the DBAPI connection object used by the underlying driver in use. The object will have all the same behavior as the real DBAPI connection, except that its ``close()`` method will result in the connection being returned to the pool, rather than being closed for real. - + This method provides direct DBAPI connection access for special situations. In most situations, the :class:`.Connection` object should be used, which is procured using the :meth:`.Engine.connect` method. - + """ return self.pool.unique_connection() @@ -2487,7 +2623,6 @@ except ImportError: def __getattr__(self, name): try: - # TODO: no test coverage here return self[name] except KeyError, e: raise AttributeError(e.args[0]) @@ -2575,6 +2710,10 @@ class ResultMetaData(object): context = parent.context dialect = context.dialect typemap = dialect.dbapi_type_map + translate_colname = dialect._translate_colname + + # high precedence key values. + primary_keymap = {} for i, rec in enumerate(metadata): colname = rec[0] @@ -2583,6 +2722,9 @@ class ResultMetaData(object): if dialect.description_encoding: colname = dialect._description_decoder(colname) + if translate_colname: + colname, untranslated = translate_colname(colname) + if context.result_map: try: name, obj, type_ = context.result_map[colname.lower()] @@ -2600,15 +2742,17 @@ class ResultMetaData(object): # indexes as keys. This is only needed for the Python version of # RowProxy (the C version uses a faster path for integer indexes). - keymap[i] = rec - - # Column names as keys - if keymap.setdefault(name.lower(), rec) is not rec: - # We do not raise an exception directly because several - # columns colliding by name is not a problem as long as the - # user does not try to access them (ie use an index directly, - # or the more precise ColumnElement) - keymap[name.lower()] = (processor, obj, None) + primary_keymap[i] = rec + + # populate primary keymap, looking for conflicts. + if primary_keymap.setdefault(name.lower(), rec) is not rec: + # place a record that doesn't have the "index" - this + # is interpreted later as an AmbiguousColumnError, + # but only when actually accessed. Columns + # colliding by name is not a problem if those names + # aren't used; integer and ColumnElement access is always + # unambiguous. + primary_keymap[name.lower()] = (processor, obj, None) if dialect.requires_name_normalize: colname = dialect.normalize_name(colname) @@ -2618,10 +2762,20 @@ class ResultMetaData(object): for o in obj: keymap[o] = rec + if translate_colname and \ + untranslated: + keymap[untranslated] = rec + + # overwrite keymap values with those of the + # high precedence keymap. + keymap.update(primary_keymap) + if parent._echo: context.engine.logger.debug( "Col %r", tuple(x[0] for x in metadata)) + @util.pending_deprecation("0.8", "sqlite dialect uses " + "_translate_colname() now") def _set_keymap_synonym(self, name, origname): """Set a synonym for the given name. @@ -2647,7 +2801,7 @@ class ResultMetaData(object): if key._label and key._label.lower() in map: result = map[key._label.lower()] elif hasattr(key, 'name') and key.name.lower() in map: - # match is only on name. + # match is only on name. result = map[key.name.lower()] # search extra hard to make sure this # isn't a column/label name overlap. @@ -2800,7 +2954,7 @@ class ResultProxy(object): @property def returns_rows(self): """True if this :class:`.ResultProxy` returns rows. - + I.e. if it is legal to call the methods :meth:`~.ResultProxy.fetchone`, :meth:`~.ResultProxy.fetchmany` @@ -2814,12 +2968,12 @@ class ResultProxy(object): """True if this :class:`.ResultProxy` is the result of a executing an expression language compiled :func:`.expression.insert` construct. - + When True, this implies that the :attr:`inserted_primary_key` attribute is accessible, assuming the statement did not include a user defined "returning" construct. - + """ return self.context.isinsert @@ -2867,7 +3021,7 @@ class ResultProxy(object): @util.memoized_property def inserted_primary_key(self): """Return the primary key for the row just inserted. - + The return value is a list of scalar values corresponding to the list of primary key columns in the target table. @@ -2875,7 +3029,7 @@ class ResultProxy(object): This only applies to single row :func:`.insert` constructs which did not explicitly specify :meth:`.Insert.returning`. - + Note that primary key columns which specify a server_default clause, or otherwise do not qualify as "autoincrement" diff --git a/libs/sqlalchemy/engine/default.py b/libs/sqlalchemy/engine/default.py index 73bd7fd..5c2d981 100644 --- a/libs/sqlalchemy/engine/default.py +++ b/libs/sqlalchemy/engine/default.py @@ -44,6 +44,7 @@ class DefaultDialect(base.Dialect): postfetch_lastrowid = True implicit_returning = False + supports_native_enum = False supports_native_boolean = False @@ -95,6 +96,10 @@ class DefaultDialect(base.Dialect): # and denormalize_name() must be provided. requires_name_normalize = False + # a hook for SQLite's translation of + # result column names + _translate_colname = None + reflection_options = () def __init__(self, convert_unicode=False, assert_unicode=False, @@ -329,6 +334,9 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, context=None): cursor.execute(statement, parameters) + def do_execute_no_params(self, cursor, statement, context=None): + cursor.execute(statement) + def is_disconnect(self, e, connection, cursor): return False @@ -533,6 +541,10 @@ class DefaultExecutionContext(base.ExecutionContext): return self @util.memoized_property + def no_parameters(self): + return self.execution_options.get("no_parameters", False) + + @util.memoized_property def is_crud(self): return self.isinsert or self.isupdate or self.isdelete diff --git a/libs/sqlalchemy/engine/reflection.py b/libs/sqlalchemy/engine/reflection.py index f5911f3..71d97e6 100644 --- a/libs/sqlalchemy/engine/reflection.py +++ b/libs/sqlalchemy/engine/reflection.py @@ -317,7 +317,7 @@ class Inspector(object): info_cache=self.info_cache, **kw) return indexes - def reflecttable(self, table, include_columns, exclude_columns=None): + def reflecttable(self, table, include_columns, exclude_columns=()): """Given a Table object, load its internal constructs based on introspection. This is the underlying method used by most dialects to produce @@ -414,9 +414,12 @@ class Inspector(object): # Primary keys pk_cons = self.get_pk_constraint(table_name, schema, **tblkw) if pk_cons: + pk_cols = [table.c[pk] + for pk in pk_cons['constrained_columns'] + if pk in table.c and pk not in exclude_columns + ] + [pk for pk in table.primary_key if pk.key in exclude_columns] primary_key_constraint = sa_schema.PrimaryKeyConstraint(name=pk_cons.get('name'), - *[table.c[pk] for pk in pk_cons['constrained_columns'] - if pk in table.c] + *pk_cols ) table.append_constraint(primary_key_constraint) diff --git a/libs/sqlalchemy/engine/strategies.py b/libs/sqlalchemy/engine/strategies.py index 7b2da68..4d5a4b3 100644 --- a/libs/sqlalchemy/engine/strategies.py +++ b/libs/sqlalchemy/engine/strategies.py @@ -108,7 +108,8 @@ class DefaultEngineStrategy(EngineStrategy): 'timeout': 'pool_timeout', 'recycle': 'pool_recycle', 'events':'pool_events', - 'use_threadlocal':'pool_threadlocal'} + 'use_threadlocal':'pool_threadlocal', + 'reset_on_return':'pool_reset_on_return'} for k in util.get_cls_kwargs(poolclass): tk = translate.get(k, k) if tk in kwargs: @@ -226,6 +227,9 @@ class MockEngineStrategy(EngineStrategy): def contextual_connect(self, **kwargs): return self + def execution_options(self, **kw): + return self + def compiler(self, statement, parameters, **kwargs): return self._dialect.compiler( statement, parameters, engine=self, **kwargs) diff --git a/libs/sqlalchemy/event.py b/libs/sqlalchemy/event.py index 9cc3139..cd70b3a 100644 --- a/libs/sqlalchemy/event.py +++ b/libs/sqlalchemy/event.py @@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL') def listen(target, identifier, fn, *args, **kw): """Register a listener function for the given target. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( table.name, @@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw): def listens_for(target, identifier, *args, **kw): """Decorate a function as a listener for the given target + identifier. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + @event.listens_for(UniqueConstraint, "after_parent_attach") def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( @@ -205,12 +205,14 @@ class _DispatchDescriptor(object): def insert(self, obj, target, propagate): assert isinstance(target, type), \ "Class-level Event targets must be classes." - stack = [target] while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].insert(0, obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].insert(0, obj) def append(self, obj, target, propagate): assert isinstance(target, type), \ @@ -220,7 +222,20 @@ class _DispatchDescriptor(object): while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].append(obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].append(obj) + + def update_subclass(self, target): + clslevel = self._clslevel[target] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend([ + fn for fn + in self._clslevel[cls] + if fn not in clslevel + ]) def remove(self, obj, target): stack = [target] @@ -252,6 +267,8 @@ class _ListenerCollection(object): _exec_once = False def __init__(self, parent, target_cls): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) self.parent_listeners = parent._clslevel[target_cls] self.name = parent.__name__ self.listeners = [] diff --git a/libs/sqlalchemy/exc.py b/libs/sqlalchemy/exc.py index 64f25a2..91ffc28 100644 --- a/libs/sqlalchemy/exc.py +++ b/libs/sqlalchemy/exc.py @@ -162,7 +162,7 @@ UnmappedColumnError = None class StatementError(SQLAlchemyError): """An error occurred during execution of a SQL statement. - :class:`.StatementError` wraps the exception raised + :class:`StatementError` wraps the exception raised during execution, and features :attr:`.statement` and :attr:`.params` attributes which supply context regarding the specifics of the statement which had an issue. @@ -172,6 +172,15 @@ class StatementError(SQLAlchemyError): """ + statement = None + """The string SQL statement being invoked when this exception occurred.""" + + params = None + """The parameter list being used when this exception occurred.""" + + orig = None + """The DBAPI exception object.""" + def __init__(self, message, statement, params, orig): SQLAlchemyError.__init__(self, message) self.statement = statement @@ -192,21 +201,21 @@ class StatementError(SQLAlchemyError): class DBAPIError(StatementError): """Raised when the execution of a database operation fails. - ``DBAPIError`` wraps exceptions raised by the DB-API underlying the + Wraps exceptions raised by the DB-API underlying the database operation. Driver-specific implementations of the standard DB-API exception types are wrapped by matching sub-types of SQLAlchemy's - ``DBAPIError`` when possible. DB-API's ``Error`` type maps to - ``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note + :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to + :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note that there is no guarantee that different DB-API implementations will raise the same exception type for any given error condition. - :class:`.DBAPIError` features :attr:`.statement` - and :attr:`.params` attributes which supply context regarding + :class:`DBAPIError` features :attr:`~.StatementError.statement` + and :attr:`~.StatementError.params` attributes which supply context regarding the specifics of the statement which had an issue, for the typical case when the error was raised within the context of emitting a SQL statement. - The wrapped exception object is available in the :attr:`.orig` attribute. + The wrapped exception object is available in the :attr:`~.StatementError.orig` attribute. Its type and properties are DB-API implementation specific. """ diff --git a/libs/sqlalchemy/ext/declarative.py b/libs/sqlalchemy/ext/declarative.py index 891130a..faf575d 100755 --- a/libs/sqlalchemy/ext/declarative.py +++ b/libs/sqlalchemy/ext/declarative.py @@ -1213,6 +1213,12 @@ def _as_declarative(cls, classname, dict_): del our_stuff[key] cols = sorted(cols, key=lambda c:c._creation_order) table = None + + if hasattr(cls, '__table_cls__'): + table_cls = util.unbound_method_to_callable(cls.__table_cls__) + else: + table_cls = Table + if '__table__' not in dict_: if tablename is not None: @@ -1230,7 +1236,7 @@ def _as_declarative(cls, classname, dict_): if autoload: table_kw['autoload'] = True - cls.__table__ = table = Table(tablename, cls.metadata, + cls.__table__ = table = table_cls(tablename, cls.metadata, *(tuple(cols) + tuple(args)), **table_kw) else: diff --git a/libs/sqlalchemy/ext/hybrid.py b/libs/sqlalchemy/ext/hybrid.py index 086ec90..8734181 100644 --- a/libs/sqlalchemy/ext/hybrid.py +++ b/libs/sqlalchemy/ext/hybrid.py @@ -11,30 +11,30 @@ class level and at the instance level. The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of method decorator, is around 50 lines of code and has almost no dependencies on the rest -of SQLAlchemy. It can in theory work with any class-level expression generator. +of SQLAlchemy. It can, in theory, work with any descriptor-based expression +system. -Consider a table ``interval`` as below:: - - from sqlalchemy import MetaData, Table, Column, Integer - - metadata = MetaData() - - interval_table = Table('interval', metadata, - Column('id', Integer, primary_key=True), - Column('start', Integer, nullable=False), - Column('end', Integer, nullable=False) - ) - -We can define higher level functions on mapped classes that produce SQL -expressions at the class level, and Python expression evaluation at the -instance level. Below, each function decorated with :func:`.hybrid_method` -or :func:`.hybrid_property` may receive ``self`` as an instance of the class, -or as the class itself:: +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce +SQL expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +as the class itself:: + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import Session, aliased from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method - from sqlalchemy.orm import mapper, Session, aliased + + Base = declarative_base() + + class Interval(Base): + __tablename__ = 'interval' + + id = Column(Integer, primary_key=True) + start = Column(Integer, nullable=False) + end = Column(Integer, nullable=False) - class Interval(object): def __init__(self, start, end): self.start = start self.end = end @@ -51,8 +51,6 @@ or as the class itself:: def intersects(self, other): return self.contains(other.start) | self.contains(other.end) - mapper(Interval, interval_table) - Above, the ``length`` property returns the difference between the ``end`` and ``start`` attributes. With an instance of ``Interval``, this subtraction occurs in Python, using normal Python descriptor mechanics:: @@ -60,10 +58,11 @@ in Python, using normal Python descriptor mechanics:: >>> i1 = Interval(5, 10) >>> i1.length 5 - -At the class level, the usual descriptor behavior of returning the descriptor -itself is modified by :class:`.hybrid_property`, to instead evaluate the function -body given the ``Interval`` class as the argument:: + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +returns a new SQL expression:: >>> print Interval.length interval."end" - interval.start @@ -83,9 +82,10 @@ locate attributes, so can also be used with hybrid attributes:: FROM interval WHERE interval."end" - interval.start = :param_1 -The ``contains()`` and ``intersects()`` methods are decorated with :class:`.hybrid_method`. -This decorator applies the same idea to methods which accept -zero or more arguments. The above methods return boolean values, and take advantage +The ``Interval`` class example also illustrates two methods, ``contains()`` and ``intersects()``, +decorated with :class:`.hybrid_method`. +This decorator applies the same idea to methods that :class:`.hybrid_property` applies +to attributes. The methods return boolean values, and take advantage of the Python ``|`` and ``&`` bitwise operators to produce equivalent instance-level and SQL expression-level boolean behavior:: @@ -368,7 +368,12 @@ SQL expression versus SQL expression:: >>> sw1 = aliased(SearchWord) >>> sw2 = aliased(SearchWord) - >>> print Session().query(sw1.word_insensitive, sw2.word_insensitive).filter(sw1.word_insensitive > sw2.word_insensitive) + >>> print Session().query( + ... sw1.word_insensitive, + ... sw2.word_insensitive).\\ + ... filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) SELECT lower(searchword_1.word) AS lower_1, lower(searchword_2.word) AS lower_2 FROM searchword AS searchword_1, searchword AS searchword_2 WHERE lower(searchword_1.word) > lower(searchword_2.word) diff --git a/libs/sqlalchemy/ext/orderinglist.py b/libs/sqlalchemy/ext/orderinglist.py index 9847861..3895725 100644 --- a/libs/sqlalchemy/ext/orderinglist.py +++ b/libs/sqlalchemy/ext/orderinglist.py @@ -184,12 +184,11 @@ class OrderingList(list): This implementation relies on the list starting in the proper order, so be **sure** to put an ``order_by`` on your relationship. - ordering_attr + :param ordering_attr: Name of the attribute that stores the object's order in the relationship. - ordering_func - Optional. A function that maps the position in the Python list to a + :param ordering_func: Optional. A function that maps the position in the Python list to a value to store in the ``ordering_attr``. Values returned are usually (but need not be!) integers. @@ -202,7 +201,7 @@ class OrderingList(list): like stepped numbering, alphabetical and Fibonacci numbering, see the unit tests. - reorder_on_append + :param reorder_on_append: Default False. When appending an object with an existing (non-None) ordering value, that value will be left untouched unless ``reorder_on_append`` is true. This is an optimization to avoid a diff --git a/libs/sqlalchemy/orm/collections.py b/libs/sqlalchemy/orm/collections.py index 7872715..160fac8 100644 --- a/libs/sqlalchemy/orm/collections.py +++ b/libs/sqlalchemy/orm/collections.py @@ -112,12 +112,32 @@ from sqlalchemy.sql import expression from sqlalchemy import schema, util, exc as sa_exc + __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] __instrumentation_mutex = util.threading.Lock() +class _SerializableColumnGetter(object): + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return _SerializableColumnGetter, (self.colkeys,) + + def __call__(self, value): + state = instance_state(value) + m = _state_mapper(state) + key = [m._get_state_attr_by_column( + state, state.dict, + m.mapped_table.columns[k]) + for k in self.colkeys] + if self.composite: + return tuple(key) + else: + return key[0] def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec): after a session flush. """ + global _state_mapper, instance_state from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._only_column_elements(q, "mapping_spec") - for q in util.to_list(mapping_spec)] - if len(cols) == 1: - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return m._get_state_attr_by_column(state, state.dict, cols[0]) - else: - mapping_spec = tuple(cols) - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return tuple(m._get_state_attr_by_column(state, state.dict, c) - for c in mapping_spec) + cols = [c.key for c in [ + expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec)]] + keyfunc = _SerializableColumnGetter(cols) return lambda: MappedCollection(keyfunc) +class _SerializableAttrGetter(object): + def __init__(self, name): + self.name = name + self.getter = operator.attrgetter(name) + + def __call__(self, target): + return self.getter(target) + + def __reduce__(self): + return _SerializableAttrGetter, (self.name, ) + def attribute_mapped_collection(attr_name): """A dictionary-based collection type with attribute-based keying. @@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name): after a session flush. """ - return lambda: MappedCollection(operator.attrgetter(attr_name)) + getter = _SerializableAttrGetter(attr_name) + return lambda: MappedCollection(getter) def mapped_collection(keyfunc): @@ -814,6 +837,7 @@ def _instrument_class(cls): methods[name] = None, None, after # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): fn = getattr(cls, method, None) if (fn and method not in methods and @@ -1465,3 +1489,13 @@ class MappedCollection(dict): incoming_key, value, new_key)) yield value _convert = collection.converter(_convert) + +# ensure instrumentation is associated with +# these built-in classes; if a user-defined class +# subclasses these and uses @internally_instrumented, +# the superclass is otherwise not instrumented. +# see [ticket:2406]. +_instrument_class(MappedCollection) +_instrument_class(InstrumentedList) +_instrument_class(InstrumentedSet) + diff --git a/libs/sqlalchemy/orm/mapper.py b/libs/sqlalchemy/orm/mapper.py index 4c952c1..e96b754 100644 --- a/libs/sqlalchemy/orm/mapper.py +++ b/libs/sqlalchemy/orm/mapper.py @@ -1452,12 +1452,19 @@ class Mapper(object): return result def _is_userland_descriptor(self, obj): - return not isinstance(obj, - (MapperProperty, attributes.QueryableAttribute)) and \ - hasattr(obj, '__get__') and not \ - isinstance(obj.__get__(None, obj), - attributes.QueryableAttribute) - + if isinstance(obj, (MapperProperty, + attributes.QueryableAttribute)): + return False + elif not hasattr(obj, '__get__'): + return False + else: + obj = util.unbound_method_to_callable(obj) + if isinstance( + obj.__get__(None, obj), + attributes.QueryableAttribute + ): + return False + return True def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly @@ -1875,501 +1882,6 @@ class Mapper(object): self._memoized_values[key] = value = callable_() return value - def _post_update(self, states, uowtransaction, post_update_cols): - """Issue UPDATE statements on behalf of a relationship() which - specifies post_update. - - """ - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - # if session has a connection callable, - # organize individual states with the connection - # to use for update - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - for state in _sort_states(states): - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - mapper = _state_mapper(state) - - tups.append((state, state.dict, mapper, conn)) - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper: - update = [] - - for state, state_dict, mapper, connection in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - params = {} - hasdata = False - - for col in mapper._cols_by_table[table]: - if col in pks: - params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col) - elif col in post_update_cols: - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - value = history.added[0] - params[col.key] = value - hasdata = True - if hasdata: - update.append((state, state_dict, params, mapper, - connection)) - - if update: - mapper = table_to_mapper[table] - - def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - return table.update(clause) - - statement = self._memo(('post_update', table), update_stmt) - - # execute each UPDATE in the order according to the original - # list of states to guarantee row access order, but - # also group them into common (connection, cols) sets - # to support executemany(). - for key, grouper in groupby( - update, lambda rec: (rec[4], rec[2].keys()) - ): - multiparams = [params for state, state_dict, - params, mapper, conn in grouper] - cached_connections[connection].\ - execute(statement, multiparams) - - def _save_obj(self, states, uowtransaction, single=False): - """Issue ``INSERT`` and/or ``UPDATE`` statements for a list - of objects. - - This is called within the context of a UOWTransaction during a - flush operation, given a list of states to be flushed. The - base mapper in an inheritance hierarchy handles the inserts/ - updates for all descendant mappers. - - """ - - # if batch=false, call _save_obj separately for each object - if not single and not self.batch: - for state in _sort_states(states): - self._save_obj([state], - uowtransaction, - single=True) - return - - # if session has a connection callable, - # organize individual states with the connection - # to use for insert/update - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - - for state in _sort_states(states): - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - has_identity = bool(state.key) - mapper = _state_mapper(state) - instance_key = state.key or mapper._identity_key_from_state(state) - - row_switch = None - - # call before_XXX extensions - if not has_identity: - mapper.dispatch.before_insert(mapper, conn, state) - else: - mapper.dispatch.before_update(mapper, conn, state) - - # detect if we have a "pending" instance (i.e. has - # no instance_key attached to it), and another instance - # with the same identity key already exists as persistent. - # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise orm_exc.FlushError( - "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) - - self._log_debug( - "detected row switch for identity %s. " - "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) - - # remove the "delete" flag from the existing element - uowtransaction.remove_state_actions(existing) - row_switch = existing - - tups.append( - (state, state.dict, mapper, conn, - has_identity, instance_key, row_switch) - ) - - # dictionary of connection->connection_with_cache_options. - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper: - insert = [] - update = [] - - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - - isinsert = not has_identity and not row_switch - - params = {} - value_params = {} - - if isinsert: - has_all_pks = True - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col.key] = \ - mapper.version_id_generator(None) - else: - # pull straight from the dict for - # pending objects - prop = mapper._columntoproperty[col] - value = state_dict.get(prop.key, None) - - if value is None: - if col in pks: - has_all_pks = False - elif col.default is None and \ - col.server_default is None: - params[col.key] = value - - elif isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value - - insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks)) - else: - hasdata = False - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col._label] = \ - mapper._get_committed_state_attr_by_column( - row_switch or state, - row_switch and row_switch.dict - or state_dict, - col) - - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE - ) - if history.added: - params[col.key] = history.added[0] - hasdata = True - else: - params[col.key] = \ - mapper.version_id_generator( - params[col._label]) - - # HACK: check for history, in case the - # history is only - # in a different table than the one - # where the version_id_col is. - for prop in mapper._columntoproperty.\ - itervalues(): - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - hasdata = True - else: - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - if isinstance(history.added[0], - sql.ClauseElement): - value_params[col] = history.added[0] - else: - value = history.added[0] - params[col.key] = value - - if col in pks: - if history.deleted and \ - not row_switch: - # if passive_updates and sync detected - # this was a pk->pk sync, use the new - # value to locate the row, since the - # DB would already have set this - if ("pk_cascaded", state, col) in \ - uowtransaction.\ - attributes: - value = history.added[0] - params[col._label] = value - else: - # use the old value to - # locate the row - value = history.deleted[0] - params[col._label] = value - hasdata = True - else: - # row switch logic can reach us here - # remove the pk from the update params - # so the update doesn't - # attempt to include the pk in the - # update statement - del params[col.key] - value = history.added[0] - params[col._label] = value - if value is None and hasdata: - raise sa_exc.FlushError( - "Can't update table " - "using NULL for primary key " - "value") - else: - hasdata = True - elif col in pks: - value = state.manager[prop.key].\ - impl.get(state, state_dict) - if value is None: - raise sa_exc.FlushError( - "Can't update table " - "using NULL for primary " - "key value") - params[col._label] = value - if hasdata: - update.append((state, state_dict, params, mapper, - connection, value_params)) - - if update: - mapper = table_to_mapper[table] - - needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) - - def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append(mapper.version_id_col ==\ - sql.bindparam(mapper.version_id_col._label, - type_=col.type)) - - return table.update(clause) - - statement = self._memo(('update', table), update_stmt) - - rows = 0 - for state, state_dict, params, mapper, \ - connection, value_params in update: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) - else: - c = cached_connections[connection].\ - execute(statement, params) - - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, - c.context.compiled_parameters[0], - value_params) - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to update %d row(s); " - "%d were matched." % - (table.description, len(update), rows)) - - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) - - if insert: - statement = self._memo(('insert', table), table.insert) - - for (connection, pkeys, hasvalue, has_all_pks), \ - records in groupby(insert, - lambda rec: (rec[4], - rec[2].keys(), - bool(rec[5]), - rec[6]) - ): - if has_all_pks and not hasvalue: - records = list(records) - multiparams = [rec[2] for rec in records] - c = cached_connections[connection].\ - execute(statement, multiparams) - - for (state, state_dict, params, mapper, - conn, value_params, has_all_pks), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, - last_inserted_params, - value_params) - - else: - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_pks in records: - - if value_params: - result = connection.execute( - statement.values(value_params), - params) - else: - result = cached_connections[connection].\ - execute(statement, params) - - primary_key = result.context.inserted_primary_key - - if primary_key is not None: - # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): - prop = mapper._columntoproperty[col] - if state_dict.get(prop.key) is None: - # TODO: would rather say: - #state_dict[prop.key] = pk - mapper._set_state_attr_by_column( - state, - state_dict, - col, pk) - - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - result.context.prefetch_cols, - result.context.postfetch_cols, - result.context.compiled_parameters[0], - value_params) - - - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: - - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [p.key for p in mapper._readonly_props - if p.expire_on_flush or p.key not in state.dict] - ) - if readonly: - state.expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, - # refresh whatever has been expired. - if self.eager_defaults and state.unloaded: - state.key = self._identity_key_from_state(state) - uowtransaction.session.query(self)._load_on_ident( - state.key, refresh_state=state, - only_load_props=state.unloaded) - - # call after_XXX extensions - if not has_identity: - mapper.dispatch.after_insert(mapper, connection, state) - else: - mapper.dispatch.after_update(mapper, connection, state) - - def _postfetch(self, uowtransaction, table, - state, dict_, prefetch_cols, postfetch_cols, - params, value_params): - """During a flush, expire attributes in need of newly - persisted database state.""" - - if self.version_id_col is not None: - prefetch_cols = list(prefetch_cols) + [self.version_id_col] - - for c in prefetch_cols: - if c.key in params and c in self._columntoproperty: - self._set_state_attr_by_column(state, dict_, c, params[c.key]) - - if postfetch_cols: - state.expire_attributes(state.dict, - [self._columntoproperty[c].key - for c in postfetch_cols if c in - self._columntoproperty] - ) - - # synchronize newly inserted ids from one table to the next - # TODO: this still goes a little too often. would be nice to - # have definitive list of "columns that changed" here - for m, equated_pairs in self._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - self.passive_updates) - @util.memoized_property def _table_to_equated(self): """memoized map of tables to collections of columns to be @@ -2387,128 +1899,6 @@ class Mapper(object): return result - def _delete_obj(self, states, uowtransaction): - """Issue ``DELETE`` statements for a list of objects. - - This is called within the context of a UOWTransaction during a - flush operation. - - """ - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - for state in _sort_states(states): - mapper = _state_mapper(state) - - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - mapper.dispatch.before_delete(mapper, conn, state) - - tups.append((state, - state.dict, - _state_mapper(state), - bool(state.key), - conn)) - - table_to_mapper = self._sorted_tables - - for table in reversed(table_to_mapper.keys()): - delete = util.defaultdict(list) - for state, state_dict, mapper, has_identity, connection in tups: - if not has_identity or table not in mapper._pks_by_table: - continue - - params = {} - delete[connection].append(params) - for col in mapper._pks_by_table[table]: - params[col.key] = \ - value = \ - mapper._get_state_attr_by_column( - state, state_dict, col) - if value is None: - raise sa_exc.FlushError( - "Can't delete from table " - "using NULL for primary " - "key value") - - if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, - mapper.version_id_col) - - mapper = table_to_mapper[table] - need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) - - def delete_stmt(): - clause = sql.and_() - for col in mapper._pks_by_table[table]: - clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) - - if need_version_id: - clause.clauses.append( - mapper.version_id_col == - sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type - ) - ) - - return table.delete(clause) - - for connection, del_objects in delete.iteritems(): - statement = self._memo(('delete', table), delete_stmt) - rows = -1 - - connection = cached_connections[connection] - - if need_version_id and \ - not connection.dialect.supports_sane_multi_rowcount: - # TODO: need test coverage for this [ticket:1761] - if connection.dialect.supports_sane_rowcount: - rows = 0 - # execute deletes individually so that versioned - # rows can be verified - for params in del_objects: - c = connection.execute(statement, params) - rows += c.rowcount - else: - util.warn( - "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." % - connection.dialect.dialect_description, - stacklevel=12) - connection.execute(statement, del_objects) - else: - c = connection.execute(statement, del_objects) - if connection.dialect.supports_sane_multi_rowcount: - rows = c.rowcount - - if rows != -1 and rows != len(del_objects): - raise orm_exc.StaleDataError( - "DELETE statement on table '%s' expected to delete %d row(s); " - "%d were matched." % - (table.description, len(del_objects), c.rowcount) - ) - - for state, state_dict, mapper, has_identity, connection in tups: - mapper.dispatch.after_delete(mapper, connection, state) def _instance_processor(self, context, path, reduced_path, adapter, polymorphic_from=None, @@ -2518,6 +1908,12 @@ class Mapper(object): """Produce a mapper level row processor callable which processes rows into mapped instances.""" + # note that this method, most of which exists in a closure + # called _instance(), resists being broken out, as + # attempts to do so tend to add significant function + # call overhead. _instance() is the most + # performance-critical section in the whole ORM. + pk_cols = self.primary_key if polymorphic_from or refresh_state: @@ -2961,13 +2357,6 @@ def _event_on_resurrect(state): state, state.dict, col, val) -def _sort_states(states): - pending = set(states) - persistent = set(s for s in pending if s.key is not None) - pending.difference_update(persistent) - return sorted(pending, key=operator.attrgetter("insert_order")) + \ - sorted(persistent, key=lambda q:q.key[1]) - class _ColumnMapping(util.py25_dict): """Error reporting helper for mapper._columntoproperty.""" diff --git a/libs/sqlalchemy/orm/persistence.py b/libs/sqlalchemy/orm/persistence.py new file mode 100644 index 0000000..55b9bf8 --- /dev/null +++ b/libs/sqlalchemy/orm/persistence.py @@ -0,0 +1,777 @@ +# orm/persistence.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" + +import operator +from itertools import groupby + +from sqlalchemy import sql, util, exc as sa_exc +from sqlalchemy.orm import attributes, sync, \ + exc as orm_exc + +from sqlalchemy.orm.util import _state_mapper, state_str + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_insert, states_to_update = _organize_states_for_save( + base_mapper, + states, + uowtransaction) + + cached_connections = _cached_connection_dict(base_mapper) + + for table, mapper in base_mapper._sorted_tables.iteritems(): + insert = _collect_insert_commands(base_mapper, uowtransaction, + table, states_to_insert) + + update = _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update) + + if update: + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + if insert: + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + table, insert) + + _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update) + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + cached_connections = _cached_connection_dict(base_mapper) + + states_to_update = _organize_states_for_post_update( + base_mapper, + states, uowtransaction) + + + for table, mapper in base_mapper._sorted_tables.iteritems(): + update = _collect_post_update_commands(base_mapper, uowtransaction, + table, states_to_update, + post_update_cols) + + if update: + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + cached_connections = _cached_connection_dict(base_mapper) + + states_to_delete = _organize_states_for_delete( + base_mapper, + states, + uowtransaction) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(table_to_mapper.keys()): + delete = _collect_delete_commands(base_mapper, uowtransaction, + table, states_to_delete) + + mapper = table_to_mapper[table] + + _emit_delete_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, delete) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + states_to_insert = [] + states_to_update = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + has_identity = bool(state.key) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: + instance = \ + uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, + state_str(existing))) + + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if not has_identity and not row_switch: + states_to_insert.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + else: + states_to_update.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + + return states_to_insert, states_to_update + +def _organize_states_for_post_update(base_mapper, states, + uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return list(_connections_for_states(base_mapper, uowtransaction, + states)) + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + states_to_delete = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + mapper.dispatch.before_delete(mapper, connection, state) + + states_to_delete.append((state, dict_, mapper, + bool(state.key), connection)) + return states_to_delete + +def _collect_insert_commands(base_mapper, uowtransaction, table, + states_to_insert): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + insert = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + has_all_pks = True + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col.key] = mapper.version_id_generator(None) + else: + # pull straight from the dict for + # pending objects + prop = mapper._columntoproperty[col] + value = state_dict.get(prop.key, None) + + if value is None: + if col in pks: + has_all_pks = False + elif col.default is None and \ + col.server_default is None: + params[col.key] = value + + elif isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + + insert.append((state, state_dict, params, mapper, + connection, value_params, has_all_pks)) + return insert + +def _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + update = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_update: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + hasdata = hasnull = False + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = \ + mapper._get_committed_state_attr_by_column( + row_switch or state, + row_switch and row_switch.dict + or state_dict, + col) + + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + params[col.key] = history.added[0] + hasdata = True + else: + params[col.key] = mapper.version_id_generator( + params[col._label]) + + # HACK: check for history, in case the + # history is only + # in a different table than the one + # where the version_id_col is. + for prop in mapper._columntoproperty.itervalues(): + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + hasdata = True + else: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + if isinstance(history.added[0], + sql.ClauseElement): + value_params[col] = history.added[0] + else: + value = history.added[0] + params[col.key] = value + + if col in pks: + if history.deleted and \ + not row_switch: + # if passive_updates and sync detected + # this was a pk->pk sync, use the new + # value to locate the row, since the + # DB would already have set this + if ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + value = history.added[0] + params[col._label] = value + else: + # use the old value to + # locate the row + value = history.deleted[0] + params[col._label] = value + hasdata = True + else: + # row switch logic can reach us here + # remove the pk from the update params + # so the update doesn't + # attempt to include the pk in the + # update statement + del params[col.key] + value = history.added[0] + params[col._label] = value + if value is None: + hasnull = True + else: + hasdata = True + elif col in pks: + value = state.manager[prop.key].impl.get( + state, state_dict) + if value is None: + hasnull = True + params[col._label] = value + if hasdata: + if hasnull: + raise sa_exc.FlushError( + "Can't update table " + "using NULL for primary " + "key value") + update.append((state, state_dict, params, mapper, + connection, value_params)) + return update + + +def _collect_post_update_commands(base_mapper, uowtransaction, table, + states_to_update, post_update_cols): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + update = [] + for state, state_dict, mapper, connection in states_to_update: + if table not in mapper._pks_by_table: + continue + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = \ + mapper._get_state_attr_by_column( + state, + state_dict, col) + elif col in post_update_cols: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + update.append((state, state_dict, params, mapper, + connection)) + return update + +def _collect_delete_commands(base_mapper, uowtransaction, table, + states_to_delete): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + delete = util.defaultdict(list) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + if not has_identity or table not in mapper._pks_by_table: + continue + + params = {} + delete[connection].append(params) + for col in mapper._pks_by_table[table]: + params[col.key] = \ + value = \ + mapper._get_state_attr_by_column( + state, state_dict, col) + if value is None: + raise sa_exc.FlushError( + "Can't delete from table " + "using NULL for primary " + "key value") + + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, + mapper.version_id_col) + return delete + + +def _emit_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append(mapper.version_id_col ==\ + sql.bindparam(mapper.version_id_col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('update', table), update_stmt) + + rows = 0 + for state, state_dict, params, mapper, \ + connection, value_params in update: + + if value_params: + c = connection.execute( + statement.values(value_params), + params) + else: + c = cached_connections[connection].\ + execute(statement, params) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(update): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(update), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + +def _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, table, insert): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + statement = base_mapper._memo(('insert', table), table.insert) + + for (connection, pkeys, hasvalue, has_all_pks), \ + records in groupby(insert, + lambda rec: (rec[4], + rec[2].keys(), + bool(rec[5]), + rec[6]) + ): + if has_all_pks and not hasvalue: + records = list(records) + multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ + execute(statement, multiparams) + + for (state, state_dict, params, mapper, + conn, value_params, has_all_pks), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + last_inserted_params, + value_params) + + else: + for state, state_dict, params, mapper, \ + connection, value_params, \ + has_all_pks in records: + + if value_params: + result = connection.execute( + statement.values(value_params), + params) + else: + result = cached_connections[connection].\ + execute(statement, params) + + primary_key = result.context.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for pk, col in zip(primary_key, + mapper._pks_by_table[table]): + prop = mapper._columntoproperty[col] + if state_dict.get(prop.key) is None: + # TODO: would rather say: + #state_dict[prop.key] = pk + mapper._set_state_attr_by_column( + state, + state_dict, + col, pk) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + result.context.prefetch_cols, + result.context.postfetch_cols, + result.context.compiled_parameters[0], + value_params) + + + +def _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('post_update', table), update_stmt) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, grouper in groupby( + update, lambda rec: (rec[4], rec[2].keys()) + ): + connection = key[0] + multiparams = [params for state, state_dict, + params, mapper, conn in grouper] + cached_connections[connection].\ + execute(statement, multiparams) + + +def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, + mapper, table, delete): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def delete_stmt(): + clause = sql.and_() + for col in mapper._pks_by_table[table]: + clause.clauses.append( + col == sql.bindparam(col.key, type_=col.type)) + + if need_version_id: + clause.clauses.append( + mapper.version_id_col == + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + + return table.delete(clause) + + for connection, del_objects in delete.iteritems(): + statement = base_mapper._memo(('delete', table), delete_stmt) + + connection = cached_connections[connection] + + if need_version_id: + # TODO: need test coverage for this [ticket:1761] + if connection.dialect.supports_sane_rowcount: + rows = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows += c.rowcount + if rows != len(del_objects): + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched." % + (table.description, len(del_objects), c.rowcount) + ) + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + connection.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + connection.execute(statement, del_objects) + + +def _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert + \ + states_to_update: + + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state.expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, + # refresh whatever has been expired. + if base_mapper.eager_defaults and state.unloaded: + state.key = base_mapper._identity_key_from_state(state) + uowtransaction.session.query(base_mapper)._load_on_ident( + state.key, refresh_state=state, + only_load_props=state.unloaded) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + +def _postfetch(mapper, uowtransaction, table, + state, dict_, prefetch_cols, postfetch_cols, + params, value_params): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + if mapper.version_id_col is not None: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) + + if postfetch_cols: + state.expire_attributes(state.dict, + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = \ + uowtransaction.session.connection_callable + else: + connection = uowtransaction.transaction.connection( + base_mapper) + connection_callable = None + + for state in _sort_states(states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + + mapper = _state_mapper(state) + + yield state, state.dict, mapper, connection + +def _cached_connection_dict(base_mapper): + # dictionary of connection->connection_with_cache_options. + return util.PopulateDict( + lambda conn:conn.execution_options( + compiled_cache=base_mapper._compiled_cache + )) + +def _sort_states(states): + pending = set(states) + persistent = set(s for s in pending if s.key is not None) + pending.difference_update(persistent) + return sorted(pending, key=operator.attrgetter("insert_order")) + \ + sorted(persistent, key=lambda q:q.key[1]) + + diff --git a/libs/sqlalchemy/orm/query.py b/libs/sqlalchemy/orm/query.py index 9508cb5..aa3dd01 100644 --- a/libs/sqlalchemy/orm/query.py +++ b/libs/sqlalchemy/orm/query.py @@ -133,7 +133,7 @@ class Query(object): with_polymorphic = mapper._with_polymorphic_mappers if mapper.mapped_table not in \ self._polymorphic_adapters: - self.__mapper_loads_polymorphically_with(mapper, + self._mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter( selectable, mapper._equivalent_columns)) @@ -150,7 +150,7 @@ class Query(object): is_aliased_class, with_polymorphic) ent.setup_entity(entity, *d[entity]) - def __mapper_loads_polymorphically_with(self, mapper, adapter): + def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: self._polymorphic_adapters[m2] = adapter for m in m2.iterate_to_root(): @@ -174,10 +174,6 @@ class Query(object): self._from_obj_alias = sql_util.ColumnAdapter( self._from_obj[0], equivs) - def _get_polymorphic_adapter(self, entity, selectable): - self.__mapper_loads_polymorphically_with(entity.mapper, - sql_util.ColumnAdapter(selectable, - entity.mapper._equivalent_columns)) def _reset_polymorphic_adapter(self, mapper): for m2 in mapper._with_polymorphic_mappers: @@ -276,6 +272,7 @@ class Query(object): return self._select_from_entity or \ self._entity_zero().entity_zero + @property def _mapper_entities(self): # TODO: this is wrong, its hardcoded to "primary entity" when @@ -324,13 +321,6 @@ class Query(object): ) return self._entity_zero() - def _generate_mapper_zero(self): - if not getattr(self._entities[0], 'primary_entity', False): - raise sa_exc.InvalidRequestError( - "No primary mapper set up for this Query.") - entity = self._entities[0]._clone() - self._entities = [entity] + self._entities[1:] - return entity def __all_equivs(self): equivs = {} @@ -460,6 +450,62 @@ class Query(object): """ return self.enable_eagerloads(False).statement.alias(name=name) + def cte(self, name=None, recursive=False): + """Return the full SELECT statement represented by this :class:`.Query` + represented as a common table expression (CTE). + + The :meth:`.Query.cte` method is new in 0.7.6. + + Parameters and usage are the same as those of the + :meth:`._SelectBase.cte` method; see that method for + further details. + + Here is the `Postgresql WITH + RECURSIVE example `_. + Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias + of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The ``parts_alias`` + object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped + attributes are available directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String, primary_key=True) + sub_part = Column(String, primary_key=True) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\\ + filter(Part.part=="our part").\\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union_all( + session.query( + parts_alias.part, + parts_alias.sub_part, + parts_alias.quantity).\\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity') + ).\\ + group_by(included_parts.c.sub_part) + + See also: + + :meth:`._SelectBase.cte` + + """ + return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive) + def label(self, name): """Return the full SELECT statement represented by this :class:`.Query`, converted to a scalar subquery with a label of the given name. @@ -601,7 +647,12 @@ class Query(object): such as concrete table mappers. """ - entity = self._generate_mapper_zero() + + if not getattr(self._entities[0], 'primary_entity', False): + raise sa_exc.InvalidRequestError( + "No primary mapper set up for this Query.") + entity = self._entities[0]._clone() + self._entities = [entity] + self._entities[1:] entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, @@ -1041,7 +1092,22 @@ class Query(object): @_generative() def with_lockmode(self, mode): - """Return a new Query object with the specified locking mode.""" + """Return a new Query object with the specified locking mode. + + :param mode: a string representing the desired locking mode. A + corresponding value is passed to the ``for_update`` parameter of + :meth:`~sqlalchemy.sql.expression.select` when the query is + executed. Valid values are: + + ``'update'`` - passes ``for_update=True``, which translates to + ``FOR UPDATE`` (standard SQL, supported by most dialects) + + ``'update_nowait'`` - passes ``for_update='nowait'``, which + translates to ``FOR UPDATE NOWAIT`` (supported by Oracle) + + ``'read'`` - passes ``for_update='read'``, which translates to + ``LOCK IN SHARE MODE`` (supported by MySQL). + """ self._lockmode = mode @@ -1583,7 +1649,6 @@ class Query(object): consistent format with which to form the actual JOIN constructs. """ - self._polymorphic_adapters = self._polymorphic_adapters.copy() if not from_joinpoint: self._reset_joinpoint() @@ -1683,6 +1748,8 @@ class Query(object): onclause, outerjoin, create_aliases, prop): """append a JOIN to the query's from clause.""" + self._polymorphic_adapters = self._polymorphic_adapters.copy() + if left is None: if self._from_obj: left = self._from_obj[0] @@ -1696,7 +1763,29 @@ class Query(object): "are the same entity" % (left, right)) - left_mapper, left_selectable, left_is_aliased = _entity_info(left) + right, right_is_aliased, onclause = self._prepare_right_side( + right, onclause, + outerjoin, create_aliases, + prop) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint({ + '_joinpoint_entity':right, + 'prev':((left, right, prop.key), self._joinpoint) + }) + else: + self._joinpoint = { + '_joinpoint_entity':right + } + + self._join_to_left(left, right, + right_is_aliased, + onclause, outerjoin) + + def _prepare_right_side(self, right, onclause, outerjoin, + create_aliases, prop): right_mapper, right_selectable, right_is_aliased = _entity_info(right) if right_mapper: @@ -1741,24 +1830,13 @@ class Query(object): right = aliased(right) need_adapter = True - # if joining on a MapperProperty path, - # track the path to prevent redundant joins - if not create_aliases and prop: - self._update_joinpoint({ - '_joinpoint_entity':right, - 'prev':((left, right, prop.key), self._joinpoint) - }) - else: - self._joinpoint = { - '_joinpoint_entity':right - } - # if an alias() of the right side was generated here, # apply an adapter to all subsequent filter() calls # until reset_joinpoint() is called. if need_adapter: self._filter_aliases = ORMAdapter(right, - equivalents=right_mapper and right_mapper._equivalent_columns or {}, + equivalents=right_mapper and + right_mapper._equivalent_columns or {}, chain_to=self._filter_aliases) # if the onclause is a ClauseElement, adapt it with any @@ -1771,7 +1849,7 @@ class Query(object): # ensure that columns retrieved from this target in the result # set are also adapted. if aliased_entity and not create_aliases: - self.__mapper_loads_polymorphically_with( + self._mapper_loads_polymorphically_with( right_mapper, ORMAdapter( right, @@ -1779,6 +1857,11 @@ class Query(object): ) ) + return right, right_is_aliased, onclause + + def _join_to_left(self, left, right, right_is_aliased, onclause, outerjoin): + left_mapper, left_selectable, left_is_aliased = _entity_info(left) + # this is an overly broad assumption here, but there's a # very wide variety of situations where we rely upon orm.join's # adaption to glue clauses together, with joined-table inheritance's @@ -2959,7 +3042,9 @@ class _MapperEntity(_QueryEntity): # with_polymorphic() can be applied to aliases if not self.is_aliased_class: self.selectable = from_obj - self.adapter = query._get_polymorphic_adapter(self, from_obj) + query._mapper_loads_polymorphically_with(self.mapper, + sql_util.ColumnAdapter(from_obj, + self.mapper._equivalent_columns)) filter_fn = id @@ -3086,8 +3171,9 @@ class _MapperEntity(_QueryEntity): class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" - def __init__(self, query, column): + def __init__(self, query, column, namespace=None): self.expr = column + self.namespace = namespace if isinstance(column, basestring): column = sql.literal_column(column) @@ -3106,7 +3192,7 @@ class _ColumnEntity(_QueryEntity): for c in column._select_iterable: if c is column: break - _ColumnEntity(query, c) + _ColumnEntity(query, c, namespace=column) if c is not column: return @@ -3147,12 +3233,14 @@ class _ColumnEntity(_QueryEntity): if self.entities: self.entity_zero = list(self.entities)[0] + elif self.namespace is not None: + self.entity_zero = self.namespace else: self.entity_zero = None @property def entity_zero_or_selectable(self): - if self.entity_zero: + if self.entity_zero is not None: return self.entity_zero elif self.actual_froms: return list(self.actual_froms)[0] diff --git a/libs/sqlalchemy/orm/scoping.py b/libs/sqlalchemy/orm/scoping.py index ffc8ef4..3c1cd7f 100644 --- a/libs/sqlalchemy/orm/scoping.py +++ b/libs/sqlalchemy/orm/scoping.py @@ -41,8 +41,9 @@ class ScopedSession(object): scope = kwargs.pop('scope', False) if scope is not None: if self.registry.has(): - raise sa_exc.InvalidRequestError("Scoped session is already present; " - "no new arguments may be specified.") + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified.") else: sess = self.session_factory(**kwargs) self.registry.set(sess) @@ -70,8 +71,8 @@ class ScopedSession(object): self.session_factory.configure(**kwargs) def query_property(self, query_cls=None): - """return a class property which produces a `Query` object against the - class when called. + """return a class property which produces a `Query` object + against the class when called. e.g.:: @@ -121,7 +122,8 @@ def makeprop(name): def get(self): return getattr(self.registry(), name) return property(get, set) -for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'): +for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', + 'is_active', 'autoflush', 'no_autoflush'): setattr(ScopedSession, prop, makeprop(prop)) def clslevel(name): diff --git a/libs/sqlalchemy/orm/session.py b/libs/sqlalchemy/orm/session.py index 4299290..1477870 100644 --- a/libs/sqlalchemy/orm/session.py +++ b/libs/sqlalchemy/orm/session.py @@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, kwargs.update(new_kwargs) - return type("Session", (Sess, class_), {}) + return type("SessionMaker", (Sess, class_), {}) class SessionTransaction(object): @@ -978,6 +978,34 @@ class Session(object): return self._query_cls(entities, self, **kwargs) + @property + @util.contextmanager + def no_autoflush(self): + """Return a context manager that disables autoflush. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + New in 0.7.6. + + """ + autoflush = self.autoflush + self.autoflush = False + yield self + self.autoflush = autoflush + def _autoflush(self): if self.autoflush and not self._flushing: self.flush() @@ -1772,6 +1800,19 @@ class Session(object): return self.transaction and self.transaction.is_active + identity_map = None + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + See also: + + :func:`.identity_key` - operations involving identity keys. + + """ + @property def _dirty_states(self): """The set of all persistent states considered dirty. diff --git a/libs/sqlalchemy/orm/sync.py b/libs/sqlalchemy/orm/sync.py index b016e81..a20e871 100644 --- a/libs/sqlalchemy/orm/sync.py +++ b/libs/sqlalchemy/orm/sync.py @@ -6,6 +6,7 @@ """private module containing functions used for copying data between instances based on join conditions. + """ from sqlalchemy.orm import exc, util as mapperutil, attributes diff --git a/libs/sqlalchemy/orm/unitofwork.py b/libs/sqlalchemy/orm/unitofwork.py index 3cd0f15..8fc5f13 100644 --- a/libs/sqlalchemy/orm/unitofwork.py +++ b/libs/sqlalchemy/orm/unitofwork.py @@ -14,7 +14,7 @@ organizes them in order of dependency, and executes. from sqlalchemy import util, event from sqlalchemy.util import topological -from sqlalchemy.orm import attributes, interfaces +from sqlalchemy.orm import attributes, interfaces, persistence from sqlalchemy.orm import util as mapperutil session = util.importlater("sqlalchemy.orm", "session") @@ -462,7 +462,7 @@ class IssuePostUpdate(PostSortRec): states, cols = uow.post_update_states[self.mapper] states = [s for s in states if uow.states[s][0] == self.isdelete] - self.mapper._post_update(states, uow, cols) + persistence.post_update(self.mapper, states, uow, cols) class SaveUpdateAll(PostSortRec): def __init__(self, uow, mapper): @@ -470,7 +470,7 @@ class SaveUpdateAll(PostSortRec): assert mapper is mapper.base_mapper def execute(self, uow): - self.mapper._save_obj( + persistence.save_obj(self.mapper, uow.states_for_mapper_hierarchy(self.mapper, False, False), uow ) @@ -493,7 +493,7 @@ class DeleteAll(PostSortRec): assert mapper is mapper.base_mapper def execute(self, uow): - self.mapper._delete_obj( + persistence.delete_obj(self.mapper, uow.states_for_mapper_hierarchy(self.mapper, True, False), uow ) @@ -551,7 +551,7 @@ class SaveUpdateState(PostSortRec): if r.__class__ is cls_ and r.mapper is mapper] recs.difference_update(our_recs) - mapper._save_obj( + persistence.save_obj(mapper, [self.state] + [r.state for r in our_recs], uow) @@ -575,7 +575,7 @@ class DeleteState(PostSortRec): r.mapper is mapper] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] - mapper._delete_obj( + persistence.delete_obj(mapper, [s for s in states if uow.states[s][0]], uow) diff --git a/libs/sqlalchemy/orm/util.py b/libs/sqlalchemy/orm/util.py index 0cd5b05..0c5f203 100644 --- a/libs/sqlalchemy/orm/util.py +++ b/libs/sqlalchemy/orm/util.py @@ -11,6 +11,7 @@ from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\ PropComparator, MapperProperty from sqlalchemy.orm import attributes, exc import operator +import re mapperlib = util.importlater("sqlalchemy.orm", "mapperlib") @@ -20,38 +21,52 @@ all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", _INSTRUMENTOR = ('mapper', 'instrumentor') -class CascadeOptions(dict): +class CascadeOptions(frozenset): """Keeps track of the options sent to relationship().cascade""" - def __init__(self, arg=""): - if not arg: - values = set() - else: - values = set(c.strip() for c in arg.split(',')) - - for name in ['save-update', 'delete', 'refresh-expire', - 'merge', 'expunge']: - boolean = name in values or 'all' in values - setattr(self, name.replace('-', '_'), boolean) - if boolean: - self[name] = True + _add_w_all_cascades = all_cascades.difference([ + 'all', 'none', 'delete-orphan']) + _allowed_cascades = all_cascades + + def __new__(cls, arg): + values = set([ + c for c + in re.split('\s*,\s*', arg or "") + if c + ]) + + if values.difference(cls._allowed_cascades): + raise sa_exc.ArgumentError( + "Invalid cascade option(s): %s" % + ", ".join([repr(x) for x in + sorted( + values.difference(cls._allowed_cascades) + )]) + ) + + if "all" in values: + values.update(cls._add_w_all_cascades) + if "none" in values: + values.clear() + values.discard('all') + + self = frozenset.__new__(CascadeOptions, values) + self.save_update = 'save-update' in values + self.delete = 'delete' in values + self.refresh_expire = 'refresh-expire' in values + self.merge = 'merge' in values + self.expunge = 'expunge' in values self.delete_orphan = "delete-orphan" in values - if self.delete_orphan: - self['delete-orphan'] = True if self.delete_orphan and not self.delete: - util.warn("The 'delete-orphan' cascade option requires " - "'delete'.") - - for x in values: - if x not in all_cascades: - raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) + util.warn("The 'delete-orphan' cascade " + "option requires 'delete'.") + return self def __repr__(self): - return "CascadeOptions(%s)" % repr(",".join( - [x for x in ['delete', 'save_update', 'merge', 'expunge', - 'delete_orphan', 'refresh-expire'] - if getattr(self, x, False) is True])) + return "CascadeOptions(%r)" % ( + ",".join([x for x in sorted(self)]) + ) def _validator_events(desc, key, validator): """Runs a validation method on an attribute value to be set or appended.""" @@ -557,15 +572,20 @@ def _entity_descriptor(entity, key): attribute. """ - if not isinstance(entity, (AliasedClass, type)): - entity = entity.class_ + if isinstance(entity, expression.FromClause): + description = entity + entity = entity.c + elif not isinstance(entity, (AliasedClass, type)): + description = entity = entity.class_ + else: + description = entity try: return getattr(entity, key) except AttributeError: raise sa_exc.InvalidRequestError( "Entity '%s' has no property '%s'" % - (entity, key) + (description, key) ) def _orm_columns(entity): diff --git a/libs/sqlalchemy/pool.py b/libs/sqlalchemy/pool.py index a615e8c..6254a4b 100644 --- a/libs/sqlalchemy/pool.py +++ b/libs/sqlalchemy/pool.py @@ -57,6 +57,10 @@ def clear_managers(): manager.close() proxies.clear() +reset_rollback = util.symbol('reset_rollback') +reset_commit = util.symbol('reset_commit') +reset_none = util.symbol('reset_none') + class Pool(log.Identified): """Abstract base class for connection pools.""" @@ -130,7 +134,17 @@ class Pool(log.Identified): self._creator = creator self._recycle = recycle self._use_threadlocal = use_threadlocal - self._reset_on_return = reset_on_return + if reset_on_return in ('rollback', True, reset_rollback): + self._reset_on_return = reset_rollback + elif reset_on_return in (None, False, reset_none): + self._reset_on_return = reset_none + elif reset_on_return in ('commit', reset_commit): + self._reset_on_return = reset_commit + else: + raise exc.ArgumentError( + "Invalid value for 'reset_on_return': %r" + % reset_on_return) + self.echo = echo if _dispatch: self.dispatch._update(_dispatch, only_propagate=False) @@ -330,8 +344,10 @@ def _finalize_fairy(connection, connection_record, pool, ref, echo): if connection is not None: try: - if pool._reset_on_return: + if pool._reset_on_return is reset_rollback: connection.rollback() + elif pool._reset_on_return is reset_commit: + connection.commit() # Immediately close detached instances if connection_record is None: connection.close() @@ -624,11 +640,37 @@ class QueuePool(Pool): :meth:`unique_connection` method is provided to bypass the threadlocal behavior installed into :meth:`connect`. - :param reset_on_return: If true, reset the database state of - connections returned to the pool. This is typically a - ROLLBACK to release locks and transaction resources. - Disable at your own peril. Defaults to True. - + :param reset_on_return: Determine steps to take on + connections as they are returned to the pool. + As of SQLAlchemy 0.7.6, reset_on_return can have any + of these values: + + * 'rollback' - call rollback() on the connection, + to release locks and transaction resources. + This is the default value. The vast majority + of use cases should leave this value set. + * True - same as 'rollback', this is here for + backwards compatibility. + * 'commit' - call commit() on the connection, + to release locks and transaction resources. + A commit here may be desirable for databases that + cache query plans if a commit is emitted, + such as Microsoft SQL Server. However, this + value is more dangerous than 'rollback' because + any data changes present on the transaction + are committed unconditionally. + * None - don't do anything on the connection. + This setting should only be made on a database + that has no transaction support at all, + namely MySQL MyISAM. By not doing anything, + performance can be improved. This + setting should **never be selected** for a + database that supports transactions, + as it will lead to deadlocks and stale + state. + * False - same as None, this is here for + backwards compatibility. + :param listeners: A list of :class:`~sqlalchemy.interfaces.PoolListener`-like objects or dictionaries of callables that receive events when DB-API diff --git a/libs/sqlalchemy/schema.py b/libs/sqlalchemy/schema.py index f0a9297..d295143 100644 --- a/libs/sqlalchemy/schema.py +++ b/libs/sqlalchemy/schema.py @@ -80,6 +80,17 @@ def _get_table_key(name, schema): else: return schema + "." + name +def _validate_dialect_kwargs(kwargs, name): + # validate remaining kwargs that they all specify DB prefixes + if len([k for k in kwargs + if not re.match( + r'^(?:%s)_' % + '|'.join(dialects.__all__), k + ) + ]): + raise TypeError( + "Invalid argument(s) for %s: %r" % (name, kwargs.keys())) + class Table(SchemaItem, expression.TableClause): """Represent a table in a database. @@ -369,9 +380,12 @@ class Table(SchemaItem, expression.TableClause): # allow user-overrides self._init_items(*args) - def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=None): + def _autoload(self, metadata, autoload_with, include_columns, exclude_columns=()): if self.primary_key.columns: - PrimaryKeyConstraint()._set_parent_with_dispatch(self) + PrimaryKeyConstraint(*[ + c for c in self.primary_key.columns + if c.key in exclude_columns + ])._set_parent_with_dispatch(self) if autoload_with: autoload_with.run_callable( @@ -424,7 +438,7 @@ class Table(SchemaItem, expression.TableClause): if not autoload_replace: exclude_columns = [c.name for c in self.c] else: - exclude_columns = None + exclude_columns = () self._autoload(self.metadata, autoload_with, include_columns, exclude_columns) self._extra_kwargs(**kwargs) @@ -432,14 +446,7 @@ class Table(SchemaItem, expression.TableClause): def _extra_kwargs(self, **kwargs): # validate remaining kwargs that they all specify DB prefixes - if len([k for k in kwargs - if not re.match( - r'^(?:%s)_' % - '|'.join(dialects.__all__), k - ) - ]): - raise TypeError( - "Invalid argument(s) for Table: %r" % kwargs.keys()) + _validate_dialect_kwargs(kwargs, "Table") self.kwargs.update(kwargs) def _init_collections(self): @@ -1028,7 +1035,7 @@ class Column(SchemaItem, expression.ColumnClause): "The 'index' keyword argument on Column is boolean only. " "To create indexes with a specific name, create an " "explicit Index object external to the Table.") - Index(expression._generated_label('ix_%s' % self._label), self, unique=self.unique) + Index(expression._truncated_label('ix_%s' % self._label), self, unique=self.unique) elif self.unique: if isinstance(self.unique, basestring): raise exc.ArgumentError( @@ -1093,7 +1100,7 @@ class Column(SchemaItem, expression.ColumnClause): "been assigned.") try: c = self._constructor( - name or self.name, + expression._as_truncated(name or self.name), self.type, key = name or self.key, primary_key = self.primary_key, @@ -1119,6 +1126,8 @@ class Column(SchemaItem, expression.ColumnClause): c.table = selectable selectable._columns.add(c) + if selectable._is_clone_of is not None: + c._is_clone_of = selectable._is_clone_of.columns[c.name] if self.primary_key: selectable.primary_key.add(c) c.dispatch.after_parent_attach(c, selectable) @@ -1809,7 +1818,8 @@ class Constraint(SchemaItem): __visit_name__ = 'constraint' def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None): + _create_rule=None, + **kw): """Create a SQL constraint. :param name: @@ -1839,6 +1849,10 @@ class Constraint(SchemaItem): _create_rule is used by some types to create constraints. Currently, its call signature is subject to change at any time. + + :param \**kwargs: + Dialect-specific keyword parameters, see the documentation + for various dialects and constraints regarding options here. """ @@ -1847,6 +1861,8 @@ class Constraint(SchemaItem): self.initially = initially self._create_rule = _create_rule util.set_creation_order(self) + _validate_dialect_kwargs(kw, self.__class__.__name__) + self.kwargs = kw @property def table(self): @@ -2192,6 +2208,8 @@ class Index(ColumnCollectionMixin, SchemaItem): self.table = None # will call _set_parent() if table-bound column # objects are present + if not columns: + util.warn("No column names or expressions given for Index.") ColumnCollectionMixin.__init__(self, *columns) self.name = name self.unique = kw.pop('unique', False) @@ -3004,9 +3022,11 @@ def _to_schema_column(element): return element def _to_schema_column_or_string(element): - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, (basestring, expression.ColumnElement)): + raise exc.ArgumentError("Element %r is not a string name or column element" % element) + return element class _CreateDropBase(DDLElement): """Base class for DDL constucts that represent CREATE and DROP or diff --git a/libs/sqlalchemy/sql/compiler.py b/libs/sqlalchemy/sql/compiler.py index b0a55b8..c5c6f9e 100644 --- a/libs/sqlalchemy/sql/compiler.py +++ b/libs/sqlalchemy/sql/compiler.py @@ -154,9 +154,10 @@ class _CompileLabel(visitors.Visitable): __visit_name__ = 'label' __slots__ = 'element', 'name' - def __init__(self, col, name): + def __init__(self, col, name, alt_names=()): self.element = col self.name = name + self._alt_names = alt_names @property def proxy_set(self): @@ -251,6 +252,10 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + # true if the paramstyle is positional self.positional = dialect.positional if self.positional: @@ -354,14 +359,16 @@ class SQLCompiler(engine.Compiled): # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. if within_columns_clause and not within_label_clause: - if isinstance(label.name, sql._generated_label): + if isinstance(label.name, sql._truncated_label): labelname = self._truncated_identifier("colident", label.name) else: labelname = label.name if result_map is not None: - result_map[labelname.lower()] = \ - (label.name, (label, label.element, labelname),\ + result_map[labelname.lower()] = ( + label.name, + (label, label.element, labelname, ) + + label._alt_names, label.type) return label.element._compiler_dispatch(self, @@ -376,17 +383,19 @@ class SQLCompiler(engine.Compiled): **kw) def visit_column(self, column, result_map=None, **kwargs): - name = column.name + name = orig_name = column.name if name is None: raise exc.CompileError("Cannot compile Column object until " "it's 'name' is assigned.") is_literal = column.is_literal - if not is_literal and isinstance(name, sql._generated_label): + if not is_literal and isinstance(name, sql._truncated_label): name = self._truncated_identifier("colident", name) if result_map is not None: - result_map[name.lower()] = (name, (column, ), column.type) + result_map[name.lower()] = (orig_name, + (column, name, column.key), + column.type) if is_literal: name = self.escape_literal_column(name) @@ -404,7 +413,7 @@ class SQLCompiler(engine.Compiled): else: schema_prefix = '' tablename = table.name - if isinstance(tablename, sql._generated_label): + if isinstance(tablename, sql._truncated_label): tablename = self._truncated_identifier("alias", tablename) return schema_prefix + \ @@ -646,7 +655,8 @@ class SQLCompiler(engine.Compiled): if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if existing.unique or bindparam.unique: + if (existing.unique or bindparam.unique) and \ + not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " "unique bind parameter of the same name" % @@ -703,7 +713,7 @@ class SQLCompiler(engine.Compiled): return self.bind_names[bindparam] bind_name = bindparam.key - if isinstance(bind_name, sql._generated_label): + if isinstance(bind_name, sql._truncated_label): bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation @@ -715,7 +725,7 @@ class SQLCompiler(engine.Compiled): if (ident_class, name) in self.truncated_names: return self.truncated_names[(ident_class, name)] - anonname = name % self.anon_map + anonname = name.apply_map(self.anon_map) if len(anonname) > self.label_length: counter = self.truncated_names.get(ident_class, 1) @@ -744,10 +754,49 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, **kwargs): + if isinstance(cte.name, sql._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + if cte.cte_alias: + if isinstance(cte.cte_alias, sql._truncated_label): + cte_alias = self._truncated_identifier("alias", cte.cte_alias) + else: + cte_alias = cte.cte_alias + if not cte.cte_alias and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, sql.Select): + col_source = cte.original + elif isinstance(cte.original, sql.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c.key for c in util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join(recur_cols)) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.ctes[cte] = text + if asfrom: + if cte.cte_alias: + text = self.preparer.format_alias(cte, cte_alias) + text += " AS " + cte_name + else: + return self.preparer.format_alias(cte, cte_name) + return text + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: - if isinstance(alias.name, sql._generated_label): + if isinstance(alias.name, sql._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) else: alias_name = alias.name @@ -775,8 +824,14 @@ class SQLCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - elif select is not None and select.use_labels and column._label: - return _CompileLabel(column, column._label) + elif select is not None and \ + select.use_labels and \ + column._label: + return _CompileLabel( + column, + column._label, + alt_names=(column._key_label, ) + ) elif \ asfrom and \ @@ -784,7 +839,8 @@ class SQLCompiler(engine.Compiled): not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._generated_label(column.name)) + return _CompileLabel(column, sql._as_truncated(column.name), + alt_names=(column.key,)) elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) \ and (not hasattr(column, 'name') or \ @@ -799,6 +855,9 @@ class SQLCompiler(engine.Compiled): def get_from_hint_text(self, table, text): return None + def get_crud_hint_text(self, table, text): + return None + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=1, **kwargs): @@ -897,6 +956,15 @@ class SQLCompiler(engine.Compiled): if select.for_update: text += self.for_update_clause(select) + if self.ctes and \ + compound_index==1 and not entry: + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + text = cte_text + text + self.stack.pop(-1) if asfrom and parens: @@ -904,6 +972,12 @@ class SQLCompiler(engine.Compiled): else: return text + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list. @@ -977,12 +1051,26 @@ class SQLCompiler(engine.Compiled): text = "INSERT" + prefixes = [self.process(x) for x in insert_stmt._prefixes] if prefixes: text += " " + " ".join(prefixes) text += " INTO " + preparer.format_table(insert_stmt.table) + if insert_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + insert_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if insert_stmt.table in dialect_hints: + text += " " + self.get_crud_hint_text( + insert_stmt.table, + dialect_hints[insert_stmt.table] + ) + if colparams or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in colparams]) @@ -1014,21 +1102,25 @@ class SQLCompiler(engine.Compiled): extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. - + MySQL overrides this. """ return self.preparer.format_table(from_table) - def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): """Provide a hook to override the generation of an UPDATE..FROM clause. - + MySQL overrides this. """ return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, **kw) + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) for t in extra_froms) def visit_update(self, update_stmt, **kw): @@ -1045,6 +1137,21 @@ class SQLCompiler(engine.Compiled): update_stmt.table, extra_froms, **kw) + if update_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + update_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if update_stmt.table in dialect_hints: + text += " " + self.get_crud_hint_text( + update_stmt.table, + dialect_hints[update_stmt.table] + ) + else: + dialect_hints = None + text += ' SET ' if extra_froms and self.render_table_with_column_in_update_from: text += ', '.join( @@ -1067,7 +1174,8 @@ class SQLCompiler(engine.Compiled): extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, - extra_froms, **kw) + extra_froms, + dialect_hints, **kw) if extra_from_text: text += " " + extra_from_text @@ -1133,7 +1241,6 @@ class SQLCompiler(engine.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(sql._column_as_key(k), v) - # create a list of column assignment clauses as tuples values = [] @@ -1192,7 +1299,7 @@ class SQLCompiler(engine.Compiled): # "defaults", "primary key cols", etc. for c in stmt.table.columns: if c.key in parameters and c.key not in check_columns: - value = parameters[c.key] + value = parameters.pop(c.key) if sql._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is required) @@ -1288,6 +1395,17 @@ class SQLCompiler(engine.Compiled): self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) + + if parameters and stmt.parameters: + check = set(parameters).intersection( + sql._column_as_key(k) for k in stmt.parameters + ).difference(check_columns) + if check: + util.warn( + "Unconsumed column names: %s" % + (", ".join(check)) + ) + return values def visit_delete(self, delete_stmt): @@ -1296,6 +1414,21 @@ class SQLCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + if delete_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + delete_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if delete_stmt.table in dialect_hints: + text += " " + self.get_crud_hint_text( + delete_stmt.table, + dialect_hints[delete_stmt.table] + ) + else: + dialect_hints = None + if delete_stmt._returning: self.returning = delete_stmt._returning if self.returning_precedes_values: @@ -1445,7 +1578,7 @@ class DDLCompiler(engine.Compiled): return "\nDROP TABLE " + self.preparer.format_table(drop.element) def _index_identifier(self, ident): - if isinstance(ident, sql._generated_label): + if isinstance(ident, sql._truncated_label): max = self.dialect.max_index_name_length or \ self.dialect.max_identifier_length if len(ident) > max: diff --git a/libs/sqlalchemy/sql/expression.py b/libs/sqlalchemy/sql/expression.py index bff086e..aa67f44 100644 --- a/libs/sqlalchemy/sql/expression.py +++ b/libs/sqlalchemy/sql/expression.py @@ -832,6 +832,14 @@ def tuple_(*expr): [(1, 2), (5, 12), (10, 19)] ) + .. warning:: + + The composite IN construct is not supported by all backends, + and is currently known to work on Postgresql and MySQL, + but not SQLite. Unsupported backends will raise + a subclass of :class:`~sqlalchemy.exc.DBAPIError` when such + an expression is invoked. + """ return _Tuple(*expr) @@ -1275,14 +1283,48 @@ func = _FunctionGenerator() # TODO: use UnaryExpression for this instead ? modifier = _FunctionGenerator(group=False) -class _generated_label(unicode): - """A unicode subclass used to identify dynamically generated names.""" +class _truncated_label(unicode): + """A unicode subclass used to identify symbolic " + "names that may require truncation.""" + + def apply_map(self, map_): + return self + +# for backwards compatibility in case +# someone is re-implementing the +# _truncated_identifier() sequence in a custom +# compiler +_generated_label = _truncated_label + +class _anonymous_label(_truncated_label): + """A unicode subclass used to identify anonymously + generated names.""" + + def __add__(self, other): + return _anonymous_label( + unicode(self) + + unicode(other)) + + def __radd__(self, other): + return _anonymous_label( + unicode(other) + + unicode(self)) + + def apply_map(self, map_): + return self % map_ -def _escape_for_generated(x): - if isinstance(x, _generated_label): - return x +def _as_truncated(value): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(value, _truncated_label): + return value else: - return x.replace('%', '%%') + return _truncated_label(value) def _string_or_unprintable(element): if isinstance(element, basestring): @@ -1466,6 +1508,7 @@ class ClauseElement(Visitable): supports_execution = False _from_objects = [] bind = None + _is_clone_of = None def _clone(self): """Create a shallow copy of this ClauseElement. @@ -1514,7 +1557,7 @@ class ClauseElement(Visitable): f = self while f is not None: s.add(f) - f = getattr(f, '_is_clone_of', None) + f = f._is_clone_of return s def __getstate__(self): @@ -2063,6 +2106,8 @@ class ColumnElement(ClauseElement, _CompareMixin): foreign_keys = [] quote = None _label = None + _key_label = None + _alt_names = () @property def _select_iterable(self): @@ -2109,9 +2154,14 @@ class ColumnElement(ClauseElement, _CompareMixin): else: key = name - co = ColumnClause(name, selectable, type_=getattr(self, + co = ColumnClause(_as_truncated(name), + selectable, + type_=getattr(self, 'type', None)) co.proxies = [self] + if selectable._is_clone_of is not None: + co._is_clone_of = \ + selectable._is_clone_of.columns[key] selectable._columns[key] = co return co @@ -2157,7 +2207,7 @@ class ColumnElement(ClauseElement, _CompareMixin): expressions and function calls. """ - return _generated_label('%%(%d %s)s' % (id(self), getattr(self, + return _anonymous_label('%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))) class ColumnCollection(util.OrderedProperties): @@ -2420,6 +2470,13 @@ class FromClause(Selectable): """ + def embedded(expanded_proxy_set, target_set): + for t in target_set.difference(expanded_proxy_set): + if not set(_expand_cloned([t]) + ).intersection(expanded_proxy_set): + return False + return True + # dont dig around if the column is locally present if self.c.contains_column(column): return column @@ -2427,10 +2484,10 @@ class FromClause(Selectable): target_set = column.proxy_set cols = self.c for c in cols: - i = target_set.intersection(itertools.chain(*[p._cloned_set - for p in c.proxy_set])) + expanded_proxy_set = set(_expand_cloned(c.proxy_set)) + i = target_set.intersection(expanded_proxy_set) if i and (not require_embedded - or c.proxy_set.issuperset(target_set)): + or embedded(expanded_proxy_set, target_set)): if col is None: # no corresponding column yet, pick this one. @@ -2580,10 +2637,10 @@ class _BindParamClause(ColumnElement): """ if unique: - self.key = _generated_label('%%(%d %s)s' % (id(self), key + self.key = _anonymous_label('%%(%d %s)s' % (id(self), key or 'param')) else: - self.key = key or _generated_label('%%(%d param)s' + self.key = key or _anonymous_label('%%(%d param)s' % id(self)) # identifiying key that won't change across @@ -2631,14 +2688,14 @@ class _BindParamClause(ColumnElement): def _clone(self): c = ClauseElement._clone(self) if self.unique: - c.key = _generated_label('%%(%d %s)s' % (id(c), c._orig_key + c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key or 'param')) return c def _convert_to_unique(self): if not self.unique: self.unique = True - self.key = _generated_label('%%(%d %s)s' % (id(self), + self.key = _anonymous_label('%%(%d %s)s' % (id(self), self._orig_key or 'param')) def compare(self, other, **kw): @@ -3607,7 +3664,7 @@ class Alias(FromClause): if name is None: if self.original.named_with_column: name = getattr(self.original, 'name', None) - name = _generated_label('%%(%d %s)s' % (id(self), name + name = _anonymous_label('%%(%d %s)s' % (id(self), name or 'anon')) self.name = name @@ -3662,6 +3719,47 @@ class Alias(FromClause): def bind(self): return self.element.bind +class CTE(Alias): + """Represent a Common Table Expression. + + The :class:`.CTE` object is obtained using the + :meth:`._SelectBase.cte` method from any selectable. + See that method for complete examples. + + New in 0.7.6. + + """ + __visit_name__ = 'cte' + def __init__(self, selectable, + name=None, + recursive=False, + cte_alias=False): + self.recursive = recursive + self.cte_alias = cte_alias + super(CTE, self).__init__(selectable, name=name) + + def alias(self, name=None): + return CTE( + self.original, + name=name, + recursive=self.recursive, + cte_alias = self.name + ) + + def union(self, other): + return CTE( + self.original.union(other), + name=self.name, + recursive=self.recursive + ) + + def union_all(self, other): + return CTE( + self.original.union_all(other), + name=self.name, + recursive=self.recursive + ) + class _Grouping(ColumnElement): """Represent a grouping within a column expression""" @@ -3807,9 +3905,12 @@ class _Label(ColumnElement): def __init__(self, name, element, type_=None): while isinstance(element, _Label): element = element.element - self.name = self.key = self._label = name \ - or _generated_label('%%(%d %s)s' % (id(self), + if name: + self.name = name + else: + self.name = _anonymous_label('%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon'))) + self.key = self._label = self._key_label = self.name self._element = element self._type = type_ self.quote = element.quote @@ -3957,7 +4058,17 @@ class ColumnClause(_Immutable, ColumnElement): # end Py2K @_memoized_property + def _key_label(self): + if self.key != self.name: + return self._gen_label(self.key) + else: + return self._label + + @_memoized_property def _label(self): + return self._gen_label(self.name) + + def _gen_label(self, name): t = self.table if self.is_literal: return None @@ -3965,11 +4076,9 @@ class ColumnClause(_Immutable, ColumnElement): elif t is not None and t.named_with_column: if getattr(t, 'schema', None): label = t.schema.replace('.', '_') + "_" + \ - _escape_for_generated(t.name) + "_" + \ - _escape_for_generated(self.name) + t.name + "_" + name else: - label = _escape_for_generated(t.name) + "_" + \ - _escape_for_generated(self.name) + label = t.name + "_" + name # ensure the label name doesn't conflict with that # of an existing column @@ -3981,10 +4090,10 @@ class ColumnClause(_Immutable, ColumnElement): counter += 1 label = _label - return _generated_label(label) + return _as_truncated(label) else: - return self.name + return name def label(self, name): # currently, anonymous labels don't occur for @@ -4010,12 +4119,15 @@ class ColumnClause(_Immutable, ColumnElement): # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - name or self.name, + _as_truncated(name or self.name), selectable=selectable, type_=self.type, is_literal=is_literal ) c.proxies = [self] + if selectable._is_clone_of is not None: + c._is_clone_of = \ + selectable._is_clone_of.columns[c.name] if attach: selectable._columns[c.name] = c @@ -4218,6 +4330,125 @@ class _SelectBase(Executable, FromClause): """ return self.as_scalar().label(name) + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + The :meth:`._SelectBase.cte` method is new in 0.7.6. + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION ALL in order to derive rows + from those already selected. + + The following examples illustrate two examples from + Postgresql's documentation at + http://www.postgresql.org/docs/8.4/static/queries-with.html. + + Example 1, non recursive:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union_all( + select([ + parts_alias.c.part, + parts_alias.c.sub_part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity') + ]).\ + select_from(included_parts.join(parts, + included_parts.c.part==parts.c.part)).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + + See also: + + :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + @_generative @util.deprecated('0.6', message=":func:`.autocommit` is deprecated. Use " @@ -4602,7 +4833,7 @@ class Select(_SelectBase): The text of the hint is rendered in the appropriate location for the database backend in use, relative to the given :class:`.Table` or :class:`.Alias` passed as the - *selectable* argument. The dialect implementation + ``selectable`` argument. The dialect implementation typically uses Python string substitution syntax with the token ``%(name)s`` to render the name of the table or alias. E.g. when using Oracle, the @@ -4999,7 +5230,9 @@ class Select(_SelectBase): def _populate_column_collection(self): for c in self.inner_columns: if hasattr(c, '_make_proxy'): - c._make_proxy(self, name=self.use_labels and c._label or None) + c._make_proxy(self, + name=self.use_labels + and c._label or None) def self_group(self, against=None): """return a 'grouping' construct as per the ClauseElement @@ -5086,6 +5319,7 @@ class UpdateBase(Executable, ClauseElement): _execution_options = \ Executable._execution_options.union({'autocommit': True}) kwargs = util.immutabledict() + _hints = util.immutabledict() def _process_colparams(self, parameters): if isinstance(parameters, (list, tuple)): @@ -5166,6 +5400,45 @@ class UpdateBase(Executable, ClauseElement): """ self._returning = cols + @_generative + def with_hint(self, text, selectable=None, dialect_name="*"): + """Add a table hint for a single table to this + INSERT/UPDATE/DELETE statement. + + .. note:: + + :meth:`.UpdateBase.with_hint` currently applies only to + Microsoft SQL Server. For MySQL INSERT hints, use + :meth:`.Insert.prefix_with`. UPDATE/DELETE hints for + MySQL will be added in a future release. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the :class:`.Table` that is the subject of this + statement, or optionally to that of the given + :class:`.Table` passed as the ``selectable`` argument. + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add a hint + that only takes effect for SQL Server:: + + mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql") + + New in 0.7.6. + + :param text: Text of the hint. + :param selectable: optional :class:`.Table` that specifies + an element of the FROM clause within an UPDATE or DELETE + to be the subject of the hint - applies only to certain backends. + :param dialect_name: defaults to ``*``, if specified as the name + of a particular dialect, will apply these hints only when + that dialect is in use. + """ + if selectable is None: + selectable = self.table + + self._hints = self._hints.union({(selectable, dialect_name):text}) + class ValuesBase(UpdateBase): """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs.""" diff --git a/libs/sqlalchemy/sql/visitors.py b/libs/sqlalchemy/sql/visitors.py index cdcf40a..5354fbc 100644 --- a/libs/sqlalchemy/sql/visitors.py +++ b/libs/sqlalchemy/sql/visitors.py @@ -34,11 +34,19 @@ __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', 'cloned_traverse', 'replacement_traverse'] class VisitableType(type): - """Metaclass which checks for a `__visit_name__` attribute and - applies `_compiler_dispatch` method to classes. - + """Metaclass which assigns a `_compiler_dispatch` method to classes + having a `__visit_name__` attribute. + + The _compiler_dispatch attribute becomes an instance method which + looks approximately like the following:: + + def _compiler_dispatch (self, visitor, **kw): + '''Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params.''' + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + Classes having no __visit_name__ attribute will remain unaffected. """ - def __init__(cls, clsname, bases, clsdict): if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): super(VisitableType, cls).__init__(clsname, bases, clsdict) @@ -48,19 +56,31 @@ class VisitableType(type): super(VisitableType, cls).__init__(clsname, bases, clsdict) + def _generate_dispatch(cls): - # set up an optimized visit dispatch function - # for use by the compiler + """Return an optimized visit dispatch function for the cls + for use by the compiler. + """ if '__visit_name__' in cls.__dict__: visit_name = cls.__visit_name__ if isinstance(visit_name, str): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. getter = operator.attrgetter("visit_%s" % visit_name) def _compiler_dispatch(self, visitor, **kw): return getter(visitor)(self, **kw) else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. def _compiler_dispatch(self, visitor, **kw): return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + _compiler_dispatch.__doc__ = \ + """Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params. + """ cls._compiler_dispatch = _compiler_dispatch class Visitable(object): diff --git a/libs/sqlalchemy/types.py b/libs/sqlalchemy/types.py index 8c8e6eb..512ac62 100644 --- a/libs/sqlalchemy/types.py +++ b/libs/sqlalchemy/types.py @@ -397,7 +397,7 @@ class TypeDecorator(TypeEngine): def copy(self): return MyType(self.impl.length) - The class-level "impl" variable is required, and can reference any + The class-level "impl" attribute is required, and can reference any TypeEngine class. Alternatively, the load_dialect_impl() method can be used to provide different type classes based on the dialect given; in this case, the "impl" variable can reference @@ -457,15 +457,19 @@ class TypeDecorator(TypeEngine): Arguments sent here are passed to the constructor of the class assigned to the ``impl`` class level attribute, - where the ``self.impl`` attribute is assigned an instance - of the implementation type. If ``impl`` at the class level - is already an instance, then it's assigned to ``self.impl`` - as is. + assuming the ``impl`` is a callable, and the resulting + object is assigned to the ``self.impl`` instance attribute + (thus overriding the class attribute of the same name). + + If the class level ``impl`` is not a callable (the unusual case), + it will be assigned to the same instance attribute 'as-is', + ignoring those arguments passed to the constructor. Subclasses can override this to customize the generation - of ``self.impl``. + of ``self.impl`` entirely. """ + if not hasattr(self.__class__, 'impl'): raise AssertionError("TypeDecorator implementations " "require a class-level variable " @@ -475,6 +479,9 @@ class TypeDecorator(TypeEngine): def _gen_dialect_impl(self, dialect): + """ + #todo + """ adapted = dialect.type_descriptor(self) if adapted is not self: return adapted @@ -494,6 +501,9 @@ class TypeDecorator(TypeEngine): @property def _type_affinity(self): + """ + #todo + """ return self.impl._type_affinity def type_engine(self, dialect): @@ -531,7 +541,6 @@ class TypeDecorator(TypeEngine): def __getattr__(self, key): """Proxy all other undefined accessors to the underlying implementation.""" - return getattr(self.impl, key) def process_bind_param(self, value, dialect): @@ -542,29 +551,52 @@ class TypeDecorator(TypeEngine): :class:`.TypeEngine` object, and from there to the DBAPI ``execute()`` method. - :param value: the value. Can be None. + The operation could be anything desired to perform custom + behavior, such as transforming or serializing data. + This could also be used as a hook for validating logic. + + This operation should be designed with the reverse operation + in mind, which would be the process_result_value method of + this class. + + :param value: Data to operate upon, of any type expected by + this method in the subclass. Can be ``None``. :param dialect: the :class:`.Dialect` in use. """ + raise NotImplementedError() def process_result_value(self, value, dialect): """Receive a result-row column value to be converted. + Subclasses should implement this method to operate on data + fetched from the database. + Subclasses override this method to return the value that should be passed back to the application, given a value that is already processed by the underlying :class:`.TypeEngine` object, originally from the DBAPI cursor method ``fetchone()`` or similar. - :param value: the value. Can be None. + The operation could be anything desired to perform custom + behavior, such as transforming or serializing data. + This could also be used as a hook for validating logic. + + :param value: Data to operate upon, of any type expected by + this method in the subclass. Can be ``None``. :param dialect: the :class:`.Dialect` in use. + This operation should be designed to be reversible by + the "process_bind_param" method of this class. + """ + raise NotImplementedError() def bind_processor(self, dialect): - """Provide a bound value processing function for the given :class:`.Dialect`. + """Provide a bound value processing function for the + given :class:`.Dialect`. This is the method that fulfills the :class:`.TypeEngine` contract for bound value conversion. :class:`.TypeDecorator` @@ -575,6 +607,11 @@ class TypeDecorator(TypeEngine): though its likely best to use :meth:`process_bind_param` so that the processing provided by ``self.impl`` is maintained. + :param dialect: Dialect instance in use. + + This method is the reverse counterpart to the + :meth:`result_processor` method of this class. + """ if self.__class__.process_bind_param.func_code \ is not TypeDecorator.process_bind_param.func_code: @@ -604,6 +641,12 @@ class TypeDecorator(TypeEngine): though its likely best to use :meth:`process_result_value` so that the processing provided by ``self.impl`` is maintained. + :param dialect: Dialect instance in use. + :param coltype: An SQLAlchemy data type + + This method is the reverse counterpart to the + :meth:`bind_processor` method of this class. + """ if self.__class__.process_result_value.func_code \ is not TypeDecorator.process_result_value.func_code: @@ -654,6 +697,7 @@ class TypeDecorator(TypeEngine): has local state that should be deep-copied. """ + instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__) return instance @@ -724,6 +768,9 @@ class TypeDecorator(TypeEngine): return self.impl.is_mutable() def _adapt_expression(self, op, othertype): + """ + #todo + """ op, typ =self.impl._adapt_expression(op, othertype) if typ is self.impl: return op, self diff --git a/libs/sqlalchemy/util/__init__.py b/libs/sqlalchemy/util/__init__.py index 5712940..13914aa 100644 --- a/libs/sqlalchemy/util/__init__.py +++ b/libs/sqlalchemy/util/__init__.py @@ -7,7 +7,7 @@ from compat import callable, cmp, reduce, defaultdict, py25_dict, \ threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \ update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\ - parse_qsl, any + parse_qsl, any, contextmanager from _collections import NamedTuple, ImmutableContainer, immutabledict, \ Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ diff --git a/libs/sqlalchemy/util/compat.py b/libs/sqlalchemy/util/compat.py index 07652f3..99b92b1 100644 --- a/libs/sqlalchemy/util/compat.py +++ b/libs/sqlalchemy/util/compat.py @@ -57,6 +57,12 @@ buffer = buffer # end Py2K try: + from contextlib import contextmanager +except ImportError: + def contextmanager(fn): + return fn + +try: from functools import update_wrapper except ImportError: def update_wrapper(wrapper, wrapped,