Source code for rdc.etl.extra.db.load

# -*- coding: utf-8 -*-
#
# Copyright 2012-2014 Romain Dorgueil
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy

from sqlalchemy import MetaData, Table
from rdc.etl.error import ProhibitedOperationError
from rdc.etl.hash import Hash
from rdc.etl.io import STDIN, INSERT, UPDATE, SELECT, STDERR
from rdc.etl.transform import Transform
from rdc.etl.util import now, cached_property


[docs]class DatabaseLoad(Transform): """ TODO doc this !!! test this !!!! """ engine = None table_name = None fetch_columns = None insert_only_fields = () discriminant = ('id', ) created_at_field = 'created_at' updated_at_field = 'updated_at' allowed_operations = (INSERT, UPDATE, ) def __init__(self, engine=None, table_name=None, fetch_columns=None, discriminant=None, created_at_field=None, updated_at_field=None, insert_only_fields=None, allowed_operations=None): super(DatabaseLoad, self).__init__() self.engine = engine or self.engine self.table_name = table_name or self.table_name # xxx should take self.fetch_columns into account if provided self.fetch_columns = {} if isinstance(fetch_columns, list): self.add_fetch_column(*fetch_columns) elif isinstance(fetch_columns, dict): self.add_fetch_column(**fetch_columns) self.discriminant = discriminant or self.discriminant self.created_at_field = created_at_field or self.created_at_field self.updated_at_field = updated_at_field or self.updated_at_field self.insert_only_fields = insert_only_fields or self.insert_only_fields self.allowed_operations = allowed_operations or self.allowed_operations self._buffer = [] self._connection = None self._max_buffer_size = 1000 self._last_duration = None self._last_commit_at = None self._query_count = 0 @property def connection(self): if self._connection is None: self._connection = self.engine.connect() return self._connection def commit(self): with self.connection.begin(): while len(self._buffer): hash = self._buffer.pop(0) try: yield self.do_transform(copy(hash)) except Exception as e: yield Hash(( ('_input', hash, ), ('_transform', self, ), ('_error', e, ), )), STDERR def close_connection(self): self._connection.close() self._connection = None def get_insert_columns_for(self, hash): """List of columns we can use for insert.""" return self.columns def get_update_columns_for(self, hash, row): """List of columns we can use for update.""" return [ column for column in self.columns if not column in self.insert_only_fields ] def get_columns_for(self, hash, row=None): """Retrieve list of table column names for which we have a value in given hash. """ if row: column_names = self.get_update_columns_for(hash, row) else: column_names = self.get_insert_columns_for(hash) return [key for key in hash if key in column_names] def find(self, dataset, connection=None): query = '''SELECT * FROM {table} WHERE {criteria} LIMIT 1'''.format( table=self.table_name, criteria=' AND '.join([key_atom + ' = %s' for key_atom in self.discriminant]), ) rp = (connection or self.connection).execute(query, [dataset.get(key_atom) for key_atom in self.discriminant]) # Increment stats self._input._special_stats[SELECT] += 1 return rp.fetchone() def initialize(self): super(DatabaseLoad, self).initialize() self._input._special_stats[SELECT] = 0 self._output._special_stats[INSERT] = 0 self._output._special_stats[UPDATE] = 0 def do_transform(self, hash): """Actual database load transformation logic, without the buffering / transaction logic. """ # find line, if it exist row = self.find(hash) now = self.now column_names = self.table.columns.keys() # UpdatedAt field configured ? Let's set the value in source hash if self.updated_at_field in column_names: hash[self.updated_at_field] = now # Otherwise, make sure there is no such field else: if self.updated_at_field in hash: del hash[self.updated_at_field] # UPDATE if row: if not UPDATE in self.allowed_operations: raise ProhibitedOperationError('UPDATE operations are not allowed by this transformation.') _columns = self.get_columns_for(hash, row) query = '''UPDATE {table} SET {values} WHERE {criteria}'''.format( table=self.table_name, values=', '.join(( '{column} = %s'.format(column=_column) for _column in _columns if not _column in self.discriminant )), criteria=' AND '.join(( '{key} = %s'.format(key=_key) for _key in self.discriminant )) ) values = [hash[_column] for _column in _columns if not _column in self.discriminant] + \ [hash[_column] for _column in self.discriminant] # INSERT else: if not INSERT in self.allowed_operations: raise ProhibitedOperationError('INSERT operations are not allowed by this transformation.') if self.created_at_field in column_names: hash[self.created_at_field] = now else: if self.created_at_field in hash: del hash[self.created_at_field] _columns = self.get_columns_for(hash) query = '''INSERT INTO {table} ({keys}) VALUES ({values})'''.format( table=self.table_name, keys=', '.join(_columns), values=', '.join(['%s'] * len(_columns)) ) values = [hash[key] for key in _columns] # Execute self.connection.execute(query, values) # Increment stats if row: self._output._special_stats[UPDATE] += 1 else: self._output._special_stats[INSERT] += 1 # If user required us to fetch some columns, let's query again to get their actual values. if self.fetch_columns and len(self.fetch_columns): if not row: row = self.find(hash) if not row: raise ValueError('Could not find matching row after load.') for alias, column in self.fetch_columns.iteritems(): hash[alias] = row[column] return hash def transform(self, hash, channel=STDIN): """Transform method. Stores the input in a buffer, and only unstack buffer content if some limit has been exceeded. TODO for now buffer limit is hardcoded as 1000, but we may use a few criterias to add intelligence to this: time since last commit, duration of last commit, buffer length ... """ self._buffer.append(hash) if len(self._buffer) >= self._max_buffer_size: for _out in self.commit(): yield _out def finalize(self): """Transform's finalize method. Empties the remaining lines in buffer by loading them into database and close database connection. """ super(DatabaseLoad, self).finalize() for _out in self.commit(): yield _out self.close_connection() def add_fetch_column(self, *columns, **aliased_columns): self.fetch_columns.update(aliased_columns) for column in columns: self.fetch_columns[column] = column @cached_property def columns(self): return self.table.columns.keys() @cached_property def metadata(self): """SQLAlchemy metadata.""" return MetaData() @cached_property def table(self): """SQLAlchemy table object, using metadata autoloading from database to avoid the need of column definitions.""" return Table(self.table_name, self.metadata, autoload=True, autoload_with=self.engine) @property def now(self): """Current timestamp, used for created/updated at fields.""" return now()