diff --git a/holoclean/dataengine.py b/holoclean/dataengine.py index edaa987e..fa7a9d2b 100644 --- a/holoclean/dataengine.py +++ b/holoclean/dataengine.py @@ -168,22 +168,31 @@ def query(self, sql_query, spark_flag=0): print "Could not execute Query ", sql_query, "Check log for info" exit(5) - def ingest_data(self, filepath, dataset): + def ingest_data(self, filepath, dataset, sep=None, escape=None, multiLine=None): """ - Load data from a file to a dataframe and store it on the db + Load data from a file to a dataframe and store it on the db. + + The named parameters `sep`, `escape`, `multiLine` + correspond to those for `pyspark.sql.DataFrameReader.csv`. filepath : String File path of the .csv file for the dataset dataset: DataSet The DataSet object that holds the Session ID for HoloClean - + sep : String + Single character used as a field separator + escape : String + Single character used for escaping quoted values + multiLine : boolean + Parse records that span multiple lines """ # Spawn new reader and load data into dataframe filereader = Reader(self.holo_env) # read with an index column - df = filereader.read(filepath,1) + df = filereader.read(filepath, indexcol=1, sep=sep, + escape=escape, multiLine=multiLine) # Store dataframe to DB table schema = df.schema.names diff --git a/holoclean/holoclean.py b/holoclean/holoclean.py index f4f7f335..f4b4411f 100644 --- a/holoclean/holoclean.py +++ b/holoclean/holoclean.py @@ -293,18 +293,26 @@ def __init__(self, holo_env, name="session"): self.inferred_values = None self.feature_count = 0 - def load_data(self, file_path): + def load_data(self, file_path, sep=None, escape=None, + multiLine=None): """ - Loads a dataset from file into the database + Loads a dataset from file into the database. + + The named parameters `sep`, `escape`, `multiLine` + correspond to those for `pyspark.sql.DataFrameReader.csv`. :param file_path: path to data file + :param sep: Single character used as a field separator + :param escape: Single character used for escaping quoted values + :param multiLine: Parse records that span multiple lines :return: pyspark dataframe """ if self.holo_env.verbose: start = time.time() - self._ingest_dataset(file_path) + self._ingest_dataset(file_path, sep=sep, escape=escape, + multiLine=multiLine) init = self.init_dataset @@ -516,17 +524,25 @@ def compare_to_truth(self, truth_path): acc = Accuracy(self, truth_path) acc.accuracy_calculation() - def _ingest_dataset(self, src_path): + def _ingest_dataset(self, src_path, sep=None, escape=None, + multiLine=None): """ - Ingests a dataset from given source path + Ingests a dataset from given source path. + + The named parameters `sep`, `escape`, `multiLine` + correspond to those for `pyspark.sql.DataFrameReader.csv`. :param src_path: string literal of path to the .csv file of the dataset + :param sep: Single character used as a field separator + :param escape: Single character used for escaping quoted values + :param multiLine: Parse records that span multiple lines :return: Null """ self.holo_env.logger.info('ingesting file:' + src_path) self.init_dataset, self.attribute_map = \ - self.holo_env.dataengine.ingest_data(src_path, self.dataset) + self.holo_env.dataengine.ingest_data(src_path, self.dataset, + sep=sep, escape=escape, multiLine=multiLine) self.holo_env.logger.info( 'creating dataset with id:' + self.dataset.print_id()) diff --git a/holoclean/utils/reader.py b/holoclean/utils/reader.py index 21933ece..a6bc465a 100644 --- a/holoclean/utils/reader.py +++ b/holoclean/utils/reader.py @@ -20,28 +20,37 @@ def __init__(self, holo_object): self.spark_session = holo_object.spark_session # Internal Methods - def _findextesion(self, filepath): + def _findextension(self, filepath): """ - Finds the extesion of the file. + Finds the extension of the file. :param filepath: The path to the file """ - extention = filepath.split('.')[-1] - return extention + extension = filepath.split('.')[-1] + return extension - def read(self, filepath, indexcol=0, schema=None): + def read(self, filepath, indexcol=0, schema=None, sep=None, + escape=None, multiLine=None): """ Calls the appropriate reader for the file + The named parameters `schema`, `sep`, `escape`, `multiLine` + correspond to those for `pyspark.sql.DataFrameReader.csv`. + :param schema: optional schema when known :param filepath: The path to the file + :param sep: Single character used as a field separator + :param escape: Single character used for escaping quoted values + :param multiLine: Parse records that span multiple lines :return: data frame of the read data """ - if self._findextesion(filepath) == "csv": + if self._findextension(filepath) == "csv": csv_obj = CSVReader(self.holo_object) - df = csv_obj.read(filepath, self.spark_session, indexcol, schema) + df = csv_obj.read(filepath, self.spark_session, indexcol, schema, + sep=sep, escape=escape, multiLine=multiLine) + return df else: print("This extension doesn't support") @@ -62,22 +71,31 @@ def __init__(self, holo_object): self.holo_obj = holo_object # Setters - def read(self, file_path, spark_session, indexcol=0, schema=None): + def read(self, file_path, spark_session, indexcol=0, schema=None, sep=None, escape=None, multiLine=None): """ - Creates a dataframe from the csv file + Creates a dataframe from the csv file. + + The named parameters `schema`, `sep`, `escape`, `multiLine` + correspond to those for `pyspark.sql.DataFrameReader.csv`. :param indexcol: if 1, create a tuple id column as auto increment :param schema: optional schema of file if known :param spark_session: The spark_session we created in Holoclean object :param file_path: The path to the file + :param sep: Single character used as a field separator + :param escape: Single character used for escaping quoted values + :param multiLine: Parse records that span multiple lines :return: dataframe """ try: if schema is None: - df = spark_session.read.csv(file_path, header=True) + df = spark_session.read.csv(file_path, header=True, sep=sep, + escape=escape, multiLine=multiLine) else: - df = spark_session.read.csv(file_path, header=True, schema=schema) + df = spark_session.read.csv(file_path, header=True, schema=schema, + sep=sep, escape=escape, + multiLine=multiLine) if indexcol == 0: return df @@ -128,7 +146,7 @@ def checking_string_size(self, dataframe): dataframe = self.ignore_columns(columns, dataframe) return dataframe - def ignore_columns(self, columns, dataframe): + def ignore_columns(self, columns, dataframe): """ This method asks the user if he wants to drop a column which has a string with more than 255 characters