Source code for pemi.data_subject

import json

import pandas as pd
import sqlalchemy as sa

import pemi
from pemi.fields import *

__all__ = [
    'PdDataSubject',
    'SaDataSubject',
    'SparkDataSubject'
]

class MissingFieldsError(Exception): pass

[docs]class DataSubject: ''' A data subject is mostly just a schema and a generic data object Actually, it's mostly just a schema that knows which pipe it belongs to (if any) and can be converted from and to a pandas dataframe (really only needed for testing to work) ''' def __init__(self, schema=None, name=None, pipe=None): self.schema = schema or pemi.Schema() self.name = name self.pipe = pipe def __str__(self): subject_str = '<{}({}) {}>'.format(self.__class__.__name__, self.name, id(self)) if self.pipe: return '{}.{}'.format(self.pipe, subject_str) return subject_str def to_pd(self): raise NotImplementedError def from_pd(self, df, **kwargs): raise NotImplementedError def connect_from(self, _other): self.validate_schema() raise NotImplementedError def validate_schema(self): #pylint: disable=no-self-use return True
class PdDataSubject(DataSubject): def __init__(self, df=None, strict_match_schema=False, **kwargs): super().__init__(**kwargs) if df is None or df.shape == (0, 0): df = self._empty_df() self.strict_match_schema = strict_match_schema self.df = df def to_pd(self): return self.df def from_pd(self, df, **kwargs): self.df = df def connect_from(self, other): if other.df is None or other.df.shape == (0, 0): self.df = self._empty_df() else: self.df = other.df self.validate_schema() def validate_schema(self): 'Verify that the dataframe contains all of the columns specified in the schema' if self.strict_match_schema: return self.validate_data_frame_columns() missing = set(self.schema.keys()) - set(self.df.columns) if len(missing) == 0: return True raise MissingFieldsError('DataFrame missing expected fields: {}'.format(missing)) def validate_data_frame_columns(self): 'Verify that the schema contains all the columns specefied in the dataframe' missing = set(self.df.columns) - set(self.schema.keys()) if len(missing) > 0: raise MissingFieldsError("Schema is missing current columns: {}".format(missing)) return True def _empty_df(self): return pd.DataFrame(columns=self.schema.keys()) class SaDataSubject(DataSubject): def __init__(self, engine, table, sql_schema=None, **kwargs): super().__init__(**kwargs) self.engine = engine self.table = table self.sql_schema = sql_schema self.cached_test_df = None def to_pd(self): if self.cached_test_df is not None: return self.cached_test_df with self.engine.connect() as conn: df = pd.read_sql_table( self.table, conn, schema=self.sql_schema, ) for column in set(df.columns) & set(self.schema.keys()): df[column] = df[column].apply(self.schema[column].coerce) self.cached_test_df = df return df def from_pd(self, df, **to_sql_opts): self.cached_test_df = df pemi.log.debug('loading SaDataSubject with:\n%s', self.cached_test_df) to_sql_opts['if_exists'] = to_sql_opts.get('if_exists', 'append') to_sql_opts['index'] = to_sql_opts.get('index', False) if self.sql_schema: to_sql_opts['schema'] = self.sql_schema df_to_sql = df.copy() for field in self.schema.values(): if isinstance(field, JsonField): df_to_sql[field.name] = df_to_sql[field.name].apply(json.dumps) with self.engine.connect() as conn: df_to_sql.to_sql(self.table, conn, **to_sql_opts) def connect_from(self, _other): self.engine.dispose() self.validate_schema() def __getstate__(self): return ( [], { 'url': self.engine.url, 'table': self.table, 'sql_schema': self.sql_schema } ) def __setstate__(self, state): _args, kwargs = state self.engine = sa.create_engine(kwargs['url']) self.table = kwargs['table'] self.sql_schema = kwargs['sql_schema'] class SparkDataSubject(DataSubject): def __init__(self, spark, df=None, **kwargs): super().__init__(**kwargs) self.spark = spark self.df = df self.cached_test_df = None def to_pd(self): if self.cached_test_df is not None: return self.cached_test_df converted_df = self.df.toPandas() self.cached_test_df = pd.DataFrame() for column in list(converted_df): self.cached_test_df[column] = converted_df[column].apply(self.schema[column].coerce) return self.cached_test_df def from_pd(self, df, **kwargs): self.df = self.spark.createDataFrame(df) def connect_from(self, other): self.spark = other.spark.builder.getOrCreate() self.df = other.df self.validate_schema()