541 changed files with 173394 additions and 4 deletions
@ -1,3 +0,0 @@ |
|||
[submodule "libs"] |
|||
path = libs |
|||
url = git://github.com/CouchPotato/Dependencies.git |
@ -0,0 +1,4 @@ |
|||
Dependencies |
|||
=========== |
|||
|
|||
Holds all dependencies that are required by CouchPotato. |
@ -0,0 +1 @@ |
|||
|
@ -0,0 +1,176 @@ |
|||
""" |
|||
This module contains the expressions applicable for CronTrigger's fields. |
|||
""" |
|||
from calendar import monthrange |
|||
import re |
|||
|
|||
from apscheduler.util import asint |
|||
|
|||
__all__ = ('AllExpression', 'RangeExpression', 'WeekdayRangeExpression', |
|||
'WeekdayPositionExpression') |
|||
|
|||
WEEKDAYS = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] |
|||
|
|||
|
|||
class AllExpression(object): |
|||
value_re = re.compile(r'\*(?:/(?P<step>\d+))?$') |
|||
|
|||
def __init__(self, step=None): |
|||
self.step = asint(step) |
|||
if self.step == 0: |
|||
raise ValueError('Increment must be higher than 0') |
|||
|
|||
def get_next_value(self, date, field): |
|||
start = field.get_value(date) |
|||
minval = field.get_min(date) |
|||
maxval = field.get_max(date) |
|||
start = max(start, minval) |
|||
|
|||
if not self.step: |
|||
next = start |
|||
else: |
|||
distance_to_next = (self.step - (start - minval)) % self.step |
|||
next = start + distance_to_next |
|||
|
|||
if next <= maxval: |
|||
return next |
|||
|
|||
def __str__(self): |
|||
if self.step: |
|||
return '*/%d' % self.step |
|||
return '*' |
|||
|
|||
def __repr__(self): |
|||
return "%s(%s)" % (self.__class__.__name__, self.step) |
|||
|
|||
|
|||
class RangeExpression(AllExpression): |
|||
value_re = re.compile( |
|||
r'(?P<first>\d+)(?:-(?P<last>\d+))?(?:/(?P<step>\d+))?$') |
|||
|
|||
def __init__(self, first, last=None, step=None): |
|||
AllExpression.__init__(self, step) |
|||
first = asint(first) |
|||
last = asint(last) |
|||
if last is None and step is None: |
|||
last = first |
|||
if last is not None and first > last: |
|||
raise ValueError('The minimum value in a range must not be ' |
|||
'higher than the maximum') |
|||
self.first = first |
|||
self.last = last |
|||
|
|||
def get_next_value(self, date, field): |
|||
start = field.get_value(date) |
|||
minval = field.get_min(date) |
|||
maxval = field.get_max(date) |
|||
|
|||
# Apply range limits |
|||
minval = max(minval, self.first) |
|||
if self.last is not None: |
|||
maxval = min(maxval, self.last) |
|||
start = max(start, minval) |
|||
|
|||
if not self.step: |
|||
next = start |
|||
else: |
|||
distance_to_next = (self.step - (start - minval)) % self.step |
|||
next = start + distance_to_next |
|||
|
|||
if next <= maxval: |
|||
return next |
|||
|
|||
def __str__(self): |
|||
if self.last != self.first and self.last is not None: |
|||
range = '%d-%d' % (self.first, self.last) |
|||
else: |
|||
range = str(self.first) |
|||
|
|||
if self.step: |
|||
return '%s/%d' % (range, self.step) |
|||
return range |
|||
|
|||
def __repr__(self): |
|||
args = [str(self.first)] |
|||
if self.last != self.first and self.last is not None or self.step: |
|||
args.append(str(self.last)) |
|||
if self.step: |
|||
args.append(str(self.step)) |
|||
return "%s(%s)" % (self.__class__.__name__, ', '.join(args)) |
|||
|
|||
|
|||
class WeekdayRangeExpression(RangeExpression): |
|||
value_re = re.compile(r'(?P<first>[a-z]+)(?:-(?P<last>[a-z]+))?', |
|||
re.IGNORECASE) |
|||
|
|||
def __init__(self, first, last=None): |
|||
try: |
|||
first_num = WEEKDAYS.index(first.lower()) |
|||
except ValueError: |
|||
raise ValueError('Invalid weekday name "%s"' % first) |
|||
|
|||
if last: |
|||
try: |
|||
last_num = WEEKDAYS.index(last.lower()) |
|||
except ValueError: |
|||
raise ValueError('Invalid weekday name "%s"' % last) |
|||
else: |
|||
last_num = None |
|||
|
|||
RangeExpression.__init__(self, first_num, last_num) |
|||
|
|||
def __str__(self): |
|||
if self.last != self.first and self.last is not None: |
|||
return '%s-%s' % (WEEKDAYS[self.first], WEEKDAYS[self.last]) |
|||
return WEEKDAYS[self.first] |
|||
|
|||
def __repr__(self): |
|||
args = ["'%s'" % WEEKDAYS[self.first]] |
|||
if self.last != self.first and self.last is not None: |
|||
args.append("'%s'" % WEEKDAYS[self.last]) |
|||
return "%s(%s)" % (self.__class__.__name__, ', '.join(args)) |
|||
|
|||
|
|||
class WeekdayPositionExpression(AllExpression): |
|||
options = ['1st', '2nd', '3rd', '4th', '5th', 'last'] |
|||
value_re = re.compile(r'(?P<option_name>%s) +(?P<weekday_name>(?:\d+|\w+))' |
|||
% '|'.join(options), re.IGNORECASE) |
|||
|
|||
def __init__(self, option_name, weekday_name): |
|||
try: |
|||
self.option_num = self.options.index(option_name.lower()) |
|||
except ValueError: |
|||
raise ValueError('Invalid weekday position "%s"' % option_name) |
|||
|
|||
try: |
|||
self.weekday = WEEKDAYS.index(weekday_name.lower()) |
|||
except ValueError: |
|||
raise ValueError('Invalid weekday name "%s"' % weekday_name) |
|||
|
|||
def get_next_value(self, date, field): |
|||
# Figure out the weekday of the month's first day and the number |
|||
# of days in that month |
|||
first_day_wday, last_day = monthrange(date.year, date.month) |
|||
|
|||
# Calculate which day of the month is the first of the target weekdays |
|||
first_hit_day = self.weekday - first_day_wday + 1 |
|||
if first_hit_day <= 0: |
|||
first_hit_day += 7 |
|||
|
|||
# Calculate what day of the month the target weekday would be |
|||
if self.option_num < 5: |
|||
target_day = first_hit_day + self.option_num * 7 |
|||
else: |
|||
target_day = first_hit_day + ((last_day - first_hit_day) / 7) * 7 |
|||
|
|||
if target_day <= last_day and target_day >= date.day: |
|||
return target_day |
|||
|
|||
def __str__(self): |
|||
return '%s %s' % (self.options[self.option_num], |
|||
WEEKDAYS[self.weekday]) |
|||
|
|||
def __repr__(self): |
|||
return "%s('%s', '%s')" % (self.__class__.__name__, |
|||
self.options[self.option_num], |
|||
WEEKDAYS[self.weekday]) |
@ -0,0 +1,92 @@ |
|||
""" |
|||
Fields represent :class:`~apscheduler.triggers.CronTrigger` options which map |
|||
to :class:`~datetime.datetime` fields. |
|||
""" |
|||
from calendar import monthrange |
|||
|
|||
from apscheduler.expressions import * |
|||
|
|||
__all__ = ('BaseField', 'WeekField', 'DayOfMonthField', 'DayOfWeekField') |
|||
|
|||
MIN_VALUES = {'year': 1970, 'month': 1, 'day': 1, 'week': 1, |
|||
'day_of_week': 0, 'hour': 0, 'minute': 0, 'second': 0} |
|||
MAX_VALUES = {'year': 2 ** 63, 'month': 12, 'day:': 31, 'week': 53, |
|||
'day_of_week': 6, 'hour': 23, 'minute': 59, 'second': 59} |
|||
|
|||
class BaseField(object): |
|||
REAL = True |
|||
COMPILERS = [AllExpression, RangeExpression] |
|||
|
|||
def __init__(self, name, exprs): |
|||
self.name = name |
|||
self.compile_expressions(exprs) |
|||
|
|||
def get_min(self, dateval): |
|||
return MIN_VALUES[self.name] |
|||
|
|||
def get_max(self, dateval): |
|||
return MAX_VALUES[self.name] |
|||
|
|||
def get_value(self, dateval): |
|||
return getattr(dateval, self.name) |
|||
|
|||
def get_next_value(self, dateval): |
|||
smallest = None |
|||
for expr in self.expressions: |
|||
value = expr.get_next_value(dateval, self) |
|||
if smallest is None or (value is not None and value < smallest): |
|||
smallest = value |
|||
|
|||
return smallest |
|||
|
|||
def compile_expressions(self, exprs): |
|||
self.expressions = [] |
|||
|
|||
# Split a comma-separated expression list, if any |
|||
exprs = str(exprs).strip() |
|||
if ',' in exprs: |
|||
for expr in exprs.split(','): |
|||
self.compile_expression(expr) |
|||
else: |
|||
self.compile_expression(exprs) |
|||
|
|||
def compile_expression(self, expr): |
|||
for compiler in self.COMPILERS: |
|||
match = compiler.value_re.match(expr) |
|||
if match: |
|||
compiled_expr = compiler(**match.groupdict()) |
|||
self.expressions.append(compiled_expr) |
|||
return |
|||
|
|||
raise ValueError('Unrecognized expression "%s" for field "%s"' % |
|||
(expr, self.name)) |
|||
|
|||
def __str__(self): |
|||
expr_strings = (str(e) for e in self.expressions) |
|||
return ','.join(expr_strings) |
|||
|
|||
def __repr__(self): |
|||
return "%s('%s', '%s')" % (self.__class__.__name__, self.name, |
|||
str(self)) |
|||
|
|||
|
|||
class WeekField(BaseField): |
|||
REAL = False |
|||
|
|||
def get_value(self, dateval): |
|||
return dateval.isocalendar()[1] |
|||
|
|||
|
|||
class DayOfMonthField(BaseField): |
|||
COMPILERS = BaseField.COMPILERS + [WeekdayPositionExpression] |
|||
|
|||
def get_max(self, dateval): |
|||
return monthrange(dateval.year, dateval.month)[1] |
|||
|
|||
|
|||
class DayOfWeekField(BaseField): |
|||
REAL = False |
|||
COMPILERS = BaseField.COMPILERS + [WeekdayRangeExpression] |
|||
|
|||
def get_value(self, dateval): |
|||
return dateval.weekday() |
@ -0,0 +1,407 @@ |
|||
""" |
|||
This module is the main part of the library, and is the only module that |
|||
regular users should be concerned with. |
|||
""" |
|||
from threading import Thread, Event, Lock |
|||
from datetime import datetime, timedelta |
|||
from logging import getLogger |
|||
import os |
|||
|
|||
from apscheduler.util import time_difference, asbool |
|||
from apscheduler.triggers import DateTrigger, IntervalTrigger, CronTrigger |
|||
|
|||
|
|||
logger = getLogger(__name__) |
|||
|
|||
|
|||
class Job(object): |
|||
""" |
|||
Represents a task scheduled in the scheduler. |
|||
""" |
|||
|
|||
def __init__(self, trigger, func, args, kwargs): |
|||
self.thread = None |
|||
self.trigger = trigger |
|||
self.func = func |
|||
self.args = args |
|||
self.kwargs = kwargs |
|||
if hasattr(func, '__name__'): |
|||
self.name = func.__name__ |
|||
else: |
|||
self.name = str(func) |
|||
|
|||
def run(self): |
|||
""" |
|||
Starts the execution of this job in a separate thread. |
|||
""" |
|||
if (self.thread and self.thread.isAlive()): |
|||
logger.info('Skipping run of job %s (previously triggered ' |
|||
'instance is still running)', self) |
|||
else: |
|||
self.thread = Thread(target=self.run_in_thread) |
|||
self.thread.setDaemon(False) |
|||
self.thread.start() |
|||
|
|||
def run_in_thread(self): |
|||
""" |
|||
Runs the associated callable. |
|||
This method is executed in a dedicated thread. |
|||
""" |
|||
try: |
|||
self.func(*self.args, **self.kwargs) |
|||
except: |
|||
logger.exception('Error executing job "%s"', self) |
|||
raise |
|||
|
|||
def __str__(self): |
|||
return '%s: %s' % (self.name, repr(self.trigger)) |
|||
|
|||
def __repr__(self): |
|||
return '%s(%s, %s)' % (self.__class__.__name__, self.name, |
|||
repr(self.trigger)) |
|||
|
|||
|
|||
class SchedulerShutdownError(Exception): |
|||
""" |
|||
Thrown when attempting to use the scheduler after |
|||
it's been shut down. |
|||
""" |
|||
|
|||
def __init__(self): |
|||
Exception.__init__(self, 'Scheduler has already been shut down') |
|||
|
|||
|
|||
class SchedulerAlreadyRunningError(Exception): |
|||
""" |
|||
Thrown when attempting to start the scheduler, but it's already running. |
|||
""" |
|||
|
|||
def __init__(self): |
|||
Exception.__init__(self, 'Scheduler is already running') |
|||
|
|||
|
|||
class Scheduler(object): |
|||
""" |
|||
This class is responsible for scheduling jobs and triggering |
|||
their execution. |
|||
""" |
|||
|
|||
stopped = False |
|||
thread = None |
|||
misfire_grace_time = 1 |
|||
daemonic = True |
|||
|
|||
def __init__(self, **config): |
|||
self.jobs = [] |
|||
self.jobs_lock = Lock() |
|||
self.wakeup = Event() |
|||
self.configure(config) |
|||
|
|||
def configure(self, config): |
|||
""" |
|||
Updates the configuration with the given options. |
|||
""" |
|||
for key, val in config.items(): |
|||
if key.startswith('apscheduler.'): |
|||
key = key[12:] |
|||
if key == 'misfire_grace_time': |
|||
self.misfire_grace_time = int(val) |
|||
elif key == 'daemonic': |
|||
self.daemonic = asbool(val) |
|||
|
|||
def start(self): |
|||
""" |
|||
Starts the scheduler in a new thread. |
|||
""" |
|||
if self.thread and self.thread.isAlive(): |
|||
raise SchedulerAlreadyRunningError |
|||
|
|||
self.stopped = False |
|||
self.thread = Thread(target=self.run, name='APScheduler') |
|||
self.thread.setDaemon(self.daemonic) |
|||
self.thread.start() |
|||
logger.info('Scheduler started') |
|||
|
|||
def shutdown(self, timeout=0): |
|||
""" |
|||
Shuts down the scheduler and terminates the thread. |
|||
Does not terminate any currently running jobs. |
|||
|
|||
:param timeout: time (in seconds) to wait for the scheduler thread to |
|||
terminate, 0 to wait forever, None to skip waiting |
|||
""" |
|||
if self.stopped or not self.thread.isAlive(): |
|||
return |
|||
|
|||
logger.info('Scheduler shutting down') |
|||
self.stopped = True |
|||
self.wakeup.set() |
|||
if timeout is not None: |
|||
self.thread.join(timeout) |
|||
self.jobs = [] |
|||
|
|||
def cron_schedule(self, year='*', month='*', day='*', week='*', |
|||
day_of_week='*', hour='*', minute='*', second='*', |
|||
args=None, kwargs=None): |
|||
""" |
|||
Decorator that causes its host function to be scheduled |
|||
according to the given parameters. |
|||
This decorator does not wrap its host function. |
|||
The scheduled function will be called without any arguments. |
|||
See :meth:`add_cron_job` for more information. |
|||
""" |
|||
def inner(func): |
|||
self.add_cron_job(func, year, month, day, week, day_of_week, hour, |
|||
minute, second, args, kwargs) |
|||
return func |
|||
return inner |
|||
|
|||
def interval_schedule(self, weeks=0, days=0, hours=0, minutes=0, seconds=0, |
|||
start_date=None, repeat=0, args=None, kwargs=None): |
|||
""" |
|||
Decorator that causes its host function to be scheduled |
|||
for execution on specified intervals. |
|||
This decorator does not wrap its host function. |
|||
The scheduled function will be called without any arguments. |
|||
Note that the default repeat value is 0, which means to repeat forever. |
|||
See :meth:`add_delayed_job` for more information. |
|||
""" |
|||
def inner(func): |
|||
self.add_interval_job(func, weeks, days, hours, minutes, seconds, |
|||
start_date, repeat, args, kwargs) |
|||
return func |
|||
return inner |
|||
|
|||
def _add_job(self, trigger, func, args, kwargs): |
|||
""" |
|||
Adds a Job to the job list and notifies the scheduler thread. |
|||
|
|||
:param trigger: trigger for the given callable |
|||
:param args: list of positional arguments to call func with |
|||
:param kwargs: dict of keyword arguments to call func with |
|||
:return: the scheduled job |
|||
:rtype: Job |
|||
""" |
|||
if self.stopped: |
|||
raise SchedulerShutdownError |
|||
if not hasattr(func, '__call__'): |
|||
raise TypeError('func must be callable') |
|||
|
|||
if args is None: |
|||
args = [] |
|||
if kwargs is None: |
|||
kwargs = {} |
|||
|
|||
job = Job(trigger, func, args, kwargs) |
|||
self.jobs_lock.acquire() |
|||
try: |
|||
self.jobs.append(job) |
|||
finally: |
|||
self.jobs_lock.release() |
|||
logger.info('Added job "%s"', job) |
|||
|
|||
# Notify the scheduler about the new job |
|||
self.wakeup.set() |
|||
|
|||
return job |
|||
|
|||
def add_date_job(self, func, date, args=None, kwargs=None): |
|||
""" |
|||
Adds a job to be completed on a specific date and time. |
|||
|
|||
:param func: callable to run |
|||
:param args: positional arguments to call func with |
|||
:param kwargs: keyword arguments to call func with |
|||
""" |
|||
trigger = DateTrigger(date) |
|||
return self._add_job(trigger, func, args, kwargs) |
|||
|
|||
def add_interval_job(self, func, weeks=0, days=0, hours=0, minutes=0, |
|||
seconds=0, start_date=None, repeat=0, args=None, |
|||
kwargs=None): |
|||
""" |
|||
Adds a job to be completed on specified intervals. |
|||
|
|||
:param func: callable to run |
|||
:param weeks: number of weeks to wait |
|||
:param days: number of days to wait |
|||
:param hours: number of hours to wait |
|||
:param minutes: number of minutes to wait |
|||
:param seconds: number of seconds to wait |
|||
:param start_date: when to first execute the job and start the |
|||
counter (default is after the given interval) |
|||
:param repeat: number of times the job will be run (0 = repeat |
|||
indefinitely) |
|||
:param args: list of positional arguments to call func with |
|||
:param kwargs: dict of keyword arguments to call func with |
|||
""" |
|||
interval = timedelta(weeks=weeks, days=days, hours=hours, |
|||
minutes=minutes, seconds=seconds) |
|||
trigger = IntervalTrigger(interval, repeat, start_date) |
|||
return self._add_job(trigger, func, args, kwargs) |
|||
|
|||
def add_cron_job(self, func, year='*', month='*', day='*', week='*', |
|||
day_of_week='*', hour='*', minute='*', second='*', |
|||
args=None, kwargs=None): |
|||
""" |
|||
Adds a job to be completed on times that match the given expressions. |
|||
|
|||
:param func: callable to run |
|||
:param year: year to run on |
|||
:param month: month to run on (0 = January) |
|||
:param day: day of month to run on |
|||
:param week: week of the year to run on |
|||
:param day_of_week: weekday to run on (0 = Monday) |
|||
:param hour: hour to run on |
|||
:param second: second to run on |
|||
:param args: list of positional arguments to call func with |
|||
:param kwargs: dict of keyword arguments to call func with |
|||
:return: the scheduled job |
|||
:rtype: Job |
|||
""" |
|||
trigger = CronTrigger(year=year, month=month, day=day, week=week, |
|||
day_of_week=day_of_week, hour=hour, |
|||
minute=minute, second=second) |
|||
return self._add_job(trigger, func, args, kwargs) |
|||
|
|||
def is_job_active(self, job): |
|||
""" |
|||
Determines if the given job is still on the job list. |
|||
|
|||
:return: True if the job is still active, False if not |
|||
""" |
|||
self.jobs_lock.acquire() |
|||
try: |
|||
return job in self.jobs |
|||
finally: |
|||
self.jobs_lock.release() |
|||
|
|||
def unschedule_job(self, job): |
|||
""" |
|||
Removes a job, preventing it from being fired any more. |
|||
""" |
|||
self.jobs_lock.acquire() |
|||
try: |
|||
self.jobs.remove(job) |
|||
finally: |
|||
self.jobs_lock.release() |
|||
logger.info('Removed job "%s"', job) |
|||
self.wakeup.set() |
|||
|
|||
def unschedule_func(self, func): |
|||
""" |
|||
Removes all jobs that would execute the given function. |
|||
""" |
|||
self.jobs_lock.acquire() |
|||
try: |
|||
remove_list = [job for job in self.jobs if job.func == func] |
|||
for job in remove_list: |
|||
self.jobs.remove(job) |
|||
logger.info('Removed job "%s"', job) |
|||
finally: |
|||
self.jobs_lock.release() |
|||
|
|||
# Have the scheduler calculate a new wakeup time |
|||
self.wakeup.set() |
|||
|
|||
def dump_jobs(self): |
|||
""" |
|||
Gives a textual listing of all jobs currently scheduled on this |
|||
scheduler. |
|||
|
|||
:rtype: str |
|||
""" |
|||
job_strs = [] |
|||
now = datetime.now() |
|||
self.jobs_lock.acquire() |
|||
try: |
|||
for job in self.jobs: |
|||
next_fire_time = job.trigger.get_next_fire_time(now) |
|||
job_str = '%s (next fire time: %s)' % (str(job), |
|||
next_fire_time) |
|||
job_strs.append(job_str) |
|||
finally: |
|||
self.jobs_lock.release() |
|||
|
|||
if job_strs: |
|||
return os.linesep.join(job_strs) |
|||
return 'No jobs currently scheduled.' |
|||
|
|||
def _get_next_wakeup_time(self, now): |
|||
""" |
|||
Determines the time of the next job execution, and removes finished |
|||
jobs. |
|||
|
|||
:param now: the result of datetime.now(), generated elsewhere for |
|||
consistency. |
|||
""" |
|||
next_wakeup = None |
|||
finished_jobs = [] |
|||
|
|||
self.jobs_lock.acquire() |
|||
try: |
|||
for job in self.jobs: |
|||
next_run = job.trigger.get_next_fire_time(now) |
|||
if next_run is None: |
|||
finished_jobs.append(job) |
|||
elif next_run and (next_wakeup is None or \ |
|||
next_run < next_wakeup): |
|||
next_wakeup = next_run |
|||
|
|||
# Clear out any finished jobs |
|||
for job in finished_jobs: |
|||
self.jobs.remove(job) |
|||
logger.info('Removed finished job "%s"', job) |
|||
finally: |
|||
self.jobs_lock.release() |
|||
|
|||
return next_wakeup |
|||
|
|||
def _get_current_jobs(self): |
|||
""" |
|||
Determines which jobs should be executed right now. |
|||
""" |
|||
current_jobs = [] |
|||
now = datetime.now() |
|||
start = now - timedelta(seconds=self.misfire_grace_time) |
|||
|
|||
self.jobs_lock.acquire() |
|||
try: |
|||
for job in self.jobs: |
|||
next_run = job.trigger.get_next_fire_time(start) |
|||
if next_run: |
|||
time_diff = time_difference(now, next_run) |
|||
if next_run < now and time_diff <= self.misfire_grace_time: |
|||
current_jobs.append(job) |
|||
finally: |
|||
self.jobs_lock.release() |
|||
|
|||
return current_jobs |
|||
|
|||
def run(self): |
|||
""" |
|||
Runs the main loop of the scheduler. |
|||
""" |
|||
self.wakeup.clear() |
|||
while not self.stopped: |
|||
# Execute any jobs scheduled to be run right now |
|||
for job in self._get_current_jobs(): |
|||
logger.debug('Executing job "%s"', job) |
|||
job.run() |
|||
|
|||
# Figure out when the next job should be run, and |
|||
# adjust the wait time accordingly |
|||
now = datetime.now() |
|||
next_wakeup_time = self._get_next_wakeup_time(now) |
|||
|
|||
# Sleep until the next job is scheduled to be run, |
|||
# or a new job is added, or the scheduler is stopped |
|||
if next_wakeup_time is not None: |
|||
wait_seconds = time_difference(next_wakeup_time, now) |
|||
logger.debug('Next wakeup is due at %s (in %f seconds)', |
|||
next_wakeup_time, wait_seconds) |
|||
self.wakeup.wait(wait_seconds) |
|||
else: |
|||
logger.debug('No jobs; waiting until a job is added') |
|||
self.wakeup.wait() |
|||
self.wakeup.clear() |
@ -0,0 +1,171 @@ |
|||
""" |
|||
Triggers determine the times when a job should be executed. |
|||
""" |
|||
from datetime import datetime, timedelta |
|||
from math import ceil |
|||
|
|||
from apscheduler.fields import * |
|||
from apscheduler.util import * |
|||
|
|||
__all__ = ('CronTrigger', 'DateTrigger', 'IntervalTrigger') |
|||
|
|||
|
|||
class CronTrigger(object): |
|||
FIELD_NAMES = ('year', 'month', 'day', 'week', 'day_of_week', 'hour', |
|||
'minute', 'second') |
|||
FIELDS_MAP = {'year': BaseField, |
|||
'month': BaseField, |
|||
'week': WeekField, |
|||
'day': DayOfMonthField, |
|||
'day_of_week': DayOfWeekField, |
|||
'hour': BaseField, |
|||
'minute': BaseField, |
|||
'second': BaseField} |
|||
|
|||
def __init__(self, **values): |
|||
self.fields = [] |
|||
for field_name in self.FIELD_NAMES: |
|||
exprs = values.get(field_name) or '*' |
|||
field_class = self.FIELDS_MAP[field_name] |
|||
field = field_class(field_name, exprs) |
|||
self.fields.append(field) |
|||
|
|||
def _increment_field_value(self, dateval, fieldnum): |
|||
""" |
|||
Increments the designated field and resets all less significant fields |
|||
to their minimum values. |
|||
|
|||
:type dateval: datetime |
|||
:type fieldnum: int |
|||
:type amount: int |
|||
:rtype: tuple |
|||
:return: a tuple containing the new date, and the number of the field |
|||
that was actually incremented |
|||
""" |
|||
i = 0 |
|||
values = {} |
|||
while i < len(self.fields): |
|||
field = self.fields[i] |
|||
if not field.REAL: |
|||
if i == fieldnum: |
|||
fieldnum -= 1 |
|||
i -= 1 |
|||
else: |
|||
i += 1 |
|||
continue |
|||
|
|||
if i < fieldnum: |
|||
values[field.name] = field.get_value(dateval) |
|||
i += 1 |
|||
elif i > fieldnum: |
|||
values[field.name] = field.get_min(dateval) |
|||
i += 1 |
|||
else: |
|||
value = field.get_value(dateval) |
|||
maxval = field.get_max(dateval) |
|||
if value == maxval: |
|||
fieldnum -= 1 |
|||
i -= 1 |
|||
else: |
|||
values[field.name] = value + 1 |
|||
i += 1 |
|||
|
|||
return datetime(**values), fieldnum |
|||
|
|||
def _set_field_value(self, dateval, fieldnum, new_value): |
|||
values = {} |
|||
for i, field in enumerate(self.fields): |
|||
if field.REAL: |
|||
if i < fieldnum: |
|||
values[field.name] = field.get_value(dateval) |
|||
elif i > fieldnum: |
|||
values[field.name] = field.get_min(dateval) |
|||
else: |
|||
values[field.name] = new_value |
|||
|
|||
return datetime(**values) |
|||
|
|||
def get_next_fire_time(self, start_date): |
|||
next_date = datetime_ceil(start_date) |
|||
fieldnum = 0 |
|||
while 0 <= fieldnum < len(self.fields): |
|||
field = self.fields[fieldnum] |
|||
curr_value = field.get_value(next_date) |
|||
next_value = field.get_next_value(next_date) |
|||
|
|||
if next_value is None: |
|||
# No valid value was found |
|||
next_date, fieldnum = self._increment_field_value(next_date, |
|||
fieldnum - 1) |
|||
elif next_value > curr_value: |
|||
# A valid, but higher than the starting value, was found |
|||
if field.REAL: |
|||
next_date = self._set_field_value(next_date, fieldnum, |
|||
next_value) |
|||
fieldnum += 1 |
|||
else: |
|||
next_date, fieldnum = self._increment_field_value(next_date, |
|||
fieldnum) |
|||
else: |
|||
# A valid value was found, no changes necessary |
|||
fieldnum += 1 |
|||
|
|||
if fieldnum >= 0: |
|||
return next_date |
|||
|
|||
def __repr__(self): |
|||
field_reprs = ("%s='%s'" % (f.name, str(f)) for f in self.fields |
|||
if str(f) != '*') |
|||
return '%s(%s)' % (self.__class__.__name__, ', '.join(field_reprs)) |
|||
|
|||
|
|||
class DateTrigger(object): |
|||
def __init__(self, run_date): |
|||
self.run_date = convert_to_datetime(run_date) |
|||
|
|||
def get_next_fire_time(self, start_date): |
|||
if self.run_date >= start_date: |
|||
return self.run_date |
|||
|
|||
def __repr__(self): |
|||
return '%s(%s)' % (self.__class__.__name__, repr(self.run_date)) |
|||
|
|||
|
|||
class IntervalTrigger(object): |
|||
def __init__(self, interval, repeat, start_date=None): |
|||
if not isinstance(interval, timedelta): |
|||
raise TypeError('interval must be a timedelta') |
|||
if repeat < 0: |
|||
raise ValueError('Illegal value for repeat; expected >= 0, ' |
|||
'received %s' % repeat) |
|||
|
|||
self.interval = interval |
|||
self.interval_length = timedelta_seconds(self.interval) |
|||
if self.interval_length == 0: |
|||
self.interval = timedelta(seconds=1) |
|||
self.interval_length = 1 |
|||
self.repeat = repeat |
|||
if start_date is None: |
|||
self.first_fire_date = datetime.now() + self.interval |
|||
else: |
|||
self.first_fire_date = convert_to_datetime(start_date) |
|||
self.first_fire_date -= timedelta(microseconds=\ |
|||
self.first_fire_date.microsecond) |
|||
if repeat > 0: |
|||
self.last_fire_date = self.first_fire_date + interval * (repeat - 1) |
|||
else: |
|||
self.last_fire_date = None |
|||
|
|||
def get_next_fire_time(self, start_date): |
|||
if start_date < self.first_fire_date: |
|||
return self.first_fire_date |
|||
if self.last_fire_date and start_date > self.last_fire_date: |
|||
return None |
|||
timediff_seconds = timedelta_seconds(start_date - self.first_fire_date) |
|||
next_interval_num = int(ceil(timediff_seconds / self.interval_length)) |
|||
return self.first_fire_date + self.interval * next_interval_num |
|||
|
|||
def __repr__(self): |
|||
return "%s(interval=%s, repeat=%d, start_date=%s)" % ( |
|||
self.__class__.__name__, repr(self.interval), self.repeat, |
|||
repr(self.first_fire_date)) |
@ -0,0 +1,91 @@ |
|||
""" |
|||
This module contains several handy functions primarily meant for internal use. |
|||
""" |
|||
|
|||
from datetime import date, datetime, timedelta |
|||
from time import mktime |
|||
|
|||
__all__ = ('asint', 'asbool', 'convert_to_datetime', 'timedelta_seconds', |
|||
'time_difference', 'datetime_ceil') |
|||
|
|||
|
|||
def asint(text): |
|||
""" |
|||
Safely converts a string to an integer, returning None if the string |
|||
is None. |
|||
|
|||
:type text: str |
|||
:rtype: int |
|||
""" |
|||
if text is not None: |
|||
return int(text) |
|||
|
|||
|
|||
def asbool(obj): |
|||
""" |
|||
Interprets an object as a boolean value. |
|||
|
|||
:rtype: bool |
|||
""" |
|||
if isinstance(obj, str): |
|||
obj = obj.strip().lower() |
|||
if obj in ('true', 'yes', 'on', 'y', 't', '1'): |
|||
return True |
|||
if obj in ('false', 'no', 'off', 'n', 'f', '0'): |
|||
return False |
|||
raise ValueError('Unable to interpret value "%s" as boolean' % obj) |
|||
return bool(obj) |
|||
|
|||
|
|||
def convert_to_datetime(dateval): |
|||
""" |
|||
Converts a date object to a datetime object. |
|||
If an actual datetime object is passed, it is returned unmodified. |
|||
|
|||
:type dateval: date |
|||
:rtype: datetime |
|||
""" |
|||
if isinstance(dateval, datetime): |
|||
return dateval |
|||
elif isinstance(dateval, date): |
|||
return datetime.fromordinal(dateval.toordinal()) |
|||
raise TypeError('Expected date, got %s instead' % type(dateval)) |
|||
|
|||
|
|||
def timedelta_seconds(delta): |
|||
""" |
|||
Converts the given timedelta to seconds. |
|||
|
|||
:type delta: timedelta |
|||
:rtype: float |
|||
""" |
|||
return delta.days * 24 * 60 * 60 + delta.seconds + \ |
|||
delta.microseconds / 1000000.0 |
|||
|
|||
|
|||
def time_difference(date1, date2): |
|||
""" |
|||
Returns the time difference in seconds between the given two |
|||
datetime objects. The difference is calculated as: date1 - date2. |
|||
|
|||
:param date1: the later datetime |
|||
:type date1: datetime |
|||
:param date2: the earlier datetime |
|||
:type date2: datetime |
|||
:rtype: float |
|||
""" |
|||
later = mktime(date1.timetuple()) |
|||
earlier = mktime(date2.timetuple()) |
|||
return int(later - earlier) |
|||
|
|||
|
|||
def datetime_ceil(dateval): |
|||
""" |
|||
Rounds the given datetime object upwards. |
|||
|
|||
:type dateval: datetime |
|||
""" |
|||
if dateval.microsecond > 0: |
|||
return dateval + timedelta(seconds=1, |
|||
microseconds=-dateval.microsecond) |
|||
return dateval |
File diff suppressed because it is too large
@ -0,0 +1,13 @@ |
|||
# __init__.py |
|||
# |
|||
# Copyright (C) 2010 Adrian Cristea adrian dot cristea at gmail dotcom |
|||
# |
|||
# This module is part of Axel and is released under |
|||
# the MIT License: http://www.opensource.org/licenses/mit-license.php |
|||
|
|||
import inspect |
|||
from .axel import * |
|||
__all__ = sorted(name for name, obj in locals().items() |
|||
if not (name.startswith('_') or inspect.ismodule(obj))) |
|||
__all__.append('axel') |
|||
del inspect |
@ -0,0 +1,325 @@ |
|||
# axel.py |
|||
# |
|||
# Copyright (C) 2010 Adrian Cristea adrian dot cristea at gmail dotcom |
|||
# |
|||
# Based on an idea by Peter Thatcher, found on |
|||
# http://www.valuedlessons.com/2008/04/events-in-python.html |
|||
# |
|||
# This module is part of Axel and is released under |
|||
# the MIT License: http://www.opensource.org/licenses/mit-license.php |
|||
# |
|||
# Source: http://pypi.python.org/pypi/axel |
|||
# Docs: http://packages.python.org/axel |
|||
|
|||
import sys, threading, Queue |
|||
|
|||
class Event(object): |
|||
""" |
|||
Event object inspired by C# events. Handlers can be registered and |
|||
unregistered using += and -= operators. Execution and result are |
|||
influenced by the arguments passed to the constructor and += method. |
|||
|
|||
from axel import Event |
|||
|
|||
event = Event() |
|||
def on_event(*args, **kwargs): |
|||
return (args, kwargs) |
|||
|
|||
event += on_event # handler registration |
|||
print(event(10, 20, y=30)) |
|||
>> ((True, ((10, 20), {'y': 30}), <function on_event at 0x00BAA270>),) |
|||
|
|||
event -= on_event # handler is unregistered |
|||
print(event(10, 20, y=30)) |
|||
>> None |
|||
|
|||
class Mouse(object): |
|||
def __init__(self): |
|||
self.click = Event(self) |
|||
self.click += self.on_click # handler registration |
|||
|
|||
def on_click(self, sender, *args, **kwargs): |
|||
assert isinstance(sender, Mouse), 'Wrong sender' |
|||
return (args, kwargs) |
|||
|
|||
mouse = Mouse() |
|||
print(mouse.click(10, 20)) |
|||
>> ((True, ((10, 20), {}), |
|||
>> <bound method Mouse.on_click of <__main__.Mouse object at 0x00B6F470>>),) |
|||
|
|||
mouse.click -= mouse.on_click # handler is unregistered |
|||
print(mouse.click(10, 20)) |
|||
>> None |
|||
""" |
|||
|
|||
def __init__(self, sender=None, asynch=False, exc_info=False, |
|||
lock=None, threads=3, traceback=False): |
|||
""" Creates an event |
|||
|
|||
asynch |
|||
if True handler's are executes asynchronous |
|||
exc_info |
|||
if True, result will contain sys.exc_info()[:2] on error |
|||
lock |
|||
threading.RLock used to synchronize execution |
|||
sender |
|||
event's sender. The sender is passed as the first argument to the |
|||
handler, only if is not None. For this case the handler must have |
|||
a placeholder in the arguments to receive the sender |
|||
threads |
|||
maximum number of threads that will be started |
|||
traceback |
|||
if True, the execution result will contain sys.exc_info() |
|||
on error. exc_info must be also True to get the traceback |
|||
|
|||
hash = hash(handler) |
|||
|
|||
Handlers are stored in a dictionary that has as keys the handler's hash |
|||
handlers = { |
|||
hash : (handler, memoize, timeout), |
|||
hash : (handler, memoize, timeout), ... |
|||
} |
|||
The execution result is cached using the following structure |
|||
memoize = { |
|||
hash : ((args, kwargs, result), (args, kwargs, result), ...), |
|||
hash : ((args, kwargs, result), ...), ... |
|||
} |
|||
The execution result is returned as a tuple having this structure |
|||
exec_result = ( |
|||
(True, result, handler), # on success |
|||
(False, error_info, handler), # on error |
|||
(None, None, handler), ... # asynchronous execution |
|||
) |
|||
""" |
|||
self.asynchronous = asynch |
|||
self.exc_info = exc_info |
|||
self.lock = lock |
|||
self.sender = sender |
|||
self.threads = threads |
|||
self.traceback = traceback |
|||
self.handlers = {} |
|||
self.memoize = {} |
|||
|
|||
def handle(self, handler): |
|||
""" Registers a handler. The handler can be transmitted together |
|||
with two arguments as a list or dictionary. The arguments are: |
|||
|
|||
memoize |
|||
if True, the execution result will be cached in self.memoize |
|||
timeout |
|||
will allocate a predefined time interval for the execution |
|||
|
|||
If arguments are provided as a list, they are considered to have |
|||
this sequence: (handler, memoize, timeout) |
|||
|
|||
Examples: |
|||
event += handler |
|||
event += (handler, True, 1.5) |
|||
event += {'handler':handler, 'memoize':True, 'timeout':1.5} |
|||
""" |
|||
handler_, memoize, timeout = self._extract(handler) |
|||
self.handlers[hash(handler_)] = (handler_, memoize, timeout) |
|||
return self |
|||
|
|||
def unhandle(self, handler): |
|||
""" Unregisters a handler """ |
|||
handler_, memoize, timeout = self._extract(handler) |
|||
key = hash(handler_) |
|||
if not key in self.handlers: |
|||
raise ValueError('Handler "%s" was not found' % str(handler_)) |
|||
del self.handlers[key] |
|||
return self |
|||
|
|||
def fire(self, *args, **kwargs): |
|||
""" Stores all registered handlers in a queue for processing """ |
|||
self.queue = Queue.Queue() |
|||
self.result = [] |
|||
|
|||
if self.handlers: |
|||
max_threads = self._threads() |
|||
|
|||
for i in range(max_threads): |
|||
t = threading.Thread(target=self._execute, |
|||
args=args, kwargs=kwargs) |
|||
t.daemon = True |
|||
t.start() |
|||
|
|||
for handler in self.handlers: |
|||
self.queue.put(handler) |
|||
|
|||
if self.asynchronous: |
|||
handler_, memoize, timeout = self.handlers[handler] |
|||
self.result.append((None, None, handler_)) |
|||
|
|||
if not self.asynchronous: |
|||
self.queue.join() |
|||
|
|||
return tuple(self.result) or None |
|||
|
|||
def count(self): |
|||
""" Returns the count of registered handlers """ |
|||
return len(self.handlers) |
|||
|
|||
def clear(self): |
|||
""" Discards all registered handlers and cached results """ |
|||
self.handlers.clear() |
|||
self.memoize.clear() |
|||
|
|||
def _execute(self, *args, **kwargs): |
|||
""" Executes all handlers stored in the queue """ |
|||
while True: |
|||
try: |
|||
handler, memoize, timeout = self.handlers[self.queue.get()] |
|||
|
|||
if isinstance(self.lock, threading._RLock): |
|||
self.lock.acquire() #synchronization |
|||
|
|||
try: |
|||
r = self._memoize(memoize, timeout, handler, *args, **kwargs) |
|||
if not self.asynchronous: |
|||
self.result.append(tuple(r)) |
|||
|
|||
except Exception as err: |
|||
if not self.asynchronous: |
|||
self.result.append((False, self._error(sys.exc_info()), |
|||
handler)) |
|||
finally: |
|||
if isinstance(self.lock, threading._RLock): |
|||
self.lock.release() |
|||
|
|||
if not self.asynchronous: |
|||
self.queue.task_done() |
|||
|
|||
except Queue.Empty: |
|||
break |
|||
|
|||
def _extract(self, queue_item): |
|||
""" Extracts a handler and handler's arguments that can be provided |
|||
as list or dictionary. If arguments are provided as list, they are |
|||
considered to have this sequence: (handler, memoize, timeout) |
|||
Examples: |
|||
event += handler |
|||
event += (handler, True, 1.5) |
|||
event += {'handler':handler, 'memoize':True, 'timeout':1.5} |
|||
""" |
|||
assert queue_item, 'Invalid list of arguments' |
|||
handler = None |
|||
memoize = False |
|||
timeout = 0 |
|||
|
|||
if not isinstance(queue_item, (list, tuple, dict)): |
|||
handler = queue_item |
|||
elif isinstance(queue_item, (list, tuple)): |
|||
if len(queue_item) == 3: |
|||
handler, memoize, timeout = queue_item |
|||
elif len(queue_item) == 2: |
|||
handler, memoize, = queue_item |
|||
elif len(queue_item) == 1: |
|||
handler = queue_item |
|||
elif isinstance(queue_item, dict): |
|||
handler = queue_item.get('handler') |
|||
memoize = queue_item.get('memoize', False) |
|||
timeout = queue_item.get('timeout', 0) |
|||
return (handler, bool(memoize), float(timeout)) |
|||
|
|||
def _memoize(self, memoize, timeout, handler, *args, **kwargs): |
|||
""" Caches the execution result of successful executions |
|||
hash = hash(handler) |
|||
memoize = { |
|||
hash : ((args, kwargs, result), (args, kwargs, result), ...), |
|||
hash : ((args, kwargs, result), ...), ... |
|||
} |
|||
""" |
|||
if not isinstance(handler, Event) and self.sender is not None: |
|||
args = list(args)[:] |
|||
args.insert(0, self.sender) |
|||
|
|||
if not memoize: |
|||
if timeout <= 0: #no time restriction |
|||
return [True, handler(*args, **kwargs), handler] |
|||
|
|||
result = self._timeout(timeout, handler, *args, **kwargs) |
|||
if isinstance(result, tuple) and len(result) == 3: |
|||
if isinstance(result[1], Exception): #error occurred |
|||
return [False, self._error(result), handler] |
|||
return [True, result, handler] |
|||
else: |
|||
hash_ = hash(handler) |
|||
if hash_ in self.memoize: |
|||
for args_, kwargs_, result in self.memoize[hash_]: |
|||
if args_ == args and kwargs_ == kwargs: |
|||
return [True, result, handler] |
|||
|
|||
if timeout <= 0: #no time restriction |
|||
result = handler(*args, **kwargs) |
|||
else: |
|||
result = self._timeout(timeout, handler, *args, **kwargs) |
|||
if isinstance(result, tuple) and len(result) == 3: |
|||
if isinstance(result[1], Exception): #error occurred |
|||
return [False, self._error(result), handler] |
|||
|
|||
lock = threading.RLock() |
|||
lock.acquire() |
|||
try: |
|||
if hash_ not in self.memoize: |
|||
self.memoize[hash_] = [] |
|||
self.memoize[hash_].append((args, kwargs, result)) |
|||
return [True, result, handler] |
|||
finally: |
|||
lock.release() |
|||
|
|||
def _timeout(self, timeout, handler, *args, **kwargs): |
|||
""" Controls the time allocated for the execution of a method """ |
|||
t = spawn_thread(target=handler, args=args, kwargs=kwargs) |
|||
t.daemon = True |
|||
t.start() |
|||
t.join(timeout) |
|||
|
|||
if not t.is_alive(): |
|||
if t.exc_info: |
|||
return t.exc_info |
|||
return t.result |
|||
else: |
|||
try: |
|||
msg = '[%s] Execution was forcefully terminated' |
|||
raise RuntimeError(msg % t.name) |
|||
except: |
|||
return sys.exc_info() |
|||
|
|||
def _threads(self): |
|||
""" Calculates maximum number of threads that will be started """ |
|||
if self.threads < len(self.handlers): |
|||
return self.threads |
|||
return len(self.handlers) |
|||
|
|||
def _error(self, exc_info): |
|||
""" Retrieves the error info """ |
|||
if self.exc_info: |
|||
if self.traceback: |
|||
return exc_info |
|||
return exc_info[:2] |
|||
return exc_info[1] |
|||
|
|||
__iadd__ = handle |
|||
__isub__ = unhandle |
|||
__call__ = fire |
|||
__len__ = count |
|||
|
|||
class spawn_thread(threading.Thread): |
|||
""" Spawns a new thread and returns the execution result """ |
|||
|
|||
def __init__(self, target, args=(), kwargs={}, default=None): |
|||
threading.Thread.__init__(self) |
|||
self._target = target |
|||
self._args = args |
|||
self._kwargs = kwargs |
|||
self.result = default |
|||
self.exc_info = None |
|||
|
|||
def run(self): |
|||
try: |
|||
self.result = self._target(*self._args, **self._kwargs) |
|||
except: |
|||
self.exc_info = sys.exc_info() |
|||
finally: |
|||
del self._target, self._args, self._kwargs |
@ -0,0 +1,181 @@ |
|||
## {{{ http://code.activestate.com/recipes/278731/ (r6) |
|||
"""Disk And Execution MONitor (Daemon) |
|||
|
|||
Configurable daemon behaviors: |
|||
|
|||
1.) The current working directory set to the "/" directory. |
|||
2.) The current file creation mode mask set to 0. |
|||
3.) Close all open files (1024). |
|||
4.) Redirect standard I/O streams to "/dev/null". |
|||
|
|||
A failed call to fork() now raises an exception. |
|||
|
|||
References: |
|||
1) Advanced Programming in the Unix Environment: W. Richard Stevens |
|||
2) Unix Programming Frequently Asked Questions: |
|||
http://www.erlenstar.demon.co.uk/unix/faq_toc.html |
|||
""" |
|||
|
|||
__author__ = "Chad J. Schroeder" |
|||
__copyright__ = "Copyright (C) 2005 Chad J. Schroeder" |
|||
|
|||
__revision__ = "$Id$" |
|||
__version__ = "0.2" |
|||
|
|||
# Standard Python modules. |
|||
import os # Miscellaneous OS interfaces. |
|||
import sys # System-specific parameters and functions. |
|||
|
|||
# Default daemon parameters. |
|||
# File mode creation mask of the daemon. |
|||
UMASK = 0 |
|||
|
|||
# Default working directory for the daemon. |
|||
WORKDIR = "/" |
|||
|
|||
# Default maximum for the number of available file descriptors. |
|||
MAXFD = 1024 |
|||
|
|||
# The standard I/O file descriptors are redirected to /dev/null by default. |
|||
if (hasattr(os, "devnull")): |
|||
REDIRECT_TO = os.devnull |
|||
else: |
|||
REDIRECT_TO = "/dev/null" |
|||
|
|||
def createDaemon(): |
|||
"""Detach a process from the controlling terminal and run it in the |
|||
background as a daemon. |
|||
""" |
|||
|
|||
try: |
|||
# Fork a child process so the parent can exit. This returns control to |
|||
# the command-line or shell. It also guarantees that the child will not |
|||
# be a process group leader, since the child receives a new process ID |
|||
# and inherits the parent's process group ID. This step is required |
|||
# to insure that the next call to os.setsid is successful. |
|||
pid = os.fork() |
|||
except OSError, e: |
|||
raise Exception, "%s [%d]" % (e.strerror, e.errno) |
|||
|
|||
if (pid == 0): # The first child. |
|||
# To become the session leader of this new session and the process group |
|||
# leader of the new process group, we call os.setsid(). The process is |
|||
# also guaranteed not to have a controlling terminal. |
|||
os.setsid() |
|||
|
|||
# Is ignoring SIGHUP necessary? |
|||
# |
|||
# It's often suggested that the SIGHUP signal should be ignored before |
|||
# the second fork to avoid premature termination of the process. The |
|||
# reason is that when the first child terminates, all processes, e.g. |
|||
# the second child, in the orphaned group will be sent a SIGHUP. |
|||
# |
|||
# "However, as part of the session management system, there are exactly |
|||
# two cases where SIGHUP is sent on the death of a process: |
|||
# |
|||
# 1) When the process that dies is the session leader of a session that |
|||
# is attached to a terminal device, SIGHUP is sent to all processes |
|||
# in the foreground process group of that terminal device. |
|||
# 2) When the death of a process causes a process group to become |
|||
# orphaned, and one or more processes in the orphaned group are |
|||
# stopped, then SIGHUP and SIGCONT are sent to all members of the |
|||
# orphaned group." [2] |
|||
# |
|||
# The first case can be ignored since the child is guaranteed not to have |
|||
# a controlling terminal. The second case isn't so easy to dismiss. |
|||
# The process group is orphaned when the first child terminates and |
|||
# POSIX.1 requires that every STOPPED process in an orphaned process |
|||
# group be sent a SIGHUP signal followed by a SIGCONT signal. Since the |
|||
# second child is not STOPPED though, we can safely forego ignoring the |
|||
# SIGHUP signal. In any case, there are no ill-effects if it is ignored. |
|||
# |
|||
# import signal # Set handlers for asynchronous events. |
|||
# signal.signal(signal.SIGHUP, signal.SIG_IGN) |
|||
|
|||
try: |
|||
# Fork a second child and exit immediately to prevent zombies. This |
|||
# causes the second child process to be orphaned, making the init |
|||
# process responsible for its cleanup. And, since the first child is |
|||
# a session leader without a controlling terminal, it's possible for |
|||
# it to acquire one by opening a terminal in the future (System V- |
|||
# based systems). This second fork guarantees that the child is no |
|||
# longer a session leader, preventing the daemon from ever acquiring |
|||
# a controlling terminal. |
|||
pid = os.fork() # Fork a second child. |
|||
except OSError, e: |
|||
raise Exception, "%s [%d]" % (e.strerror, e.errno) |
|||
|
|||
if (pid == 0): # The second child. |
|||
# Since the current working directory may be a mounted filesystem, we |
|||
# avoid the issue of not being able to unmount the filesystem at |
|||
# shutdown time by changing it to the root directory. |
|||
os.chdir(WORKDIR) |
|||
# We probably don't want the file mode creation mask inherited from |
|||
# the parent, so we give the child complete control over permissions. |
|||
os.umask(UMASK) |
|||
else: |
|||
# exit() or _exit()? See below. |
|||
os._exit(0) # Exit parent (the first child) of the second child. |
|||
else: |
|||
# exit() or _exit()? |
|||
# _exit is like exit(), but it doesn't call any functions registered |
|||
# with atexit (and on_exit) or any registered signal handlers. It also |
|||
# closes any open file descriptors. Using exit() may cause all stdio |
|||
# streams to be flushed twice and any temporary files may be unexpectedly |
|||
# removed. It's therefore recommended that child branches of a fork() |
|||
# and the parent branch(es) of a daemon use _exit(). |
|||
os._exit(0) # Exit parent of the first child. |
|||
|
|||
# Close all open file descriptors. This prevents the child from keeping |
|||
# open any file descriptors inherited from the parent. There is a variety |
|||
# of methods to accomplish this task. Three are listed below. |
|||
# |
|||
# Try the system configuration variable, SC_OPEN_MAX, to obtain the maximum |
|||
# number of open file descriptors to close. If it doesn't exists, use |
|||
# the default value (configurable). |
|||
# |
|||
# try: |
|||
# maxfd = os.sysconf("SC_OPEN_MAX") |
|||
# except (AttributeError, ValueError): |
|||
# maxfd = MAXFD |
|||
# |
|||
# OR |
|||
# |
|||
# if (os.sysconf_names.has_key("SC_OPEN_MAX")): |
|||
# maxfd = os.sysconf("SC_OPEN_MAX") |
|||
# else: |
|||
# maxfd = MAXFD |
|||
# |
|||
# OR |
|||
# |
|||
# Use the getrlimit method to retrieve the maximum file descriptor number |
|||
# that can be opened by this process. If there is not limit on the |
|||
# resource, use the default value. |
|||
# |
|||
import resource # Resource usage information. |
|||
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] |
|||
if (maxfd == resource.RLIM_INFINITY): |
|||
maxfd = MAXFD |
|||
|
|||
# Iterate through and close all file descriptors. |
|||
for fd in range(0, maxfd): |
|||
try: |
|||
os.close(fd) |
|||
except OSError: # ERROR, fd wasn't open to begin with (ignored) |
|||
pass |
|||
|
|||
# Redirect the standard I/O file descriptors to the specified file. Since |
|||
# the daemon has no controlling terminal, most daemons redirect stdin, |
|||
# stdout, and stderr to /dev/null. This is done to prevent side-effects |
|||
# from reads and writes to the standard I/O file descriptors. |
|||
|
|||
# This call to open is guaranteed to return the lowest file descriptor, |
|||
# which will be 0 (stdin), since it was closed above. |
|||
os.open(REDIRECT_TO, os.O_RDWR) # standard input (0) |
|||
|
|||
# Duplicate standard input to standard output and standard error. |
|||
os.dup2(0, 1) # standard output (1) |
|||
os.dup2(0, 2) # standard error (2) |
|||
|
|||
return(0) |
|||
|
@ -0,0 +1,209 @@ |
|||
########################## LICENCE ############################### |
|||
## |
|||
## Copyright (c) 2005-2011, Michele Simionato |
|||
## All rights reserved. |
|||
## |
|||
## Redistributions of source code must retain the above copyright |
|||
## notice, this list of conditions and the following disclaimer. |
|||
## Redistributions in bytecode form must reproduce the above copyright |
|||
## notice, this list of conditions and the following disclaimer in |
|||
## the documentation and/or other materials provided with the |
|||
## distribution. |
|||
|
|||
## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
|||
## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|||
## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
|||
## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
|||
## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
|||
## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
|||
## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
|||
## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
|||
## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
|||
## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
|||
## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH |
|||
## DAMAGE. |
|||
|
|||
""" |
|||
Decorator module, see http://pypi.python.org/pypi/decorator |
|||
for the documentation. |
|||
""" |
|||
|
|||
__version__ = '3.3.0' |
|||
|
|||
__all__ = ["decorator", "FunctionMaker", "partial"] |
|||
|
|||
import sys, re, inspect |
|||
|
|||
try: |
|||
from functools import partial |
|||
except ImportError: # for Python version < 2.5 |
|||
class partial(object): |
|||
"A simple replacement of functools.partial" |
|||
def __init__(self, func, *args, **kw): |
|||
self.func = func |
|||
self.args = args |
|||
self.keywords = kw |
|||
def __call__(self, *otherargs, **otherkw): |
|||
kw = self.keywords.copy() |
|||
kw.update(otherkw) |
|||
return self.func(*(self.args + otherargs), **kw) |
|||
|
|||
if sys.version >= '3': |
|||
from inspect import getfullargspec |
|||
else: |
|||
class getfullargspec(object): |
|||
"A quick and dirty replacement for getfullargspec for Python 2.X" |
|||
def __init__(self, f): |
|||
self.args, self.varargs, self.varkw, self.defaults = \ |
|||
inspect.getargspec(f) |
|||
self.kwonlyargs = [] |
|||
self.kwonlydefaults = None |
|||
self.annotations = getattr(f, '__annotations__', {}) |
|||
def __iter__(self): |
|||
yield self.args |
|||
yield self.varargs |
|||
yield self.varkw |
|||
yield self.defaults |
|||
|
|||
DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') |
|||
|
|||
# basic functionality |
|||
class FunctionMaker(object): |
|||
""" |
|||
An object with the ability to create functions with a given signature. |
|||
It has attributes name, doc, module, signature, defaults, dict and |
|||
methods update and make. |
|||
""" |
|||
def __init__(self, func=None, name=None, signature=None, |
|||
defaults=None, doc=None, module=None, funcdict=None): |
|||
self.shortsignature = signature |
|||
if func: |
|||
# func can be a class or a callable, but not an instance method |
|||
self.name = func.__name__ |
|||
if self.name == '<lambda>': # small hack for lambda functions |
|||
self.name = '_lambda_' |
|||
self.doc = func.__doc__ |
|||
self.module = func.__module__ |
|||
if inspect.isfunction(func): |
|||
argspec = getfullargspec(func) |
|||
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', |
|||
'kwonlydefaults', 'annotations'): |
|||
setattr(self, a, getattr(argspec, a)) |
|||
for i, arg in enumerate(self.args): |
|||
setattr(self, 'arg%d' % i, arg) |
|||
self.signature = inspect.formatargspec( |
|||
formatvalue=lambda val: "", *argspec)[1:-1] |
|||
allargs = list(self.args) |
|||
if self.varargs: |
|||
allargs.append('*' + self.varargs) |
|||
if self.varkw: |
|||
allargs.append('**' + self.varkw) |
|||
try: |
|||
self.shortsignature = ', '.join(allargs) |
|||
except TypeError: # exotic signature, valid only in Python 2.X |
|||
self.shortsignature = self.signature |
|||
self.dict = func.__dict__.copy() |
|||
# func=None happens when decorating a caller |
|||
if name: |
|||
self.name = name |
|||
if signature is not None: |
|||
self.signature = signature |
|||
if defaults: |
|||
self.defaults = defaults |
|||
if doc: |
|||
self.doc = doc |
|||
if module: |
|||
self.module = module |
|||
if funcdict: |
|||
self.dict = funcdict |
|||
# check existence required attributes |
|||
assert hasattr(self, 'name') |
|||
if not hasattr(self, 'signature'): |
|||
raise TypeError('You are decorating a non function: %s' % func) |
|||
|
|||
def update(self, func, **kw): |
|||
"Update the signature of func with the data in self" |
|||
func.__name__ = self.name |
|||
func.__doc__ = getattr(self, 'doc', None) |
|||
func.__dict__ = getattr(self, 'dict', {}) |
|||
func.func_defaults = getattr(self, 'defaults', ()) |
|||
callermodule = sys._getframe(3).f_globals.get('__name__', '?') |
|||
func.__module__ = getattr(self, 'module', callermodule) |
|||
func.__dict__.update(kw) |
|||
|
|||
def make(self, src_templ, evaldict=None, addsource=False, **attrs): |
|||
"Make a new function from a given template and update the signature" |
|||
src = src_templ % vars(self) # expand name and signature |
|||
evaldict = evaldict or {} |
|||
mo = DEF.match(src) |
|||
if mo is None: |
|||
raise SyntaxError('not a valid function template\n%s' % src) |
|||
name = mo.group(1) # extract the function name |
|||
names = set([name] + [arg.strip(' *') for arg in |
|||
self.shortsignature.split(',')]) |
|||
for n in names: |
|||
if n in ('_func_', '_call_'): |
|||
raise NameError('%s is overridden in\n%s' % (n, src)) |
|||
if not src.endswith('\n'): # add a newline just for safety |
|||
src += '\n' # this is needed in old versions of Python |
|||
try: |
|||
code = compile(src, '<string>', 'single') |
|||
# print >> sys.stderr, 'Compiling %s' % src |
|||
exec code in evaldict |
|||
except: |
|||
print >> sys.stderr, 'Error in generated code:' |
|||
print >> sys.stderr, src |
|||
raise |
|||
func = evaldict[name] |
|||
if addsource: |
|||
attrs['__source__'] = src |
|||
self.update(func, **attrs) |
|||
return func |
|||
|
|||
@classmethod |
|||
def create(cls, obj, body, evaldict, defaults=None, |
|||
doc=None, module=None, addsource=True,**attrs): |
|||
""" |
|||
Create a function from the strings name, signature and body. |
|||
evaldict is the evaluation dictionary. If addsource is true an attribute |
|||
__source__ is added to the result. The attributes attrs are added, |
|||
if any. |
|||
""" |
|||
if isinstance(obj, str): # "name(signature)" |
|||
name, rest = obj.strip().split('(', 1) |
|||
signature = rest[:-1] #strip a right parens |
|||
func = None |
|||
else: # a function |
|||
name = None |
|||
signature = None |
|||
func = obj |
|||
self = cls(func, name, signature, defaults, doc, module) |
|||
ibody = '\n'.join(' ' + line for line in body.splitlines()) |
|||
return self.make('def %(name)s(%(signature)s):\n' + ibody, |
|||
evaldict, addsource, **attrs) |
|||
|
|||
def decorator(caller, func=None): |
|||
""" |
|||
decorator(caller) converts a caller function into a decorator; |
|||
decorator(caller, func) decorates a function using a caller. |
|||
""" |
|||
if func is not None: # returns a decorated function |
|||
evaldict = func.func_globals.copy() |
|||
evaldict['_call_'] = caller |
|||
evaldict['_func_'] = func |
|||
return FunctionMaker.create( |
|||
func, "return _call_(_func_, %(shortsignature)s)", |
|||
evaldict, undecorated=func) |
|||
else: # returns a decorator |
|||
if isinstance(caller, partial): |
|||
return partial(decorator, caller) |
|||
# otherwise assume caller is a function |
|||
first = inspect.getargspec(caller)[0][0] # first arg |
|||
evaldict = caller.func_globals.copy() |
|||
evaldict['_call_'] = caller |
|||
evaldict['decorator'] = decorator |
|||
return FunctionMaker.create( |
|||
'%s(%s)' % (caller.__name__, first), |
|||
'return decorator(_call_, %s)' % first, |
|||
evaldict, undecorated=caller, |
|||
doc=caller.__doc__, module=caller.__module__) |
@ -0,0 +1,119 @@ |
|||
''' |
|||
Elixir package |
|||
|
|||
A declarative layer on top of the `SQLAlchemy library |
|||
<http://www.sqlalchemy.org/>`_. It is a fairly thin wrapper, which provides |
|||
the ability to create simple Python classes that map directly to relational |
|||
database tables (this pattern is often referred to as the Active Record design |
|||
pattern), providing many of the benefits of traditional databases |
|||
without losing the convenience of Python objects. |
|||
|
|||
Elixir is intended to replace the ActiveMapper SQLAlchemy extension, and the |
|||
TurboEntity project but does not intend to replace SQLAlchemy's core features, |
|||
and instead focuses on providing a simpler syntax for defining model objects |
|||
when you do not need the full expressiveness of SQLAlchemy's manual mapper |
|||
definitions. |
|||
''' |
|||
|
|||
try: |
|||
set |
|||
except NameError: |
|||
from sets import Set as set |
|||
|
|||
import sqlalchemy |
|||
from sqlalchemy.types import * |
|||
|
|||
from elixir.options import using_options, using_table_options, \ |
|||
using_mapper_options, options_defaults, \ |
|||
using_options_defaults |
|||
from elixir.entity import Entity, EntityBase, EntityMeta, EntityDescriptor, \ |
|||
setup_entities, cleanup_entities |
|||
from elixir.fields import has_field, Field |
|||
from elixir.relationships import belongs_to, has_one, has_many, \ |
|||
has_and_belongs_to_many, \ |
|||
ManyToOne, OneToOne, OneToMany, ManyToMany |
|||
from elixir.properties import has_property, GenericProperty, ColumnProperty, \ |
|||
Synonym |
|||
from elixir.statements import Statement |
|||
from elixir.collection import EntityCollection, GlobalEntityCollection |
|||
|
|||
|
|||
__version__ = '0.7.1' |
|||
|
|||
__all__ = ['Entity', 'EntityBase', 'EntityMeta', 'EntityCollection', |
|||
'entities', |
|||
'Field', 'has_field', |
|||
'has_property', 'GenericProperty', 'ColumnProperty', 'Synonym', |
|||
'belongs_to', 'has_one', 'has_many', 'has_and_belongs_to_many', |
|||
'ManyToOne', 'OneToOne', 'OneToMany', 'ManyToMany', |
|||
'using_options', 'using_table_options', 'using_mapper_options', |
|||
'options_defaults', 'using_options_defaults', |
|||
'metadata', 'session', |
|||
'create_all', 'drop_all', |
|||
'setup_all', 'cleanup_all', |
|||
'setup_entities', 'cleanup_entities'] + \ |
|||
sqlalchemy.types.__all__ |
|||
|
|||
__doc_all__ = ['create_all', 'drop_all', |
|||
'setup_all', 'cleanup_all', |
|||
'metadata', 'session'] |
|||
|
|||
# default session |
|||
session = sqlalchemy.orm.scoped_session(sqlalchemy.orm.sessionmaker()) |
|||
|
|||
# default metadata |
|||
metadata = sqlalchemy.MetaData() |
|||
|
|||
metadatas = set() |
|||
|
|||
# default entity collection |
|||
entities = GlobalEntityCollection() |
|||
|
|||
|
|||
def create_all(*args, **kwargs): |
|||
'''Create the necessary tables for all declared entities''' |
|||
for md in metadatas: |
|||
md.create_all(*args, **kwargs) |
|||
|
|||
|
|||
def drop_all(*args, **kwargs): |
|||
'''Drop tables for all declared entities''' |
|||
for md in metadatas: |
|||
md.drop_all(*args, **kwargs) |
|||
|
|||
|
|||
def setup_all(create_tables=False, *args, **kwargs): |
|||
'''Setup the table and mapper of all entities in the default entity |
|||
collection. |
|||
|
|||
This is called automatically if any entity of the collection is configured |
|||
with the `autosetup` option and it is first accessed, |
|||
instanciated (called) or the create_all method of a metadata containing |
|||
tables from any of those entities is called. |
|||
''' |
|||
setup_entities(entities) |
|||
|
|||
# issue the "CREATE" SQL statements |
|||
if create_tables: |
|||
create_all(*args, **kwargs) |
|||
|
|||
|
|||
def cleanup_all(drop_tables=False, *args, **kwargs): |
|||
'''Clear all mappers, clear the session, and clear all metadatas. |
|||
Optionally drops the tables. |
|||
''' |
|||
session.close() |
|||
|
|||
cleanup_entities(entities) |
|||
|
|||
sqlalchemy.orm.clear_mappers() |
|||
entities.clear() |
|||
|
|||
if drop_tables: |
|||
drop_all(*args, **kwargs) |
|||
|
|||
for md in metadatas: |
|||
md.clear() |
|||
metadatas.clear() |
|||
|
|||
|
@ -0,0 +1,127 @@ |
|||
''' |
|||
Default entity collection implementation |
|||
''' |
|||
import sys |
|||
import re |
|||
|
|||
from elixir.py23compat import rsplit |
|||
|
|||
class BaseCollection(list): |
|||
def __init__(self, entities=None): |
|||
list.__init__(self) |
|||
if entities is not None: |
|||
self.extend(entities) |
|||
|
|||
def extend(self, entities): |
|||
for e in entities: |
|||
self.append(e) |
|||
|
|||
def clear(self): |
|||
del self[:] |
|||
|
|||
def resolve_absolute(self, key, full_path, entity=None, root=None): |
|||
if root is None: |
|||
root = entity._descriptor.resolve_root |
|||
if root: |
|||
full_path = '%s.%s' % (root, full_path) |
|||
module_path, classname = rsplit(full_path, '.', 1) |
|||
module = sys.modules[module_path] |
|||
res = getattr(module, classname, None) |
|||
if res is None: |
|||
if entity is not None: |
|||
raise Exception("Couldn't resolve target '%s' <%s> in '%s'!" |
|||
% (key, full_path, entity.__name__)) |
|||
else: |
|||
raise Exception("Couldn't resolve target '%s' <%s>!" |
|||
% (key, full_path)) |
|||
return res |
|||
|
|||
def __getattr__(self, key): |
|||
return self.resolve(key) |
|||
|
|||
# default entity collection |
|||
class GlobalEntityCollection(BaseCollection): |
|||
def __init__(self, entities=None): |
|||
# _entities is a dict of entities keyed on their name. |
|||
self._entities = {} |
|||
super(GlobalEntityCollection, self).__init__(entities) |
|||
|
|||
def append(self, entity): |
|||
''' |
|||
Add an entity to the collection. |
|||
''' |
|||
super(EntityCollection, self).append(entity) |
|||
|
|||
existing_entities = self._entities.setdefault(entity.__name__, []) |
|||
existing_entities.append(entity) |
|||
|
|||
def resolve(self, key, entity=None): |
|||
''' |
|||
Resolve a key to an Entity. The optional `entity` argument is the |
|||
"source" entity when resolving relationship targets. |
|||
''' |
|||
# Do we have a fully qualified entity name? |
|||
if '.' in key: |
|||
return self.resolve_absolute(key, key, entity) |
|||
else: |
|||
# Otherwise we look in the entities of this collection |
|||
res = self._entities.get(key, None) |
|||
if res is None: |
|||
if entity: |
|||
raise Exception("Couldn't resolve target '%s' in '%s'" |
|||
% (key, entity.__name__)) |
|||
else: |
|||
raise Exception("This collection does not contain any " |
|||
"entity corresponding to the key '%s'!" |
|||
% key) |
|||
elif len(res) > 1: |
|||
raise Exception("'%s' resolves to several entities, you should" |
|||
" use the full path (including the full module" |
|||
" name) to that entity." % key) |
|||
else: |
|||
return res[0] |
|||
|
|||
def clear(self): |
|||
self._entities = {} |
|||
super(GlobalEntityCollection, self).clear() |
|||
|
|||
# backward compatible name |
|||
EntityCollection = GlobalEntityCollection |
|||
|
|||
_leading_dots = re.compile('^([.]*).*$') |
|||
|
|||
class RelativeEntityCollection(BaseCollection): |
|||
# the entity=None does not make any sense with a relative entity collection |
|||
def resolve(self, key, entity): |
|||
''' |
|||
Resolve a key to an Entity. The optional `entity` argument is the |
|||
"source" entity when resolving relationship targets. |
|||
''' |
|||
full_path = key |
|||
|
|||
if '.' not in key or key.startswith('.'): |
|||
# relative target |
|||
|
|||
# any leading dot is stripped and with each dot removed, |
|||
# the entity_module is stripped of one more chunk (starting with |
|||
# the last one). |
|||
num_dots = _leading_dots.match(full_path).end(1) |
|||
full_path = full_path[num_dots:] |
|||
chunks = entity.__module__.split('.') |
|||
chunkstokeep = len(chunks) - num_dots |
|||
if chunkstokeep < 0: |
|||
raise Exception("Couldn't resolve relative target " |
|||
"'%s' relative to '%s'" % (key, entity.__module__)) |
|||
entity_module = '.'.join(chunks[:chunkstokeep]) |
|||
|
|||
if entity_module and entity_module is not '__main__': |
|||
full_path = '%s.%s' % (entity_module, full_path) |
|||
|
|||
root = '' |
|||
else: |
|||
root = None |
|||
return self.resolve_absolute(key, full_path, entity, root=root) |
|||
|
|||
def __getattr__(self, key): |
|||
raise NotImplementedError |
|||
|
File diff suppressed because it is too large
@ -0,0 +1,30 @@ |
|||
__all__ = [ |
|||
'before_insert', |
|||
'after_insert', |
|||
'before_update', |
|||
'after_update', |
|||
'before_delete', |
|||
'after_delete', |
|||
'reconstructor' |
|||
] |
|||
|
|||
def create_decorator(event_name): |
|||
def decorator(func): |
|||
if not hasattr(func, '_elixir_events'): |
|||
func._elixir_events = [] |
|||
func._elixir_events.append(event_name) |
|||
return func |
|||
return decorator |
|||
|
|||
before_insert = create_decorator('before_insert') |
|||
after_insert = create_decorator('after_insert') |
|||
before_update = create_decorator('before_update') |
|||
after_update = create_decorator('after_update') |
|||
before_delete = create_decorator('before_delete') |
|||
after_delete = create_decorator('after_delete') |
|||
try: |
|||
from sqlalchemy.orm import reconstructor |
|||
except ImportError: |
|||
def reconstructor(func): |
|||
raise Exception('The reconstructor method decorator is only ' |
|||
'available with SQLAlchemy 0.5 and later') |
@ -0,0 +1,5 @@ |
|||
''' |
|||
Ext package |
|||
|
|||
Additional Elixir statements and functionality. |
|||
''' |
@ -0,0 +1,234 @@ |
|||
''' |
|||
Associable Elixir Statement Generator |
|||
|
|||
========== |
|||
Associable |
|||
========== |
|||
|
|||
About Polymorphic Associations |
|||
------------------------------ |
|||
|
|||
A frequent pattern in database schemas is the has_and_belongs_to_many, or a |
|||
many-to-many table. Quite often multiple tables will refer to a single one |
|||
creating quite a few many-to-many intermediate tables. |
|||
|
|||
Polymorphic associations lower the amount of many-to-many tables by setting up |
|||
a table that allows relations to any other table in the database, and relates |
|||
it to the associable table. In some implementations, this layout does not |
|||
enforce referential integrity with database foreign key constraints, this |
|||
implementation uses an additional many-to-many table with foreign key |
|||
constraints to avoid this problem. |
|||
|
|||
.. note: |
|||
SQLite does not support foreign key constraints, so referential integrity |
|||
can only be enforced using database backends with such support. |
|||
|
|||
Elixir Statement Generator for Polymorphic Associations |
|||
------------------------------------------------------- |
|||
|
|||
The ``associable`` function generates the intermediary tables for an Elixir |
|||
entity that should be associable with other Elixir entities and returns an |
|||
Elixir Statement for use with them. This automates the process of creating the |
|||
polymorphic association tables and ensuring their referential integrity. |
|||
|
|||
Matching select_XXX and select_by_XXX are also added to the associated entity |
|||
which allow queries to be run for the associated objects. |
|||
|
|||
Example usage: |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Tag(Entity): |
|||
name = Field(Unicode) |
|||
|
|||
acts_as_taggable = associable(Tag) |
|||
|
|||
class Entry(Entity): |
|||
title = Field(Unicode) |
|||
acts_as_taggable('tags') |
|||
|
|||
class Article(Entity): |
|||
title = Field(Unicode) |
|||
acts_as_taggable('tags') |
|||
|
|||
Or if one of the entities being associated should only have a single member of |
|||
the associated table: |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Address(Entity): |
|||
street = Field(String(130)) |
|||
city = Field(String(100)) |
|||
|
|||
is_addressable = associable(Address, 'addresses') |
|||
|
|||
class Person(Entity): |
|||
name = Field(Unicode) |
|||
orders = OneToMany('Order') |
|||
is_addressable() |
|||
|
|||
class Order(Entity): |
|||
order_num = Field(primary_key=True) |
|||
item_count = Field(Integer) |
|||
person = ManyToOne('Person') |
|||
is_addressable('address', uselist=False) |
|||
|
|||
home = Address(street='123 Elm St.', city='Spooksville') |
|||
user = Person(name='Jane Doe') |
|||
user.addresses.append(home) |
|||
|
|||
neworder = Order(item_count=4) |
|||
neworder.address = home |
|||
user.orders.append(neworder) |
|||
|
|||
# Queries using the added helpers |
|||
Person.select_by_addresses(city='Cupertino') |
|||
Person.select_addresses(and_(Address.c.street=='132 Elm St', |
|||
Address.c.city=='Smallville')) |
|||
|
|||
Statement Options |
|||
----------------- |
|||
|
|||
The generated Elixir Statement has several options available: |
|||
|
|||
+---------------+-------------------------------------------------------------+ |
|||
| Option Name | Description | |
|||
+===============+=============================================================+ |
|||
| ``name`` | Specify a custom name for the Entity attribute. This is | |
|||
| | used to declare the attribute used to access the associated | |
|||
| | table values. Otherwise, the name will use the plural_name | |
|||
| | provided to the associable call. | |
|||
+---------------+-------------------------------------------------------------+ |
|||
| ``uselist`` | Whether or not the associated table should be represented | |
|||
| | as a list, or a single property. It should be set to False | |
|||
| | when the entity should only have a single associated | |
|||
| | entity. Defaults to True. | |
|||
+---------------+-------------------------------------------------------------+ |
|||
| ``lazy`` | Determines eager loading of the associated entity objects. | |
|||
| | Defaults to False, to indicate that they should not be | |
|||
| | lazily loaded. | |
|||
+---------------+-------------------------------------------------------------+ |
|||
''' |
|||
from elixir.statements import Statement |
|||
import sqlalchemy as sa |
|||
|
|||
__doc_all__ = ['associable'] |
|||
|
|||
|
|||
def associable(assoc_entity, plural_name=None, lazy=True): |
|||
''' |
|||
Generate an associable Elixir Statement |
|||
''' |
|||
interface_name = assoc_entity._descriptor.tablename |
|||
able_name = interface_name + 'able' |
|||
|
|||
if plural_name: |
|||
attr_name = "%s_rel" % plural_name |
|||
else: |
|||
plural_name = interface_name |
|||
attr_name = "%s_rel" % interface_name |
|||
|
|||
class GenericAssoc(object): |
|||
|
|||
def __init__(self, tablename): |
|||
self.type = tablename |
|||
|
|||
#TODO: inherit from entity builder |
|||
class Associable(object): |
|||
"""An associable Elixir Statement object""" |
|||
|
|||
def __init__(self, entity, name=None, uselist=True, lazy=True): |
|||
self.entity = entity |
|||
self.lazy = lazy |
|||
self.uselist = uselist |
|||
|
|||
if name is None: |
|||
self.name = plural_name |
|||
else: |
|||
self.name = name |
|||
|
|||
def after_table(self): |
|||
col = sa.Column('%s_assoc_id' % interface_name, sa.Integer, |
|||
sa.ForeignKey('%s.id' % able_name)) |
|||
self.entity._descriptor.add_column(col) |
|||
|
|||
if not hasattr(assoc_entity, '_assoc_table'): |
|||
metadata = assoc_entity._descriptor.metadata |
|||
association_table = sa.Table("%s" % able_name, metadata, |
|||
sa.Column('id', sa.Integer, primary_key=True), |
|||
sa.Column('type', sa.String(40), nullable=False), |
|||
) |
|||
tablename = "%s_to_%s" % (able_name, interface_name) |
|||
association_to_table = sa.Table(tablename, metadata, |
|||
sa.Column('assoc_id', sa.Integer, |
|||
sa.ForeignKey(association_table.c.id, |
|||
ondelete="CASCADE"), |
|||
primary_key=True), |
|||
#FIXME: this assumes a single id col |
|||
sa.Column('%s_id' % interface_name, sa.Integer, |
|||
sa.ForeignKey(assoc_entity.table.c.id, |
|||
ondelete="RESTRICT"), |
|||
primary_key=True), |
|||
) |
|||
|
|||
assoc_entity._assoc_table = association_table |
|||
assoc_entity._assoc_to_table = association_to_table |
|||
|
|||
def after_mapper(self): |
|||
if not hasattr(assoc_entity, '_assoc_mapper'): |
|||
assoc_entity._assoc_mapper = sa.orm.mapper( |
|||
GenericAssoc, assoc_entity._assoc_table, properties={ |
|||
'targets': sa.orm.relation( |
|||
assoc_entity, |
|||
secondary=assoc_entity._assoc_to_table, |
|||
lazy=lazy, backref='associations', |
|||
order_by=assoc_entity.mapper.order_by) |
|||
}) |
|||
|
|||
entity = self.entity |
|||
entity.mapper.add_property( |
|||
attr_name, |
|||
sa.orm.relation(GenericAssoc, lazy=self.lazy, |
|||
backref='_backref_%s' % entity.table.name) |
|||
) |
|||
|
|||
if self.uselist: |
|||
def get(self): |
|||
if getattr(self, attr_name) is None: |
|||
setattr(self, attr_name, |
|||
GenericAssoc(entity.table.name)) |
|||
return getattr(self, attr_name).targets |
|||
setattr(entity, self.name, property(get)) |
|||
else: |
|||
# scalar based property decorator |
|||
def get(self): |
|||
attr = getattr(self, attr_name) |
|||
if attr is not None: |
|||
return attr.targets[0] |
|||
else: |
|||
return None |
|||
def set(self, value): |
|||
if getattr(self, attr_name) is None: |
|||
setattr(self, attr_name, |
|||
GenericAssoc(entity.table.name)) |
|||
getattr(self, attr_name).targets = [value] |
|||
setattr(entity, self.name, property(get, set)) |
|||
|
|||
# self.name is both set via mapper synonym and the python |
|||
# property, but that's how synonym properties work. |
|||
# adding synonym property after "real" property otherwise it |
|||
# breaks when using SQLAlchemy > 0.4.1 |
|||
entity.mapper.add_property(self.name, sa.orm.synonym(attr_name)) |
|||
|
|||
# add helper methods |
|||
def select_by(cls, **kwargs): |
|||
return cls.query.join([attr_name, 'targets']) \ |
|||
.filter_by(**kwargs).all() |
|||
setattr(entity, 'select_by_%s' % self.name, classmethod(select_by)) |
|||
|
|||
def select(cls, *args, **kwargs): |
|||
return cls.query.join([attr_name, 'targets']) \ |
|||
.filter(*args, **kwargs).all() |
|||
setattr(entity, 'select_%s' % self.name, classmethod(select)) |
|||
|
|||
return Statement(Associable) |
@ -0,0 +1,124 @@ |
|||
''' |
|||
An encryption plugin for Elixir utilizing the excellent PyCrypto library, which |
|||
can be downloaded here: http://www.amk.ca/python/code/crypto |
|||
|
|||
Values for columns that are specified to be encrypted will be transparently |
|||
encrypted and safely encoded for storage in a unicode column using the powerful |
|||
and secure Blowfish Cipher using a specified "secret" which can be passed into |
|||
the plugin at class declaration time. |
|||
|
|||
Example usage: |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
from elixir import * |
|||
from elixir.ext.encrypted import acts_as_encrypted |
|||
|
|||
class Person(Entity): |
|||
name = Field(Unicode) |
|||
password = Field(Unicode) |
|||
ssn = Field(Unicode) |
|||
acts_as_encrypted(for_fields=['password', 'ssn'], |
|||
with_secret='secret') |
|||
|
|||
The above Person entity will automatically encrypt and decrypt the password and |
|||
ssn columns on save, update, and load. Different secrets can be specified on |
|||
an entity by entity basis, for added security. |
|||
|
|||
**Important note**: instance attributes are encrypted in-place. This means that |
|||
if one of the encrypted attributes of an instance is accessed after the |
|||
instance has been flushed to the database (and thus encrypted), the value for |
|||
that attribute will be crypted in the in-memory object in addition to the |
|||
database row. |
|||
''' |
|||
|
|||
from Crypto.Cipher import Blowfish |
|||
from elixir.statements import Statement |
|||
from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, EXT_STOP |
|||
|
|||
try: |
|||
from sqlalchemy.orm import EXT_PASS |
|||
SA05orlater = False |
|||
except ImportError: |
|||
SA05orlater = True |
|||
|
|||
__all__ = ['acts_as_encrypted'] |
|||
__doc_all__ = [] |
|||
|
|||
|
|||
# |
|||
# encryption and decryption functions |
|||
# |
|||
|
|||
def encrypt_value(value, secret): |
|||
return Blowfish.new(secret, Blowfish.MODE_CFB) \ |
|||
.encrypt(value).encode('string_escape') |
|||
|
|||
def decrypt_value(value, secret): |
|||
return Blowfish.new(secret, Blowfish.MODE_CFB) \ |
|||
.decrypt(value.decode('string_escape')) |
|||
|
|||
|
|||
# |
|||
# acts_as_encrypted statement |
|||
# |
|||
|
|||
class ActsAsEncrypted(object): |
|||
|
|||
def __init__(self, entity, for_fields=[], with_secret='abcdef'): |
|||
|
|||
def perform_encryption(instance, encrypt=True): |
|||
encrypted = getattr(instance, '_elixir_encrypted', None) |
|||
if encrypted is encrypt: |
|||
# skipping encryption or decryption, as it is already done |
|||
return |
|||
else: |
|||
# marking instance as already encrypted/decrypted |
|||
instance._elixir_encrypted = encrypt |
|||
|
|||
if encrypt: |
|||
func = encrypt_value |
|||
else: |
|||
func = decrypt_value |
|||
|
|||
for column_name in for_fields: |
|||
current_value = getattr(instance, column_name) |
|||
if current_value: |
|||
setattr(instance, column_name, |
|||
func(current_value, with_secret)) |
|||
|
|||
def perform_decryption(instance): |
|||
perform_encryption(instance, encrypt=False) |
|||
|
|||
class EncryptedMapperExtension(MapperExtension): |
|||
|
|||
def before_insert(self, mapper, connection, instance): |
|||
perform_encryption(instance) |
|||
return EXT_CONTINUE |
|||
|
|||
def before_update(self, mapper, connection, instance): |
|||
perform_encryption(instance) |
|||
return EXT_CONTINUE |
|||
|
|||
if SA05orlater: |
|||
def reconstruct_instance(self, mapper, instance): |
|||
perform_decryption(instance) |
|||
# no special return value is required for |
|||
# reconstruct_instance, but you never know... |
|||
return EXT_CONTINUE |
|||
else: |
|||
def populate_instance(self, mapper, selectcontext, row, |
|||
instance, *args, **kwargs): |
|||
mapper.populate_instance(selectcontext, instance, row, |
|||
*args, **kwargs) |
|||
perform_decryption(instance) |
|||
# EXT_STOP because we already did populate the instance and |
|||
# the normal processing should not happen |
|||
return EXT_STOP |
|||
|
|||
# make sure that the entity's mapper has our mapper extension |
|||
entity._descriptor.add_mapper_extension(EncryptedMapperExtension()) |
|||
|
|||
|
|||
acts_as_encrypted = Statement(ActsAsEncrypted) |
|||
|
@ -0,0 +1,251 @@ |
|||
''' |
|||
This extension is DEPRECATED. Please use the orderinglist SQLAlchemy |
|||
extension instead. |
|||
|
|||
For details: |
|||
http://www.sqlalchemy.org/docs/05/reference/ext/orderinglist.html |
|||
|
|||
For an Elixir example: |
|||
http://elixir.ematia.de/trac/wiki/Recipes/UsingEntityForOrderedList |
|||
or |
|||
http://elixir.ematia.de/trac/browser/elixir/0.7.0/tests/test_o2m.py#L155 |
|||
|
|||
|
|||
|
|||
An ordered-list plugin for Elixir to help you make an entity be able to be |
|||
managed in a list-like way. Much inspiration comes from the Ruby on Rails |
|||
acts_as_list plugin, which is currently more full-featured than this plugin. |
|||
|
|||
Once you flag an entity with an `acts_as_list()` statement, a column will be |
|||
added to the entity called `position` which will be an integer column that is |
|||
managed for you by the plugin. You can pass an alternative column name to |
|||
the plugin using the `column_name` keyword argument. |
|||
|
|||
In addition, your entity will get a series of new methods attached to it, |
|||
including: |
|||
|
|||
+----------------------+------------------------------------------------------+ |
|||
| Method Name | Description | |
|||
+======================+======================================================+ |
|||
| ``move_lower`` | Move the item lower in the list | |
|||
+----------------------+------------------------------------------------------+ |
|||
| ``move_higher`` | Move the item higher in the list | |
|||
+----------------------+------------------------------------------------------+ |
|||
| ``move_to_bottom`` | Move the item to the bottom of the list | |
|||
+----------------------+------------------------------------------------------+ |
|||
| ``move_to_top`` | Move the item to the top of the list | |
|||
+----------------------+------------------------------------------------------+ |
|||
| ``move_to`` | Move the item to a specific position in the list | |
|||
+----------------------+------------------------------------------------------+ |
|||
|
|||
|
|||
Sometimes, your entities that represent list items will be a part of different |
|||
lists. To implement this behavior, simply pass the `acts_as_list` statement a |
|||
callable that returns a "qualifier" SQLAlchemy expression. This expression will |
|||
be added to the generated WHERE clauses used by the plugin. |
|||
|
|||
Example model usage: |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
from elixir import * |
|||
from elixir.ext.list import acts_as_list |
|||
|
|||
class ToDo(Entity): |
|||
subject = Field(String(128)) |
|||
owner = ManyToOne('Person') |
|||
|
|||
def qualify(self): |
|||
return ToDo.owner_id == self.owner_id |
|||
|
|||
acts_as_list(qualifier=qualify) |
|||
|
|||
class Person(Entity): |
|||
name = Field(String(64)) |
|||
todos = OneToMany('ToDo', order_by='position') |
|||
|
|||
|
|||
The above example can then be used to manage ordered todo lists for people. |
|||
Note that you must set the `order_by` property on the `Person.todo` relation in |
|||
order for the relation to respect the ordering. Here is an example of using |
|||
this model in practice: |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
p = Person.query.filter_by(name='Jonathan').one() |
|||
p.todos.append(ToDo(subject='Three')) |
|||
p.todos.append(ToDo(subject='Two')) |
|||
p.todos.append(ToDo(subject='One')) |
|||
session.commit(); session.clear() |
|||
|
|||
p = Person.query.filter_by(name='Jonathan').one() |
|||
p.todos[0].move_to_bottom() |
|||
p.todos[2].move_to_top() |
|||
session.commit(); session.clear() |
|||
|
|||
p = Person.query.filter_by(name='Jonathan').one() |
|||
assert p.todos[0].subject == 'One' |
|||
assert p.todos[1].subject == 'Two' |
|||
assert p.todos[2].subject == 'Three' |
|||
|
|||
|
|||
For more examples, refer to the unit tests for this plugin. |
|||
''' |
|||
|
|||
from elixir.statements import Statement |
|||
from elixir.events import before_insert, before_delete |
|||
from sqlalchemy import Column, Integer, select, func, literal, and_ |
|||
import warnings |
|||
|
|||
__all__ = ['acts_as_list'] |
|||
__doc_all__ = [] |
|||
|
|||
|
|||
def get_entity_where(instance): |
|||
clauses = [] |
|||
for column in instance.table.primary_key.columns: |
|||
instance_value = getattr(instance, column.name) |
|||
clauses.append(column == instance_value) |
|||
return and_(*clauses) |
|||
|
|||
|
|||
class ListEntityBuilder(object): |
|||
|
|||
def __init__(self, entity, qualifier=None, column_name='position'): |
|||
warnings.warn("The act_as_list extension is deprecated. Please use " |
|||
"SQLAlchemy's orderinglist extension instead", |
|||
DeprecationWarning, stacklevel=6) |
|||
self.entity = entity |
|||
self.qualifier_method = qualifier |
|||
self.column_name = column_name |
|||
|
|||
def create_non_pk_cols(self): |
|||
if self.entity._descriptor.autoload: |
|||
for c in self.entity.table.c: |
|||
if c.name == self.column_name: |
|||
self.position_column = c |
|||
if not hasattr(self, 'position_column'): |
|||
raise Exception( |
|||
"Could not find column '%s' in autoloaded table '%s', " |
|||
"needed by entity '%s'." % (self.column_name, |
|||
self.entity.table.name, self.entity.__name__)) |
|||
else: |
|||
self.position_column = Column(self.column_name, Integer) |
|||
self.entity._descriptor.add_column(self.position_column) |
|||
|
|||
def after_table(self): |
|||
position_column = self.position_column |
|||
position_column_name = self.column_name |
|||
|
|||
qualifier_method = self.qualifier_method |
|||
if not qualifier_method: |
|||
qualifier_method = lambda self: None |
|||
|
|||
def _init_position(self): |
|||
s = select( |
|||
[(func.max(position_column)+1).label('value')], |
|||
qualifier_method(self) |
|||
).union( |
|||
select([literal(1).label('value')]) |
|||
) |
|||
a = s.alias() |
|||
# we use a second func.max to get the maximum between 1 and the |
|||
# real max position if any exist |
|||
setattr(self, position_column_name, select([func.max(a.c.value)])) |
|||
|
|||
# Note that this method could be rewritten more simply like below, |
|||
# but because this extension is going to be deprecated anyway, |
|||
# I don't want to risk breaking something I don't want to maintain. |
|||
# setattr(self, position_column_name, select( |
|||
# [func.coalesce(func.max(position_column), 0) + 1], |
|||
# qualifier_method(self) |
|||
# )) |
|||
_init_position = before_insert(_init_position) |
|||
|
|||
def _shift_items(self): |
|||
self.table.update( |
|||
and_( |
|||
position_column > getattr(self, position_column_name), |
|||
qualifier_method(self) |
|||
), |
|||
values={ |
|||
position_column : position_column - 1 |
|||
} |
|||
).execute() |
|||
_shift_items = before_delete(_shift_items) |
|||
|
|||
def move_to_bottom(self): |
|||
# move the items that were above this item up one |
|||
self.table.update( |
|||
and_( |
|||
position_column >= getattr(self, position_column_name), |
|||
qualifier_method(self) |
|||
), |
|||
values = { |
|||
position_column : position_column - 1 |
|||
} |
|||
).execute() |
|||
|
|||
# move this item to the max position |
|||
# MySQL does not support the correlated subquery, so we need to |
|||
# execute the query (through scalar()). See ticket #34. |
|||
self.table.update( |
|||
get_entity_where(self), |
|||
values={ |
|||
position_column : select( |
|||
[func.max(position_column) + 1], |
|||
qualifier_method(self) |
|||
).scalar() |
|||
} |
|||
).execute() |
|||
|
|||
def move_to_top(self): |
|||
self.move_to(1) |
|||
|
|||
def move_to(self, position): |
|||
current_position = getattr(self, position_column_name) |
|||
|
|||
# determine which direction we're moving |
|||
if position < current_position: |
|||
where = and_( |
|||
position <= position_column, |
|||
position_column < current_position, |
|||
qualifier_method(self) |
|||
) |
|||
modifier = 1 |
|||
elif position > current_position: |
|||
where = and_( |
|||
current_position < position_column, |
|||
position_column <= position, |
|||
qualifier_method(self) |
|||
) |
|||
modifier = -1 |
|||
|
|||
# shift the items in between the current and new positions |
|||
self.table.update(where, values = { |
|||
position_column : position_column + modifier |
|||
}).execute() |
|||
|
|||
# update this item's position to the desired position |
|||
self.table.update(get_entity_where(self)) \ |
|||
.execute(**{position_column_name: position}) |
|||
|
|||
def move_lower(self): |
|||
# replace for ex.: p.todos.insert(x + 1, p.todos.pop(x)) |
|||
self.move_to(getattr(self, position_column_name) + 1) |
|||
|
|||
def move_higher(self): |
|||
self.move_to(getattr(self, position_column_name) - 1) |
|||
|
|||
|
|||
# attach new methods to entity |
|||
self.entity._init_position = _init_position |
|||
self.entity._shift_items = _shift_items |
|||
self.entity.move_lower = move_lower |
|||
self.entity.move_higher = move_higher |
|||
self.entity.move_to_bottom = move_to_bottom |
|||
self.entity.move_to_top = move_to_top |
|||
self.entity.move_to = move_to |
|||
|
|||
|
|||
acts_as_list = Statement(ListEntityBuilder) |
@ -0,0 +1,106 @@ |
|||
''' |
|||
DDL statements for Elixir. |
|||
|
|||
Entities having the perform_ddl statement, will automatically execute the |
|||
given DDL statement, at the given moment: ether before or after the table |
|||
creation in SQL. |
|||
|
|||
The 'when' argument can be either 'before-create' or 'after-create'. |
|||
The 'statement' argument can be one of: |
|||
|
|||
- a single string statement |
|||
- a list of string statements, in which case, each of them will be executed |
|||
in turn. |
|||
- a callable which should take no argument and return either a single string |
|||
or a list of strings. |
|||
|
|||
In each string statement, you may use the special '%(fullname)s' construct, |
|||
that will be replaced with the real table name including schema, if unknown |
|||
to you. Also, self explained '%(table)s' and '%(schema)s' may be used here. |
|||
|
|||
You would use this extension to handle non elixir sql statemts, like triggers |
|||
etc. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Movie(Entity): |
|||
title = Field(Unicode(30), primary_key=True) |
|||
year = Field(Integer) |
|||
|
|||
perform_ddl('after-create', |
|||
"insert into %(fullname)s values ('Alien', 1979)") |
|||
|
|||
preload_data is a more specific statement meant to preload data in your |
|||
entity table from a list of tuples (of fields values for each row). |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Movie(Entity): |
|||
title = Field(Unicode(30), primary_key=True) |
|||
year = Field(Integer) |
|||
|
|||
preload_data(('title', 'year'), |
|||
[(u'Alien', 1979), (u'Star Wars', 1977)]) |
|||
preload_data(('year', 'title'), |
|||
[(1982, u'Blade Runner')]) |
|||
preload_data(data=[(u'Batman', 1966)]) |
|||
''' |
|||
|
|||
from elixir.statements import Statement |
|||
from elixir.properties import EntityBuilder |
|||
from sqlalchemy import DDL |
|||
|
|||
__all__ = ['perform_ddl', 'preload_data'] |
|||
__doc_all__ = [] |
|||
|
|||
# |
|||
# the perform_ddl statement |
|||
# |
|||
class PerformDDLEntityBuilder(EntityBuilder): |
|||
|
|||
def __init__(self, entity, when, statement, on=None, context=None): |
|||
self.entity = entity |
|||
self.when = when |
|||
self.statement = statement |
|||
self.on = on |
|||
self.context = context |
|||
|
|||
def after_table(self): |
|||
statement = self.statement |
|||
if hasattr(statement, '__call__'): |
|||
statement = statement() |
|||
if not isinstance(statement, list): |
|||
statement = [statement] |
|||
for s in statement: |
|||
ddl = DDL(s, self.on, self.context) |
|||
ddl.execute_at(self.when, self.entity.table) |
|||
|
|||
perform_ddl = Statement(PerformDDLEntityBuilder) |
|||
|
|||
# |
|||
# the preload_data statement |
|||
# |
|||
class PreloadDataEntityBuilder(EntityBuilder): |
|||
|
|||
def __init__(self, entity, columns=None, data=None): |
|||
self.entity = entity |
|||
self.columns = columns |
|||
self.data = data |
|||
|
|||
def after_table(self): |
|||
all_columns = [col.name for col in self.entity.table.columns] |
|||
def onload(event, schema_item, connection): |
|||
columns = self.columns |
|||
if columns is None: |
|||
columns = all_columns |
|||
data = self.data |
|||
if hasattr(data, '__call__'): |
|||
data = data() |
|||
insert = schema_item.insert() |
|||
connection.execute(insert, |
|||
[dict(zip(columns, values)) for values in data]) |
|||
|
|||
self.entity.table.append_ddl_listener('after-create', onload) |
|||
|
|||
preload_data = Statement(PreloadDataEntityBuilder) |
|||
|
@ -0,0 +1,288 @@ |
|||
''' |
|||
A versioning plugin for Elixir. |
|||
|
|||
Entities that are marked as versioned with the `acts_as_versioned` statement |
|||
will automatically have a history table created and a timestamp and version |
|||
column added to their tables. In addition, versioned entities are provided |
|||
with four new methods: revert, revert_to, compare_with and get_as_of, and one |
|||
new attribute: versions. Entities with compound primary keys are supported. |
|||
|
|||
The `versions` attribute will contain a list of previous versions of the |
|||
instance, in increasing version number order. |
|||
|
|||
The `get_as_of` method will retrieve a previous version of the instance "as of" |
|||
a specified datetime. If the current version is the most recent, it will be |
|||
returned. |
|||
|
|||
The `revert` method will rollback the current instance to its previous version, |
|||
if possible. Once reverted, the current instance will be expired from the |
|||
session, and you will need to fetch it again to retrieve the now reverted |
|||
instance. |
|||
|
|||
The `revert_to` method will rollback the current instance to the specified |
|||
version number, if possibe. Once reverted, the current instance will be expired |
|||
from the session, and you will need to fetch it again to retrieve the now |
|||
reverted instance. |
|||
|
|||
The `compare_with` method will compare the instance with a previous version. A |
|||
dictionary will be returned with each field difference as an element in the |
|||
dictionary where the key is the field name and the value is a tuple of the |
|||
format (current_value, version_value). Version instances also have a |
|||
`compare_with` method so that two versions can be compared. |
|||
|
|||
Also included in the module is a `after_revert` decorator that can be used to |
|||
decorate methods on the versioned entity that will be called following that |
|||
instance being reverted. |
|||
|
|||
The acts_as_versioned statement also accepts an optional `ignore` argument |
|||
that consists of a list of strings, specifying names of fields. Changes in |
|||
those fields will not result in a version increment. In addition, you can |
|||
pass in an optional `check_concurrent` argument, which will use SQLAlchemy's |
|||
built-in optimistic concurrency mechanisms. |
|||
|
|||
Note that relationships that are stored in mapping tables will not be included |
|||
as part of the versioning process, and will need to be handled manually. Only |
|||
values within the entity's main table will be versioned into the history table. |
|||
''' |
|||
|
|||
from datetime import datetime |
|||
import inspect |
|||
|
|||
from sqlalchemy import Table, Column, and_, desc |
|||
from sqlalchemy.orm import mapper, MapperExtension, EXT_CONTINUE, \ |
|||
object_session |
|||
|
|||
from elixir import Integer, DateTime |
|||
from elixir.statements import Statement |
|||
from elixir.properties import EntityBuilder |
|||
from elixir.entity import getmembers |
|||
|
|||
__all__ = ['acts_as_versioned', 'after_revert'] |
|||
__doc_all__ = [] |
|||
|
|||
# |
|||
# utility functions |
|||
# |
|||
|
|||
def get_entity_where(instance): |
|||
clauses = [] |
|||
for column in instance.table.primary_key.columns: |
|||
instance_value = getattr(instance, column.name) |
|||
clauses.append(column==instance_value) |
|||
return and_(*clauses) |
|||
|
|||
|
|||
def get_history_where(instance): |
|||
clauses = [] |
|||
history_columns = instance.__history_table__.primary_key.columns |
|||
for column in instance.table.primary_key.columns: |
|||
instance_value = getattr(instance, column.name) |
|||
history_column = getattr(history_columns, column.name) |
|||
clauses.append(history_column==instance_value) |
|||
return and_(*clauses) |
|||
|
|||
|
|||
# |
|||
# a mapper extension to track versions on insert, update, and delete |
|||
# |
|||
|
|||
class VersionedMapperExtension(MapperExtension): |
|||
def before_insert(self, mapper, connection, instance): |
|||
version_colname, timestamp_colname = \ |
|||
instance.__class__.__versioned_column_names__ |
|||
setattr(instance, version_colname, 1) |
|||
setattr(instance, timestamp_colname, datetime.now()) |
|||
return EXT_CONTINUE |
|||
|
|||
def before_update(self, mapper, connection, instance): |
|||
old_values = instance.table.select(get_entity_where(instance)) \ |
|||
.execute().fetchone() |
|||
|
|||
# SA might've flagged this for an update even though it didn't change. |
|||
# This occurs when a relation is updated, thus marking this instance |
|||
# for a save/update operation. We check here against the last version |
|||
# to ensure we really should save this version and update the version |
|||
# data. |
|||
ignored = instance.__class__.__ignored_fields__ |
|||
version_colname, timestamp_colname = \ |
|||
instance.__class__.__versioned_column_names__ |
|||
for key in instance.table.c.keys(): |
|||
if key in ignored: |
|||
continue |
|||
if getattr(instance, key) != old_values[key]: |
|||
# the instance was really updated, so we create a new version |
|||
dict_values = dict(old_values.items()) |
|||
connection.execute( |
|||
instance.__class__.__history_table__.insert(), dict_values) |
|||
old_version = getattr(instance, version_colname) |
|||
setattr(instance, version_colname, old_version + 1) |
|||
setattr(instance, timestamp_colname, datetime.now()) |
|||
break |
|||
|
|||
return EXT_CONTINUE |
|||
|
|||
def before_delete(self, mapper, connection, instance): |
|||
connection.execute(instance.__history_table__.delete( |
|||
get_history_where(instance) |
|||
)) |
|||
return EXT_CONTINUE |
|||
|
|||
|
|||
versioned_mapper_extension = VersionedMapperExtension() |
|||
|
|||
|
|||
# |
|||
# the acts_as_versioned statement |
|||
# |
|||
|
|||
class VersionedEntityBuilder(EntityBuilder): |
|||
|
|||
def __init__(self, entity, ignore=None, check_concurrent=False, |
|||
column_names=None): |
|||
self.entity = entity |
|||
self.add_mapper_extension(versioned_mapper_extension) |
|||
#TODO: we should rather check that the version_id_col isn't set |
|||
# externally |
|||
self.check_concurrent = check_concurrent |
|||
|
|||
# Changes in these fields will be ignored |
|||
if column_names is None: |
|||
column_names = ['version', 'timestamp'] |
|||
entity.__versioned_column_names__ = column_names |
|||
if ignore is None: |
|||
ignore = [] |
|||
ignore.extend(column_names) |
|||
entity.__ignored_fields__ = ignore |
|||
|
|||
def create_non_pk_cols(self): |
|||
# add a version column to the entity, along with a timestamp |
|||
version_colname, timestamp_colname = \ |
|||
self.entity.__versioned_column_names__ |
|||
#XXX: fail in case the columns already exist? |
|||
#col_names = [col.name for col in self.entity._descriptor.columns] |
|||
#if version_colname not in col_names: |
|||
self.add_table_column(Column(version_colname, Integer)) |
|||
#if timestamp_colname not in col_names: |
|||
self.add_table_column(Column(timestamp_colname, DateTime)) |
|||
|
|||
# add a concurrent_version column to the entity, if required |
|||
if self.check_concurrent: |
|||
self.entity._descriptor.version_id_col = 'concurrent_version' |
|||
|
|||
# we copy columns from the main entity table, so we need it to exist first |
|||
def after_table(self): |
|||
entity = self.entity |
|||
version_colname, timestamp_colname = \ |
|||
entity.__versioned_column_names__ |
|||
|
|||
# look for events |
|||
after_revert_events = [] |
|||
for name, func in getmembers(entity, inspect.ismethod): |
|||
if getattr(func, '_elixir_after_revert', False): |
|||
after_revert_events.append(func) |
|||
|
|||
# create a history table for the entity |
|||
skipped_columns = [version_colname] |
|||
if self.check_concurrent: |
|||
skipped_columns.append('concurrent_version') |
|||
|
|||
columns = [ |
|||
column.copy() for column in entity.table.c |
|||
if column.name not in skipped_columns |
|||
] |
|||
columns.append(Column(version_colname, Integer, primary_key=True)) |
|||
table = Table(entity.table.name + '_history', entity.table.metadata, |
|||
*columns |
|||
) |
|||
entity.__history_table__ = table |
|||
|
|||
# create an object that represents a version of this entity |
|||
class Version(object): |
|||
pass |
|||
|
|||
# map the version class to the history table for this entity |
|||
Version.__name__ = entity.__name__ + 'Version' |
|||
Version.__versioned_entity__ = entity |
|||
mapper(Version, entity.__history_table__) |
|||
|
|||
version_col = getattr(table.c, version_colname) |
|||
timestamp_col = getattr(table.c, timestamp_colname) |
|||
|
|||
# attach utility methods and properties to the entity |
|||
def get_versions(self): |
|||
v = object_session(self).query(Version) \ |
|||
.filter(get_history_where(self)) \ |
|||
.order_by(version_col) \ |
|||
.all() |
|||
# history contains all the previous records. |
|||
# Add the current one to the list to get all the versions |
|||
v.append(self) |
|||
return v |
|||
|
|||
def get_as_of(self, dt): |
|||
# if the passed in timestamp is older than our current version's |
|||
# time stamp, then the most recent version is our current version |
|||
if getattr(self, timestamp_colname) < dt: |
|||
return self |
|||
|
|||
# otherwise, we need to look to the history table to get our |
|||
# older version |
|||
sess = object_session(self) |
|||
query = sess.query(Version) \ |
|||
.filter(and_(get_history_where(self), |
|||
timestamp_col <= dt)) \ |
|||
.order_by(desc(timestamp_col)).limit(1) |
|||
return query.first() |
|||
|
|||
def revert_to(self, to_version): |
|||
if isinstance(to_version, Version): |
|||
to_version = getattr(to_version, version_colname) |
|||
|
|||
old_version = table.select(and_( |
|||
get_history_where(self), |
|||
version_col == to_version |
|||
)).execute().fetchone() |
|||
|
|||
entity.table.update(get_entity_where(self)).execute( |
|||
dict(old_version.items()) |
|||
) |
|||
|
|||
table.delete(and_(get_history_where(self), |
|||
version_col >= to_version)).execute() |
|||
self.expire() |
|||
for event in after_revert_events: |
|||
event(self) |
|||
|
|||
def revert(self): |
|||
assert getattr(self, version_colname) > 1 |
|||
self.revert_to(getattr(self, version_colname) - 1) |
|||
|
|||
def compare_with(self, version): |
|||
differences = {} |
|||
for column in self.table.c: |
|||
if column.name in (version_colname, 'concurrent_version'): |
|||
continue |
|||
this = getattr(self, column.name) |
|||
that = getattr(version, column.name) |
|||
if this != that: |
|||
differences[column.name] = (this, that) |
|||
return differences |
|||
|
|||
entity.versions = property(get_versions) |
|||
entity.get_as_of = get_as_of |
|||
entity.revert_to = revert_to |
|||
entity.revert = revert |
|||
entity.compare_with = compare_with |
|||
Version.compare_with = compare_with |
|||
|
|||
acts_as_versioned = Statement(VersionedEntityBuilder) |
|||
|
|||
|
|||
def after_revert(func): |
|||
""" |
|||
Decorator for watching for revert events. |
|||
""" |
|||
func._elixir_after_revert = True |
|||
return func |
|||
|
|||
|
@ -0,0 +1,191 @@ |
|||
''' |
|||
This module provides support for defining the fields (columns) of your |
|||
entities. Elixir currently supports two syntaxes to do so: the default |
|||
`Attribute-based syntax`_ as well as the has_field_ DSL statement. |
|||
|
|||
Attribute-based syntax |
|||
---------------------- |
|||
|
|||
Here is a quick example of how to use the object-oriented syntax. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Person(Entity): |
|||
id = Field(Integer, primary_key=True) |
|||
name = Field(String(50), required=True) |
|||
ssn = Field(String(50), unique=True) |
|||
biography = Field(Text) |
|||
join_date = Field(DateTime, default=datetime.datetime.now) |
|||
photo = Field(Binary, deferred=True) |
|||
_email = Field(String(20), colname='email', synonym='email') |
|||
|
|||
def _set_email(self, email): |
|||
self._email = email |
|||
def _get_email(self): |
|||
return self._email |
|||
email = property(_get_email, _set_email) |
|||
|
|||
|
|||
The Field class takes one mandatory argument, which is its type. Please refer |
|||
to SQLAlchemy documentation for a list of `types supported by SQLAlchemy |
|||
<http://www.sqlalchemy.org/docs/05/reference/sqlalchemy/types.html>`_. |
|||
|
|||
Following that first mandatory argument, fields can take any number of |
|||
optional keyword arguments. Please note that all the **arguments** that are |
|||
**not specifically processed by Elixir**, as mentioned in the documentation |
|||
below **are passed on to the SQLAlchemy ``Column`` object**. Please refer to |
|||
the `SQLAlchemy Column object's documentation |
|||
<http://www.sqlalchemy.org/docs/05/reference/sqlalchemy/schema.html |
|||
#sqlalchemy.schema.Column>`_ for more details about other |
|||
supported keyword arguments. |
|||
|
|||
The following Elixir-specific arguments are supported: |
|||
|
|||
+-------------------+---------------------------------------------------------+ |
|||
| Argument Name | Description | |
|||
+===================+=========================================================+ |
|||
| ``required`` | Specify whether or not this field can be set to None | |
|||
| | (left without a value). Defaults to ``False``, unless | |
|||
| | the field is a primary key. | |
|||
+-------------------+---------------------------------------------------------+ |
|||
| ``colname`` | Specify a custom name for the column of this field. By | |
|||
| | default the column will have the same name as the | |
|||
| | attribute. | |
|||
+-------------------+---------------------------------------------------------+ |
|||
| ``deferred`` | Specify whether this particular column should be | |
|||
| | fetched by default (along with the other columns) when | |
|||
| | an instance of the entity is fetched from the database | |
|||
| | or rather only later on when this particular column is | |
|||
| | first referenced. This can be useful when one wants to | |
|||
| | avoid loading a large text or binary field into memory | |
|||
| | when its not needed. Individual columns can be lazy | |
|||
| | loaded by themselves (by using ``deferred=True``) | |
|||
| | or placed into groups that lazy-load together (by using | |
|||
| | ``deferred`` = `"group_name"`). | |
|||
+-------------------+---------------------------------------------------------+ |
|||
| ``synonym`` | Specify a synonym name for this field. The field will | |
|||
| | also be usable under that name in keyword-based Query | |
|||
| | functions such as filter_by. The Synonym class (see the | |
|||
| | `properties` module) provides a similar functionality | |
|||
| | with an (arguably) nicer syntax, but a limited scope. | |
|||
+-------------------+---------------------------------------------------------+ |
|||
|
|||
has_field |
|||
--------- |
|||
|
|||
The `has_field` statement allows you to define fields one at a time. |
|||
|
|||
The first argument is the name of the field, the second is its type. Following |
|||
these, any number of keyword arguments can be specified for additional |
|||
behavior. The following arguments are supported: |
|||
|
|||
+-------------------+---------------------------------------------------------+ |
|||
| Argument Name | Description | |
|||
+===================+=========================================================+ |
|||
| ``through`` | Specify a relation name to go through. This field will | |
|||
| | not exist as a column on the database but will be a | |
|||
| | property which automatically proxy values to the | |
|||
| | ``attribute`` attribute of the object pointed to by the | |
|||
| | relation. If the ``attribute`` argument is not present, | |
|||
| | the name of the current field will be used. In an | |
|||
| | has_field statement, you can only proxy through a | |
|||
| | belongs_to or an has_one relationship. | |
|||
+-------------------+---------------------------------------------------------+ |
|||
| ``attribute`` | Name of the "endpoint" attribute to proxy to. This | |
|||
| | should only be used in combination with the ``through`` | |
|||
| | argument. | |
|||
+-------------------+---------------------------------------------------------+ |
|||
|
|||
|
|||
Here is a quick example of how to use ``has_field``. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Person(Entity): |
|||
has_field('id', Integer, primary_key=True) |
|||
has_field('name', String(50)) |
|||
''' |
|||
from sqlalchemy import Column |
|||
from sqlalchemy.orm import deferred, synonym |
|||
from sqlalchemy.ext.associationproxy import association_proxy |
|||
|
|||
from elixir.statements import ClassMutator |
|||
from elixir.properties import Property |
|||
|
|||
__doc_all__ = ['Field'] |
|||
|
|||
|
|||
class Field(Property): |
|||
''' |
|||
Represents the definition of a 'field' on an entity. |
|||
|
|||
This class represents a column on the table where the entity is stored. |
|||
''' |
|||
|
|||
def __init__(self, type, *args, **kwargs): |
|||
super(Field, self).__init__() |
|||
|
|||
self.colname = kwargs.pop('colname', None) |
|||
self.synonym = kwargs.pop('synonym', None) |
|||
self.deferred = kwargs.pop('deferred', False) |
|||
if 'required' in kwargs: |
|||
kwargs['nullable'] = not kwargs.pop('required') |
|||
self.type = type |
|||
self.primary_key = kwargs.get('primary_key', False) |
|||
|
|||
self.column = None |
|||
self.property = None |
|||
|
|||
self.args = args |
|||
self.kwargs = kwargs |
|||
|
|||
def attach(self, entity, name): |
|||
# If no colname was defined (through the 'colname' kwarg), set |
|||
# it to the name of the attr. |
|||
if self.colname is None: |
|||
self.colname = name |
|||
super(Field, self).attach(entity, name) |
|||
|
|||
def create_pk_cols(self): |
|||
if self.primary_key: |
|||
self.create_col() |
|||
|
|||
def create_non_pk_cols(self): |
|||
if not self.primary_key: |
|||
self.create_col() |
|||
|
|||
def create_col(self): |
|||
self.column = Column(self.colname, self.type, |
|||
*self.args, **self.kwargs) |
|||
self.add_table_column(self.column) |
|||
|
|||
def create_properties(self): |
|||
if self.deferred: |
|||
group = None |
|||
if isinstance(self.deferred, basestring): |
|||
group = self.deferred |
|||
self.property = deferred(self.column, group=group) |
|||
elif self.name != self.colname: |
|||
# if the property name is different from the column name, we need |
|||
# to add an explicit property (otherwise nothing is needed as it's |
|||
# done automatically by SA) |
|||
self.property = self.column |
|||
|
|||
if self.property is not None: |
|||
self.add_mapper_property(self.name, self.property) |
|||
|
|||
if self.synonym: |
|||
self.add_mapper_property(self.synonym, synonym(self.name)) |
|||
|
|||
|
|||
def has_field_handler(entity, name, *args, **kwargs): |
|||
if 'through' in kwargs: |
|||
setattr(entity, name, |
|||
association_proxy(kwargs.pop('through'), |
|||
kwargs.pop('attribute', name), |
|||
**kwargs)) |
|||
return |
|||
field = Field(*args, **kwargs) |
|||
field.attach(entity, name) |
|||
|
|||
has_field = ClassMutator(has_field_handler) |
@ -0,0 +1,285 @@ |
|||
''' |
|||
This module provides support for defining several options on your Elixir |
|||
entities. There are three different kinds of options that can be set |
|||
up, and for this there are three different statements: using_options_, |
|||
using_table_options_ and using_mapper_options_. |
|||
|
|||
Alternatively, these options can be set on all Elixir entities by modifying |
|||
the `options_defaults` dictionary before defining any entity. |
|||
|
|||
`using_options` |
|||
--------------- |
|||
The 'using_options' DSL statement allows you to set up some additional |
|||
behaviors on your model objects, including table names, ordering, and |
|||
more. To specify an option, simply supply the option as a keyword |
|||
argument onto the statement, as follows: |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Person(Entity): |
|||
name = Field(Unicode(64)) |
|||
|
|||
using_options(shortnames=True, order_by='name') |
|||
|
|||
The list of supported arguments are as follows: |
|||
|
|||
+---------------------+-------------------------------------------------------+ |
|||
| Option Name | Description | |
|||
+=====================+=======================================================+ |
|||
| ``inheritance`` | Specify the type of inheritance this entity must use. | |
|||
| | It can be one of ``single``, ``concrete`` or | |
|||
| | ``multi``. Defaults to ``single``. | |
|||
| | Note that polymorphic concrete inheritance is | |
|||
| | currently not implemented. See: | |
|||
| | http://www.sqlalchemy.org/docs/05/mappers.html | |
|||
| | #mapping-class-inheritance-hierarchies for an | |
|||
| | explanation of the different kinds of inheritances. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``abstract`` | Set 'abstract'=True to declare abstract entity. | |
|||
| | Abstract base classes are useful when you want to put | |
|||
| | some common information into a number of other | |
|||
| | entities. Abstract entity will not be used to create | |
|||
| | any database table. Instead, when it is used as a base| |
|||
| | class for other entity, its fields will be added to | |
|||
| | those of the child class. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``polymorphic`` | Whether the inheritance should be polymorphic or not. | |
|||
| | Defaults to ``True``. The column used to store the | |
|||
| | type of each row is named "row_type" by default. You | |
|||
| | can change this by passing the desired name for the | |
|||
| | column to this argument. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``identity`` | Specify a custom polymorphic identity. When using | |
|||
| | polymorphic inheritance, this value (usually a | |
|||
| | string) will represent this particular entity (class) | |
|||
| | . It will be used to differentiate it from other | |
|||
| | entities (classes) in your inheritance hierarchy when | |
|||
| | loading from the database instances of different | |
|||
| | entities in that hierarchy at the same time. | |
|||
| | This value will be stored by default in the | |
|||
| | "row_type" column of the entity's table (see above). | |
|||
| | You can either provide a | |
|||
| | plain string or a callable. The callable will be | |
|||
| | given the entity (ie class) as argument and must | |
|||
| | return a value (usually a string) representing the | |
|||
| | polymorphic identity of that entity. | |
|||
| | By default, this value is automatically generated: it | |
|||
| | is the name of the entity lower-cased. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``metadata`` | Specify a custom MetaData for this entity. | |
|||
| | By default, entities uses the global | |
|||
| | ``elixir.metadata``. | |
|||
| | This option can also be set for all entities of a | |
|||
| | module by setting the ``__metadata__`` attribute of | |
|||
| | that module. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``autoload`` | Automatically load column definitions from the | |
|||
| | existing database table. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``tablename`` | Specify a custom tablename. You can either provide a | |
|||
| | plain string or a callable. The callable will be | |
|||
| | given the entity (ie class) as argument and must | |
|||
| | return a string representing the name of the table | |
|||
| | for that entity. By default, the tablename is | |
|||
| | automatically generated: it is a concatenation of the | |
|||
| | full module-path to the entity and the entity (class) | |
|||
| | name itself. The result is lower-cased and separated | |
|||
| | by underscores ("_"), eg.: for an entity named | |
|||
| | "MyEntity" in the module "project1.model", the | |
|||
| | generated table name will be | |
|||
| | "project1_model_myentity". | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``shortnames`` | Specify whether or not the automatically generated | |
|||
| | table names include the full module-path | |
|||
| | to the entity. If ``shortnames`` is ``True``, only | |
|||
| | the entity name is used. Defaults to ``False``. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``auto_primarykey`` | If given as string, it will represent the | |
|||
| | auto-primary-key's column name. If this option | |
|||
| | is True, it will allow auto-creation of a primary | |
|||
| | key if there's no primary key defined for the | |
|||
| | corresponding entity. If this option is False, | |
|||
| | it will disallow auto-creation of a primary key. | |
|||
| | Defaults to ``True``. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``version_id_col`` | If this option is True, it will create a version | |
|||
| | column automatically using the default name. If given | |
|||
| | as string, it will create the column using that name. | |
|||
| | This can be used to prevent concurrent modifications | |
|||
| | to the entity's table rows (i.e. it will raise an | |
|||
| | exception if it happens). Defaults to ``False``. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``order_by`` | How to order select results. Either a string or a | |
|||
| | list of strings, composed of the field name, | |
|||
| | optionally lead by a minus (for descending order). | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``session`` | Specify a custom contextual session for this entity. | |
|||
| | By default, entities uses the global | |
|||
| | ``elixir.session``. | |
|||
| | This option takes a ``ScopedSession`` object or | |
|||
| | ``None``. In the later case your entity will be | |
|||
| | mapped using a non-contextual mapper which requires | |
|||
| | manual session management, as seen in pure SQLAlchemy.| |
|||
| | This option can also be set for all entities of a | |
|||
| | module by setting the ``__session__`` attribute of | |
|||
| | that module. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``autosetup`` | DEPRECATED. Specify whether that entity will contain | |
|||
| | automatic setup triggers. | |
|||
| | That is if this entity will be | |
|||
| | automatically setup (along with all other entities | |
|||
| | which were already declared) if any of the following | |
|||
| | condition happen: some of its attributes are accessed | |
|||
| | ('c', 'table', 'mapper' or 'query'), instanciated | |
|||
| | (called) or the create_all method of this entity's | |
|||
| | metadata is called. Defaults to ``False``. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
| ``allowcoloverride``| Specify whether it is allowed to override columns. | |
|||
| | By default, Elixir forbids you to add a column to an | |
|||
| | entity's table which already exist in that table. If | |
|||
| | you set this option to ``True`` it will skip that | |
|||
| | check. Use with care as it is easy to shoot oneself | |
|||
| | in the foot when overriding columns. | |
|||
+---------------------+-------------------------------------------------------+ |
|||
|
|||
For examples, please refer to the examples and unit tests. |
|||
|
|||
`using_table_options` |
|||
--------------------- |
|||
The 'using_table_options' DSL statement allows you to set up some |
|||
additional options on your entity table. It is meant only to handle the |
|||
options which are not supported directly by the 'using_options' statement. |
|||
By opposition to the 'using_options' statement, these options are passed |
|||
directly to the underlying SQLAlchemy Table object (both non-keyword arguments |
|||
and keyword arguments) without any processing. |
|||
|
|||
For further information, please refer to the `SQLAlchemy table's documentation |
|||
<http://www.sqlalchemy.org/docs/05/reference/sqlalchemy/schema.html |
|||
#sqlalchemy.schema.Table>`_. |
|||
|
|||
You might also be interested in the section about `constraints |
|||
<http://www.sqlalchemy.org/docs/05/metadata.html |
|||
#defining-constraints-and-indexes>`_. |
|||
|
|||
`using_mapper_options` |
|||
---------------------- |
|||
The 'using_mapper_options' DSL statement allows you to set up some |
|||
additional options on your entity mapper. It is meant only to handle the |
|||
options which are not supported directly by the 'using_options' statement. |
|||
By opposition to the 'using_options' statement, these options are passed |
|||
directly to the underlying SQLAlchemy mapper (as keyword arguments) |
|||
without any processing. |
|||
|
|||
For further information, please refer to the `SQLAlchemy mapper |
|||
function's documentation |
|||
<http://www.sqlalchemy.org/docs/05/reference/orm/mapping.html |
|||
#sqlalchemy.orm.mapper>`_. |
|||
|
|||
`using_options_defaults` |
|||
------------------------ |
|||
The 'using_options_defaults' DSL statement allows you to set up some |
|||
default options on a custom base class. These will be used as the default value |
|||
for options of all its subclasses. Note that any option not set within the |
|||
using_options_defaults (nor specifically on a particular Entity) will use the |
|||
global defaults, so you don't have to provide a default value for all options, |
|||
but only those you want to change. Please also note that this statement does |
|||
not work on normal entities, and the normal using_options statement does not |
|||
work on base classes (because normal options do not and should not propagate to |
|||
the children classes). |
|||
''' |
|||
|
|||
from sqlalchemy import Integer, String |
|||
|
|||
from elixir.statements import ClassMutator |
|||
|
|||
__doc_all__ = ['options_defaults'] |
|||
|
|||
OLD_M2MCOL_NAMEFORMAT = "%(tablename)s_%(key)s%(numifself)s" |
|||
ALTERNATE_M2MCOL_NAMEFORMAT = "%(inversename)s_%(key)s" |
|||
|
|||
def default_m2m_column_formatter(data): |
|||
if data['selfref']: |
|||
return ALTERNATE_M2MCOL_NAMEFORMAT % data |
|||
else: |
|||
return OLD_M2MCOL_NAMEFORMAT % data |
|||
|
|||
NEW_M2MCOL_NAMEFORMAT = default_m2m_column_formatter |
|||
|
|||
# format constants |
|||
FKCOL_NAMEFORMAT = "%(relname)s_%(key)s" |
|||
M2MCOL_NAMEFORMAT = NEW_M2MCOL_NAMEFORMAT |
|||
CONSTRAINT_NAMEFORMAT = "%(tablename)s_%(colnames)s_fk" |
|||
MULTIINHERITANCECOL_NAMEFORMAT = "%(entity)s_%(key)s" |
|||
|
|||
# other global constants |
|||
DEFAULT_AUTO_PRIMARYKEY_NAME = "id" |
|||
DEFAULT_AUTO_PRIMARYKEY_TYPE = Integer |
|||
DEFAULT_VERSION_ID_COL_NAME = "row_version" |
|||
DEFAULT_POLYMORPHIC_COL_NAME = "row_type" |
|||
POLYMORPHIC_COL_SIZE = 40 |
|||
POLYMORPHIC_COL_TYPE = String(POLYMORPHIC_COL_SIZE) |
|||
|
|||
# debugging/migration help |
|||
MIGRATION_TO_07_AID = False |
|||
|
|||
# |
|||
options_defaults = dict( |
|||
abstract=False, |
|||
autosetup=False, |
|||
inheritance='single', |
|||
polymorphic=True, |
|||
identity=None, |
|||
autoload=False, |
|||
tablename=None, |
|||
shortnames=False, |
|||
auto_primarykey=True, |
|||
version_id_col=False, |
|||
allowcoloverride=False, |
|||
order_by=None, |
|||
resolve_root=None, |
|||
mapper_options={}, |
|||
table_options={} |
|||
) |
|||
|
|||
valid_options = options_defaults.keys() + [ |
|||
'metadata', |
|||
'session', |
|||
'collection' |
|||
] |
|||
|
|||
|
|||
def using_options_defaults_handler(entity, **kwargs): |
|||
for kwarg in kwargs: |
|||
if kwarg not in valid_options: |
|||
raise Exception("'%s' is not a valid option for Elixir entities." |
|||
% kwarg) |
|||
|
|||
# We use __dict__ instead of hasattr to not check its presence within the |
|||
# parent, and thus update the parent dict instead of creating a local dict. |
|||
if not entity.__dict__.get('options_defaults'): |
|||
entity.options_defaults = {} |
|||
entity.options_defaults.update(kwargs) |
|||
|
|||
|
|||
def using_options_handler(entity, *args, **kwargs): |
|||
for kwarg in kwargs: |
|||
if kwarg in valid_options: |
|||
setattr(entity._descriptor, kwarg, kwargs[kwarg]) |
|||
else: |
|||
raise Exception("'%s' is not a valid option for Elixir entities." |
|||
% kwarg) |
|||
|
|||
|
|||
def using_table_options_handler(entity, *args, **kwargs): |
|||
entity._descriptor.table_args.extend(list(args)) |
|||
entity._descriptor.table_options.update(kwargs) |
|||
|
|||
|
|||
def using_mapper_options_handler(entity, *args, **kwargs): |
|||
entity._descriptor.mapper_options.update(kwargs) |
|||
|
|||
|
|||
using_options_defaults = ClassMutator(using_options_defaults_handler) |
|||
using_options = ClassMutator(using_options_handler) |
|||
using_table_options = ClassMutator(using_table_options_handler) |
|||
using_mapper_options = ClassMutator(using_mapper_options_handler) |
@ -0,0 +1,244 @@ |
|||
''' |
|||
This module provides support for defining properties on your entities. It both |
|||
provides, the `Property` class which acts as a building block for common |
|||
properties such as fields and relationships (for those, please consult the |
|||
corresponding modules), but also provides some more specialized properties, |
|||
such as `ColumnProperty` and `Synonym`. It also provides the GenericProperty |
|||
class which allows you to wrap any SQLAlchemy property, and its DSL-syntax |
|||
equivalent: has_property_. |
|||
|
|||
`has_property` |
|||
-------------- |
|||
The ``has_property`` statement allows you to define properties which rely on |
|||
their entity's table (and columns) being defined before they can be declared |
|||
themselves. The `has_property` statement takes two arguments: first the name of |
|||
the property to be defined and second a function (often given as an anonymous |
|||
lambda) taking one argument and returning the desired SQLAlchemy property. That |
|||
function will be called whenever the entity table is completely defined, and |
|||
will be given the .c attribute of the entity as argument (as a way to access |
|||
the entity columns). |
|||
|
|||
Here is a quick example of how to use ``has_property``. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class OrderLine(Entity): |
|||
has_field('quantity', Float) |
|||
has_field('unit_price', Float) |
|||
has_property('price', |
|||
lambda c: column_property( |
|||
(c.quantity * c.unit_price).label('price'))) |
|||
''' |
|||
|
|||
from elixir.statements import PropertyStatement |
|||
from sqlalchemy.orm import column_property, synonym |
|||
|
|||
__doc_all__ = ['EntityBuilder', 'Property', 'GenericProperty', |
|||
'ColumnProperty'] |
|||
|
|||
class EntityBuilder(object): |
|||
''' |
|||
Abstract base class for all entity builders. An Entity builder is a class |
|||
of objects which can be added to an Entity (usually by using special |
|||
properties or statements) to "build" that entity. Building an entity, |
|||
meaning to add columns to its "main" table, create other tables, add |
|||
properties to its mapper, ... To do so an EntityBuilder must override the |
|||
corresponding method(s). This is to ensure the different operations happen |
|||
in the correct order (for example, that the table is fully created before |
|||
the mapper that use it is defined). |
|||
''' |
|||
def create_pk_cols(self): |
|||
pass |
|||
|
|||
def create_non_pk_cols(self): |
|||
pass |
|||
|
|||
def before_table(self): |
|||
pass |
|||
|
|||
def create_tables(self): |
|||
''' |
|||
Subclasses may override this method to create tables. |
|||
''' |
|||
|
|||
def after_table(self): |
|||
pass |
|||
|
|||
def create_properties(self): |
|||
''' |
|||
Subclasses may override this method to add properties to the involved |
|||
entity. |
|||
''' |
|||
|
|||
def before_mapper(self): |
|||
pass |
|||
|
|||
def after_mapper(self): |
|||
pass |
|||
|
|||
def finalize(self): |
|||
pass |
|||
|
|||
# helper methods |
|||
def add_table_column(self, column): |
|||
self.entity._descriptor.add_column(column) |
|||
|
|||
def add_mapper_property(self, name, prop): |
|||
self.entity._descriptor.add_property(name, prop) |
|||
|
|||
def add_mapper_extension(self, ext): |
|||
self.entity._descriptor.add_mapper_extension(ext) |
|||
|
|||
|
|||
class CounterMeta(type): |
|||
''' |
|||
A simple meta class which adds a ``_counter`` attribute to the instances of |
|||
the classes it is used on. This counter is simply incremented for each new |
|||
instance. |
|||
''' |
|||
counter = 0 |
|||
|
|||
def __call__(self, *args, **kwargs): |
|||
instance = type.__call__(self, *args, **kwargs) |
|||
instance._counter = CounterMeta.counter |
|||
CounterMeta.counter += 1 |
|||
return instance |
|||
|
|||
|
|||
class Property(EntityBuilder): |
|||
''' |
|||
Abstract base class for all properties of an Entity. |
|||
''' |
|||
__metaclass__ = CounterMeta |
|||
|
|||
def __init__(self, *args, **kwargs): |
|||
self.entity = None |
|||
self.name = None |
|||
|
|||
def attach(self, entity, name): |
|||
"""Attach this property to its entity, using 'name' as name. |
|||
|
|||
Properties will be attached in the order they were declared. |
|||
""" |
|||
self.entity = entity |
|||
self.name = name |
|||
|
|||
# register this property as a builder |
|||
entity._descriptor.builders.append(self) |
|||
|
|||
def __repr__(self): |
|||
return "Property(%s, %s)" % (self.name, self.entity) |
|||
|
|||
|
|||
class GenericProperty(Property): |
|||
''' |
|||
Generic catch-all class to wrap an SQLAlchemy property. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class OrderLine(Entity): |
|||
quantity = Field(Float) |
|||
unit_price = Field(Numeric) |
|||
price = GenericProperty(lambda c: column_property( |
|||
(c.quantity * c.unit_price).label('price'))) |
|||
''' |
|||
|
|||
def __init__(self, prop, *args, **kwargs): |
|||
super(GenericProperty, self).__init__(*args, **kwargs) |
|||
self.prop = prop |
|||
#XXX: move this to Property? |
|||
self.args = args |
|||
self.kwargs = kwargs |
|||
|
|||
def create_properties(self): |
|||
if hasattr(self.prop, '__call__'): |
|||
prop_value = self.prop(self.entity.table.c) |
|||
else: |
|||
prop_value = self.prop |
|||
prop_value = self.evaluate_property(prop_value) |
|||
self.add_mapper_property(self.name, prop_value) |
|||
|
|||
def evaluate_property(self, prop): |
|||
if self.args or self.kwargs: |
|||
raise Exception('superfluous arguments passed to GenericProperty') |
|||
return prop |
|||
|
|||
|
|||
class ColumnProperty(GenericProperty): |
|||
''' |
|||
A specialized form of the GenericProperty to generate SQLAlchemy |
|||
``column_property``'s. |
|||
|
|||
It takes a function (often given as an anonymous lambda) as its first |
|||
argument. Other arguments and keyword arguments are forwarded to the |
|||
column_property construct. That first-argument function must accept exactly |
|||
one argument and must return the desired (scalar-returning) SQLAlchemy |
|||
ClauseElement. |
|||
|
|||
The function will be called whenever the entity table is completely |
|||
defined, and will be given |
|||
the .c attribute of the table of the entity as argument (as a way to |
|||
access the entity columns). The ColumnProperty will first wrap your |
|||
ClauseElement in an |
|||
"empty" label (ie it will be labelled automatically during queries), |
|||
then wrap that in a column_property. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class OrderLine(Entity): |
|||
quantity = Field(Float) |
|||
unit_price = Field(Numeric) |
|||
price = ColumnProperty(lambda c: c.quantity * c.unit_price, |
|||
deferred=True) |
|||
|
|||
Please look at the `corresponding SQLAlchemy |
|||
documentation <http://www.sqlalchemy.org/docs/05/mappers.html |
|||
#sql-expressions-as-mapped-attributes>`_ for details. |
|||
''' |
|||
|
|||
def evaluate_property(self, prop): |
|||
return column_property(prop.label(None), *self.args, **self.kwargs) |
|||
|
|||
|
|||
class Synonym(GenericProperty): |
|||
''' |
|||
This class represents a synonym property of another property (column, ...) |
|||
of an entity. As opposed to the `synonym` kwarg to the Field class (which |
|||
share the same goal), this class can be used to define a synonym of a |
|||
property defined in a parent class (of the current class). On the other |
|||
hand, it cannot define a synonym for the purpose of using a standard python |
|||
property in queries. See the Field class for details on that usage. |
|||
|
|||
.. sourcecode:: python |
|||
|
|||
class Person(Entity): |
|||
name = Field(String(30)) |
|||
primary_email = Field(String(100)) |
|||
email_address = Synonym('primary_email') |
|||
|
|||
class User(Person): |
|||
user_name = Synonym('name') |
|||
password = Field(String(20)) |
|||
''' |
|||
|
|||
def evaluate_property(self, prop): |
|||
return synonym(prop, *self.args, **self.kwargs) |
|||
|
|||
#class Composite(GenericProperty): |
|||
# def __init__(self, prop): |
|||
# super(GenericProperty, self).__init__() |
|||
# self.prop = prop |
|||
|
|||
# def evaluate_property(self, prop): |
|||
# return composite(prop.label(self.name)) |
|||
|
|||
#start = Composite(Point, lambda c: (c.x1, c.y1)) |
|||
|
|||
#mapper(Vertex, vertices, properties={ |
|||
# 'start':composite(Point, vertices.c.x1, vertices.c.y1), |
|||
# 'end':composite(Point, vertices.c.x2, vertices.c.y2) |
|||
#}) |
|||
|
|||
|
|||
has_property = PropertyStatement(GenericProperty) |
|||
|
@ -0,0 +1,73 @@ |
|||
# Some helper functions to get by without Python 2.4 |
|||
|
|||
# set |
|||
try: |
|||
set = set |
|||
except NameError: |
|||
from sets import Set as set |
|||
|
|||
orig_cmp = cmp |
|||
# [].sort |
|||
def sort_list(l, cmp=None, key=None, reverse=False): |
|||
try: |
|||
l.sort(cmp, key, reverse) |
|||
except TypeError, e: |
|||
if not str(e).startswith('sort expected at most 1 arguments'): |
|||
raise |
|||
if cmp is None: |
|||
cmp = orig_cmp |
|||
if key is not None: |
|||
# the cmp=cmp parameter is required to get the original comparator |
|||
# into the lambda namespace |
|||
cmp = lambda self, other, cmp=cmp: cmp(key(self), key(other)) |
|||
if reverse: |
|||
cmp = lambda self, other, cmp=cmp: -cmp(self,other) |
|||
l.sort(cmp) |
|||
|
|||
# sorted |
|||
try: |
|||
sorted = sorted |
|||
except NameError: |
|||
# global name 'sorted' doesn't exist in Python2.3 |
|||
# this provides a poor-man's emulation of the sorted built-in method |
|||
def sorted(l, cmp=None, key=None, reverse=False): |
|||
sorted_list = list(l) |
|||
sort_list(sorted_list, cmp, key, reverse) |
|||
return sorted_list |
|||
|
|||
# rsplit |
|||
try: |
|||
''.rsplit |
|||
def rsplit(s, delim, maxsplit): |
|||
return s.rsplit(delim, maxsplit) |
|||
|
|||
except AttributeError: |
|||
def rsplit(s, delim, maxsplit): |
|||
"""Return a list of the words of the string s, scanning s |
|||
from the end. To all intents and purposes, the resulting |
|||
list of words is the same as returned by split(), except |
|||
when the optional third argument maxsplit is explicitly |
|||
specified and nonzero. When maxsplit is nonzero, at most |
|||
maxsplit number of splits - the rightmost ones - occur, |
|||
and the remainder of the string is returned as the first |
|||
element of the list (thus, the list will have at most |
|||
maxsplit+1 elements). New in version 2.4. |
|||
>>> rsplit('foo.bar.baz', '.', 0) |
|||
['foo.bar.baz'] |
|||
>>> rsplit('foo.bar.baz', '.', 1) |
|||
['foo.bar', 'baz'] |
|||
>>> rsplit('foo.bar.baz', '.', 2) |
|||
['foo', 'bar', 'baz'] |
|||
>>> rsplit('foo.bar.baz', '.', 99) |
|||
['foo', 'bar', 'baz'] |
|||
""" |
|||
assert maxsplit >= 0 |
|||
|
|||
if maxsplit == 0: return [s] |
|||
|
|||
# the following lines perform the function, but inefficiently. |
|||
# This may be adequate for compatibility purposes |
|||
items = s.split(delim) |
|||
if maxsplit < len(items): |
|||
items[:-maxsplit] = [delim.join(items[:-maxsplit])] |
|||
return items |
File diff suppressed because it is too large
@ -0,0 +1,59 @@ |
|||
import sys |
|||
|
|||
MUTATORS = '__elixir_mutators__' |
|||
|
|||
class ClassMutator(object): |
|||
''' |
|||
DSL-style syntax |
|||
|
|||
A ``ClassMutator`` object represents a DSL term. |
|||
''' |
|||
|
|||
def __init__(self, handler): |
|||
''' |
|||
Create a new ClassMutator, using the `handler` callable to process it |
|||
when the time will come. |
|||
''' |
|||
self.handler = handler |
|||
|
|||
# called when a mutator (eg. "has_field(...)") is parsed |
|||
def __call__(self, *args, **kwargs): |
|||
# self in this case is the "generic" mutator (eg "has_field") |
|||
|
|||
# jam this mutator into the class's mutator list |
|||
class_locals = sys._getframe(1).f_locals |
|||
mutators = class_locals.setdefault(MUTATORS, []) |
|||
mutators.append((self, args, kwargs)) |
|||
|
|||
def process(self, entity, *args, **kwargs): |
|||
''' |
|||
Process one mutator. This version simply calls the handler callable, |
|||
but another mutator (sub)class could do more processing. |
|||
''' |
|||
self.handler(entity, *args, **kwargs) |
|||
|
|||
|
|||
#TODO: move this to the super class (to be created here) of EntityMeta |
|||
def process_mutators(entity): |
|||
''' |
|||
Apply all mutators of the given entity. That is, loop over all mutators |
|||
in the class's mutator list and process them. |
|||
''' |
|||
# we don't use getattr here to not inherit from the parent mutators |
|||
# inadvertantly if the current entity hasn't defined any mutator. |
|||
mutators = entity.__dict__.get(MUTATORS, []) |
|||
for mutator, args, kwargs in mutators: |
|||
mutator.process(entity, *args, **kwargs) |
|||
|
|||
class Statement(ClassMutator): |
|||
|
|||
def process(self, entity, *args, **kwargs): |
|||
builder = self.handler(entity, *args, **kwargs) |
|||
entity._descriptor.builders.append(builder) |
|||
|
|||
class PropertyStatement(ClassMutator): |
|||
|
|||
def process(self, entity, name, *args, **kwargs): |
|||
prop = self.handler(*args, **kwargs) |
|||
prop.attach(entity, name) |
|||
|
@ -0,0 +1,34 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask |
|||
~~~~~ |
|||
|
|||
A microframework based on Werkzeug. It's extensively documented |
|||
and follows best practice patterns. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
# utilities we import from Werkzeug and Jinja2 that are unused |
|||
# in the module but are exported as public interface. |
|||
from werkzeug import abort, redirect |
|||
from jinja2 import Markup, escape |
|||
|
|||
from .app import Flask, Request, Response |
|||
from .config import Config |
|||
from .helpers import url_for, jsonify, json_available, flash, \ |
|||
send_file, send_from_directory, get_flashed_messages, \ |
|||
get_template_attribute, make_response |
|||
from .globals import current_app, g, request, session, _request_ctx_stack |
|||
from .module import Module |
|||
from .templating import render_template, render_template_string |
|||
from .session import Session |
|||
|
|||
# the signals |
|||
from .signals import signals_available, template_rendered, request_started, \ |
|||
request_finished, got_request_exception |
|||
|
|||
# only import json if it's available |
|||
if json_available: |
|||
from .helpers import json |
@ -0,0 +1,965 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.app |
|||
~~~~~~~~~ |
|||
|
|||
This module implements the central WSGI application object. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from __future__ import with_statement |
|||
|
|||
from threading import Lock |
|||
from datetime import timedelta, datetime |
|||
from itertools import chain |
|||
|
|||
from jinja2 import Environment |
|||
|
|||
from werkzeug import ImmutableDict |
|||
from werkzeug.routing import Map, Rule |
|||
from werkzeug.exceptions import HTTPException, InternalServerError, \ |
|||
MethodNotAllowed |
|||
|
|||
from .helpers import _PackageBoundObject, url_for, get_flashed_messages, \ |
|||
_tojson_filter, _endpoint_from_view_func |
|||
from .wrappers import Request, Response |
|||
from .config import ConfigAttribute, Config |
|||
from .ctx import _RequestContext |
|||
from .globals import _request_ctx_stack, request |
|||
from .session import Session, _NullSession |
|||
from .module import _ModuleSetupState |
|||
from .templating import _DispatchingJinjaLoader, \ |
|||
_default_template_ctx_processor |
|||
from .signals import request_started, request_finished, got_request_exception |
|||
|
|||
# a lock used for logger initialization |
|||
_logger_lock = Lock() |
|||
|
|||
|
|||
class Flask(_PackageBoundObject): |
|||
"""The flask object implements a WSGI application and acts as the central |
|||
object. It is passed the name of the module or package of the |
|||
application. Once it is created it will act as a central registry for |
|||
the view functions, the URL rules, template configuration and much more. |
|||
|
|||
The name of the package is used to resolve resources from inside the |
|||
package or the folder the module is contained in depending on if the |
|||
package parameter resolves to an actual python package (a folder with |
|||
an `__init__.py` file inside) or a standard module (just a `.py` file). |
|||
|
|||
For more information about resource loading, see :func:`open_resource`. |
|||
|
|||
Usually you create a :class:`Flask` instance in your main module or |
|||
in the `__init__.py` file of your package like this:: |
|||
|
|||
from flask import Flask |
|||
app = Flask(__name__) |
|||
|
|||
.. admonition:: About the First Parameter |
|||
|
|||
The idea of the first parameter is to give Flask an idea what |
|||
belongs to your application. This name is used to find resources |
|||
on the file system, can be used by extensions to improve debugging |
|||
information and a lot more. |
|||
|
|||
So it's important what you provide there. If you are using a single |
|||
module, `__name__` is always the correct value. If you however are |
|||
using a package, it's usually recommended to hardcode the name of |
|||
your package there. |
|||
|
|||
For example if your application is defined in `yourapplication/app.py` |
|||
you should create it with one of the two versions below:: |
|||
|
|||
app = Flask('yourapplication') |
|||
app = Flask(__name__.split('.')[0]) |
|||
|
|||
Why is that? The application will work even with `__name__`, thanks |
|||
to how resources are looked up. However it will make debugging more |
|||
painful. Certain extensions can make assumptions based on the |
|||
import name of your application. For example the Flask-SQLAlchemy |
|||
extension will look for the code in your application that triggered |
|||
an SQL query in debug mode. If the import name is not properly set |
|||
up, that debugging information is lost. (For example it would only |
|||
pick up SQL queries in `yourapplicaiton.app` and not |
|||
`yourapplication.views.frontend`) |
|||
|
|||
.. versionadded:: 0.5 |
|||
The `static_path` parameter was added. |
|||
|
|||
:param import_name: the name of the application package |
|||
:param static_path: can be used to specify a different path for the |
|||
static files on the web. Defaults to ``/static``. |
|||
This does not affect the folder the files are served |
|||
*from*. |
|||
""" |
|||
|
|||
#: The class that is used for request objects. See :class:`~flask.Request` |
|||
#: for more information. |
|||
request_class = Request |
|||
|
|||
#: The class that is used for response objects. See |
|||
#: :class:`~flask.Response` for more information. |
|||
response_class = Response |
|||
|
|||
#: Path for the static files. If you don't want to use static files |
|||
#: you can set this value to `None` in which case no URL rule is added |
|||
#: and the development server will no longer serve any static files. |
|||
#: |
|||
#: This is the default used for application and modules unless a |
|||
#: different value is passed to the constructor. |
|||
static_path = '/static' |
|||
|
|||
#: The debug flag. Set this to `True` to enable debugging of the |
|||
#: application. In debug mode the debugger will kick in when an unhandled |
|||
#: exception ocurrs and the integrated server will automatically reload |
|||
#: the application if changes in the code are detected. |
|||
#: |
|||
#: This attribute can also be configured from the config with the `DEBUG` |
|||
#: configuration key. Defaults to `False`. |
|||
debug = ConfigAttribute('DEBUG') |
|||
|
|||
#: The testing flask. Set this to `True` to enable the test mode of |
|||
#: Flask extensions (and in the future probably also Flask itself). |
|||
#: For example this might activate unittest helpers that have an |
|||
#: additional runtime cost which should not be enabled by default. |
|||
#: |
|||
#: This attribute can also be configured from the config with the |
|||
#: `TESTING` configuration key. Defaults to `False`. |
|||
testing = ConfigAttribute('TESTING') |
|||
|
|||
#: If a secret key is set, cryptographic components can use this to |
|||
#: sign cookies and other things. Set this to a complex random value |
|||
#: when you want to use the secure cookie for instance. |
|||
#: |
|||
#: This attribute can also be configured from the config with the |
|||
#: `SECRET_KEY` configuration key. Defaults to `None`. |
|||
secret_key = ConfigAttribute('SECRET_KEY') |
|||
|
|||
#: The secure cookie uses this for the name of the session cookie. |
|||
#: |
|||
#: This attribute can also be configured from the config with the |
|||
#: `SESSION_COOKIE_NAME` configuration key. Defaults to ``'session'`` |
|||
session_cookie_name = ConfigAttribute('SESSION_COOKIE_NAME') |
|||
|
|||
#: A :class:`~datetime.timedelta` which is used to set the expiration |
|||
#: date of a permanent session. The default is 31 days which makes a |
|||
#: permanent session survive for roughly one month. |
|||
#: |
|||
#: This attribute can also be configured from the config with the |
|||
#: `PERMANENT_SESSION_LIFETIME` configuration key. Defaults to |
|||
#: ``timedelta(days=31)`` |
|||
permanent_session_lifetime = ConfigAttribute('PERMANENT_SESSION_LIFETIME') |
|||
|
|||
#: Enable this if you want to use the X-Sendfile feature. Keep in |
|||
#: mind that the server has to support this. This only affects files |
|||
#: sent with the :func:`send_file` method. |
|||
#: |
|||
#: .. versionadded:: 0.2 |
|||
#: |
|||
#: This attribute can also be configured from the config with the |
|||
#: `USE_X_SENDFILE` configuration key. Defaults to `False`. |
|||
use_x_sendfile = ConfigAttribute('USE_X_SENDFILE') |
|||
|
|||
#: The name of the logger to use. By default the logger name is the |
|||
#: package name passed to the constructor. |
|||
#: |
|||
#: .. versionadded:: 0.4 |
|||
logger_name = ConfigAttribute('LOGGER_NAME') |
|||
|
|||
#: The logging format used for the debug logger. This is only used when |
|||
#: the application is in debug mode, otherwise the attached logging |
|||
#: handler does the formatting. |
|||
#: |
|||
#: .. versionadded:: 0.3 |
|||
debug_log_format = ( |
|||
'-' * 80 + '\n' + |
|||
'%(levelname)s in %(module)s [%(pathname)s:%(lineno)d]:\n' + |
|||
'%(message)s\n' + |
|||
'-' * 80 |
|||
) |
|||
|
|||
#: Options that are passed directly to the Jinja2 environment. |
|||
jinja_options = ImmutableDict( |
|||
extensions=['jinja2.ext.autoescape', 'jinja2.ext.with_'] |
|||
) |
|||
|
|||
#: Default configuration parameters. |
|||
default_config = ImmutableDict({ |
|||
'DEBUG': False, |
|||
'TESTING': False, |
|||
'PROPAGATE_EXCEPTIONS': None, |
|||
'SECRET_KEY': None, |
|||
'SESSION_COOKIE_NAME': 'session', |
|||
'PERMANENT_SESSION_LIFETIME': timedelta(days=31), |
|||
'USE_X_SENDFILE': False, |
|||
'LOGGER_NAME': None, |
|||
'SERVER_NAME': None, |
|||
'MAX_CONTENT_LENGTH': None |
|||
}) |
|||
|
|||
#: the test client that is used with when `test_client` is used. |
|||
#: |
|||
#: .. versionadded:: 0.7 |
|||
test_client_class = None |
|||
|
|||
def __init__(self, import_name, static_path=None): |
|||
_PackageBoundObject.__init__(self, import_name) |
|||
if static_path is not None: |
|||
self.static_path = static_path |
|||
|
|||
#: The configuration dictionary as :class:`Config`. This behaves |
|||
#: exactly like a regular dictionary but supports additional methods |
|||
#: to load a config from files. |
|||
self.config = Config(self.root_path, self.default_config) |
|||
|
|||
#: Prepare the deferred setup of the logger. |
|||
self._logger = None |
|||
self.logger_name = self.import_name |
|||
|
|||
#: A dictionary of all view functions registered. The keys will |
|||
#: be function names which are also used to generate URLs and |
|||
#: the values are the function objects themselves. |
|||
#: To register a view function, use the :meth:`route` decorator. |
|||
self.view_functions = {} |
|||
|
|||
#: A dictionary of all registered error handlers. The key is |
|||
#: be the error code as integer, the value the function that |
|||
#: should handle that error. |
|||
#: To register a error handler, use the :meth:`errorhandler` |
|||
#: decorator. |
|||
self.error_handlers = {} |
|||
|
|||
#: A dictionary with lists of functions that should be called at the |
|||
#: beginning of the request. The key of the dictionary is the name of |
|||
#: the module this function is active for, `None` for all requests. |
|||
#: This can for example be used to open database connections or |
|||
#: getting hold of the currently logged in user. To register a |
|||
#: function here, use the :meth:`before_request` decorator. |
|||
self.before_request_funcs = {} |
|||
|
|||
#: A dictionary with lists of functions that should be called after |
|||
#: each request. The key of the dictionary is the name of the module |
|||
#: this function is active for, `None` for all requests. This can for |
|||
#: example be used to open database connections or getting hold of the |
|||
#: currently logged in user. To register a function here, use the |
|||
#: :meth:`after_request` decorator. |
|||
self.after_request_funcs = {} |
|||
|
|||
#: A dictionary with list of functions that are called without argument |
|||
#: to populate the template context. The key of the dictionary is the |
|||
#: name of the module this function is active for, `None` for all |
|||
#: requests. Each returns a dictionary that the template context is |
|||
#: updated with. To register a function here, use the |
|||
#: :meth:`context_processor` decorator. |
|||
self.template_context_processors = { |
|||
None: [_default_template_ctx_processor] |
|||
} |
|||
|
|||
#: all the loaded modules in a dictionary by name. |
|||
#: |
|||
#: .. versionadded:: 0.5 |
|||
self.modules = {} |
|||
|
|||
#: a place where extensions can store application specific state. For |
|||
#: example this is where an extension could store database engines and |
|||
#: similar things. For backwards compatibility extensions should register |
|||
#: themselves like this:: |
|||
#: |
|||
#: if not hasattr(app, 'extensions'): |
|||
#: app.extensions = {} |
|||
#: app.extensions['extensionname'] = SomeObject() |
|||
#: |
|||
#: The key must match the name of the `flaskext` module. For example in |
|||
#: case of a "Flask-Foo" extension in `flaskext.foo`, the key would be |
|||
#: ``'foo'``. |
|||
#: |
|||
#: .. versionadded:: 0.7 |
|||
self.extensions = {} |
|||
|
|||
#: The :class:`~werkzeug.routing.Map` for this instance. You can use |
|||
#: this to change the routing converters after the class was created |
|||
#: but before any routes are connected. Example:: |
|||
#: |
|||
#: from werkzeug.routing import BaseConverter |
|||
#: |
|||
#: class ListConverter(BaseConverter): |
|||
#: def to_python(self, value): |
|||
#: return value.split(',') |
|||
#: def to_url(self, values): |
|||
#: return ','.join(BaseConverter.to_url(value) |
|||
#: for value in values) |
|||
#: |
|||
#: app = Flask(__name__) |
|||
#: app.url_map.converters['list'] = ListConverter |
|||
self.url_map = Map() |
|||
|
|||
# register the static folder for the application. Do that even |
|||
# if the folder does not exist. First of all it might be created |
|||
# while the server is running (usually happens during development) |
|||
# but also because google appengine stores static files somewhere |
|||
# else when mapped with the .yml file. |
|||
self.add_url_rule(self.static_path + '/<path:filename>', |
|||
endpoint='static', |
|||
view_func=self.send_static_file) |
|||
|
|||
#: The Jinja2 environment. It is created from the |
|||
#: :attr:`jinja_options`. |
|||
self.jinja_env = self.create_jinja_environment() |
|||
self.init_jinja_globals() |
|||
|
|||
@property |
|||
def propagate_exceptions(self): |
|||
"""Returns the value of the `PROPAGATE_EXCEPTIONS` configuration |
|||
value in case it's set, otherwise a sensible default is returned. |
|||
|
|||
.. versionadded:: 0.7 |
|||
""" |
|||
rv = self.config['PROPAGATE_EXCEPTIONS'] |
|||
if rv is not None: |
|||
return rv |
|||
return self.testing or self.debug |
|||
|
|||
@property |
|||
def logger(self): |
|||
"""A :class:`logging.Logger` object for this application. The |
|||
default configuration is to log to stderr if the application is |
|||
in debug mode. This logger can be used to (surprise) log messages. |
|||
Here some examples:: |
|||
|
|||
app.logger.debug('A value for debugging') |
|||
app.logger.warning('A warning ocurred (%d apples)', 42) |
|||
app.logger.error('An error occoured') |
|||
|
|||
.. versionadded:: 0.3 |
|||
""" |
|||
if self._logger and self._logger.name == self.logger_name: |
|||
return self._logger |
|||
with _logger_lock: |
|||
if self._logger and self._logger.name == self.logger_name: |
|||
return self._logger |
|||
from flask.logging import create_logger |
|||
self._logger = rv = create_logger(self) |
|||
return rv |
|||
|
|||
def create_jinja_environment(self): |
|||
"""Creates the Jinja2 environment based on :attr:`jinja_options` |
|||
and :meth:`select_jinja_autoescape`. |
|||
|
|||
.. versionadded:: 0.5 |
|||
""" |
|||
options = dict(self.jinja_options) |
|||
if 'autoescape' not in options: |
|||
options['autoescape'] = self.select_jinja_autoescape |
|||
return Environment(loader=_DispatchingJinjaLoader(self), **options) |
|||
|
|||
def init_jinja_globals(self): |
|||
"""Called directly after the environment was created to inject |
|||
some defaults (like `url_for`, `get_flashed_messages` and the |
|||
`tojson` filter. |
|||
|
|||
.. versionadded:: 0.5 |
|||
""" |
|||
self.jinja_env.globals.update( |
|||
url_for=url_for, |
|||
get_flashed_messages=get_flashed_messages |
|||
) |
|||
self.jinja_env.filters['tojson'] = _tojson_filter |
|||
|
|||
def select_jinja_autoescape(self, filename): |
|||
"""Returns `True` if autoescaping should be active for the given |
|||
template name. |
|||
|
|||
.. versionadded:: 0.5 |
|||
""" |
|||
if filename is None: |
|||
return False |
|||
return filename.endswith(('.html', '.htm', '.xml', '.xhtml')) |
|||
|
|||
def update_template_context(self, context): |
|||
"""Update the template context with some commonly used variables. |
|||
This injects request, session, config and g into the template |
|||
context as well as everything template context processors want |
|||
to inject. Note that the as of Flask 0.6, the original values |
|||
in the context will not be overriden if a context processor |
|||
decides to return a value with the same key. |
|||
|
|||
:param context: the context as a dictionary that is updated in place |
|||
to add extra variables. |
|||
""" |
|||
funcs = self.template_context_processors[None] |
|||
mod = _request_ctx_stack.top.request.module |
|||
if mod is not None and mod in self.template_context_processors: |
|||
funcs = chain(funcs, self.template_context_processors[mod]) |
|||
orig_ctx = context.copy() |
|||
for func in funcs: |
|||
context.update(func()) |
|||
# make sure the original values win. This makes it possible to |
|||
# easier add new variables in context processors without breaking |
|||
# existing views. |
|||
context.update(orig_ctx) |
|||
|
|||
def run(self, host='127.0.0.1', port=5000, **options): |
|||
"""Runs the application on a local development server. If the |
|||
:attr:`debug` flag is set the server will automatically reload |
|||
for code changes and show a debugger in case an exception happened. |
|||
|
|||
If you want to run the application in debug mode, but disable the |
|||
code execution on the interactive debugger, you can pass |
|||
``use_evalex=False`` as parameter. This will keep the debugger's |
|||
traceback screen active, but disable code execution. |
|||
|
|||
.. admonition:: Keep in Mind |
|||
|
|||
Flask will suppress any server error with a generic error page |
|||
unless it is in debug mode. As such to enable just the |
|||
interactive debugger without the code reloading, you have to |
|||
invoke :meth:`run` with ``debug=True`` and ``use_reloader=False``. |
|||
Setting ``use_debugger`` to `True` without being in debug mode |
|||
won't catch any exceptions because there won't be any to |
|||
catch. |
|||
|
|||
:param host: the hostname to listen on. set this to ``'0.0.0.0'`` |
|||
to have the server available externally as well. |
|||
:param port: the port of the webserver |
|||
:param options: the options to be forwarded to the underlying |
|||
Werkzeug server. See :func:`werkzeug.run_simple` |
|||
for more information. |
|||
""" |
|||
from werkzeug import run_simple |
|||
if 'debug' in options: |
|||
self.debug = options.pop('debug') |
|||
options.setdefault('use_reloader', self.debug) |
|||
options.setdefault('use_debugger', self.debug) |
|||
return run_simple(host, port, self, **options) |
|||
|
|||
def test_client(self, use_cookies=True): |
|||
"""Creates a test client for this application. For information |
|||
about unit testing head over to :ref:`testing`. |
|||
|
|||
The test client can be used in a `with` block to defer the closing down |
|||
of the context until the end of the `with` block. This is useful if |
|||
you want to access the context locals for testing:: |
|||
|
|||
with app.test_client() as c: |
|||
rv = c.get('/?vodka=42') |
|||
assert request.args['vodka'] == '42' |
|||
|
|||
.. versionchanged:: 0.4 |
|||
added support for `with` block usage for the client. |
|||
|
|||
.. versionadded:: 0.7 |
|||
The `use_cookies` parameter was added as well as the ability |
|||
to override the client to be used by setting the |
|||
:attr:`test_client_class` attribute. |
|||
""" |
|||
cls = self.test_client_class |
|||
if cls is None: |
|||
from flask.testing import FlaskClient as cls |
|||
return cls(self, self.response_class, use_cookies=use_cookies) |
|||
|
|||
def open_session(self, request): |
|||
"""Creates or opens a new session. Default implementation stores all |
|||
session data in a signed cookie. This requires that the |
|||
:attr:`secret_key` is set. |
|||
|
|||
:param request: an instance of :attr:`request_class`. |
|||
""" |
|||
key = self.secret_key |
|||
if key is not None: |
|||
return Session.load_cookie(request, self.session_cookie_name, |
|||
secret_key=key) |
|||
|
|||
def save_session(self, session, response): |
|||
"""Saves the session if it needs updates. For the default |
|||
implementation, check :meth:`open_session`. |
|||
|
|||
:param session: the session to be saved (a |
|||
:class:`~werkzeug.contrib.securecookie.SecureCookie` |
|||
object) |
|||
:param response: an instance of :attr:`response_class` |
|||
""" |
|||
expires = domain = None |
|||
if session.permanent: |
|||
expires = datetime.utcnow() + self.permanent_session_lifetime |
|||
if self.config['SERVER_NAME'] is not None: |
|||
domain = '.' + self.config['SERVER_NAME'] |
|||
session.save_cookie(response, self.session_cookie_name, |
|||
expires=expires, httponly=True, domain=domain) |
|||
|
|||
def register_module(self, module, **options): |
|||
"""Registers a module with this application. The keyword argument |
|||
of this function are the same as the ones for the constructor of the |
|||
:class:`Module` class and will override the values of the module if |
|||
provided. |
|||
""" |
|||
options.setdefault('url_prefix', module.url_prefix) |
|||
options.setdefault('subdomain', module.subdomain) |
|||
self.view_functions.update(module.view_functions) |
|||
state = _ModuleSetupState(self, **options) |
|||
for func in module._register_events: |
|||
func(state) |
|||
|
|||
def add_url_rule(self, rule, endpoint=None, view_func=None, **options): |
|||
"""Connects a URL rule. Works exactly like the :meth:`route` |
|||
decorator. If a view_func is provided it will be registered with the |
|||
endpoint. |
|||
|
|||
Basically this example:: |
|||
|
|||
@app.route('/') |
|||
def index(): |
|||
pass |
|||
|
|||
Is equivalent to the following:: |
|||
|
|||
def index(): |
|||
pass |
|||
app.add_url_rule('/', 'index', index) |
|||
|
|||
If the view_func is not provided you will need to connect the endpoint |
|||
to a view function like so:: |
|||
|
|||
app.view_functions['index'] = index |
|||
|
|||
.. versionchanged:: 0.2 |
|||
`view_func` parameter added. |
|||
|
|||
.. versionchanged:: 0.6 |
|||
`OPTIONS` is added automatically as method. |
|||
|
|||
:param rule: the URL rule as string |
|||
:param endpoint: the endpoint for the registered URL rule. Flask |
|||
itself assumes the name of the view function as |
|||
endpoint |
|||
:param view_func: the function to call when serving a request to the |
|||
provided endpoint |
|||
:param options: the options to be forwarded to the underlying |
|||
:class:`~werkzeug.routing.Rule` object. A change |
|||
to Werkzeug is handling of method options. methods |
|||
is a list of methods this rule should be limited |
|||
to (`GET`, `POST` etc.). By default a rule |
|||
just listens for `GET` (and implicitly `HEAD`). |
|||
Starting with Flask 0.6, `OPTIONS` is implicitly |
|||
added and handled by the standard request handling. |
|||
""" |
|||
if endpoint is None: |
|||
endpoint = _endpoint_from_view_func(view_func) |
|||
options['endpoint'] = endpoint |
|||
methods = options.pop('methods', ('GET',)) |
|||
provide_automatic_options = False |
|||
if 'OPTIONS' not in methods: |
|||
methods = tuple(methods) + ('OPTIONS',) |
|||
provide_automatic_options = True |
|||
rule = Rule(rule, methods=methods, **options) |
|||
rule.provide_automatic_options = provide_automatic_options |
|||
self.url_map.add(rule) |
|||
if view_func is not None: |
|||
self.view_functions[endpoint] = view_func |
|||
|
|||
def route(self, rule, **options): |
|||
"""A decorator that is used to register a view function for a |
|||
given URL rule. Example:: |
|||
|
|||
@app.route('/') |
|||
def index(): |
|||
return 'Hello World' |
|||
|
|||
Variables parts in the route can be specified with angular |
|||
brackets (``/user/<username>``). By default a variable part |
|||
in the URL accepts any string without a slash however a different |
|||
converter can be specified as well by using ``<converter:name>``. |
|||
|
|||
Variable parts are passed to the view function as keyword |
|||
arguments. |
|||
|
|||
The following converters are possible: |
|||
|
|||
=========== =========================================== |
|||
`int` accepts integers |
|||
`float` like `int` but for floating point values |
|||
`path` like the default but also accepts slashes |
|||
=========== =========================================== |
|||
|
|||
Here some examples:: |
|||
|
|||
@app.route('/') |
|||
def index(): |
|||
pass |
|||
|
|||
@app.route('/<username>') |
|||
def show_user(username): |
|||
pass |
|||
|
|||
@app.route('/post/<int:post_id>') |
|||
def show_post(post_id): |
|||
pass |
|||
|
|||
An important detail to keep in mind is how Flask deals with trailing |
|||
slashes. The idea is to keep each URL unique so the following rules |
|||
apply: |
|||
|
|||
1. If a rule ends with a slash and is requested without a slash |
|||
by the user, the user is automatically redirected to the same |
|||
page with a trailing slash attached. |
|||
2. If a rule does not end with a trailing slash and the user request |
|||
the page with a trailing slash, a 404 not found is raised. |
|||
|
|||
This is consistent with how web servers deal with static files. This |
|||
also makes it possible to use relative link targets safely. |
|||
|
|||
The :meth:`route` decorator accepts a couple of other arguments |
|||
as well: |
|||
|
|||
:param rule: the URL rule as string |
|||
:param methods: a list of methods this rule should be limited |
|||
to (`GET`, `POST` etc.). By default a rule |
|||
just listens for `GET` (and implicitly `HEAD`). |
|||
Starting with Flask 0.6, `OPTIONS` is implicitly |
|||
added and handled by the standard request handling. |
|||
:param subdomain: specifies the rule for the subdomain in case |
|||
subdomain matching is in use. |
|||
:param strict_slashes: can be used to disable the strict slashes |
|||
setting for this rule. See above. |
|||
:param options: other options to be forwarded to the underlying |
|||
:class:`~werkzeug.routing.Rule` object. |
|||
""" |
|||
def decorator(f): |
|||
self.add_url_rule(rule, None, f, **options) |
|||
return f |
|||
return decorator |
|||
|
|||
|
|||
def endpoint(self, endpoint): |
|||
"""A decorator to register a function as an endpoint. |
|||
Example:: |
|||
|
|||
@app.endpoint('example.endpoint') |
|||
def example(): |
|||
return "example" |
|||
|
|||
:param endpoint: the name of the endpoint |
|||
""" |
|||
def decorator(f): |
|||
self.view_functions[endpoint] = f |
|||
return f |
|||
return decorator |
|||
|
|||
def errorhandler(self, code): |
|||
"""A decorator that is used to register a function give a given |
|||
error code. Example:: |
|||
|
|||
@app.errorhandler(404) |
|||
def page_not_found(error): |
|||
return 'This page does not exist', 404 |
|||
|
|||
You can also register a function as error handler without using |
|||
the :meth:`errorhandler` decorator. The following example is |
|||
equivalent to the one above:: |
|||
|
|||
def page_not_found(error): |
|||
return 'This page does not exist', 404 |
|||
app.error_handlers[404] = page_not_found |
|||
|
|||
:param code: the code as integer for the handler |
|||
""" |
|||
def decorator(f): |
|||
self.error_handlers[code] = f |
|||
return f |
|||
return decorator |
|||
|
|||
def template_filter(self, name=None): |
|||
"""A decorator that is used to register custom template filter. |
|||
You can specify a name for the filter, otherwise the function |
|||
name will be used. Example:: |
|||
|
|||
@app.template_filter() |
|||
def reverse(s): |
|||
return s[::-1] |
|||
|
|||
:param name: the optional name of the filter, otherwise the |
|||
function name will be used. |
|||
""" |
|||
def decorator(f): |
|||
self.jinja_env.filters[name or f.__name__] = f |
|||
return f |
|||
return decorator |
|||
|
|||
def before_request(self, f): |
|||
"""Registers a function to run before each request.""" |
|||
self.before_request_funcs.setdefault(None, []).append(f) |
|||
return f |
|||
|
|||
def after_request(self, f): |
|||
"""Register a function to be run after each request.""" |
|||
self.after_request_funcs.setdefault(None, []).append(f) |
|||
return f |
|||
|
|||
def context_processor(self, f): |
|||
"""Registers a template context processor function.""" |
|||
self.template_context_processors[None].append(f) |
|||
return f |
|||
|
|||
def handle_http_exception(self, e): |
|||
"""Handles an HTTP exception. By default this will invoke the |
|||
registered error handlers and fall back to returning the |
|||
exception as response. |
|||
|
|||
.. versionadded: 0.3 |
|||
""" |
|||
handler = self.error_handlers.get(e.code) |
|||
if handler is None: |
|||
return e |
|||
return handler(e) |
|||
|
|||
def handle_exception(self, e): |
|||
"""Default exception handling that kicks in when an exception |
|||
occours that is not catched. In debug mode the exception will |
|||
be re-raised immediately, otherwise it is logged and the handler |
|||
for a 500 internal server error is used. If no such handler |
|||
exists, a default 500 internal server error message is displayed. |
|||
|
|||
.. versionadded: 0.3 |
|||
""" |
|||
got_request_exception.send(self, exception=e) |
|||
handler = self.error_handlers.get(500) |
|||
if self.propagate_exceptions: |
|||
raise |
|||
self.logger.exception('Exception on %s [%s]' % ( |
|||
request.path, |
|||
request.method |
|||
)) |
|||
if handler is None: |
|||
return InternalServerError() |
|||
return handler(e) |
|||
|
|||
def dispatch_request(self): |
|||
"""Does the request dispatching. Matches the URL and returns the |
|||
return value of the view or error handler. This does not have to |
|||
be a response object. In order to convert the return value to a |
|||
proper response object, call :func:`make_response`. |
|||
""" |
|||
req = _request_ctx_stack.top.request |
|||
try: |
|||
if req.routing_exception is not None: |
|||
raise req.routing_exception |
|||
rule = req.url_rule |
|||
# if we provide automatic options for this URL and the |
|||
# request came with the OPTIONS method, reply automatically |
|||
if getattr(rule, 'provide_automatic_options', False) \ |
|||
and req.method == 'OPTIONS': |
|||
return self.make_default_options_response() |
|||
# otherwise dispatch to the handler for that endpoint |
|||
return self.view_functions[rule.endpoint](**req.view_args) |
|||
except HTTPException, e: |
|||
return self.handle_http_exception(e) |
|||
|
|||
def make_default_options_response(self): |
|||
"""This method is called to create the default `OPTIONS` response. |
|||
This can be changed through subclassing to change the default |
|||
behaviour of `OPTIONS` responses. |
|||
|
|||
.. versionadded:: 0.7 |
|||
""" |
|||
# This would be nicer in Werkzeug 0.7, which however currently |
|||
# is not released. Werkzeug 0.7 provides a method called |
|||
# allowed_methods() that returns all methods that are valid for |
|||
# a given path. |
|||
methods = [] |
|||
try: |
|||
_request_ctx_stack.top.url_adapter.match(method='--') |
|||
except MethodNotAllowed, e: |
|||
methods = e.valid_methods |
|||
except HTTPException, e: |
|||
pass |
|||
rv = self.response_class() |
|||
rv.allow.update(methods) |
|||
return rv |
|||
|
|||
def make_response(self, rv): |
|||
"""Converts the return value from a view function to a real |
|||
response object that is an instance of :attr:`response_class`. |
|||
|
|||
The following types are allowed for `rv`: |
|||
|
|||
.. tabularcolumns:: |p{3.5cm}|p{9.5cm}| |
|||
|
|||
======================= =========================================== |
|||
:attr:`response_class` the object is returned unchanged |
|||
:class:`str` a response object is created with the |
|||
string as body |
|||
:class:`unicode` a response object is created with the |
|||
string encoded to utf-8 as body |
|||
:class:`tuple` the response object is created with the |
|||
contents of the tuple as arguments |
|||
a WSGI function the function is called as WSGI application |
|||
and buffered as response object |
|||
======================= =========================================== |
|||
|
|||
:param rv: the return value from the view function |
|||
""" |
|||
if rv is None: |
|||
raise ValueError('View function did not return a response') |
|||
if isinstance(rv, self.response_class): |
|||
return rv |
|||
if isinstance(rv, basestring): |
|||
return self.response_class(rv) |
|||
if isinstance(rv, tuple): |
|||
return self.response_class(*rv) |
|||
return self.response_class.force_type(rv, request.environ) |
|||
|
|||
def create_url_adapter(self, request): |
|||
"""Creates a URL adapter for the given request. The URL adapter |
|||
is created at a point where the request context is not yet set up |
|||
so the request is passed explicitly. |
|||
|
|||
.. versionadded:: 0.6 |
|||
""" |
|||
return self.url_map.bind_to_environ(request.environ, |
|||
server_name=self.config['SERVER_NAME']) |
|||
|
|||
def preprocess_request(self): |
|||
"""Called before the actual request dispatching and will |
|||
call every as :meth:`before_request` decorated function. |
|||
If any of these function returns a value it's handled as |
|||
if it was the return value from the view and further |
|||
request handling is stopped. |
|||
""" |
|||
funcs = self.before_request_funcs.get(None, ()) |
|||
mod = request.module |
|||
if mod and mod in self.before_request_funcs: |
|||
funcs = chain(funcs, self.before_request_funcs[mod]) |
|||
for func in funcs: |
|||
rv = func() |
|||
if rv is not None: |
|||
return rv |
|||
|
|||
def process_response(self, response): |
|||
"""Can be overridden in order to modify the response object |
|||
before it's sent to the WSGI server. By default this will |
|||
call all the :meth:`after_request` decorated functions. |
|||
|
|||
.. versionchanged:: 0.5 |
|||
As of Flask 0.5 the functions registered for after request |
|||
execution are called in reverse order of registration. |
|||
|
|||
:param response: a :attr:`response_class` object. |
|||
:return: a new response object or the same, has to be an |
|||
instance of :attr:`response_class`. |
|||
""" |
|||
ctx = _request_ctx_stack.top |
|||
mod = ctx.request.module |
|||
if not isinstance(ctx.session, _NullSession): |
|||
self.save_session(ctx.session, response) |
|||
funcs = () |
|||
if mod and mod in self.after_request_funcs: |
|||
funcs = reversed(self.after_request_funcs[mod]) |
|||
if None in self.after_request_funcs: |
|||
funcs = chain(funcs, reversed(self.after_request_funcs[None])) |
|||
for handler in funcs: |
|||
response = handler(response) |
|||
return response |
|||
|
|||
def request_context(self, environ): |
|||
"""Creates a request context from the given environment and binds |
|||
it to the current context. This must be used in combination with |
|||
the `with` statement because the request is only bound to the |
|||
current context for the duration of the `with` block. |
|||
|
|||
Example usage:: |
|||
|
|||
with app.request_context(environ): |
|||
do_something_with(request) |
|||
|
|||
The object returned can also be used without the `with` statement |
|||
which is useful for working in the shell. The example above is |
|||
doing exactly the same as this code:: |
|||
|
|||
ctx = app.request_context(environ) |
|||
ctx.push() |
|||
try: |
|||
do_something_with(request) |
|||
finally: |
|||
ctx.pop() |
|||
|
|||
The big advantage of this approach is that you can use it without |
|||
the try/finally statement in a shell for interactive testing: |
|||
|
|||
>>> ctx = app.test_request_context() |
|||
>>> ctx.bind() |
|||
>>> request.path |
|||
u'/' |
|||
>>> ctx.unbind() |
|||
|
|||
.. versionchanged:: 0.3 |
|||
Added support for non-with statement usage and `with` statement |
|||
is now passed the ctx object. |
|||
|
|||
:param environ: a WSGI environment |
|||
""" |
|||
return _RequestContext(self, environ) |
|||
|
|||
def test_request_context(self, *args, **kwargs): |
|||
"""Creates a WSGI environment from the given values (see |
|||
:func:`werkzeug.create_environ` for more information, this |
|||
function accepts the same arguments). |
|||
""" |
|||
from werkzeug import create_environ |
|||
environ_overrides = kwargs.setdefault('environ_overrides', {}) |
|||
if self.config.get('SERVER_NAME'): |
|||
server_name = self.config.get('SERVER_NAME') |
|||
if ':' not in server_name: |
|||
http_host, http_port = server_name, '80' |
|||
else: |
|||
http_host, http_port = server_name.split(':', 1) |
|||
|
|||
environ_overrides.setdefault('SERVER_NAME', server_name) |
|||
environ_overrides.setdefault('HTTP_HOST', server_name) |
|||
environ_overrides.setdefault('SERVER_PORT', http_port) |
|||
return self.request_context(create_environ(*args, **kwargs)) |
|||
|
|||
def wsgi_app(self, environ, start_response): |
|||
"""The actual WSGI application. This is not implemented in |
|||
`__call__` so that middlewares can be applied without losing a |
|||
reference to the class. So instead of doing this:: |
|||
|
|||
app = MyMiddleware(app) |
|||
|
|||
It's a better idea to do this instead:: |
|||
|
|||
app.wsgi_app = MyMiddleware(app.wsgi_app) |
|||
|
|||
Then you still have the original application object around and |
|||
can continue to call methods on it. |
|||
|
|||
.. versionchanged:: 0.4 |
|||
The :meth:`after_request` functions are now called even if an |
|||
error handler took over request processing. This ensures that |
|||
even if an exception happens database have the chance to |
|||
properly close the connection. |
|||
|
|||
:param environ: a WSGI environment |
|||
:param start_response: a callable accepting a status code, |
|||
a list of headers and an optional |
|||
exception context to start the response |
|||
""" |
|||
with self.request_context(environ): |
|||
try: |
|||
request_started.send(self) |
|||
rv = self.preprocess_request() |
|||
if rv is None: |
|||
rv = self.dispatch_request() |
|||
response = self.make_response(rv) |
|||
except Exception, e: |
|||
response = self.make_response(self.handle_exception(e)) |
|||
try: |
|||
response = self.process_response(response) |
|||
except Exception, e: |
|||
response = self.make_response(self.handle_exception(e)) |
|||
request_finished.send(self, response=response) |
|||
return response(environ, start_response) |
|||
|
|||
def __call__(self, environ, start_response): |
|||
"""Shortcut for :attr:`wsgi_app`.""" |
|||
return self.wsgi_app(environ, start_response) |
@ -0,0 +1,157 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.config |
|||
~~~~~~~~~~~~ |
|||
|
|||
Implements the configuration related objects. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from __future__ import with_statement |
|||
|
|||
import imp |
|||
import os |
|||
import sys |
|||
|
|||
from werkzeug import import_string |
|||
|
|||
|
|||
class ConfigAttribute(object): |
|||
"""Makes an attribute forward to the config""" |
|||
|
|||
def __init__(self, name): |
|||
self.__name__ = name |
|||
|
|||
def __get__(self, obj, type=None): |
|||
if obj is None: |
|||
return self |
|||
return obj.config[self.__name__] |
|||
|
|||
def __set__(self, obj, value): |
|||
obj.config[self.__name__] = value |
|||
|
|||
|
|||
class Config(dict): |
|||
"""Works exactly like a dict but provides ways to fill it from files |
|||
or special dictionaries. There are two common patterns to populate the |
|||
config. |
|||
|
|||
Either you can fill the config from a config file:: |
|||
|
|||
app.config.from_pyfile('yourconfig.cfg') |
|||
|
|||
Or alternatively you can define the configuration options in the |
|||
module that calls :meth:`from_object` or provide an import path to |
|||
a module that should be loaded. It is also possible to tell it to |
|||
use the same module and with that provide the configuration values |
|||
just before the call:: |
|||
|
|||
DEBUG = True |
|||
SECRET_KEY = 'development key' |
|||
app.config.from_object(__name__) |
|||
|
|||
In both cases (loading from any Python file or loading from modules), |
|||
only uppercase keys are added to the config. This makes it possible to use |
|||
lowercase values in the config file for temporary values that are not added |
|||
to the config or to define the config keys in the same file that implements |
|||
the application. |
|||
|
|||
Probably the most interesting way to load configurations is from an |
|||
environment variable pointing to a file:: |
|||
|
|||
app.config.from_envvar('YOURAPPLICATION_SETTINGS') |
|||
|
|||
In this case before launching the application you have to set this |
|||
environment variable to the file you want to use. On Linux and OS X |
|||
use the export statement:: |
|||
|
|||
export YOURAPPLICATION_SETTINGS='/path/to/config/file' |
|||
|
|||
On windows use `set` instead. |
|||
|
|||
:param root_path: path to which files are read relative from. When the |
|||
config object is created by the application, this is |
|||
the application's :attr:`~flask.Flask.root_path`. |
|||
:param defaults: an optional dictionary of default values |
|||
""" |
|||
|
|||
def __init__(self, root_path, defaults=None): |
|||
dict.__init__(self, defaults or {}) |
|||
self.root_path = root_path |
|||
|
|||
def from_envvar(self, variable_name, silent=False): |
|||
"""Loads a configuration from an environment variable pointing to |
|||
a configuration file. This basically is just a shortcut with nicer |
|||
error messages for this line of code:: |
|||
|
|||
app.config.from_pyfile(os.environ['YOURAPPLICATION_SETTINGS']) |
|||
|
|||
:param variable_name: name of the environment variable |
|||
:param silent: set to `True` if you want silent failing for missing |
|||
files. |
|||
:return: bool. `True` if able to load config, `False` otherwise. |
|||
""" |
|||
rv = os.environ.get(variable_name) |
|||
if not rv: |
|||
if silent: |
|||
return False |
|||
raise RuntimeError('The environment variable %r is not set ' |
|||
'and as such configuration could not be ' |
|||
'loaded. Set this variable and make it ' |
|||
'point to a configuration file' % |
|||
variable_name) |
|||
self.from_pyfile(rv) |
|||
return True |
|||
|
|||
def from_pyfile(self, filename): |
|||
"""Updates the values in the config from a Python file. This function |
|||
behaves as if the file was imported as module with the |
|||
:meth:`from_object` function. |
|||
|
|||
:param filename: the filename of the config. This can either be an |
|||
absolute filename or a filename relative to the |
|||
root path. |
|||
""" |
|||
filename = os.path.join(self.root_path, filename) |
|||
d = imp.new_module('config') |
|||
d.__file__ = filename |
|||
try: |
|||
execfile(filename, d.__dict__) |
|||
except IOError, e: |
|||
e.strerror = 'Unable to load configuration file (%s)' % e.strerror |
|||
raise |
|||
self.from_object(d) |
|||
|
|||
def from_object(self, obj): |
|||
"""Updates the values from the given object. An object can be of one |
|||
of the following two types: |
|||
|
|||
- a string: in this case the object with that name will be imported |
|||
- an actual object reference: that object is used directly |
|||
|
|||
Objects are usually either modules or classes. |
|||
|
|||
Just the uppercase variables in that object are stored in the config |
|||
after lowercasing. Example usage:: |
|||
|
|||
app.config.from_object('yourapplication.default_config') |
|||
from yourapplication import default_config |
|||
app.config.from_object(default_config) |
|||
|
|||
You should not use this function to load the actual configuration but |
|||
rather configuration defaults. The actual config should be loaded |
|||
with :meth:`from_pyfile` and ideally from a location not within the |
|||
package because the package might be installed system wide. |
|||
|
|||
:param obj: an import name or object |
|||
""" |
|||
if isinstance(obj, basestring): |
|||
obj = import_string(obj) |
|||
for key in dir(obj): |
|||
if key.isupper(): |
|||
self[key] = getattr(obj, key) |
|||
|
|||
def __repr__(self): |
|||
return '<%s %s>' % (self.__class__.__name__, dict.__repr__(self)) |
@ -0,0 +1,66 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.ctx |
|||
~~~~~~~~~ |
|||
|
|||
Implements the objects required to keep the context. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from werkzeug.exceptions import HTTPException |
|||
|
|||
from .globals import _request_ctx_stack |
|||
from .session import _NullSession |
|||
|
|||
|
|||
class _RequestGlobals(object): |
|||
pass |
|||
|
|||
|
|||
class _RequestContext(object): |
|||
"""The request context contains all request relevant information. It is |
|||
created at the beginning of the request and pushed to the |
|||
`_request_ctx_stack` and removed at the end of it. It will create the |
|||
URL adapter and request object for the WSGI environment provided. |
|||
""" |
|||
|
|||
def __init__(self, app, environ): |
|||
self.app = app |
|||
self.request = app.request_class(environ) |
|||
self.url_adapter = app.create_url_adapter(self.request) |
|||
self.session = app.open_session(self.request) |
|||
if self.session is None: |
|||
self.session = _NullSession() |
|||
self.g = _RequestGlobals() |
|||
self.flashes = None |
|||
|
|||
try: |
|||
url_rule, self.request.view_args = \ |
|||
self.url_adapter.match(return_rule=True) |
|||
self.request.url_rule = url_rule |
|||
except HTTPException, e: |
|||
self.request.routing_exception = e |
|||
|
|||
def push(self): |
|||
"""Binds the request context.""" |
|||
_request_ctx_stack.push(self) |
|||
|
|||
def pop(self): |
|||
"""Pops the request context.""" |
|||
_request_ctx_stack.pop() |
|||
|
|||
def __enter__(self): |
|||
self.push() |
|||
return self |
|||
|
|||
def __exit__(self, exc_type, exc_value, tb): |
|||
# do not pop the request stack if we are in debug mode and an |
|||
# exception happened. This will allow the debugger to still |
|||
# access the request object in the interactive shell. Furthermore |
|||
# the context can be force kept alive for the test client. |
|||
# See flask.testing for how this works. |
|||
if not self.request.environ.get('flask._preserve_context') and \ |
|||
(tb is None or not self.app.debug): |
|||
self.pop() |
@ -0,0 +1,27 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.globals |
|||
~~~~~~~~~~~~~ |
|||
|
|||
Defines all the global objects that are proxies to the current |
|||
active context. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from functools import partial |
|||
from werkzeug import LocalStack, LocalProxy |
|||
|
|||
def _lookup_object(name): |
|||
top = _request_ctx_stack.top |
|||
if top is None: |
|||
raise RuntimeError('working outside of request context') |
|||
return getattr(top, name) |
|||
|
|||
# context locals |
|||
_request_ctx_stack = LocalStack() |
|||
current_app = LocalProxy(partial(_lookup_object, 'app')) |
|||
request = LocalProxy(partial(_lookup_object, 'request')) |
|||
session = LocalProxy(partial(_lookup_object, 'session')) |
|||
g = LocalProxy(partial(_lookup_object, 'g')) |
@ -0,0 +1,496 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.helpers |
|||
~~~~~~~~~~~~~ |
|||
|
|||
Implements various helpers. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
import os |
|||
import sys |
|||
import posixpath |
|||
import mimetypes |
|||
from time import time |
|||
from zlib import adler32 |
|||
|
|||
# try to load the best simplejson implementation available. If JSON |
|||
# is not installed, we add a failing class. |
|||
json_available = True |
|||
json = None |
|||
try: |
|||
import simplejson as json |
|||
except ImportError: |
|||
try: |
|||
import json |
|||
except ImportError: |
|||
try: |
|||
# Google Appengine offers simplejson via django |
|||
from django.utils import simplejson as json |
|||
except ImportError: |
|||
json_available = False |
|||
|
|||
|
|||
from werkzeug import Headers, wrap_file, cached_property |
|||
from werkzeug.exceptions import NotFound |
|||
|
|||
from jinja2 import FileSystemLoader |
|||
|
|||
from .globals import session, _request_ctx_stack, current_app, request |
|||
|
|||
|
|||
def _assert_have_json(): |
|||
"""Helper function that fails if JSON is unavailable.""" |
|||
if not json_available: |
|||
raise RuntimeError('simplejson not installed') |
|||
|
|||
# figure out if simplejson escapes slashes. This behaviour was changed |
|||
# from one version to another without reason. |
|||
if not json_available or '\\/' not in json.dumps('/'): |
|||
|
|||
def _tojson_filter(*args, **kwargs): |
|||
if __debug__: |
|||
_assert_have_json() |
|||
return json.dumps(*args, **kwargs).replace('/', '\\/') |
|||
else: |
|||
_tojson_filter = json.dumps |
|||
|
|||
|
|||
# what separators does this operating system provide that are not a slash? |
|||
# this is used by the send_from_directory function to ensure that nobody is |
|||
# able to access files from outside the filesystem. |
|||
_os_alt_seps = list(sep for sep in [os.path.sep, os.path.altsep] |
|||
if sep not in (None, '/')) |
|||
|
|||
|
|||
def _endpoint_from_view_func(view_func): |
|||
"""Internal helper that returns the default endpoint for a given |
|||
function. This always is the function name. |
|||
""" |
|||
assert view_func is not None, 'expected view func if endpoint ' \ |
|||
'is not provided.' |
|||
return view_func.__name__ |
|||
|
|||
|
|||
def jsonify(*args, **kwargs): |
|||
"""Creates a :class:`~flask.Response` with the JSON representation of |
|||
the given arguments with an `application/json` mimetype. The arguments |
|||
to this function are the same as to the :class:`dict` constructor. |
|||
|
|||
Example usage:: |
|||
|
|||
@app.route('/_get_current_user') |
|||
def get_current_user(): |
|||
return jsonify(username=g.user.username, |
|||
email=g.user.email, |
|||
id=g.user.id) |
|||
|
|||
This will send a JSON response like this to the browser:: |
|||
|
|||
{ |
|||
"username": "admin", |
|||
"email": "admin@localhost", |
|||
"id": 42 |
|||
} |
|||
|
|||
This requires Python 2.6 or an installed version of simplejson. For |
|||
security reasons only objects are supported toplevel. For more |
|||
information about this, have a look at :ref:`json-security`. |
|||
|
|||
.. versionadded:: 0.2 |
|||
""" |
|||
if __debug__: |
|||
_assert_have_json() |
|||
return current_app.response_class(json.dumps(dict(*args, **kwargs), |
|||
indent=None if request.is_xhr else 2), mimetype='application/json') |
|||
|
|||
|
|||
def make_response(*args): |
|||
"""Sometimes it is necessary to set additional headers in a view. Because |
|||
views do not have to return response objects but can return a value that |
|||
is converted into a response object by Flask itself, it becomes tricky to |
|||
add headers to it. This function can be called instead of using a return |
|||
and you will get a response object which you can use to attach headers. |
|||
|
|||
If view looked like this and you want to add a new header:: |
|||
|
|||
def index(): |
|||
return render_template('index.html', foo=42) |
|||
|
|||
You can now do something like this:: |
|||
|
|||
def index(): |
|||
response = make_response(render_template('index.html', foo=42)) |
|||
response.headers['X-Parachutes'] = 'parachutes are cool' |
|||
return response |
|||
|
|||
This function accepts the very same arguments you can return from a |
|||
view function. This for example creates a response with a 404 error |
|||
code:: |
|||
|
|||
response = make_response(render_template('not_found.html'), 404) |
|||
|
|||
Internally this function does the following things: |
|||
|
|||
- if no arguments are passed, it creates a new response argument |
|||
- if one argument is passed, :meth:`flask.Flask.make_response` |
|||
is invoked with it. |
|||
- if more than one argument is passed, the arguments are passed |
|||
to the :meth:`flask.Flask.make_response` function as tuple. |
|||
|
|||
.. versionadded:: 0.6 |
|||
""" |
|||
if not args: |
|||
return current_app.response_class() |
|||
if len(args) == 1: |
|||
args = args[0] |
|||
return current_app.make_response(args) |
|||
|
|||
|
|||
def url_for(endpoint, **values): |
|||
"""Generates a URL to the given endpoint with the method provided. |
|||
The endpoint is relative to the active module if modules are in use. |
|||
|
|||
Here are some examples: |
|||
|
|||
==================== ======================= ============================= |
|||
Active Module Target Endpoint Target Function |
|||
==================== ======================= ============================= |
|||
`None` ``'index'`` `index` of the application |
|||
`None` ``'.index'`` `index` of the application |
|||
``'admin'`` ``'index'`` `index` of the `admin` module |
|||
any ``'.index'`` `index` of the application |
|||
any ``'admin.index'`` `index` of the `admin` module |
|||
==================== ======================= ============================= |
|||
|
|||
Variable arguments that are unknown to the target endpoint are appended |
|||
to the generated URL as query arguments. |
|||
|
|||
For more information, head over to the :ref:`Quickstart <url-building>`. |
|||
|
|||
:param endpoint: the endpoint of the URL (name of the function) |
|||
:param values: the variable arguments of the URL rule |
|||
:param _external: if set to `True`, an absolute URL is generated. |
|||
""" |
|||
ctx = _request_ctx_stack.top |
|||
if '.' not in endpoint: |
|||
mod = ctx.request.module |
|||
if mod is not None: |
|||
endpoint = mod + '.' + endpoint |
|||
elif endpoint.startswith('.'): |
|||
endpoint = endpoint[1:] |
|||
external = values.pop('_external', False) |
|||
return ctx.url_adapter.build(endpoint, values, force_external=external) |
|||
|
|||
|
|||
def get_template_attribute(template_name, attribute): |
|||
"""Loads a macro (or variable) a template exports. This can be used to |
|||
invoke a macro from within Python code. If you for example have a |
|||
template named `_cider.html` with the following contents: |
|||
|
|||
.. sourcecode:: html+jinja |
|||
|
|||
{% macro hello(name) %}Hello {{ name }}!{% endmacro %} |
|||
|
|||
You can access this from Python code like this:: |
|||
|
|||
hello = get_template_attribute('_cider.html', 'hello') |
|||
return hello('World') |
|||
|
|||
.. versionadded:: 0.2 |
|||
|
|||
:param template_name: the name of the template |
|||
:param attribute: the name of the variable of macro to acccess |
|||
""" |
|||
return getattr(current_app.jinja_env.get_template(template_name).module, |
|||
attribute) |
|||
|
|||
|
|||
def flash(message, category='message'): |
|||
"""Flashes a message to the next request. In order to remove the |
|||
flashed message from the session and to display it to the user, |
|||
the template has to call :func:`get_flashed_messages`. |
|||
|
|||
.. versionchanged: 0.3 |
|||
`category` parameter added. |
|||
|
|||
:param message: the message to be flashed. |
|||
:param category: the category for the message. The following values |
|||
are recommended: ``'message'`` for any kind of message, |
|||
``'error'`` for errors, ``'info'`` for information |
|||
messages and ``'warning'`` for warnings. However any |
|||
kind of string can be used as category. |
|||
""" |
|||
session.setdefault('_flashes', []).append((category, message)) |
|||
|
|||
|
|||
def get_flashed_messages(with_categories=False): |
|||
"""Pulls all flashed messages from the session and returns them. |
|||
Further calls in the same request to the function will return |
|||
the same messages. By default just the messages are returned, |
|||
but when `with_categories` is set to `True`, the return value will |
|||
be a list of tuples in the form ``(category, message)`` instead. |
|||
|
|||
Example usage: |
|||
|
|||
.. sourcecode:: html+jinja |
|||
|
|||
{% for category, msg in get_flashed_messages(with_categories=true) %} |
|||
<p class=flash-{{ category }}>{{ msg }} |
|||
{% endfor %} |
|||
|
|||
.. versionchanged:: 0.3 |
|||
`with_categories` parameter added. |
|||
|
|||
:param with_categories: set to `True` to also receive categories. |
|||
""" |
|||
flashes = _request_ctx_stack.top.flashes |
|||
if flashes is None: |
|||
_request_ctx_stack.top.flashes = flashes = session.pop('_flashes', []) |
|||
if not with_categories: |
|||
return [x[1] for x in flashes] |
|||
return flashes |
|||
|
|||
|
|||
def send_file(filename_or_fp, mimetype=None, as_attachment=False, |
|||
attachment_filename=None, add_etags=True, |
|||
cache_timeout=60 * 60 * 12, conditional=False): |
|||
"""Sends the contents of a file to the client. This will use the |
|||
most efficient method available and configured. By default it will |
|||
try to use the WSGI server's file_wrapper support. Alternatively |
|||
you can set the application's :attr:`~Flask.use_x_sendfile` attribute |
|||
to ``True`` to directly emit an `X-Sendfile` header. This however |
|||
requires support of the underlying webserver for `X-Sendfile`. |
|||
|
|||
By default it will try to guess the mimetype for you, but you can |
|||
also explicitly provide one. For extra security you probably want |
|||
to send certain files as attachment (HTML for instance). The mimetype |
|||
guessing requires a `filename` or an `attachment_filename` to be |
|||
provided. |
|||
|
|||
Please never pass filenames to this function from user sources without |
|||
checking them first. Something like this is usually sufficient to |
|||
avoid security problems:: |
|||
|
|||
if '..' in filename or filename.startswith('/'): |
|||
abort(404) |
|||
|
|||
.. versionadded:: 0.2 |
|||
|
|||
.. versionadded:: 0.5 |
|||
The `add_etags`, `cache_timeout` and `conditional` parameters were |
|||
added. The default behaviour is now to attach etags. |
|||
|
|||
.. versionchanged:: 0.7 |
|||
mimetype guessing and etag support for file objects was |
|||
deprecated because it was unreliable. Pass a filename if you are |
|||
able to, otherwise attach an etag yourself. This functionality |
|||
will be removed in Flask 1.0 |
|||
|
|||
:param filename_or_fp: the filename of the file to send. This is |
|||
relative to the :attr:`~Flask.root_path` if a |
|||
relative path is specified. |
|||
Alternatively a file object might be provided |
|||
in which case `X-Sendfile` might not work and |
|||
fall back to the traditional method. Make sure |
|||
that the file pointer is positioned at the start |
|||
of data to send before calling :func:`send_file`. |
|||
:param mimetype: the mimetype of the file if provided, otherwise |
|||
auto detection happens. |
|||
:param as_attachment: set to `True` if you want to send this file with |
|||
a ``Content-Disposition: attachment`` header. |
|||
:param attachment_filename: the filename for the attachment if it |
|||
differs from the file's filename. |
|||
:param add_etags: set to `False` to disable attaching of etags. |
|||
:param conditional: set to `True` to enable conditional responses. |
|||
:param cache_timeout: the timeout in seconds for the headers. |
|||
""" |
|||
mtime = None |
|||
if isinstance(filename_or_fp, basestring): |
|||
filename = filename_or_fp |
|||
file = None |
|||
else: |
|||
from warnings import warn |
|||
file = filename_or_fp |
|||
filename = getattr(file, 'name', None) |
|||
|
|||
# XXX: this behaviour is now deprecated because it was unreliable. |
|||
# removed in Flask 1.0 |
|||
if not attachment_filename and not mimetype \ |
|||
and isinstance(filename, basestring): |
|||
warn(DeprecationWarning('The filename support for file objects ' |
|||
'passed to send_file is not deprecated. Pass an ' |
|||
'attach_filename if you want mimetypes to be guessed.'), |
|||
stacklevel=2) |
|||
if add_etags: |
|||
warn(DeprecationWarning('In future flask releases etags will no ' |
|||
'longer be generated for file objects passed to the send_file ' |
|||
'function because this behaviour was unreliable. Pass ' |
|||
'filenames instead if possible, otherwise attach an etag ' |
|||
'yourself based on another value'), stacklevel=2) |
|||
|
|||
if filename is not None: |
|||
if not os.path.isabs(filename): |
|||
filename = os.path.join(current_app.root_path, filename) |
|||
if mimetype is None and (filename or attachment_filename): |
|||
mimetype = mimetypes.guess_type(filename or attachment_filename)[0] |
|||
if mimetype is None: |
|||
mimetype = 'application/octet-stream' |
|||
|
|||
headers = Headers() |
|||
if as_attachment: |
|||
if attachment_filename is None: |
|||
if filename is None: |
|||
raise TypeError('filename unavailable, required for ' |
|||
'sending as attachment') |
|||
attachment_filename = os.path.basename(filename) |
|||
headers.add('Content-Disposition', 'attachment', |
|||
filename=attachment_filename) |
|||
|
|||
if current_app.use_x_sendfile and filename: |
|||
if file is not None: |
|||
file.close() |
|||
headers['X-Sendfile'] = filename |
|||
data = None |
|||
else: |
|||
if file is None: |
|||
file = open(filename, 'rb') |
|||
mtime = os.path.getmtime(filename) |
|||
data = wrap_file(request.environ, file) |
|||
|
|||
rv = current_app.response_class(data, mimetype=mimetype, headers=headers, |
|||
direct_passthrough=True) |
|||
|
|||
# if we know the file modification date, we can store it as the |
|||
# the time of the last modification. |
|||
if mtime is not None: |
|||
rv.last_modified = int(mtime) |
|||
|
|||
rv.cache_control.public = True |
|||
if cache_timeout: |
|||
rv.cache_control.max_age = cache_timeout |
|||
rv.expires = int(time() + cache_timeout) |
|||
|
|||
if add_etags and filename is not None: |
|||
rv.set_etag('flask-%s-%s-%s' % ( |
|||
os.path.getmtime(filename), |
|||
os.path.getsize(filename), |
|||
adler32(filename) & 0xffffffff |
|||
)) |
|||
if conditional: |
|||
rv = rv.make_conditional(request) |
|||
# make sure we don't send x-sendfile for servers that |
|||
# ignore the 304 status code for x-sendfile. |
|||
if rv.status_code == 304: |
|||
rv.headers.pop('x-sendfile', None) |
|||
return rv |
|||
|
|||
|
|||
def send_from_directory(directory, filename, **options): |
|||
"""Send a file from a given directory with :func:`send_file`. This |
|||
is a secure way to quickly expose static files from an upload folder |
|||
or something similar. |
|||
|
|||
Example usage:: |
|||
|
|||
@app.route('/uploads/<path:filename>') |
|||
def download_file(filename): |
|||
return send_from_directory(app.config['UPLOAD_FOLDER'], |
|||
filename, as_attachment=True) |
|||
|
|||
.. admonition:: Sending files and Performance |
|||
|
|||
It is strongly recommended to activate either `X-Sendfile` support in |
|||
your webserver or (if no authentication happens) to tell the webserver |
|||
to serve files for the given path on its own without calling into the |
|||
web application for improved performance. |
|||
|
|||
.. versionadded:: 0.5 |
|||
|
|||
:param directory: the directory where all the files are stored. |
|||
:param filename: the filename relative to that directory to |
|||
download. |
|||
:param options: optional keyword arguments that are directly |
|||
forwarded to :func:`send_file`. |
|||
""" |
|||
filename = posixpath.normpath(filename) |
|||
for sep in _os_alt_seps: |
|||
if sep in filename: |
|||
raise NotFound() |
|||
if os.path.isabs(filename) or filename.startswith('../'): |
|||
raise NotFound() |
|||
filename = os.path.join(directory, filename) |
|||
if not os.path.isfile(filename): |
|||
raise NotFound() |
|||
return send_file(filename, conditional=True, **options) |
|||
|
|||
|
|||
def _get_package_path(name): |
|||
"""Returns the path to a package or cwd if that cannot be found.""" |
|||
try: |
|||
return os.path.abspath(os.path.dirname(sys.modules[name].__file__)) |
|||
except (KeyError, AttributeError): |
|||
return os.getcwd() |
|||
|
|||
|
|||
class _PackageBoundObject(object): |
|||
|
|||
def __init__(self, import_name): |
|||
#: The name of the package or module. Do not change this once |
|||
#: it was set by the constructor. |
|||
self.import_name = import_name |
|||
|
|||
#: Where is the app root located? |
|||
self.root_path = _get_package_path(self.import_name) |
|||
|
|||
@property |
|||
def has_static_folder(self): |
|||
"""This is `True` if the package bound object's container has a |
|||
folder named ``'static'``. |
|||
|
|||
.. versionadded:: 0.5 |
|||
""" |
|||
return os.path.isdir(os.path.join(self.root_path, 'static')) |
|||
|
|||
@cached_property |
|||
def jinja_loader(self): |
|||
"""The Jinja loader for this package bound object. |
|||
|
|||
.. versionadded:: 0.5 |
|||
""" |
|||
return FileSystemLoader(os.path.join(self.root_path, 'templates')) |
|||
|
|||
def send_static_file(self, filename): |
|||
"""Function used internally to send static files from the static |
|||
folder to the browser. |
|||
|
|||
.. versionadded:: 0.5 |
|||
""" |
|||
return send_from_directory(os.path.join(self.root_path, 'static'), |
|||
filename) |
|||
|
|||
def open_resource(self, resource): |
|||
"""Opens a resource from the application's resource folder. To see |
|||
how this works, consider the following folder structure:: |
|||
|
|||
/myapplication.py |
|||
/schema.sql |
|||
/static |
|||
/style.css |
|||
/templates |
|||
/layout.html |
|||
/index.html |
|||
|
|||
If you want to open the `schema.sql` file you would do the |
|||
following:: |
|||
|
|||
with app.open_resource('schema.sql') as f: |
|||
contents = f.read() |
|||
do_something_with(contents) |
|||
|
|||
:param resource: the name of the resource. To access resources within |
|||
subfolders use forward slashes as separator. |
|||
""" |
|||
return open(os.path.join(self.root_path, resource), 'rb') |
@ -0,0 +1,42 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.logging |
|||
~~~~~~~~~~~~~ |
|||
|
|||
Implements the logging support for Flask. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from __future__ import absolute_import |
|||
|
|||
from logging import getLogger, StreamHandler, Formatter, Logger, DEBUG |
|||
|
|||
|
|||
def create_logger(app): |
|||
"""Creates a logger for the given application. This logger works |
|||
similar to a regular Python logger but changes the effective logging |
|||
level based on the application's debug flag. Furthermore this |
|||
function also removes all attached handlers in case there was a |
|||
logger with the log name before. |
|||
""" |
|||
|
|||
class DebugLogger(Logger): |
|||
def getEffectiveLevel(x): |
|||
return DEBUG if app.debug else Logger.getEffectiveLevel(x) |
|||
|
|||
class DebugHandler(StreamHandler): |
|||
def emit(x, record): |
|||
StreamHandler.emit(x, record) if app.debug else None |
|||
|
|||
handler = DebugHandler() |
|||
handler.setLevel(DEBUG) |
|||
handler.setFormatter(Formatter(app.debug_log_format)) |
|||
logger = getLogger(app.logger_name) |
|||
# just in case that was not a new logger, get rid of all the handlers |
|||
# already attached to it. |
|||
del logger.handlers[:] |
|||
logger.__class__ = DebugLogger |
|||
logger.addHandler(handler) |
|||
return logger |
@ -0,0 +1,230 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.module |
|||
~~~~~~~~~~~~ |
|||
|
|||
Implements a class that represents module blueprints. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from .helpers import _PackageBoundObject, _endpoint_from_view_func |
|||
|
|||
|
|||
def _register_module(module, static_path): |
|||
"""Internal helper function that returns a function for recording |
|||
that registers the `send_static_file` function for the module on |
|||
the application if necessary. It also registers the module on |
|||
the application. |
|||
""" |
|||
def _register(state): |
|||
state.app.modules[module.name] = module |
|||
# do not register the rule if the static folder of the |
|||
# module is the same as the one from the application. |
|||
if state.app.root_path == module.root_path: |
|||
return |
|||
path = static_path |
|||
if path is None: |
|||
path = state.app.static_path |
|||
if state.url_prefix: |
|||
path = state.url_prefix + path |
|||
state.app.add_url_rule(path + '/<path:filename>', |
|||
endpoint='%s.static' % module.name, |
|||
view_func=module.send_static_file, |
|||
subdomain=state.subdomain) |
|||
return _register |
|||
|
|||
|
|||
class _ModuleSetupState(object): |
|||
|
|||
def __init__(self, app, url_prefix=None, subdomain=None): |
|||
self.app = app |
|||
self.url_prefix = url_prefix |
|||
self.subdomain = subdomain |
|||
|
|||
|
|||
class Module(_PackageBoundObject): |
|||
"""Container object that enables pluggable applications. A module can |
|||
be used to organize larger applications. They represent blueprints that, |
|||
in combination with a :class:`Flask` object are used to create a large |
|||
application. |
|||
|
|||
A module is like an application bound to an `import_name`. Multiple |
|||
modules can share the same import names, but in that case a `name` has |
|||
to be provided to keep them apart. If different import names are used, |
|||
the rightmost part of the import name is used as name. |
|||
|
|||
Here's an example structure for a larger application:: |
|||
|
|||
/myapplication |
|||
/__init__.py |
|||
/views |
|||
/__init__.py |
|||
/admin.py |
|||
/frontend.py |
|||
|
|||
The `myapplication/__init__.py` can look like this:: |
|||
|
|||
from flask import Flask |
|||
from myapplication.views.admin import admin |
|||
from myapplication.views.frontend import frontend |
|||
|
|||
app = Flask(__name__) |
|||
app.register_module(admin, url_prefix='/admin') |
|||
app.register_module(frontend) |
|||
|
|||
And here's an example view module (`myapplication/views/admin.py`):: |
|||
|
|||
from flask import Module |
|||
|
|||
admin = Module(__name__) |
|||
|
|||
@admin.route('/') |
|||
def index(): |
|||
pass |
|||
|
|||
@admin.route('/login') |
|||
def login(): |
|||
pass |
|||
|
|||
For a gentle introduction into modules, checkout the |
|||
:ref:`working-with-modules` section. |
|||
|
|||
.. versionadded:: 0.5 |
|||
The `static_path` parameter was added and it's now possible for |
|||
modules to refer to their own templates and static files. See |
|||
:ref:`modules-and-resources` for more information. |
|||
|
|||
.. versionadded:: 0.6 |
|||
The `subdomain` parameter was added. |
|||
|
|||
:param import_name: the name of the Python package or module |
|||
implementing this :class:`Module`. |
|||
:param name: the internal short name for the module. Unless specified |
|||
the rightmost part of the import name |
|||
:param url_prefix: an optional string that is used to prefix all the |
|||
URL rules of this module. This can also be specified |
|||
when registering the module with the application. |
|||
:param subdomain: used to set the subdomain setting for URL rules that |
|||
do not have a subdomain setting set. |
|||
:param static_path: can be used to specify a different path for the |
|||
static files on the web. Defaults to ``/static``. |
|||
This does not affect the folder the files are served |
|||
*from*. |
|||
""" |
|||
|
|||
def __init__(self, import_name, name=None, url_prefix=None, |
|||
static_path=None, subdomain=None): |
|||
if name is None: |
|||
assert '.' in import_name, 'name required if package name ' \ |
|||
'does not point to a submodule' |
|||
name = import_name.rsplit('.', 1)[1] |
|||
_PackageBoundObject.__init__(self, import_name) |
|||
self.name = name |
|||
self.url_prefix = url_prefix |
|||
self.subdomain = subdomain |
|||
self.view_functions = {} |
|||
self._register_events = [_register_module(self, static_path)] |
|||
|
|||
def route(self, rule, **options): |
|||
"""Like :meth:`Flask.route` but for a module. The endpoint for the |
|||
:func:`url_for` function is prefixed with the name of the module. |
|||
""" |
|||
def decorator(f): |
|||
self.add_url_rule(rule, f.__name__, f, **options) |
|||
return f |
|||
return decorator |
|||
|
|||
def add_url_rule(self, rule, endpoint=None, view_func=None, **options): |
|||
"""Like :meth:`Flask.add_url_rule` but for a module. The endpoint for |
|||
the :func:`url_for` function is prefixed with the name of the module. |
|||
|
|||
.. versionchanged:: 0.6 |
|||
The `endpoint` argument is now optional and will default to the |
|||
function name to consistent with the function of the same name |
|||
on the application object. |
|||
""" |
|||
def register_rule(state): |
|||
the_rule = rule |
|||
if state.url_prefix: |
|||
the_rule = state.url_prefix + rule |
|||
options.setdefault('subdomain', state.subdomain) |
|||
the_endpoint = endpoint |
|||
if the_endpoint is None: |
|||
the_endpoint = _endpoint_from_view_func(view_func) |
|||
state.app.add_url_rule(the_rule, '%s.%s' % (self.name, |
|||
the_endpoint), |
|||
view_func, **options) |
|||
self._record(register_rule) |
|||
|
|||
def endpoint(self, endpoint): |
|||
"""Like :meth:`Flask.endpoint` but for a module.""" |
|||
def decorator(f): |
|||
self.view_functions[endpoint] = f |
|||
return f |
|||
return decorator |
|||
|
|||
def before_request(self, f): |
|||
"""Like :meth:`Flask.before_request` but for a module. This function |
|||
is only executed before each request that is handled by a function of |
|||
that module. |
|||
""" |
|||
self._record(lambda s: s.app.before_request_funcs |
|||
.setdefault(self.name, []).append(f)) |
|||
return f |
|||
|
|||
def before_app_request(self, f): |
|||
"""Like :meth:`Flask.before_request`. Such a function is executed |
|||
before each request, even if outside of a module. |
|||
""" |
|||
self._record(lambda s: s.app.before_request_funcs |
|||
.setdefault(None, []).append(f)) |
|||
return f |
|||
|
|||
def after_request(self, f): |
|||
"""Like :meth:`Flask.after_request` but for a module. This function |
|||
is only executed after each request that is handled by a function of |
|||
that module. |
|||
""" |
|||
self._record(lambda s: s.app.after_request_funcs |
|||
.setdefault(self.name, []).append(f)) |
|||
return f |
|||
|
|||
def after_app_request(self, f): |
|||
"""Like :meth:`Flask.after_request` but for a module. Such a function |
|||
is executed after each request, even if outside of the module. |
|||
""" |
|||
self._record(lambda s: s.app.after_request_funcs |
|||
.setdefault(None, []).append(f)) |
|||
return f |
|||
|
|||
def context_processor(self, f): |
|||
"""Like :meth:`Flask.context_processor` but for a module. This |
|||
function is only executed for requests handled by a module. |
|||
""" |
|||
self._record(lambda s: s.app.template_context_processors |
|||
.setdefault(self.name, []).append(f)) |
|||
return f |
|||
|
|||
def app_context_processor(self, f): |
|||
"""Like :meth:`Flask.context_processor` but for a module. Such a |
|||
function is executed each request, even if outside of the module. |
|||
""" |
|||
self._record(lambda s: s.app.template_context_processors |
|||
.setdefault(None, []).append(f)) |
|||
return f |
|||
|
|||
def app_errorhandler(self, code): |
|||
"""Like :meth:`Flask.errorhandler` but for a module. This |
|||
handler is used for all requests, even if outside of the module. |
|||
|
|||
.. versionadded:: 0.4 |
|||
""" |
|||
def decorator(f): |
|||
self._record(lambda s: s.app.errorhandler(code)(f)) |
|||
return f |
|||
return decorator |
|||
|
|||
def _record(self, func): |
|||
self._register_events.append(func) |
@ -0,0 +1,43 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.session |
|||
~~~~~~~~~~~~~ |
|||
|
|||
Implements cookie based sessions based on Werkzeug's secure cookie |
|||
system. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from werkzeug.contrib.securecookie import SecureCookie |
|||
|
|||
|
|||
class Session(SecureCookie): |
|||
"""Expands the session with support for switching between permanent |
|||
and non-permanent sessions. |
|||
""" |
|||
|
|||
def _get_permanent(self): |
|||
return self.get('_permanent', False) |
|||
|
|||
def _set_permanent(self, value): |
|||
self['_permanent'] = bool(value) |
|||
|
|||
permanent = property(_get_permanent, _set_permanent) |
|||
del _get_permanent, _set_permanent |
|||
|
|||
|
|||
class _NullSession(Session): |
|||
"""Class used to generate nicer error messages if sessions are not |
|||
available. Will still allow read-only access to the empty session |
|||
but fail on setting. |
|||
""" |
|||
|
|||
def _fail(self, *args, **kwargs): |
|||
raise RuntimeError('the session is unavailable because no secret ' |
|||
'key was set. Set the secret_key on the ' |
|||
'application to something unique and secret.') |
|||
__setitem__ = __delitem__ = clear = pop = popitem = \ |
|||
update = setdefault = _fail |
|||
del _fail |
@ -0,0 +1,50 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.signals |
|||
~~~~~~~~~~~~~ |
|||
|
|||
Implements signals based on blinker if available, otherwise |
|||
falls silently back to a noop |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
signals_available = False |
|||
try: |
|||
from blinker import Namespace |
|||
signals_available = True |
|||
except ImportError: |
|||
class Namespace(object): |
|||
def signal(self, name, doc=None): |
|||
return _FakeSignal(name, doc) |
|||
|
|||
class _FakeSignal(object): |
|||
"""If blinker is unavailable, create a fake class with the same |
|||
interface that allows sending of signals but will fail with an |
|||
error on anything else. Instead of doing anything on send, it |
|||
will just ignore the arguments and do nothing instead. |
|||
""" |
|||
|
|||
def __init__(self, name, doc=None): |
|||
self.name = name |
|||
self.__doc__ = doc |
|||
def _fail(self, *args, **kwargs): |
|||
raise RuntimeError('signalling support is unavailable ' |
|||
'because the blinker library is ' |
|||
'not installed.') |
|||
send = lambda *a, **kw: None |
|||
connect = disconnect = has_receivers_for = receivers_for = \ |
|||
temporarily_connected_to = _fail |
|||
del _fail |
|||
|
|||
# the namespace for code signals. If you are not flask code, do |
|||
# not put signals in here. Create your own namespace instead. |
|||
_signals = Namespace() |
|||
|
|||
|
|||
# core signals. For usage examples grep the sourcecode or consult |
|||
# the API documentation in docs/api.rst as well as docs/signals.rst |
|||
template_rendered = _signals.signal('template-rendered') |
|||
request_started = _signals.signal('request-started') |
|||
request_finished = _signals.signal('request-finished') |
|||
got_request_exception = _signals.signal('got-request-exception') |
@ -0,0 +1,100 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.templating |
|||
~~~~~~~~~~~~~~~~ |
|||
|
|||
Implements the bridge to Jinja2. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
import posixpath |
|||
from jinja2 import BaseLoader, TemplateNotFound |
|||
|
|||
from .globals import _request_ctx_stack |
|||
from .signals import template_rendered |
|||
|
|||
|
|||
def _default_template_ctx_processor(): |
|||
"""Default template context processor. Injects `request`, |
|||
`session` and `g`. |
|||
""" |
|||
reqctx = _request_ctx_stack.top |
|||
return dict( |
|||
config=reqctx.app.config, |
|||
request=reqctx.request, |
|||
session=reqctx.session, |
|||
g=reqctx.g |
|||
) |
|||
|
|||
|
|||
class _DispatchingJinjaLoader(BaseLoader): |
|||
"""A loader that looks for templates in the application and all |
|||
the module folders. |
|||
""" |
|||
|
|||
def __init__(self, app): |
|||
self.app = app |
|||
|
|||
def get_source(self, environment, template): |
|||
template = posixpath.normpath(template) |
|||
if template.startswith('../'): |
|||
raise TemplateNotFound(template) |
|||
loader = None |
|||
try: |
|||
module, name = template.split('/', 1) |
|||
loader = self.app.modules[module].jinja_loader |
|||
except (ValueError, KeyError): |
|||
pass |
|||
# if there was a module and it has a loader, try this first |
|||
if loader is not None: |
|||
try: |
|||
return loader.get_source(environment, name) |
|||
except TemplateNotFound: |
|||
pass |
|||
# fall back to application loader if module failed |
|||
return self.app.jinja_loader.get_source(environment, template) |
|||
|
|||
def list_templates(self): |
|||
result = self.app.jinja_loader.list_templates() |
|||
for name, module in self.app.modules.iteritems(): |
|||
if module.jinja_loader is not None: |
|||
for template in module.jinja_loader.list_templates(): |
|||
result.append('%s/%s' % (name, template)) |
|||
return result |
|||
|
|||
|
|||
def _render(template, context, app): |
|||
"""Renders the template and fires the signal""" |
|||
rv = template.render(context) |
|||
template_rendered.send(app, template=template, context=context) |
|||
return rv |
|||
|
|||
|
|||
def render_template(template_name, **context): |
|||
"""Renders a template from the template folder with the given |
|||
context. |
|||
|
|||
:param template_name: the name of the template to be rendered |
|||
:param context: the variables that should be available in the |
|||
context of the template. |
|||
""" |
|||
ctx = _request_ctx_stack.top |
|||
ctx.app.update_template_context(context) |
|||
return _render(ctx.app.jinja_env.get_template(template_name), |
|||
context, ctx.app) |
|||
|
|||
|
|||
def render_template_string(source, **context): |
|||
"""Renders a template from the given template source string |
|||
with the given context. |
|||
|
|||
:param template_name: the sourcecode of the template to be |
|||
rendered |
|||
:param context: the variables that should be available in the |
|||
context of the template. |
|||
""" |
|||
ctx = _request_ctx_stack.top |
|||
ctx.app.update_template_context(context) |
|||
return _render(ctx.app.jinja_env.from_string(source), |
|||
context, ctx.app) |
@ -0,0 +1,67 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.testing |
|||
~~~~~~~~~~~~~ |
|||
|
|||
Implements test support helpers. This module is lazily imported |
|||
and usually not used in production environments. |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from werkzeug import Client, EnvironBuilder |
|||
from flask import _request_ctx_stack |
|||
|
|||
|
|||
class FlaskClient(Client): |
|||
"""Works like a regular Werkzeug test client but has some |
|||
knowledge about how Flask works to defer the cleanup of the |
|||
request context stack to the end of a with body when used |
|||
in a with statement. |
|||
""" |
|||
|
|||
preserve_context = context_preserved = False |
|||
|
|||
def open(self, *args, **kwargs): |
|||
if self.context_preserved: |
|||
_request_ctx_stack.pop() |
|||
self.context_preserved = False |
|||
kwargs.setdefault('environ_overrides', {}) \ |
|||
['flask._preserve_context'] = self.preserve_context |
|||
|
|||
as_tuple = kwargs.pop('as_tuple', False) |
|||
buffered = kwargs.pop('buffered', False) |
|||
follow_redirects = kwargs.pop('follow_redirects', False) |
|||
|
|||
builder = EnvironBuilder(*args, **kwargs) |
|||
|
|||
if self.application.config.get('SERVER_NAME'): |
|||
server_name = self.application.config.get('SERVER_NAME') |
|||
if ':' not in server_name: |
|||
http_host, http_port = server_name, None |
|||
else: |
|||
http_host, http_port = server_name.split(':', 1) |
|||
if builder.base_url == 'http://localhost/': |
|||
# Default Generated Base URL |
|||
if http_port != None: |
|||
builder.host = http_host + ':' + http_port |
|||
else: |
|||
builder.host = http_host |
|||
old = _request_ctx_stack.top |
|||
try: |
|||
return Client.open(self, builder, |
|||
as_tuple=as_tuple, |
|||
buffered=buffered, |
|||
follow_redirects=follow_redirects) |
|||
finally: |
|||
self.context_preserved = _request_ctx_stack.top is not old |
|||
|
|||
def __enter__(self): |
|||
self.preserve_context = True |
|||
return self |
|||
|
|||
def __exit__(self, exc_type, exc_value, tb): |
|||
self.preserve_context = False |
|||
if self.context_preserved: |
|||
_request_ctx_stack.pop() |
@ -0,0 +1,88 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
flask.wrappers |
|||
~~~~~~~~~~~~~~ |
|||
|
|||
Implements the WSGI wrappers (request and response). |
|||
|
|||
:copyright: (c) 2010 by Armin Ronacher. |
|||
:license: BSD, see LICENSE for more details. |
|||
""" |
|||
|
|||
from werkzeug import Request as RequestBase, Response as ResponseBase, \ |
|||
cached_property |
|||
|
|||
from .helpers import json, _assert_have_json |
|||
from .globals import _request_ctx_stack |
|||
|
|||
|
|||
class Request(RequestBase): |
|||
"""The request object used by default in Flask. Remembers the |
|||
matched endpoint and view arguments. |
|||
|
|||
It is what ends up as :class:`~flask.request`. If you want to replace |
|||
the request object used you can subclass this and set |
|||
:attr:`~flask.Flask.request_class` to your subclass. |
|||
""" |
|||
|
|||
#: the internal URL rule that matched the request. This can be |
|||
#: useful to inspect which methods are allowed for the URL from |
|||
#: a before/after handler (``request.url_rule.methods``) etc. |
|||
#: |
|||
#: .. versionadded:: 0.6 |
|||
url_rule = None |
|||
|
|||
#: a dict of view arguments that matched the request. If an exception |
|||
#: happened when matching, this will be `None`. |
|||
view_args = None |
|||
|
|||
#: if matching the URL failed, this is the exception that will be |
|||
#: raised / was raised as part of the request handling. This is |
|||
#: usually a :exc:`~werkzeug.exceptions.NotFound` exception or |
|||
#: something similar. |
|||
routing_exception = None |
|||
|
|||
@property |
|||
def max_content_length(self): |
|||
"""Read-only view of the `MAX_CONTENT_LENGTH` config key.""" |
|||
ctx = _request_ctx_stack.top |
|||
if ctx is not None: |
|||
return ctx.app.config['MAX_CONTENT_LENGTH'] |
|||
|
|||
@property |
|||
def endpoint(self): |
|||
"""The endpoint that matched the request. This in combination with |
|||
:attr:`view_args` can be used to reconstruct the same or a |
|||
modified URL. If an exception happened when matching, this will |
|||
be `None`. |
|||
""" |
|||
if self.url_rule is not None: |
|||
return self.url_rule.endpoint |
|||
|
|||
@property |
|||
def module(self): |
|||
"""The name of the current module""" |
|||
if self.url_rule and '.' in self.url_rule.endpoint: |
|||
return self.url_rule.endpoint.rsplit('.', 1)[0] |
|||
|
|||
@cached_property |
|||
def json(self): |
|||
"""If the mimetype is `application/json` this will contain the |
|||
parsed JSON data. |
|||
""" |
|||
if __debug__: |
|||
_assert_have_json() |
|||
if self.mimetype == 'application/json': |
|||
return json.loads(self.data) |
|||
|
|||
|
|||
class Response(ResponseBase): |
|||
"""The response object that is used by default in Flask. Works like the |
|||
response object from Werkzeug but is set to have an HTML mimetype by |
|||
default. Quite often you don't have to create this object yourself because |
|||
:meth:`~flask.Flask.make_response` will take care of that for you. |
|||
|
|||
If you want to replace the response object used you can subclass this and |
|||
set :attr:`~flask.Flask.response_class` to your subclass. |
|||
""" |
|||
default_mimetype = 'text/html' |
@ -0,0 +1,206 @@ |
|||
from hachoir_parser import createParser |
|||
from hachoir_metadata import extractMetadata |
|||
from hachoir_core.cmd_line import unicodeFilename |
|||
|
|||
import datetime |
|||
import json |
|||
import sys |
|||
import re |
|||
|
|||
|
|||
def getMetadata(filename): |
|||
filename, realname = unicodeFilename(filename), filename |
|||
parser = createParser(filename, realname) |
|||
try: |
|||
metadata = extractMetadata(parser) |
|||
except: |
|||
return None |
|||
|
|||
if metadata is not None: |
|||
metadata = metadata.exportPlaintext() |
|||
return metadata |
|||
return None |
|||
|
|||
def parseMetadata(meta, jsonsafe=True): |
|||
''' |
|||
Return a dict of section headings like 'Video stream' or 'Audio stream'. Each key will have a list of dicts. |
|||
This supports multiple video/audio/subtitle/whatever streams per stream type. Each element in the list of streams |
|||
will he a dict with keys like 'Image height' and 'Compression'...anything that hachoir is able to extract. |
|||
|
|||
An example output: |
|||
{'Audio stream': [{u'Channel': u'6', |
|||
u'Compression': u'A_AC3', |
|||
u'Sample rate': u'48.0 kHz'}], |
|||
u'Common': [{u'Creation date': u'2008-03-20 09:09:43', |
|||
u'Duration': u'1 hour 40 min 6 sec', |
|||
u'Endianness': u'Big endian', |
|||
u'MIME type': u'video/x-matroska', |
|||
u'Producer': u'libebml v0.7.7 + libmatroska v0.8.1'}], |
|||
'Video stream': [{u'Compression': u'V_MPEG4/ISO/AVC', |
|||
u'Image height': u'688 pixels', |
|||
u'Image width': u'1280 pixels', |
|||
u'Language': u'English'}]} |
|||
''' |
|||
if not meta: |
|||
return |
|||
sections = {} |
|||
what = [] |
|||
for line in meta: |
|||
#if line doesn't start with "- " it is a section heading |
|||
if line[:2] != "- ": |
|||
section = line.strip(":").lower() |
|||
|
|||
#lets collapse multiple stream headings into one... |
|||
search = re.search(r'#\d+\Z', section) |
|||
if search: |
|||
section = re.sub(search.group(), '', section).strip() |
|||
|
|||
if section not in sections: |
|||
sections[section] = [dict()] |
|||
else: |
|||
sections[section].append(dict()) |
|||
else: |
|||
#This isn't a section heading, so we put it in the last section heading we found. |
|||
#meta always starts out with a section heading so 'section' will always be defined |
|||
i = line.find(":") |
|||
key = line[2:i].lower() |
|||
value = _parseValue(section, key, line[i+2:]) |
|||
|
|||
if value is None: |
|||
value = line[i+2:] |
|||
|
|||
if jsonsafe: |
|||
try: |
|||
v = json.dumps(value) |
|||
except TypeError: |
|||
value = str(value) |
|||
|
|||
sections[section][-1][key] = value |
|||
|
|||
|
|||
|
|||
return sections |
|||
|
|||
def _parseValue(section, key, value, jsonsafe = True): |
|||
''' |
|||
Tediously check all the types that we know about (checked over 7k videos to find these) |
|||
and convert them to python native types. |
|||
|
|||
If jsonsafe is True, we'll make json-unfriendly types like datetime into json-friendly. |
|||
''' |
|||
|
|||
date_search = re.search("\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d", value) |
|||
|
|||
if key == 'bit rate': |
|||
ret = _parseBitRate(value.lower()) |
|||
elif key == 'bits/sample' or key == 'bits/pixel': |
|||
try: |
|||
bits = int(value.split()[0]) |
|||
ret = bits |
|||
except: |
|||
ret = None |
|||
elif key == 'channel': |
|||
if value == 'stereo': |
|||
ret = 2 |
|||
elif value == 'mono': |
|||
ret = 1 |
|||
else: |
|||
try: |
|||
channels = int(value) |
|||
ret = channels |
|||
except: |
|||
ret = None |
|||
elif key == 'compression': |
|||
ret = _parseCompression(value) |
|||
elif key == 'compression rate': |
|||
try: |
|||
ret = float(value.split('x')[0]) |
|||
except: |
|||
ret = None |
|||
elif key == 'duration': |
|||
try: |
|||
ret = _parseDuration(value) |
|||
except: |
|||
ret = None |
|||
elif key == 'sample rate': |
|||
try: |
|||
ret = float(value.split()[0]) * 1000 |
|||
except: |
|||
ret = None |
|||
elif key == 'frame rate': |
|||
try: |
|||
ret = float(value.split()[0]) |
|||
except: |
|||
pass |
|||
elif key == 'image height' or key == 'image width': |
|||
pixels = re.match("(?P<pixels>\d{1,4}) pixel", value) |
|||
if pixels: |
|||
ret = int(pixels.group('pixels')) |
|||
else: |
|||
ret = None |
|||
elif date_search: |
|||
try: |
|||
ret = datetime.datetime.strptime(date_search.group(), "%Y-%m-%d %H:%M:%S") |
|||
except: |
|||
ret = None |
|||
else: |
|||
#If it's something we don't know about... |
|||
ret = None |
|||
|
|||
return ret |
|||
|
|||
def _parseDuration(value): |
|||
t = re.search(r"((?P<hour>\d+) hour(s|))? ?((?P<min>\d+) min)? ?((?P<sec>\d+) sec)? ?((?P<ms>\d+) ms)?", value) |
|||
if t: |
|||
hour = 0 if not t.group('hour') else int(t.group('hour')) |
|||
min = 0 if not t.group('min') else int(t.group('min')) |
|||
sec = 0 if not t.group('sec') else int(t.group('sec')) |
|||
ms = 0 if not t.group('ms') else int(t.group('ms')) |
|||
return datetime.timedelta(hours = hour, minutes = min, seconds = sec, milliseconds = ms) |
|||
|
|||
def _parseCompression(value): |
|||
codecs = { |
|||
'v_mpeg4/iso/avc': 'AVC', |
|||
'x264': 'AVC', |
|||
'divx': 'divx', |
|||
'xvid': 'xvid', |
|||
'v_ms/vfw/fourcc': 'vfw', |
|||
'vorbis': 'vorbis', |
|||
'xvid': 'xvid', |
|||
'mpeg layer 3': 'mp3', |
|||
'a_dts': 'DTS', |
|||
'a_aac': 'AAC', |
|||
'a_truehd': 'TRUEHD', |
|||
'microsoft mpeg': 'MPEG', |
|||
'ac3': 'AC3', |
|||
'wvc1': 'WVC1', |
|||
'pulse code modulation': 'PCM', |
|||
'pcm': 'PCM', |
|||
'windows media audio': 'WMA', |
|||
'windows media video': 'WMV', |
|||
's_text/ascii': 'ASCII', |
|||
's_text/utf8': 'UTF8', |
|||
's_text/ssa': 'SSA', |
|||
's_text/ass': 'ASS' |
|||
} |
|||
for codec in codecs: |
|||
if codec in value.lower(): |
|||
return codecs[codec] |
|||
|
|||
|
|||
def _parseBitRate(value): |
|||
try: |
|||
bitrate = float(value.split()[0]) |
|||
except: |
|||
return None |
|||
|
|||
if 'kbit' in value.lower(): |
|||
multi = 1000 |
|||
elif 'mbit' in value.lower(): |
|||
multi = 1000 * 1000 |
|||
else: |
|||
return None |
|||
|
|||
return bitrate * multi |
|||
|
|||
print json.dumps(parseMetadata(getMetadata(sys.argv[1]))) |
@ -0,0 +1,2 @@ |
|||
from hachoir_core.version import VERSION as __version__, PACKAGE, WEBSITE, LICENSE |
|||
|
@ -0,0 +1,210 @@ |
|||
from hachoir_core.tools import humanDurationNanosec |
|||
from hachoir_core.i18n import _ |
|||
from math import floor |
|||
from time import time |
|||
|
|||
class BenchmarkError(Exception): |
|||
""" |
|||
Error during benchmark, use str(err) to format it as string. |
|||
""" |
|||
def __init__(self, message): |
|||
Exception.__init__(self, |
|||
"Benchmark internal error: %s" % message) |
|||
|
|||
class BenchmarkStat: |
|||
""" |
|||
Benchmark statistics. This class automatically computes minimum value, |
|||
maximum value and sum of all values. |
|||
|
|||
Methods: |
|||
- append(value): append a value |
|||
- getMin(): minimum value |
|||
- getMax(): maximum value |
|||
- getSum(): sum of all values |
|||
- __len__(): get number of elements |
|||
- __nonzero__(): isn't empty? |
|||
""" |
|||
def __init__(self): |
|||
self._values = [] |
|||
|
|||
def append(self, value): |
|||
self._values.append(value) |
|||
try: |
|||
self._min = min(self._min, value) |
|||
self._max = max(self._max, value) |
|||
self._sum += value |
|||
except AttributeError: |
|||
self._min = value |
|||
self._max = value |
|||
self._sum = value |
|||
|
|||
def __len__(self): |
|||
return len(self._values) |
|||
|
|||
def __nonzero__(self): |
|||
return bool(self._values) |
|||
|
|||
def getMin(self): |
|||
return self._min |
|||
|
|||
def getMax(self): |
|||
return self._max |
|||
|
|||
def getSum(self): |
|||
return self._sum |
|||
|
|||
class Benchmark: |
|||
def __init__(self, max_time=5.0, |
|||
min_count=5, max_count=None, progress_time=1.0): |
|||
""" |
|||
Constructor: |
|||
- max_time: Maximum wanted duration of the whole benchmark |
|||
(default: 5 seconds, minimum: 1 second). |
|||
- min_count: Minimum number of function calls to get good statistics |
|||
(defaut: 5, minimum: 1). |
|||
- progress_time: Time between each "progress" message |
|||
(default: 1 second, minimum: 250 ms). |
|||
- max_count: Maximum number of function calls (default: no limit). |
|||
- verbose: Is verbose? (default: False) |
|||
- disable_gc: Disable garbage collector? (default: False) |
|||
""" |
|||
self.max_time = max(max_time, 1.0) |
|||
self.min_count = max(min_count, 1) |
|||
self.max_count = max_count |
|||
self.progress_time = max(progress_time, 0.25) |
|||
self.verbose = False |
|||
self.disable_gc = False |
|||
|
|||
def formatTime(self, value): |
|||
""" |
|||
Format a time delta to string: use humanDurationNanosec() |
|||
""" |
|||
return humanDurationNanosec(value * 1000000000) |
|||
|
|||
def displayStat(self, stat): |
|||
""" |
|||
Display statistics to stdout: |
|||
- best time (minimum) |
|||
- average time (arithmetic average) |
|||
- worst time (maximum) |
|||
- total time (sum) |
|||
|
|||
Use arithmetic avertage instead of geometric average because |
|||
geometric fails if any value is zero (returns zero) and also |
|||
because floating point multiplication lose precision with many |
|||
values. |
|||
""" |
|||
average = stat.getSum() / len(stat) |
|||
values = (stat.getMin(), average, stat.getMax(), stat.getSum()) |
|||
values = tuple(self.formatTime(value) for value in values) |
|||
print _("Benchmark: best=%s average=%s worst=%s total=%s") \ |
|||
% values |
|||
|
|||
def _runOnce(self, func, args, kw): |
|||
before = time() |
|||
func(*args, **kw) |
|||
after = time() |
|||
return after - before |
|||
|
|||
def _run(self, func, args, kw): |
|||
""" |
|||
Call func(*args, **kw) as many times as needed to get |
|||
good statistics. Algorithm: |
|||
- call the function once |
|||
- compute needed number of calls |
|||
- and then call function N times |
|||
|
|||
To compute number of calls, parameters are: |
|||
- time of first function call |
|||
- minimum number of calls (min_count attribute) |
|||
- maximum test time (max_time attribute) |
|||
|
|||
Notice: The function will approximate number of calls. |
|||
""" |
|||
# First call of the benchmark |
|||
stat = BenchmarkStat() |
|||
diff = self._runOnce(func, args, kw) |
|||
best = diff |
|||
stat.append(diff) |
|||
total_time = diff |
|||
|
|||
# Compute needed number of calls |
|||
count = int(floor(self.max_time / diff)) |
|||
count = max(count, self.min_count) |
|||
if self.max_count: |
|||
count = min(count, self.max_count) |
|||
|
|||
# Not other call? Just exit |
|||
if count == 1: |
|||
return stat |
|||
estimate = diff * count |
|||
if self.verbose: |
|||
print _("Run benchmark: %s calls (estimate: %s)") \ |
|||
% (count, self.formatTime(estimate)) |
|||
|
|||
display_progress = self.verbose and (1.0 <= estimate) |
|||
total_count = 1 |
|||
while total_count < count: |
|||
# Run benchmark and display each result |
|||
if display_progress: |
|||
print _("Result %s/%s: %s (best: %s)") % \ |
|||
(total_count, count, |
|||
self.formatTime(diff), self.formatTime(best)) |
|||
part = count - total_count |
|||
|
|||
# Will takes more than one second? |
|||
average = total_time / total_count |
|||
if self.progress_time < part * average: |
|||
part = max( int(self.progress_time / average), 1) |
|||
for index in xrange(part): |
|||
diff = self._runOnce(func, args, kw) |
|||
stat.append(diff) |
|||
total_time += diff |
|||
best = min(diff, best) |
|||
total_count += part |
|||
if display_progress: |
|||
print _("Result %s/%s: %s (best: %s)") % \ |
|||
(count, count, |
|||
self.formatTime(diff), self.formatTime(best)) |
|||
return stat |
|||
|
|||
def validateStat(self, stat): |
|||
""" |
|||
Check statistics and raise a BenchmarkError if they are invalid. |
|||
Example of tests: reject empty stat, reject stat with only nul values. |
|||
""" |
|||
if not stat: |
|||
raise BenchmarkError("empty statistics") |
|||
if not stat.getSum(): |
|||
raise BenchmarkError("nul statistics") |
|||
|
|||
def run(self, func, *args, **kw): |
|||
""" |
|||
Run function func(*args, **kw), validate statistics, |
|||
and display the result on stdout. |
|||
|
|||
Disable garbage collector if asked too. |
|||
""" |
|||
|
|||
# Disable garbarge collector is needed and if it does exist |
|||
# (Jython 2.2 don't have it for example) |
|||
if self.disable_gc: |
|||
try: |
|||
import gc |
|||
except ImportError: |
|||
self.disable_gc = False |
|||
if self.disable_gc: |
|||
gc_enabled = gc.isenabled() |
|||
gc.disable() |
|||
else: |
|||
gc_enabled = False |
|||
|
|||
# Run the benchmark |
|||
stat = self._run(func, args, kw) |
|||
if gc_enabled: |
|||
gc.enable() |
|||
|
|||
# Validate and display stats |
|||
self.validateStat(stat) |
|||
self.displayStat(stat) |
|||
|
@ -0,0 +1,277 @@ |
|||
""" |
|||
Utilities to convert integers and binary strings to binary (number), binary |
|||
string, number, hexadecimal, etc. |
|||
""" |
|||
|
|||
from hachoir_core.endian import BIG_ENDIAN, LITTLE_ENDIAN |
|||
from hachoir_core.compatibility import reversed |
|||
from itertools import chain, repeat |
|||
from struct import calcsize, unpack, error as struct_error |
|||
|
|||
def swap16(value): |
|||
""" |
|||
Swap byte between big and little endian of a 16 bits integer. |
|||
|
|||
>>> "%x" % swap16(0x1234) |
|||
'3412' |
|||
""" |
|||
return (value & 0xFF) << 8 | (value >> 8) |
|||
|
|||
def swap32(value): |
|||
""" |
|||
Swap byte between big and little endian of a 32 bits integer. |
|||
|
|||
>>> "%x" % swap32(0x12345678) |
|||
'78563412' |
|||
""" |
|||
value = long(value) |
|||
return ((value & 0x000000FFL) << 24) \ |
|||
| ((value & 0x0000FF00L) << 8) \ |
|||
| ((value & 0x00FF0000L) >> 8) \ |
|||
| ((value & 0xFF000000L) >> 24) |
|||
|
|||
def bin2long(text, endian): |
|||
""" |
|||
Convert binary number written in a string into an integer. |
|||
Skip characters differents than "0" and "1". |
|||
|
|||
>>> bin2long("110", BIG_ENDIAN) |
|||
6 |
|||
>>> bin2long("110", LITTLE_ENDIAN) |
|||
3 |
|||
>>> bin2long("11 00", LITTLE_ENDIAN) |
|||
3 |
|||
""" |
|||
assert endian in (LITTLE_ENDIAN, BIG_ENDIAN) |
|||
bits = [ (ord(character)-ord("0")) \ |
|||
for character in text if character in "01" ] |
|||
assert len(bits) != 0 |
|||
if endian is not BIG_ENDIAN: |
|||
bits = reversed(bits) |
|||
value = 0 |
|||
for bit in bits: |
|||
value *= 2 |
|||
value += bit |
|||
return value |
|||
|
|||
def str2hex(value, prefix="", glue=u"", format="%02X"): |
|||
r""" |
|||
Convert binary string in hexadecimal (base 16). |
|||
|
|||
>>> str2hex("ABC") |
|||
u'414243' |
|||
>>> str2hex("\xF0\xAF", glue=" ") |
|||
u'F0 AF' |
|||
>>> str2hex("ABC", prefix="0x") |
|||
u'0x414243' |
|||
>>> str2hex("ABC", format=r"\x%02X") |
|||
u'\\x41\\x42\\x43' |
|||
""" |
|||
if isinstance(glue, str): |
|||
glue = unicode(glue) |
|||
if 0 < len(prefix): |
|||
text = [prefix] |
|||
else: |
|||
text = [] |
|||
for character in value: |
|||
text.append(format % ord(character)) |
|||
return glue.join(text) |
|||
|
|||
def countBits(value): |
|||
""" |
|||
Count number of bits needed to store a (positive) integer number. |
|||
|
|||
>>> countBits(0) |
|||
1 |
|||
>>> countBits(1000) |
|||
10 |
|||
>>> countBits(44100) |
|||
16 |
|||
>>> countBits(18446744073709551615) |
|||
64 |
|||
""" |
|||
assert 0 <= value |
|||
count = 1 |
|||
bits = 1 |
|||
while (1 << bits) <= value: |
|||
count += bits |
|||
value >>= bits |
|||
bits <<= 1 |
|||
while 2 <= value: |
|||
if bits != 1: |
|||
bits >>= 1 |
|||
else: |
|||
bits -= 1 |
|||
while (1 << bits) <= value: |
|||
count += bits |
|||
value >>= bits |
|||
return count |
|||
|
|||
def byte2bin(number, classic_mode=True): |
|||
""" |
|||
Convert a byte (integer in 0..255 range) to a binary string. |
|||
If classic_mode is true (default value), reverse bits. |
|||
|
|||
>>> byte2bin(10) |
|||
'00001010' |
|||
>>> byte2bin(10, False) |
|||
'01010000' |
|||
""" |
|||
text = "" |
|||
for i in range(0, 8): |
|||
if classic_mode: |
|||
mask = 1 << (7-i) |
|||
else: |
|||
mask = 1 << i |
|||
if (number & mask) == mask: |
|||
text += "1" |
|||
else: |
|||
text += "0" |
|||
return text |
|||
|
|||
def long2raw(value, endian, size=None): |
|||
r""" |
|||
Convert a number (positive and not nul) to a raw string. |
|||
If size is given, add nul bytes to fill to size bytes. |
|||
|
|||
>>> long2raw(0x1219, BIG_ENDIAN) |
|||
'\x12\x19' |
|||
>>> long2raw(0x1219, BIG_ENDIAN, 4) # 32 bits |
|||
'\x00\x00\x12\x19' |
|||
>>> long2raw(0x1219, LITTLE_ENDIAN, 4) # 32 bits |
|||
'\x19\x12\x00\x00' |
|||
""" |
|||
assert (not size and 0 < value) or (0 <= value) |
|||
assert endian in (LITTLE_ENDIAN, BIG_ENDIAN) |
|||
text = [] |
|||
while (value != 0 or text == ""): |
|||
byte = value % 256 |
|||
text.append( chr(byte) ) |
|||
value >>= 8 |
|||
if size: |
|||
need = max(size - len(text), 0) |
|||
else: |
|||
need = 0 |
|||
if need: |
|||
if endian is BIG_ENDIAN: |
|||
text = chain(repeat("\0", need), reversed(text)) |
|||
else: |
|||
text = chain(text, repeat("\0", need)) |
|||
else: |
|||
if endian is BIG_ENDIAN: |
|||
text = reversed(text) |
|||
return "".join(text) |
|||
|
|||
def long2bin(size, value, endian, classic_mode=False): |
|||
""" |
|||
Convert a number into bits (in a string): |
|||
- size: size in bits of the number |
|||
- value: positive (or nul) number |
|||
- endian: BIG_ENDIAN (most important bit first) |
|||
or LITTLE_ENDIAN (least important bit first) |
|||
- classic_mode (default: False): reverse each packet of 8 bits |
|||
|
|||
>>> long2bin(16, 1+4 + (1+8)*256, BIG_ENDIAN) |
|||
'10100000 10010000' |
|||
>>> long2bin(16, 1+4 + (1+8)*256, BIG_ENDIAN, True) |
|||
'00000101 00001001' |
|||
>>> long2bin(16, 1+4 + (1+8)*256, LITTLE_ENDIAN) |
|||
'00001001 00000101' |
|||
>>> long2bin(16, 1+4 + (1+8)*256, LITTLE_ENDIAN, True) |
|||
'10010000 10100000' |
|||
""" |
|||
text = "" |
|||
assert endian in (LITTLE_ENDIAN, BIG_ENDIAN) |
|||
assert 0 <= value |
|||
for index in xrange(size): |
|||
if (value & 1) == 1: |
|||
text += "1" |
|||
else: |
|||
text += "0" |
|||
value >>= 1 |
|||
if endian is LITTLE_ENDIAN: |
|||
text = text[::-1] |
|||
result = "" |
|||
while len(text) != 0: |
|||
if len(result) != 0: |
|||
result += " " |
|||
if classic_mode: |
|||
result += text[7::-1] |
|||
else: |
|||
result += text[:8] |
|||
text = text[8:] |
|||
return result |
|||
|
|||
def str2bin(value, classic_mode=True): |
|||
r""" |
|||
Convert binary string to binary numbers. |
|||
If classic_mode is true (default value), reverse bits. |
|||
|
|||
>>> str2bin("\x03\xFF") |
|||
'00000011 11111111' |
|||
>>> str2bin("\x03\xFF", False) |
|||
'11000000 11111111' |
|||
""" |
|||
text = "" |
|||
for character in value: |
|||
if text != "": |
|||
text += " " |
|||
byte = ord(character) |
|||
text += byte2bin(byte, classic_mode) |
|||
return text |
|||
|
|||
def _createStructFormat(): |
|||
""" |
|||
Create a dictionnary (endian, size_byte) => struct format used |
|||
by str2long() to convert raw data to positive integer. |
|||
""" |
|||
format = { |
|||
BIG_ENDIAN: {}, |
|||
LITTLE_ENDIAN: {}, |
|||
} |
|||
for struct_format in "BHILQ": |
|||
try: |
|||
size = calcsize(struct_format) |
|||
format[BIG_ENDIAN][size] = '>%s' % struct_format |
|||
format[LITTLE_ENDIAN][size] = '<%s' % struct_format |
|||
except struct_error: |
|||
pass |
|||
return format |
|||
_struct_format = _createStructFormat() |
|||
|
|||
def str2long(data, endian): |
|||
r""" |
|||
Convert a raw data (type 'str') into a long integer. |
|||
|
|||
>>> chr(str2long('*', BIG_ENDIAN)) |
|||
'*' |
|||
>>> str2long("\x00\x01\x02\x03", BIG_ENDIAN) == 0x10203 |
|||
True |
|||
>>> str2long("\x2a\x10", LITTLE_ENDIAN) == 0x102a |
|||
True |
|||
>>> str2long("\xff\x14\x2a\x10", BIG_ENDIAN) == 0xff142a10 |
|||
True |
|||
>>> str2long("\x00\x01\x02\x03", LITTLE_ENDIAN) == 0x3020100 |
|||
True |
|||
>>> str2long("\xff\x14\x2a\x10\xab\x00\xd9\x0e", BIG_ENDIAN) == 0xff142a10ab00d90e |
|||
True |
|||
>>> str2long("\xff\xff\xff\xff\xff\xff\xff\xff", BIG_ENDIAN) == (2**64-1) |
|||
True |
|||
""" |
|||
assert 1 <= len(data) <= 32 # arbitrary limit: 256 bits |
|||
try: |
|||
return unpack(_struct_format[endian][len(data)], data)[0] |
|||
except KeyError: |
|||
pass |
|||
|
|||
assert endian in (BIG_ENDIAN, LITTLE_ENDIAN) |
|||
shift = 0 |
|||
value = 0 |
|||
if endian is BIG_ENDIAN: |
|||
data = reversed(data) |
|||
for character in data: |
|||
byte = ord(character) |
|||
value += (byte << shift) |
|||
shift += 8 |
|||
return value |
|||
|
@ -0,0 +1,43 @@ |
|||
from optparse import OptionGroup |
|||
from hachoir_core.log import log |
|||
from hachoir_core.i18n import _, getTerminalCharset |
|||
from hachoir_core.tools import makePrintable |
|||
import hachoir_core.config as config |
|||
|
|||
def getHachoirOptions(parser): |
|||
""" |
|||
Create an option group (type optparse.OptionGroup) of Hachoir |
|||
library options. |
|||
""" |
|||
def setLogFilename(*args): |
|||
log.setFilename(args[2]) |
|||
|
|||
common = OptionGroup(parser, _("Hachoir library"), \ |
|||
"Configure Hachoir library") |
|||
common.add_option("--verbose", help=_("Verbose mode"), |
|||
default=False, action="store_true") |
|||
common.add_option("--log", help=_("Write log in a file"), |
|||
type="string", action="callback", callback=setLogFilename) |
|||
common.add_option("--quiet", help=_("Quiet mode (don't display warning)"), |
|||
default=False, action="store_true") |
|||
common.add_option("--debug", help=_("Debug mode"), |
|||
default=False, action="store_true") |
|||
return common |
|||
|
|||
def configureHachoir(option): |
|||
# Configure Hachoir using "option" (value from optparse) |
|||
if option.quiet: |
|||
config.quiet = True |
|||
if option.verbose: |
|||
config.verbose = True |
|||
if option.debug: |
|||
config.debug = True |
|||
|
|||
def unicodeFilename(filename, charset=None): |
|||
if not charset: |
|||
charset = getTerminalCharset() |
|||
try: |
|||
return unicode(filename, charset) |
|||
except UnicodeDecodeError: |
|||
return makePrintable(filename, charset, to_unicode=True) |
|||
|
@ -0,0 +1,185 @@ |
|||
""" |
|||
Compatibility constants and functions. This module works on Python 1.5 to 2.5. |
|||
|
|||
This module provides: |
|||
- True and False constants ; |
|||
- any() and all() function ; |
|||
- has_yield and has_slice values ; |
|||
- isinstance() with Python 2.3 behaviour ; |
|||
- reversed() and sorted() function. |
|||
|
|||
|
|||
True and False constants |
|||
======================== |
|||
|
|||
Truth constants: True is yes (one) and False is no (zero). |
|||
|
|||
>>> int(True), int(False) # int value |
|||
(1, 0) |
|||
>>> int(False | True) # and binary operator |
|||
1 |
|||
>>> int(True & False) # or binary operator |
|||
0 |
|||
>>> int(not(True) == False) # not binary operator |
|||
1 |
|||
|
|||
Warning: on Python smaller than 2.3, True and False are aliases to |
|||
number 1 and 0. So "print True" will displays 1 and not True. |
|||
|
|||
|
|||
any() function |
|||
============== |
|||
|
|||
any() returns True if at least one items is True, or False otherwise. |
|||
|
|||
>>> any([False, True]) |
|||
True |
|||
>>> any([True, True]) |
|||
True |
|||
>>> any([False, False]) |
|||
False |
|||
|
|||
|
|||
all() function |
|||
============== |
|||
|
|||
all() returns True if all items are True, or False otherwise. |
|||
This function is just apply binary and operator (&) on all values. |
|||
|
|||
>>> all([True, True]) |
|||
True |
|||
>>> all([False, True]) |
|||
False |
|||
>>> all([False, False]) |
|||
False |
|||
|
|||
|
|||
has_yield boolean |
|||
================= |
|||
|
|||
has_yield: boolean which indicatese if the interpreter supports yield keyword. |
|||
yield keyworkd is available since Python 2.0. |
|||
|
|||
|
|||
has_yield boolean |
|||
================= |
|||
|
|||
has_slice: boolean which indicates if the interpreter supports slices with step |
|||
argument or not. slice with step is available since Python 2.3. |
|||
|
|||
|
|||
reversed() and sorted() function |
|||
================================ |
|||
|
|||
reversed() and sorted() function has been introduced in Python 2.4. |
|||
It's should returns a generator, but this module it may be a list. |
|||
|
|||
>>> data = list("cab") |
|||
>>> list(sorted(data)) |
|||
['a', 'b', 'c'] |
|||
>>> list(reversed("abc")) |
|||
['c', 'b', 'a'] |
|||
""" |
|||
|
|||
import copy |
|||
import operator |
|||
|
|||
# --- True and False constants from Python 2.0 --- |
|||
# --- Warning: for Python < 2.3, they are aliases for 1 and 0 --- |
|||
try: |
|||
True = True |
|||
False = False |
|||
except NameError: |
|||
True = 1 |
|||
False = 0 |
|||
|
|||
# --- any() from Python 2.5 --- |
|||
try: |
|||
from __builtin__ import any |
|||
except ImportError: |
|||
def any(items): |
|||
for item in items: |
|||
if item: |
|||
return True |
|||
return False |
|||
|
|||
# ---all() from Python 2.5 --- |
|||
try: |
|||
from __builtin__ import all |
|||
except ImportError: |
|||
def all(items): |
|||
return reduce(operator.__and__, items) |
|||
|
|||
# --- test if interpreter supports yield keyword --- |
|||
try: |
|||
eval(compile(""" |
|||
from __future__ import generators |
|||
|
|||
def gen(): |
|||
yield 1 |
|||
yield 2 |
|||
|
|||
if list(gen()) != [1, 2]: |
|||
raise KeyError("42") |
|||
""", "<string>", "exec")) |
|||
except (KeyError, SyntaxError): |
|||
has_yield = False |
|||
else: |
|||
has_yield = True |
|||
|
|||
# --- test if interpreter supports slices (with step argument) --- |
|||
try: |
|||
has_slice = eval('"abc"[::-1] == "cba"') |
|||
except (TypeError, SyntaxError): |
|||
has_slice = False |
|||
|
|||
# --- isinstance with isinstance Python 2.3 behaviour (arg 2 is a type) --- |
|||
try: |
|||
if isinstance(1, int): |
|||
from __builtin__ import isinstance |
|||
except TypeError: |
|||
print "Redef isinstance" |
|||
def isinstance20(a, typea): |
|||
if type(typea) != type(type): |
|||
raise TypeError("TypeError: isinstance() arg 2 must be a class, type, or tuple of classes and types") |
|||
return type(typea) != typea |
|||
isinstance = isinstance20 |
|||
|
|||
# --- reversed() from Python 2.4 --- |
|||
try: |
|||
from __builtin__ import reversed |
|||
except ImportError: |
|||
# if hasYield() == "ok": |
|||
# code = """ |
|||
#def reversed(data): |
|||
# for index in xrange(len(data)-1, -1, -1): |
|||
# yield data[index]; |
|||
#reversed""" |
|||
# reversed = eval(compile(code, "<string>", "exec")) |
|||
if has_slice: |
|||
def reversed(data): |
|||
if not isinstance(data, list): |
|||
data = list(data) |
|||
return data[::-1] |
|||
else: |
|||
def reversed(data): |
|||
if not isinstance(data, list): |
|||
data = list(data) |
|||
reversed_data = [] |
|||
for index in xrange(len(data)-1, -1, -1): |
|||
reversed_data.append(data[index]) |
|||
return reversed_data |
|||
|
|||
# --- sorted() from Python 2.4 --- |
|||
try: |
|||
from __builtin__ import sorted |
|||
except ImportError: |
|||
def sorted(data): |
|||
sorted_data = copy.copy(data) |
|||
sorted_data.sort() |
|||
return sorted |
|||
|
|||
__all__ = ("True", "False", |
|||
"any", "all", "has_yield", "has_slice", |
|||
"isinstance", "reversed", "sorted") |
|||
|
@ -0,0 +1,29 @@ |
|||
""" |
|||
Configuration of Hachoir |
|||
""" |
|||
|
|||
import os |
|||
|
|||
# UI: display options |
|||
max_string_length = 40 # Max. length in characters of GenericString.display |
|||
max_byte_length = 14 # Max. length in bytes of RawBytes.display |
|||
max_bit_length = 256 # Max. length in bits of RawBits.display |
|||
unicode_stdout = True # Replace stdout and stderr with Unicode compatible objects |
|||
# Disable it for readline or ipython |
|||
|
|||
# Global options |
|||
debug = False # Display many informations usefull to debug |
|||
verbose = False # Display more informations |
|||
quiet = False # Don't display warnings |
|||
|
|||
# Use internationalization and localization (gettext)? |
|||
if os.name == "nt": |
|||
# TODO: Remove this hack and make i18n works on Windows :-) |
|||
use_i18n = False |
|||
else: |
|||
use_i18n = True |
|||
|
|||
# Parser global options |
|||
autofix = True # Enable Autofix? see hachoir_core.field.GenericFieldSet |
|||
check_padding_pattern = True # Check padding fields pattern? |
|||
|
@ -0,0 +1,183 @@ |
|||
""" |
|||
Dictionnary classes which store values order. |
|||
""" |
|||
|
|||
from hachoir_core.error import HachoirError |
|||
from hachoir_core.i18n import _ |
|||
|
|||
class UniqKeyError(HachoirError): |
|||
""" |
|||
Error raised when a value is set whereas the key already exist in a |
|||
dictionnary. |
|||
""" |
|||
pass |
|||
|
|||
class Dict(object): |
|||
""" |
|||
This class works like classic Python dict() but has an important method: |
|||
__iter__() which allow to iterate into the dictionnary _values_ (and not |
|||
keys like Python's dict does). |
|||
""" |
|||
def __init__(self, values=None): |
|||
self._index = {} # key => index |
|||
self._key_list = [] # index => key |
|||
self._value_list = [] # index => value |
|||
if values: |
|||
for key, value in values: |
|||
self.append(key,value) |
|||
|
|||
def _getValues(self): |
|||
return self._value_list |
|||
values = property(_getValues) |
|||
|
|||
def index(self, key): |
|||
""" |
|||
Search a value by its key and returns its index |
|||
Returns None if the key doesn't exist. |
|||
|
|||
>>> d=Dict( (("two", "deux"), ("one", "un")) ) |
|||
>>> d.index("two") |
|||
0 |
|||
>>> d.index("one") |
|||
1 |
|||
>>> d.index("three") is None |
|||
True |
|||
""" |
|||
return self._index.get(key) |
|||
|
|||
def __getitem__(self, key): |
|||
""" |
|||
Get item with specified key. |
|||
To get a value by it's index, use mydict.values[index] |
|||
|
|||
>>> d=Dict( (("two", "deux"), ("one", "un")) ) |
|||
>>> d["one"] |
|||
'un' |
|||
""" |
|||
return self._value_list[self._index[key]] |
|||
|
|||
def __setitem__(self, key, value): |
|||
self._value_list[self._index[key]] = value |
|||
|
|||
def append(self, key, value): |
|||
""" |
|||
Append new value |
|||
""" |
|||
if key in self._index: |
|||
raise UniqKeyError(_("Key '%s' already exists") % key) |
|||
self._index[key] = len(self._value_list) |
|||
self._key_list.append(key) |
|||
self._value_list.append(value) |
|||
|
|||
def __len__(self): |
|||
return len(self._value_list) |
|||
|
|||
def __contains__(self, key): |
|||
return key in self._index |
|||
|
|||
def __iter__(self): |
|||
return iter(self._value_list) |
|||
|
|||
def iteritems(self): |
|||
""" |
|||
Create a generator to iterate on: (key, value). |
|||
|
|||
>>> d=Dict( (("two", "deux"), ("one", "un")) ) |
|||
>>> for key, value in d.iteritems(): |
|||
... print "%r: %r" % (key, value) |
|||
... |
|||
'two': 'deux' |
|||
'one': 'un' |
|||
""" |
|||
for index in xrange(len(self)): |
|||
yield (self._key_list[index], self._value_list[index]) |
|||
|
|||
def itervalues(self): |
|||
""" |
|||
Create an iterator on values |
|||
""" |
|||
return iter(self._value_list) |
|||
|
|||
def iterkeys(self): |
|||
""" |
|||
Create an iterator on keys |
|||
""" |
|||
return iter(self._key_list) |
|||
|
|||
def replace(self, oldkey, newkey, new_value): |
|||
""" |
|||
Replace an existing value with another one |
|||
|
|||
>>> d=Dict( (("two", "deux"), ("one", "un")) ) |
|||
>>> d.replace("one", "three", 3) |
|||
>>> d |
|||
{'two': 'deux', 'three': 3} |
|||
|
|||
You can also use the classic form: |
|||
|
|||
>>> d['three'] = 4 |
|||
>>> d |
|||
{'two': 'deux', 'three': 4} |
|||
""" |
|||
index = self._index[oldkey] |
|||
self._value_list[index] = new_value |
|||
if oldkey != newkey: |
|||
del self._index[oldkey] |
|||
self._index[newkey] = index |
|||
self._key_list[index] = newkey |
|||
|
|||
def __delitem__(self, index): |
|||
""" |
|||
Delete item at position index. May raise IndexError. |
|||
|
|||
>>> d=Dict( ((6, 'six'), (9, 'neuf'), (4, 'quatre')) ) |
|||
>>> del d[1] |
|||
>>> d |
|||
{6: 'six', 4: 'quatre'} |
|||
""" |
|||
if index < 0: |
|||
index += len(self._value_list) |
|||
if not (0 <= index < len(self._value_list)): |
|||
raise IndexError(_("list assignment index out of range (%s/%s)") |
|||
% (index, len(self._value_list))) |
|||
del self._value_list[index] |
|||
del self._key_list[index] |
|||
|
|||
# First loop which may alter self._index |
|||
for key, item_index in self._index.iteritems(): |
|||
if item_index == index: |
|||
del self._index[key] |
|||
break |
|||
|
|||
# Second loop update indexes |
|||
for key, item_index in self._index.iteritems(): |
|||
if index < item_index: |
|||
self._index[key] -= 1 |
|||
|
|||
def insert(self, index, key, value): |
|||
""" |
|||
Insert an item at specified position index. |
|||
|
|||
>>> d=Dict( ((6, 'six'), (9, 'neuf'), (4, 'quatre')) ) |
|||
>>> d.insert(1, '40', 'quarante') |
|||
>>> d |
|||
{6: 'six', '40': 'quarante', 9: 'neuf', 4: 'quatre'} |
|||
""" |
|||
if key in self: |
|||
raise UniqKeyError(_("Insert error: key '%s' ready exists") % key) |
|||
_index = index |
|||
if index < 0: |
|||
index += len(self._value_list) |
|||
if not(0 <= index <= len(self._value_list)): |
|||
raise IndexError(_("Insert error: index '%s' is invalid") % _index) |
|||
for item_key, item_index in self._index.iteritems(): |
|||
if item_index >= index: |
|||
self._index[item_key] += 1 |
|||
self._index[key] = index |
|||
self._key_list.insert(index, key) |
|||
self._value_list.insert(index, value) |
|||
|
|||
def __repr__(self): |
|||
items = ( "%r: %r" % (key, value) for key, value in self.iteritems() ) |
|||
return "{%s}" % ", ".join(items) |
|||
|
@ -0,0 +1,15 @@ |
|||
""" |
|||
Constant values about endian. |
|||
""" |
|||
|
|||
from hachoir_core.i18n import _ |
|||
|
|||
BIG_ENDIAN = "ABCD" |
|||
LITTLE_ENDIAN = "DCBA" |
|||
NETWORK_ENDIAN = BIG_ENDIAN |
|||
|
|||
endian_name = { |
|||
BIG_ENDIAN: _("Big endian"), |
|||
LITTLE_ENDIAN: _("Little endian"), |
|||
} |
|||
|
@ -0,0 +1,45 @@ |
|||
""" |
|||
Functions to display an error (error, warning or information) message. |
|||
""" |
|||
|
|||
from hachoir_core.log import log |
|||
from hachoir_core.tools import makePrintable |
|||
import sys, traceback |
|||
|
|||
def getBacktrace(empty="Empty backtrace."): |
|||
""" |
|||
Try to get backtrace as string. |
|||
Returns "Error while trying to get backtrace" on failure. |
|||
""" |
|||
try: |
|||
info = sys.exc_info() |
|||
trace = traceback.format_exception(*info) |
|||
sys.exc_clear() |
|||
if trace[0] != "None\n": |
|||
return "".join(trace) |
|||
except: |
|||
# No i18n here (imagine if i18n function calls error...) |
|||
return "Error while trying to get backtrace" |
|||
return empty |
|||
|
|||
class HachoirError(Exception): |
|||
""" |
|||
Parent of all errors in Hachoir library |
|||
""" |
|||
def __init__(self, message): |
|||
message_bytes = makePrintable(message, "ASCII") |
|||
Exception.__init__(self, message_bytes) |
|||
self.text = message |
|||
|
|||
def __unicode__(self): |
|||
return self.text |
|||
|
|||
# Error classes which may be raised by Hachoir core |
|||
# FIXME: Add EnvironmentError (IOError or OSError) and AssertionError? |
|||
# FIXME: Remove ArithmeticError and RuntimeError? |
|||
HACHOIR_ERRORS = (HachoirError, LookupError, NameError, AttributeError, |
|||
TypeError, ValueError, ArithmeticError, RuntimeError) |
|||
|
|||
info = log.info |
|||
warning = log.warning |
|||
error = log.error |
@ -0,0 +1,26 @@ |
|||
class EventHandler(object): |
|||
""" |
|||
Class to connect events to event handlers. |
|||
""" |
|||
|
|||
def __init__(self): |
|||
self.handlers = {} |
|||
|
|||
def connect(self, event_name, handler): |
|||
""" |
|||
Connect an event handler to an event. Append it to handlers list. |
|||
""" |
|||
try: |
|||
self.handlers[event_name].append(handler) |
|||
except KeyError: |
|||
self.handlers[event_name] = [handler] |
|||
|
|||
def raiseEvent(self, event_name, *args): |
|||
""" |
|||
Raiser an event: call each handler for this event_name. |
|||
""" |
|||
if event_name not in self.handlers: |
|||
return |
|||
for handler in self.handlers[event_name]: |
|||
handler(*args) |
|||
|
@ -0,0 +1,59 @@ |
|||
# Field classes |
|||
from hachoir_core.field.field import Field, FieldError, MissingField, joinPath |
|||
from hachoir_core.field.bit_field import Bit, Bits, RawBits |
|||
from hachoir_core.field.byte_field import Bytes, RawBytes |
|||
from hachoir_core.field.sub_file import SubFile, CompressedField |
|||
from hachoir_core.field.character import Character |
|||
from hachoir_core.field.integer import ( |
|||
Int8, Int16, Int24, Int32, Int64, |
|||
UInt8, UInt16, UInt24, UInt32, UInt64, |
|||
GenericInteger) |
|||
from hachoir_core.field.enum import Enum |
|||
from hachoir_core.field.string_field import (GenericString, |
|||
String, CString, UnixLine, |
|||
PascalString8, PascalString16, PascalString32) |
|||
from hachoir_core.field.padding import (PaddingBits, PaddingBytes, |
|||
NullBits, NullBytes) |
|||
|
|||
# Functions |
|||
from hachoir_core.field.helper import (isString, isInteger, |
|||
createPaddingField, createNullField, createRawField, |
|||
writeIntoFile, createOrphanField) |
|||
|
|||
# FieldSet classes |
|||
from hachoir_core.field.fake_array import FakeArray |
|||
from hachoir_core.field.basic_field_set import (BasicFieldSet, |
|||
ParserError, MatchError) |
|||
from hachoir_core.field.generic_field_set import GenericFieldSet |
|||
from hachoir_core.field.seekable_field_set import SeekableFieldSet, RootSeekableFieldSet |
|||
from hachoir_core.field.field_set import FieldSet |
|||
from hachoir_core.field.static_field_set import StaticFieldSet |
|||
from hachoir_core.field.parser import Parser |
|||
from hachoir_core.field.vector import GenericVector, UserVector |
|||
|
|||
# Complex types |
|||
from hachoir_core.field.float import Float32, Float64, Float80 |
|||
from hachoir_core.field.timestamp import (GenericTimestamp, |
|||
TimestampUnix32, TimestampUnix64, TimestampMac32, TimestampUUID60, TimestampWin64, |
|||
DateTimeMSDOS32, TimeDateMSDOS32, TimedeltaWin64) |
|||
|
|||
# Special Field classes |
|||
from hachoir_core.field.link import Link, Fragment |
|||
|
|||
available_types = ( |
|||
Bit, Bits, RawBits, |
|||
Bytes, RawBytes, |
|||
SubFile, |
|||
Character, |
|||
Int8, Int16, Int24, Int32, Int64, |
|||
UInt8, UInt16, UInt24, UInt32, UInt64, |
|||
String, CString, UnixLine, |
|||
PascalString8, PascalString16, PascalString32, |
|||
Float32, Float64, |
|||
PaddingBits, PaddingBytes, |
|||
NullBits, NullBytes, |
|||
TimestampUnix32, TimestampMac32, TimestampWin64, |
|||
DateTimeMSDOS32, TimeDateMSDOS32, |
|||
# GenericInteger, GenericString, |
|||
) |
|||
|
@ -0,0 +1,147 @@ |
|||
from hachoir_core.field import Field, FieldError |
|||
from hachoir_core.stream import InputStream |
|||
from hachoir_core.endian import BIG_ENDIAN, LITTLE_ENDIAN |
|||
from hachoir_core.event_handler import EventHandler |
|||
|
|||
class ParserError(FieldError): |
|||
""" |
|||
Error raised by a field set. |
|||
|
|||
@see: L{FieldError} |
|||
""" |
|||
pass |
|||
|
|||
class MatchError(FieldError): |
|||
""" |
|||
Error raised by a field set when the stream content doesn't |
|||
match to file format. |
|||
|
|||
@see: L{FieldError} |
|||
""" |
|||
pass |
|||
|
|||
class BasicFieldSet(Field): |
|||
_event_handler = None |
|||
is_field_set = True |
|||
endian = None |
|||
|
|||
def __init__(self, parent, name, stream, description, size): |
|||
# Sanity checks (preconditions) |
|||
assert not parent or issubclass(parent.__class__, BasicFieldSet) |
|||
assert issubclass(stream.__class__, InputStream) |
|||
|
|||
# Set field set size |
|||
if size is None and self.static_size: |
|||
assert isinstance(self.static_size, (int, long)) |
|||
size = self.static_size |
|||
|
|||
# Set Field attributes |
|||
self._parent = parent |
|||
self._name = name |
|||
self._size = size |
|||
self._description = description |
|||
self.stream = stream |
|||
self._field_array_count = {} |
|||
|
|||
# Set endian |
|||
if not self.endian: |
|||
assert parent and parent.endian |
|||
self.endian = parent.endian |
|||
|
|||
if parent: |
|||
# This field set is one of the root leafs |
|||
self._address = parent.nextFieldAddress() |
|||
self.root = parent.root |
|||
assert id(self.stream) == id(parent.stream) |
|||
else: |
|||
# This field set is the root |
|||
self._address = 0 |
|||
self.root = self |
|||
self._global_event_handler = None |
|||
|
|||
# Sanity checks (post-conditions) |
|||
assert self.endian in (BIG_ENDIAN, LITTLE_ENDIAN) |
|||
if (self._size is not None) and (self._size <= 0): |
|||
raise ParserError("Invalid parser '%s' size: %s" % (self.path, self._size)) |
|||
|
|||
def reset(self): |
|||
self._field_array_count = {} |
|||
|
|||
def createValue(self): |
|||
return None |
|||
|
|||
def connectEvent(self, event_name, handler, local=True): |
|||
assert event_name in ( |
|||
# Callback prototype: def f(field) |
|||
# Called when new value is already set |
|||
"field-value-changed", |
|||
|
|||
# Callback prototype: def f(field) |
|||
# Called when field size is already set |
|||
"field-resized", |
|||
|
|||
# A new field has been inserted in the field set |
|||
# Callback prototype: def f(index, new_field) |
|||
"field-inserted", |
|||
|
|||
# Callback prototype: def f(old_field, new_field) |
|||
# Called when new field is already in field set |
|||
"field-replaced", |
|||
|
|||
# Callback prototype: def f(field, new_value) |
|||
# Called to ask to set new value |
|||
"set-field-value" |
|||
), "Event name %r is invalid" % event_name |
|||
if local: |
|||
if self._event_handler is None: |
|||
self._event_handler = EventHandler() |
|||
self._event_handler.connect(event_name, handler) |
|||
else: |
|||
if self.root._global_event_handler is None: |
|||
self.root._global_event_handler = EventHandler() |
|||
self.root._global_event_handler.connect(event_name, handler) |
|||
|
|||
def raiseEvent(self, event_name, *args): |
|||
# Transfer event to local listeners |
|||
if self._event_handler is not None: |
|||
self._event_handler.raiseEvent(event_name, *args) |
|||
|
|||
# Transfer event to global listeners |
|||
if self.root._global_event_handler is not None: |
|||
self.root._global_event_handler.raiseEvent(event_name, *args) |
|||
|
|||
def setUniqueFieldName(self, field): |
|||
key = field._name[:-2] |
|||
try: |
|||
self._field_array_count[key] += 1 |
|||
except KeyError: |
|||
self._field_array_count[key] = 0 |
|||
field._name = key + "[%u]" % self._field_array_count[key] |
|||
|
|||
def readFirstFields(self, number): |
|||
""" |
|||
Read first number fields if they are not read yet. |
|||
|
|||
Returns number of new added fields. |
|||
""" |
|||
number = number - self.current_length |
|||
if 0 < number: |
|||
return self.readMoreFields(number) |
|||
else: |
|||
return 0 |
|||
|
|||
def createFields(self): |
|||
raise NotImplementedError() |
|||
def __iter__(self): |
|||
raise NotImplementedError() |
|||
def __len__(self): |
|||
raise NotImplementedError() |
|||
def getField(self, key, const=True): |
|||
raise NotImplementedError() |
|||
def nextFieldAddress(self): |
|||
raise NotImplementedError() |
|||
def getFieldIndex(self, field): |
|||
raise NotImplementedError() |
|||
def readMoreFields(self, number): |
|||
raise NotImplementedError() |
|||
|
@ -0,0 +1,68 @@ |
|||
""" |
|||
Bit sized classes: |
|||
- Bit: Single bit, value is False or True ; |
|||
- Bits: Integer with a size in bits ; |
|||
- RawBits: unknown content with a size in bits. |
|||
""" |
|||
|
|||
from hachoir_core.field import Field |
|||
from hachoir_core.i18n import _ |
|||
from hachoir_core import config |
|||
|
|||
class RawBits(Field): |
|||
""" |
|||
Unknown content with a size in bits. |
|||
""" |
|||
static_size = staticmethod(lambda *args, **kw: args[1]) |
|||
|
|||
def __init__(self, parent, name, size, description=None): |
|||
""" |
|||
Constructor: see L{Field.__init__} for parameter description |
|||
""" |
|||
Field.__init__(self, parent, name, size, description) |
|||
|
|||
def hasValue(self): |
|||
return True |
|||
|
|||
def createValue(self): |
|||
return self._parent.stream.readBits( |
|||
self.absolute_address, self._size, self._parent.endian) |
|||
|
|||
def createDisplay(self): |
|||
if self._size < config.max_bit_length: |
|||
return unicode(self.value) |
|||
else: |
|||
return _("<%s size=%u>" % |
|||
(self.__class__.__name__, self._size)) |
|||
createRawDisplay = createDisplay |
|||
|
|||
class Bits(RawBits): |
|||
""" |
|||
Positive integer with a size in bits |
|||
|
|||
@see: L{Bit} |
|||
@see: L{RawBits} |
|||
""" |
|||
pass |
|||
|
|||
class Bit(RawBits): |
|||
""" |
|||
Single bit: value can be False or True, and size is exactly one bit. |
|||
|
|||
@see: L{Bits} |
|||
""" |
|||
static_size = 1 |
|||
|
|||
def __init__(self, parent, name, description=None): |
|||
""" |
|||
Constructor: see L{Field.__init__} for parameter description |
|||
""" |
|||
RawBits.__init__(self, parent, name, 1, description=description) |
|||
|
|||
def createValue(self): |
|||
return 1 == self._parent.stream.readBits( |
|||
self.absolute_address, 1, self._parent.endian) |
|||
|
|||
def createRawDisplay(self): |
|||
return unicode(int(self.value)) |
|||
|
@ -0,0 +1,73 @@ |
|||
""" |
|||
Very basic field: raw content with a size in byte. Use this class for |
|||
unknown content. |
|||
""" |
|||
|
|||
from hachoir_core.field import Field, FieldError |
|||
from hachoir_core.tools import makePrintable |
|||
from hachoir_core.bits import str2hex |
|||
from hachoir_core import config |
|||
|
|||
MAX_LENGTH = (2**64) |
|||
|
|||
class RawBytes(Field): |
|||
""" |
|||
Byte vector of unknown content |
|||
|
|||
@see: L{Bytes} |
|||
""" |
|||
static_size = staticmethod(lambda *args, **kw: args[1]*8) |
|||
|
|||
def __init__(self, parent, name, length, description="Raw data"): |
|||
assert issubclass(parent.__class__, Field) |
|||
if not(0 < length <= MAX_LENGTH): |
|||
raise FieldError("Invalid RawBytes length (%s)!" % length) |
|||
Field.__init__(self, parent, name, length*8, description) |
|||
self._display = None |
|||
|
|||
def _createDisplay(self, human): |
|||
max_bytes = config.max_byte_length |
|||
if type(self._getValue) is type(lambda: None): |
|||
display = self.value[:max_bytes] |
|||
else: |
|||
if self._display is None: |
|||
address = self.absolute_address |
|||
length = min(self._size / 8, max_bytes) |
|||
self._display = self._parent.stream.readBytes(address, length) |
|||
display = self._display |
|||
truncated = (8 * len(display) < self._size) |
|||
if human: |
|||
if truncated: |
|||
display += "(...)" |
|||
return makePrintable(display, "latin-1", quote='"', to_unicode=True) |
|||
else: |
|||
display = str2hex(display, format=r"\x%02x") |
|||
if truncated: |
|||
return '"%s(...)"' % display |
|||
else: |
|||
return '"%s"' % display |
|||
|
|||
def createDisplay(self): |
|||
return self._createDisplay(True) |
|||
|
|||
def createRawDisplay(self): |
|||
return self._createDisplay(False) |
|||
|
|||
def hasValue(self): |
|||
return True |
|||
|
|||
def createValue(self): |
|||
assert (self._size % 8) == 0 |
|||
if self._display: |
|||
self._display = None |
|||
return self._parent.stream.readBytes( |
|||
self.absolute_address, self._size / 8) |
|||
|
|||
class Bytes(RawBytes): |
|||
""" |
|||
Byte vector: can be used for magic number or GUID/UUID for example. |
|||
|
|||
@see: L{RawBytes} |
|||
""" |
|||
pass |
|||
|
@ -0,0 +1,27 @@ |
|||
""" |
|||
Character field class: a 8-bit character |
|||
""" |
|||
|
|||
from hachoir_core.field import Bits |
|||
from hachoir_core.endian import BIG_ENDIAN |
|||
from hachoir_core.tools import makePrintable |
|||
|
|||
class Character(Bits): |
|||
""" |
|||
A 8-bit character using ASCII charset for display attribute. |
|||
""" |
|||
static_size = 8 |
|||
|
|||
def __init__(self, parent, name, description=None): |
|||
Bits.__init__(self, parent, name, 8, description=description) |
|||
|
|||
def createValue(self): |
|||
return chr(self._parent.stream.readBits( |
|||
self.absolute_address, 8, BIG_ENDIAN)) |
|||
|
|||
def createRawDisplay(self): |
|||
return unicode(Bits.createValue(self)) |
|||
|
|||
def createDisplay(self): |
|||
return makePrintable(self.value, "ASCII", quote="'", to_unicode=True) |
|||
|
@ -0,0 +1,26 @@ |
|||
def Enum(field, enum, key_func=None): |
|||
""" |
|||
Enum is an adapter to another field: it will just change its display |
|||
attribute. It uses a dictionary to associate a value to another. |
|||
|
|||
key_func is an optional function with prototype "def func(key)->key" |
|||
which is called to transform key. |
|||
""" |
|||
display = field.createDisplay |
|||
if key_func: |
|||
def createDisplay(): |
|||
try: |
|||
key = key_func(field.value) |
|||
return enum[key] |
|||
except LookupError: |
|||
return display() |
|||
else: |
|||
def createDisplay(): |
|||
try: |
|||
return enum[field.value] |
|||
except LookupError: |
|||
return display() |
|||
field.createDisplay = createDisplay |
|||
field.getEnum = lambda: enum |
|||
return field |
|||
|
@ -0,0 +1,81 @@ |
|||
import itertools |
|||
from hachoir_core.field import MissingField |
|||
|
|||
class FakeArray: |
|||
""" |
|||
Simulate an array for GenericFieldSet.array(): fielset.array("item")[0] is |
|||
equivalent to fielset.array("item[0]"). |
|||
|
|||
It's possible to iterate over the items using:: |
|||
|
|||
for element in fieldset.array("item"): |
|||
... |
|||
|
|||
And to get array size using len(fieldset.array("item")). |
|||
""" |
|||
def __init__(self, fieldset, name): |
|||
pos = name.rfind("/") |
|||
if pos != -1: |
|||
self.fieldset = fieldset[name[:pos]] |
|||
self.name = name[pos+1:] |
|||
else: |
|||
self.fieldset = fieldset |
|||
self.name = name |
|||
self._format = "%s[%%u]" % self.name |
|||
self._cache = {} |
|||
self._known_size = False |
|||
self._max_index = -1 |
|||
|
|||
def __nonzero__(self): |
|||
"Is the array empty or not?" |
|||
if self._cache: |
|||
return True |
|||
else: |
|||
return (0 in self) |
|||
|
|||
def __len__(self): |
|||
"Number of fields in the array" |
|||
total = self._max_index+1 |
|||
if not self._known_size: |
|||
for index in itertools.count(total): |
|||
try: |
|||
field = self[index] |
|||
total += 1 |
|||
except MissingField: |
|||
break |
|||
return total |
|||
|
|||
def __contains__(self, index): |
|||
try: |
|||
field = self[index] |
|||
return True |
|||
except MissingField: |
|||
return False |
|||
|
|||
def __getitem__(self, index): |
|||
""" |
|||
Get a field of the array. Returns a field, or raise MissingField |
|||
exception if the field doesn't exist. |
|||
""" |
|||
try: |
|||
value = self._cache[index] |
|||
except KeyError: |
|||
try: |
|||
value = self.fieldset[self._format % index] |
|||
except MissingField: |
|||
self._known_size = True |
|||
raise |
|||
self._cache[index] = value |
|||
self._max_index = max(index, self._max_index) |
|||
return value |
|||
|
|||
def __iter__(self): |
|||
""" |
|||
Iterate in the fields in their index order: field[0], field[1], ... |
|||
""" |
|||
for index in itertools.count(0): |
|||
try: |
|||
yield self[index] |
|||
except MissingField: |
|||
raise StopIteration() |
|||
|
@ -0,0 +1,262 @@ |
|||
""" |
|||
Parent of all (field) classes in Hachoir: Field. |
|||
""" |
|||
|
|||
from hachoir_core.compatibility import reversed |
|||
from hachoir_core.stream import InputFieldStream |
|||
from hachoir_core.error import HachoirError, HACHOIR_ERRORS |
|||
from hachoir_core.log import Logger |
|||
from hachoir_core.i18n import _ |
|||
from hachoir_core.tools import makePrintable |
|||
from weakref import ref as weakref_ref |
|||
|
|||
class FieldError(HachoirError): |
|||
""" |
|||
Error raised by a L{Field}. |
|||
|
|||
@see: L{HachoirError} |
|||
""" |
|||
pass |
|||
|
|||
def joinPath(path, name): |
|||
if path != "/": |
|||
return "/".join((path, name)) |
|||
else: |
|||
return "/%s" % name |
|||
|
|||
class MissingField(KeyError, FieldError): |
|||
def __init__(self, field, key): |
|||
KeyError.__init__(self) |
|||
self.field = field |
|||
self.key = key |
|||
|
|||
def __str__(self): |
|||
return 'Can\'t get field "%s" from %s' % (self.key, self.field.path) |
|||
|
|||
def __unicode__(self): |
|||
return u'Can\'t get field "%s" from %s' % (self.key, self.field.path) |
|||
|
|||
class Field(Logger): |
|||
# static size can have two differents value: None (no static size), an |
|||
# integer (number of bits), or a function which returns an integer. |
|||
# |
|||
# This function receives exactly the same arguments than the constructor |
|||
# except the first one (one). Example of function: |
|||
# static_size = staticmethod(lambda *args, **kw: args[1]) |
|||
static_size = None |
|||
|
|||
# Indicate if this field contains other fields (is a field set) or not |
|||
is_field_set = False |
|||
|
|||
def __init__(self, parent, name, size=None, description=None): |
|||
""" |
|||
Set default class attributes, set right address if None address is |
|||
given. |
|||
|
|||
@param parent: Parent field of this field |
|||
@type parent: L{Field}|None |
|||
@param name: Name of the field, have to be unique in parent. If it ends |
|||
with "[]", end will be replaced with "[new_id]" (eg. "raw[]" |
|||
becomes "raw[0]", next will be "raw[1]", and then "raw[2]", etc.) |
|||
@type name: str |
|||
@param size: Size of the field in bit (can be None, so it |
|||
will be computed later) |
|||
@type size: int|None |
|||
@param address: Address in bit relative to the parent absolute address |
|||
@type address: int|None |
|||
@param description: Optional string description |
|||
@type description: str|None |
|||
""" |
|||
assert issubclass(parent.__class__, Field) |
|||
assert (size is None) or (0 <= size) |
|||
self._parent = parent |
|||
if not name: |
|||
raise ValueError("empty field name") |
|||
self._name = name |
|||
self._address = parent.nextFieldAddress() |
|||
self._size = size |
|||
self._description = description |
|||
|
|||
def _logger(self): |
|||
return self.path |
|||
|
|||
def createDescription(self): |
|||
return "" |
|||
def _getDescription(self): |
|||
if self._description is None: |
|||
try: |
|||
self._description = self.createDescription() |
|||
if isinstance(self._description, str): |
|||
self._description = makePrintable( |
|||
self._description, "ISO-8859-1", to_unicode=True) |
|||
except HACHOIR_ERRORS, err: |
|||
self.error("Error getting description: " + unicode(err)) |
|||
self._description = "" |
|||
return self._description |
|||
description = property(_getDescription, |
|||
doc="Description of the field (string)") |
|||
|
|||
def __str__(self): |
|||
return self.display |
|||
def __unicode__(self): |
|||
return self.display |
|||
def __repr__(self): |
|||
return "<%s path=%r, address=%s, size=%s>" % ( |
|||
self.__class__.__name__, self.path, self._address, self._size) |
|||
|
|||
def hasValue(self): |
|||
return self._getValue() is not None |
|||
def createValue(self): |
|||
raise NotImplementedError() |
|||
def _getValue(self): |
|||
try: |
|||
value = self.createValue() |
|||
except HACHOIR_ERRORS, err: |
|||
self.error(_("Unable to create value: %s") % unicode(err)) |
|||
value = None |
|||
self._getValue = lambda: value |
|||
return value |
|||
value = property(lambda self: self._getValue(), doc="Value of field") |
|||
|
|||
def _getParent(self): |
|||
return self._parent |
|||
parent = property(_getParent, doc="Parent of this field") |
|||
|
|||
def createDisplay(self): |
|||
return unicode(self.value) |
|||
def _getDisplay(self): |
|||
if not hasattr(self, "_Field__display"): |
|||
try: |
|||
self.__display = self.createDisplay() |
|||
except HACHOIR_ERRORS, err: |
|||
self.error("Unable to create display: %s" % err) |
|||
self.__display = u"" |
|||
return self.__display |
|||
display = property(lambda self: self._getDisplay(), |
|||
doc="Short (unicode) string which represents field content") |
|||
|
|||
def createRawDisplay(self): |
|||
value = self.value |
|||
if isinstance(value, str): |
|||
return makePrintable(value, "ASCII", to_unicode=True) |
|||
else: |
|||
return unicode(value) |
|||
def _getRawDisplay(self): |
|||
if not hasattr(self, "_Field__raw_display"): |
|||
try: |
|||
self.__raw_display = self.createRawDisplay() |
|||
except HACHOIR_ERRORS, err: |
|||
self.error("Unable to create raw display: %s" % err) |
|||
self.__raw_display = u"" |
|||
return self.__raw_display |
|||
raw_display = property(lambda self: self._getRawDisplay(), |
|||
doc="(Unicode) string which represents raw field content") |
|||
|
|||
def _getName(self): |
|||
return self._name |
|||
name = property(_getName, |
|||
doc="Field name (unique in its parent field set list)") |
|||
|
|||
def _getIndex(self): |
|||
if not self._parent: |
|||
return None |
|||
return self._parent.getFieldIndex(self) |
|||
index = property(_getIndex) |
|||
|
|||
def _getPath(self): |
|||
if not self._parent: |
|||
return '/' |
|||
names = [] |
|||
field = self |
|||
while field is not None: |
|||
names.append(field._name) |
|||
field = field._parent |
|||
names[-1] = '' |
|||
return '/'.join(reversed(names)) |
|||
path = property(_getPath, |
|||
doc="Full path of the field starting at root field") |
|||
|
|||
def _getAddress(self): |
|||
return self._address |
|||
address = property(_getAddress, |
|||
doc="Relative address in bit to parent address") |
|||
|
|||
def _getAbsoluteAddress(self): |
|||
address = self._address |
|||
current = self._parent |
|||
while current: |
|||
address += current._address |
|||
current = current._parent |
|||
return address |
|||
absolute_address = property(_getAbsoluteAddress, |
|||
doc="Absolute address (from stream beginning) in bit") |
|||
|
|||
def _getSize(self): |
|||
return self._size |
|||
size = property(_getSize, doc="Content size in bit") |
|||
|
|||
def _getField(self, name, const): |
|||
if name.strip("."): |
|||
return None |
|||
field = self |
|||
for index in xrange(1, len(name)): |
|||
field = field._parent |
|||
if field is None: |
|||
break |
|||
return field |
|||
|
|||
def getField(self, key, const=True): |
|||
if key: |
|||
if key[0] == "/": |
|||
if self._parent: |
|||
current = self._parent.root |
|||
else: |
|||
current = self |
|||
if len(key) == 1: |
|||
return current |
|||
key = key[1:] |
|||
else: |
|||
current = self |
|||
for part in key.split("/"): |
|||
field = current._getField(part, const) |
|||
if field is None: |
|||
raise MissingField(current, part) |
|||
current = field |
|||
return current |
|||
raise KeyError("Key must not be an empty string!") |
|||
|
|||
def __getitem__(self, key): |
|||
return self.getField(key, False) |
|||
|
|||
def __contains__(self, key): |
|||
try: |
|||
return self.getField(key, False) is not None |
|||
except FieldError: |
|||
return False |
|||
|
|||
def _createInputStream(self, **args): |
|||
assert self._parent |
|||
return InputFieldStream(self, **args) |
|||
def getSubIStream(self): |
|||
if hasattr(self, "_sub_istream"): |
|||
stream = self._sub_istream() |
|||
else: |
|||
stream = None |
|||
if stream is None: |
|||
stream = self._createInputStream() |
|||
self._sub_istream = weakref_ref(stream) |
|||
return stream |
|||
def setSubIStream(self, createInputStream): |
|||
cis = self._createInputStream |
|||
self._createInputStream = lambda **args: createInputStream(cis, **args) |
|||
|
|||
def __nonzero__(self): |
|||
""" |
|||
Method called by code like "if field: (...)". |
|||
Always returns True |
|||
""" |
|||
return True |
|||
|
|||
def getFieldType(self): |
|||
return self.__class__.__name__ |
|||
|
@ -0,0 +1,7 @@ |
|||
from hachoir_core.field import BasicFieldSet, GenericFieldSet |
|||
|
|||
class FieldSet(GenericFieldSet): |
|||
def __init__(self, parent, name, *args, **kw): |
|||
assert issubclass(parent.__class__, BasicFieldSet) |
|||
GenericFieldSet.__init__(self, parent, name, parent.stream, *args, **kw) |
|||
|
@ -0,0 +1,99 @@ |
|||
from hachoir_core.field import Bit, Bits, FieldSet |
|||
from hachoir_core.endian import BIG_ENDIAN, LITTLE_ENDIAN |
|||
import struct |
|||
|
|||
# Make sure that we use right struct types |
|||
assert struct.calcsize("f") == 4 |
|||
assert struct.calcsize("d") == 8 |
|||
assert struct.unpack("<d", "\x1f\x85\xebQ\xb8\x1e\t@")[0] == 3.14 |
|||
assert struct.unpack(">d", "\xc0\0\0\0\0\0\0\0")[0] == -2.0 |
|||
|
|||
class FloatMantissa(Bits): |
|||
def createValue(self): |
|||
value = Bits.createValue(self) |
|||
return 1 + float(value) / (2 ** self.size) |
|||
|
|||
def createRawDisplay(self): |
|||
return unicode(Bits.createValue(self)) |
|||
|
|||
class FloatExponent(Bits): |
|||
def __init__(self, parent, name, size): |
|||
Bits.__init__(self, parent, name, size) |
|||
self.bias = 2 ** (size-1) - 1 |
|||
|
|||
def createValue(self): |
|||
return Bits.createValue(self) - self.bias |
|||
|
|||
def createRawDisplay(self): |
|||
return unicode(self.value + self.bias) |
|||
|
|||
def floatFactory(name, format, mantissa_bits, exponent_bits, doc): |
|||
size = 1 + mantissa_bits + exponent_bits |
|||
|
|||
class Float(FieldSet): |
|||
static_size = size |
|||
__doc__ = doc |
|||
|
|||
def __init__(self, parent, name, description=None): |
|||
assert parent.endian in (BIG_ENDIAN, LITTLE_ENDIAN) |
|||
FieldSet.__init__(self, parent, name, description, size) |
|||
if format: |
|||
if self._parent.endian == BIG_ENDIAN: |
|||
self.struct_format = ">"+format |
|||
else: |
|||
self.struct_format = "<"+format |
|||
else: |
|||
self.struct_format = None |
|||
|
|||
def createValue(self): |
|||
""" |
|||
Create float value: use struct.unpack() when it's possible |
|||
(32 and 64-bit float) or compute it with : |
|||
mantissa * (2.0 ** exponent) |
|||
|
|||
This computation may raise an OverflowError. |
|||
""" |
|||
if self.struct_format: |
|||
raw = self._parent.stream.readBytes( |
|||
self.absolute_address, self._size//8) |
|||
try: |
|||
return struct.unpack(self.struct_format, raw)[0] |
|||
except struct.error, err: |
|||
raise ValueError("[%s] conversion error: %s" % |
|||
(self.__class__.__name__, err)) |
|||
else: |
|||
try: |
|||
value = self["mantissa"].value * (2.0 ** float(self["exponent"].value)) |
|||
if self["negative"].value: |
|||
return -(value) |
|||
else: |
|||
return value |
|||
except OverflowError: |
|||
raise ValueError("[%s] floating point overflow" % |
|||
self.__class__.__name__) |
|||
|
|||
def createFields(self): |
|||
yield Bit(self, "negative") |
|||
yield FloatExponent(self, "exponent", exponent_bits) |
|||
if 64 <= mantissa_bits: |
|||
yield Bit(self, "one") |
|||
yield FloatMantissa(self, "mantissa", mantissa_bits-1) |
|||
else: |
|||
yield FloatMantissa(self, "mantissa", mantissa_bits) |
|||
|
|||
cls = Float |
|||
cls.__name__ = name |
|||
return cls |
|||
|
|||
# 32-bit float (standard: IEEE 754/854) |
|||
Float32 = floatFactory("Float32", "f", 23, 8, |
|||
"Floating point number: format IEEE 754 int 32 bit") |
|||
|
|||
# 64-bit float (standard: IEEE 754/854) |
|||
Float64 = floatFactory("Float64", "d", 52, 11, |
|||
"Floating point number: format IEEE 754 in 64 bit") |
|||
|
|||
# 80-bit float (standard: IEEE 754/854) |
|||
Float80 = floatFactory("Float80", None, 64, 15, |
|||
"Floating point number: format IEEE 754 in 80 bit") |
|||
|
@ -0,0 +1,532 @@ |
|||
from hachoir_core.field import (MissingField, BasicFieldSet, Field, ParserError, |
|||
createRawField, createNullField, createPaddingField, FakeArray) |
|||
from hachoir_core.dict import Dict, UniqKeyError |
|||
from hachoir_core.error import HACHOIR_ERRORS |
|||
from hachoir_core.tools import lowerBound |
|||
import hachoir_core.config as config |
|||
|
|||
class GenericFieldSet(BasicFieldSet): |
|||
""" |
|||
Ordered list of fields. Use operator [] to access fields using their |
|||
name (field names are unique in a field set, but not in the whole |
|||
document). |
|||
|
|||
Class attributes: |
|||
- endian: Bytes order (L{BIG_ENDIAN} or L{LITTLE_ENDIAN}). Optional if the |
|||
field set has a parent ; |
|||
- static_size: (optional) Size of FieldSet in bits. This attribute should |
|||
be used in parser of constant size. |
|||
|
|||
Instance attributes/methods: |
|||
- _fields: Ordered dictionnary of all fields, may be incomplete |
|||
because feeded when a field is requested ; |
|||
- stream: Input stream used to feed fields' value |
|||
- root: The root of all field sets ; |
|||
- __len__(): Number of fields, may need to create field set ; |
|||
- __getitem__(): Get an field by it's name or it's path. |
|||
|
|||
And attributes inherited from Field class: |
|||
- parent: Parent field (may be None if it's the root) ; |
|||
- name: Field name (unique in parent field set) ; |
|||
- value: The field set ; |
|||
- address: Field address (in bits) relative to parent ; |
|||
- description: A string describing the content (can be None) ; |
|||
- size: Size of field set in bits, may need to create field set. |
|||
|
|||
Event handling: |
|||
- "connectEvent": Connect an handler to an event ; |
|||
- "raiseEvent": Raise an event. |
|||
|
|||
To implement a new field set, you need to: |
|||
- create a class which inherite from FieldSet ; |
|||
- write createFields() method using lines like: |
|||
yield Class(self, "name", ...) ; |
|||
- and maybe set endian and static_size class attributes. |
|||
""" |
|||
|
|||
_current_size = 0 |
|||
|
|||
def __init__(self, parent, name, stream, description=None, size=None): |
|||
""" |
|||
Constructor |
|||
@param parent: Parent field set, None for root parser |
|||
@param name: Name of the field, have to be unique in parent. If it ends |
|||
with "[]", end will be replaced with "[new_id]" (eg. "raw[]" |
|||
becomes "raw[0]", next will be "raw[1]", and then "raw[2]", etc.) |
|||
@type name: str |
|||
@param stream: Input stream from which data are read |
|||
@type stream: L{InputStream} |
|||
@param description: Optional string description |
|||
@type description: str|None |
|||
@param size: Size in bits. If it's None, size will be computed. You |
|||
can also set size with class attribute static_size |
|||
""" |
|||
BasicFieldSet.__init__(self, parent, name, stream, description, size) |
|||
self._fields = Dict() |
|||
self._field_generator = self.createFields() |
|||
self._array_cache = {} |
|||
self.__is_feeding = False |
|||
|
|||
def array(self, key): |
|||
try: |
|||
return self._array_cache[key] |
|||
except KeyError: |
|||
array = FakeArray(self, key) |
|||
self._array_cache[key] = array |
|||
return self._array_cache[key] |
|||
|
|||
def reset(self): |
|||
""" |
|||
Reset a field set: |
|||
* clear fields ; |
|||
* restart field generator ; |
|||
* set current size to zero ; |
|||
* clear field array count. |
|||
|
|||
But keep: name, value, description and size. |
|||
""" |
|||
BasicFieldSet.reset(self) |
|||
self._fields = Dict() |
|||
self._field_generator = self.createFields() |
|||
self._current_size = 0 |
|||
self._array_cache = {} |
|||
|
|||
def __str__(self): |
|||
return '<%s path=%s, current_size=%s, current length=%s>' % \ |
|||
(self.__class__.__name__, self.path, self._current_size, len(self._fields)) |
|||
|
|||
def __len__(self): |
|||
""" |
|||
Returns number of fields, may need to create all fields |
|||
if it's not done yet. |
|||
""" |
|||
if self._field_generator is not None: |
|||
self._feedAll() |
|||
return len(self._fields) |
|||
|
|||
def _getCurrentLength(self): |
|||
return len(self._fields) |
|||
current_length = property(_getCurrentLength) |
|||
|
|||
def _getSize(self): |
|||
if self._size is None: |
|||
self._feedAll() |
|||
return self._size |
|||
size = property(_getSize, doc="Size in bits, may create all fields to get size") |
|||
|
|||
def _getCurrentSize(self): |
|||
assert not(self.done) |
|||
return self._current_size |
|||
current_size = property(_getCurrentSize) |
|||
|
|||
eof = property(lambda self: self._checkSize(self._current_size + 1, True) < 0) |
|||
|
|||
def _checkSize(self, size, strict): |
|||
field = self |
|||
while field._size is None: |
|||
if not field._parent: |
|||
assert self.stream.size is None |
|||
if not strict: |
|||
return None |
|||
if self.stream.sizeGe(size): |
|||
return 0 |
|||
break |
|||
size += field._address |
|||
field = field._parent |
|||
return field._size - size |
|||
|
|||
autofix = property(lambda self: self.root.autofix) |
|||
|
|||
def _addField(self, field): |
|||
""" |
|||
Add a field to the field set: |
|||
* add it into _fields |
|||
* update _current_size |
|||
|
|||
May raise a StopIteration() on error |
|||
""" |
|||
if not issubclass(field.__class__, Field): |
|||
raise ParserError("Field type (%s) is not a subclass of 'Field'!" |
|||
% field.__class__.__name__) |
|||
assert isinstance(field._name, str) |
|||
if field._name.endswith("[]"): |
|||
self.setUniqueFieldName(field) |
|||
if config.debug: |
|||
self.info("[+] DBG: _addField(%s)" % field.name) |
|||
|
|||
# required for the msoffice parser |
|||
if field._address != self._current_size: |
|||
self.warning("Fix address of %s to %s (was %s)" % |
|||
(field.path, self._current_size, field._address)) |
|||
field._address = self._current_size |
|||
|
|||
ask_stop = False |
|||
# Compute field size and check that there is enough place for it |
|||
self.__is_feeding = True |
|||
try: |
|||
field_size = field.size |
|||
except HACHOIR_ERRORS, err: |
|||
if field.is_field_set and field.current_length and field.eof: |
|||
self.warning("Error when getting size of '%s': %s" % (field.name, err)) |
|||
field._stopFeeding() |
|||
ask_stop = True |
|||
else: |
|||
self.warning("Error when getting size of '%s': delete it" % field.name) |
|||
self.__is_feeding = False |
|||
raise |
|||
self.__is_feeding = False |
|||
|
|||
# No more place? |
|||
dsize = self._checkSize(field._address + field.size, False) |
|||
if (dsize is not None and dsize < 0) or (field.is_field_set and field.size <= 0): |
|||
if self.autofix and self._current_size: |
|||
self._fixFieldSize(field, field.size + dsize) |
|||
else: |
|||
raise ParserError("Field %s is too large!" % field.path) |
|||
|
|||
self._current_size += field.size |
|||
try: |
|||
self._fields.append(field._name, field) |
|||
except UniqKeyError, err: |
|||
self.warning("Duplicate field name " + unicode(err)) |
|||
field._name += "[]" |
|||
self.setUniqueFieldName(field) |
|||
self._fields.append(field._name, field) |
|||
if ask_stop: |
|||
raise StopIteration() |
|||
|
|||
def _fixFieldSize(self, field, new_size): |
|||
if new_size > 0: |
|||
if field.is_field_set and 0 < field.size: |
|||
field._truncate(new_size) |
|||
return |
|||
|
|||
# Don't add the field <=> delete item |
|||
if self._size is None: |
|||
self._size = self._current_size + new_size |
|||
self.warning("[Autofix] Delete '%s' (too large)" % field.path) |
|||
raise StopIteration() |
|||
|
|||
def _getField(self, name, const): |
|||
field = Field._getField(self, name, const) |
|||
if field is None: |
|||
if name in self._fields: |
|||
field = self._fields[name] |
|||
elif self._field_generator is not None and not const: |
|||
field = self._feedUntil(name) |
|||
return field |
|||
|
|||
def getField(self, key, const=True): |
|||
if isinstance(key, (int, long)): |
|||
if key < 0: |
|||
raise KeyError("Key must be positive!") |
|||
if not const: |
|||
self.readFirstFields(key+1) |
|||
if len(self._fields.values) <= key: |
|||
raise MissingField(self, key) |
|||
return self._fields.values[key] |
|||
return Field.getField(self, key, const) |
|||
|
|||
def _truncate(self, size): |
|||
assert size > 0 |
|||
if size < self._current_size: |
|||
self._size = size |
|||
while True: |
|||
field = self._fields.values[-1] |
|||
if field._address < size: |
|||
break |
|||
del self._fields[-1] |
|||
self._current_size = field._address |
|||
size -= field._address |
|||
if size < field._size: |
|||
if field.is_field_set: |
|||
field._truncate(size) |
|||
else: |
|||
del self._fields[-1] |
|||
field = createRawField(self, size, "raw[]") |
|||
self._fields.append(field._name, field) |
|||
self._current_size = self._size |
|||
else: |
|||
assert size < self._size or self._size is None |
|||
self._size = size |
|||
if self._size == self._current_size: |
|||
self._field_generator = None |
|||
|
|||
def _deleteField(self, index): |
|||
field = self._fields.values[index] |
|||
size = field.size |
|||
self._current_size -= size |
|||
del self._fields[index] |
|||
return field |
|||
|
|||
def _fixLastField(self): |
|||
""" |
|||
Try to fix last field when we know current field set size. |
|||
Returns new added field if any, or None. |
|||
""" |
|||
assert self._size is not None |
|||
|
|||
# Stop parser |
|||
message = ["stop parser"] |
|||
self._field_generator = None |
|||
|
|||
# If last field is too big, delete it |
|||
while self._size < self._current_size: |
|||
field = self._deleteField(len(self._fields)-1) |
|||
message.append("delete field %s" % field.path) |
|||
assert self._current_size <= self._size |
|||
|
|||
# If field size current is smaller: add a raw field |
|||
size = self._size - self._current_size |
|||
if size: |
|||
field = createRawField(self, size, "raw[]") |
|||
message.append("add padding") |
|||
self._current_size += field.size |
|||
self._fields.append(field._name, field) |
|||
else: |
|||
field = None |
|||
message = ", ".join(message) |
|||
self.warning("[Autofix] Fix parser error: " + message) |
|||
assert self._current_size == self._size |
|||
return field |
|||
|
|||
def _stopFeeding(self): |
|||
new_field = None |
|||
if self._size is None: |
|||
if self._parent: |
|||
self._size = self._current_size |
|||
elif self._size != self._current_size: |
|||
if self.autofix: |
|||
new_field = self._fixLastField() |
|||
else: |
|||
raise ParserError("Invalid parser \"%s\" size!" % self.path) |
|||
self._field_generator = None |
|||
return new_field |
|||
|
|||
def _fixFeedError(self, exception): |
|||
""" |
|||
Try to fix a feeding error. Returns False if error can't be fixed, |
|||
otherwise returns new field if any, or None. |
|||
""" |
|||
if self._size is None or not self.autofix: |
|||
return False |
|||
self.warning(unicode(exception)) |
|||
return self._fixLastField() |
|||
|
|||
def _feedUntil(self, field_name): |
|||
""" |
|||
Return the field if it was found, None else |
|||
""" |
|||
if self.__is_feeding \ |
|||
or (self._field_generator and self._field_generator.gi_running): |
|||
self.warning("Unable to get %s (and generator is already running)" |
|||
% field_name) |
|||
return None |
|||
try: |
|||
while True: |
|||
field = self._field_generator.next() |
|||
self._addField(field) |
|||
if field.name == field_name: |
|||
return field |
|||
except HACHOIR_ERRORS, err: |
|||
if self._fixFeedError(err) is False: |
|||
raise |
|||
except StopIteration: |
|||
self._stopFeeding() |
|||
return None |
|||
|
|||
def readMoreFields(self, number): |
|||
""" |
|||
Read more number fields, or do nothing if parsing is done. |
|||
|
|||
Returns number of new added fields. |
|||
""" |
|||
if self._field_generator is None: |
|||
return 0 |
|||
oldlen = len(self._fields) |
|||
try: |
|||
for index in xrange(number): |
|||
self._addField( self._field_generator.next() ) |
|||
except HACHOIR_ERRORS, err: |
|||
if self._fixFeedError(err) is False: |
|||
raise |
|||
except StopIteration: |
|||
self._stopFeeding() |
|||
return len(self._fields) - oldlen |
|||
|
|||
def _feedAll(self): |
|||
if self._field_generator is None: |
|||
return |
|||
try: |
|||
while True: |
|||
field = self._field_generator.next() |
|||
self._addField(field) |
|||
except HACHOIR_ERRORS, err: |
|||
if self._fixFeedError(err) is False: |
|||
raise |
|||
except StopIteration: |
|||
self._stopFeeding() |
|||
|
|||
def __iter__(self): |
|||
""" |
|||
Create a generator to iterate on each field, may create new |
|||
fields when needed |
|||
""" |
|||
try: |
|||
done = 0 |
|||
while True: |
|||
if done == len(self._fields): |
|||
if self._field_generator is None: |
|||
break |
|||
self._addField( self._field_generator.next() ) |
|||
for field in self._fields.values[done:]: |
|||
yield field |
|||
done += 1 |
|||
except HACHOIR_ERRORS, err: |
|||
field = self._fixFeedError(err) |
|||
if isinstance(field, Field): |
|||
yield field |
|||
elif hasattr(field, '__iter__'): |
|||
for f in field: |
|||
yield f |
|||
elif field is False: |
|||
raise |
|||
except StopIteration: |
|||
field = self._stopFeeding() |
|||
if isinstance(field, Field): |
|||
yield field |
|||
elif hasattr(field, '__iter__'): |
|||
for f in field: |
|||
yield f |
|||
|
|||
def _isDone(self): |
|||
return (self._field_generator is None) |
|||
done = property(_isDone, doc="Boolean to know if parsing is done or not") |
|||
|
|||
# |
|||
# FieldSet_SeekUtility |
|||
# |
|||
def seekBit(self, address, name="padding[]", |
|||
description=None, relative=True, null=False): |
|||
""" |
|||
Create a field to seek to specified address, |
|||
or None if it's not needed. |
|||
|
|||
May raise an (ParserError) exception if address is invalid. |
|||
""" |
|||
if relative: |
|||
nbits = address - self._current_size |
|||
else: |
|||
nbits = address - (self.absolute_address + self._current_size) |
|||
if nbits < 0: |
|||
raise ParserError("Seek error, unable to go back!") |
|||
if 0 < nbits: |
|||
if null: |
|||
return createNullField(self, nbits, name, description) |
|||
else: |
|||
return createPaddingField(self, nbits, name, description) |
|||
else: |
|||
return None |
|||
|
|||
def seekByte(self, address, name="padding[]", description=None, relative=True, null=False): |
|||
""" |
|||
Same as seekBit(), but with address in byte. |
|||
""" |
|||
return self.seekBit(address * 8, name, description, relative, null=null) |
|||
|
|||
# |
|||
# RandomAccessFieldSet |
|||
# |
|||
def replaceField(self, name, new_fields): |
|||
# TODO: Check in self and not self.field |
|||
# Problem is that "generator is already executing" |
|||
if name not in self._fields: |
|||
raise ParserError("Unable to replace %s: field doesn't exist!" % name) |
|||
assert 1 <= len(new_fields) |
|||
old_field = self[name] |
|||
total_size = sum( (field.size for field in new_fields) ) |
|||
if old_field.size != total_size: |
|||
raise ParserError("Unable to replace %s: " |
|||
"new field(s) hasn't same size (%u bits instead of %u bits)!" |
|||
% (name, total_size, old_field.size)) |
|||
field = new_fields[0] |
|||
if field._name.endswith("[]"): |
|||
self.setUniqueFieldName(field) |
|||
field._address = old_field.address |
|||
if field.name != name and field.name in self._fields: |
|||
raise ParserError( |
|||
"Unable to replace %s: name \"%s\" is already used!" |
|||
% (name, field.name)) |
|||
self._fields.replace(name, field.name, field) |
|||
self.raiseEvent("field-replaced", old_field, field) |
|||
if 1 < len(new_fields): |
|||
index = self._fields.index(new_fields[0].name)+1 |
|||
address = field.address + field.size |
|||
for field in new_fields[1:]: |
|||
if field._name.endswith("[]"): |
|||
self.setUniqueFieldName(field) |
|||
field._address = address |
|||
if field.name in self._fields: |
|||
raise ParserError( |
|||
"Unable to replace %s: name \"%s\" is already used!" |
|||
% (name, field.name)) |
|||
self._fields.insert(index, field.name, field) |
|||
self.raiseEvent("field-inserted", index, field) |
|||
index += 1 |
|||
address += field.size |
|||
|
|||
def getFieldByAddress(self, address, feed=True): |
|||
""" |
|||
Only search in existing fields |
|||
""" |
|||
if feed and self._field_generator is not None: |
|||
self._feedAll() |
|||
if address < self._current_size: |
|||
i = lowerBound(self._fields.values, lambda x: x.address + x.size <= address) |
|||
if i is not None: |
|||
return self._fields.values[i] |
|||
return None |
|||
|
|||
def writeFieldsIn(self, old_field, address, new_fields): |
|||
""" |
|||
Can only write in existing fields (address < self._current_size) |
|||
""" |
|||
|
|||
# Check size |
|||
total_size = sum( field.size for field in new_fields ) |
|||
if old_field.size < total_size: |
|||
raise ParserError( \ |
|||
"Unable to write fields at address %s " \ |
|||
"(too big)!" % (address)) |
|||
|
|||
# Need padding before? |
|||
replace = [] |
|||
size = address - old_field.address |
|||
assert 0 <= size |
|||
if 0 < size: |
|||
padding = createPaddingField(self, size) |
|||
padding._address = old_field.address |
|||
replace.append(padding) |
|||
|
|||
# Set fields address |
|||
for field in new_fields: |
|||
field._address = address |
|||
address += field.size |
|||
replace.append(field) |
|||
|
|||
# Need padding after? |
|||
size = (old_field.address + old_field.size) - address |
|||
assert 0 <= size |
|||
if 0 < size: |
|||
padding = createPaddingField(self, size) |
|||
padding._address = address |
|||
replace.append(padding) |
|||
|
|||
self.replaceField(old_field.name, replace) |
|||
|
|||
def nextFieldAddress(self): |
|||
return self._current_size |
|||
|
|||
def getFieldIndex(self, field): |
|||
return self._fields.index(field._name) |
|||
|
@ -0,0 +1,57 @@ |
|||
from hachoir_core.field import (FieldError, |
|||
RawBits, RawBytes, |
|||
PaddingBits, PaddingBytes, |
|||
NullBits, NullBytes, |
|||
GenericString, GenericInteger) |
|||
from hachoir_core.stream import FileOutputStream |
|||
|
|||
def createRawField(parent, size, name="raw[]", description=None): |
|||
if size <= 0: |
|||
raise FieldError("Unable to create raw field of %s bits" % size) |
|||
if (size % 8) == 0: |
|||
return RawBytes(parent, name, size/8, description) |
|||
else: |
|||
return RawBits(parent, name, size, description) |
|||
|
|||
def createPaddingField(parent, nbits, name="padding[]", description=None): |
|||
if nbits <= 0: |
|||
raise FieldError("Unable to create padding of %s bits" % nbits) |
|||
if (nbits % 8) == 0: |
|||
return PaddingBytes(parent, name, nbits/8, description) |
|||
else: |
|||
return PaddingBits(parent, name, nbits, description) |
|||
|
|||
def createNullField(parent, nbits, name="padding[]", description=None): |
|||
if nbits <= 0: |
|||
raise FieldError("Unable to create null padding of %s bits" % nbits) |
|||
if (nbits % 8) == 0: |
|||
return NullBytes(parent, name, nbits/8, description) |
|||
else: |
|||
return NullBits(parent, name, nbits, description) |
|||
|
|||
def isString(field): |
|||
return issubclass(field.__class__, GenericString) |
|||
|
|||
def isInteger(field): |
|||
return issubclass(field.__class__, GenericInteger) |
|||
|
|||
def writeIntoFile(fieldset, filename): |
|||
output = FileOutputStream(filename) |
|||
fieldset.writeInto(output) |
|||
|
|||
def createOrphanField(fieldset, address, field_cls, *args, **kw): |
|||
""" |
|||
Create an orphan field at specified address: |
|||
field_cls(fieldset, *args, **kw) |
|||
|
|||
The field uses the fieldset properties but it isn't added to the |
|||
field set. |
|||
""" |
|||
save_size = fieldset._current_size |
|||
try: |
|||
fieldset._current_size = address |
|||
field = field_cls(fieldset, *args, **kw) |
|||
finally: |
|||
fieldset._current_size = save_size |
|||
return field |
|||
|
@ -0,0 +1,44 @@ |
|||
""" |
|||
Integer field classes: |
|||
- UInt8, UInt16, UInt24, UInt32, UInt64: unsigned integer of 8, 16, 32, 64 bits ; |
|||
- Int8, Int16, Int24, Int32, Int64: signed integer of 8, 16, 32, 64 bits. |
|||
""" |
|||
|
|||
from hachoir_core.field import Bits, FieldError |
|||
|
|||
class GenericInteger(Bits): |
|||
""" |
|||
Generic integer class used to generate other classes. |
|||
""" |
|||
def __init__(self, parent, name, signed, size, description=None): |
|||
if not (8 <= size <= 256): |
|||
raise FieldError("Invalid integer size (%s): have to be in 8..256" % size) |
|||
Bits.__init__(self, parent, name, size, description) |
|||
self.signed = signed |
|||
|
|||
def createValue(self): |
|||
return self._parent.stream.readInteger( |
|||
self.absolute_address, self.signed, self._size, self._parent.endian) |
|||
|
|||
def integerFactory(name, is_signed, size, doc): |
|||
class Integer(GenericInteger): |
|||
__doc__ = doc |
|||
static_size = size |
|||
def __init__(self, parent, name, description=None): |
|||
GenericInteger.__init__(self, parent, name, is_signed, size, description) |
|||
cls = Integer |
|||
cls.__name__ = name |
|||
return cls |
|||
|
|||
UInt8 = integerFactory("UInt8", False, 8, "Unsigned integer of 8 bits") |
|||
UInt16 = integerFactory("UInt16", False, 16, "Unsigned integer of 16 bits") |
|||
UInt24 = integerFactory("UInt24", False, 24, "Unsigned integer of 24 bits") |
|||
UInt32 = integerFactory("UInt32", False, 32, "Unsigned integer of 32 bits") |
|||
UInt64 = integerFactory("UInt64", False, 64, "Unsigned integer of 64 bits") |
|||
|
|||
Int8 = integerFactory("Int8", True, 8, "Signed integer of 8 bits") |
|||
Int16 = integerFactory("Int16", True, 16, "Signed integer of 16 bits") |
|||
Int24 = integerFactory("Int24", True, 24, "Signed integer of 24 bits") |
|||
Int32 = integerFactory("Int32", True, 32, "Signed integer of 32 bits") |
|||
Int64 = integerFactory("Int64", True, 64, "Signed integer of 64 bits") |
|||
|
@ -0,0 +1,109 @@ |
|||
from hachoir_core.field import Field, FieldSet, ParserError, Bytes, MissingField |
|||
from hachoir_core.stream import FragmentedStream |
|||
|
|||
|
|||
class Link(Field): |
|||
def __init__(self, parent, name, *args, **kw): |
|||
Field.__init__(self, parent, name, 0, *args, **kw) |
|||
|
|||
def hasValue(self): |
|||
return True |
|||
|
|||
def createValue(self): |
|||
return self._parent[self.display] |
|||
|
|||
def createDisplay(self): |
|||
value = self.value |
|||
if value is None: |
|||
return "<%s>" % MissingField.__name__ |
|||
return value.path |
|||
|
|||
def _getField(self, name, const): |
|||
target = self.value |
|||
assert self != target |
|||
return target._getField(name, const) |
|||
|
|||
|
|||
class Fragments: |
|||
def __init__(self, first): |
|||
self.first = first |
|||
|
|||
def __iter__(self): |
|||
fragment = self.first |
|||
while fragment is not None: |
|||
data = fragment.getData() |
|||
yield data and data.size |
|||
fragment = fragment.next |
|||
|
|||
|
|||
class Fragment(FieldSet): |
|||
_first = None |
|||
|
|||
def __init__(self, *args, **kw): |
|||
FieldSet.__init__(self, *args, **kw) |
|||
self._field_generator = self._createFields(self._field_generator) |
|||
if self.__class__.createFields == Fragment.createFields: |
|||
self._getData = lambda: self |
|||
|
|||
def getData(self): |
|||
try: |
|||
return self._getData() |
|||
except MissingField, e: |
|||
self.error(str(e)) |
|||
return None |
|||
|
|||
def setLinks(self, first, next=None): |
|||
self._first = first or self |
|||
self._next = next |
|||
self._feedLinks = lambda: self |
|||
return self |
|||
|
|||
def _feedLinks(self): |
|||
while self._first is None and self.readMoreFields(1): |
|||
pass |
|||
if self._first is None: |
|||
raise ParserError("first is None") |
|||
return self |
|||
first = property(lambda self: self._feedLinks()._first) |
|||
|
|||
def _getNext(self): |
|||
next = self._feedLinks()._next |
|||
if callable(next): |
|||
self._next = next = next() |
|||
return next |
|||
next = property(_getNext) |
|||
|
|||
def _createInputStream(self, **args): |
|||
first = self.first |
|||
if first is self and hasattr(first, "_getData"): |
|||
return FragmentedStream(first, packets=Fragments(first), **args) |
|||
return FieldSet._createInputStream(self, **args) |
|||
|
|||
def _createFields(self, field_generator): |
|||
if self._first is None: |
|||
for field in field_generator: |
|||
if self._first is not None: |
|||
break |
|||
yield field |
|||
else: |
|||
raise ParserError("Fragment.setLinks not called") |
|||
else: |
|||
field = None |
|||
if self._first is not self: |
|||
link = Link(self, "first", None) |
|||
link._getValue = lambda: self._first |
|||
yield link |
|||
if self._next: |
|||
link = Link(self, "next", None) |
|||
link.createValue = self._getNext |
|||
yield link |
|||
if field: |
|||
yield field |
|||
for field in field_generator: |
|||
yield field |
|||
|
|||
def createFields(self): |
|||
if self._size is None: |
|||
self._size = self._getSize() |
|||
yield Bytes(self, "data", self._size/8) |
|||
|
@ -0,0 +1,82 @@ |
|||
from hachoir_core.field import BasicFieldSet, GenericFieldSet, ParserError, createRawField |
|||
from hachoir_core.error import HACHOIR_ERRORS |
|||
|
|||
# getgaps(int, int, [listof (int, int)]) -> generator of (int, int) |
|||
# Gets all the gaps not covered by a block in `blocks` from `start` for `length` units. |
|||
def getgaps(start, length, blocks): |
|||
''' |
|||
Example: |
|||
>>> list(getgaps(0, 20, [(15,3), (6,2), (6,2), (1,2), (2,3), (11,2), (9,5)])) |
|||
[(0, 1), (5, 1), (8, 1), (14, 1), (18, 2)] |
|||
''' |
|||
# done this way to avoid mutating the original |
|||
blocks = sorted(blocks, key=lambda b: b[0]) |
|||
end = start+length |
|||
for s, l in blocks: |
|||
if s > start: |
|||
yield (start, s-start) |
|||
start = s |
|||
if s+l > start: |
|||
start = s+l |
|||
if start < end: |
|||
yield (start, end-start) |
|||
|
|||
class NewRootSeekableFieldSet(GenericFieldSet): |
|||
def seekBit(self, address, relative=True): |
|||
if not relative: |
|||
address -= self.absolute_address |
|||
if address < 0: |
|||
raise ParserError("Seek below field set start (%s.%s)" % divmod(address, 8)) |
|||
self._current_size = address |
|||
return None |
|||
|
|||
def seekByte(self, address, relative=True): |
|||
return self.seekBit(address*8, relative) |
|||
|
|||
def _fixLastField(self): |
|||
""" |
|||
Try to fix last field when we know current field set size. |
|||
Returns new added field if any, or None. |
|||
""" |
|||
assert self._size is not None |
|||
|
|||
# Stop parser |
|||
message = ["stop parser"] |
|||
self._field_generator = None |
|||
|
|||
# If last field is too big, delete it |
|||
while self._size < self._current_size: |
|||
field = self._deleteField(len(self._fields)-1) |
|||
message.append("delete field %s" % field.path) |
|||
assert self._current_size <= self._size |
|||
|
|||
blocks = [(x.absolute_address, x.size) for x in self._fields] |
|||
fields = [] |
|||
for start, length in getgaps(self.absolute_address, self._size, blocks): |
|||
self.seekBit(start, relative=False) |
|||
field = createRawField(self, length, "unparsed[]") |
|||
self.setUniqueFieldName(field) |
|||
self._fields.append(field.name, field) |
|||
fields.append(field) |
|||
message.append("found unparsed segment: start %s, length %s" % (start, length)) |
|||
|
|||
self.seekBit(self._size, relative=False) |
|||
message = ", ".join(message) |
|||
if fields: |
|||
self.warning("[Autofix] Fix parser error: " + message) |
|||
return fields |
|||
|
|||
def _stopFeeding(self): |
|||
new_field = None |
|||
if self._size is None: |
|||
if self._parent: |
|||
self._size = self._current_size |
|||
|
|||
new_field = self._fixLastField() |
|||
self._field_generator = None |
|||
return new_field |
|||
|
|||
class NewSeekableFieldSet(NewRootSeekableFieldSet): |
|||
def __init__(self, parent, name, description=None, size=None): |
|||
assert issubclass(parent.__class__, BasicFieldSet) |
|||
NewRootSeekableFieldSet.__init__(self, parent, name, parent.stream, description, size) |
@ -0,0 +1,138 @@ |
|||
from hachoir_core.field import Bits, Bytes |
|||
from hachoir_core.tools import makePrintable, humanFilesize |
|||
from hachoir_core import config |
|||
|
|||
class PaddingBits(Bits): |
|||
""" |
|||
Padding bits used, for example, to align address (of next field). |
|||
See also NullBits and PaddingBytes types. |
|||
|
|||
Arguments: |
|||
* nbits: Size of the field in bits |
|||
|
|||
Optional arguments: |
|||
* pattern (int): Content pattern, eg. 0 if all bits are set to 0 |
|||
""" |
|||
static_size = staticmethod(lambda *args, **kw: args[1]) |
|||
MAX_SIZE = 128 |
|||
|
|||
def __init__(self, parent, name, nbits, description="Padding", pattern=None): |
|||
Bits.__init__(self, parent, name, nbits, description) |
|||
self.pattern = pattern |
|||
self._display_pattern = self.checkPattern() |
|||
|
|||
def checkPattern(self): |
|||
if not(config.check_padding_pattern): |
|||
return False |
|||
if self.pattern != 0: |
|||
return False |
|||
|
|||
if self.MAX_SIZE < self._size: |
|||
value = self._parent.stream.readBits( |
|||
self.absolute_address, self.MAX_SIZE, self._parent.endian) |
|||
else: |
|||
value = self.value |
|||
if value != 0: |
|||
self.warning("padding contents doesn't look normal (invalid pattern)") |
|||
return False |
|||
if self.MAX_SIZE < self._size: |
|||
self.info("only check first %u bits" % self.MAX_SIZE) |
|||
return True |
|||
|
|||
def createDisplay(self): |
|||
if self._display_pattern: |
|||
return u"<padding pattern=%s>" % self.pattern |
|||
else: |
|||
return Bits.createDisplay(self) |
|||
|
|||
class PaddingBytes(Bytes): |
|||
""" |
|||
Padding bytes used, for example, to align address (of next field). |
|||
See also NullBytes and PaddingBits types. |
|||
|
|||
Arguments: |
|||
* nbytes: Size of the field in bytes |
|||
|
|||
Optional arguments: |
|||
* pattern (str): Content pattern, eg. "\0" for nul bytes |
|||
""" |
|||
|
|||
static_size = staticmethod(lambda *args, **kw: args[1]*8) |
|||
MAX_SIZE = 4096 |
|||
|
|||
def __init__(self, parent, name, nbytes, |
|||
description="Padding", pattern=None): |
|||
""" pattern is None or repeated string """ |
|||
assert (pattern is None) or (isinstance(pattern, str)) |
|||
Bytes.__init__(self, parent, name, nbytes, description) |
|||
self.pattern = pattern |
|||
self._display_pattern = self.checkPattern() |
|||
|
|||
def checkPattern(self): |
|||
if not(config.check_padding_pattern): |
|||
return False |
|||
if self.pattern is None: |
|||
return False |
|||
|
|||
if self.MAX_SIZE < self._size/8: |
|||
self.info("only check first %s of padding" % humanFilesize(self.MAX_SIZE)) |
|||
content = self._parent.stream.readBytes( |
|||
self.absolute_address, self.MAX_SIZE) |
|||
else: |
|||
content = self.value |
|||
index = 0 |
|||
pattern_len = len(self.pattern) |
|||
while index < len(content): |
|||
if content[index:index+pattern_len] != self.pattern: |
|||
self.warning( |
|||
"padding contents doesn't look normal" |
|||
" (invalid pattern at byte %u)!" |
|||
% index) |
|||
return False |
|||
index += pattern_len |
|||
return True |
|||
|
|||
def createDisplay(self): |
|||
if self._display_pattern: |
|||
return u"<padding pattern=%s>" % makePrintable(self.pattern, "ASCII", quote="'") |
|||
else: |
|||
return Bytes.createDisplay(self) |
|||
|
|||
def createRawDisplay(self): |
|||
return Bytes.createDisplay(self) |
|||
|
|||
class NullBits(PaddingBits): |
|||
""" |
|||
Null padding bits used, for example, to align address (of next field). |
|||
See also PaddingBits and NullBytes types. |
|||
|
|||
Arguments: |
|||
* nbits: Size of the field in bits |
|||
""" |
|||
|
|||
def __init__(self, parent, name, nbits, description=None): |
|||
PaddingBits.__init__(self, parent, name, nbits, description, pattern=0) |
|||
|
|||
def createDisplay(self): |
|||
if self._display_pattern: |
|||
return "<null>" |
|||
else: |
|||
return Bits.createDisplay(self) |
|||
|
|||
class NullBytes(PaddingBytes): |
|||
""" |
|||
Null padding bytes used, for example, to align address (of next field). |
|||
See also PaddingBytes and NullBits types. |
|||
|
|||
Arguments: |
|||
* nbytes: Size of the field in bytes |
|||
""" |
|||
def __init__(self, parent, name, nbytes, description=None): |
|||
PaddingBytes.__init__(self, parent, name, nbytes, description, pattern="\0") |
|||
|
|||
def createDisplay(self): |
|||
if self._display_pattern: |
|||
return "<null>" |
|||
else: |
|||
return Bytes.createDisplay(self) |
|||
|
@ -0,0 +1,40 @@ |
|||
from hachoir_core.endian import BIG_ENDIAN, LITTLE_ENDIAN |
|||
from hachoir_core.field import GenericFieldSet |
|||
from hachoir_core.log import Logger |
|||
import hachoir_core.config as config |
|||
|
|||
class Parser(GenericFieldSet): |
|||
""" |
|||
A parser is the root of all other fields. It create first level of fields |
|||
and have special attributes and methods: |
|||
- endian: Byte order (L{BIG_ENDIAN} or L{LITTLE_ENDIAN}) of input data ; |
|||
- stream: Data input stream (set in L{__init__()}) ; |
|||
- size: Field set size will be size of input stream. |
|||
""" |
|||
|
|||
def __init__(self, stream, description=None): |
|||
""" |
|||
Parser constructor |
|||
|
|||
@param stream: Data input stream (see L{InputStream}) |
|||
@param description: (optional) String description |
|||
""" |
|||
# Check arguments |
|||
assert hasattr(self, "endian") \ |
|||
and self.endian in (BIG_ENDIAN, LITTLE_ENDIAN) |
|||
|
|||
# Call parent constructor |
|||
GenericFieldSet.__init__(self, None, "root", stream, description, stream.askSize(self)) |
|||
|
|||
def _logger(self): |
|||
return Logger._logger(self) |
|||
|
|||
def _setSize(self, size): |
|||
self._truncate(size) |
|||
self.raiseEvent("field-resized", self) |
|||
size = property(lambda self: self._size, doc="Size in bits") |
|||
|
|||
path = property(lambda self: "/") |
|||
|
|||
# dummy definition to prevent hachoir-core from depending on hachoir-parser |
|||
autofix = property(lambda self: config.autofix) |
@ -0,0 +1,182 @@ |
|||
from hachoir_core.field import Field, BasicFieldSet, FakeArray, MissingField, ParserError |
|||
from hachoir_core.tools import makeUnicode |
|||
from hachoir_core.error import HACHOIR_ERRORS |
|||
from itertools import repeat |
|||
import hachoir_core.config as config |
|||
|
|||
class RootSeekableFieldSet(BasicFieldSet): |
|||
def __init__(self, parent, name, stream, description, size): |
|||
BasicFieldSet.__init__(self, parent, name, stream, description, size) |
|||
self._generator = self.createFields() |
|||
self._offset = 0 |
|||
self._current_size = 0 |
|||
if size: |
|||
self._current_max_size = size |
|||
else: |
|||
self._current_max_size = 0 |
|||
self._field_dict = {} |
|||
self._field_array = [] |
|||
|
|||
def _feedOne(self): |
|||
assert self._generator |
|||
field = self._generator.next() |
|||
self._addField(field) |
|||
return field |
|||
|
|||
def array(self, key): |
|||
return FakeArray(self, key) |
|||
|
|||
def getFieldByAddress(self, address, feed=True): |
|||
for field in self._field_array: |
|||
if field.address <= address < field.address + field.size: |
|||
return field |
|||
for field in self._readFields(): |
|||
if field.address <= address < field.address + field.size: |
|||
return field |
|||
return None |
|||
|
|||
def _stopFeed(self): |
|||
self._size = self._current_max_size |
|||
self._generator = None |
|||
done = property(lambda self: not bool(self._generator)) |
|||
|
|||
def _getSize(self): |
|||
if self._size is None: |
|||
self._feedAll() |
|||
return self._size |
|||
size = property(_getSize) |
|||
|
|||
def _getField(self, key, const): |
|||
field = Field._getField(self, key, const) |
|||
if field is not None: |
|||
return field |
|||
if key in self._field_dict: |
|||
return self._field_dict[key] |
|||
if self._generator and not const: |
|||
try: |
|||
while True: |
|||
field = self._feedOne() |
|||
if field.name == key: |
|||
return field |
|||
except StopIteration: |
|||
self._stopFeed() |
|||
except HACHOIR_ERRORS, err: |
|||
self.error("Error: %s" % makeUnicode(err)) |
|||
self._stopFeed() |
|||
return None |
|||
|
|||
def getField(self, key, const=True): |
|||
if isinstance(key, (int, long)): |
|||
if key < 0: |
|||
raise KeyError("Key must be positive!") |
|||
if not const: |
|||
self.readFirstFields(key+1) |
|||
if len(self._field_array) <= key: |
|||
raise MissingField(self, key) |
|||
return self._field_array[key] |
|||
return Field.getField(self, key, const) |
|||
|
|||
def _addField(self, field): |
|||
if field._name.endswith("[]"): |
|||
self.setUniqueFieldName(field) |
|||
if config.debug: |
|||
self.info("[+] DBG: _addField(%s)" % field.name) |
|||
|
|||
if field._address != self._offset: |
|||
self.warning("Set field %s address to %s (was %s)" % ( |
|||
field.path, self._offset//8, field._address//8)) |
|||
field._address = self._offset |
|||
assert field.name not in self._field_dict |
|||
|
|||
self._checkFieldSize(field) |
|||
|
|||
self._field_dict[field.name] = field |
|||
self._field_array.append(field) |
|||
self._current_size += field.size |
|||
self._offset += field.size |
|||
self._current_max_size = max(self._current_max_size, field.address + field.size) |
|||
|
|||
def _checkAddress(self, address): |
|||
if self._size is not None: |
|||
max_addr = self._size |
|||
else: |
|||
# FIXME: Use parent size |
|||
max_addr = self.stream.size |
|||
return address < max_addr |
|||
|
|||
def _checkFieldSize(self, field): |
|||
size = field.size |
|||
addr = field.address |
|||
if not self._checkAddress(addr+size-1): |
|||
raise ParserError("Unable to add %s: field is too large" % field.name) |
|||
|
|||
def seekBit(self, address, relative=True): |
|||
if not relative: |
|||
address -= self.absolute_address |
|||
if address < 0: |
|||
raise ParserError("Seek below field set start (%s.%s)" % divmod(address, 8)) |
|||
if not self._checkAddress(address): |
|||
raise ParserError("Seek above field set end (%s.%s)" % divmod(address, 8)) |
|||
self._offset = address |
|||
return None |
|||
|
|||
def seekByte(self, address, relative=True): |
|||
return self.seekBit(address*8, relative) |
|||
|
|||
def readMoreFields(self, number): |
|||
return self._readMoreFields(xrange(number)) |
|||
|
|||
def _feedAll(self): |
|||
return self._readMoreFields(repeat(1)) |
|||
|
|||
def _readFields(self): |
|||
while True: |
|||
added = self._readMoreFields(xrange(1)) |
|||
if not added: |
|||
break |
|||
yield self._field_array[-1] |
|||
|
|||
def _readMoreFields(self, index_generator): |
|||
added = 0 |
|||
if self._generator: |
|||
try: |
|||
for index in index_generator: |
|||
self._feedOne() |
|||
added += 1 |
|||
except StopIteration: |
|||
self._stopFeed() |
|||
except HACHOIR_ERRORS, err: |
|||
self.error("Error: %s" % makeUnicode(err)) |
|||
self._stopFeed() |
|||
return added |
|||
|
|||
current_length = property(lambda self: len(self._field_array)) |
|||
current_size = property(lambda self: self._offset) |
|||
|
|||
def __iter__(self): |
|||
for field in self._field_array: |
|||
yield field |
|||
if self._generator: |
|||
try: |
|||
while True: |
|||
yield self._feedOne() |
|||
except StopIteration: |
|||
self._stopFeed() |
|||
raise StopIteration |
|||
|
|||
def __len__(self): |
|||
if self._generator: |
|||
self._feedAll() |
|||
return len(self._field_array) |
|||
|
|||
def nextFieldAddress(self): |
|||
return self._offset |
|||
|
|||
def getFieldIndex(self, field): |
|||
return self._field_array.index(field) |
|||
|
|||
class SeekableFieldSet(RootSeekableFieldSet): |
|||
def __init__(self, parent, name, description=None, size=None): |
|||
assert issubclass(parent.__class__, BasicFieldSet) |
|||
RootSeekableFieldSet.__init__(self, parent, name, parent.stream, description, size) |
|||
|
@ -0,0 +1,54 @@ |
|||
from hachoir_core.field import FieldSet, ParserError |
|||
|
|||
class StaticFieldSet(FieldSet): |
|||
""" |
|||
Static field set: format class attribute is a tuple of all fields |
|||
in syntax like: |
|||
format = ( |
|||
(TYPE1, ARG1, ARG2, ...), |
|||
(TYPE2, ARG1, ARG2, ..., {KEY1=VALUE1, ...}), |
|||
... |
|||
) |
|||
|
|||
Types with dynamic size are forbidden, eg. CString, PascalString8, etc. |
|||
""" |
|||
format = None # You have to redefine this class variable |
|||
_class = None |
|||
|
|||
def __new__(cls, *args, **kw): |
|||
assert cls.format is not None, "Class attribute 'format' is not set" |
|||
if cls._class is not cls.__name__: |
|||
cls._class = cls.__name__ |
|||
cls.static_size = cls._computeStaticSize() |
|||
return object.__new__(cls, *args, **kw) |
|||
|
|||
@staticmethod |
|||
def _computeItemSize(item): |
|||
item_class = item[0] |
|||
if item_class.static_size is None: |
|||
raise ParserError("Unable to get static size of field type: %s" |
|||
% item_class.__name__) |
|||
if callable(item_class.static_size): |
|||
if isinstance(item[-1], dict): |
|||
return item_class.static_size(*item[1:-1], **item[-1]) |
|||
else: |
|||
return item_class.static_size(*item[1:]) |
|||
else: |
|||
assert isinstance(item_class.static_size, (int, long)) |
|||
return item_class.static_size |
|||
|
|||
def createFields(self): |
|||
for item in self.format: |
|||
if isinstance(item[-1], dict): |
|||
yield item[0](self, *item[1:-1], **item[-1]) |
|||
else: |
|||
yield item[0](self, *item[1:]) |
|||
|
|||
@classmethod |
|||
def _computeStaticSize(cls, *args): |
|||
return sum(cls._computeItemSize(item) for item in cls.format) |
|||
|
|||
# Initial value of static_size, it changes when first instance |
|||
# is created (see __new__) |
|||
static_size = _computeStaticSize |
|||
|
@ -0,0 +1,402 @@ |
|||
""" |
|||
String field classes: |
|||
- String: Fixed length string (no prefix/no suffix) ; |
|||
- CString: String which ends with nul byte ("\0") ; |
|||
- UnixLine: Unix line of text, string which ends with "\n" ; |
|||
- PascalString8, PascalString16, PascalString32: String prefixed with |
|||
length written in a 8, 16, 32-bit integer (use parent endian). |
|||
|
|||
Constructor has optional arguments: |
|||
- strip: value can be a string or True ; |
|||
- charset: if set, convert string to unicode using this charset (in "replace" |
|||
mode which replace all buggy characters with "."). |
|||
|
|||
Note: For PascalStringXX, prefixed value is the number of bytes and not |
|||
of characters! |
|||
""" |
|||
|
|||
from hachoir_core.field import FieldError, Bytes |
|||
from hachoir_core.endian import LITTLE_ENDIAN, BIG_ENDIAN |
|||
from hachoir_core.tools import alignValue, makePrintable |
|||
from hachoir_core.i18n import guessBytesCharset, _ |
|||
from hachoir_core import config |
|||
from codecs import BOM_UTF16_LE, BOM_UTF16_BE, BOM_UTF32_LE, BOM_UTF32_BE |
|||
|
|||
# Default charset used to convert byte string to Unicode |
|||
# This charset is used if no charset is specified or on conversion error |
|||
FALLBACK_CHARSET = "ISO-8859-1" |
|||
|
|||
class GenericString(Bytes): |
|||
""" |
|||
Generic string class. |
|||
|
|||
charset have to be in CHARSET_8BIT or in UTF_CHARSET. |
|||
""" |
|||
|
|||
VALID_FORMATS = ("C", "UnixLine", |
|||
"fixed", "Pascal8", "Pascal16", "Pascal32") |
|||
|
|||
# 8-bit charsets |
|||
CHARSET_8BIT = set(( |
|||
"ASCII", # ANSI X3.4-1968 |
|||
"MacRoman", |
|||
"CP037", # EBCDIC 037 |
|||
"CP874", # Thai |
|||
"WINDOWS-1250", # Central Europe |
|||
"WINDOWS-1251", # Cyrillic |
|||
"WINDOWS-1252", # Latin I |
|||
"WINDOWS-1253", # Greek |
|||
"WINDOWS-1254", # Turkish |
|||
"WINDOWS-1255", # Hebrew |
|||
"WINDOWS-1256", # Arabic |
|||
"WINDOWS-1257", # Baltic |
|||
"WINDOWS-1258", # Vietnam |
|||
"ISO-8859-1", # Latin-1 |
|||
"ISO-8859-2", # Latin-2 |
|||
"ISO-8859-3", # Latin-3 |
|||
"ISO-8859-4", # Latin-4 |
|||
"ISO-8859-5", |
|||
"ISO-8859-6", |
|||
"ISO-8859-7", |
|||
"ISO-8859-8", |
|||
"ISO-8859-9", # Latin-5 |
|||
"ISO-8859-10", # Latin-6 |
|||
"ISO-8859-11", # Thai |
|||
"ISO-8859-13", # Latin-7 |
|||
"ISO-8859-14", # Latin-8 |
|||
"ISO-8859-15", # Latin-9 or ("Latin-0") |
|||
"ISO-8859-16", # Latin-10 |
|||
)) |
|||
|
|||
# UTF-xx charset familly |
|||
UTF_CHARSET = { |
|||
"UTF-8": (8, None), |
|||
"UTF-16-LE": (16, LITTLE_ENDIAN), |
|||
"UTF-32LE": (32, LITTLE_ENDIAN), |
|||
"UTF-16-BE": (16, BIG_ENDIAN), |
|||
"UTF-32BE": (32, BIG_ENDIAN), |
|||
"UTF-16": (16, "BOM"), |
|||
"UTF-32": (32, "BOM"), |
|||
} |
|||
|
|||
# UTF-xx BOM => charset with endian |
|||
UTF_BOM = { |
|||
16: {BOM_UTF16_LE: "UTF-16-LE", BOM_UTF16_BE: "UTF-16-BE"}, |
|||
32: {BOM_UTF32_LE: "UTF-32LE", BOM_UTF32_BE: "UTF-32BE"}, |
|||
} |
|||
|
|||
# Suffix format: value is suffix (string) |
|||
SUFFIX_FORMAT = { |
|||
"C": { |
|||
8: {LITTLE_ENDIAN: "\0", BIG_ENDIAN: "\0"}, |
|||
16: {LITTLE_ENDIAN: "\0\0", BIG_ENDIAN: "\0\0"}, |
|||
32: {LITTLE_ENDIAN: "\0\0\0\0", BIG_ENDIAN: "\0\0\0\0"}, |
|||
}, |
|||
"UnixLine": { |
|||
8: {LITTLE_ENDIAN: "\n", BIG_ENDIAN: "\n"}, |
|||
16: {LITTLE_ENDIAN: "\n\0", BIG_ENDIAN: "\0\n"}, |
|||
32: {LITTLE_ENDIAN: "\n\0\0\0", BIG_ENDIAN: "\0\0\0\n"}, |
|||
}, |
|||
|
|||
} |
|||
|
|||
# Pascal format: value is the size of the prefix in bits |
|||
PASCAL_FORMATS = { |
|||
"Pascal8": 1, |
|||
"Pascal16": 2, |
|||
"Pascal32": 4 |
|||
} |
|||
|
|||
# Raw value: with prefix and suffix, not stripped, |
|||
# and not converted to Unicode |
|||
_raw_value = None |
|||
|
|||
def __init__(self, parent, name, format, description=None, |
|||
strip=None, charset=None, nbytes=None, truncate=None): |
|||
Bytes.__init__(self, parent, name, 1, description) |
|||
|
|||
# Is format valid? |
|||
assert format in self.VALID_FORMATS |
|||
|
|||
# Store options |
|||
self._format = format |
|||
self._strip = strip |
|||
self._truncate = truncate |
|||
|
|||
# Check charset and compute character size in bytes |
|||
# (or None when it's not possible to guess character size) |
|||
if not charset or charset in self.CHARSET_8BIT: |
|||
self._character_size = 1 # one byte per character |
|||
elif charset in self.UTF_CHARSET: |
|||
self._character_size = None |
|||
else: |
|||
raise FieldError("Invalid charset for %s: \"%s\"" % |
|||
(self.path, charset)) |
|||
self._charset = charset |
|||
|
|||
# It is a fixed string? |
|||
if nbytes is not None: |
|||
assert self._format == "fixed" |
|||
# Arbitrary limits, just to catch some bugs... |
|||
if not (1 <= nbytes <= 0xffff): |
|||
raise FieldError("Invalid string size for %s: %s" % |
|||
(self.path, nbytes)) |
|||
self._content_size = nbytes # content length in bytes |
|||
self._size = nbytes * 8 |
|||
self._content_offset = 0 |
|||
else: |
|||
# Format with a suffix: Find the end of the string |
|||
if self._format in self.SUFFIX_FORMAT: |
|||
self._content_offset = 0 |
|||
|
|||
# Choose the suffix |
|||
suffix = self.suffix_str |
|||
|
|||
# Find the suffix |
|||
length = self._parent.stream.searchBytesLength( |
|||
suffix, False, self.absolute_address) |
|||
if length is None: |
|||
raise FieldError("Unable to find end of string %s (format %s)!" |
|||
% (self.path, self._format)) |
|||
if 1 < len(suffix): |
|||
# Fix length for little endian bug with UTF-xx charset: |
|||
# u"abc" -> "a\0b\0c\0\0\0" (UTF-16-LE) |
|||
# search returns length=5, whereas real lenght is 6 |
|||
length = alignValue(length, len(suffix)) |
|||
|
|||
# Compute sizes |
|||
self._content_size = length # in bytes |
|||
self._size = (length + len(suffix)) * 8 |
|||
|
|||
# Format with a prefix: Read prefixed length in bytes |
|||
else: |
|||
assert self._format in self.PASCAL_FORMATS |
|||
|
|||
# Get the prefix size |
|||
prefix_size = self.PASCAL_FORMATS[self._format] |
|||
self._content_offset = prefix_size |
|||
|
|||
# Read the prefix and compute sizes |
|||
value = self._parent.stream.readBits( |
|||
self.absolute_address, prefix_size*8, self._parent.endian) |
|||
self._content_size = value # in bytes |
|||
self._size = (prefix_size + value) * 8 |
|||
|
|||
# For UTF-16 and UTF-32, choose the right charset using BOM |
|||
if self._charset in self.UTF_CHARSET: |
|||
# Charset requires a BOM? |
|||
bomsize, endian = self.UTF_CHARSET[self._charset] |
|||
if endian == "BOM": |
|||
# Read the BOM value |
|||
nbytes = bomsize // 8 |
|||
bom = self._parent.stream.readBytes(self.absolute_address, nbytes) |
|||
|
|||
# Choose right charset using the BOM |
|||
bom_endian = self.UTF_BOM[bomsize] |
|||
if bom not in bom_endian: |
|||
raise FieldError("String %s has invalid BOM (%s)!" |
|||
% (self.path, repr(bom))) |
|||
self._charset = bom_endian[bom] |
|||
self._content_size -= nbytes |
|||
self._content_offset += nbytes |
|||
|
|||
# Compute length in character if possible |
|||
if self._character_size: |
|||
self._length = self._content_size // self._character_size |
|||
else: |
|||
self._length = None |
|||
|
|||
@staticmethod |
|||
def staticSuffixStr(format, charset, endian): |
|||
if format not in GenericString.SUFFIX_FORMAT: |
|||
return '' |
|||
suffix = GenericString.SUFFIX_FORMAT[format] |
|||
if charset in GenericString.UTF_CHARSET: |
|||
suffix_size = GenericString.UTF_CHARSET[charset][0] |
|||
suffix = suffix[suffix_size] |
|||
else: |
|||
suffix = suffix[8] |
|||
return suffix[endian] |
|||
|
|||
def _getSuffixStr(self): |
|||
return self.staticSuffixStr( |
|||
self._format, self._charset, self._parent.endian) |
|||
suffix_str = property(_getSuffixStr) |
|||
|
|||
def _convertText(self, text): |
|||
if not self._charset: |
|||
# charset is still unknown: guess the charset |
|||
self._charset = guessBytesCharset(text, default=FALLBACK_CHARSET) |
|||
|
|||
# Try to convert to Unicode |
|||
try: |
|||
return unicode(text, self._charset, "strict") |
|||
except UnicodeDecodeError, err: |
|||
pass |
|||
|
|||
#--- Conversion error --- |
|||
|
|||
# Fix truncated UTF-16 string like 'B\0e' (3 bytes) |
|||
# => Add missing nul byte: 'B\0e\0' (4 bytes) |
|||
if err.reason == "truncated data" \ |
|||
and err.end == len(text) \ |
|||
and self._charset == "UTF-16-LE": |
|||
try: |
|||
text = unicode(text+"\0", self._charset, "strict") |
|||
self.warning("Fix truncated %s string: add missing nul byte" % self._charset) |
|||
return text |
|||
except UnicodeDecodeError, err: |
|||
pass |
|||
|
|||
# On error, use FALLBACK_CHARSET |
|||
self.warning(u"Unable to convert string to Unicode: %s" % err) |
|||
return unicode(text, FALLBACK_CHARSET, "strict") |
|||
|
|||
def _guessCharset(self): |
|||
addr = self.absolute_address + self._content_offset * 8 |
|||
bytes = self._parent.stream.readBytes(addr, self._content_size) |
|||
return guessBytesCharset(bytes, default=FALLBACK_CHARSET) |
|||
|
|||
def createValue(self, human=True): |
|||
# Compress data address (in bits) and size (in bytes) |
|||
if human: |
|||
addr = self.absolute_address + self._content_offset * 8 |
|||
size = self._content_size |
|||
else: |
|||
addr = self.absolute_address |
|||
size = self._size // 8 |
|||
if size == 0: |
|||
# Empty string |
|||
return u"" |
|||
|
|||
# Read bytes in data stream |
|||
text = self._parent.stream.readBytes(addr, size) |
|||
|
|||
# Don't transform data? |
|||
if not human: |
|||
return text |
|||
|
|||
# Convert text to Unicode |
|||
text = self._convertText(text) |
|||
|
|||
# Truncate |
|||
if self._truncate: |
|||
pos = text.find(self._truncate) |
|||
if 0 <= pos: |
|||
text = text[:pos] |
|||
|
|||
# Strip string if needed |
|||
if self._strip: |
|||
if isinstance(self._strip, (str, unicode)): |
|||
text = text.strip(self._strip) |
|||
else: |
|||
text = text.strip() |
|||
assert isinstance(text, unicode) |
|||
return text |
|||
|
|||
def createDisplay(self, human=True): |
|||
if not human: |
|||
if self._raw_value is None: |
|||
self._raw_value = GenericString.createValue(self, False) |
|||
value = makePrintable(self._raw_value, "ASCII", to_unicode=True) |
|||
elif self._charset: |
|||
value = makePrintable(self.value, "ISO-8859-1", to_unicode=True) |
|||
else: |
|||
value = self.value |
|||
if config.max_string_length < len(value): |
|||
# Truncate string if needed |
|||
value = "%s(...)" % value[:config.max_string_length] |
|||
if not self._charset or not human: |
|||
return makePrintable(value, "ASCII", quote='"', to_unicode=True) |
|||
else: |
|||
if value: |
|||
return '"%s"' % value.replace('"', '\\"') |
|||
else: |
|||
return _("(empty)") |
|||
|
|||
def createRawDisplay(self): |
|||
return GenericString.createDisplay(self, human=False) |
|||
|
|||
def _getLength(self): |
|||
if self._length is None: |
|||
self._length = len(self.value) |
|||
return self._length |
|||
length = property(_getLength, doc="String length in characters") |
|||
|
|||
def _getFormat(self): |
|||
return self._format |
|||
format = property(_getFormat, doc="String format (eg. 'C')") |
|||
|
|||
def _getCharset(self): |
|||
if not self._charset: |
|||
self._charset = self._guessCharset() |
|||
return self._charset |
|||
charset = property(_getCharset, doc="String charset (eg. 'ISO-8859-1')") |
|||
|
|||
def _getContentSize(self): |
|||
return self._content_size |
|||
content_size = property(_getContentSize, doc="Content size in bytes") |
|||
|
|||
def _getContentOffset(self): |
|||
return self._content_offset |
|||
content_offset = property(_getContentOffset, doc="Content offset in bytes") |
|||
|
|||
def getFieldType(self): |
|||
info = self.charset |
|||
if self._strip: |
|||
if isinstance(self._strip, (str, unicode)): |
|||
info += ",strip=%s" % makePrintable(self._strip, "ASCII", quote="'") |
|||
else: |
|||
info += ",strip=True" |
|||
return "%s<%s>" % (Bytes.getFieldType(self), info) |
|||
|
|||
def stringFactory(name, format, doc): |
|||
class NewString(GenericString): |
|||
__doc__ = doc |
|||
def __init__(self, parent, name, description=None, |
|||
strip=None, charset=None, truncate=None): |
|||
GenericString.__init__(self, parent, name, format, description, |
|||
strip=strip, charset=charset, truncate=truncate) |
|||
cls = NewString |
|||
cls.__name__ = name |
|||
return cls |
|||
|
|||
# String which ends with nul byte ("\0") |
|||
CString = stringFactory("CString", "C", |
|||
r"""C string: string ending with nul byte. |
|||
See GenericString to get more information.""") |
|||
|
|||
# Unix line of text: string which ends with "\n" (ASCII 0x0A) |
|||
UnixLine = stringFactory("UnixLine", "UnixLine", |
|||
r"""Unix line: string ending with "\n" (ASCII code 10). |
|||
See GenericString to get more information.""") |
|||
|
|||
# String prefixed with length written in a 8-bit integer |
|||
PascalString8 = stringFactory("PascalString8", "Pascal8", |
|||
r"""Pascal string: string prefixed with 8-bit integer containing its length (endian depends on parent endian). |
|||
See GenericString to get more information.""") |
|||
|
|||
# String prefixed with length written in a 16-bit integer (use parent endian) |
|||
PascalString16 = stringFactory("PascalString16", "Pascal16", |
|||
r"""Pascal string: string prefixed with 16-bit integer containing its length (endian depends on parent endian). |
|||
See GenericString to get more information.""") |
|||
|
|||
# String prefixed with length written in a 32-bit integer (use parent endian) |
|||
PascalString32 = stringFactory("PascalString32", "Pascal32", |
|||
r"""Pascal string: string prefixed with 32-bit integer containing its length (endian depends on parent endian). |
|||
See GenericString to get more information.""") |
|||
|
|||
|
|||
class String(GenericString): |
|||
""" |
|||
String with fixed size (size in bytes). |
|||
See GenericString to get more information. |
|||
""" |
|||
static_size = staticmethod(lambda *args, **kw: args[1]*8) |
|||
|
|||
def __init__(self, parent, name, nbytes, description=None, |
|||
strip=None, charset=None, truncate=None): |
|||
GenericString.__init__(self, parent, name, "fixed", description, |
|||
strip=strip, charset=charset, nbytes=nbytes, truncate=truncate) |
|||
String.__name__ = "FixedString" |
|||
|
@ -0,0 +1,72 @@ |
|||
from hachoir_core.field import Bytes |
|||
from hachoir_core.tools import makePrintable, humanFilesize |
|||
from hachoir_core.stream import InputIOStream |
|||
|
|||
class SubFile(Bytes): |
|||
""" |
|||
File stored in another file |
|||
""" |
|||
def __init__(self, parent, name, length, description=None, |
|||
parser=None, filename=None, mime_type=None, parser_class=None): |
|||
if filename: |
|||
if not isinstance(filename, unicode): |
|||
filename = makePrintable(filename, "ISO-8859-1") |
|||
if not description: |
|||
description = 'File "%s" (%s)' % (filename, humanFilesize(length)) |
|||
Bytes.__init__(self, parent, name, length, description) |
|||
def createInputStream(cis, **args): |
|||
tags = args.setdefault("tags",[]) |
|||
if parser_class: |
|||
tags.append(( "class", parser_class )) |
|||
if parser is not None: |
|||
tags.append(( "id", parser.PARSER_TAGS["id"] )) |
|||
if mime_type: |
|||
tags.append(( "mime", mime_type )) |
|||
if filename: |
|||
tags.append(( "filename", filename )) |
|||
return cis(**args) |
|||
self.setSubIStream(createInputStream) |
|||
|
|||
class CompressedStream: |
|||
offset = 0 |
|||
|
|||
def __init__(self, stream, decompressor): |
|||
self.stream = stream |
|||
self.decompressor = decompressor(stream) |
|||
self._buffer = '' |
|||
|
|||
def read(self, size): |
|||
d = self._buffer |
|||
data = [ d[:size] ] |
|||
size -= len(d) |
|||
if size > 0: |
|||
d = self.decompressor(size) |
|||
data.append(d[:size]) |
|||
size -= len(d) |
|||
while size > 0: |
|||
n = 4096 |
|||
if self.stream.size: |
|||
n = min(self.stream.size - self.offset, n) |
|||
if not n: |
|||
break |
|||
d = self.stream.read(self.offset, n)[1] |
|||
self.offset += 8 * len(d) |
|||
d = self.decompressor(size, d) |
|||
data.append(d[:size]) |
|||
size -= len(d) |
|||
self._buffer = d[size+len(d):] |
|||
return ''.join(data) |
|||
|
|||
def CompressedField(field, decompressor): |
|||
def createInputStream(cis, source=None, **args): |
|||
if field._parent: |
|||
stream = cis(source=source) |
|||
args.setdefault("tags", []).extend(stream.tags) |
|||
else: |
|||
stream = field.stream |
|||
input = CompressedStream(stream, decompressor) |
|||
if source is None: |
|||
source = "Compressed source: '%s' (offset=%s)" % (stream.source, field.absolute_address) |
|||
return InputIOStream(input, source=source, **args) |
|||
field.setSubIStream(createInputStream) |
|||
return field |
@ -0,0 +1,86 @@ |
|||
from hachoir_core.tools import (humanDatetime, humanDuration, |
|||
timestampUNIX, timestampMac32, timestampUUID60, |
|||
timestampWin64, durationWin64) |
|||
from hachoir_core.field import Bits, FieldSet |
|||
from datetime import datetime |
|||
|
|||
class GenericTimestamp(Bits): |
|||
def __init__(self, parent, name, size, description=None): |
|||
Bits.__init__(self, parent, name, size, description) |
|||
|
|||
def createDisplay(self): |
|||
return humanDatetime(self.value) |
|||
|
|||
def createRawDisplay(self): |
|||
value = Bits.createValue(self) |
|||
return unicode(value) |
|||
|
|||
def __nonzero__(self): |
|||
return Bits.createValue(self) != 0 |
|||
|
|||
def timestampFactory(cls_name, handler, size): |
|||
class Timestamp(GenericTimestamp): |
|||
def __init__(self, parent, name, description=None): |
|||
GenericTimestamp.__init__(self, parent, name, size, description) |
|||
|
|||
def createValue(self): |
|||
value = Bits.createValue(self) |
|||
return handler(value) |
|||
cls = Timestamp |
|||
cls.__name__ = cls_name |
|||
return cls |
|||
|
|||
TimestampUnix32 = timestampFactory("TimestampUnix32", timestampUNIX, 32) |
|||
TimestampUnix64 = timestampFactory("TimestampUnix64", timestampUNIX, 64) |
|||
TimestampMac32 = timestampFactory("TimestampUnix32", timestampMac32, 32) |
|||
TimestampUUID60 = timestampFactory("TimestampUUID60", timestampUUID60, 60) |
|||
TimestampWin64 = timestampFactory("TimestampWin64", timestampWin64, 64) |
|||
|
|||
class TimeDateMSDOS32(FieldSet): |
|||
""" |
|||
32-bit MS-DOS timestamp (16-bit time, 16-bit date) |
|||
""" |
|||
static_size = 32 |
|||
|
|||
def createFields(self): |
|||
# TODO: Create type "MSDOS_Second" : value*2 |
|||
yield Bits(self, "second", 5, "Second/2") |
|||
yield Bits(self, "minute", 6) |
|||
yield Bits(self, "hour", 5) |
|||
|
|||
yield Bits(self, "day", 5) |
|||
yield Bits(self, "month", 4) |
|||
# TODO: Create type "MSDOS_Year" : value+1980 |
|||
yield Bits(self, "year", 7, "Number of year after 1980") |
|||
|
|||
def createValue(self): |
|||
return datetime( |
|||
1980+self["year"].value, self["month"].value, self["day"].value, |
|||
self["hour"].value, self["minute"].value, 2*self["second"].value) |
|||
|
|||
def createDisplay(self): |
|||
return humanDatetime(self.value) |
|||
|
|||
class DateTimeMSDOS32(TimeDateMSDOS32): |
|||
""" |
|||
32-bit MS-DOS timestamp (16-bit date, 16-bit time) |
|||
""" |
|||
def createFields(self): |
|||
yield Bits(self, "day", 5) |
|||
yield Bits(self, "month", 4) |
|||
yield Bits(self, "year", 7, "Number of year after 1980") |
|||
yield Bits(self, "second", 5, "Second/2") |
|||
yield Bits(self, "minute", 6) |
|||
yield Bits(self, "hour", 5) |
|||
|
|||
class TimedeltaWin64(GenericTimestamp): |
|||
def __init__(self, parent, name, description=None): |
|||
GenericTimestamp.__init__(self, parent, name, 64, description) |
|||
|
|||
def createDisplay(self): |
|||
return humanDuration(self.value) |
|||
|
|||
def createValue(self): |
|||
value = Bits.createValue(self) |
|||
return durationWin64(value) |
|||
|
@ -0,0 +1,38 @@ |
|||
from hachoir_core.field import Field, FieldSet, ParserError |
|||
|
|||
class GenericVector(FieldSet): |
|||
def __init__(self, parent, name, nb_items, item_class, item_name="item", description=None): |
|||
# Sanity checks |
|||
assert issubclass(item_class, Field) |
|||
assert isinstance(item_class.static_size, (int, long)) |
|||
if not(0 < nb_items): |
|||
raise ParserError('Unable to create empty vector "%s" in %s' \ |
|||
% (name, parent.path)) |
|||
size = nb_items * item_class.static_size |
|||
self.__nb_items = nb_items |
|||
self._item_class = item_class |
|||
self._item_name = item_name |
|||
FieldSet.__init__(self, parent, name, description, size=size) |
|||
|
|||
def __len__(self): |
|||
return self.__nb_items |
|||
|
|||
def createFields(self): |
|||
name = self._item_name + "[]" |
|||
parser = self._item_class |
|||
for index in xrange(len(self)): |
|||
yield parser(self, name) |
|||
|
|||
class UserVector(GenericVector): |
|||
""" |
|||
To implement: |
|||
- item_name: name of a field without [] (eg. "color" becomes "color[0]"), |
|||
default value is "item" |
|||
- item_class: class of an item |
|||
""" |
|||
item_class = None |
|||
item_name = "item" |
|||
|
|||
def __init__(self, parent, name, nb_items, description=None): |
|||
GenericVector.__init__(self, parent, name, nb_items, self.item_class, self.item_name, description) |
|||
|
@ -0,0 +1,214 @@ |
|||
# -*- coding: UTF-8 -*- |
|||
""" |
|||
Functions to manage internationalisation (i18n): |
|||
- initLocale(): setup locales and install Unicode compatible stdout and |
|||
stderr ; |
|||
- getTerminalCharset(): guess terminal charset ; |
|||
- gettext(text) translate a string to current language. The function always |
|||
returns Unicode string. You can also use the alias: _() ; |
|||
- ngettext(singular, plural, count): translate a sentence with singular and |
|||
plural form. The function always returns Unicode string. |
|||
|
|||
WARNING: Loading this module indirectly calls initLocale() which sets |
|||
locale LC_ALL to ''. This is needed to get user preferred locale |
|||
settings. |
|||
""" |
|||
|
|||
import hachoir_core.config as config |
|||
import hachoir_core |
|||
import locale |
|||
from os import path |
|||
import sys |
|||
from codecs import BOM_UTF8, BOM_UTF16_LE, BOM_UTF16_BE |
|||
|
|||
def _getTerminalCharset(): |
|||
""" |
|||
Function used by getTerminalCharset() to get terminal charset. |
|||
|
|||
@see getTerminalCharset() |
|||
""" |
|||
# (1) Try locale.getpreferredencoding() |
|||
try: |
|||
charset = locale.getpreferredencoding() |
|||
if charset: |
|||
return charset |
|||
except (locale.Error, AttributeError): |
|||
pass |
|||
|
|||
# (2) Try locale.nl_langinfo(CODESET) |
|||
try: |
|||
charset = locale.nl_langinfo(locale.CODESET) |
|||
if charset: |
|||
return charset |
|||
except (locale.Error, AttributeError): |
|||
pass |
|||
|
|||
# (3) Try sys.stdout.encoding |
|||
if hasattr(sys.stdout, "encoding") and sys.stdout.encoding: |
|||
return sys.stdout.encoding |
|||
|
|||
# (4) Otherwise, returns "ASCII" |
|||
return "ASCII" |
|||
|
|||
def getTerminalCharset(): |
|||
""" |
|||
Guess terminal charset using differents tests: |
|||
1. Try locale.getpreferredencoding() |
|||
2. Try locale.nl_langinfo(CODESET) |
|||
3. Try sys.stdout.encoding |
|||
4. Otherwise, returns "ASCII" |
|||
|
|||
WARNING: Call initLocale() before calling this function. |
|||
""" |
|||
try: |
|||
return getTerminalCharset.value |
|||
except AttributeError: |
|||
getTerminalCharset.value = _getTerminalCharset() |
|||
return getTerminalCharset.value |
|||
|
|||
class UnicodeStdout(object): |
|||
def __init__(self, old_device, charset): |
|||
self.device = old_device |
|||
self.charset = charset |
|||
|
|||
def flush(self): |
|||
self.device.flush() |
|||
|
|||
def write(self, text): |
|||
if isinstance(text, unicode): |
|||
text = text.encode(self.charset, 'replace') |
|||
self.device.write(text) |
|||
|
|||
def writelines(self, lines): |
|||
for text in lines: |
|||
self.write(text) |
|||
|
|||
def initLocale(): |
|||
# Only initialize locale once |
|||
if initLocale.is_done: |
|||
return getTerminalCharset() |
|||
initLocale.is_done = True |
|||
|
|||
# Setup locales |
|||
try: |
|||
locale.setlocale(locale.LC_ALL, "") |
|||
except (locale.Error, IOError): |
|||
pass |
|||
|
|||
# Get the terminal charset |
|||
charset = getTerminalCharset() |
|||
|
|||
# UnicodeStdout conflicts with the readline module |
|||
if config.unicode_stdout and ('readline' not in sys.modules): |
|||
# Replace stdout and stderr by unicode objet supporting unicode string |
|||
sys.stdout = UnicodeStdout(sys.stdout, charset) |
|||
sys.stderr = UnicodeStdout(sys.stderr, charset) |
|||
return charset |
|||
initLocale.is_done = False |
|||
|
|||
def _dummy_gettext(text): |
|||
return unicode(text) |
|||
|
|||
def _dummy_ngettext(singular, plural, count): |
|||
if 1 < abs(count) or not count: |
|||
return unicode(plural) |
|||
else: |
|||
return unicode(singular) |
|||
|
|||
def _initGettext(): |
|||
charset = initLocale() |
|||
|
|||
# Try to load gettext module |
|||
if config.use_i18n: |
|||
try: |
|||
import gettext |
|||
ok = True |
|||
except ImportError: |
|||
ok = False |
|||
else: |
|||
ok = False |
|||
|
|||
# gettext is not available or not needed: use dummy gettext functions |
|||
if not ok: |
|||
return (_dummy_gettext, _dummy_ngettext) |
|||
|
|||
# Gettext variables |
|||
package = hachoir_core.PACKAGE |
|||
locale_dir = path.join(path.dirname(__file__), "..", "locale") |
|||
|
|||
# Initialize gettext module |
|||
gettext.bindtextdomain(package, locale_dir) |
|||
gettext.textdomain(package) |
|||
translate = gettext.gettext |
|||
ngettext = gettext.ngettext |
|||
|
|||
# TODO: translate_unicode lambda function really sucks! |
|||
# => find native function to do that |
|||
unicode_gettext = lambda text: \ |
|||
unicode(translate(text), charset) |
|||
unicode_ngettext = lambda singular, plural, count: \ |
|||
unicode(ngettext(singular, plural, count), charset) |
|||
return (unicode_gettext, unicode_ngettext) |
|||
|
|||
UTF_BOMS = ( |
|||
(BOM_UTF8, "UTF-8"), |
|||
(BOM_UTF16_LE, "UTF-16-LE"), |
|||
(BOM_UTF16_BE, "UTF-16-BE"), |
|||
) |
|||
|
|||
# Set of valid characters for specific charset |
|||
CHARSET_CHARACTERS = ( |
|||
# U+00E0: LATIN SMALL LETTER A WITH GRAVE |
|||
(set(u"©®éêè\xE0ç".encode("ISO-8859-1")), "ISO-8859-1"), |
|||
(set(u"©®éêè\xE0ç€".encode("ISO-8859-15")), "ISO-8859-15"), |
|||
(set(u"©®".encode("MacRoman")), "MacRoman"), |
|||
(set(u"εδηιθκμοΡσςυΈί".encode("ISO-8859-7")), "ISO-8859-7"), |
|||
) |
|||
|
|||
def guessBytesCharset(bytes, default=None): |
|||
r""" |
|||
>>> guessBytesCharset("abc") |
|||
'ASCII' |
|||
>>> guessBytesCharset("\xEF\xBB\xBFabc") |
|||
'UTF-8' |
|||
>>> guessBytesCharset("abc\xC3\xA9") |
|||
'UTF-8' |
|||
>>> guessBytesCharset("File written by Adobe Photoshop\xA8 4.0\0") |
|||
'MacRoman' |
|||
>>> guessBytesCharset("\xE9l\xE9phant") |
|||
'ISO-8859-1' |
|||
>>> guessBytesCharset("100 \xA4") |
|||
'ISO-8859-15' |
|||
>>> guessBytesCharset('Word \xb8\xea\xe4\xef\xf3\xe7 - Microsoft Outlook 97 - \xd1\xf5\xe8\xec\xdf\xf3\xe5\xe9\xf2 e-mail') |
|||
'ISO-8859-7' |
|||
""" |
|||
# Check for UTF BOM |
|||
for bom_bytes, charset in UTF_BOMS: |
|||
if bytes.startswith(bom_bytes): |
|||
return charset |
|||
|
|||
# Pure ASCII? |
|||
try: |
|||
text = unicode(bytes, 'ASCII', 'strict') |
|||
return 'ASCII' |
|||
except UnicodeDecodeError: |
|||
pass |
|||
|
|||
# Valid UTF-8? |
|||
try: |
|||
text = unicode(bytes, 'UTF-8', 'strict') |
|||
return 'UTF-8' |
|||
except UnicodeDecodeError: |
|||
pass |
|||
|
|||
# Create a set of non-ASCII characters |
|||
non_ascii_set = set( byte for byte in bytes if ord(byte) >= 128 ) |
|||
for characters, charset in CHARSET_CHARACTERS: |
|||
if characters.issuperset(non_ascii_set): |
|||
return charset |
|||
return default |
|||
|
|||
# Initialize _(), gettext() and ngettext() functions |
|||
gettext, ngettext = _initGettext() |
|||
_ = gettext |
|||
|
@ -0,0 +1,558 @@ |
|||
# -*- coding: utf-8 -*- |
|||
""" |
|||
ISO639-2 standart: the module only contains the dictionary ISO639_2 |
|||
which maps a language code in three letters (eg. "fre") to a language |
|||
name in english (eg. "French"). |
|||
""" |
|||
|
|||
# ISO-639, the list comes from: |
|||
# http://www.loc.gov/standards/iso639-2/php/English_list.php |
|||
_ISO639 = ( |
|||
(u"Abkhazian", "abk", "ab"), |
|||
(u"Achinese", "ace", None), |
|||
(u"Acoli", "ach", None), |
|||
(u"Adangme", "ada", None), |
|||
(u"Adygei", "ady", None), |
|||
(u"Adyghe", "ady", None), |
|||
(u"Afar", "aar", "aa"), |
|||
(u"Afrihili", "afh", None), |
|||
(u"Afrikaans", "afr", "af"), |
|||
(u"Afro-Asiatic (Other)", "afa", None), |
|||
(u"Ainu", "ain", None), |
|||
(u"Akan", "aka", "ak"), |
|||
(u"Akkadian", "akk", None), |
|||
(u"Albanian", "alb/sqi", "sq"), |
|||
(u"Alemani", "gsw", None), |
|||
(u"Aleut", "ale", None), |
|||
(u"Algonquian languages", "alg", None), |
|||
(u"Altaic (Other)", "tut", None), |
|||
(u"Amharic", "amh", "am"), |
|||
(u"Angika", "anp", None), |
|||
(u"Apache languages", "apa", None), |
|||
(u"Arabic", "ara", "ar"), |
|||
(u"Aragonese", "arg", "an"), |
|||
(u"Aramaic", "arc", None), |
|||
(u"Arapaho", "arp", None), |
|||
(u"Araucanian", "arn", None), |
|||
(u"Arawak", "arw", None), |
|||
(u"Armenian", "arm/hye", "hy"), |
|||
(u"Aromanian", "rup", None), |
|||
(u"Artificial (Other)", "art", None), |
|||
(u"Arumanian", "rup", None), |
|||
(u"Assamese", "asm", "as"), |
|||
(u"Asturian", "ast", None), |
|||
(u"Athapascan languages", "ath", None), |
|||
(u"Australian languages", "aus", None), |
|||
(u"Austronesian (Other)", "map", None), |
|||
(u"Avaric", "ava", "av"), |
|||
(u"Avestan", "ave", "ae"), |
|||
(u"Awadhi", "awa", None), |
|||
(u"Aymara", "aym", "ay"), |
|||
(u"Azerbaijani", "aze", "az"), |
|||
(u"Bable", "ast", None), |
|||
(u"Balinese", "ban", None), |
|||
(u"Baltic (Other)", "bat", None), |
|||
(u"Baluchi", "bal", None), |
|||
(u"Bambara", "bam", "bm"), |
|||
(u"Bamileke languages", "bai", None), |
|||
(u"Banda", "bad", None), |
|||
(u"Bantu (Other)", "bnt", None), |
|||
(u"Basa", "bas", None), |
|||
(u"Bashkir", "bak", "ba"), |
|||
(u"Basque", "baq/eus", "eu"), |
|||
(u"Batak (Indonesia)", "btk", None), |
|||
(u"Beja", "bej", None), |
|||
(u"Belarusian", "bel", "be"), |
|||
(u"Bemba", "bem", None), |
|||
(u"Bengali", "ben", "bn"), |
|||
(u"Berber (Other)", "ber", None), |
|||
(u"Bhojpuri", "bho", None), |
|||
(u"Bihari", "bih", "bh"), |
|||
(u"Bikol", "bik", None), |
|||
(u"Bilin", "byn", None), |
|||
(u"Bini", "bin", None), |
|||
(u"Bislama", "bis", "bi"), |
|||
(u"Blin", "byn", None), |
|||
(u"Bokmål, Norwegian", "nob", "nb"), |
|||
(u"Bosnian", "bos", "bs"), |
|||
(u"Braj", "bra", None), |
|||
(u"Breton", "bre", "br"), |
|||
(u"Buginese", "bug", None), |
|||
(u"Bulgarian", "bul", "bg"), |
|||
(u"Buriat", "bua", None), |
|||
(u"Burmese", "bur/mya", "my"), |
|||
(u"Caddo", "cad", None), |
|||
(u"Carib", "car", None), |
|||
(u"Castilian", "spa", "es"), |
|||
(u"Catalan", "cat", "ca"), |
|||
(u"Caucasian (Other)", "cau", None), |
|||
(u"Cebuano", "ceb", None), |
|||
(u"Celtic (Other)", "cel", None), |
|||
(u"Central American Indian (Other)", "cai", None), |
|||
(u"Chagatai", "chg", None), |
|||
(u"Chamic languages", "cmc", None), |
|||
(u"Chamorro", "cha", "ch"), |
|||
(u"Chechen", "che", "ce"), |
|||
(u"Cherokee", "chr", None), |
|||
(u"Chewa", "nya", "ny"), |
|||
(u"Cheyenne", "chy", None), |
|||
(u"Chibcha", "chb", None), |
|||
(u"Chichewa", "nya", "ny"), |
|||
(u"Chinese", "chi/zho", "zh"), |
|||
(u"Chinook jargon", "chn", None), |
|||
(u"Chipewyan", "chp", None), |
|||
(u"Choctaw", "cho", None), |
|||
(u"Chuang", "zha", "za"), |
|||
(u"Church Slavic", "chu", "cu"), |
|||
(u"Church Slavonic", "chu", "cu"), |
|||
(u"Chuukese", "chk", None), |
|||
(u"Chuvash", "chv", "cv"), |
|||
(u"Classical Nepal Bhasa", "nwc", None), |
|||
(u"Classical Newari", "nwc", None), |
|||
(u"Coptic", "cop", None), |
|||
(u"Cornish", "cor", "kw"), |
|||
(u"Corsican", "cos", "co"), |
|||
(u"Cree", "cre", "cr"), |
|||
(u"Creek", "mus", None), |
|||
(u"Creoles and pidgins (Other)", "crp", None), |
|||
(u"Creoles and pidgins, English based (Other)", "cpe", None), |
|||
(u"Creoles and pidgins, French-based (Other)", "cpf", None), |
|||
(u"Creoles and pidgins, Portuguese-based (Other)", "cpp", None), |
|||
(u"Crimean Tatar", "crh", None), |
|||
(u"Crimean Turkish", "crh", None), |
|||
(u"Croatian", "scr/hrv", "hr"), |
|||
(u"Cushitic (Other)", "cus", None), |
|||
(u"Czech", "cze/ces", "cs"), |
|||
(u"Dakota", "dak", None), |
|||
(u"Danish", "dan", "da"), |
|||
(u"Dargwa", "dar", None), |
|||
(u"Dayak", "day", None), |
|||
(u"Delaware", "del", None), |
|||
(u"Dhivehi", "div", "dv"), |
|||
(u"Dimili", "zza", None), |
|||
(u"Dimli", "zza", None), |
|||
(u"Dinka", "din", None), |
|||
(u"Divehi", "div", "dv"), |
|||
(u"Dogri", "doi", None), |
|||
(u"Dogrib", "dgr", None), |
|||
(u"Dravidian (Other)", "dra", None), |
|||
(u"Duala", "dua", None), |
|||
(u"Dutch", "dut/nld", "nl"), |
|||
(u"Dutch, Middle (ca.1050-1350)", "dum", None), |
|||
(u"Dyula", "dyu", None), |
|||
(u"Dzongkha", "dzo", "dz"), |
|||
(u"Eastern Frisian", "frs", None), |
|||
(u"Efik", "efi", None), |
|||
(u"Egyptian (Ancient)", "egy", None), |
|||
(u"Ekajuk", "eka", None), |
|||
(u"Elamite", "elx", None), |
|||
(u"English", "eng", "en"), |
|||
(u"English, Middle (1100-1500)", "enm", None), |
|||
(u"English, Old (ca.450-1100)", "ang", None), |
|||
(u"Erzya", "myv", None), |
|||
(u"Esperanto", "epo", "eo"), |
|||
(u"Estonian", "est", "et"), |
|||
(u"Ewe", "ewe", "ee"), |
|||
(u"Ewondo", "ewo", None), |
|||
(u"Fang", "fan", None), |
|||
(u"Fanti", "fat", None), |
|||
(u"Faroese", "fao", "fo"), |
|||
(u"Fijian", "fij", "fj"), |
|||
(u"Filipino", "fil", None), |
|||
(u"Finnish", "fin", "fi"), |
|||
(u"Finno-Ugrian (Other)", "fiu", None), |
|||
(u"Flemish", "dut/nld", "nl"), |
|||
(u"Fon", "fon", None), |
|||
(u"French", "fre/fra", "fr"), |
|||
(u"French, Middle (ca.1400-1600)", "frm", None), |
|||
(u"French, Old (842-ca.1400)", "fro", None), |
|||
(u"Friulian", "fur", None), |
|||
(u"Fulah", "ful", "ff"), |
|||
(u"Ga", "gaa", None), |
|||
(u"Gaelic", "gla", "gd"), |
|||
(u"Galician", "glg", "gl"), |
|||
(u"Ganda", "lug", "lg"), |
|||
(u"Gayo", "gay", None), |
|||
(u"Gbaya", "gba", None), |
|||
(u"Geez", "gez", None), |
|||
(u"Georgian", "geo/kat", "ka"), |
|||
(u"German", "ger/deu", "de"), |
|||
(u"German, Low", "nds", None), |
|||
(u"German, Middle High (ca.1050-1500)", "gmh", None), |
|||
(u"German, Old High (ca.750-1050)", "goh", None), |
|||
(u"Germanic (Other)", "gem", None), |
|||
(u"Gikuyu", "kik", "ki"), |
|||
(u"Gilbertese", "gil", None), |
|||
(u"Gondi", "gon", None), |
|||
(u"Gorontalo", "gor", None), |
|||
(u"Gothic", "got", None), |
|||
(u"Grebo", "grb", None), |
|||
(u"Greek, Ancient (to 1453)", "grc", None), |
|||
(u"Greek, Modern (1453-)", "gre/ell", "el"), |
|||
(u"Greenlandic", "kal", "kl"), |
|||
(u"Guarani", "grn", "gn"), |
|||
(u"Gujarati", "guj", "gu"), |
|||
(u"Gwich´in", "gwi", None), |
|||
(u"Haida", "hai", None), |
|||
(u"Haitian", "hat", "ht"), |
|||
(u"Haitian Creole", "hat", "ht"), |
|||
(u"Hausa", "hau", "ha"), |
|||
(u"Hawaiian", "haw", None), |
|||
(u"Hebrew", "heb", "he"), |
|||
(u"Herero", "her", "hz"), |
|||
(u"Hiligaynon", "hil", None), |
|||
(u"Himachali", "him", None), |
|||
(u"Hindi", "hin", "hi"), |
|||
(u"Hiri Motu", "hmo", "ho"), |
|||
(u"Hittite", "hit", None), |
|||
(u"Hmong", "hmn", None), |
|||
(u"Hungarian", "hun", "hu"), |
|||
(u"Hupa", "hup", None), |
|||
(u"Iban", "iba", None), |
|||
(u"Icelandic", "ice/isl", "is"), |
|||
(u"Ido", "ido", "io"), |
|||
(u"Igbo", "ibo", "ig"), |
|||
(u"Ijo", "ijo", None), |
|||
(u"Iloko", "ilo", None), |
|||
(u"Inari Sami", "smn", None), |
|||
(u"Indic (Other)", "inc", None), |
|||
(u"Indo-European (Other)", "ine", None), |
|||
(u"Indonesian", "ind", "id"), |
|||
(u"Ingush", "inh", None), |
|||
(u"Interlingua", "ina", "ia"), |
|||
(u"Interlingue", "ile", "ie"), |
|||
(u"Inuktitut", "iku", "iu"), |
|||
(u"Inupiaq", "ipk", "ik"), |
|||
(u"Iranian (Other)", "ira", None), |
|||
(u"Irish", "gle", "ga"), |
|||
(u"Irish, Middle (900-1200)", "mga", None), |
|||
(u"Irish, Old (to 900)", "sga", None), |
|||
(u"Iroquoian languages", "iro", None), |
|||
(u"Italian", "ita", "it"), |
|||
(u"Japanese", "jpn", "ja"), |
|||
(u"Javanese", "jav", "jv"), |
|||
(u"Judeo-Arabic", "jrb", None), |
|||
(u"Judeo-Persian", "jpr", None), |
|||
(u"Kabardian", "kbd", None), |
|||
(u"Kabyle", "kab", None), |
|||
(u"Kachin", "kac", None), |
|||
(u"Kalaallisut", "kal", "kl"), |
|||
(u"Kalmyk", "xal", None), |
|||
(u"Kamba", "kam", None), |
|||
(u"Kannada", "kan", "kn"), |
|||
(u"Kanuri", "kau", "kr"), |
|||
(u"Kara-Kalpak", "kaa", None), |
|||
(u"Karachay-Balkar", "krc", None), |
|||
(u"Karelian", "krl", None), |
|||
(u"Karen", "kar", None), |
|||
(u"Kashmiri", "kas", "ks"), |
|||
(u"Kashubian", "csb", None), |
|||
(u"Kawi", "kaw", None), |
|||
(u"Kazakh", "kaz", "kk"), |
|||
(u"Khasi", "kha", None), |
|||
(u"Khmer", "khm", "km"), |
|||
(u"Khoisan (Other)", "khi", None), |
|||
(u"Khotanese", "kho", None), |
|||
(u"Kikuyu", "kik", "ki"), |
|||
(u"Kimbundu", "kmb", None), |
|||
(u"Kinyarwanda", "kin", "rw"), |
|||
(u"Kirdki", "zza", None), |
|||
(u"Kirghiz", "kir", "ky"), |
|||
(u"Kirmanjki", "zza", None), |
|||
(u"Klingon", "tlh", None), |
|||
(u"Komi", "kom", "kv"), |
|||
(u"Kongo", "kon", "kg"), |
|||
(u"Konkani", "kok", None), |
|||
(u"Korean", "kor", "ko"), |
|||
(u"Kosraean", "kos", None), |
|||
(u"Kpelle", "kpe", None), |
|||
(u"Kru", "kro", None), |
|||
(u"Kuanyama", "kua", "kj"), |
|||
(u"Kumyk", "kum", None), |
|||
(u"Kurdish", "kur", "ku"), |
|||
(u"Kurukh", "kru", None), |
|||
(u"Kutenai", "kut", None), |
|||
(u"Kwanyama", "kua", "kj"), |
|||
(u"Ladino", "lad", None), |
|||
(u"Lahnda", "lah", None), |
|||
(u"Lamba", "lam", None), |
|||
(u"Lao", "lao", "lo"), |
|||
(u"Latin", "lat", "la"), |
|||
(u"Latvian", "lav", "lv"), |
|||
(u"Letzeburgesch", "ltz", "lb"), |
|||
(u"Lezghian", "lez", None), |
|||
(u"Limburgan", "lim", "li"), |
|||
(u"Limburger", "lim", "li"), |
|||
(u"Limburgish", "lim", "li"), |
|||
(u"Lingala", "lin", "ln"), |
|||
(u"Lithuanian", "lit", "lt"), |
|||
(u"Lojban", "jbo", None), |
|||
(u"Low German", "nds", None), |
|||
(u"Low Saxon", "nds", None), |
|||
(u"Lower Sorbian", "dsb", None), |
|||
(u"Lozi", "loz", None), |
|||
(u"Luba-Katanga", "lub", "lu"), |
|||
(u"Luba-Lulua", "lua", None), |
|||
(u"Luiseno", "lui", None), |
|||
(u"Lule Sami", "smj", None), |
|||
(u"Lunda", "lun", None), |
|||
(u"Luo (Kenya and Tanzania)", "luo", None), |
|||
(u"Lushai", "lus", None), |
|||
(u"Luxembourgish", "ltz", "lb"), |
|||
(u"Macedo-Romanian", "rup", None), |
|||
(u"Macedonian", "mac/mkd", "mk"), |
|||
(u"Madurese", "mad", None), |
|||
(u"Magahi", "mag", None), |
|||
(u"Maithili", "mai", None), |
|||
(u"Makasar", "mak", None), |
|||
(u"Malagasy", "mlg", "mg"), |
|||
(u"Malay", "may/msa", "ms"), |
|||
(u"Malayalam", "mal", "ml"), |
|||
(u"Maldivian", "div", "dv"), |
|||
(u"Maltese", "mlt", "mt"), |
|||
(u"Manchu", "mnc", None), |
|||
(u"Mandar", "mdr", None), |
|||
(u"Mandingo", "man", None), |
|||
(u"Manipuri", "mni", None), |
|||
(u"Manobo languages", "mno", None), |
|||
(u"Manx", "glv", "gv"), |
|||
(u"Maori", "mao/mri", "mi"), |
|||
(u"Marathi", "mar", "mr"), |
|||
(u"Mari", "chm", None), |
|||
(u"Marshallese", "mah", "mh"), |
|||
(u"Marwari", "mwr", None), |
|||
(u"Masai", "mas", None), |
|||
(u"Mayan languages", "myn", None), |
|||
(u"Mende", "men", None), |
|||
(u"Mi'kmaq", "mic", None), |
|||
(u"Micmac", "mic", None), |
|||
(u"Minangkabau", "min", None), |
|||
(u"Mirandese", "mwl", None), |
|||
(u"Miscellaneous languages", "mis", None), |
|||
(u"Mohawk", "moh", None), |
|||
(u"Moksha", "mdf", None), |
|||
(u"Moldavian", "mol", "mo"), |
|||
(u"Mon-Khmer (Other)", "mkh", None), |
|||
(u"Mongo", "lol", None), |
|||
(u"Mongolian", "mon", "mn"), |
|||
(u"Mossi", "mos", None), |
|||
(u"Multiple languages", "mul", None), |
|||
(u"Munda languages", "mun", None), |
|||
(u"N'Ko", "nqo", None), |
|||
(u"Nahuatl", "nah", None), |
|||
(u"Nauru", "nau", "na"), |
|||
(u"Navaho", "nav", "nv"), |
|||
(u"Navajo", "nav", "nv"), |
|||
(u"Ndebele, North", "nde", "nd"), |
|||
(u"Ndebele, South", "nbl", "nr"), |
|||
(u"Ndonga", "ndo", "ng"), |
|||
(u"Neapolitan", "nap", None), |
|||
(u"Nepal Bhasa", "new", None), |
|||
(u"Nepali", "nep", "ne"), |
|||
(u"Newari", "new", None), |
|||
(u"Nias", "nia", None), |
|||
(u"Niger-Kordofanian (Other)", "nic", None), |
|||
(u"Nilo-Saharan (Other)", "ssa", None), |
|||
(u"Niuean", "niu", None), |
|||
(u"No linguistic content", "zxx", None), |
|||
(u"Nogai", "nog", None), |
|||
(u"Norse, Old", "non", None), |
|||
(u"North American Indian", "nai", None), |
|||
(u"North Ndebele", "nde", "nd"), |
|||
(u"Northern Frisian", "frr", None), |
|||
(u"Northern Sami", "sme", "se"), |
|||
(u"Northern Sotho", "nso", None), |
|||
(u"Norwegian", "nor", "no"), |
|||
(u"Norwegian Bokmål", "nob", "nb"), |
|||
(u"Norwegian Nynorsk", "nno", "nn"), |
|||
(u"Nubian languages", "nub", None), |
|||
(u"Nyamwezi", "nym", None), |
|||
(u"Nyanja", "nya", "ny"), |
|||
(u"Nyankole", "nyn", None), |
|||
(u"Nynorsk, Norwegian", "nno", "nn"), |
|||
(u"Nyoro", "nyo", None), |
|||
(u"Nzima", "nzi", None), |
|||
(u"Occitan (post 1500)", "oci", "oc"), |
|||
(u"Oirat", "xal", None), |
|||
(u"Ojibwa", "oji", "oj"), |
|||
(u"Old Bulgarian", "chu", "cu"), |
|||
(u"Old Church Slavonic", "chu", "cu"), |
|||
(u"Old Newari", "nwc", None), |
|||
(u"Old Slavonic", "chu", "cu"), |
|||
(u"Oriya", "ori", "or"), |
|||
(u"Oromo", "orm", "om"), |
|||
(u"Osage", "osa", None), |
|||
(u"Ossetian", "oss", "os"), |
|||
(u"Ossetic", "oss", "os"), |
|||
(u"Otomian languages", "oto", None), |
|||
(u"Pahlavi", "pal", None), |
|||
(u"Palauan", "pau", None), |
|||
(u"Pali", "pli", "pi"), |
|||
(u"Pampanga", "pam", None), |
|||
(u"Pangasinan", "pag", None), |
|||
(u"Panjabi", "pan", "pa"), |
|||
(u"Papiamento", "pap", None), |
|||
(u"Papuan (Other)", "paa", None), |
|||
(u"Pedi", "nso", None), |
|||
(u"Persian", "per/fas", "fa"), |
|||
(u"Persian, Old (ca.600-400 B.C.)", "peo", None), |
|||
(u"Philippine (Other)", "phi", None), |
|||
(u"Phoenician", "phn", None), |
|||
(u"Pilipino", "fil", None), |
|||
(u"Pohnpeian", "pon", None), |
|||
(u"Polish", "pol", "pl"), |
|||
(u"Portuguese", "por", "pt"), |
|||
(u"Prakrit languages", "pra", None), |
|||
(u"Provençal", "oci", "oc"), |
|||
(u"Provençal, Old (to 1500)", "pro", None), |
|||
(u"Punjabi", "pan", "pa"), |
|||
(u"Pushto", "pus", "ps"), |
|||
(u"Quechua", "que", "qu"), |
|||
(u"Raeto-Romance", "roh", "rm"), |
|||
(u"Rajasthani", "raj", None), |
|||
(u"Rapanui", "rap", None), |
|||
(u"Rarotongan", "rar", None), |
|||
(u"Reserved for local use", "qaa/qtz", None), |
|||
(u"Romance (Other)", "roa", None), |
|||
(u"Romanian", "rum/ron", "ro"), |
|||
(u"Romany", "rom", None), |
|||
(u"Rundi", "run", "rn"), |
|||
(u"Russian", "rus", "ru"), |
|||
(u"Salishan languages", "sal", None), |
|||
(u"Samaritan Aramaic", "sam", None), |
|||
(u"Sami languages (Other)", "smi", None), |
|||
(u"Samoan", "smo", "sm"), |
|||
(u"Sandawe", "sad", None), |
|||
(u"Sango", "sag", "sg"), |
|||
(u"Sanskrit", "san", "sa"), |
|||
(u"Santali", "sat", None), |
|||
(u"Sardinian", "srd", "sc"), |
|||
(u"Sasak", "sas", None), |
|||
(u"Saxon, Low", "nds", None), |
|||
(u"Scots", "sco", None), |
|||
(u"Scottish Gaelic", "gla", "gd"), |
|||
(u"Selkup", "sel", None), |
|||
(u"Semitic (Other)", "sem", None), |
|||
(u"Sepedi", "nso", None), |
|||
(u"Serbian", "scc/srp", "sr"), |
|||
(u"Serer", "srr", None), |
|||
(u"Shan", "shn", None), |
|||
(u"Shona", "sna", "sn"), |
|||
(u"Sichuan Yi", "iii", "ii"), |
|||
(u"Sicilian", "scn", None), |
|||
(u"Sidamo", "sid", None), |
|||
(u"Sign Languages", "sgn", None), |
|||
(u"Siksika", "bla", None), |
|||
(u"Sindhi", "snd", "sd"), |
|||
(u"Sinhala", "sin", "si"), |
|||
(u"Sinhalese", "sin", "si"), |
|||
(u"Sino-Tibetan (Other)", "sit", None), |
|||
(u"Siouan languages", "sio", None), |
|||
(u"Skolt Sami", "sms", None), |
|||
(u"Slave (Athapascan)", "den", None), |
|||
(u"Slavic (Other)", "sla", None), |
|||
(u"Slovak", "slo/slk", "sk"), |
|||
(u"Slovenian", "slv", "sl"), |
|||
(u"Sogdian", "sog", None), |
|||
(u"Somali", "som", "so"), |
|||
(u"Songhai", "son", None), |
|||
(u"Soninke", "snk", None), |
|||
(u"Sorbian languages", "wen", None), |
|||
(u"Sotho, Northern", "nso", None), |
|||
(u"Sotho, Southern", "sot", "st"), |
|||
(u"South American Indian (Other)", "sai", None), |
|||
(u"South Ndebele", "nbl", "nr"), |
|||
(u"Southern Altai", "alt", None), |
|||
(u"Southern Sami", "sma", None), |
|||
(u"Spanish", "spa", "es"), |
|||
(u"Sranan Togo", "srn", None), |
|||
(u"Sukuma", "suk", None), |
|||
(u"Sumerian", "sux", None), |
|||
(u"Sundanese", "sun", "su"), |
|||
(u"Susu", "sus", None), |
|||
(u"Swahili", "swa", "sw"), |
|||
(u"Swati", "ssw", "ss"), |
|||
(u"Swedish", "swe", "sv"), |
|||
(u"Swiss German", "gsw", None), |
|||
(u"Syriac", "syr", None), |
|||
(u"Tagalog", "tgl", "tl"), |
|||
(u"Tahitian", "tah", "ty"), |
|||
(u"Tai (Other)", "tai", None), |
|||
(u"Tajik", "tgk", "tg"), |
|||
(u"Tamashek", "tmh", None), |
|||
(u"Tamil", "tam", "ta"), |
|||
(u"Tatar", "tat", "tt"), |
|||
(u"Telugu", "tel", "te"), |
|||
(u"Tereno", "ter", None), |
|||
(u"Tetum", "tet", None), |
|||
(u"Thai", "tha", "th"), |
|||
(u"Tibetan", "tib/bod", "bo"), |
|||
(u"Tigre", "tig", None), |
|||
(u"Tigrinya", "tir", "ti"), |
|||
(u"Timne", "tem", None), |
|||
(u"Tiv", "tiv", None), |
|||
(u"tlhIngan-Hol", "tlh", None), |
|||
(u"Tlingit", "tli", None), |
|||
(u"Tok Pisin", "tpi", None), |
|||
(u"Tokelau", "tkl", None), |
|||
(u"Tonga (Nyasa)", "tog", None), |
|||
(u"Tonga (Tonga Islands)", "ton", "to"), |
|||
(u"Tsimshian", "tsi", None), |
|||
(u"Tsonga", "tso", "ts"), |
|||
(u"Tswana", "tsn", "tn"), |
|||
(u"Tumbuka", "tum", None), |
|||
(u"Tupi languages", "tup", None), |
|||
(u"Turkish", "tur", "tr"), |
|||
(u"Turkish, Ottoman (1500-1928)", "ota", None), |
|||
(u"Turkmen", "tuk", "tk"), |
|||
(u"Tuvalu", "tvl", None), |
|||
(u"Tuvinian", "tyv", None), |
|||
(u"Twi", "twi", "tw"), |
|||
(u"Udmurt", "udm", None), |
|||
(u"Ugaritic", "uga", None), |
|||
(u"Uighur", "uig", "ug"), |
|||
(u"Ukrainian", "ukr", "uk"), |
|||
(u"Umbundu", "umb", None), |
|||
(u"Undetermined", "und", None), |
|||
(u"Upper Sorbian", "hsb", None), |
|||
(u"Urdu", "urd", "ur"), |
|||
(u"Uyghur", "uig", "ug"), |
|||
(u"Uzbek", "uzb", "uz"), |
|||
(u"Vai", "vai", None), |
|||
(u"Valencian", "cat", "ca"), |
|||
(u"Venda", "ven", "ve"), |
|||
(u"Vietnamese", "vie", "vi"), |
|||
(u"Volapük", "vol", "vo"), |
|||
(u"Votic", "vot", None), |
|||
(u"Wakashan languages", "wak", None), |
|||
(u"Walamo", "wal", None), |
|||
(u"Walloon", "wln", "wa"), |
|||
(u"Waray", "war", None), |
|||
(u"Washo", "was", None), |
|||
(u"Welsh", "wel/cym", "cy"), |
|||
(u"Western Frisian", "fry", "fy"), |
|||
(u"Wolof", "wol", "wo"), |
|||
(u"Xhosa", "xho", "xh"), |
|||
(u"Yakut", "sah", None), |
|||
(u"Yao", "yao", None), |
|||
(u"Yapese", "yap", None), |
|||
(u"Yiddish", "yid", "yi"), |
|||
(u"Yoruba", "yor", "yo"), |
|||
(u"Yupik languages", "ypk", None), |
|||
(u"Zande", "znd", None), |
|||
(u"Zapotec", "zap", None), |
|||
(u"Zaza", "zza", None), |
|||
(u"Zazaki", "zza", None), |
|||
(u"Zenaga", "zen", None), |
|||
(u"Zhuang", "zha", "za"), |
|||
(u"Zulu", "zul", "zu"), |
|||
(u"Zuni", "zun", None), |
|||
) |
|||
|
|||
# Bibliographic ISO-639-2 form (eg. "fre" => "French") |
|||
ISO639_2 = {} |
|||
for line in _ISO639: |
|||
for key in line[1].split("/"): |
|||
ISO639_2[key] = line[0] |
|||
del _ISO639 |
|||
|
@ -0,0 +1,23 @@ |
|||
from hachoir_core.iso639 import ISO639_2 |
|||
|
|||
class Language: |
|||
def __init__(self, code): |
|||
code = str(code) |
|||
if code not in ISO639_2: |
|||
raise ValueError("Invalid language code: %r" % code) |
|||
self.code = code |
|||
|
|||
def __cmp__(self, other): |
|||
if other.__class__ != Language: |
|||
return 1 |
|||
return cmp(self.code, other.code) |
|||
|
|||
def __unicode__(self): |
|||
return ISO639_2[self.code] |
|||
|
|||
def __str__(self): |
|||
return self.__unicode__() |
|||
|
|||
def __repr__(self): |
|||
return "<Language '%s', code=%r>" % (unicode(self), self.code) |
|||
|
@ -0,0 +1,144 @@ |
|||
import os, sys, time |
|||
import hachoir_core.config as config |
|||
from hachoir_core.i18n import _ |
|||
|
|||
class Log: |
|||
LOG_INFO = 0 |
|||
LOG_WARN = 1 |
|||
LOG_ERROR = 2 |
|||
|
|||
level_name = { |
|||
LOG_WARN: "[warn]", |
|||
LOG_ERROR: "[err!]", |
|||
LOG_INFO: "[info]" |
|||
} |
|||
|
|||
def __init__(self): |
|||
self.__buffer = {} |
|||
self.__file = None |
|||
self.use_print = True |
|||
self.use_buffer = False |
|||
self.on_new_message = None # Prototype: def func(level, prefix, text, context) |
|||
|
|||
def shutdown(self): |
|||
if self.__file: |
|||
self._writeIntoFile(_("Stop Hachoir")) |
|||
|
|||
def setFilename(self, filename, append=True): |
|||
""" |
|||
Use a file to store all messages. The |
|||
UTF-8 encoding will be used. Write an informative |
|||
message if the file can't be created. |
|||
|
|||
@param filename: C{L{string}} |
|||
""" |
|||
|
|||
# Look if file already exists or not |
|||
filename = os.path.expanduser(filename) |
|||
filename = os.path.realpath(filename) |
|||
append = os.access(filename, os.F_OK) |
|||
|
|||
# Create log file (or open it in append mode, if it already exists) |
|||
try: |
|||
import codecs |
|||
if append: |
|||
self.__file = codecs.open(filename, "a", "utf-8") |
|||
else: |
|||
self.__file = codecs.open(filename, "w", "utf-8") |
|||
self._writeIntoFile(_("Starting Hachoir")) |
|||
except IOError, err: |
|||
if err.errno == 2: |
|||
self.__file = None |
|||
self.info(_("[Log] setFilename(%s) fails: no such file") % filename) |
|||
else: |
|||
raise |
|||
|
|||
def _writeIntoFile(self, message): |
|||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S") |
|||
self.__file.write(u"%s - %s\n" % (timestamp, message)) |
|||
self.__file.flush() |
|||
|
|||
def newMessage(self, level, text, ctxt=None): |
|||
""" |
|||
Write a new message : append it in the buffer, |
|||
display it to the screen (if needed), and write |
|||
it in the log file (if needed). |
|||
|
|||
@param level: Message level. |
|||
@type level: C{int} |
|||
@param text: Message content. |
|||
@type text: C{str} |
|||
@param ctxt: The caller instance. |
|||
""" |
|||
|
|||
if level < self.LOG_ERROR and config.quiet or \ |
|||
level <= self.LOG_INFO and not config.verbose: |
|||
return |
|||
if config.debug: |
|||
from hachoir_core.error import getBacktrace |
|||
backtrace = getBacktrace(None) |
|||
if backtrace: |
|||
text += "\n\n" + backtrace |
|||
|
|||
_text = text |
|||
if hasattr(ctxt, "_logger"): |
|||
_ctxt = ctxt._logger() |
|||
if _ctxt is not None: |
|||
text = "[%s] %s" % (_ctxt, text) |
|||
|
|||
# Add message to log buffer |
|||
if self.use_buffer: |
|||
if not self.__buffer.has_key(level): |
|||
self.__buffer[level] = [text] |
|||
else: |
|||
self.__buffer[level].append(text) |
|||
|
|||
# Add prefix |
|||
prefix = self.level_name.get(level, "[info]") |
|||
|
|||
# Display on stdout (if used) |
|||
if self.use_print: |
|||
sys.stdout.flush() |
|||
sys.stderr.write("%s %s\n" % (prefix, text)) |
|||
sys.stderr.flush() |
|||
|
|||
# Write into outfile (if used) |
|||
if self.__file: |
|||
self._writeIntoFile("%s %s" % (prefix, text)) |
|||
|
|||
# Use callback (if used) |
|||
if self.on_new_message: |
|||
self.on_new_message (level, prefix, _text, ctxt) |
|||
|
|||
def info(self, text): |
|||
""" |
|||
New informative message. |
|||
@type text: C{str} |
|||
""" |
|||
self.newMessage(Log.LOG_INFO, text) |
|||
|
|||
def warning(self, text): |
|||
""" |
|||
New warning message. |
|||
@type text: C{str} |
|||
""" |
|||
self.newMessage(Log.LOG_WARN, text) |
|||
|
|||
def error(self, text): |
|||
""" |
|||
New error message. |
|||
@type text: C{str} |
|||
""" |
|||
self.newMessage(Log.LOG_ERROR, text) |
|||
|
|||
log = Log() |
|||
|
|||
class Logger(object): |
|||
def _logger(self): |
|||
return "<%s>" % self.__class__.__name__ |
|||
def info(self, text): |
|||
log.newMessage(Log.LOG_INFO, text, self) |
|||
def warning(self, text): |
|||
log.newMessage(Log.LOG_WARN, text, self) |
|||
def error(self, text): |
|||
log.newMessage(Log.LOG_ERROR, text, self) |
@ -0,0 +1,99 @@ |
|||
import gc |
|||
|
|||
#---- Default implementation when resource is missing ---------------------- |
|||
PAGE_SIZE = 4096 |
|||
|
|||
def getMemoryLimit(): |
|||
""" |
|||
Get current memory limit in bytes. |
|||
|
|||
Return None on error. |
|||
""" |
|||
return None |
|||
|
|||
def setMemoryLimit(max_mem): |
|||
""" |
|||
Set memory limit in bytes. |
|||
Use value 'None' to disable memory limit. |
|||
|
|||
Return True if limit is set, False on error. |
|||
""" |
|||
return False |
|||
|
|||
def getMemorySize(): |
|||
""" |
|||
Read currenet process memory size: size of available virtual memory. |
|||
This value is NOT the real memory usage. |
|||
|
|||
This function only works on Linux (use /proc/self/statm file). |
|||
""" |
|||
try: |
|||
statm = open('/proc/self/statm').readline().split() |
|||
except IOError: |
|||
return None |
|||
return int(statm[0]) * PAGE_SIZE |
|||
|
|||
def clearCaches(): |
|||
""" |
|||
Try to clear all caches: call gc.collect() (Python garbage collector). |
|||
""" |
|||
gc.collect() |
|||
#import re; re.purge() |
|||
|
|||
try: |
|||
#---- 'resource' implementation --------------------------------------------- |
|||
from resource import getpagesize, getrlimit, setrlimit, RLIMIT_AS |
|||
|
|||
PAGE_SIZE = getpagesize() |
|||
|
|||
def getMemoryLimit(): |
|||
try: |
|||
limit = getrlimit(RLIMIT_AS)[0] |
|||
if 0 < limit: |
|||
limit *= PAGE_SIZE |
|||
return limit |
|||
except ValueError: |
|||
return None |
|||
|
|||
def setMemoryLimit(max_mem): |
|||
if max_mem is None: |
|||
max_mem = -1 |
|||
try: |
|||
setrlimit(RLIMIT_AS, (max_mem, -1)) |
|||
return True |
|||
except ValueError: |
|||
return False |
|||
except ImportError: |
|||
pass |
|||
|
|||
def limitedMemory(limit, func, *args, **kw): |
|||
""" |
|||
Limit memory grow when calling func(*args, **kw): |
|||
restrict memory grow to 'limit' bytes. |
|||
|
|||
Use try/except MemoryError to catch the error. |
|||
""" |
|||
# First step: clear cache to gain memory |
|||
clearCaches() |
|||
|
|||
# Get total program size |
|||
max_rss = getMemorySize() |
|||
if max_rss is not None: |
|||
# Get old limit and then set our new memory limit |
|||
old_limit = getMemoryLimit() |
|||
limit = max_rss + limit |
|||
limited = setMemoryLimit(limit) |
|||
else: |
|||
limited = False |
|||
|
|||
try: |
|||
# Call function |
|||
return func(*args, **kw) |
|||
finally: |
|||
# and unset our memory limit |
|||
if limited: |
|||
setMemoryLimit(old_limit) |
|||
|
|||
# After calling the function: clear all caches |
|||
clearCaches() |
|||
|
@ -0,0 +1,31 @@ |
|||
from hotshot import Profile |
|||
from hotshot.stats import load as loadStats |
|||
from os import unlink |
|||
|
|||
def runProfiler(func, args=tuple(), kw={}, verbose=True, nb_func=25, sort_by=('cumulative', 'calls')): |
|||
profile_filename = "/tmp/profiler" |
|||
prof = Profile(profile_filename) |
|||
try: |
|||
if verbose: |
|||
print "[+] Run profiler" |
|||
result = prof.runcall(func, *args, **kw) |
|||
prof.close() |
|||
if verbose: |
|||
print "[+] Stop profiler" |
|||
print "[+] Process data..." |
|||
stat = loadStats(profile_filename) |
|||
if verbose: |
|||
print "[+] Strip..." |
|||
stat.strip_dirs() |
|||
if verbose: |
|||
print "[+] Sort data..." |
|||
stat.sort_stats(*sort_by) |
|||
if verbose: |
|||
print |
|||
print "[+] Display statistics" |
|||
print |
|||
stat.print_stats(nb_func) |
|||
return result |
|||
finally: |
|||
unlink(profile_filename) |
|||
|
@ -0,0 +1,11 @@ |
|||
from hachoir_core.endian import BIG_ENDIAN, LITTLE_ENDIAN |
|||
from hachoir_core.stream.stream import StreamError |
|||
from hachoir_core.stream.input import ( |
|||
InputStreamError, |
|||
InputStream, InputIOStream, StringInputStream, |
|||
InputSubStream, InputFieldStream, |
|||
FragmentedStream, ConcatStream) |
|||
from hachoir_core.stream.input_helper import FileInputStream, guessStreamCharset |
|||
from hachoir_core.stream.output import (OutputStreamError, |
|||
FileOutputStream, StringOutputStream, OutputStream) |
|||
|
@ -0,0 +1,563 @@ |
|||
from hachoir_core.endian import BIG_ENDIAN, LITTLE_ENDIAN |
|||
from hachoir_core.error import info |
|||
from hachoir_core.log import Logger |
|||
from hachoir_core.bits import str2long |
|||
from hachoir_core.i18n import getTerminalCharset |
|||
from hachoir_core.tools import lowerBound |
|||
from hachoir_core.i18n import _ |
|||
from errno import ESPIPE |
|||
from weakref import ref as weakref_ref |
|||
from hachoir_core.stream import StreamError |
|||
|
|||
class InputStreamError(StreamError): |
|||
pass |
|||
|
|||
class ReadStreamError(InputStreamError): |
|||
def __init__(self, size, address, got=None): |
|||
self.size = size |
|||
self.address = address |
|||
self.got = got |
|||
if self.got is not None: |
|||
msg = _("Can't read %u bits at address %u (got %u bits)") % (self.size, self.address, self.got) |
|||
else: |
|||
msg = _("Can't read %u bits at address %u") % (self.size, self.address) |
|||
InputStreamError.__init__(self, msg) |
|||
|
|||
class NullStreamError(InputStreamError): |
|||
def __init__(self, source): |
|||
self.source = source |
|||
msg = _("Input size is nul (source='%s')!") % self.source |
|||
InputStreamError.__init__(self, msg) |
|||
|
|||
class FileFromInputStream: |
|||
_offset = 0 |
|||
_from_end = False |
|||
|
|||
def __init__(self, stream): |
|||
self.stream = stream |
|||
self._setSize(stream.askSize(self)) |
|||
|
|||
def _setSize(self, size): |
|||
if size is None: |
|||
self._size = size |
|||
elif size % 8: |
|||
raise InputStreamError("Invalid size") |
|||
else: |
|||
self._size = size // 8 |
|||
|
|||
def tell(self): |
|||
if self._from_end: |
|||
while self._size is None: |
|||
self.stream._feed(max(self.stream._current_size << 1, 1 << 16)) |
|||
self._from_end = False |
|||
self._offset += self._size |
|||
return self._offset |
|||
|
|||
def seek(self, pos, whence=0): |
|||
if whence == 0: |
|||
self._from_end = False |
|||
self._offset = pos |
|||
elif whence == 1: |
|||
self._offset += pos |
|||
elif whence == 2: |
|||
self._from_end = True |
|||
self._offset = pos |
|||
else: |
|||
raise ValueError("seek() second argument must be 0, 1 or 2") |
|||
|
|||
def read(self, size=None): |
|||
def read(address, size): |
|||
shift, data, missing = self.stream.read(8 * address, 8 * size) |
|||
if shift: |
|||
raise InputStreamError("TODO: handle non-byte-aligned data") |
|||
return data |
|||
if self._size or size is not None and not self._from_end: |
|||
# We don't want self.tell() to read anything |
|||
# and the size must be known if we read until the end. |
|||
pos = self.tell() |
|||
if size is None or None < self._size < pos + size: |
|||
size = self._size - pos |
|||
if size <= 0: |
|||
return '' |
|||
data = read(pos, size) |
|||
self._offset += len(data) |
|||
return data |
|||
elif self._from_end: |
|||
# TODO: not tested |
|||
max_size = - self._offset |
|||
if size is None or max_size < size: |
|||
size = max_size |
|||
if size <= 0: |
|||
return '' |
|||
data = '', '' |
|||
self._offset = max(0, self.stream._current_size // 8 + self._offset) |
|||
self._from_end = False |
|||
bs = max(max_size, 1 << 16) |
|||
while True: |
|||
d = read(self._offset, bs) |
|||
data = data[1], d |
|||
self._offset += len(d) |
|||
if self._size: |
|||
bs = self._size - self._offset |
|||
if not bs: |
|||
data = data[0] + data[1] |
|||
d = len(data) - max_size |
|||
return data[d:d+size] |
|||
else: |
|||
# TODO: not tested |
|||
data = [ ] |
|||
size = 1 << 16 |
|||
while True: |
|||
d = read(self._offset, size) |
|||
data.append(d) |
|||
self._offset += len(d) |
|||
if self._size: |
|||
size = self._size - self._offset |
|||
if not size: |
|||
return ''.join(data) |
|||
|
|||
|
|||
class InputStream(Logger): |
|||
_set_size = None |
|||
_current_size = 0 |
|||
|
|||
def __init__(self, source=None, size=None, packets=None, **args): |
|||
self.source = source |
|||
self._size = size # in bits |
|||
if size == 0: |
|||
raise NullStreamError(source) |
|||
self.tags = tuple(args.get("tags", tuple())) |
|||
self.packets = packets |
|||
|
|||
def askSize(self, client): |
|||
if self._size != self._current_size: |
|||
if self._set_size is None: |
|||
self._set_size = [] |
|||
self._set_size.append(weakref_ref(client)) |
|||
return self._size |
|||
|
|||
def _setSize(self, size=None): |
|||
assert self._size is None or self._current_size <= self._size |
|||
if self._size != self._current_size: |
|||
self._size = self._current_size |
|||
if not self._size: |
|||
raise NullStreamError(self.source) |
|||
if self._set_size: |
|||
for client in self._set_size: |
|||
client = client() |
|||
if client: |
|||
client._setSize(self._size) |
|||
del self._set_size |
|||
|
|||
size = property(lambda self: self._size, doc="Size of the stream in bits") |
|||
checked = property(lambda self: self._size == self._current_size) |
|||
|
|||
def sizeGe(self, size, const=False): |
|||
return self._current_size >= size or \ |
|||
not (None < self._size < size or const or self._feed(size)) |
|||
|
|||
def _feed(self, size): |
|||
return self.read(size-1,1)[2] |
|||
|
|||
def read(self, address, size): |
|||
""" |
|||
Read 'size' bits at position 'address' (in bits) |
|||
from the beginning of the stream. |
|||
""" |
|||
raise NotImplementedError |
|||
|
|||
def readBits(self, address, nbits, endian): |
|||
assert endian in (BIG_ENDIAN, LITTLE_ENDIAN) |
|||
|
|||
shift, data, missing = self.read(address, nbits) |
|||
if missing: |
|||
raise ReadStreamError(nbits, address) |
|||
value = str2long(data, endian) |
|||
if endian is BIG_ENDIAN: |
|||
value >>= len(data) * 8 - shift - nbits |
|||
else: |
|||
value >>= shift |
|||
return value & (1 << nbits) - 1 |
|||
|
|||
def readInteger(self, address, signed, nbits, endian): |
|||
""" Read an integer number """ |
|||
value = self.readBits(address, nbits, endian) |
|||
|
|||
# Signe number. Example with nbits=8: |
|||
# if 128 <= value: value -= 256 |
|||
if signed and (1 << (nbits-1)) <= value: |
|||
value -= (1 << nbits) |
|||
return value |
|||
|
|||
def readBytes(self, address, nb_bytes): |
|||
shift, data, missing = self.read(address, 8 * nb_bytes) |
|||
if shift: |
|||
raise InputStreamError("TODO: handle non-byte-aligned data") |
|||
if missing: |
|||
raise ReadStreamError(8 * nb_bytes, address) |
|||
return data |
|||
|
|||
def searchBytesLength(self, needle, include_needle, |
|||
start_address=0, end_address=None): |
|||
""" |
|||
If include_needle is True, add its length to the result. |
|||
Returns None is needle can't be found. |
|||
""" |
|||
|
|||
pos = self.searchBytes(needle, start_address, end_address) |
|||
if pos is None: |
|||
return None |
|||
length = (pos - start_address) // 8 |
|||
if include_needle: |
|||
length += len(needle) |
|||
return length |
|||
|
|||
def searchBytes(self, needle, start_address=0, end_address=None): |
|||
""" |
|||
Search some bytes in [start_address;end_address[. Addresses must |
|||
be aligned to byte. Returns the address of the bytes if found, |
|||
None else. |
|||
""" |
|||
if start_address % 8: |
|||
raise InputStreamError("Unable to search bytes with address with bit granularity") |
|||
length = len(needle) |
|||
size = max(3 * length, 4096) |
|||
buffer = '' |
|||
|
|||
if self._size and (end_address is None or self._size < end_address): |
|||
end_address = self._size |
|||
|
|||
while True: |
|||
if end_address is not None: |
|||
todo = (end_address - start_address) >> 3 |
|||
if todo < size: |
|||
if todo <= 0: |
|||
return None |
|||
size = todo |
|||
data = self.readBytes(start_address, size) |
|||
if end_address is None and self._size: |
|||
end_address = self._size |
|||
size = (end_address - start_address) >> 3 |
|||
assert size > 0 |
|||
data = data[:size] |
|||
start_address += 8 * size |
|||
buffer = buffer[len(buffer) - length + 1:] + data |
|||
found = buffer.find(needle) |
|||
if found >= 0: |
|||
return start_address + (found - len(buffer)) * 8 |
|||
|
|||
def file(self): |
|||
return FileFromInputStream(self) |
|||
|
|||
|
|||
class InputPipe(object): |
|||
""" |
|||
InputPipe makes input streams seekable by caching a certain |
|||
amount of data. The memory usage may be unlimited in worst cases. |
|||
A function (set_size) is called when the size of the stream is known. |
|||
|
|||
InputPipe sees the input stream as an array of blocks of |
|||
size = (2 ^ self.buffer_size) and self.buffers maps to this array. |
|||
It also maintains a circular ordered list of non-discarded blocks, |
|||
sorted by access time. |
|||
|
|||
Each element of self.buffers is an array of 3 elements: |
|||
* self.buffers[i][0] is the data. |
|||
len(self.buffers[i][0]) == 1 << self.buffer_size |
|||
(except at the end: the length may be smaller) |
|||
* self.buffers[i][1] is the index of a more recently used block |
|||
* self.buffers[i][2] is the opposite of self.buffers[1], |
|||
in order to have a double-linked list. |
|||
For any discarded block, self.buffers[i] = None |
|||
|
|||
self.last is the index of the most recently accessed block. |
|||
self.first is the first (= smallest index) non-discarded block. |
|||
|
|||
How InputPipe discards blocks: |
|||
* Just before returning from the read method. |
|||
* Only if there are more than self.buffer_nb_min blocks in memory. |
|||
* While self.buffers[self.first] is that least recently used block. |
|||
|
|||
Property: There is no hole in self.buffers, except at the beginning. |
|||
""" |
|||
buffer_nb_min = 256 |
|||
buffer_size = 16 |
|||
last = None |
|||
size = None |
|||
|
|||
def __init__(self, input, set_size=None): |
|||
self._input = input |
|||
self.first = self.address = 0 |
|||
self.buffers = [] |
|||
self.set_size = set_size |
|||
|
|||
current_size = property(lambda self: len(self.buffers) << self.buffer_size) |
|||
|
|||
def _append(self, data): |
|||
if self.last is None: |
|||
self.last = next = prev = 0 |
|||
else: |
|||
prev = self.last |
|||
last = self.buffers[prev] |
|||
next = last[1] |
|||
self.last = self.buffers[next][2] = last[1] = len(self.buffers) |
|||
self.buffers.append([ data, next, prev ]) |
|||
|
|||
def _get(self, index): |
|||
if index >= len(self.buffers): |
|||
return '' |
|||
buf = self.buffers[index] |
|||
if buf is None: |
|||
raise InputStreamError(_("Error: Buffers too small. Can't seek backward.")) |
|||
if self.last != index: |
|||
next = buf[1] |
|||
prev = buf[2] |
|||
self.buffers[next][2] = prev |
|||
self.buffers[prev][1] = next |
|||
first = self.buffers[self.last][1] |
|||
buf[1] = first |
|||
buf[2] = self.last |
|||
self.buffers[first][2] = index |
|||
self.buffers[self.last][1] = index |
|||
self.last = index |
|||
return buf[0] |
|||
|
|||
def _flush(self): |
|||
lim = len(self.buffers) - self.buffer_nb_min |
|||
while self.first < lim: |
|||
buf = self.buffers[self.first] |
|||
if buf[2] != self.last: |
|||
break |
|||
info("Discarding buffer %u." % self.first) |
|||
self.buffers[self.last][1] = buf[1] |
|||
self.buffers[buf[1]][2] = self.last |
|||
self.buffers[self.first] = None |
|||
self.first += 1 |
|||
|
|||
def seek(self, address): |
|||
assert 0 <= address |
|||
self.address = address |
|||
|
|||
def read(self, size): |
|||
end = self.address + size |
|||
for i in xrange(len(self.buffers), (end >> self.buffer_size) + 1): |
|||
data = self._input.read(1 << self.buffer_size) |
|||
if len(data) < 1 << self.buffer_size: |
|||
self.size = (len(self.buffers) << self.buffer_size) + len(data) |
|||
if self.set_size: |
|||
self.set_size(self.size) |
|||
if data: |
|||
self._append(data) |
|||
break |
|||
self._append(data) |
|||
block, offset = divmod(self.address, 1 << self.buffer_size) |
|||
data = ''.join(self._get(index) |
|||
for index in xrange(block, (end - 1 >> self.buffer_size) + 1) |
|||
)[offset:offset+size] |
|||
self._flush() |
|||
self.address += len(data) |
|||
return data |
|||
|
|||
class InputIOStream(InputStream): |
|||
def __init__(self, input, size=None, **args): |
|||
if not hasattr(input, "seek"): |
|||
if size is None: |
|||
input = InputPipe(input, self._setSize) |
|||
else: |
|||
input = InputPipe(input) |
|||
elif size is None: |
|||
try: |
|||
input.seek(0, 2) |
|||
size = input.tell() * 8 |
|||
except IOError, err: |
|||
if err.errno == ESPIPE: |
|||
input = InputPipe(input, self._setSize) |
|||
else: |
|||
charset = getTerminalCharset() |
|||
errmsg = unicode(str(err), charset) |
|||
source = args.get("source", "<inputio:%r>" % input) |
|||
raise InputStreamError(_("Unable to get size of %s: %s") % (source, errmsg)) |
|||
self._input = input |
|||
InputStream.__init__(self, size=size, **args) |
|||
|
|||
def __current_size(self): |
|||
if self._size: |
|||
return self._size |
|||
if self._input.size: |
|||
return 8 * self._input.size |
|||
return 8 * self._input.current_size |
|||
_current_size = property(__current_size) |
|||
|
|||
def read(self, address, size): |
|||
assert size > 0 |
|||
_size = self._size |
|||
address, shift = divmod(address, 8) |
|||
self._input.seek(address) |
|||
size = (size + shift + 7) >> 3 |
|||
data = self._input.read(size) |
|||
got = len(data) |
|||
missing = size != got |
|||
if missing and _size == self._size: |
|||
raise ReadStreamError(8 * size, 8 * address, 8 * got) |
|||
return shift, data, missing |
|||
|
|||
def file(self): |
|||
if hasattr(self._input, "fileno"): |
|||
from os import dup, fdopen |
|||
new_fd = dup(self._input.fileno()) |
|||
new_file = fdopen(new_fd, "r") |
|||
new_file.seek(0) |
|||
return new_file |
|||
return InputStream.file(self) |
|||
|
|||
|
|||
class StringInputStream(InputStream): |
|||
def __init__(self, data, source="<string>", **args): |
|||
self.data = data |
|||
InputStream.__init__(self, source=source, size=8*len(data), **args) |
|||
self._current_size = self._size |
|||
|
|||
def read(self, address, size): |
|||
address, shift = divmod(address, 8) |
|||
size = (size + shift + 7) >> 3 |
|||
data = self.data[address:address+size] |
|||
got = len(data) |
|||
if got != size: |
|||
raise ReadStreamError(8 * size, 8 * address, 8 * got) |
|||
return shift, data, False |
|||
|
|||
|
|||
class InputSubStream(InputStream): |
|||
def __init__(self, stream, offset, size=None, source=None, **args): |
|||
if offset is None: |
|||
offset = 0 |
|||
if size is None and stream.size is not None: |
|||
size = stream.size - offset |
|||
if None < size <= 0: |
|||
raise ValueError("InputSubStream: offset is outside input stream") |
|||
self.stream = stream |
|||
self._offset = offset |
|||
if source is None: |
|||
source = "<substream input=%s offset=%s size=%s>" % (stream.source, offset, size) |
|||
InputStream.__init__(self, source=source, size=size, **args) |
|||
self.stream.askSize(self) |
|||
|
|||
_current_size = property(lambda self: min(self._size, max(0, self.stream._current_size - self._offset))) |
|||
|
|||
def read(self, address, size): |
|||
return self.stream.read(self._offset + address, size) |
|||
|
|||
def InputFieldStream(field, **args): |
|||
if not field.parent: |
|||
return field.stream |
|||
stream = field.parent.stream |
|||
args["size"] = field.size |
|||
args.setdefault("source", stream.source + field.path) |
|||
return InputSubStream(stream, field.absolute_address, **args) |
|||
|
|||
|
|||
class FragmentedStream(InputStream): |
|||
def __init__(self, field, **args): |
|||
self.stream = field.parent.stream |
|||
data = field.getData() |
|||
self.fragments = [ (0, data.absolute_address, data.size) ] |
|||
self.next = field.next |
|||
args.setdefault("source", "%s%s" % (self.stream.source, field.path)) |
|||
InputStream.__init__(self, **args) |
|||
if not self.next: |
|||
self._current_size = data.size |
|||
self._setSize() |
|||
|
|||
def _feed(self, end): |
|||
if self._current_size < end: |
|||
if self.checked: |
|||
raise ReadStreamError(end - self._size, self._size) |
|||
a, fa, fs = self.fragments[-1] |
|||
while self.stream.sizeGe(fa + min(fs, end - a)): |
|||
a += fs |
|||
f = self.next |
|||
if a >= end: |
|||
self._current_size = end |
|||
if a == end and not f: |
|||
self._setSize() |
|||
return False |
|||
if f: |
|||
self.next = f.next |
|||
f = f.getData() |
|||
if not f: |
|||
self._current_size = a |
|||
self._setSize() |
|||
return True |
|||
fa = f.absolute_address |
|||
fs = f.size |
|||
self.fragments += [ (a, fa, fs) ] |
|||
self._current_size = a + max(0, self.stream.size - fa) |
|||
self._setSize() |
|||
return True |
|||
return False |
|||
|
|||
def read(self, address, size): |
|||
assert size > 0 |
|||
missing = self._feed(address + size) |
|||
if missing: |
|||
size = self._size - address |
|||
if size <= 0: |
|||
return 0, '', True |
|||
d = [] |
|||
i = lowerBound(self.fragments, lambda x: x[0] <= address) |
|||
a, fa, fs = self.fragments[i-1] |
|||
a -= address |
|||
fa -= a |
|||
fs += a |
|||
s = None |
|||
while True: |
|||
n = min(fs, size) |
|||
u, v, w = self.stream.read(fa, n) |
|||
assert not w |
|||
if s is None: |
|||
s = u |
|||
else: |
|||
assert not u |
|||
d += [ v ] |
|||
size -= n |
|||
if not size: |
|||
return s, ''.join(d), missing |
|||
a, fa, fs = self.fragments[i] |
|||
i += 1 |
|||
|
|||
|
|||
class ConcatStream(InputStream): |
|||
# TODO: concatene any number of any type of stream |
|||
def __init__(self, streams, **args): |
|||
if len(streams) > 2 or not streams[0].checked: |
|||
raise NotImplementedError |
|||
self.__size0 = streams[0].size |
|||
size1 = streams[1].askSize(self) |
|||
if size1 is not None: |
|||
args["size"] = self.__size0 + size1 |
|||
self.__streams = streams |
|||
InputStream.__init__(self, **args) |
|||
|
|||
_current_size = property(lambda self: self.__size0 + self.__streams[1]._current_size) |
|||
|
|||
def read(self, address, size): |
|||
_size = self._size |
|||
s = self.__size0 - address |
|||
shift, data, missing = None, '', False |
|||
if s > 0: |
|||
s = min(size, s) |
|||
shift, data, w = self.__streams[0].read(address, s) |
|||
assert not w |
|||
a, s = 0, size - s |
|||
else: |
|||
a, s = -s, size |
|||
if s: |
|||
u, v, missing = self.__streams[1].read(a, s) |
|||
if missing and _size == self._size: |
|||
raise ReadStreamError(s, a) |
|||
if shift is None: |
|||
shift = u |
|||
else: |
|||
assert not u |
|||
data += v |
|||
return shift, data, missing |
@ -0,0 +1,38 @@ |
|||
from hachoir_core.i18n import getTerminalCharset, guessBytesCharset, _ |
|||
from hachoir_core.stream import InputIOStream, InputSubStream, InputStreamError |
|||
|
|||
def FileInputStream(filename, real_filename=None, **args): |
|||
""" |
|||
Create an input stream of a file. filename must be unicode. |
|||
|
|||
real_filename is an optional argument used to specify the real filename, |
|||
its type can be 'str' or 'unicode'. Use real_filename when you are |
|||
not able to convert filename to real unicode string (ie. you have to |
|||
use unicode(name, 'replace') or unicode(name, 'ignore')). |
|||
""" |
|||
assert isinstance(filename, unicode) |
|||
if not real_filename: |
|||
real_filename = filename |
|||
try: |
|||
inputio = open(real_filename, 'rb') |
|||
except IOError, err: |
|||
charset = getTerminalCharset() |
|||
errmsg = unicode(str(err), charset) |
|||
raise InputStreamError(_("Unable to open file %s: %s") % (filename, errmsg)) |
|||
source = "file:" + filename |
|||
offset = args.pop("offset", 0) |
|||
size = args.pop("size", None) |
|||
if offset or size: |
|||
if size: |
|||
size = 8 * size |
|||
stream = InputIOStream(inputio, source=source, **args) |
|||
return InputSubStream(stream, 8 * offset, size, **args) |
|||
else: |
|||
args.setdefault("tags",[]).append(("filename", filename)) |
|||
return InputIOStream(inputio, source=source, **args) |
|||
|
|||
def guessStreamCharset(stream, address, size, default=None): |
|||
size = min(size, 1024*8) |
|||
bytes = stream.readBytes(address, size//8) |
|||
return guessBytesCharset(bytes, default) |
|||
|
@ -0,0 +1,173 @@ |
|||
from cStringIO import StringIO |
|||
from hachoir_core.endian import BIG_ENDIAN |
|||
from hachoir_core.bits import long2raw |
|||
from hachoir_core.stream import StreamError |
|||
from errno import EBADF |
|||
|
|||
MAX_READ_NBYTES = 2 ** 16 |
|||
|
|||
class OutputStreamError(StreamError): |
|||
pass |
|||
|
|||
class OutputStream(object): |
|||
def __init__(self, output, filename=None): |
|||
self._output = output |
|||
self._filename = filename |
|||
self._bit_pos = 0 |
|||
self._byte = 0 |
|||
|
|||
def _getFilename(self): |
|||
return self._filename |
|||
filename = property(_getFilename) |
|||
|
|||
def writeBit(self, state, endian): |
|||
if self._bit_pos == 7: |
|||
self._bit_pos = 0 |
|||
if state: |
|||
if endian is BIG_ENDIAN: |
|||
self._byte |= 1 |
|||
else: |
|||
self._byte |= 128 |
|||
self._output.write(chr(self._byte)) |
|||
self._byte = 0 |
|||
else: |
|||
if state: |
|||
if endian is BIG_ENDIAN: |
|||
self._byte |= (1 << self._bit_pos) |
|||
else: |
|||
self._byte |= (1 << (7-self._bit_pos)) |
|||
self._bit_pos += 1 |
|||
|
|||
def writeBits(self, count, value, endian): |
|||
assert 0 <= value < 2**count |
|||
|
|||
# Feed bits to align to byte address |
|||
if self._bit_pos != 0: |
|||
n = 8 - self._bit_pos |
|||
if n <= count: |
|||
count -= n |
|||
if endian is BIG_ENDIAN: |
|||
self._byte |= (value >> count) |
|||
value &= ((1 << count) - 1) |
|||
else: |
|||
self._byte |= (value & ((1 << n)-1)) << self._bit_pos |
|||
value >>= n |
|||
self._output.write(chr(self._byte)) |
|||
self._bit_pos = 0 |
|||
self._byte = 0 |
|||
else: |
|||
if endian is BIG_ENDIAN: |
|||
self._byte |= (value << (8-self._bit_pos-count)) |
|||
else: |
|||
self._byte |= (value << self._bit_pos) |
|||
self._bit_pos += count |
|||
return |
|||
|
|||
# Write byte per byte |
|||
while 8 <= count: |
|||
count -= 8 |
|||
if endian is BIG_ENDIAN: |
|||
byte = (value >> count) |
|||
value &= ((1 << count) - 1) |
|||
else: |
|||
byte = (value & 0xFF) |
|||
value >>= 8 |
|||
self._output.write(chr(byte)) |
|||
|
|||
# Keep last bits |
|||
assert 0 <= count < 8 |
|||
self._bit_pos = count |
|||
if 0 < count: |
|||
assert 0 <= value < 2**count |
|||
if endian is BIG_ENDIAN: |
|||
self._byte = value << (8-count) |
|||
else: |
|||
self._byte = value |
|||
else: |
|||
assert value == 0 |
|||
self._byte = 0 |
|||
|
|||
def writeInteger(self, value, signed, size_byte, endian): |
|||
if signed: |
|||
value += 1 << (size_byte*8 - 1) |
|||
raw = long2raw(value, endian, size_byte) |
|||
self.writeBytes(raw) |
|||
|
|||
def copyBitsFrom(self, input, address, nb_bits, endian): |
|||
if (nb_bits % 8) == 0: |
|||
self.copyBytesFrom(input, address, nb_bits/8) |
|||
else: |
|||
# Arbitrary limit (because we should use a buffer, like copyBytesFrom(), |
|||
# but with endianess problem |
|||
assert nb_bits <= 128 |
|||
data = input.readBits(address, nb_bits, endian) |
|||
self.writeBits(nb_bits, data, endian) |
|||
|
|||
def copyBytesFrom(self, input, address, nb_bytes): |
|||
if (address % 8): |
|||
raise OutputStreamError("Unable to copy bytes with address with bit granularity") |
|||
buffer_size = 1 << 12 # 8192 (8 KB) |
|||
while 0 < nb_bytes: |
|||
# Compute buffer size |
|||
if nb_bytes < buffer_size: |
|||
buffer_size = nb_bytes |
|||
|
|||
# Read |
|||
data = input.readBytes(address, buffer_size) |
|||
|
|||
# Write |
|||
self.writeBytes(data) |
|||
|
|||
# Move address |
|||
address += buffer_size*8 |
|||
nb_bytes -= buffer_size |
|||
|
|||
def writeBytes(self, bytes): |
|||
if self._bit_pos != 0: |
|||
raise NotImplementedError() |
|||
self._output.write(bytes) |
|||
|
|||
def readBytes(self, address, nbytes): |
|||
""" |
|||
Read bytes from the stream at specified address (in bits). |
|||
Address have to be a multiple of 8. |
|||
nbytes have to in 1..MAX_READ_NBYTES (64 KB). |
|||
|
|||
This method is only supported for StringOuputStream (not on |
|||
FileOutputStream). |
|||
|
|||
Return read bytes as byte string. |
|||
""" |
|||
assert (address % 8) == 0 |
|||
assert (1 <= nbytes <= MAX_READ_NBYTES) |
|||
self._output.flush() |
|||
oldpos = self._output.tell() |
|||
try: |
|||
self._output.seek(0) |
|||
try: |
|||
return self._output.read(nbytes) |
|||
except IOError, err: |
|||
if err[0] == EBADF: |
|||
raise OutputStreamError("Stream doesn't support read() operation") |
|||
finally: |
|||
self._output.seek(oldpos) |
|||
|
|||
def StringOutputStream(): |
|||
""" |
|||
Create an output stream into a string. |
|||
""" |
|||
data = StringIO() |
|||
return OutputStream(data) |
|||
|
|||
def FileOutputStream(filename, real_filename=None): |
|||
""" |
|||
Create an output stream into file with given name. |
|||
|
|||
Filename have to be unicode, whereas (optional) real_filename can be str. |
|||
""" |
|||
assert isinstance(filename, unicode) |
|||
if not real_filename: |
|||
real_filename = filename |
|||
output = open(real_filename, 'wb') |
|||
return OutputStream(output, filename=filename) |
|||
|
@ -0,0 +1,5 @@ |
|||
from hachoir_core.error import HachoirError |
|||
|
|||
class StreamError(HachoirError): |
|||
pass |
|||
|
@ -0,0 +1,60 @@ |
|||
""" |
|||
Utilities used to convert a field to human classic reprentation of data. |
|||
""" |
|||
|
|||
from hachoir_core.tools import ( |
|||
humanDuration, humanFilesize, alignValue, |
|||
durationWin64 as doDurationWin64, |
|||
deprecated) |
|||
from types import FunctionType, MethodType |
|||
from hachoir_core.field import Field |
|||
|
|||
def textHandler(field, handler): |
|||
assert isinstance(handler, (FunctionType, MethodType)) |
|||
assert issubclass(field.__class__, Field) |
|||
field.createDisplay = lambda: handler(field) |
|||
return field |
|||
|
|||
def displayHandler(field, handler): |
|||
assert isinstance(handler, (FunctionType, MethodType)) |
|||
assert issubclass(field.__class__, Field) |
|||
field.createDisplay = lambda: handler(field.value) |
|||
return field |
|||
|
|||
@deprecated("Use TimedeltaWin64 field type") |
|||
def durationWin64(field): |
|||
""" |
|||
Convert Windows 64-bit duration to string. The timestamp format is |
|||
a 64-bit number: number of 100ns. See also timestampWin64(). |
|||
|
|||
>>> durationWin64(type("", (), dict(value=2146280000, size=64))) |
|||
u'3 min 34 sec 628 ms' |
|||
>>> durationWin64(type("", (), dict(value=(1 << 64)-1, size=64))) |
|||
u'58494 years 88 days 5 hours' |
|||
""" |
|||
assert hasattr(field, "value") and hasattr(field, "size") |
|||
assert field.size == 64 |
|||
delta = doDurationWin64(field.value) |
|||
return humanDuration(delta) |
|||
|
|||
def filesizeHandler(field): |
|||
""" |
|||
Format field value using humanFilesize() |
|||
""" |
|||
return displayHandler(field, humanFilesize) |
|||
|
|||
def hexadecimal(field): |
|||
""" |
|||
Convert an integer to hexadecimal in lower case. Returns unicode string. |
|||
|
|||
>>> hexadecimal(type("", (), dict(value=412, size=16))) |
|||
u'0x019c' |
|||
>>> hexadecimal(type("", (), dict(value=0, size=32))) |
|||
u'0x00000000' |
|||
""" |
|||
assert hasattr(field, "value") and hasattr(field, "size") |
|||
size = field.size |
|||
padding = alignValue(size, 4) // 4 |
|||
pattern = u"0x%%0%ux" % padding |
|||
return pattern % field.value |
|||
|
@ -0,0 +1,76 @@ |
|||
""" |
|||
limitedTime(): set a timeout in seconds when calling a function, |
|||
raise a Timeout error if time exceed. |
|||
""" |
|||
from math import ceil |
|||
|
|||
IMPLEMENTATION = None |
|||
|
|||
class Timeout(RuntimeError): |
|||
""" |
|||
Timeout error, inherits from RuntimeError |
|||
""" |
|||
pass |
|||
|
|||
def signalHandler(signum, frame): |
|||
""" |
|||
Signal handler to catch timeout signal: raise Timeout exception. |
|||
""" |
|||
raise Timeout("Timeout exceed!") |
|||
|
|||
def limitedTime(second, func, *args, **kw): |
|||
""" |
|||
Call func(*args, **kw) with a timeout of second seconds. |
|||
""" |
|||
return func(*args, **kw) |
|||
|
|||
def fixTimeout(second): |
|||
""" |
|||
Fix timeout value: convert to integer with a minimum of 1 second |
|||
""" |
|||
if isinstance(second, float): |
|||
second = int(ceil(second)) |
|||
assert isinstance(second, (int, long)) |
|||
return max(second, 1) |
|||
|
|||
if not IMPLEMENTATION: |
|||
try: |
|||
from signal import signal, alarm, SIGALRM |
|||
|
|||
# signal.alarm() implementation |
|||
def limitedTime(second, func, *args, **kw): |
|||
second = fixTimeout(second) |
|||
old_alarm = signal(SIGALRM, signalHandler) |
|||
try: |
|||
alarm(second) |
|||
return func(*args, **kw) |
|||
finally: |
|||
alarm(0) |
|||
signal(SIGALRM, old_alarm) |
|||
|
|||
IMPLEMENTATION = "signal.alarm()" |
|||
except ImportError: |
|||
pass |
|||
|
|||
if not IMPLEMENTATION: |
|||
try: |
|||
from signal import signal, SIGXCPU |
|||
from resource import getrlimit, setrlimit, RLIMIT_CPU |
|||
|
|||
# resource.setrlimit(RLIMIT_CPU) implementation |
|||
# "Bug": timeout is 'CPU' time so sleep() are not part of the timeout |
|||
def limitedTime(second, func, *args, **kw): |
|||
second = fixTimeout(second) |
|||
old_alarm = signal(SIGXCPU, signalHandler) |
|||
current = getrlimit(RLIMIT_CPU) |
|||
try: |
|||
setrlimit(RLIMIT_CPU, (second, current[1])) |
|||
return func(*args, **kw) |
|||
finally: |
|||
setrlimit(RLIMIT_CPU, current) |
|||
signal(SIGXCPU, old_alarm) |
|||
|
|||
IMPLEMENTATION = "resource.setrlimit(RLIMIT_CPU)" |
|||
except ImportError: |
|||
pass |
|||
|
@ -0,0 +1,582 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
Various utilities. |
|||
""" |
|||
|
|||
from hachoir_core.i18n import _, ngettext |
|||
import re |
|||
import stat |
|||
from datetime import datetime, timedelta, MAXYEAR |
|||
from warnings import warn |
|||
|
|||
def deprecated(comment=None): |
|||
""" |
|||
This is a decorator which can be used to mark functions |
|||
as deprecated. It will result in a warning being emmitted |
|||
when the function is used. |
|||
|
|||
Examples: :: |
|||
|
|||
@deprecated |
|||
def oldfunc(): ... |
|||
|
|||
@deprecated("use newfunc()!") |
|||
def oldfunc2(): ... |
|||
|
|||
Code from: http://code.activestate.com/recipes/391367/ |
|||
""" |
|||
def _deprecated(func): |
|||
def newFunc(*args, **kwargs): |
|||
message = "Call to deprecated function %s" % func.__name__ |
|||
if comment: |
|||
message += ": " + comment |
|||
warn(message, category=DeprecationWarning, stacklevel=2) |
|||
return func(*args, **kwargs) |
|||
newFunc.__name__ = func.__name__ |
|||
newFunc.__doc__ = func.__doc__ |
|||
newFunc.__dict__.update(func.__dict__) |
|||
return newFunc |
|||
return _deprecated |
|||
|
|||
def paddingSize(value, align): |
|||
""" |
|||
Compute size of a padding field. |
|||
|
|||
>>> paddingSize(31, 4) |
|||
1 |
|||
>>> paddingSize(32, 4) |
|||
0 |
|||
>>> paddingSize(33, 4) |
|||
3 |
|||
|
|||
Note: (value + paddingSize(value, align)) == alignValue(value, align) |
|||
""" |
|||
if value % align != 0: |
|||
return align - (value % align) |
|||
else: |
|||
return 0 |
|||
|
|||
def alignValue(value, align): |
|||
""" |
|||
Align a value to next 'align' multiple. |
|||
|
|||
>>> alignValue(31, 4) |
|||
32 |
|||
>>> alignValue(32, 4) |
|||
32 |
|||
>>> alignValue(33, 4) |
|||
36 |
|||
|
|||
Note: alignValue(value, align) == (value + paddingSize(value, align)) |
|||
""" |
|||
|
|||
if value % align != 0: |
|||
return value + align - (value % align) |
|||
else: |
|||
return value |
|||
|
|||
def timedelta2seconds(delta): |
|||
""" |
|||
Convert a datetime.timedelta() objet to a number of second |
|||
(floatting point number). |
|||
|
|||
>>> timedelta2seconds(timedelta(seconds=2, microseconds=40000)) |
|||
2.04 |
|||
>>> timedelta2seconds(timedelta(minutes=1, milliseconds=250)) |
|||
60.25 |
|||
""" |
|||
return delta.microseconds / 1000000.0 \ |
|||
+ delta.seconds + delta.days * 60*60*24 |
|||
|
|||
def humanDurationNanosec(nsec): |
|||
""" |
|||
Convert a duration in nanosecond to human natural representation. |
|||
Returns an unicode string. |
|||
|
|||
>>> humanDurationNanosec(60417893) |
|||
u'60.42 ms' |
|||
""" |
|||
|
|||
# Nano second |
|||
if nsec < 1000: |
|||
return u"%u nsec" % nsec |
|||
|
|||
# Micro seconds |
|||
usec, nsec = divmod(nsec, 1000) |
|||
if usec < 1000: |
|||
return u"%.2f usec" % (usec+float(nsec)/1000) |
|||
|
|||
# Milli seconds |
|||
msec, usec = divmod(usec, 1000) |
|||
if msec < 1000: |
|||
return u"%.2f ms" % (msec + float(usec)/1000) |
|||
return humanDuration(msec) |
|||
|
|||
def humanDuration(delta): |
|||
""" |
|||
Convert a duration in millisecond to human natural representation. |
|||
Returns an unicode string. |
|||
|
|||
>>> humanDuration(0) |
|||
u'0 ms' |
|||
>>> humanDuration(213) |
|||
u'213 ms' |
|||
>>> humanDuration(4213) |
|||
u'4 sec 213 ms' |
|||
>>> humanDuration(6402309) |
|||
u'1 hour 46 min 42 sec' |
|||
""" |
|||
if not isinstance(delta, timedelta): |
|||
delta = timedelta(microseconds=delta*1000) |
|||
|
|||
# Milliseconds |
|||
text = [] |
|||
if 1000 <= delta.microseconds: |
|||
text.append(u"%u ms" % (delta.microseconds//1000)) |
|||
|
|||
# Seconds |
|||
minutes, seconds = divmod(delta.seconds, 60) |
|||
hours, minutes = divmod(minutes, 60) |
|||
if seconds: |
|||
text.append(u"%u sec" % seconds) |
|||
if minutes: |
|||
text.append(u"%u min" % minutes) |
|||
if hours: |
|||
text.append(ngettext("%u hour", "%u hours", hours) % hours) |
|||
|
|||
# Days |
|||
years, days = divmod(delta.days, 365) |
|||
if days: |
|||
text.append(ngettext("%u day", "%u days", days) % days) |
|||
if years: |
|||
text.append(ngettext("%u year", "%u years", years) % years) |
|||
if 3 < len(text): |
|||
text = text[-3:] |
|||
elif not text: |
|||
return u"0 ms" |
|||
return u" ".join(reversed(text)) |
|||
|
|||
def humanFilesize(size): |
|||
""" |
|||
Convert a file size in byte to human natural representation. |
|||
It uses the values: 1 KB is 1024 bytes, 1 MB is 1024 KB, etc. |
|||
The result is an unicode string. |
|||
|
|||
>>> humanFilesize(1) |
|||
u'1 byte' |
|||
>>> humanFilesize(790) |
|||
u'790 bytes' |
|||
>>> humanFilesize(256960) |
|||
u'250.9 KB' |
|||
""" |
|||
if size < 10000: |
|||
return ngettext("%u byte", "%u bytes", size) % size |
|||
units = [_("KB"), _("MB"), _("GB"), _("TB")] |
|||
size = float(size) |
|||
divisor = 1024 |
|||
for unit in units: |
|||
size = size / divisor |
|||
if size < divisor: |
|||
return "%.1f %s" % (size, unit) |
|||
return "%u %s" % (size, unit) |
|||
|
|||
def humanBitSize(size): |
|||
""" |
|||
Convert a size in bit to human classic representation. |
|||
It uses the values: 1 Kbit is 1000 bits, 1 Mbit is 1000 Kbit, etc. |
|||
The result is an unicode string. |
|||
|
|||
>>> humanBitSize(1) |
|||
u'1 bit' |
|||
>>> humanBitSize(790) |
|||
u'790 bits' |
|||
>>> humanBitSize(256960) |
|||
u'257.0 Kbit' |
|||
""" |
|||
divisor = 1000 |
|||
if size < divisor: |
|||
return ngettext("%u bit", "%u bits", size) % size |
|||
units = [u"Kbit", u"Mbit", u"Gbit", u"Tbit"] |
|||
size = float(size) |
|||
for unit in units: |
|||
size = size / divisor |
|||
if size < divisor: |
|||
return "%.1f %s" % (size, unit) |
|||
return u"%u %s" % (size, unit) |
|||
|
|||
def humanBitRate(size): |
|||
""" |
|||
Convert a bit rate to human classic representation. It uses humanBitSize() |
|||
to convert size into human reprensation. The result is an unicode string. |
|||
|
|||
>>> humanBitRate(790) |
|||
u'790 bits/sec' |
|||
>>> humanBitRate(256960) |
|||
u'257.0 Kbit/sec' |
|||
""" |
|||
return "".join((humanBitSize(size), "/sec")) |
|||
|
|||
def humanFrequency(hertz): |
|||
""" |
|||
Convert a frequency in hertz to human classic representation. |
|||
It uses the values: 1 KHz is 1000 Hz, 1 MHz is 1000 KMhz, etc. |
|||
The result is an unicode string. |
|||
|
|||
>>> humanFrequency(790) |
|||
u'790 Hz' |
|||
>>> humanFrequency(629469) |
|||
u'629.5 kHz' |
|||
""" |
|||
divisor = 1000 |
|||
if hertz < divisor: |
|||
return u"%u Hz" % hertz |
|||
units = [u"kHz", u"MHz", u"GHz", u"THz"] |
|||
hertz = float(hertz) |
|||
for unit in units: |
|||
hertz = hertz / divisor |
|||
if hertz < divisor: |
|||
return u"%.1f %s" % (hertz, unit) |
|||
return u"%s %s" % (hertz, unit) |
|||
|
|||
regex_control_code = re.compile(r"([\x00-\x1f\x7f])") |
|||
controlchars = tuple({ |
|||
# Don't use "\0", because "\0"+"0"+"1" = "\001" = "\1" (1 character) |
|||
# Same rease to not use octal syntax ("\1") |
|||
ord("\n"): r"\n", |
|||
ord("\r"): r"\r", |
|||
ord("\t"): r"\t", |
|||
ord("\a"): r"\a", |
|||
ord("\b"): r"\b", |
|||
}.get(code, '\\x%02x' % code) |
|||
for code in xrange(128) |
|||
) |
|||
|
|||
def makePrintable(data, charset, quote=None, to_unicode=False, smart=True): |
|||
r""" |
|||
Prepare a string to make it printable in the specified charset. |
|||
It escapes control characters. Characters with code bigger than 127 |
|||
are escaped if data type is 'str' or if charset is "ASCII". |
|||
|
|||
Examples with Unicode: |
|||
>>> aged = unicode("âgé", "UTF-8") |
|||
>>> repr(aged) # text type is 'unicode' |
|||
"u'\\xe2g\\xe9'" |
|||
>>> makePrintable("abc\0", "UTF-8") |
|||
'abc\\0' |
|||
>>> makePrintable(aged, "latin1") |
|||
'\xe2g\xe9' |
|||
>>> makePrintable(aged, "latin1", quote='"') |
|||
'"\xe2g\xe9"' |
|||
|
|||
Examples with string encoded in latin1: |
|||
>>> aged_latin = unicode("âgé", "UTF-8").encode("latin1") |
|||
>>> repr(aged_latin) # text type is 'str' |
|||
"'\\xe2g\\xe9'" |
|||
>>> makePrintable(aged_latin, "latin1") |
|||
'\\xe2g\\xe9' |
|||
>>> makePrintable("", "latin1") |
|||
'' |
|||
>>> makePrintable("a", "latin1", quote='"') |
|||
'"a"' |
|||
>>> makePrintable("", "latin1", quote='"') |
|||
'(empty)' |
|||
>>> makePrintable("abc", "latin1", quote="'") |
|||
"'abc'" |
|||
|
|||
Control codes: |
|||
>>> makePrintable("\0\x03\x0a\x10 \x7f", "latin1") |
|||
'\\0\\3\\n\\x10 \\x7f' |
|||
|
|||
Quote character may also be escaped (only ' and "): |
|||
>>> print makePrintable("a\"b", "latin-1", quote='"') |
|||
"a\"b" |
|||
>>> print makePrintable("a\"b", "latin-1", quote="'") |
|||
'a"b' |
|||
>>> print makePrintable("a'b", "latin-1", quote="'") |
|||
'a\'b' |
|||
""" |
|||
|
|||
if data: |
|||
if not isinstance(data, unicode): |
|||
data = unicode(data, "ISO-8859-1") |
|||
charset = "ASCII" |
|||
data = regex_control_code.sub( |
|||
lambda regs: controlchars[ord(regs.group(1))], data) |
|||
if quote: |
|||
if quote in "\"'": |
|||
data = data.replace(quote, '\\' + quote) |
|||
data = ''.join((quote, data, quote)) |
|||
elif quote: |
|||
data = "(empty)" |
|||
data = data.encode(charset, "backslashreplace") |
|||
if smart: |
|||
# Replace \x00\x01 by \0\1 |
|||
data = re.sub(r"\\x0([0-7])(?=[^0-7]|$)", r"\\\1", data) |
|||
if to_unicode: |
|||
data = unicode(data, charset) |
|||
return data |
|||
|
|||
def makeUnicode(text): |
|||
r""" |
|||
Convert text to printable Unicode string. For byte string (type 'str'), |
|||
use charset ISO-8859-1 for the conversion to Unicode |
|||
|
|||
>>> makeUnicode(u'abc\0d') |
|||
u'abc\\0d' |
|||
>>> makeUnicode('a\xe9') |
|||
u'a\xe9' |
|||
""" |
|||
if isinstance(text, str): |
|||
text = unicode(text, "ISO-8859-1") |
|||
elif not isinstance(text, unicode): |
|||
text = unicode(text) |
|||
text = regex_control_code.sub( |
|||
lambda regs: controlchars[ord(regs.group(1))], text) |
|||
text = re.sub(r"\\x0([0-7])(?=[^0-7]|$)", r"\\\1", text) |
|||
return text |
|||
|
|||
def binarySearch(seq, cmp_func): |
|||
""" |
|||
Search a value in a sequence using binary search. Returns index of the |
|||
value, or None if the value doesn't exist. |
|||
|
|||
'seq' have to be sorted in ascending order according to the |
|||
comparaison function ; |
|||
|
|||
'cmp_func', prototype func(x), is the compare function: |
|||
- Return strictly positive value if we have to search forward ; |
|||
- Return strictly negative value if we have to search backward ; |
|||
- Otherwise (zero) we got the value. |
|||
|
|||
>>> # Search number 5 (search forward) |
|||
... binarySearch([0, 4, 5, 10], lambda x: 5-x) |
|||
2 |
|||
>>> # Backward search |
|||
... binarySearch([10, 5, 4, 0], lambda x: x-5) |
|||
1 |
|||
""" |
|||
lower = 0 |
|||
upper = len(seq) |
|||
while lower < upper: |
|||
index = (lower + upper) >> 1 |
|||
diff = cmp_func(seq[index]) |
|||
if diff < 0: |
|||
upper = index |
|||
elif diff > 0: |
|||
lower = index + 1 |
|||
else: |
|||
return index |
|||
return None |
|||
|
|||
def lowerBound(seq, cmp_func): |
|||
f = 0 |
|||
l = len(seq) |
|||
while l > 0: |
|||
h = l >> 1 |
|||
m = f + h |
|||
if cmp_func(seq[m]): |
|||
f = m |
|||
f += 1 |
|||
l -= h + 1 |
|||
else: |
|||
l = h |
|||
return f |
|||
|
|||
def humanUnixAttributes(mode): |
|||
""" |
|||
Convert a Unix file attributes (or "file mode") to an unicode string. |
|||
|
|||
Original source code: |
|||
http://cvs.savannah.gnu.org/viewcvs/coreutils/lib/filemode.c?root=coreutils |
|||
|
|||
>>> humanUnixAttributes(0644) |
|||
u'-rw-r--r-- (644)' |
|||
>>> humanUnixAttributes(02755) |
|||
u'-rwxr-sr-x (2755)' |
|||
""" |
|||
|
|||
def ftypelet(mode): |
|||
if stat.S_ISREG (mode) or not stat.S_IFMT(mode): |
|||
return '-' |
|||
if stat.S_ISBLK (mode): return 'b' |
|||
if stat.S_ISCHR (mode): return 'c' |
|||
if stat.S_ISDIR (mode): return 'd' |
|||
if stat.S_ISFIFO(mode): return 'p' |
|||
if stat.S_ISLNK (mode): return 'l' |
|||
if stat.S_ISSOCK(mode): return 's' |
|||
return '?' |
|||
|
|||
chars = [ ftypelet(mode), 'r', 'w', 'x', 'r', 'w', 'x', 'r', 'w', 'x' ] |
|||
for i in xrange(1, 10): |
|||
if not mode & 1 << 9 - i: |
|||
chars[i] = '-' |
|||
if mode & stat.S_ISUID: |
|||
if chars[3] != 'x': |
|||
chars[3] = 'S' |
|||
else: |
|||
chars[3] = 's' |
|||
if mode & stat.S_ISGID: |
|||
if chars[6] != 'x': |
|||
chars[6] = 'S' |
|||
else: |
|||
chars[6] = 's' |
|||
if mode & stat.S_ISVTX: |
|||
if chars[9] != 'x': |
|||
chars[9] = 'T' |
|||
else: |
|||
chars[9] = 't' |
|||
return u"%s (%o)" % (''.join(chars), mode) |
|||
|
|||
def createDict(data, index): |
|||
""" |
|||
Create a new dictionnay from dictionnary key=>values: |
|||
just keep value number 'index' from all values. |
|||
|
|||
>>> data={10: ("dix", 100, "a"), 20: ("vingt", 200, "b")} |
|||
>>> createDict(data, 0) |
|||
{10: 'dix', 20: 'vingt'} |
|||
>>> createDict(data, 2) |
|||
{10: 'a', 20: 'b'} |
|||
""" |
|||
return dict( (key,values[index]) for key, values in data.iteritems() ) |
|||
|
|||
# Start of UNIX timestamp (Epoch): 1st January 1970 at 00:00 |
|||
UNIX_TIMESTAMP_T0 = datetime(1970, 1, 1) |
|||
|
|||
def timestampUNIX(value): |
|||
""" |
|||
Convert an UNIX (32-bit) timestamp to datetime object. Timestamp value |
|||
is the number of seconds since the 1st January 1970 at 00:00. Maximum |
|||
value is 2147483647: 19 january 2038 at 03:14:07. |
|||
|
|||
May raise ValueError for invalid value: value have to be in 0..2147483647. |
|||
|
|||
>>> timestampUNIX(0) |
|||
datetime.datetime(1970, 1, 1, 0, 0) |
|||
>>> timestampUNIX(1154175644) |
|||
datetime.datetime(2006, 7, 29, 12, 20, 44) |
|||
>>> timestampUNIX(1154175644.37) |
|||
datetime.datetime(2006, 7, 29, 12, 20, 44, 370000) |
|||
>>> timestampUNIX(2147483647) |
|||
datetime.datetime(2038, 1, 19, 3, 14, 7) |
|||
""" |
|||
if not isinstance(value, (float, int, long)): |
|||
raise TypeError("timestampUNIX(): an integer or float is required") |
|||
if not(0 <= value <= 2147483647): |
|||
raise ValueError("timestampUNIX(): value have to be in 0..2147483647") |
|||
return UNIX_TIMESTAMP_T0 + timedelta(seconds=value) |
|||
|
|||
# Start of Macintosh timestamp: 1st January 1904 at 00:00 |
|||
MAC_TIMESTAMP_T0 = datetime(1904, 1, 1) |
|||
|
|||
def timestampMac32(value): |
|||
""" |
|||
Convert an Mac (32-bit) timestamp to string. The format is the number |
|||
of seconds since the 1st January 1904 (to 2040). Returns unicode string. |
|||
|
|||
>>> timestampMac32(0) |
|||
datetime.datetime(1904, 1, 1, 0, 0) |
|||
>>> timestampMac32(2843043290) |
|||
datetime.datetime(1994, 2, 2, 14, 14, 50) |
|||
""" |
|||
if not isinstance(value, (float, int, long)): |
|||
raise TypeError("an integer or float is required") |
|||
if not(0 <= value <= 4294967295): |
|||
return _("invalid Mac timestamp (%s)") % value |
|||
return MAC_TIMESTAMP_T0 + timedelta(seconds=value) |
|||
|
|||
def durationWin64(value): |
|||
""" |
|||
Convert Windows 64-bit duration to string. The timestamp format is |
|||
a 64-bit number: number of 100ns. See also timestampWin64(). |
|||
|
|||
>>> str(durationWin64(1072580000)) |
|||
'0:01:47.258000' |
|||
>>> str(durationWin64(2146280000)) |
|||
'0:03:34.628000' |
|||
""" |
|||
if not isinstance(value, (float, int, long)): |
|||
raise TypeError("an integer or float is required") |
|||
if value < 0: |
|||
raise ValueError("value have to be a positive or nul integer") |
|||
return timedelta(microseconds=value/10) |
|||
|
|||
# Start of 64-bit Windows timestamp: 1st January 1600 at 00:00 |
|||
WIN64_TIMESTAMP_T0 = datetime(1601, 1, 1, 0, 0, 0) |
|||
|
|||
def timestampWin64(value): |
|||
""" |
|||
Convert Windows 64-bit timestamp to string. The timestamp format is |
|||
a 64-bit number which represents number of 100ns since the |
|||
1st January 1601 at 00:00. Result is an unicode string. |
|||
See also durationWin64(). Maximum date is 28 may 60056. |
|||
|
|||
>>> timestampWin64(0) |
|||
datetime.datetime(1601, 1, 1, 0, 0) |
|||
>>> timestampWin64(127840491566710000) |
|||
datetime.datetime(2006, 2, 10, 12, 45, 56, 671000) |
|||
""" |
|||
try: |
|||
return WIN64_TIMESTAMP_T0 + durationWin64(value) |
|||
except OverflowError: |
|||
raise ValueError(_("date newer than year %s (value=%s)") % (MAXYEAR, value)) |
|||
|
|||
# Start of 60-bit UUID timestamp: 15 October 1582 at 00:00 |
|||
UUID60_TIMESTAMP_T0 = datetime(1582, 10, 15, 0, 0, 0) |
|||
|
|||
def timestampUUID60(value): |
|||
""" |
|||
Convert UUID 60-bit timestamp to string. The timestamp format is |
|||
a 60-bit number which represents number of 100ns since the |
|||
the 15 October 1582 at 00:00. Result is an unicode string. |
|||
|
|||
>>> timestampUUID60(0) |
|||
datetime.datetime(1582, 10, 15, 0, 0) |
|||
>>> timestampUUID60(130435676263032368) |
|||
datetime.datetime(1996, 2, 14, 5, 13, 46, 303236) |
|||
""" |
|||
if not isinstance(value, (float, int, long)): |
|||
raise TypeError("an integer or float is required") |
|||
if value < 0: |
|||
raise ValueError("value have to be a positive or nul integer") |
|||
try: |
|||
return UUID60_TIMESTAMP_T0 + timedelta(microseconds=value/10) |
|||
except OverflowError: |
|||
raise ValueError(_("timestampUUID60() overflow (value=%s)") % value) |
|||
|
|||
def humanDatetime(value, strip_microsecond=True): |
|||
""" |
|||
Convert a timestamp to Unicode string: use ISO format with space separator. |
|||
|
|||
>>> humanDatetime( datetime(2006, 7, 29, 12, 20, 44) ) |
|||
u'2006-07-29 12:20:44' |
|||
>>> humanDatetime( datetime(2003, 6, 30, 16, 0, 5, 370000) ) |
|||
u'2003-06-30 16:00:05' |
|||
>>> humanDatetime( datetime(2003, 6, 30, 16, 0, 5, 370000), False ) |
|||
u'2003-06-30 16:00:05.370000' |
|||
""" |
|||
text = unicode(value.isoformat()) |
|||
text = text.replace('T', ' ') |
|||
if strip_microsecond and "." in text: |
|||
text = text.split(".")[0] |
|||
return text |
|||
|
|||
NEWLINES_REGEX = re.compile("\n+") |
|||
|
|||
def normalizeNewline(text): |
|||
r""" |
|||
Replace Windows and Mac newlines with Unix newlines. |
|||
Replace multiple consecutive newlines with one newline. |
|||
|
|||
>>> normalizeNewline('a\r\nb') |
|||
'a\nb' |
|||
>>> normalizeNewline('a\r\rb') |
|||
'a\nb' |
|||
>>> normalizeNewline('a\n\nb') |
|||
'a\nb' |
|||
""" |
|||
text = text.replace("\r\n", "\n") |
|||
text = text.replace("\r", "\n") |
|||
return NEWLINES_REGEX.sub("\n", text) |
|||
|
@ -0,0 +1,5 @@ |
|||
PACKAGE = "hachoir-core" |
|||
VERSION = "1.3.4" |
|||
WEBSITE = 'http://bitbucket.org/haypo/hachoir/wiki/hachoir-core' |
|||
LICENSE = 'GNU GPL v2' |
|||
|
@ -0,0 +1,15 @@ |
|||
from hachoir_metadata.version import VERSION as __version__ |
|||
from hachoir_metadata.metadata import extractMetadata |
|||
|
|||
# Just import the module, |
|||
# each module use registerExtractor() method |
|||
import hachoir_metadata.archive |
|||
import hachoir_metadata.audio |
|||
import hachoir_metadata.file_system |
|||
import hachoir_metadata.image |
|||
import hachoir_metadata.jpeg |
|||
import hachoir_metadata.misc |
|||
import hachoir_metadata.program |
|||
import hachoir_metadata.riff |
|||
import hachoir_metadata.video |
|||
|
@ -0,0 +1,166 @@ |
|||
from hachoir_metadata.metadata_item import QUALITY_BEST, QUALITY_FASTEST |
|||
from hachoir_metadata.safe import fault_tolerant, getValue |
|||
from hachoir_metadata.metadata import ( |
|||
RootMetadata, Metadata, MultipleMetadata, registerExtractor) |
|||
from hachoir_parser.archive import (Bzip2Parser, CabFile, GzipParser, |
|||
TarFile, ZipFile, MarFile) |
|||
from hachoir_core.tools import humanUnixAttributes |
|||
from hachoir_core.i18n import _ |
|||
|
|||
def maxNbFile(meta): |
|||
if meta.quality <= QUALITY_FASTEST: |
|||
return 0 |
|||
if QUALITY_BEST <= meta.quality: |
|||
return None |
|||
return 1 + int(10 * meta.quality) |
|||
|
|||
def computeCompressionRate(meta): |
|||
""" |
|||
Compute compression rate, sizes have to be in byte. |
|||
""" |
|||
if not meta.has("file_size") \ |
|||
or not meta.get("compr_size", 0): |
|||
return |
|||
file_size = meta.get("file_size") |
|||
if not file_size: |
|||
return |
|||
meta.compr_rate = float(file_size) / meta.get("compr_size") |
|||
|
|||
class Bzip2Metadata(RootMetadata): |
|||
def extract(self, zip): |
|||
if "file" in zip: |
|||
self.compr_size = zip["file"].size/8 |
|||
|
|||
class GzipMetadata(RootMetadata): |
|||
def extract(self, gzip): |
|||
self.useHeader(gzip) |
|||
computeCompressionRate(self) |
|||
|
|||
@fault_tolerant |
|||
def useHeader(self, gzip): |
|||
self.compression = gzip["compression"].display |
|||
if gzip["mtime"]: |
|||
self.last_modification = gzip["mtime"].value |
|||
self.os = gzip["os"].display |
|||
if gzip["has_filename"].value: |
|||
self.filename = getValue(gzip, "filename") |
|||
if gzip["has_comment"].value: |
|||
self.comment = getValue(gzip, "comment") |
|||
self.compr_size = gzip["file"].size/8 |
|||
self.file_size = gzip["size"].value |
|||
|
|||
class ZipMetadata(MultipleMetadata): |
|||
def extract(self, zip): |
|||
max_nb = maxNbFile(self) |
|||
for index, field in enumerate(zip.array("file")): |
|||
if max_nb is not None and max_nb <= index: |
|||
self.warning("ZIP archive contains many files, but only first %s files are processed" % max_nb) |
|||
break |
|||
self.processFile(field) |
|||
|
|||
@fault_tolerant |
|||
def processFile(self, field): |
|||
meta = Metadata(self) |
|||
meta.filename = field["filename"].value |
|||
meta.creation_date = field["last_mod"].value |
|||
meta.compression = field["compression"].display |
|||
if "data_desc" in field: |
|||
meta.file_size = field["data_desc/file_uncompressed_size"].value |
|||
if field["data_desc/file_compressed_size"].value: |
|||
meta.compr_size = field["data_desc/file_compressed_size"].value |
|||
else: |
|||
meta.file_size = field["uncompressed_size"].value |
|||
if field["compressed_size"].value: |
|||
meta.compr_size = field["compressed_size"].value |
|||
computeCompressionRate(meta) |
|||
self.addGroup(field.name, meta, "File \"%s\"" % meta.get('filename')) |
|||
|
|||
class TarMetadata(MultipleMetadata): |
|||
def extract(self, tar): |
|||
max_nb = maxNbFile(self) |
|||
for index, field in enumerate(tar.array("file")): |
|||
if max_nb is not None and max_nb <= index: |
|||
self.warning("TAR archive contains many files, but only first %s files are processed" % max_nb) |
|||
break |
|||
meta = Metadata(self) |
|||
self.extractFile(field, meta) |
|||
if meta.has("filename"): |
|||
title = _('File "%s"') % meta.getText('filename') |
|||
else: |
|||
title = _("File") |
|||
self.addGroup(field.name, meta, title) |
|||
|
|||
@fault_tolerant |
|||
def extractFile(self, field, meta): |
|||
meta.filename = field["name"].value |
|||
meta.file_attr = humanUnixAttributes(field.getOctal("mode")) |
|||
meta.file_size = field.getOctal("size") |
|||
try: |
|||
if field.getOctal("mtime"): |
|||
meta.last_modification = field.getDatetime() |
|||
except ValueError: |
|||
pass |
|||
meta.file_type = field["type"].display |
|||
meta.author = "%s (uid=%s), group %s (gid=%s)" %\ |
|||
(field["uname"].value, field.getOctal("uid"), |
|||
field["gname"].value, field.getOctal("gid")) |
|||
|
|||
|
|||
class CabMetadata(MultipleMetadata): |
|||
def extract(self, cab): |
|||
if "folder[0]" in cab: |
|||
self.useFolder(cab["folder[0]"]) |
|||
self.format_version = "Microsoft Cabinet version %s" % cab["cab_version"].display |
|||
self.comment = "%s folders, %s files" % ( |
|||
cab["nb_folder"].value, cab["nb_files"].value) |
|||
max_nb = maxNbFile(self) |
|||
for index, field in enumerate(cab.array("file")): |
|||
if max_nb is not None and max_nb <= index: |
|||
self.warning("CAB archive contains many files, but only first %s files are processed" % max_nb) |
|||
break |
|||
self.useFile(field) |
|||
|
|||
@fault_tolerant |
|||
def useFolder(self, folder): |
|||
compr = folder["compr_method"].display |
|||
if folder["compr_method"].value != 0: |
|||
compr += " (level %u)" % folder["compr_level"].value |
|||
self.compression = compr |
|||
|
|||
@fault_tolerant |
|||
def useFile(self, field): |
|||
meta = Metadata(self) |
|||
meta.filename = field["filename"].value |
|||
meta.file_size = field["filesize"].value |
|||
meta.creation_date = field["timestamp"].value |
|||
attr = field["attributes"].value |
|||
if attr != "(none)": |
|||
meta.file_attr = attr |
|||
if meta.has("filename"): |
|||
title = _("File \"%s\"") % meta.getText('filename') |
|||
else: |
|||
title = _("File") |
|||
self.addGroup(field.name, meta, title) |
|||
|
|||
class MarMetadata(MultipleMetadata): |
|||
def extract(self, mar): |
|||
self.comment = "Contains %s files" % mar["nb_file"].value |
|||
self.format_version = "Microsoft Archive version %s" % mar["version"].value |
|||
max_nb = maxNbFile(self) |
|||
for index, field in enumerate(mar.array("file")): |
|||
if max_nb is not None and max_nb <= index: |
|||
self.warning("MAR archive contains many files, but only first %s files are processed" % max_nb) |
|||
break |
|||
meta = Metadata(self) |
|||
meta.filename = field["filename"].value |
|||
meta.compression = "None" |
|||
meta.file_size = field["filesize"].value |
|||
self.addGroup(field.name, meta, "File \"%s\"" % meta.getText('filename')) |
|||
|
|||
registerExtractor(CabFile, CabMetadata) |
|||
registerExtractor(GzipParser, GzipMetadata) |
|||
registerExtractor(Bzip2Parser, Bzip2Metadata) |
|||
registerExtractor(TarFile, TarMetadata) |
|||
registerExtractor(ZipFile, ZipMetadata) |
|||
registerExtractor(MarFile, MarMetadata) |
|||
|
@ -0,0 +1,406 @@ |
|||
from hachoir_metadata.metadata import (registerExtractor, |
|||
Metadata, RootMetadata, MultipleMetadata) |
|||
from hachoir_parser.audio import AuFile, MpegAudioFile, RealAudioFile, AiffFile, FlacParser |
|||
from hachoir_parser.container import OggFile, RealMediaFile |
|||
from hachoir_core.i18n import _ |
|||
from hachoir_core.tools import makePrintable, timedelta2seconds, humanBitRate |
|||
from datetime import timedelta |
|||
from hachoir_metadata.metadata_item import QUALITY_FAST, QUALITY_NORMAL, QUALITY_BEST |
|||
from hachoir_metadata.safe import fault_tolerant, getValue |
|||
|
|||
def computeComprRate(meta, size): |
|||
if not meta.has("duration") \ |
|||
or not meta.has("sample_rate") \ |
|||
or not meta.has("bits_per_sample") \ |
|||
or not meta.has("nb_channel") \ |
|||
or not size: |
|||
return |
|||
orig_size = timedelta2seconds(meta.get("duration")) * meta.get('sample_rate') * meta.get('bits_per_sample') * meta.get('nb_channel') |
|||
meta.compr_rate = float(orig_size) / size |
|||
|
|||
def computeBitRate(meta): |
|||
if not meta.has("bits_per_sample") \ |
|||
or not meta.has("nb_channel") \ |
|||
or not meta.has("sample_rate"): |
|||
return |
|||
meta.bit_rate = meta.get('bits_per_sample') * meta.get('nb_channel') * meta.get('sample_rate') |
|||
|
|||
VORBIS_KEY_TO_ATTR = { |
|||
"ARTIST": "artist", |
|||
"ALBUM": "album", |
|||
"TRACKNUMBER": "track_number", |
|||
"TRACKTOTAL": "track_total", |
|||
"ENCODER": "producer", |
|||
"TITLE": "title", |
|||
"LOCATION": "location", |
|||
"DATE": "creation_date", |
|||
"ORGANIZATION": "organization", |
|||
"GENRE": "music_genre", |
|||
"": "comment", |
|||
"COMPOSER": "music_composer", |
|||
"DESCRIPTION": "comment", |
|||
"COMMENT": "comment", |
|||
"WWW": "url", |
|||
"WOAF": "url", |
|||
"LICENSE": "copyright", |
|||
} |
|||
|
|||
@fault_tolerant |
|||
def readVorbisComment(metadata, comment): |
|||
metadata.producer = getValue(comment, "vendor") |
|||
for item in comment.array("metadata"): |
|||
if "=" in item.value: |
|||
key, value = item.value.split("=", 1) |
|||
key = key.upper() |
|||
if key in VORBIS_KEY_TO_ATTR: |
|||
key = VORBIS_KEY_TO_ATTR[key] |
|||
setattr(metadata, key, value) |
|||
elif value: |
|||
metadata.warning("Skip Vorbis comment %s: %s" % (key, value)) |
|||
|
|||
class OggMetadata(MultipleMetadata): |
|||
def extract(self, ogg): |
|||
granule_quotient = None |
|||
for index, page in enumerate(ogg.array("page")): |
|||
if "segments" not in page: |
|||
continue |
|||
page = page["segments"] |
|||
if "vorbis_hdr" in page: |
|||
meta = Metadata(self) |
|||
self.vorbisHeader(page["vorbis_hdr"], meta) |
|||
self.addGroup("audio[]", meta, "Audio") |
|||
if not granule_quotient and meta.has("sample_rate"): |
|||
granule_quotient = meta.get('sample_rate') |
|||
if "theora_hdr" in page: |
|||
meta = Metadata(self) |
|||
self.theoraHeader(page["theora_hdr"], meta) |
|||
self.addGroup("video[]", meta, "Video") |
|||
if "video_hdr" in page: |
|||
meta = Metadata(self) |
|||
self.videoHeader(page["video_hdr"], meta) |
|||
self.addGroup("video[]", meta, "Video") |
|||
if not granule_quotient and meta.has("frame_rate"): |
|||
granule_quotient = meta.get('frame_rate') |
|||
if "comment" in page: |
|||
readVorbisComment(self, page["comment"]) |
|||
if 3 <= index: |
|||
# Only process pages 0..3 |
|||
break |
|||
|
|||
# Compute duration |
|||
if granule_quotient and QUALITY_NORMAL <= self.quality: |
|||
page = ogg.createLastPage() |
|||
if page and "abs_granule_pos" in page: |
|||
try: |
|||
self.duration = timedelta(seconds=float(page["abs_granule_pos"].value) / granule_quotient) |
|||
except OverflowError: |
|||
pass |
|||
|
|||
def videoHeader(self, header, meta): |
|||
meta.compression = header["fourcc"].display |
|||
meta.width = header["width"].value |
|||
meta.height = header["height"].value |
|||
meta.bits_per_pixel = header["bits_per_sample"].value |
|||
if header["time_unit"].value: |
|||
meta.frame_rate = 10000000.0 / header["time_unit"].value |
|||
|
|||
def theoraHeader(self, header, meta): |
|||
meta.compression = "Theora" |
|||
meta.format_version = "Theora version %u.%u (revision %u)" % (\ |
|||
header["version_major"].value, |
|||
header["version_minor"].value, |
|||
header["version_revision"].value) |
|||
meta.width = header["frame_width"].value |
|||
meta.height = header["frame_height"].value |
|||
if header["fps_den"].value: |
|||
meta.frame_rate = float(header["fps_num"].value) / header["fps_den"].value |
|||
if header["aspect_ratio_den"].value: |
|||
meta.aspect_ratio = float(header["aspect_ratio_num"].value) / header["aspect_ratio_den"].value |
|||
meta.pixel_format = header["pixel_format"].display |
|||
meta.comment = "Quality: %s" % header["quality"].value |
|||
|
|||
def vorbisHeader(self, header, meta): |
|||
meta.compression = u"Vorbis" |
|||
meta.sample_rate = header["audio_sample_rate"].value |
|||
meta.nb_channel = header["audio_channels"].value |
|||
meta.format_version = u"Vorbis version %s" % header["vorbis_version"].value |
|||
meta.bit_rate = header["bitrate_nominal"].value |
|||
|
|||
class AuMetadata(RootMetadata): |
|||
def extract(self, audio): |
|||
self.sample_rate = audio["sample_rate"].value |
|||
self.nb_channel = audio["channels"].value |
|||
self.compression = audio["codec"].display |
|||
if "info" in audio: |
|||
self.comment = audio["info"].value |
|||
self.bits_per_sample = audio.getBitsPerSample() |
|||
computeBitRate(self) |
|||
if "audio_data" in audio: |
|||
if self.has("bit_rate"): |
|||
self.duration = timedelta(seconds=float(audio["audio_data"].size) / self.get('bit_rate')) |
|||
computeComprRate(self, audio["audio_data"].size) |
|||
|
|||
class RealAudioMetadata(RootMetadata): |
|||
FOURCC_TO_BITRATE = { |
|||
u"28_8": 15200, # 28.8 kbit/sec (audio bit rate: 15.2 kbit/s) |
|||
u"14_4": 8000, # 14.4 kbit/sec |
|||
u"lpcJ": 8000, # 14.4 kbit/sec |
|||
} |
|||
|
|||
def extract(self, real): |
|||
version = real["version"].value |
|||
if "metadata" in real: |
|||
self.useMetadata(real["metadata"]) |
|||
self.useRoot(real) |
|||
self.format_version = "Real audio version %s" % version |
|||
if version == 3: |
|||
size = getValue(real, "data_size") |
|||
elif "filesize" in real and "headersize" in real: |
|||
size = (real["filesize"].value + 40) - (real["headersize"].value + 16) |
|||
else: |
|||
size = None |
|||
if size: |
|||
size *= 8 |
|||
if self.has("bit_rate"): |
|||
sec = float(size) / self.get('bit_rate') |
|||
self.duration = timedelta(seconds=sec) |
|||
computeComprRate(self, size) |
|||
|
|||
@fault_tolerant |
|||
def useMetadata(self, info): |
|||
self.title = info["title"].value |
|||
self.author = info["author"].value |
|||
self.copyright = info["copyright"].value |
|||
self.comment = info["comment"].value |
|||
|
|||
@fault_tolerant |
|||
def useRoot(self, real): |
|||
self.bits_per_sample = 16 # FIXME: Is that correct? |
|||
if real["version"].value != 3: |
|||
self.sample_rate = real["sample_rate"].value |
|||
self.nb_channel = real["channels"].value |
|||
else: |
|||
self.sample_rate = 8000 |
|||
self.nb_channel = 1 |
|||
fourcc = getValue(real, "FourCC") |
|||
if fourcc: |
|||
self.compression = fourcc |
|||
try: |
|||
self.bit_rate = self.FOURCC_TO_BITRATE[fourcc] |
|||
except LookupError: |
|||
pass |
|||
|
|||
class RealMediaMetadata(MultipleMetadata): |
|||
KEY_TO_ATTR = { |
|||
"generated by": "producer", |
|||
"creation date": "creation_date", |
|||
"modification date": "last_modification", |
|||
"description": "comment", |
|||
} |
|||
|
|||
def extract(self, media): |
|||
if "file_prop" in media: |
|||
self.useFileProp(media["file_prop"]) |
|||
if "content_desc" in media: |
|||
self.useContentDesc(media["content_desc"]) |
|||
for index, stream in enumerate(media.array("stream_prop")): |
|||
self.useStreamProp(stream, index) |
|||
|
|||
@fault_tolerant |
|||
def useFileInfoProp(self, prop): |
|||
key = prop["name"].value.lower() |
|||
value = prop["value"].value |
|||
if key in self.KEY_TO_ATTR: |
|||
setattr(self, self.KEY_TO_ATTR[key], value) |
|||
elif value: |
|||
self.warning("Skip %s: %s" % (prop["name"].value, value)) |
|||
|
|||
@fault_tolerant |
|||
def useFileProp(self, prop): |
|||
self.bit_rate = prop["avg_bit_rate"].value |
|||
self.duration = timedelta(milliseconds=prop["duration"].value) |
|||
|
|||
@fault_tolerant |
|||
def useContentDesc(self, content): |
|||
self.title = content["title"].value |
|||
self.author = content["author"].value |
|||
self.copyright = content["copyright"].value |
|||
self.comment = content["comment"].value |
|||
|
|||
@fault_tolerant |
|||
def useStreamProp(self, stream, index): |
|||
meta = Metadata(self) |
|||
meta.comment = "Start: %s" % stream["stream_start"].value |
|||
if getValue(stream, "mime_type") == "logical-fileinfo": |
|||
for prop in stream.array("file_info/prop"): |
|||
self.useFileInfoProp(prop) |
|||
else: |
|||
meta.bit_rate = stream["avg_bit_rate"].value |
|||
meta.duration = timedelta(milliseconds=stream["duration"].value) |
|||
meta.mime_type = getValue(stream, "mime_type") |
|||
meta.title = getValue(stream, "desc") |
|||
self.addGroup("stream[%u]" % index, meta, "Stream #%u" % (1+index)) |
|||
|
|||
class MpegAudioMetadata(RootMetadata): |
|||
TAG_TO_KEY = { |
|||
# ID3 version 2.2 |
|||
"TP1": "author", |
|||
"COM": "comment", |
|||
"TEN": "producer", |
|||
"TRK": "track_number", |
|||
"TAL": "album", |
|||
"TT2": "title", |
|||
"TYE": "creation_date", |
|||
"TCO": "music_genre", |
|||
|
|||
# ID3 version 2.3+ |
|||
"TPE1": "author", |
|||
"COMM": "comment", |
|||
"TENC": "producer", |
|||
"TRCK": "track_number", |
|||
"TALB": "album", |
|||
"TIT2": "title", |
|||
"TYER": "creation_date", |
|||
"WXXX": "url", |
|||
"TCON": "music_genre", |
|||
"TLAN": "language", |
|||
"TCOP": "copyright", |
|||
"TDAT": "creation_date", |
|||
"TRDA": "creation_date", |
|||
"TORY": "creation_date", |
|||
"TIT1": "title", |
|||
} |
|||
|
|||
def processID3v2(self, field): |
|||
# Read value |
|||
if "content" not in field: |
|||
return |
|||
content = field["content"] |
|||
if "text" not in content: |
|||
return |
|||
if "title" in content and content["title"].value: |
|||
value = "%s: %s" % (content["title"].value, content["text"].value) |
|||
else: |
|||
value = content["text"].value |
|||
|
|||
# Known tag? |
|||
tag = field["tag"].value |
|||
if tag not in self.TAG_TO_KEY: |
|||
if tag: |
|||
if isinstance(tag, str): |
|||
tag = makePrintable(tag, "ISO-8859-1", to_unicode=True) |
|||
self.warning("Skip ID3v2 tag %s: %s" % (tag, value)) |
|||
return |
|||
key = self.TAG_TO_KEY[tag] |
|||
setattr(self, key, value) |
|||
|
|||
def readID3v2(self, id3): |
|||
for field in id3: |
|||
if field.is_field_set and "tag" in field: |
|||
self.processID3v2(field) |
|||
|
|||
def extract(self, mp3): |
|||
if "/frames/frame[0]" in mp3: |
|||
frame = mp3["/frames/frame[0]"] |
|||
self.nb_channel = (frame.getNbChannel(), frame["channel_mode"].display) |
|||
self.format_version = u"MPEG version %s layer %s" % \ |
|||
(frame["version"].display, frame["layer"].display) |
|||
self.sample_rate = frame.getSampleRate() |
|||
self.bits_per_sample = 16 |
|||
if mp3["frames"].looksConstantBitRate(): |
|||
self.computeBitrate(frame) |
|||
else: |
|||
self.computeVariableBitrate(mp3) |
|||
if "id3v1" in mp3: |
|||
id3 = mp3["id3v1"] |
|||
self.comment = id3["comment"].value |
|||
self.author = id3["author"].value |
|||
self.title = id3["song"].value |
|||
self.album = id3["album"].value |
|||
if id3["year"].value != "0": |
|||
self.creation_date = id3["year"].value |
|||
if "track_nb" in id3: |
|||
self.track_number = id3["track_nb"].value |
|||
if "id3v2" in mp3: |
|||
self.readID3v2(mp3["id3v2"]) |
|||
if "frames" in mp3: |
|||
computeComprRate(self, mp3["frames"].size) |
|||
|
|||
def computeBitrate(self, frame): |
|||
bit_rate = frame.getBitRate() # may returns None on error |
|||
if not bit_rate: |
|||
return |
|||
self.bit_rate = (bit_rate, _("%s (constant)") % humanBitRate(bit_rate)) |
|||
self.duration = timedelta(seconds=float(frame["/frames"].size) / bit_rate) |
|||
|
|||
def computeVariableBitrate(self, mp3): |
|||
if self.quality <= QUALITY_FAST: |
|||
return |
|||
count = 0 |
|||
if QUALITY_BEST <= self.quality: |
|||
self.warning("Process all MPEG audio frames to compute exact duration") |
|||
max_count = None |
|||
else: |
|||
max_count = 500 * self.quality |
|||
total_bit_rate = 0.0 |
|||
for index, frame in enumerate(mp3.array("frames/frame")): |
|||
if index < 3: |
|||
continue |
|||
bit_rate = frame.getBitRate() |
|||
if bit_rate: |
|||
total_bit_rate += float(bit_rate) |
|||
count += 1 |
|||
if max_count and max_count <= count: |
|||
break |
|||
if not count: |
|||
return |
|||
bit_rate = total_bit_rate / count |
|||
self.bit_rate = (bit_rate, |
|||
_("%s (Variable bit rate)") % humanBitRate(bit_rate)) |
|||
duration = timedelta(seconds=float(mp3["frames"].size) / bit_rate) |
|||
self.duration = duration |
|||
|
|||
class AiffMetadata(RootMetadata): |
|||
def extract(self, aiff): |
|||
if "common" in aiff: |
|||
self.useCommon(aiff["common"]) |
|||
computeBitRate(self) |
|||
|
|||
@fault_tolerant |
|||
def useCommon(self, info): |
|||
self.nb_channel = info["nb_channel"].value |
|||
self.bits_per_sample = info["sample_size"].value |
|||
self.sample_rate = getValue(info, "sample_rate") |
|||
if self.has("sample_rate"): |
|||
rate = self.get("sample_rate") |
|||
if rate: |
|||
sec = float(info["nb_sample"].value) / rate |
|||
self.duration = timedelta(seconds=sec) |
|||
if "codec" in info: |
|||
self.compression = info["codec"].display |
|||
|
|||
class FlacMetadata(RootMetadata): |
|||
def extract(self, flac): |
|||
if "metadata/stream_info/content" in flac: |
|||
self.useStreamInfo(flac["metadata/stream_info/content"]) |
|||
if "metadata/comment/content" in flac: |
|||
readVorbisComment(self, flac["metadata/comment/content"]) |
|||
|
|||
@fault_tolerant |
|||
def useStreamInfo(self, info): |
|||
self.nb_channel = info["nb_channel"].value + 1 |
|||
self.bits_per_sample = info["bits_per_sample"].value + 1 |
|||
self.sample_rate = info["sample_hertz"].value |
|||
sec = info["total_samples"].value |
|||
if sec: |
|||
sec = float(sec) / info["sample_hertz"].value |
|||
self.duration = timedelta(seconds=sec) |
|||
|
|||
registerExtractor(AuFile, AuMetadata) |
|||
registerExtractor(MpegAudioFile, MpegAudioMetadata) |
|||
registerExtractor(OggFile, OggMetadata) |
|||
registerExtractor(RealMediaFile, RealMediaMetadata) |
|||
registerExtractor(RealAudioFile, RealAudioMetadata) |
|||
registerExtractor(AiffFile, AiffMetadata) |
|||
registerExtractor(FlacParser, FlacMetadata) |
|||
|
@ -0,0 +1,2 @@ |
|||
MAX_STR_LENGTH = 300 # characters |
|||
RAW_OUTPUT = False |
@ -0,0 +1,28 @@ |
|||
from hachoir_metadata.metadata import RootMetadata, registerExtractor |
|||
from hachoir_metadata.safe import fault_tolerant |
|||
from hachoir_parser.file_system import ISO9660 |
|||
from datetime import datetime |
|||
|
|||
class ISO9660_Metadata(RootMetadata): |
|||
def extract(self, iso): |
|||
desc = iso['volume[0]/content'] |
|||
self.title = desc['volume_id'].value |
|||
self.title = desc['vol_set_id'].value |
|||
self.author = desc['publisher'].value |
|||
self.author = desc['data_preparer'].value |
|||
self.producer = desc['application'].value |
|||
self.copyright = desc['copyright'].value |
|||
self.readTimestamp('creation_date', desc['creation_ts'].value) |
|||
self.readTimestamp('last_modification', desc['modification_ts'].value) |
|||
|
|||
@fault_tolerant |
|||
def readTimestamp(self, key, value): |
|||
if value.startswith("0000"): |
|||
return |
|||
value = datetime( |
|||
int(value[0:4]), int(value[4:6]), int(value[6:8]), |
|||
int(value[8:10]), int(value[10:12]), int(value[12:14])) |
|||
setattr(self, key, value) |
|||
|
|||
registerExtractor(ISO9660, ISO9660_Metadata) |
|||
|
@ -0,0 +1,52 @@ |
|||
from hachoir_metadata.timezone import UTC |
|||
from datetime import date, datetime |
|||
|
|||
# Year in 1850..2030 |
|||
MIN_YEAR = 1850 |
|||
MAX_YEAR = 2030 |
|||
|
|||
class Filter: |
|||
def __init__(self, valid_types, min=None, max=None): |
|||
self.types = valid_types |
|||
self.min = min |
|||
self.max = max |
|||
|
|||
def __call__(self, value): |
|||
if not isinstance(value, self.types): |
|||
return True |
|||
if self.min is not None and value < self.min: |
|||
return False |
|||
if self.max is not None and self.max < value: |
|||
return False |
|||
return True |
|||
|
|||
class NumberFilter(Filter): |
|||
def __init__(self, min=None, max=None): |
|||
Filter.__init__(self, (int, long, float), min, max) |
|||
|
|||
class DatetimeFilter(Filter): |
|||
def __init__(self, min=None, max=None): |
|||
Filter.__init__(self, (date, datetime), |
|||
datetime(MIN_YEAR, 1, 1), |
|||
datetime(MAX_YEAR, 12, 31)) |
|||
self.min_date = date(MIN_YEAR, 1, 1) |
|||
self.max_date = date(MAX_YEAR, 12, 31) |
|||
self.min_tz = datetime(MIN_YEAR, 1, 1, tzinfo=UTC) |
|||
self.max_tz = datetime(MAX_YEAR, 12, 31, tzinfo=UTC) |
|||
|
|||
def __call__(self, value): |
|||
""" |
|||
Use different min/max values depending on value type |
|||
(datetime with timezone, datetime or date). |
|||
""" |
|||
if not isinstance(value, self.types): |
|||
return True |
|||
if hasattr(value, "tzinfo") and value.tzinfo: |
|||
return (self.min_tz <= value <= self.max_tz) |
|||
elif isinstance(value, datetime): |
|||
return (self.min <= value <= self.max) |
|||
else: |
|||
return (self.min_date <= value <= self.max_date) |
|||
|
|||
DATETIME_FILTER = DatetimeFilter() |
|||
|
@ -0,0 +1,25 @@ |
|||
from hachoir_core.i18n import _, ngettext |
|||
|
|||
NB_CHANNEL_NAME = {1: _("mono"), 2: _("stereo")} |
|||
|
|||
def humanAudioChannel(value): |
|||
return NB_CHANNEL_NAME.get(value, unicode(value)) |
|||
|
|||
def humanFrameRate(value): |
|||
if isinstance(value, (int, long, float)): |
|||
return _("%.1f fps") % value |
|||
else: |
|||
return value |
|||
|
|||
def humanComprRate(rate): |
|||
return u"%.1fx" % rate |
|||
|
|||
def humanAltitude(value): |
|||
return ngettext("%.1f meter", "%.1f meters", value) % value |
|||
|
|||
def humanPixelSize(value): |
|||
return ngettext("%s pixel", "%s pixels", value) % value |
|||
|
|||
def humanDPI(value): |
|||
return u"%s DPI" % value |
|||
|
Some files were not shown because too many files changed in this diff
Loading…
Reference in new issue