diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4273237..d541139 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,21 +10,29 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout code uses: actions/checkout@v4 + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev libopencv-dev + - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.10' + # poetry==2.1.4 has a bug - name: Install dependencies run: | - pip install poetry - POETRY_VIRTUALENVS_CREATE=false poetry install + pip install poetry==2.1.3 + poetry install + env: + POETRY_VIRTUALENVS_CREATE: false - name: Pre-commit checks run: | @@ -42,3 +50,4 @@ jobs: uses: actions-rs/cargo@v1 with: command: test + args: --features=improc,gaia diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1c9e770..b47cbc3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,23 +14,13 @@ jobs: make-release: name: Create Release - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout code uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Install dependencies and create source distribution - run: | - pip install poetry - poetry build --format sdist - - name: Create Release id: create_release uses: actions/create-release@v1 @@ -42,6 +32,36 @@ jobs: draft: false prerelease: false + build-python-sdist: + name: Build Python sdist + runs-on: ubuntu-22.04 + + needs: + - make-release + + steps: + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev libopencv-dev + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + # poetry==2.1.4 has a bug + - name: Install dependencies and create source distribution + run: | + pip install poetry==2.1.3 + poetry build --format sdist + env: + POETRY_VIRTUALENVS_CREATE: false + - name: Upload sdist run: | for file in ./dist/*; do @@ -68,15 +88,23 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev libopencv-dev + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + # poetry==2.1.4 has a bug - name: Install dependencies and build the wheel run: | - pip install poetry + pip install poetry==2.1.3 poetry build --format wheel + env: + POETRY_VIRTUALENVS_CREATE: false - name: Upload wheel run: | @@ -90,7 +118,7 @@ jobs: pypi-publish: name: Upload release to PyPI - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 environment: name: pypi @@ -104,6 +132,7 @@ jobs: needs: - build-python-wheels + - build-python-sdist steps: diff --git a/.github/workflows/version_bump.yml b/.github/workflows/version_bump.yml index 7edcc5f..41935dc 100644 --- a/.github/workflows/version_bump.yml +++ b/.github/workflows/version_bump.yml @@ -10,7 +10,7 @@ permissions: jobs: bump-version: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 if: "!contains(github.event.head_commit.message, 'ci: bump')" diff --git a/Cargo.toml b/Cargo.toml index 940adf4..cb0b921 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,13 +5,26 @@ edition = "2021" [lib] name = "ruststartracker" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [dependencies] itertools = "0.13.0" kdtree = "0.7.0" maths-rs = "0.2.6" -nalgebra = "0.31.4" +nalgebra = "0.33.2" polyfit-rs = "0.2.1" pyo3 = { version = "0.21.0", features = ["extension-module"] } numpy = "0.21.0" +opencv = { version = "0.94.4", optional = true } +chrono = "0.4.41" +csv = "1.3.1" +serde = { version = "1.0.219", features = ["derive"] } + +[dev-dependencies] +rand = "0.9.1" +rand_distr = "0.5.1" + +[features] +default = [] +improc = ["opencv"] +gaia = [] diff --git a/README.md b/README.md index d79b17a..8c7b1e9 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,35 @@ Features: ## Example +### Rust + +See [examples/basic.rs](examples/basic.rs) + +```rust +// Get catalog positions +let catalog: StarCatalog = StarCatalog::from_gaia(max_magnitude: ...).unwrap(); +let stars_xyz: Vec<[f32; 3]> = catalog.normalized_positions(epoch: ..., observer_position: ...); +let stars_mag: Vec = catalog.magnitudes(); + +// Create StarTracker instance (reuse this) +let star_matcher = StarMatcher::new( + stars_xyz, + stars_mag, + max_lookup_magnitude: ... + max_inter_star_angle: ..., + inter_star_angle_tolerance: ..., + min_matches: ..., + timeout: ... +); + +// Normalized observation in the camera frame +let obs_xyz_camera: Vec<[f32; 3]> = ... + +let result = star_matcher.find(&obs_xyz_camera); +println!("Result: {:?}", result); +``` + +### Python ```python import ruststartracker @@ -51,11 +80,6 @@ print(result) - Install with `pip install ruststartracker` (Currently only ARM/x86 Linux wheels available). -## TODOs - -- Improve error messages. -- Return more diagnostic data. - ## Attributions ### Gaia Data diff --git a/build_script.py b/build_script.py index 5e99887..660ecb7 100644 --- a/build_script.py +++ b/build_script.py @@ -50,7 +50,12 @@ def build_script() -> None: if not gaia_file.exists(): download_gaia_data(gaia_file) - subprocess.check_call(["cargo", "build", "--release"], cwd=cwd) # noqa: S603, S607 + subprocess.check_call( # noqa: S603 + ["cargo", "build", "--release", "--features", "improc,gaia"], # noqa: S607 + cwd=cwd, + stdout=None, + stderr=None, + ) shutil.copy( cwd / "target/release/libruststartracker.so", cwd / "ruststartracker/libruststartracker.so" ) diff --git a/examples/basic.rs b/examples/basic.rs new file mode 100644 index 0000000..878b5e0 --- /dev/null +++ b/examples/basic.rs @@ -0,0 +1,35 @@ +use ruststartracker::star::StarMatcher; +use ruststartracker::starcat::StarCatalog; + +fn get_observations() -> Vec<[f32; 3]> { + // For demo purposes we extract some bright stars from the catalog. + + // Get catalog positions + let catalog = StarCatalog::new_from_gaia(Some(5.0)).unwrap(); + let stars_xyz: Vec<[f32; 3]> = catalog.normalized_positions(Some(2025.0), None); + + // Get some observations from the catalog + stars_xyz + .iter() + .filter(|x| x[1] > f32::cos(0.5)) + .map(|x| *x) + .collect() +} + +fn main() { + // Get catalog positions + let catalog = StarCatalog::new_from_gaia(Some(6.0)).unwrap(); + let stars_xyz: Vec<[f32; 3]> = catalog.normalized_positions(Some(2025.0), None); + let stars_mag: Vec = catalog.magnitudes(); + + // Create StarTracker instance (reuse this) + let star_matcher = StarMatcher::new(stars_xyz, &stars_mag, 5.0, 1.0, 0.002, 10, 0.2).unwrap(); + + // Get observation in the camera frame (provide this function) + let obs_xyz_camera: Vec<[f32; 3]> = get_observations(); + + // Lookup attitude + let result = star_matcher.find(&obs_xyz_camera).unwrap(); + + println!("Result: {:?}", result); +} diff --git a/ruststartracker/libruststartracker.pyi b/ruststartracker/libruststartracker.pyi index fa266ad..a48fc65 100644 --- a/ruststartracker/libruststartracker.pyi +++ b/ruststartracker/libruststartracker.pyi @@ -2,12 +2,15 @@ from collections.abc import Iterator import numpy as np import numpy.typing as npt +from typing_extensions import Self class StarMatcher: def __init__( self, stars_xyz: npt.NDArray[np.float32], + stars_mag: npt.NDArray[np.float32], max_inter_star_angle: float, + max_lookup_magnitude: float, inter_star_angle_tolerance: float, n_minimum_matches: int, timeout_secs: float, @@ -19,7 +22,7 @@ class StarMatcher: npt.NDArray[np.uint32], npt.NDArray[np.uint32], int, - list[list[float]], + npt.NDArray[np.float32], float, ]: ... @@ -44,15 +47,43 @@ class IterTriangleFinder: class UnitVectorLookup: def __init__(self, vec: npt.NDArray[np.float32]) -> None: ... def lookup_nearest(self, key: npt.NDArray[np.float32]) -> int: ... - def get_inter_star_index_numpy( - self, vec: npt.NDArray[np.float32], angle_threshold: float - ) -> tuple[list[list[int]], list[float], list[float]]: ... def get_inter_star_index( - self, vec: npt.NDArray[np.float32], angle_threshold: float + self, + stars: npt.NDArray[np.float32], + magnitudes: npt.NDArray[np.float32], + max_angle_rad: float, + max_magnitude: float, ) -> tuple[list[list[int]], list[float], list[float]]: ... def look_up_close_angles( - self, vectors: npt.NDArray[np.float32], max_angle_rad: float + self, + vectors: npt.NDArray[np.float32], + magnitudes: npt.NDArray[np.float32], + max_angle_rad: float, + max_magnitude: float, ) -> list[tuple[list[float], float]]: ... def look_up_close_angles_naive( - self, vectors: npt.NDArray[np.float32], max_angle_rad: float + self, + vectors: npt.NDArray[np.float32], + magnitudes: npt.NDArray[np.float32], + max_angle_rad: float, + max_magnitude: float, ) -> list[tuple[list[float], float]]: ... + +def get_threshold_from_histogram( + img: npt.NDArray[np.uint8], + *, + fraction: float, +) -> int: ... +def extract_observations( + img: npt.NDArray[np.uint8], + threshold: int, + min_star_area: int, + max_star_area: int, +) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: ... + +class StarCatalog: + @classmethod + def from_gaia(cls, *, max_magnitude: float | None) -> Self: ... + def normalized_positions( + self, *, epoch: float | None, observer_position: np.ndarray | None + ) -> npt.NDArray[np.float32]: ... diff --git a/ruststartracker/star.py b/ruststartracker/star.py index 32cc2d5..aa3dd1c 100644 --- a/ruststartracker/star.py +++ b/ruststartracker/star.py @@ -87,8 +87,10 @@ class StarTracker: def __init__( self, stars_xyz: npt.NDArray[np.float32], + stars_mag: npt.NDArray[np.float32], camera_params: CameraParameters, *, + max_lookup_magnitude: float | None = None, max_inter_star_angle: float | None = None, inter_star_angle_tolerance: float = 0.0008, n_minimum_matches: int = 10, @@ -98,11 +100,15 @@ def __init__( Args: stars_xyz: Positions of catalog stars + stars_mag: Magnitudes of catalog stars camera_params: Calibrated camera parameters + max_lookup_magnitude: Maximum magnitude of stars used in the triangulation. Reducing + this number means only bright stars are used for triangulation. This results in + faster lookup performance. max_inter_star_angle: Maximum angle between stars that should be indexed. Calculating large inter star angles is expensive. If None, the angle is calculated from the camera field of view. - inter_star_angle_tolerance: Tolerance for inter star angle matching. + inter_star_angle_tolerance: Tolerance for inter star angle matching in rad. n_minimum_matches: Minimum amount of required matches for a successful attitude estimation timeout_secs: Maximum allowed search time in seconds. A StarTrackerError is raised @@ -127,8 +133,13 @@ def __init__( dot_products = (corner_coords_xyz * np.array([0, 0, 1], dtype=np.float32)).sum(axis=-1) max_inter_star_angle = float(np.arccos(dot_products.max())) * 2 + if max_lookup_magnitude is None: + max_lookup_magnitude = 100.0 # A very faint star. Almost infinity + self._star_matcher = ruststartracker.libruststartracker.StarMatcher( np.ascontiguousarray(stars_xyz, dtype=np.float32), + np.ascontiguousarray(stars_mag, dtype=np.float32), + float(max_lookup_magnitude), float(max_inter_star_angle), float(inter_star_angle_tolerance), int(n_minimum_matches), @@ -178,6 +189,20 @@ def get_centroids( max_star_area=max_star_area, ) + if threshold is None: + threshold = ruststartracker.libruststartracker.get_threshold_from_histogram( + img, fraction=0.99 + ) + + centroids, intensities = ruststartracker.libruststartracker.extract_observations( + img, + threshold, + min_star_area, + max_star_area, + ) + centroids = np.array(centroids, dtype=np.float32) + intensities = np.array(intensities, dtype=np.float32) + # At least 3 observations are required (one triangle) if len(centroids) < 3: raise StarTrackerError("Found too few star candidates (< 3) to continue.") @@ -238,12 +263,12 @@ def process_observation_vectors(self, x_obs: npt.NDArray[np.floating]) -> StarTr quat, match_ids, obs_indices, n_matches, matched_obs, duration_s = result return StarTrackerResult( - quat=np.asarray(quat, dtype=np.float32), - match_ids=np.asarray(match_ids, dtype=np.uint32), + quat=quat, + match_ids=match_ids, n_matches=n_matches, duration_s=duration_s, - mached_obs_x=np.asarray(matched_obs, dtype=np.float32), - obs_indices=np.asarray(obs_indices, dtype=np.uint32), + mached_obs_x=matched_obs, + obs_indices=obs_indices, ) diff --git a/ruststartracker/test_backend.py b/ruststartracker/test_backend.py index ac5122d..c9aece5 100644 --- a/ruststartracker/test_backend.py +++ b/ruststartracker/test_backend.py @@ -54,9 +54,12 @@ def test_unit_vector_lookup(): close_indices_gt = np.concatenate(results, axis=-1).T[args] angles_gt = angles[args] - close_indices, angles, poly = uvl.get_inter_star_index_numpy(vec, angle_threshold) - - close_indices, angles, poly = uvl.get_inter_star_index(vec[:, :3], angle_threshold) + close_indices, angles, poly = uvl.get_inter_star_index( + np.array(vec[:, :3], dtype=np.float32), + np.ones(len(vec), dtype=np.float32), + angle_threshold, + 10, + ) close_indices = np.array(close_indices) angles = np.array(angles) poly = np.array(poly) @@ -101,6 +104,8 @@ def test_star_matcher(): vec = rng.normal(size=[n_cat_stars, 3]).astype(np.float32) vec /= np.linalg.norm(vec, axis=-1, keepdims=True) + magnitudes = rng.uniform(0, 10, size=vec.shape[:1]).astype(np.float32) + key = rng.normal(size=[3]).astype(np.float32) key /= np.linalg.norm(key, axis=-1, keepdims=True) @@ -113,10 +118,16 @@ def test_star_matcher(): rot = scipy.spatial.transform.Rotation.from_rotvec([1, 1, 1]) - obs_rotated = rot.apply(obs) + obs_rotated = rot.apply(obs).astype(np.float32) index = libruststartracker.StarMatcher( - vec, np.radians(10).item(), np.radians(0.1).item(), 4, 999.0 + vec, + magnitudes, + 10, + np.radians(10).item(), + np.radians(0.1).item(), + 4, + 999.0, ) res = index.find(obs_rotated) @@ -124,7 +135,7 @@ def test_star_matcher(): assert res is not None quat, match_ids, obs_indices, n_matches, matched_obs, time_s = res - np.testing.assert_allclose(quat, rot.inv().as_quat()) + np.testing.assert_allclose(quat, rot.inv().as_quat(), rtol=1e-6) assert n_matches >= 4 assert len(obs_index) == len(match_ids) diff --git a/ruststartracker/test_catalog.py b/ruststartracker/test_catalog.py index 7d6d86d..70794f0 100644 --- a/ruststartracker/test_catalog.py +++ b/ruststartracker/test_catalog.py @@ -1,10 +1,12 @@ import datetime +import time import astropy.time # type: ignore[import] import numpy as np import pytest import ruststartracker.catalog +import ruststartracker.libruststartracker def test_time_to_epoch(): @@ -43,5 +45,17 @@ def test_extract_observations(): np.testing.assert_allclose(np.linalg.norm(positions, axis=-1), 1.0, rtol=1e-5) +def test_python_rust(): + t0 = time.monotonic() + positions = ruststartracker.catalog.StarCatalog().normalized_positions(epoch=2025.0) + print(f"Python catalog took {time.monotonic() - t0:.3f} seconds") + t0 = time.monotonic() + positions2 = ruststartracker.libruststartracker.StarCatalog.from_gaia( + max_magnitude=6.0 + ).normalized_positions(epoch=2025.0, observer_position=None) + print(f"Rust catalog took {time.monotonic() - t0:.3f} seconds") + np.testing.assert_allclose(positions, positions2, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/ruststartracker/test_integration.py b/ruststartracker/test_integration.py index e70abc7..920e9eb 100644 --- a/ruststartracker/test_integration.py +++ b/ruststartracker/test_integration.py @@ -26,6 +26,7 @@ def prepare() -> tuple[ruststartracker.StarTracker, np.ndarray]: catalog = ruststartracker.StarCatalog() star_catalog_vecs = catalog.normalized_positions(epoch=2024) + star_catalog_magnitudes = catalog.magnitude camera_params = ruststartracker.CameraParameters( camera_matrix=camera_matrix, @@ -35,6 +36,7 @@ def prepare() -> tuple[ruststartracker.StarTracker, np.ndarray]: st = ruststartracker.StarTracker( star_catalog_vecs, + star_catalog_magnitudes, camera_params, inter_star_angle_tolerance=np.radians(0.05).item(), n_minimum_matches=5, diff --git a/ruststartracker/test_star.py b/ruststartracker/test_star.py index 16d9974..7f8ee63 100644 --- a/ruststartracker/test_star.py +++ b/ruststartracker/test_star.py @@ -5,16 +5,31 @@ import scipy.spatial import ruststartracker +import ruststartracker.libruststartracker import ruststartracker.star -def test_extract_observations(): - size_x, size_y = (40, 50) +@pytest.mark.parametrize("impl", ["python", "rust"]) +def test_extract_observations(impl: str): + size_x, size_y = (960, 480) img = np.zeros((size_y, size_x), np.uint8) - points = np.array([(3, 5), (23, 13)]) + points = np.array([(3, 5), (23, 13), (30, 50), (230, 130)]) for x, y in points: img[y - 1 : y + 3, x - 1 : x + 3] = 50 - centers, intensities = ruststartracker.star._extract_observations(img, threshold=30) + + t0 = time.monotonic() + if impl == "python": + centers, intensities = ruststartracker.star._extract_observations(img, threshold=30) + elif impl == "rust": + centers, intensities = ruststartracker.libruststartracker.extract_observations( + img, 30, 3, 300 + ) + else: + raise AssertionError + print(f"Extracting observations took {time.monotonic() - t0:.5f} seconds") + + assert isinstance(centers, np.ndarray) + assert isinstance(intensities, np.ndarray) np.testing.assert_almost_equal(centers, points + 0.5) np.testing.assert_almost_equal(intensities, 50 * 16) @@ -28,6 +43,8 @@ def setup(): vec = rng.normal(size=[n_cat_stars, 3]).astype(np.float32) vec /= np.linalg.norm(vec, axis=-1, keepdims=True) + mag = rng.uniform(0, 10, size=vec.shape[:1]).astype(np.float32) + angle_threshold = np.radians(10) dotp = np.sum([0, 0, 1] * vec, axis=-1) threshold = np.cos(angle_threshold).item() @@ -55,17 +72,18 @@ def setup(): image_patch = img[y - 1 : y + 2, x - 1 : x + 2] image_patch[:] = 50 - return img, vec, camera_params + return img, vec, mag, pixel_in_frame, camera_params def test_star_matcher_success(setup): - img, vec, camera_params = setup + img, vec, mag, _, camera_params = setup rot = scipy.spatial.transform.Rotation.from_rotvec([1, 1, 1]) vec = rot.inv().apply(vec) st = ruststartracker.StarTracker( vec, + mag, camera_params, inter_star_angle_tolerance=np.radians(0.1).item(), n_minimum_matches=6, @@ -78,34 +96,51 @@ def test_star_matcher_success(setup): def test_star_matcher_exhaust(setup): - img, vec, camera_params = setup + img, vec, mag, _, camera_params = setup st = ruststartracker.StarTracker( vec, + mag, camera_params, inter_star_angle_tolerance=np.radians(0.001).item(), n_minimum_matches=500, timeout_secs=999.0, ) - with pytest.raises(ruststartracker.StarTrackerError, match="exhaust"): + with pytest.raises(ruststartracker.StarTrackerError, match="SearchExhausted"): st.process_image(img) def test_star_matcher_timout(setup): - img, vec, camera_params = setup - timeout = 0.2 + img, vec, mag, _, camera_params = setup + timeout = 0.0002 st = ruststartracker.StarTracker( vec, + mag, camera_params, inter_star_angle_tolerance=np.radians(0.1).item(), n_minimum_matches=500, timeout_secs=timeout, ) t = time.monotonic() - with pytest.raises(ruststartracker.StarTrackerError, match="Timeout reached"): + with pytest.raises(ruststartracker.StarTrackerError, match="Timeout"): st.process_image(img) passed_time = time.monotonic() - t assert passed_time > timeout +def test_star_matcher_not_enough_stars(setup): + _, vec, mag, pixel_in_frame, camera_params = setup + timeout = 0.2 + st = ruststartracker.StarTracker( + vec, + mag, + camera_params, + inter_star_angle_tolerance=np.radians(0.1).item(), + n_minimum_matches=500, + timeout_secs=timeout, + ) + with pytest.raises(ruststartracker.StarTrackerError, match="NotEnoughStars"): + st.process_image_coordiantes(pixel_in_frame[:2]) + + if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "--capture=no"]) diff --git a/src/lib.rs b/src/lib.rs index 44cbb39..df58fdf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,22 @@ -use numpy::{self, PyUntypedArrayMethods}; +use numpy::{self, PyArrayMethods, PyUntypedArrayMethods, ToPyArray}; +use pyo3::prelude::*; +use pyo3::types::PyAny; +#[cfg(feature = "gaia")] +use pyo3::types::PyType; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, pymodule, types::PyModule, Bound, PyRef, PyRefMut, PyResult, }; +use std::path::PathBuf; use std::{time::Instant, usize}; mod ordered_combinations; -mod star; +pub mod star; +pub mod starcat; +#[cfg(feature = "improc")] +pub mod starextraction; mod tree; mod trianglefinder; -mod util; #[pymodule] fn libruststartracker(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -17,6 +24,11 @@ fn libruststartracker(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + #[cfg(feature = "improc")] + m.add_function(wrap_pyfunction!(get_threshold_from_histogram, m)?)?; + #[cfg(feature = "improc")] + m.add_function(wrap_pyfunction!(extract_observations, m)?)?; Ok(()) } @@ -89,15 +101,21 @@ struct StarMatcher { #[pymethods] impl StarMatcher { #[new] - fn new( - stars_xyz: Vec<[f32; 3]>, + fn new<'py>( + stars_xyz: numpy::PyReadonlyArray2<'py, f32>, + stars_mag: numpy::PyReadonlyArray1<'py, f32>, + max_lookup_magnitude: f32, max_inter_star_angle: f32, inter_star_angle_tolerance: f32, n_minimum_matches: usize, timeout_secs: f32, ) -> PyResult { + let stars_slice: &[[f32; 3]] = numpy_to_slice_2d(&stars_xyz)?; + let mags_slice: &[f32] = numpy_to_slice_1d(&stars_mag)?; match star::StarMatcher::new( - stars_xyz, + stars_slice.to_vec(), + mags_slice, + max_lookup_magnitude, max_inter_star_angle, inter_star_angle_tolerance, n_minimum_matches, @@ -108,24 +126,33 @@ impl StarMatcher { } } - pub fn find( + pub fn find<'py>( &self, - obs_xyz: Vec<[f32; 3]>, - ) -> PyResult<([f32; 4], Vec, Vec, u32, Vec<[f32; 3]>, f32)> { + py: Python<'py>, + obs_xyz: numpy::PyReadonlyArray2<'py, f32>, + ) -> PyResult<( + Bound<'py, numpy::PyArray1>, + Bound<'py, numpy::PyArray1>, + Bound<'py, numpy::PyArray1>, + u32, + Bound<'py, numpy::PyArray2>, + f32, + )> { + let obs_xyz_slice: &[[f32; 3]] = numpy_to_slice_2d(&obs_xyz)?; let now = Instant::now(); - let res = self.inner.find(obs_xyz); + let res = self + .inner + .find(obs_xyz_slice) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let duration_s = now.elapsed().as_secs_f32(); - match res { - Err(x) => Err(PyRuntimeError::new_err(x)), - Ok(x) => Ok(( - x.quat, - x.match_ids, - x.obs_indices, - x.n_matches, - x.obs_matched, - duration_s, - )), - } + Ok(( + numpy_from_slice_1d(py, &res.quat), + numpy_from_slice_1d(py, &res.match_ids), + numpy_from_slice_1d(py, &res.obs_indices), + res.n_matches, + numpy_from_slice_2d(py, &res.obs_matched), + duration_s, + )) } } @@ -147,52 +174,42 @@ impl UnitVectorLookup { Ok(self.inner.lookup_nearest(&vector)) } - pub fn get_inter_star_index( + pub fn get_inter_star_index<'py>( &self, - vectors: Vec<[f32; 3]>, + stars: numpy::PyReadonlyArray2<'py, f32>, + magnitudes: numpy::PyReadonlyArray1<'py, f32>, max_angle_rad: f32, - ) -> PyResult<(Vec<[u32; 2]>, Vec, Vec)> { + max_magnitude: f32, + ) -> PyResult<(Vec<[u32; 2]>, Vec, [f32; 3])> { + let stars_slice: &[[f32; 3]] = numpy_to_slice_2d(&stars)?; + let magnitudes_slice: &[f32] = numpy_to_slice_1d(&magnitudes)?; let now = Instant::now(); - let res = match star::get_inter_star_index(&self.inner, &vectors, max_angle_rad) { - Ok(res) => res, - Err(s) => { - return Err(PyRuntimeError::new_err(format!( - "Could not calculate inter star angle: {}", - s - ))) - } - }; - println!("Time passed: {:?}", now.elapsed()); - Ok(res) - } + let res = star::InterStarIndex::new( + &self.inner, + stars_slice, + magnitudes_slice, + max_angle_rad, + max_magnitude, + ) + .map_err(|e| { + PyRuntimeError::new_err(format!("Could not calculate inter star angle: {}", e)) + })?; - pub fn get_inter_star_index_numpy<'py>( - &self, - vectors: numpy::PyReadonlyArray2<'py, f32>, - max_angle_rad: f32, - ) -> PyResult<(Vec<[u32; 2]>, Vec, Vec)> { - let now = Instant::now(); - let vectors_inner = numpy_to_vec_3_32f(&vectors).unwrap(); - let res = match star::get_inter_star_index(&self.inner, vectors_inner, max_angle_rad) { - Ok(res) => res, - Err(s) => { - return Err(PyRuntimeError::new_err(format!( - "Could not calculate inter star angle: {}", - s - ))) - } - }; println!("Time passed: {:?}", now.elapsed()); - Ok(res) + Ok((res.pairs, res.angles, res.polynomial)) } pub fn look_up_close_angles( &self, vectors: Vec<[f32; 3]>, + magnitudes: Vec, max_angle_rad: f32, + max_magnitude: f32, ) -> PyResult> { let now = Instant::now(); - let res = self.inner.look_up_close_angles(&vectors, max_angle_rad); + let res = + self.inner + .look_up_close_angles(&vectors, &magnitudes, max_angle_rad, max_magnitude); println!("Time passed: {:?}", now.elapsed()); Ok(res) } @@ -200,23 +217,141 @@ impl UnitVectorLookup { pub fn look_up_close_angles_naive( &self, vectors: Vec<[f32; 3]>, + magnitudes: Vec, max_angle_rad: f32, + max_magnitude: f32, ) -> PyResult> { let now = Instant::now(); - let res = star::look_up_close_angles_naive(&vectors, max_angle_rad); + let res = + star::look_up_close_angles_naive(&vectors, &magnitudes, max_angle_rad, max_magnitude); println!("Time passed: {:?}", now.elapsed()); Ok(res) } } -fn numpy_to_vec_3_32f<'py, const L: usize>( - vectors: &'py numpy::PyReadonlyArray2<'py, f32>, -) -> PyResult<&[[f32; L]]> { - if !vectors.is_c_contiguous() || vectors.ndim() != 2 || vectors.shape()[1] != L { +#[pyclass] +struct StarCatalog { + inner: starcat::StarCatalog, +} + +#[pymethods] +impl StarCatalog { + #[new] + fn new(filename: Bound<'_, PyAny>, epoch: f64, max_magnitude: Option) -> PyResult { + let path: PathBuf = filename.extract()?; + Ok(StarCatalog { + inner: starcat::StarCatalog::new_from_file(path, epoch, max_magnitude) + .map_err(PyRuntimeError::new_err)?, + }) + } + + #[cfg(feature = "gaia")] + #[classmethod] + fn from_gaia(_cls: &Bound<'_, PyType>, max_magnitude: Option) -> PyResult { + Ok(StarCatalog { + inner: starcat::StarCatalog::new_from_gaia(max_magnitude) + .map_err(PyRuntimeError::new_err)?, + }) + } + + pub fn normalized_positions( + &self, + epoch: Option, + observer_position: Option<[f64; 3]>, + ) -> Vec<[f64; 3]> { + self.inner.normalized_positions(epoch, observer_position) + } +} + +#[cfg(feature = "improc")] +#[pyfunction] +pub fn get_threshold_from_histogram<'py>( + img: numpy::PyReadonlyArray2<'py, u8>, + fraction: f64, +) -> PyResult { + if !img.is_c_contiguous() { + return Err(PyRuntimeError::new_err("Image must be a c_contiguous")); + } + Ok(starextraction::get_threshold_from_histogram( + img.as_slice()?, + fraction, + )) +} + +#[cfg(feature = "improc")] +#[pyfunction] +pub fn extract_observations<'py>( + py: Python<'py>, + img: numpy::PyReadonlyArray2<'py, u8>, + threshold_value: u8, + min_area: usize, + max_area: usize, +) -> PyResult<( + Bound<'py, numpy::PyArray2>, + Bound<'py, numpy::PyArray1>, +)> { + if !img.is_c_contiguous() || img.ndim() != 2 { + return Err(PyRuntimeError::new_err( + "Image must be a c_contiguous 2D array", + )); + } + let (centroids, intensities) = starextraction::extract_observations( + img.as_slice()?, + (img.shape()[1], img.shape()[0]), + threshold_value, + min_area, + max_area, + ) + .map_err(|e| PyRuntimeError::new_err(e))?; + + let centroids_np = numpy_from_slice_2d(py, ¢roids); + let intensities_np = numpy_from_slice_1d(py, &intensities); + + Ok((centroids_np, intensities_np)) +} + +fn numpy_from_slice_2d<'py, const N: usize, T>( + py: Python<'py>, + slice: &[[T; N]], +) -> Bound<'py, numpy::PyArray2> +where + T: numpy::Element, +{ + let len = slice.len(); // n + let ptr = slice.as_ptr() as *const T; + let total_len = len * N; + let contiguous_slice = unsafe { std::slice::from_raw_parts(ptr, total_len) }; + let flat_np = contiguous_slice.to_pyarray_bound(py); + flat_np.reshape((len, N)).unwrap() // Save to unwrap as we know the shape is correct +} + +fn numpy_from_slice_1d<'py, T>(py: Python<'py>, slice: &[T]) -> Bound<'py, numpy::PyArray1> +where + T: numpy::Element, +{ + slice.to_pyarray_bound(py) +} + +fn numpy_to_slice_2d<'py, T: numpy::Element + Copy, const L: usize>( + array: &'py numpy::PyReadonlyArray2<'py, T>, +) -> PyResult<&'py [[T; L]]> { + if !array.is_c_contiguous() || array.ndim() != 2 || array.shape()[1] != L { return Err(PyRuntimeError::new_err(format!( "vectors must be a c_contiguous array with shape=[n, {}]", L ))); } - Ok(util::as_vec_of_arrays(vectors.as_slice()?).unwrap()) + let slice = array.as_slice()?; + Ok(unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [T; L], slice.len() / L) }) +} + +fn numpy_to_slice_1d<'py, T: numpy::Element + Copy>( + array: &'py numpy::PyReadonlyArray1<'py, T>, +) -> PyResult<&'py [T]> { + if !array.is_c_contiguous() || array.ndim() != 1 { + return Err(PyRuntimeError::new_err( + "vectors must be a c_contiguous array with shape=[n]", + )); + } + Ok(array.as_slice()?) } diff --git a/src/ordered_combinations.rs b/src/ordered_combinations.rs index df22b87..0d80ae7 100644 --- a/src/ordered_combinations.rs +++ b/src/ordered_combinations.rs @@ -4,18 +4,20 @@ pub struct OrderedCombinations { } impl OrderedCombinations { - pub fn new(n: u32) -> Self { - assert!(n >= K as u32, "n must be at least K"); + pub fn new(n: u32) -> Result { + if n < K as u32 { + return Err("n must be at least K"); + } let mut initial = [0u32; K]; for i in 0..K { initial[i] = i as u32; } - Self { + Ok(Self { n, indices: Some(initial), - } + }) } } @@ -83,7 +85,7 @@ mod tests { #[test] fn test_combination_order_sfa_3_5() { let n = 5; - let iter = OrderedCombinations::<3>::new(n); + let iter = OrderedCombinations::<3>::new(n).unwrap(); let v: Vec<[u32; 3]> = iter.take(11).collect(); assert_eq!(v[0], [0, 1, 2]); @@ -102,7 +104,7 @@ mod tests { #[test] fn test_combination_order_3_6() { let n = 6; - let iter = OrderedCombinations::<3>::new(n); + let iter = OrderedCombinations::<3>::new(n).unwrap(); let v: Vec<[u32; 3]> = iter.take(21).collect(); assert_eq!(v[0], [0, 1, 2]); @@ -131,7 +133,7 @@ mod tests { #[test] fn test_combination_order_4_5() { let n = 5; - let iter = OrderedCombinations::<4>::new(n); + let iter = OrderedCombinations::<4>::new(n).unwrap(); let v: Vec<[u32; 4]> = iter.take(12).collect(); assert_eq!(v[0], [0, 1, 2, 3]); @@ -145,8 +147,8 @@ mod tests { #[test] fn test_combination_order_4_6() { let n = 6; - let iter = OrderedCombinations::<4>::new(n); - let v: Vec<[u32; 4]> = iter.take(16).collect(); + let iter = OrderedCombinations::<4>::new(n).unwrap(); + let v: Vec<[u32; 4]> = iter.collect(); assert_eq!(v[0], [0, 1, 2, 3]); assert_eq!(v[1], [0, 1, 2, 4]); @@ -165,4 +167,17 @@ mod tests { assert_eq!(v[14], [2, 3, 4, 5]); assert_eq!(v.len(), 15); } + + #[test] + fn test_combination_order_3_3() { + let iter = OrderedCombinations::<3>::new(3).unwrap(); + let v: Vec<[u32; 3]> = iter.collect(); + assert_eq!(v[0], [0, 1, 2]); + assert_eq!(v.len(), 1); + } + + #[test] + fn test_combination_order_3_2() { + assert!(OrderedCombinations::<3>::new(2).is_err()); + } } diff --git a/src/star.rs b/src/star.rs index 56d27a1..6cb29c3 100644 --- a/src/star.rs +++ b/src/star.rs @@ -1,23 +1,11 @@ +use std::fmt; + use crate::tree; use crate::trianglefinder; use std::time::Instant; extern crate nalgebra as na; -pub struct StarMatcher { - stars_xyz: Vec<[f32; 3]>, - star_index: tree::UnitVectorLookup, - inter_star_angles: Vec, - inter_star_angle_pairs: Vec<[u32; 2]>, - /// tolerance of inter star angle in rad - inter_star_angle_tolerance: f32, - // polynomial with terms [c0, c1, c2, ..] (c0 + x c1 + x^2 c2) - inter_star_index_polynomial: Vec, - max_inter_star_angle: f32, - n_minimum_matches: usize, - timeout_secs: f32, -} - /// Return angle between two normalized 3-dimensional vectors. fn angle(a: &[f32; 3], b: &[f32; 3]) -> f32 { maths_rs::acos(a[0] * b[0] + a[1] * b[1] + a[2] * b[2]) @@ -33,13 +21,21 @@ fn polyval(coeffs: &[f32], x: f32) -> f32 { /// Turned out to be faster in some cases pub fn look_up_close_angles_naive( vectors: &[[f32; 3]], + magnitudes: &[f32], max_angle_rad: f32, + max_magnitude: f32, ) -> Vec<([u32; 2], f32)> { let threshold = maths_rs::cos(max_angle_rad); let mut index_pairs = Vec::new(); for a in 0..vectors.len() { + if magnitudes[a] > max_magnitude { + continue; + } let vec_a = &vectors[a]; for b in (a + 1)..vectors.len() { + if magnitudes[b] > max_magnitude { + continue; + } let vec_b = &vectors[b]; let dotp = tree::dot_product(vec_a, vec_b); if dotp >= threshold { @@ -50,74 +46,194 @@ pub fn look_up_close_angles_naive( index_pairs } -pub fn get_inter_star_index( - star_index: &tree::UnitVectorLookup, - stars_xyz: &[[f32; 3]], - max_angle_rad: f32, -) -> Result<(Vec<[u32; 2]>, Vec, Vec), &'static str> { - let mut index_pairs = star_index.look_up_close_angles(stars_xyz, max_angle_rad); +pub struct InterStarIndex { + pub pairs: Vec<[u32; 2]>, + pub angles: Vec, + // polynomial with terms [c0, c1, c2, ..] (c0 + x c1 + x^2 c2) + pub polynomial: [f32; 3], +} + +impl InterStarIndex { + pub fn new( + star_index: &tree::UnitVectorLookup, + stars_xyz: &[[f32; 3]], + stars_mag: &[f32], + max_angle_rad: f32, + max_magnitude: f32, + ) -> Result { + if stars_xyz.len() != stars_mag.len() { + return Err("stars_xyz and stars_mag must have same length"); + } + + let mut index_pairs = + star_index.look_up_close_angles(stars_xyz, stars_mag, max_angle_rad, max_magnitude); + + if index_pairs.is_empty() { + return Err("Given star positions do not result in any angles below the threshold."); + } + + index_pairs.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + let angles: Vec = index_pairs.iter().map(|x| x.1).collect(); + let indices: Vec = (0..angles.len()).map(|i| i as f32).collect(); + + // polynomial with terms [c0, c1, c2, ..] (c0 + x c1 + x^2 c2) + let mut polynomial: [f32; 3] = polyfit_rs::polyfit_rs::polyfit(&angles, &indices, 2)? + .try_into() + .map_err(|_| "Failed to convert polynomial coefficients")?; + + let errors: Vec = angles + .iter() + .enumerate() + .map(|x| polyval(&polynomial, *x.1) - (x.0 as f32)) + .collect(); + + // let min = errors.iter().cloned().fold(f32::INFINITY, f32::min); + let max = errors.iter().cloned().fold(f32::NEG_INFINITY, f32::max); - if index_pairs.is_empty() { - return Err("Given star positions do not result in any angles below the threshold."); + polynomial[0] -= max; + + let pairs = index_pairs.iter().map(|x| x.0).collect(); + + Ok(InterStarIndex { + pairs, + angles, + polynomial, + }) } - index_pairs.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + /// Get star pairs that match given inter star angle. + fn pair_lookup(&self, inter_star_angle: f32, tolerance_angle: f32) -> &[[u32; 2]] { + let lower_threshold = maths_rs::max(inter_star_angle - tolerance_angle, 0.0); + let upper_threshold = maths_rs::max(inter_star_angle + tolerance_angle, 0.0); + let lower_index_float = polyval(&self.polynomial, lower_threshold); + let upper_index_float = polyval(&self.polynomial, upper_threshold); + let max = self.angles.len() - 1; + let mut lower_index = (lower_index_float as usize).clamp(0, max); + let mut upper_index = (upper_index_float as usize).clamp(0, max); - let angles: Vec = index_pairs.iter().map(|x| x.1).collect(); - let indices: Vec = (0..angles.len()).map(|i| i as f32).collect(); + lower_index = maths_rs::min(lower_index, upper_index); - // polynomial with terms [c0, c1, c2, ..] (c0 + x c1 + x^2 c2) - let mut polynomial = polyfit_rs::polyfit_rs::polyfit(&angles, &indices, 2)?; + while lower_index < max { + if self.angles[lower_index] > lower_threshold { + break; + } else { + lower_index += 1; + } + } + while upper_index < max { + if self.angles[upper_index] > upper_threshold { + break; + } else { + upper_index += 1; + } + } - let errors: Vec = angles - .iter() - .enumerate() - .map(|x| polyval(&polynomial, *x.1) - (x.0 as f32)) - .collect(); + &self.pairs[lower_index..upper_index] + } +} - // let min = errors.iter().cloned().fold(f32::INFINITY, f32::min); - let max = errors.iter().cloned().fold(f32::NEG_INFINITY, f32::max); +#[derive(Debug)] +pub struct MatchResult { + pub quat: [f32; 4], + pub match_ids: Vec, + pub n_matches: u32, + pub obs_matched: Vec<[f32; 3]>, + pub obs_indices: Vec, +} + +#[derive(Debug)] +pub enum FailureReason { + Unspecified, + Timeout, + NotEnoughStars, + SearchExhausted, +} - polynomial[0] -= max; +#[derive(Debug)] +pub struct DiagnosticData { + pub reason: FailureReason, + pub first_svd_failures: usize, + pub second_svd_failures: usize, + pub third_svd_failures: usize, + pub max_first_matches: usize, + pub max_refined_matches: usize, +} + +impl DiagnosticData { + pub fn new() -> Self { + DiagnosticData { + reason: FailureReason::Unspecified, + first_svd_failures: 0, + second_svd_failures: 0, + third_svd_failures: 0, + max_first_matches: 0, + max_refined_matches: 0, + } + } +} - let pairs = index_pairs.iter().map(|x| x.0).collect(); +impl fmt::Display for DiagnosticData { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Delegate to Debug + write!(f, "{:?}", self) + } +} - Ok((pairs, angles, polynomial)) +pub struct StarMatcher { + stars_xyz: Vec<[f32; 3]>, + star_index: tree::UnitVectorLookup, + inter_star_index: InterStarIndex, + /// tolerance of inter star angle in rad + inter_star_angle_tolerance: f32, + max_inter_star_angle: f32, + n_minimum_matches: usize, + timeout_secs: f32, } impl StarMatcher { pub fn new( stars_xyz: Vec<[f32; 3]>, + stars_mag: &[f32], + max_lookup_magnitude: f32, max_inter_star_angle: f32, inter_star_angle_tolerance: f32, n_minimum_matches: usize, timeout_secs: f32, ) -> Result { let star_index = tree::UnitVectorLookup::new(&stars_xyz); - - let (pairs, angles, polynomial) = - match get_inter_star_index(&star_index, &stars_xyz, max_inter_star_angle) { - Ok(result) => result, - Err(s) => return Err(s), - }; - + let inter_star_index = InterStarIndex::new( + &star_index, + &stars_xyz, + &stars_mag, + max_inter_star_angle, + max_lookup_magnitude, + )?; Ok(StarMatcher { stars_xyz, star_index, - inter_star_angles: angles, - inter_star_angle_pairs: pairs, + inter_star_index, inter_star_angle_tolerance, - inter_star_index_polynomial: polynomial, max_inter_star_angle, n_minimum_matches, timeout_secs, }) } - pub fn find(&self, obs_xyz: Vec<[f32; 3]>) -> Result { + pub fn find(&self, obs_xyz: &[[f32; 3]]) -> Result { let start_instant = Instant::now(); - for obs_indices in self.triangle_combinations_iterator(obs_xyz.len() as u32) { + let mut diagnostic_data = DiagnosticData::new(); + + let iter = match self.triangle_combinations_iterator(obs_xyz.len() as u32) { + Ok(x) => x, + Err(_) => { + diagnostic_data.reason = FailureReason::NotEnoughStars; + return Err(diagnostic_data); + } + }; + + for obs_indices in iter { let [a, b, c] = obs_indices; // Get positions of the 3 observations @@ -138,9 +254,15 @@ impl StarMatcher { } // Look up all pairs in the catalog that could match the observations - let ab_pairs = self.pair_lookup(angle_ab); - let ac_pairs = self.pair_lookup(angle_ac); - let bc_pairs = self.pair_lookup(angle_bc); + let ab_pairs = self + .inter_star_index + .pair_lookup(angle_ab, self.inter_star_angle_tolerance); + let ac_pairs = self + .inter_star_index + .pair_lookup(angle_ac, self.inter_star_angle_tolerance); + let bc_pairs = self + .inter_star_index + .pair_lookup(angle_bc, self.inter_star_angle_tolerance); let finder = trianglefinder::TriangleFinder::new( ab_pairs.to_vec(), @@ -152,61 +274,34 @@ impl StarMatcher { // Iterate over possible matching triangles for value in iter_finder { - match self.check(&obs_indices, &obs_xyz, &value) { - None => {} - Some(x) => return Ok(x), + if let Some(x) = self.check(&obs_indices, &obs_xyz, &value, &mut diagnostic_data) { + return Ok(x); }; } if start_instant.elapsed().as_secs_f32() > self.timeout_secs { - return Err("Timeout reached"); + diagnostic_data.reason = FailureReason::Timeout; + return Err(diagnostic_data); } } - Err("Search exhausted") + diagnostic_data.reason = FailureReason::SearchExhausted; + Err(diagnostic_data) } /// Iterator over combinations of stars forming triangles. - fn triangle_combinations_iterator(&self, n: u32) -> impl Iterator { + fn triangle_combinations_iterator( + &self, + n: u32, + ) -> Result, &'static str> { crate::ordered_combinations::OrderedCombinations::<3>::new(n) } - /// Get star pairs that match given inter star angle. - fn pair_lookup(&self, inter_star_angle: f32) -> &[[u32; 2]] { - let lower_threshold = - maths_rs::max(inter_star_angle - self.inter_star_angle_tolerance, 0.0); - let upper_threshold = - maths_rs::max(inter_star_angle + self.inter_star_angle_tolerance, 0.0); - let lower_index_float = polyval(&self.inter_star_index_polynomial, lower_threshold); - let upper_index_float = polyval(&self.inter_star_index_polynomial, upper_threshold); - let max = self.inter_star_angles.len() - 1; - let mut lower_index = (lower_index_float as usize).clamp(0, max); - let mut upper_index = (upper_index_float as usize).clamp(0, max); - - lower_index = maths_rs::min(lower_index, upper_index); - - while lower_index < max { - if self.inter_star_angles[lower_index] > lower_threshold { - break; - } else { - lower_index += 1; - } - } - while upper_index < max { - if self.inter_star_angles[upper_index] > upper_threshold { - break; - } else { - upper_index += 1; - } - } - - &self.inter_star_angle_pairs[lower_index..upper_index] - } - fn check( &self, obs_indices: &[u32; 3], obs_xyz: &[[f32; 3]], cat_indices: &[u32; 3], + diagnostic_data: &mut DiagnosticData, ) -> Option { // Get vectors of observed triangle let obs_triangle_xyz = [ @@ -225,7 +320,7 @@ impl StarMatcher { // Fit rotation matrix on triangle let rotm = match attitude_svd(&cat_triangle_xyz, &obs_triangle_xyz) { None => { - // println!("1st Attitude svd failed"); + diagnostic_data.first_svd_failures += 1; return None; } Some(value) => value.cast::(), @@ -245,7 +340,7 @@ impl StarMatcher { let obs_vec = obs_transformed.column(obs_i); let obs = [obs_vec[0], obs_vec[1], obs_vec[2]]; - //Look up closest star in the catalog to the transformed position of the observation + // Look up closest star in the catalog to the transformed position of the observation let closest_index = self.star_index.lookup_nearest(&obs); // Use star if it is close than the allowed threshold @@ -258,14 +353,15 @@ impl StarMatcher { // Do not proceed if there are less than the minimum required amount of stars if selected_cat_xyz.len() < self.n_minimum_matches { - // println!("Less than {} close neighbors found", self.n_minimum_matches); + diagnostic_data.max_first_matches = + usize::max(diagnostic_data.max_first_matches, selected_cat_xyz.len()); return None; } // Fit rotation matrix on selected observations let rotm = match attitude_svd(&selected_cat_xyz, &selected_obs_xyz) { None => { - // println!("2nd Attitude svd failed"); + diagnostic_data.second_svd_failures += 1; return None; } Some(value) => value.cast::(), @@ -301,14 +397,15 @@ impl StarMatcher { // Do not proceed if there are less than the minimum required amount of stars if selected_cat_xyz.len() < self.n_minimum_matches { - // println!("Less than {} close neighbors found", self.n_minimum_matches); + diagnostic_data.max_refined_matches = + usize::max(diagnostic_data.max_refined_matches, selected_cat_xyz.len()); return None; } // Fit rotation matrix on selected observations let final_rotm = match attitude_svd(&selected_cat_xyz, &selected_obs_xyz) { None => { - // println!("3rd Attitude svd failed"); + diagnostic_data.third_svd_failures += 1; return None; } Some(value) => value, @@ -328,14 +425,10 @@ impl StarMatcher { obs_indices: selected_obs_indices, }) } -} -pub struct MatchResult { - pub quat: [f32; 4], - pub match_ids: Vec, - pub n_matches: u32, - pub obs_matched: Vec<[f32; 3]>, - pub obs_indices: Vec, + pub fn stars_xyz(&self) -> &[[f32; 3]] { + &self.stars_xyz + } } /// Solve Wahba's problem using SVD method. @@ -355,17 +448,17 @@ pub fn attitude_svd(cat_xyz: &[[f32; 3]], obs_xyz: &[[f32; 3]]) -> Option().unwrap(); + mat += outer_prod; } // Perform SVD let svd = mat.svd(true, true); diff --git a/src/starcat.rs b/src/starcat.rs new file mode 100644 index 0000000..9b93614 --- /dev/null +++ b/src/starcat.rs @@ -0,0 +1,262 @@ +use std::fs::File; +use std::io::BufReader; +use std::path::Path; + +use chrono::{DateTime, TimeZone, Utc}; +use csv::ReaderBuilder; +use nalgebra::Vector3; +use serde::Deserialize; + +const AU: f64 = 149_597_870.693; +#[cfg(feature = "gaia")] +const GAIA_EPOCH: f64 = 2016.0; +#[cfg(feature = "gaia")] +const GAIA_2016_CSV: &str = include_str!("../ruststartracker/gaia_data_j2016.csv"); + +pub trait Cast { + fn cast(self) -> T; +} + +impl Cast for f64 { + #[inline(always)] + fn cast(self) -> f32 { + self as f32 + } +} + +impl Cast for f64 { + #[inline(always)] + fn cast(self) -> f64 { + self + } +} + +#[derive(Debug, Deserialize)] +pub struct StarRaw { + pub source_id: u64, + pub ra: f64, + pub dec: f64, + pub parallax: f64, + pub pmra: f64, + pub pmdec: f64, + pub phot_g_mean_mag: f64, +} + +pub struct Star { + pub source_id: u64, + pub ra: f64, + pub dec: f64, + pub parallax: f64, + pub proper_motion_ra: f64, + pub proper_motion_dec: f64, + pub magnitude: f64, +} + +impl Star { + const DEG2RAD: f64 = core::f64::consts::PI / 180.0; + const MAS2RAD: f64 = core::f64::consts::PI / (180.0 * 3600.0 * 1000.0); + + pub fn from_raw(raw: StarRaw) -> Self { + Star { + source_id: raw.source_id, + ra: raw.ra * Star::DEG2RAD, + dec: raw.dec * Star::DEG2RAD, + parallax: raw.parallax * Star::MAS2RAD, + proper_motion_ra: raw.pmra * Star::MAS2RAD, + proper_motion_dec: raw.pmdec * Star::MAS2RAD, + magnitude: raw.phot_g_mean_mag, + } + } +} + +fn time_to_epoch(t: DateTime) -> f64 { + let delta_t_j2000 = 64.0; + let j2000 = Utc + .with_ymd_and_hms(2000, 1, 1, 12, 0, 0) + .unwrap() + .timestamp() as f64 + - delta_t_j2000; + let seconds_in_j_year = 365.25 * 86400.0; + (t.timestamp() as f64 - j2000) / seconds_in_j_year + 2000.0 +} + +pub struct StarCatalog { + pub stars: Vec, + pub epoch: f64, +} + +impl StarCatalog { + pub fn new_from_file>( + filename: P, + epoch: f64, + max_magnitude: Option, + ) -> Result { + let file = File::open(filename).map_err(|_| "Unable to open star catalog file")?; + let reader = BufReader::new(file); + Self::new_from_buffer(reader, epoch, max_magnitude) + } + + pub fn new_from_string( + data: &str, + epoch: f64, + max_magnitude: Option, + ) -> Result { + let reader = BufReader::new(data.as_bytes()); + Self::new_from_buffer(reader, epoch, max_magnitude) + } + + #[cfg(feature = "gaia")] + pub fn new_from_gaia(max_magnitude: Option) -> Result { + StarCatalog::new_from_string(GAIA_2016_CSV, GAIA_EPOCH, max_magnitude) + } + + fn new_from_buffer( + reader: BufReader, + epoch: f64, + max_magnitude: Option, + ) -> Result { + let mut csv_reader = ReaderBuilder::new().from_reader(reader); + let mut stars = Vec::new(); + for line in csv_reader.deserialize() { + let star: Star = Star::from_raw(line.map_err(|e| e.to_string())?); + if star.magnitude < max_magnitude.unwrap_or(f64::INFINITY) { + stars.push(star); + } + } + Ok(StarCatalog { stars, epoch }) + } + + pub fn normalized_positions( + &self, + epoch: Option, + observer_position: Option<[f64; 3]>, + ) -> Vec<[T; 3]> + where + f64: Cast, + { + let epoch = epoch.unwrap_or_else(|| time_to_epoch(Utc::now())); + let delta_epoch = epoch - self.epoch; + + let mut vectors: Vec<[T; 3]> = Vec::with_capacity(self.stars.len()); + + let parallax_correction_factor = match observer_position { + Some(obs_pos) => { + let obs_pos_vec = Vector3::from_column_slice(&obs_pos); + Some(obs_pos_vec / AU) + } + None => None, + }; + + for star in self.stars.iter() { + let cos_ra = star.ra.cos(); + let sin_ra = star.ra.sin(); + let cos_dec = star.dec.cos(); + let sin_dec = star.dec.sin(); + + let mut vec = Vector3::new(cos_dec * cos_ra, cos_dec * sin_ra, sin_dec); + + // Proper motion correction + let p_hat = Vector3::new(-sin_ra, cos_ra, 0.0); + let q_hat = Vector3::new(-sin_dec * cos_ra, -sin_dec * sin_ra, cos_dec); + let pm = delta_epoch * (star.proper_motion_ra * p_hat + star.proper_motion_dec * q_hat); + vec += pm; + + // Parallax correction + if let Some(f) = parallax_correction_factor { + let plx = star.parallax * f; + vec -= plx; + } + + // Normalize + let norm_vec: [f64; 3] = vec.normalize().into(); + + vectors.push([ + Cast::::cast(norm_vec[0]), + Cast::::cast(norm_vec[1]), + Cast::::cast(norm_vec[2]), + ]); + } + + vectors + } + + pub fn magnitudes(&self) -> Vec + where + f64: Cast, + { + self.stars + .iter() + .map(|s| Cast::::cast(s.magnitude)) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const GAIA_EPOCH: f64 = 2016.0; + + #[test] + fn test_starcat() { + let cat = StarCatalog::new_from_file( + "ruststartracker/gaia_data_j2016.csv", + GAIA_EPOCH, + Some(5.5), + ) + .unwrap(); + + assert!(cat.epoch == GAIA_EPOCH); + assert!(cat.stars.len() > 1000); + + let positions: Vec<[f64; 3]> = cat.normalized_positions(Some(2025.23), None); + println!("{:?}", positions.len()); + let positions: Vec<[f32; 3]> = cat.normalized_positions(Some(2025.23), None); + println!("{:?}", positions.len()); + } + + #[test] + fn test_proper_motion_star() { + let cat = StarCatalog::new_from_file( + "ruststartracker/gaia_data_j2016.csv", + GAIA_EPOCH, + Some(8.0), + ) + .unwrap(); + + let id = 725422076533653504; + + let index = cat.stars.iter().position(|s| s.source_id == id).unwrap(); + + let pos_2010: [f64; 3] = cat.normalized_positions(Some(2010.0), None)[index]; + let pos_2030: [f64; 3] = cat.normalized_positions(Some(2030.0), None)[index]; + + let angle_rad = f64::acos( + pos_2010[0] * pos_2030[0] + pos_2010[1] * pos_2030[1] + pos_2010[2] * pos_2030[2], + ); + + let angle_mas = angle_rad / Star::MAS2RAD; + + let expected_travel_mas = 20.0 * 22.717; + + assert!( + (angle_mas / expected_travel_mas).abs() - 1.0 < 0.001, + "Expected travel: {}, got: {}", + expected_travel_mas, + angle_mas + ); + } + + #[cfg(feature = "gaia")] + #[test] + fn test_gaia() { + let cat1 = StarCatalog::new_from_file( + "ruststartracker/gaia_data_j2016.csv", + GAIA_EPOCH, + Some(5.5), + ) + .unwrap(); + let cat2 = StarCatalog::new_from_gaia(Some(5.5)).unwrap(); + assert!(cat1.stars.len() == cat2.stars.len()); + } +} diff --git a/src/starextraction.rs b/src/starextraction.rs new file mode 100644 index 0000000..b302189 --- /dev/null +++ b/src/starextraction.rs @@ -0,0 +1,307 @@ +#[cfg(feature = "improc")] +use opencv; +use opencv::{core::Mat, prelude::*}; + +pub fn get_threshold_from_histogram(image_row_major: &[u8], fraction: f64) -> u8 { + // Calculate histogram + let mut hist: [usize; 256] = [0; 256]; + for &v in image_row_major { + hist[v as usize] += 1; + } + + // Calculate cumulative sum of histogram + let mut acc = 0; + for i in 0..hist.len() { + acc += hist[i]; + hist[i] = acc; + } + + // Define threshold in terms of cumulative sum value + let total: usize = *hist.last().unwrap(); // Safe: hist always has 256 elements + let threshold_value = (total as f64 * f64::clamp(fraction, 0.0, 1.0)) as usize; + + //Find index in histogram where threshold is exceeded + // Find threshold index + hist.iter() + .position(|&v| v >= threshold_value) + .unwrap_or(255) as u8 +} + +pub fn threshold(image: &[u8], threshold: u8) -> Vec { + image.iter().map(|&x| (x > threshold) as u8).collect() +} + +pub fn extract_observations( + image_row_major: &[u8], + imsize: (usize, usize), + threshold_value: u8, + min_area: usize, + max_area: usize, +) -> Result<(Vec<[f64; 2]>, Vec), String> { + let (width, height) = imsize; + + let thresholded_row_major: Vec = threshold(image_row_major, threshold_value); + + // Call opencv function + let thresholded = opencv::core::Mat::new_rows_cols_with_data( + height as i32, + width as i32, + &thresholded_row_major, + ) + .map_err(|e| format!("Failed to create Mat: {}", e))?; + let mut labels = Mat::default(); + let mut stats = Mat::default(); + let mut centroids = Mat::default(); + _ = opencv::imgproc::connected_components_with_stats( + &thresholded, + &mut labels, + &mut stats, + &mut centroids, + 4, + opencv::core::CV_16U, + ) + .map_err(|e| format!("Failed to perform connected component analysis: {}", e))?; + + let stats_data: &[i32] = stats + .data_typed::() + .map_err(|e| format!("Failed to get stats data: {}", e))?; + let stats_rows = stats.rows() as usize; + let stats_cols = stats.cols() as usize; + + let mut centers_intensities = Vec::new(); + + for i in 0..stats_rows { + // Get bounding box around star + let stats_base = i * stats_cols; + let x = stats_data[stats_base + opencv::imgproc::CC_STAT_LEFT as usize] as usize; + let y = stats_data[stats_base + opencv::imgproc::CC_STAT_TOP as usize] as usize; + let w = stats_data[stats_base + opencv::imgproc::CC_STAT_WIDTH as usize] as usize; + let h = stats_data[stats_base + opencv::imgproc::CC_STAT_HEIGHT as usize] as usize; + let area = stats_data[stats_base + opencv::imgproc::CC_STAT_AREA as usize] as usize; + + if area < min_area + || area > max_area + || x == 0 + || y == 0 + || x + w >= width + || y + h >= height + { + continue; + } + + let mut mid_x = 0; + let mut mid_y = 0; + let mut intensity_sum = 0; + + // Accumulate intensity and intensity weighted positions in the window + for yy in y - 1..y + h + 1 { + //row should be inner loop for better cache optimization + let img_base = yy * width; + for xx in x - 1..x + w + 1 { + let value = image_row_major[img_base + xx] as usize; + intensity_sum += value; + mid_x += xx * value; + mid_y += yy * value; + } + } + + if intensity_sum == 0 { + continue; + } + + centers_intensities.push(( + mid_x as f64 / intensity_sum as f64, + mid_y as f64 / intensity_sum as f64, + intensity_sum as f64, + )); + } + + // Sort centers by intensity in descending order + centers_intensities.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap()); + let centers = centers_intensities.iter().map(|x| [x.0, x.1]).collect(); + let intensities = centers_intensities.iter().map(|x| x.2).collect(); + + Ok((centers, intensities)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple() { + #[rustfmt::skip] + let data= [ + 0, 0, 0, 0, 0, 0, 0, + 0, 4, 4, 0, 0, 0, 0, + 0, 4, 4, 0, 4, 8, 0, + 0, 0, 0, 0, 4, 8, 0, + 0, 0, 0, 0, 0, 0, 0, + ]; + + let (centers, intensities) = extract_observations(&data, (7, 5), 3, 2, 6).unwrap(); + + let centers_expected = vec![[4.666666, 2.5], [1.5, 1.5]]; + let intensities_expected = vec![24.0, 16.0]; + + // Helper function to compare floats with tolerance + fn assert_vec2_close(a: &[f64; 2], b: &[f64; 2], tol: f64) { + assert!((a[0] - b[0]).abs() < tol, "x: {} vs {}", a[0], b[0]); + assert!((a[1] - b[1]).abs() < tol, "y: {} vs {}", a[1], b[1]); + } + + fn assert_vec_close(a: &[f64], b: &[f64], tol: f64) { + assert_eq!(a.len(), b.len()); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < tol, "{} vs {}", x, y); + } + } + + assert_eq!(centers.len(), centers_expected.len()); + for (c, ce) in centers.iter().zip(centers_expected.iter()) { + assert_vec2_close(c, ce, 1e-4); + } + assert_vec_close(&intensities, &intensities_expected, 1e-4); + + // Additional test: single bright pixel + #[rustfmt::skip] + let data2 = [ + 0, 0, 0, + 0, 9, 0, + 0, 0, 0 + ]; + let (centers2, intensities2) = extract_observations(&data2, (3, 3), 5, 1, 2).unwrap(); + let centers2_expected = vec![[1.0, 1.0]]; + let intensities2_expected = vec![9.0]; + assert_eq!(centers2.len(), centers2_expected.len()); + for (c, ce) in centers2.iter().zip(centers2_expected.iter()) { + assert_vec2_close(c, ce, 1e-4); + } + assert_vec_close(&intensities2, &intensities2_expected, 1e-4); + + // Additional test: two separated stars + #[rustfmt::skip] + let data3 = [ + 0, 0, 0, 0, 0, + 0, 7, 0, 8, 0, + 0, 0, 0, 0, 0 + ]; + let (centers3, intensities3) = extract_observations(&data3, (5, 3), 5, 1, 2).unwrap(); + let centers3_expected = vec![[1.0, 1.0], [3.0, 1.0]]; + let intensities3_expected = vec![7.0, 8.0]; + // Order is not guaranteed, so sort by x + let mut zipped: Vec<_> = centers3.iter().zip(intensities3.iter()).collect(); + zipped.sort_by(|a, b| a.0[0].partial_cmp(&b.0[0]).unwrap()); + let (centers3, intensities3): (Vec<_>, Vec<_>) = + zipped.into_iter().map(|(c, i)| (*c, *i)).unzip(); + for (c, ce) in centers3.iter().zip(centers3_expected.iter()) { + assert_vec2_close(c, ce, 1e-4); + } + assert_vec_close(&intensities3, &intensities3_expected, 1e-4); + + // Additional test: no stars above threshold + #[rustfmt::skip] + let data4 = [ + 0, 0, 0, 0, 0, + 0, 1, 1, 1, 0, + 0, 1, 1, 1, 0, + 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, + ]; + let (centers4, intensities4) = extract_observations(&data4, (5, 5), 10, 1, 10).unwrap(); + assert!(centers4.is_empty()); + assert!(intensities4.is_empty()); + + // Additional test: stars in all corners + #[rustfmt::skip] + let data5 = [ + 5, 5, 0, 5, 5, + 5, 5, 1, 5, 5, + 0, 1, 8, 1, 0, + 4, 5, 1, 5, 7, + 4, 4, 0, 7, 7, + ]; + let (centers5, intensities5) = extract_observations(&data5, (5, 5), 2, 1, 10).unwrap(); + let centers5_expected = vec![[2.0, 2.0]]; + let intensities5_expected = vec![32.0]; + assert_eq!(centers.len(), centers_expected.len()); + for (c, ce) in centers5.iter().zip(centers5_expected.iter()) { + assert_vec2_close(c, ce, 1e-4); + } + assert_vec_close(&intensities5, &intensities5_expected, 1e-4); + } + + pub fn threshold_opencv(image: &[u8], threshold: u8) -> Mat { + let image_mat = + opencv::core::Mat::new_rows_cols_with_data(image.len() as i32, 1, &image).unwrap(); + let mut out = Mat::default(); + opencv::imgproc::threshold( + &image_mat, + &mut out, + threshold as f64, + 1 as f64, + opencv::imgproc::THRESH_BINARY, + ) + .unwrap(); + out + } + + #[test] + fn test_threshold() { + use std::time::Instant; + + let image = vec![56; 4000000]; + + let start1 = Instant::now(); + let out1 = threshold(&image, 42); + let duration1 = start1.elapsed(); + + let start2 = Instant::now(); + let out2 = threshold_opencv(&image, 42); + let duration2 = start2.elapsed(); + + println!("threshold (Rust) took: {:?}", duration1); + println!("threshold_opencv (OpenCV) took: {:?}", duration2); + + let sum1: usize = out1.into_iter().map(|x| x as usize).sum(); + let d = out2.data_typed::().unwrap(); + let sum2: usize = d.into_iter().map(|&x| x as usize).sum(); + + println!("{:?}", sum1); + println!("{:?}", sum2); + } + + #[test] + fn test_threshold_vlues() { + let image = vec![1, 2, 3, 4, 5, 6, 7, 8]; + + let out1 = threshold(&image, 4); + let out2 = threshold_opencv(&image, 4); + + let sum1: usize = out1.into_iter().map(|x| x as usize).sum(); + let d = out2.data_typed::().unwrap(); + let sum2: usize = d.into_iter().map(|&x| x as usize).sum(); + + assert_eq!(sum1, 4); + assert_eq!(sum2, 4); + } + + #[test] + fn test_histogram() { + let image = vec![0, 4, 4]; + let threshold = get_threshold_from_histogram(&image, 0.9); + println! {"{:?}", threshold}; + assert_eq!(threshold, 4); + + let image = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let threshold = get_threshold_from_histogram(&image, 0.899); + println! {"{:?}", threshold}; + assert_eq!(threshold, 7); + let threshold = get_threshold_from_histogram(&image, 0.900); + println! {"{:?}", threshold}; + assert_eq!(threshold, 8); + let threshold = get_threshold_from_histogram(&image, 0.901); + println! {"{:?}", threshold}; + assert_eq!(threshold, 8); + } +} diff --git a/src/tree.rs b/src/tree.rs index 1f3b1ce..b4093ab 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -21,6 +21,7 @@ impl UnitVectorLookup { } pub fn lookup_nearest(&self, vector: &[f32; 3]) -> usize { + // kdtree.nearest is equally fast as kdtree.iter_nearest let res = self.kdtree.nearest(vector, 1, &squared_euclidean).unwrap(); *(res[0].1) } @@ -28,14 +29,21 @@ impl UnitVectorLookup { pub fn look_up_close_angles( &self, vectors: &[[f32; 3]], + magnitudes: &[f32], max_angle_rad: f32, + max_magnitude: f32, ) -> Vec<([u32; 2], f32)> { let threshold = maths_rs::cos(max_angle_rad); let mut index_pairs = Vec::new(); for a in 0..vectors.len() { + if magnitudes[a] > max_magnitude { + continue; + } let vec_a = &vectors[a]; - for (_, b) in self.kdtree.iter_nearest(vec_a, &squared_euclidean).unwrap() { + if magnitudes[*b] > max_magnitude { + continue; + } let vec_b = &vectors[*b]; let dotp = dot_product(vec_a, vec_b); if dotp < threshold { @@ -51,3 +59,32 @@ impl UnitVectorLookup { index_pairs } } + +#[cfg(test)] +mod tests { + + use rand::rng; + use rand_distr::{Distribution, Normal}; + + use super::*; + + #[test] + fn test_tree() { + let mut rng = rng(); + let normal = Normal::new(0.0, 1.0).unwrap(); // mean = 0, std dev = 1 + + let samples: Vec<[f32; 3]> = (0..100) + .map(|_| { + let x = normal.sample(&mut rng) as f32; + let y = normal.sample(&mut rng) as f32; + let z = normal.sample(&mut rng) as f32; + let mag = f32::sqrt(x * x + y * y + z * z); + [x / mag, y / mag, z / mag] + }) + .collect(); + + let lookup = UnitVectorLookup::new(&samples); + + assert!(lookup.lookup_nearest(&samples[0]) == 0); + } +} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 668b4f6..0000000 --- a/src/util.rs +++ /dev/null @@ -1,37 +0,0 @@ -pub fn as_vec_of_arrays(slice: &[f32]) -> Option<&[[f32; L]]> { - match slice.len() % L { - 0 => Some(unsafe { - std::slice::from_raw_parts(slice.as_ptr() as *const [f32; L], slice.len() / L) - }), - _ => None, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_as_vec_of_arrays_valid_length() { - let data: &[f32] = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let array_ref = as_vec_of_arrays(data).expect("Expected valid array reference"); - assert_eq!(array_ref, &[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - } - - #[test] - fn test_as_vec_of_arrays_invalid_length() { - let data: &[f32] = &[1.0, 2.0, 3.0, 4.0, 5.0]; - assert!( - as_vec_of_arrays::<3>(data).is_none(), - "Expected None for invalid length" - ); - } - - #[test] - fn test_as_vec_of_arrays_empty_slice() { - let data: &[f32] = &[]; - let array_ref = - as_vec_of_arrays::<3>(data).expect("Expected valid array reference for empty slice"); - assert_eq!(array_ref.len(), 0); - } -}