diff --git a/src/dolphin/timeseries.py b/src/dolphin/timeseries.py index 5da2b9ca..336783ae 100644 --- a/src/dolphin/timeseries.py +++ b/src/dolphin/timeseries.py @@ -84,18 +84,19 @@ def run( """ condition_func = argmax_index if condition == CallFunc.MAX else argmin_index - Path(output_dir).mkdir(exist_ok=True, parents=True) + def _get_reference(): + # First we find the reference point for the unwrapped interferograms + if reference_point == (-1, -1): + return select_reference_point( + condition_file=condition_file, + output_dir=Path(output_dir), + condition_func=condition_func, + ccl_file_list=conncomp_paths, + ) + else: + return ReferencePoint(row=reference_point[0], col=reference_point[1]) - # First we find the reference point for the unwrapped interferograms - if reference_point == (-1, -1): - reference = select_reference_point( - condition_file=condition_file, - output_dir=Path(output_dir), - condition_func=condition_func, - ccl_file_list=conncomp_paths, - ) - else: - reference = ReferencePoint(row=reference_point[0], col=reference_point[1]) + Path(output_dir).mkdir(exist_ok=True, parents=True) ifg_date_pairs = [get_dates(f) for f in unwrapped_paths] sar_dates = sorted(set(utils.flatten(ifg_date_pairs))) @@ -110,7 +111,7 @@ def run( logger.info("Inverting network of %s unwrapped ifgs", len(unwrapped_paths)) inverted_phase_paths = invert_unw_network( unw_file_list=unwrapped_paths, - reference=reference, + reference=_get_reference(), output_dir=output_dir, num_threads=num_threads, ) @@ -144,7 +145,7 @@ def run( create_velocity( unw_file_list=inverted_phase_paths, output_file=velocity_file, - reference=reference, + reference=_get_reference(), cor_file_list=cor_file_list, cor_threshold=correlation_threshold, num_threads=num_threads, @@ -711,7 +712,7 @@ def invert_unw_network( unw_reader = io.VRTStack( file_list=unw_file_list, outfile=out_vrt_name, skip_size_check=True ) - cor_vrt_name = Path(output_dir) / "unw_network.vrt" + cor_vrt_name = Path(output_dir) / "cor_network.vrt" # Get the reference point data ref_row, ref_col = reference @@ -835,6 +836,12 @@ def select_reference_point( component label files """ + output_file = output_dir / "reference_point.txt" + if output_file.exists(): + ref_point = _read_reference_point(output_file=output_file) + logger.info(f"Read {ref_point!r} from existing {output_file}") + return ref_point + logger.info("Selecting reference point") condition_file_values = io.load_gdal(condition_file, masked=True) @@ -859,7 +866,18 @@ def select_reference_point( ref_row, ref_col = condition_func(condition_file_values) # Cast to `int` to avoid having `np.int64` types - return ReferencePoint(int(ref_row), int(ref_col)) + ref_point = ReferencePoint(int(ref_row), int(ref_col)) + logger.info(f"Saving {ref_point!r} to from existing {output_file}") + _write_reference_point(output_file=output_file, ref_point=ref_point) + return ref_point + + +def _write_reference_point(output_file: Path, ref_point: ReferencePoint) -> None: + output_file.write_text(",".join(list(map(str, ref_point)))) + + +def _read_reference_point(output_file: Path): + return ReferencePoint(*[int(n) for n in output_file.read_text().split(",")]) def _get_largest_conncomp_mask( @@ -891,7 +909,6 @@ def intersect_conncomp(arr: np.ma.MaskedArray, axis: int) -> np.ndarray: read_masked=True, ) - logger.info("Selecting reference point") conncomp_intersection = io.load_gdal(conncomp_intersection_file, masked=True) # Find the largest conncomp region in the intersection