diff --git a/python/README.md b/python/README.md index 106ab5d..1157a28 100644 --- a/python/README.md +++ b/python/README.md @@ -21,21 +21,3 @@ vbr_download --dataset --save-dir You can get the list of all the available sequences by using the `-h | --help` flag -```shell -❯ vbr_download -h -usage: vbr_download [-h] --dataset - {all,campus_test0,campus_test1,campus_train0,campus_train1,ciampino_test0,ciampino_t -est1,ciampino_train0,ciampino_train1,colosseo_test0,colosseo_train0,diag_test0,diag_train0,pincio_test0, -pincio_train0,spagna_test0,spagna_train0} - --save-dir PATH - -╭─ options ────────────────────────────────────────────────────────────────────────────────────────────╮ -│ -h, --help show this help message and exit │ -│ --dataset │ -│ {all,campus_test0,campus_test1,campus_train0,campus_train1,ciampino_test0,ciampino_test1,ciampino_t… │ -│ (required) │ -│ --save-dir PATH (required) │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────╯ -``` - - diff --git a/python/pyproject.toml b/python/pyproject.toml index b196e85..d633280 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -24,15 +24,12 @@ dependencies = [ "natsort", "numpy", "rich", - "tqdm", "typer[all]>=0.10.0", - "rosbags", - "tyro" + "rosbags" ] [project.scripts] -vbr_download = "vbr_devkit.download.download_data:entrypoint" -vbr_convert = "vbr_devkit.datasets.convert_bag:entrypoint" +vbr = "vbr_devkit.tools.run:app" [project.urls] Homepage = "https://github.com/rvp-group/vbr-devkit" diff --git a/python/vbr_devkit/__init__.py b/python/vbr_devkit/__init__.py index b3c06d4..34cd888 100644 --- a/python/vbr_devkit/__init__.py +++ b/python/vbr_devkit/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1" \ No newline at end of file +__version__ = "0.0.0-alpha" \ No newline at end of file diff --git a/python/vbr_devkit/datasets/__init__.py b/python/vbr_devkit/datasets/__init__.py index e69de29..de7e46a 100644 --- a/python/vbr_devkit/datasets/__init__.py +++ b/python/vbr_devkit/datasets/__init__.py @@ -0,0 +1,2 @@ +from .ros import RosReader +from .kitti import KittiWriter \ No newline at end of file diff --git a/python/vbr_devkit/datasets/convert_bag.py b/python/vbr_devkit/datasets/convert_bag.py index fea1208..3d7b2d5 100644 --- a/python/vbr_devkit/datasets/convert_bag.py +++ b/python/vbr_devkit/datasets/convert_bag.py @@ -1,25 +1,29 @@ -import tyro +import sys +sys.path.append("/home/eg/source/vbr-devkit/python") + from pathlib import Path -from typing import Union -from dataclasses import dataclass -from typing_extensions import Annotated -from kitti import KittiWriter -@dataclass -class Args: - input_dir: Path = "/" - to = Union[ - Annotated[KittiWriter, tyro.conf.subcommand(name="kitti")] - ] - output_dir: Path = "/" +from vbr_devkit.datasets import KittiWriter, RosReader +import typer +from enum import Enum +from rich.progress import track -def main(args: Args) -> None: - ... -def entrypoint(): - tyro.run(main) +class OutputDataInterface(str, Enum): + kitti = "kitti", + # Can insert additional conversion formats -if __name__ == "__main__": - entrypoint() - +OutputDataInterface_lut = { + OutputDataInterface.kitti: KittiWriter +} + +def main(to: OutputDataInterface, input_dir: Path, output_dir: Path) -> None: + with RosReader(input_dir) as reader: + with OutputDataInterface_lut[to](output_dir) as writer: + for timestamp, topic, message in track(reader, description="Processing..."): + writer.publish(timestamp, topic, message) + + +if __name__ == "__main__": + typer.run(main) diff --git a/python/vbr_devkit/datasets/kitti.py b/python/vbr_devkit/datasets/kitti.py index 767d760..c086e06 100644 --- a/python/vbr_devkit/datasets/kitti.py +++ b/python/vbr_devkit/datasets/kitti.py @@ -117,7 +117,7 @@ def __init__(self, data_dir: Path): def __enter__(self): return self - def publish(self, topic: str, timestamp, message: Union[PointCloudXf, Image, Imu]): + def publish(self, timestamp, topic: str, message: Union[PointCloudXf, Image, Imu]): if topic not in self.data_handles.keys(): # Infer path to store stuff # Remove first / on topic @@ -131,16 +131,4 @@ def publish(self, topic: str, timestamp, message: Union[PointCloudXf, Image, Imu def __exit__(self, exc_type, exc_val, exc_tb): for handle in self.data_handles: - self.data_handles[handle].close() - - -from ros import RosReader -from tqdm import tqdm - -if __name__ == "__main__": - with KittiWriter(Path("/home/eg/data/test_download/vbr_slam/campus/campus_test0/campus_test0_00_kitti")) as writer: - with RosReader(Path("/home/eg/data/test_download/vbr_slam/campus/campus_test0/campus_test0_00.bag"), - topics=["/ouster/points"]) as reader: - for timestamp, topic, message in tqdm(reader, desc="Reading bag"): - # print(f"Topic={topic} | Timestamp={timestamp} (type=({type(timestamp)}) | message_type=({type(message)})") - writer.publish(topic, timestamp, message) + self.data_handles[handle].close() \ No newline at end of file diff --git a/python/vbr_devkit/datasets/ros.py b/python/vbr_devkit/datasets/ros.py index c93d666..2f3acf9 100644 --- a/python/vbr_devkit/datasets/ros.py +++ b/python/vbr_devkit/datasets/ros.py @@ -9,7 +9,7 @@ class RosReader: - def __init__(self, data_dir: Sequence[Path], topics: List[str] = None, *args, + def __init__(self, data_dir: Union[Path, Sequence[Path]], topics: List[str] = None, *args, **kwargs): """ :param data_dir: Directory containing rosbags or path to a rosbag file diff --git a/python/vbr_devkit/datasets/split_bag.py b/python/vbr_devkit/datasets/split_bag.py deleted file mode 100644 index b5f42bf..0000000 --- a/python/vbr_devkit/datasets/split_bag.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import sys -from pathlib import Path -import typer -from tqdm import tqdm -from typing import List - - -def main(input_bag_f: str, output_dir: str, size_gb: float = 5, skip_topics: List[str] = []): - try: - from rosbag import Bag - except ModuleNotFoundError: - #TODO: Rewrite error message - print("rosbag library not installed") - sys.exit(-1) - - print(f'Opening input bag: {Path(input_bag_f)}') - input_bag = Bag(input_bag_f, 'r') - - # Get list of available topics - input_topics = list(input_bag.get_type_and_topic_info()[1].keys()) - input_topics = list(filter(lambda x: x not in skip_topics, input_topics)) - print(f"Writing the following topics: {input_topics}") - - - - output_prefix = Path(input_bag_f).name.split(".")[0] - - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - # total_size = input_bag.size - # num_bags = int(total_size / (1024 ** 3 * size_gb)) + 1 - # print(f'Splitting {input_bag_f} in {num_bags} bags of size {size_gb} Gb') - - count = 0 - current_size = 0 - output_bag = None - - for topic, msg, t in tqdm(input_bag.read_messages(topics=input_topics, raw=True), total=input_bag.get_message_count(), - desc=f"Reading {input_bag_f}"): - if output_bag is None: - output_filename = f"{output_prefix}_{count}.bag" - output_filename = Path(output_dir) / output_filename - output_bag = Bag(output_filename, 'w') - current_size = 0 - - if current_size >= size_gb * 1024 ** 3: - if output_bag is not None: - output_bag.close() - count += 1 - output_filename = f"{output_prefix}_{count}.bag" - output_filename = Path(output_dir) / output_filename - output_bag = Bag(output_filename, 'w') - current_size = 0 - - output_bag.write(topic, msg, t, raw=True) - current_size += len(msg[1]) - - if output_bag is not None: - output_bag.close() - - input_bag.close() - - print("Success") - - -if __name__ == '__main__': - typer.run(main) diff --git a/python/vbr_devkit/download/download_data.py b/python/vbr_devkit/download/download_data.py index 6a43793..d664ff2 100644 --- a/python/vbr_devkit/download/download_data.py +++ b/python/vbr_devkit/download/download_data.py @@ -1,16 +1,13 @@ from pathlib import Path import ftplib -import tyro -from typing import TYPE_CHECKING -from rich.console import Console -from rich.progress import Progress +from rich.progress import Progress, SpinnerColumn, TextColumn from rich.panel import Panel +from vbr_devkit.tools.console import console DATASET_LINK = "151.100.59.119" FTP_USER = "anonymous" vbr_downloads = [ - "all", "campus_test0", "campus_test1", "campus_train0", @@ -29,44 +26,51 @@ "spagna_train0", ] -if TYPE_CHECKING: - VbrSlamCaptureName = str -else: - VbrSlamCaptureName = tyro.extras.literal_type_from_choices(vbr_downloads) +def download_seq_fld(seq: str, output_dir: Path) -> None: + def human_readable_size(size, decimal_places=2): + for unit in ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB']: + if size < 1024.0 or unit == 'PiB': + break + size /= 1024.0 + return f"{size:.{decimal_places}f} {unit}" -CONSOLE = Console(width=120) -def main( - dataset: VbrSlamCaptureName, - save_dir: Path, -): - CONSOLE.rule(f"[bold green] Downloading {dataset}") - if dataset == "all": - for seq in vbr_downloads: - if seq != "all": - main(seq, save_dir) - - - save_dir.mkdir(parents=True, exist_ok=True) + console.rule(f"[bold green] Downloading {seq}") + # output_dir.mkdir(parents=True, exist_ok=True) + # Establish FTP connection + console.log(f"Connecting to {DATASET_LINK}") ftp = ftplib.FTP(DATASET_LINK) ftp.login(FTP_USER, "") - db_path = "vbr_slam/" + dataset.split("_")[0] + "/" + dataset - ftp.cwd(db_path) - try: - available_files = ftp.nlst() - except ftplib.error_perm as resp: - if str(resp) == "550 No files found": - CONSOLE.log("[bold red] Invalid input sequence") - else: - raise + console.log(":white_check_mark: Connection established") + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + transient=True) as progress: + progress.add_task("Gathering files", total=None) + db_path = "vbr_slam/" + seq.split("_")[0] + "/" + seq + ftp.cwd(db_path) - CONSOLE.print(Panel( - " ".join([f"{f}" for f in available_files]), title="Download table")) + try: + available_files = ftp.nlst() + available_files = [(file, ftp.size(file)) for file in available_files] + except ftplib.error_perm as resp: + if str(resp) == "550 No files found": + console.log("[bold red] Invalid input sequence") + else: + raise + # Sort based on size + available_files = sorted(available_files, key=lambda x: x[1]) + console.print(Panel( + "\n".join(f"{f[0]}\t{human_readable_size(f[1])}" for f in available_files), title="Downloading files" + )) + + available_files = [x[0] for x in available_files] + # Downloading routine with Progress() as progress: for f in available_files: - local_path = save_dir / "vbr_slam" / dataset.split("_")[0] / dataset - local_path.mkdir(parents=True, exist_ok=True) + local_path = output_dir / "vbr_slam" / seq.split("_")[0] / seq + local_path.mkdir(exist_ok=True, parents=True) local_fname = local_path / f fout = open(local_fname, "wb") task = progress.add_task(f"Downloading {f}", total=ftp.size(f)) @@ -78,12 +82,4 @@ def write_cb(data): ftp.retrbinary("RETR " + f, write_cb) fout.close() ftp.quit() - CONSOLE.print("[bold green] Done!") - - -def entrypoint(): - tyro.cli(main) - - -if __name__ == "__main__": - entrypoint() + console.print(":tada: Completed") diff --git a/python/vbr_devkit/tools/console.py b/python/vbr_devkit/tools/console.py new file mode 100644 index 0000000..5f8e6bd --- /dev/null +++ b/python/vbr_devkit/tools/console.py @@ -0,0 +1,3 @@ +from rich.console import Console + +console = Console() \ No newline at end of file diff --git a/python/vbr_devkit/tools/run.py b/python/vbr_devkit/tools/run.py new file mode 100644 index 0000000..13333fd --- /dev/null +++ b/python/vbr_devkit/tools/run.py @@ -0,0 +1,76 @@ +import sys + +sys.path.append("/home/eg/source/vbr-devkit/python") +import typer +from pathlib import Path +from rich.console import Group +from rich.panel import Panel +from rich.progress import track +from typing import Sequence +from typing_extensions import Annotated +from vbr_devkit.datasets import RosReader +from vbr_devkit.download.download_data import vbr_downloads, download_seq_fld +from vbr_devkit.datasets.convert_bag import OutputDataInterface, OutputDataInterface_lut +from vbr_devkit.tools.console import console + +app = typer.Typer() + + +@app.command("list", + help="List all available VBR sequences") +def list_sequences() -> None: + panel_group = Group( + Panel("\n".join(["all", "train", "test"]), title="Meta"), + Panel("\n".join([x for x in vbr_downloads if "train" in x]), title="Train"), + Panel("\n".join([x for x in vbr_downloads if "test" in x]), title="Test") + ) + console.print(panel_group) + + +def complete_sequence(incomplete: str) -> Sequence[str]: + for seq in ["all", "train", "test"] + vbr_downloads: + if seq.startswith(incomplete): + yield seq + + +@app.command(help="Download one or more VBR sequences. Type 'vbr list' to see the available sequences.") +def download(sequence: Annotated[ + str, typer.Argument(help="Name of the sequence to download", show_default=False, autocompletion=complete_sequence)], + output_dir: Annotated[ + Path, typer.Argument(help="Output directory. The sequence will be stored in a sub-folder", + show_default=False)]) -> None: + if sequence == "all": + console.print(":boom: Downloading all sequences") + console.print("[yellow] It will take a while") + for seq in vbr_downloads: + download_seq_fld(seq, output_dir) + elif sequence == "train" or sequence == "test": + console.print(f":woman_student: Downloading {sequence} sequences") + console.print("[yellow] It will take a while") + for seq in filter(lambda x: f"{sequence}" in x, vbr_downloads): + download_seq_fld(seq, output_dir) + else: + if sequence not in vbr_downloads: + console.log( + f":thinking_face: Error {sequence} is not a valid sequence. Type 'vbr list' to see available sequences.") + sys.exit(-1) + download_seq_fld(sequence, output_dir) + + +@app.command(help="Convert a sequence from ROS1 to other known formats") +def convert(to: Annotated[OutputDataInterface, typer.Argument(help="Desired data format")], + input_dir: Annotated[ + Path, typer.Argument(help="Input bag or directory containing multiple bags", show_default=False)], + output_dir: Annotated[ + Path, typer.Argument(help="Output directory in which the data will be stored", show_default=False)], + ) -> None: + console.print(f"Converting {input_dir} to {to} format at {output_dir}") + with RosReader(input_dir) as reader: + with OutputDataInterface_lut[to](output_dir) as writer: + for timestamp, topic, message in track(reader, description="Processing..."): + writer.publish(timestamp, topic, message) + console.print(":tada: Completed") + + +if __name__ == "__main__": + app()