From a22064aa47cc577959d92ed184b394a934c580cc Mon Sep 17 00:00:00 2001 From: Dougal Dobie Date: Wed, 12 Feb 2025 11:28:48 +1100 Subject: [PATCH] Switch measurement pairs handling to dask (#590) * Some initial fixes * Fix measurements_df type check * I think this works now * Fixed os.path.isfile patch * Fixed test_pipeanalysis_load_two_epoch_metrics_dask * Fixed test__filter_meas_pairs_df * test_recalc_measurement_pairs_df * Fix two epoch plotting tests * Update column names in test_sources * Tests should all now pass * Seems to work?? * Initial cleanup * Manual cleanup * PEP8 --- tests/data/recalc_sources_df_output.csv | 2 +- tests/data/test_sources.csv | 2 +- tests/test_pipeline.py | 184 ++++++------ vasttools/pipeline.py | 365 ++++++++++-------------- 4 files changed, 241 insertions(+), 312 deletions(-) diff --git a/tests/data/recalc_sources_df_output.csv b/tests/data/recalc_sources_df_output.csv index 0654d5e8..58489703 100644 --- a/tests/data/recalc_sources_df_output.csv +++ b/tests/data/recalc_sources_df_output.csv @@ -1,4 +1,4 @@ -id,wavg_ra,wavg_dec,avg_compactness,min_snr,max_snr,wavg_uncertainty_ew,wavg_uncertainty_ns,avg_flux_int,avg_flux_peak,max_flux_peak,max_flux_int,min_flux_peak,min_flux_int,min_flux_peak_isl_ratio,min_flux_int_isl_ratio,v_int,v_peak,eta_int,eta_peak,new,new_high_sigma,n_neighbour_dist,vs_significant_max_peak,m_abs_significant_max_peak,vs_significant_max_int,m_abs_significant_max_int,n_measurements,n_selavy,n_forced,n_siblings,n_relations +id,wavg_ra,wavg_dec,avg_compactness,min_snr,max_snr,wavg_uncertainty_ew,wavg_uncertainty_ns,avg_flux_int,avg_flux_peak,max_flux_peak,max_flux_int,min_flux_peak,min_flux_int,min_flux_peak_isl_ratio,min_flux_int_isl_ratio,v_int,v_peak,eta_int,eta_peak,new,new_high_sigma,n_neighbour_dist,vs_abs_significant_max_peak,m_abs_significant_max_peak,vs_abs_significant_max_int,m_abs_significant_max_int,n_measurements,n_selavy,n_forced,n_siblings,n_relations 729,321.9001325179949,-6.193030233233051,1.8073103029326183,276.2080200501253,392.31612903225806,0.00013898019685224127,0.00013898019685224127,217.508,120.70375,126.772,222.104,110.207,213.979,1.0,1.0,0.01738472025034037,0.06053938024395404,16.072133072157158,327.6134309054469,False,0.0,0.0,30.212044,0.13980141,5.8425226,0.03726355,4,4,0,0,1 730,324.1633295343454,-5.040717073358131,1.1167516411885052,47.25,57.471365638766514,0.00014096412317649853,0.00014096412317649853,14.525750000000002,12.972000000000001,13.643,17.672,12.096,12.831,1.0,1.0,0.1481646810260763,0.04956644262980651,15.489511624915242,5.842483557954741,False,0.0,2.529947868662914,4.04786,0.12020669,6.483879,0.3174114,4,4,0,0,0 2251,321.9001325179949,-6.193030233233051,1.8073103029326183,276.2080200501253,392.31612903225806,0.00013898019685224127,0.00013898019685224127,217.508,120.70375,126.772,222.104,110.207,213.979,1.0,1.0,0.01738472025034037,0.06053938024395403,16.072133072157158,327.61343090548564,False,0.0,0.0,30.212044,0.13980141,5.8425226,0.03726355,4,4,0,0,1 diff --git a/tests/data/test_sources.csv b/tests/data/test_sources.csv index e7da2b18..a0a9c7fa 100644 --- a/tests/data/test_sources.csv +++ b/tests/data/test_sources.csv @@ -1,4 +1,4 @@ -id,n_meas_forced,n_meas,n_meas_sel,n_sibl,wavg_ra,wavg_dec,avg_compactness,min_snr,max_snr,wavg_uncertainty_ew,wavg_uncertainty_ns,avg_flux_int,avg_flux_peak,max_flux_peak,max_flux_int,min_flux_peak,min_flux_int,min_flux_peak_isl_ratio,min_flux_int_isl_ratio,v_int,v_peak,eta_int,eta_peak,n_rel,new,new_high_sigma,n_neighbour_dist,vs_significant_max_peak,m_abs_significant_max_peak,vs_significant_max_int,m_abs_significant_max_int +id,n_meas_forced,n_meas,n_meas_sel,n_sibl,wavg_ra,wavg_dec,avg_compactness,min_snr,max_snr,wavg_uncertainty_ew,wavg_uncertainty_ns,avg_flux_int,avg_flux_peak,max_flux_peak,max_flux_int,min_flux_peak,min_flux_int,min_flux_peak_isl_ratio,min_flux_int_isl_ratio,v_int,v_peak,eta_int,eta_peak,n_rel,new,new_high_sigma,n_neighbour_dist,vs_abs_significant_max_peak,m_abs_significant_max_peak,vs_abs_significant_max_int,m_abs_significant_max_int 729,0,5,5,1,321.90006200895346,-6.19328120595884,1.6497708211574185,276.2080200501253,392.31612903225806,0.00012430646757473707,0.00012430646757473707,196.96380000000002,119.0788,126.772,222.104,110.207,114.787,0.6058888745371567,0.6098230887743717,0.23382367320174888,0.0612809995805187,3595.588994778409,345.1476926590294,1,False,0.0,0.0006578635631167275,30.21204300309697,0.13980141700319448,93.84061084533053,0.6371022081326008 730,0,5,5,0,324.1633336543154,-5.040746047434953,1.0532970708775455,47.25,57.471365638766514,0.00012597876059013304,0.00012597876059013304,13.829400000000001,13.1404,13.814,17.672,12.096,11.044,1.0,1.0,0.17561736823859211,0.051155411831745705,24.630200983798716,6.862681205039394,0,False,0.0,0.03902594624083344,4.856547003561874,0.13261289077576224,9.635110562912836,0.461624181640897 2251,0,5,5,1,321.90044443824974,-6.192744331356945,1.6464327102481555,183.0725,392.31612903225806,0.00012432638100232205,0.00012432638100232205,188.69500000000002,111.2088,126.772,222.104,73.229,73.443,0.3941111254628434,0.3901769112256282,0.34187973190761733,0.1992146017123999,7204.5820982079385,3235.2523805832607,1,False,0.0,0.0006578635631167572,101.04454706219408,0.5354273228633857,130.73169739380583,1.0060058129502245 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d9d3ae68..9dbbfe4a 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,7 +12,7 @@ from mocpy import MOC from pathlib import Path from pytest_mock import mocker, MockerFixture # noqa: F401 -from typing import Dict, List +from typing import Dict, List, Union import vasttools.pipeline as vtp @@ -266,7 +266,7 @@ def dummy_pipeline_measurement_pairs(*args, **kwargs) -> pd.DataFrame: return measurement_pairs_df -def dummy_pipeline_measurement_pairs_vaex( +def dummy_pipeline_measurement_pairs_dask( *args, **kwargs ) -> vaex.dataframe.DataFrame: """ @@ -282,8 +282,8 @@ def dummy_pipeline_measurement_pairs_vaex( The dummy pipeline measurements pairs vaex dataframe. """ filepath = TEST_DATA_DIR / 'test_measurement_pairs.csv' - measurements_pairs_df = pd.read_csv(filepath) - measurements_pairs_df = vaex.from_pandas(measurements_pairs_df) + temp_df = pd.read_csv(filepath) + measurements_pairs_df = dd.from_pandas(temp_df, npartitions=1) return measurements_pairs_df @@ -435,8 +435,8 @@ def dummy_PipeAnalysis_wtwoepoch( The vtp.PipeAnalysis instance with two epoch data attached. """ pandas_read_parquet_mocker = mocker.patch( - 'vasttools.pipeline.pd.read_parquet', - side_effect=dummy_pipeline_measurement_pairs + 'vasttools.pipeline.dd.read_parquet', + side_effect=dummy_pipeline_measurement_pairs_dask ) dummy_PipeAnalysis.load_two_epoch_metrics() @@ -466,7 +466,7 @@ def dummy_PipeAnalysis_dask( The vtp.PipeAnalysis instance. """ mock_isdir = mocker.patch('os.path.isdir', return_value=True) - + # NOTE: This is really not great - we're basically mindlessly mocking # functions that are called potentially several times mock_isfile = mocker.patch('os.path.isfile', return_value=True) @@ -506,8 +506,8 @@ def dummy_PipeAnalysis_dask_wtwoepoch( The vtp.PipeAnalysis instance with two epoch data attached. """ vaex_open_mocker = mocker.patch( - 'vasttools.pipeline.vaex.open', - side_effect=dummy_pipeline_measurement_pairs_vaex + 'vasttools.pipeline.dd.read_parquet', + side_effect=dummy_pipeline_measurement_pairs_dask ) dummy_PipeAnalysis_dask.load_two_epoch_metrics() @@ -558,7 +558,7 @@ def _filter_source(id: int): The measurements dataframe for a requested source. """ meas = dummy_PipeAnalysis.measurements - meas = meas.loc[meas['source'] == id] + meas = meas.loc[id] return meas return _filter_source @@ -586,7 +586,7 @@ def filter_moc() -> MOC: @pytest.fixture def gen_measurement_pairs_df( dummy_PipeAnalysis_wtwoepoch: vtp.PipeAnalysis -) -> pd.DataFrame: +) -> Union[pd.DataFrame, dd.DataFrame]: """ Generates a measurement pairs dataframe for a specific 'pair epoch'. @@ -597,12 +597,14 @@ def gen_measurement_pairs_df( Returns: The measurement pairs df filtered for a pair epoch. """ - def _gen_df(epoch_id: int = 2) -> pd.DataFrame: + def _gen_df(epoch_id: int = 2, compute: bool = False + ) -> Union[pd.DataFrame, dd.DataFrame]: """ Filters a measurement pairs dataframe for a specific 'pair epoch'. Args: epoch_id: The id of the measurement pair epoch. + compute: Whether or not to compute the DataFrame. Defaults to True. Returns: The measurement pairs df filtered for a pair epoch. @@ -619,7 +621,8 @@ def _gen_df(epoch_id: int = 2) -> pd.DataFrame: ] == epoch_key ] ).copy() - + if compute: + measurement_pairs_df = measurement_pairs_df.compute() return measurement_pairs_df return _gen_df @@ -918,7 +921,7 @@ def test_load_run_no_dask_check_columns( assert 'centre_ra' in run.images.columns assert run.images.shape[1] == 29 - assert run.measurements.shape[1] == 42 + assert run.measurements.shape[1] == 41 def test_load_run_dask( self, @@ -1000,7 +1003,7 @@ def test__check_measurement_pairs_file( None """ mocker_isfile = mocker.patch( - "os.path.isfile", + "os.path.exists", side_effect=pairs_existence ) @@ -1390,40 +1393,16 @@ def test_pipeanalysis_get_source_dask( assert mocker_calls.args[0] == expected_sources_skycoord[0] assert the_source == -99 - def test_pipeanalysis_load_two_epoch_metrics_pandas( - self, - dummy_PipeAnalysis: vtp.PipeAnalysis, - dummy_pipeline_pairs_df: pd.DataFrame, - mocker: MockerFixture - ) -> None: - """ - Tests the method that loads the two epoch metrics. - - This test is for pandas loaded dataframes. - - Args: - dummy_PipeAnalysis: The dummy PipeAnalysis object that is used - for testing. - dummy_pipeline_pairs_df: The dummy pairs dataframe. - mocker: The pytest mocker mock object. - - Returns: - None - """ - pandas_read_parquet_mocker = mocker.patch( - 'vasttools.pipeline.pd.read_parquet', - side_effect=dummy_pipeline_measurement_pairs - ) - - dummy_PipeAnalysis.load_two_epoch_metrics() - - assert dummy_PipeAnalysis.pairs_df.equals(dummy_pipeline_pairs_df) - assert dummy_PipeAnalysis.measurement_pairs_df.shape[0] == 30 - + @pytest.mark.parametrize( + "compute", + [True, False], + ids=('compute', 'no-compute') + ) def test_pipeanalysis_load_two_epoch_metrics_dask( self, dummy_PipeAnalysis_dask: vtp.PipeAnalysis, dummy_pipeline_pairs_df: pd.DataFrame, + compute: bool, mocker: MockerFixture ) -> None: """ @@ -1435,22 +1414,37 @@ def test_pipeanalysis_load_two_epoch_metrics_dask( dummy_PipeAnalysis: The dummy PipeAnalysis object that is used for testing. dummy_pipeline_pairs_df: The dummy pairs dataframe. + compute: Whether or not to compute the dask dataframe as part of + the loading process. mocker: The pytest mocker mock object. Returns: None """ - vaex_open_mocker = mocker.patch( - 'vasttools.pipeline.vaex.open', - side_effect=dummy_pipeline_measurement_pairs_vaex + dask_open_mocker = mocker.patch( + 'vasttools.pipeline.dd.read_parquet', + side_effect=dummy_pipeline_measurement_pairs_dask ) - dummy_PipeAnalysis_dask.load_two_epoch_metrics() + dummy_PipeAnalysis_dask.load_two_epoch_metrics(compute=compute) assert dummy_PipeAnalysis_dask.pairs_df.equals( dummy_pipeline_pairs_df ) - assert dummy_PipeAnalysis_dask.measurement_pairs_df.shape[0] == 30 + + assert compute != dummy_PipeAnalysis_dask._dask_meas_pairs + + assert len(dummy_PipeAnalysis_dask.measurement_pairs_df.index) == 30 + if compute: + assert isinstance( + dummy_PipeAnalysis_dask.measurement_pairs_df, + pd.DataFrame + ) + else: + assert isinstance( + dummy_PipeAnalysis_dask.measurement_pairs_df, + dd.DataFrame + ) @pytest.mark.parametrize("row, kwargs, expected", [ ( @@ -1687,11 +1681,11 @@ def test__filter_meas_pairs_df(self, fixture_name: str, request) -> None: result = the_fixture._filter_meas_pairs_df( new_measurements - ) + ).compute() - assert result.shape[0] == 18 - assert np.any(result['meas_id_a'].isin(meas_ids).to_numpy()) == False - assert np.any(result['meas_id_b'].isin(meas_ids).to_numpy()) == False + assert len(result.index) == 18 + assert np.any(result['meas_id_a'].isin(meas_ids)) == False + assert np.any(result['meas_id_b'].isin(meas_ids)) == False @pytest.mark.parametrize( 'fixture_name', @@ -1728,23 +1722,25 @@ def test_recalc_measurement_pairs_df( expected_m_peak = the_fixture.measurement_pairs_df['m_peak'] expected_m_int = the_fixture.measurement_pairs_df['m_int'] - assert result['vs_peak'].to_numpy() == pytest.approx( - expected_vs_peak.to_numpy() + result = result.compute() + + assert result['vs_peak'].values == pytest.approx( + expected_vs_peak.values ) - assert result['vs_int'].to_numpy() == pytest.approx( - expected_vs_int.to_numpy() + assert result['vs_int'].values == pytest.approx( + expected_vs_int.values ) - assert result['m_peak'].to_numpy() == pytest.approx( - expected_m_peak.to_numpy() + assert result['m_peak'].values == pytest.approx( + expected_m_peak.values ) - assert result['m_int'].to_numpy() == pytest.approx( - expected_m_int.to_numpy() + assert result['m_int'].values == pytest.approx( + expected_m_int.values ) def test_recalc_sources_df( self, - dummy_PipeAnalysis: vtp.PipeAnalysis, + dummy_PipeAnalysis_wtwoepoch: vtp.PipeAnalysis, mocker: MockerFixture ) -> None: """ @@ -1759,10 +1755,10 @@ def test_recalc_sources_df( None """ - pandas_read_parquet_mocker = mocker.patch( - 'vasttools.pipeline.pd.read_parquet', - side_effect=dummy_pipeline_measurement_pairs - ) + # pandas_read_parquet_mocker = mocker.patch( + # 'vasttools.pipeline.pd.read_parquet', + # side_effect=dummy_pipeline_measurement_pairs + # ) # define this to speed up the test to avoid dask """dask_from_pandas_mocker = mocker.patch( @@ -1803,7 +1799,7 @@ def test_recalc_sources_df( .return_value ) = metrics_return_value""" - dummy_PipeAnalysis.load_two_epoch_metrics() + # dummy_PipeAnalysis.load_two_epoch_metrics() expected_result = pd.read_csv( TEST_DATA_DIR / @@ -1811,34 +1807,13 @@ def test_recalc_sources_df( index_col='id') # remove measurements from image id 2 - new_measurements = dummy_PipeAnalysis.measurements[ - dummy_PipeAnalysis.measurements.image_id != 2 + new_measurements = dummy_PipeAnalysis_wtwoepoch.measurements[ + dummy_PipeAnalysis_wtwoepoch.measurements.image_id != 2 ].copy() - result = dummy_PipeAnalysis.recalc_sources_df(new_measurements) - - print(result.columns) - print(expected_result.columns) - print(expected_result) - - # result.to_csv('recalc_sources_df_output.csv') - - # print(result) - # print(dummy_PipeAnalysis.sources) - - cols_to_test = ['min_flux_int', 'avg_flux_int', 'max_flux_int'] - for col in cols_to_test: - check = (result[col].values == - dummy_PipeAnalysis.sources[col].values).all() - if not check: - print(col) - print(result[col]) - print(dummy_PipeAnalysis.sources[col]) - - # assert result['n_selavy'].to_list() == [4, 4, 4] - # assert result.shape[1] == dummy_PipeAnalysis.sources.shape[1] + result = dummy_PipeAnalysis_wtwoepoch.recalc_sources_df( + new_measurements) - print(set(expected_result.columns) - set(result.columns)) pd.testing.assert_frame_equal(result, expected_result) # assert 1==0 @@ -1866,7 +1841,7 @@ def test__get_epoch_pair_plotting_df( df_filter, num_pairs, num_candidates, td_days = ( dummy_PipeAnalysis_wtwoepoch._get_epoch_pair_plotting_df( - dummy_PipeAnalysis_wtwoepoch.measurement_pairs_df, + dummy_PipeAnalysis_wtwoepoch.measurement_pairs_df.compute(), epoch_id, 'vs_peak', 'm_peak', @@ -1903,7 +1878,8 @@ def test__plot_epoch_pair_matplotlib( None """ epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch._plot_epoch_pair_matplotlib( epoch_id, @@ -1948,7 +1924,8 @@ def test__plot_epoch_pair_matplotlib_styleb( None """ epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch._plot_epoch_pair_matplotlib( epoch_id, @@ -1994,7 +1971,8 @@ def test__plot_epoch_pair_matplotlib_int_flux( None """ epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch._plot_epoch_pair_matplotlib( epoch_id, @@ -2039,7 +2017,8 @@ def test__plot_epoch_pair_bokeh( None """ epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch._plot_epoch_pair_bokeh( epoch_id, @@ -2070,7 +2049,8 @@ def test__plot_epoch_pair_bokeh_styleb( None """ epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch._plot_epoch_pair_bokeh( epoch_id, @@ -2108,7 +2088,8 @@ def test_plot_two_epoch_pairs_matplotlib( ) epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch.plot_two_epoch_pairs( epoch_id, @@ -2148,7 +2129,8 @@ def test_plot_two_epoch_pairs_bokeh( ) epoch_id = 2 - expected_measurement_pairs_df = gen_measurement_pairs_df(epoch_id) + expected_measurement_pairs_df = gen_measurement_pairs_df( + epoch_id, compute=True) result = dummy_PipeAnalysis_wtwoepoch.plot_two_epoch_pairs(epoch_id) diff --git a/vasttools/pipeline.py b/vasttools/pipeline.py index 73c673bb..5bef7faa 100644 --- a/vasttools/pipeline.py +++ b/vasttools/pipeline.py @@ -18,7 +18,6 @@ import mocpy import matplotlib import logging -import vaex import matplotlib.pyplot as plt from typing import List, Tuple @@ -88,7 +87,7 @@ class PipeRun(object): Dataframe containing all the information on the measurements of the pipeline run. measurement_pairs_file (List[str]): List containing the locations of - the measurement_pairs.parquet (or.arrow) file(s). + the measurement_pairs.parquet file(s). name (str): The pipeline run name. n_workers (int): Number of workers (cpus) available. relations (pandas.core.frame.DataFrame): Dataframe containing all the @@ -177,7 +176,7 @@ def _check_measurement_pairs_file(self): measurement_pairs_exists = True for filepath in self.measurement_pairs_file: - if not os.path.isfile(filepath): + if not os.path.exists(filepath): self.logger.warning(f"Measurement pairs file ({filepath}) does" f" not exist. You will be unable to access" f" measurement pairs or two-epoch metrics." @@ -239,8 +238,9 @@ def combine_with_run( else: self.measurements = pd.concat( [self.measurements, other_PipeRun.measurements], - ignore_index=True - ).drop_duplicates(['id', 'source']) + ).reset_index().drop_duplicates( + ['id', 'source'] + ).set_index('source', drop=True) sources_to_add = other_PipeRun.sources.loc[ ~(other_PipeRun.sources.index.isin( @@ -362,12 +362,9 @@ def get_source( else: the_sources = user_sources + measurements = the_measurements.loc[id].reset_index() if self._dask_meas: - measurements = the_measurements.loc[id].compute().reset_index() - else: - measurements = the_measurements.loc[ - the_measurements['source'] == id - ] + measurements = measurements.compute() measurements = measurements.merge( self.images[[ @@ -453,10 +450,9 @@ def _raise_if_no_pairs(self): "available for this pipeline run." ) - def load_two_epoch_metrics(self) -> None: + def load_two_epoch_metrics(self, compute: bool = False) -> None: """ - Loads the two epoch metrics dataframe, usually stored as either - 'measurement_pairs.parquet' or 'measurement_pairs.arrow'. + Loads the two epoch metrics dataframe from 'measurement_pairs.parquet'. The two epoch metrics dataframe is stored as an attribute to the PipeRun object as self.measurement_pairs_df. An epoch 'key' is also @@ -465,6 +461,10 @@ def load_two_epoch_metrics(self) -> None: Also creates a 'pairs_df' that lists all the possible epoch pairs. This is stored as the attribute self.pairs_df. + Args: + compute: If `True`, compute the measurement_pairs_df, otherwise + leave it as a dask dataframe. Defaults to `False`. + Returns: None @@ -513,64 +513,24 @@ def load_two_epoch_metrics(self) -> None: ) ) - self._vaex_meas_pairs = False - if len(self.measurement_pairs_file) > 1: - arrow_files = ( - [i.endswith(".arrow") for i in self.measurement_pairs_file] - ) - if np.any(arrow_files): - measurement_pairs_df = vaex.open_many( - self.measurement_pairs_file[arrow_files] - ) - for i in self.measurement_pairs_file[~arrow_files]: - temp = pd.read_parquet(i) - temp = vaex.from_pandas(temp) - measurement_pairs_df = measurement_pairs_df.concat(temp) - self._vaex_meas_pairs = True - warnings.warn("Measurement pairs have been loaded with vaex.") - else: - measurement_pairs_df = ( - dd.read_parquet(self.measurement_pairs_file).compute() - ) - else: - if self.measurement_pairs_file[0].endswith('.arrow'): - measurement_pairs_df = ( - vaex.open(self.measurement_pairs_file[0]) - ) - self._vaex_meas_pairs = True - warnings.warn("Measurement pairs have been loaded with vaex.") - else: - measurement_pairs_df = ( - pd.read_parquet(self.measurement_pairs_file[0]) - ) - - if self._vaex_meas_pairs: - measurement_pairs_df['pair_epoch_key'] = ( - measurement_pairs_df['image_name_a'] + "_" - + measurement_pairs_df['image_name_b'] - ) - - pair_counts = measurement_pairs_df.groupby( - measurement_pairs_df.pair_epoch_key, agg='count' - ) + measurement_pairs_df = ( + dd.read_parquet(self.measurement_pairs_file) + ) - pair_counts = pair_counts.to_pandas_df().rename( - columns={'count': 'total_pairs'} - ).set_index('pair_epoch_key') - else: - measurement_pairs_df['pair_epoch_key'] = ( - measurement_pairs_df[['image_name_a', 'image_name_b']] - .apply( - lambda x: f"{x['image_name_a']}_{x['image_name_b']}", - axis=1 - ) + measurement_pairs_df['pair_epoch_key'] = ( + measurement_pairs_df[['image_name_a', 'image_name_b']] + .apply( + lambda x: f"{x['image_name_a']}_{x['image_name_b']}", + axis=1, + meta=(None, 'object') ) + ) - pair_counts = measurement_pairs_df[ - ['pair_epoch_key', 'image_name_a'] - ].groupby('pair_epoch_key').count().rename( - columns={'image_name_a': 'total_pairs'} - ) + pair_counts = measurement_pairs_df[ + ['pair_epoch_key', 'image_name_a'] + ].groupby('pair_epoch_key').count().rename( + columns={'image_name_a': 'total_pairs'} + ).compute() pairs_df = pairs_df.merge( pair_counts, left_on='pair_epoch_key', right_index=True @@ -580,7 +540,13 @@ def load_two_epoch_metrics(self) -> None: pairs_df = pairs_df.dropna(subset=['total_pairs']).set_index('id') - self.measurement_pairs_df = measurement_pairs_df + if compute: + self.measurement_pairs_df = measurement_pairs_df.compute() + self._dask_meas_pairs = False + else: + self.measurement_pairs_df = measurement_pairs_df + self._dask_meas_pairs = True + self.pairs_df = pairs_df.sort_values(by='td') self._loaded_two_epoch_metrics = True @@ -736,14 +702,9 @@ def filter_by_moc(self, moc: mocpy.MOC) -> PipeAnalysis: new_sources = self.sources.loc[source_mask].copy() + new_meas = self.measurements.loc[new_sources.index.values] if self._dask_meas: - new_meas = self.measurements.loc[new_sources.index.values] new_meas = new_meas.compute() - else: - new_meas = self.measurements.loc[ - self.measurements['source'].isin( - new_sources.index.values - )].copy() new_images = self.images.loc[ self.images.index.isin(new_meas['image_id'].tolist())].copy() @@ -847,8 +808,8 @@ class PipeAnalysis(PipeRun): Dataframe containing all the information on the measurements of the pipeline run. measurement_pairs_file (List[str]): - List containing the locations of the measurement_pairs.parquet (or - .arrow) file(s). + List containing the locations of the measurement_pairs.parquet + file(s). name (str): The pipeline run name. n_workers (int): @@ -929,8 +890,8 @@ def __init__( def _filter_meas_pairs_df( self, - measurements_df: Union[pd.DataFrame, vaex.dataframe.DataFrame] - ) -> Union[pd.DataFrame, vaex.dataframe.DataFrame]: + measurements_df: Union[pd.DataFrame, dd.DataFrame] + ) -> Union[pd.DataFrame, dd.DataFrame]: """ A utility method to filter the measurement pairs dataframe to remove pairs that are no longer in the measurements dataframe. @@ -946,43 +907,46 @@ def _filter_meas_pairs_df( if not self._loaded_two_epoch_metrics: self.load_two_epoch_metrics() - if self._vaex_meas_pairs: - new_measurement_pairs = self.measurement_pairs_df.copy() + new_measurement_pairs = self.measurement_pairs_df.copy() + + if isinstance(measurements_df, dd.DataFrame): + measurement_ids = measurements_df['id'].compute().values else: - new_measurement_pairs = vaex.from_pandas( - self.measurement_pairs_df - ) + measurement_ids = measurements_df['id'].values mask_a = new_measurement_pairs['meas_id_a'].isin( - measurements_df['id'].values - ).values + measurement_ids + ) mask_b = new_measurement_pairs['meas_id_b'].isin( - measurements_df['id'].values - ).values + measurement_ids + ) - new_measurement_pairs['mask_a'] = mask_a - new_measurement_pairs['mask_b'] = mask_b + if isinstance(new_measurement_pairs, dd.DataFrame): + new_measurement_pairs = new_measurement_pairs[mask_a & mask_b] + else: + mask = np.logical_and(mask_a.values, mask_b.values) + new_measurement_pairs = new_measurement_pairs[mask] - mask = np.logical_and(mask_a, mask_b) - new_measurement_pairs['mask'] = mask - new_measurement_pairs = new_measurement_pairs[ - new_measurement_pairs['mask'] == True - ] + return new_measurement_pairs - new_measurement_pairs = new_measurement_pairs.extract() - new_measurement_pairs = new_measurement_pairs.drop( - ['mask_a', 'mask_b', 'mask'] - ) + def _assign_new_flux_values( + self, measurement_pairs_df, flux_cols, measurements_df): + for j in ['a', 'b']: + id_values = measurement_pairs_df[f'meas_id_{j}'].to_numpy() - if not self._vaex_meas_pairs: - new_measurement_pairs = new_measurement_pairs.to_pandas_df() + for i in flux_cols: + if i == 'id': + continue + pairs_i = i + f'_{j}' + new_flux_values = measurements_df.loc[id_values, i].values + measurement_pairs_df[pairs_i] = new_flux_values - return new_measurement_pairs + return measurement_pairs_df def recalc_measurement_pairs_df( self, - measurements_df: Union[pd.DataFrame, vaex.dataframe.DataFrame] + measurements_df: Union[pd.DataFrame, dd.DataFrame] ) -> Union[pd.DataFrame, dd.DataFrame]: """ A method to recalculate the two epoch pair metrics based upon a @@ -1004,20 +968,10 @@ def recalc_measurement_pairs_df( new_measurement_pairs = self._filter_meas_pairs_df( measurements_df[['id']] ) - - # NOTE: This needs to be re-done to correctly handle measurement pairs - # being in dask dataframes - - # an attempt to conserve memory - if isinstance(new_measurement_pairs, vaex.dataframe.DataFrame): - new_measurement_pairs = new_measurement_pairs.drop( - ['vs_peak', 'vs_int', 'm_peak', 'm_int'] - ) - else: - new_measurement_pairs = new_measurement_pairs.drop( - ['vs_peak', 'vs_int', 'm_peak', 'm_int'], - axis=1 - ) + new_measurement_pairs = new_measurement_pairs.drop( + ['vs_peak', 'vs_int', 'm_peak', 'm_int'], + axis=1 + ) flux_cols = [ 'flux_int', @@ -1030,6 +984,8 @@ def recalc_measurement_pairs_df( # convert pandas measurements to dask for consistency if isinstance(measurements_df, pd.DataFrame): measurements_df = pandas_to_dask(measurements_df) + if isinstance(new_measurement_pairs, pd.DataFrame): + new_measurement_pairs = pandas_to_dask(new_measurement_pairs) measurements_df = measurements_df[flux_cols] @@ -1039,40 +995,55 @@ def recalc_measurement_pairs_df( .set_index('id') ) - for i in flux_cols: - if i == 'id': - continue - for j in ['a', 'b']: + new_cols = [] + for j in ['a', 'b']: + for i in flux_cols: + if i == 'id': + continue + pairs_i = i + f'_{j}' - id_values = new_measurement_pairs[f'meas_id_{j}'].to_numpy() - new_flux_values = measurements_df.loc[id_values][i] - new_flux_values = new_flux_values.compute().to_numpy() - new_measurement_pairs[pairs_i] = new_flux_values + new_cols.append(pairs_i) + cols = list(new_measurement_pairs.columns) + new_cols + meta = pd.DataFrame(columns=cols, dtype=float) + + n_partitions = new_measurement_pairs.npartitions + + new_measurement_pairs = new_measurement_pairs.map_partitions( + self._assign_new_flux_values, + flux_cols, + measurements_df, + align_dataframes=False, + meta=new_measurement_pairs.head() + ) del measurements_df # calculate 2-epoch metrics new_measurement_pairs["vs_peak"] = calculate_vs_metric( - new_measurement_pairs['flux_peak_a'].to_numpy(), - new_measurement_pairs['flux_peak_b'].to_numpy(), - new_measurement_pairs['flux_peak_err_a'].to_numpy(), - new_measurement_pairs['flux_peak_err_b'].to_numpy() + new_measurement_pairs['flux_peak_a'].values, + new_measurement_pairs['flux_peak_b'].values, + new_measurement_pairs['flux_peak_err_a'].values, + new_measurement_pairs['flux_peak_err_b'].values, ) + new_measurement_pairs["vs_int"] = calculate_vs_metric( - new_measurement_pairs['flux_int_a'].to_numpy(), - new_measurement_pairs['flux_int_b'].to_numpy(), - new_measurement_pairs['flux_int_err_a'].to_numpy(), - new_measurement_pairs['flux_int_err_b'].to_numpy() + new_measurement_pairs['flux_int_a'].values, + new_measurement_pairs['flux_int_b'].values, + new_measurement_pairs['flux_int_err_a'].values, + new_measurement_pairs['flux_int_err_b'].values, ) new_measurement_pairs["m_peak"] = calculate_m_metric( - new_measurement_pairs['flux_peak_a'].to_numpy(), - new_measurement_pairs['flux_peak_b'].to_numpy() + new_measurement_pairs['flux_peak_a'].values, + new_measurement_pairs['flux_peak_b'].values, ) new_measurement_pairs["m_int"] = calculate_m_metric( - new_measurement_pairs['flux_int_a'].to_numpy(), - new_measurement_pairs['flux_int_b'].to_numpy() + new_measurement_pairs['flux_int_a'].values, + new_measurement_pairs['flux_int_b'].values, ) + if not self._dask_meas_pairs: + new_measurement_pairs = new_measurement_pairs.compute() + return new_measurement_pairs def recalc_sources_df( @@ -1111,10 +1082,7 @@ def recalc_sources_df( if not self._loaded_two_epoch_metrics: self.load_two_epoch_metrics() - # this should actually check the type of the measurements_df - # rather than assuming it's the same as self.measurements - # TO DO: fix that!! - if not self._dask_meas: + if isinstance(measurements_df, pd.DataFrame): measurements_df = pandas_to_dask(measurements_df) # account for RA wrapping @@ -1206,9 +1174,9 @@ def recalc_sources_df( # df is in pandas format. # TraP variability metrics, using Dask. - measurements_df_temp = measurements_df[[ - 'flux_int', 'flux_int_err', 'flux_peak', 'flux_peak_err', 'source' - ]] + measurements_df_temp = measurements_df[ + ['flux_int', 'flux_int_err', 'flux_peak', 'flux_peak_err'] + ] col_dtype = { 'v_int': 'f', @@ -1239,20 +1207,17 @@ def recalc_sources_df( measurements_df[['id']] ) - if isinstance(measurement_pairs_df, vaex.dataframe.DataFrame): - new_measurement_pairs = ( - measurement_pairs_df[ - measurement_pairs_df['vs_int'].abs() >= min_vs - or measurement_pairs_df['vs_peak'].abs() >= min_vs - ] - ) + if isinstance(measurement_pairs_df, dd.DataFrame): + mask_int = measurement_pairs_df['vs_int'].abs() >= min_vs + mask_peak = measurement_pairs_df['vs_peak'].abs() >= min_vs + new_measurement_pairs = measurement_pairs_df[mask_int | mask_peak] else: min_vs_mask = np.logical_or( (measurement_pairs_df['vs_int'].abs() >= min_vs).to_numpy(), (measurement_pairs_df['vs_peak'].abs() >= min_vs).to_numpy() ) new_measurement_pairs = measurement_pairs_df.loc[min_vs_mask] - new_measurement_pairs = vaex.from_pandas(new_measurement_pairs) + new_measurement_pairs = pandas_to_dask(new_measurement_pairs) new_measurement_pairs['vs_int_abs'] = ( new_measurement_pairs['vs_int'].abs() @@ -1270,23 +1235,27 @@ def recalc_sources_df( new_measurement_pairs['m_peak'].abs() ) - sources_df_two_epochs = new_measurement_pairs.groupby( - 'source_id', - agg={ - 'vs_significant_max_int': vaex.agg.max('vs_int_abs'), - 'vs_significant_max_peak': vaex.agg.max('vs_peak_abs'), - 'm_abs_significant_max_int': vaex.agg.max('m_int_abs'), - 'm_abs_significant_max_peak': vaex.agg.max('m_peak_abs'), - } - ) + sources_df_metrics = new_measurement_pairs.groupby('source_id').agg({ + 'vs_int_abs': 'max', + 'vs_peak_abs': 'max', + 'm_int_abs': 'max', + 'm_peak_abs': 'max', + }) - sources_df_two_epochs = ( - sources_df_two_epochs.to_pandas_df().set_index('source_id') + sources_df_metrics.columns = [ + 'vs_abs_significant_max_int', + 'vs_abs_significant_max_peak', + 'm_abs_significant_max_int', + 'm_abs_significant_max_peak' + ] + + sources_df_metrics = ( + sources_df_metrics.compute() ) - sources_df = sources_df.join(sources_df_two_epochs) + sources_df = sources_df.join(sources_df_metrics) - del sources_df_two_epochs + del sources_df_metrics # new relation numbers relation_mask = np.logical_and( @@ -1314,9 +1283,9 @@ def recalc_sources_df( # Fill the NaN values. sources_df = sources_df.fillna(value={ - "vs_significant_max_peak": 0.0, + "vs_abs_significant_max_peak": 0.0, "m_abs_significant_max_peak": 0.0, - "vs_significant_max_int": 0.0, + "vs_abs_significant_max_int": 0.0, "m_abs_significant_max_int": 0.0, 'n_relations': 0, 'v_int': 0., @@ -1392,13 +1361,7 @@ def _get_epoch_pair_plotting_df( ][['id', 'forced']] if self._dask_meas: - temp_meas = self.measurements.loc[unique_meas_ids][[ - 'id', 'forced']] temp_meas = temp_meas.compute() - else: - temp_meas = self.measurements[ - self.measurements['id'].isin(unique_meas_ids) - ][['id', 'forced']] temp_meas = temp_meas.drop_duplicates('id').set_index('id') @@ -1773,8 +1736,8 @@ def plot_two_epoch_pairs( ] ) - if self._vaex_meas_pairs: - pairs_df = pairs_df.extract().to_pandas_df() + if self._dask_meas_pairs: + pairs_df = pairs_df.compute() pairs_df = pairs_df[pairs_df['source_id'].isin(df.index.values)] @@ -1803,7 +1766,8 @@ def plot_two_epoch_pairs( def run_two_epoch_analysis( self, vs: float, m: float, query: Optional[str] = None, - df: Optional[pd.DataFrame] = None, use_int_flux: bool = False + df: Optional[pd.DataFrame] = None, use_int_flux: bool = False, + compute: bool = True ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Run the two epoch analysis on the pipeline run, with optional @@ -1819,6 +1783,9 @@ def run_two_epoch_analysis( If None then the sources from the PipeAnalysis object are used. use_int_flux: Use integrated fluxes for the analysis instead of peak fluxes, defaults to 'False'. + compute: Whether or not to compute the resulting dataframes if + the pairs are loaded with dask. This is only relevant if + `self._dask_meas_pairs==True`. Defaults to `True`. Returns: Tuple containing two dataframes of the candidate sources and pairs. @@ -1851,14 +1818,9 @@ def run_two_epoch_analysis( pairs_df = self.measurement_pairs_df.copy() if len(allowed_sources) != self.sources.shape[0]: - if self._vaex_meas_pairs: - pairs_df = pairs_df[ - pairs_df['source_id'].isin(allowed_sources) - ] - else: - pairs_df = pairs_df.loc[ - pairs_df['source_id'].isin(allowed_sources) - ] + pairs_df = pairs_df.loc[ + pairs_df['source_id'].isin(allowed_sources) + ] vs_label = 'vs_int' if use_int_flux else 'vs_peak' m_abs_label = 'm_int' if use_int_flux else 'm_peak' @@ -1866,23 +1828,17 @@ def run_two_epoch_analysis( pairs_df[vs_label] = pairs_df[vs_label].abs() pairs_df[m_abs_label] = pairs_df[m_abs_label].abs() - # If vaex convert these to pandas - if self._vaex_meas_pairs: - candidate_pairs = pairs_df[ - (pairs_df[vs_label] > vs) & (pairs_df[m_abs_label] > m) - ] - - candidate_pairs = candidate_pairs.to_pandas_df() - - else: - candidate_pairs = pairs_df.loc[ - (pairs_df[vs_label] > vs) & (pairs_df[m_abs_label] > m) - ] + candidate_pairs = pairs_df.loc[ + (pairs_df[vs_label] > vs) & (pairs_df[m_abs_label] > m) + ] unique_sources = candidate_pairs['source_id'].unique() candidate_sources = self.sources.loc[unique_sources] + if self._dask_meas_pairs and compute: + candidate_pairs = candidate_pairs.compute() + return candidate_sources, candidate_pairs def _fit_eta_v( @@ -2762,23 +2718,14 @@ def load_run( right_on='id' ) .rename(columns={'source_id': 'source'}) - ).reset_index(drop=True) + ).set_index('source') images = images.set_index('id') - if os.path.isfile(os.path.join( + measurement_pairs_file = [os.path.join( run_dir, - "measurement_pairs.arrow" - )): - measurement_pairs_file = [os.path.join( - run_dir, - "measurement_pairs.arrow" - )] - else: - measurement_pairs_file = [os.path.join( - run_dir, - "measurement_pairs.parquet" - )] + "measurement_pairs.parquet" + )] piperun = PipeAnalysis( name=run_name,