diff --git a/Cargo.lock b/Cargo.lock index 92ec430..34e5c4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -463,7 +463,7 @@ dependencies = [ [[package]] name = "pyimports" -version = "0.3.2" +version = "0.3.3" dependencies = [ "anyhow", "lazy_static", diff --git a/Cargo.toml b/Cargo.toml index c280ed4..a95e2b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ repository = "https://github.com/Peter554/pyimports" documentation = "https://docs.rs/pyimports/" readme = "README.md" license = "MIT" -version = "0.3.2" +version = "0.3.3" edition = "2021" exclude = [ ".github/*", diff --git a/src/imports_info/queries/internal_imports.rs b/src/imports_info/queries/internal_imports.rs index c831800..6164e56 100644 --- a/src/imports_info/queries/internal_imports.rs +++ b/src/imports_info/queries/internal_imports.rs @@ -346,24 +346,11 @@ impl<'a> InternalImportsQueries<'a> { /// # Ok(()) /// # } /// ``` - pub fn get_downstream_items( + pub fn get_downstream_items>( &'a self, - item: PackageItemToken, + items: T, ) -> Result> { - self.imports_info.package_info.get_item(item)?; - - let mut items = bfs_reach(item, |item| { - self.imports_info - .internal_imports - .get(item) - .unwrap() - .clone() - }) - .collect::>(); - - items.remove(&item); - - Ok(items) + self.bfs_reach(items, &self.imports_info.internal_imports) } /// Returns the upstream package items. @@ -407,24 +394,40 @@ impl<'a> InternalImportsQueries<'a> { /// # Ok(()) /// # } /// ``` - pub fn get_upstream_items( + pub fn get_upstream_items>( &'a self, - item: PackageItemToken, + items: T, ) -> Result> { - self.imports_info.package_info.get_item(item)?; + self.bfs_reach(items, &self.imports_info.reverse_internal_imports) + } + + fn bfs_reach>( + &'a self, + items: T, + imports_map: &HashMap>, + ) -> Result> { + let items: PackageItemTokenSet = items.into(); + + for item in items.iter() { + self.imports_info.package_info.get_item(*item)?; + } - let mut items = bfs_reach(item, |item| { - self.imports_info - .reverse_internal_imports - .get(item) - .unwrap() - .clone() + let reachable_items = bfs_reach(PathfindingNode::Initial, |item| { + let items = match item { + PathfindingNode::Initial => items.clone(), + PathfindingNode::PackageItem(item) => imports_map.get(item).unwrap().clone(), + }; + items.into_iter().map(PathfindingNode::PackageItem) + }) + .filter_map(|item| match item { + PathfindingNode::Initial => None, + PathfindingNode::PackageItem(item) => Some(item), }) .collect::>(); - items.remove(&item); + let reachable_items = &reachable_items - &items; - Ok(items) + Ok(reachable_items) } /// Returns the metadata associated with the passed import. @@ -730,31 +733,38 @@ from testpackage import colors #[test] fn test_get_downstream_items() -> Result<()> { let testpackage = testpackage! { - "__init__.py" => " -from testpackage import fruit -", + "__init__.py" => "", - "fruit.py" => " -from testpackage import colors -from testpackage import books", + "a.py" => "from testpackage import b", + "b.py" => "from testpackage import c", + "c.py" => "", - "colors.py" => "", - "books.py" => "" + "d.py" => "from testpackage import e", + "e.py" => "from testpackage import f", + "f.py" => "" }; let package_info = PackageInfo::build(testpackage.path())?; let imports_info = ImportsInfo::build(package_info)?; - let root_package_init = imports_info._item("testpackage.__init__"); - let fruit = imports_info._item("testpackage.fruit"); - let colors = imports_info._item("testpackage.colors"); - let books = imports_info._item("testpackage.books"); + let a = imports_info._item("testpackage.a"); + let b = imports_info._item("testpackage.b"); + let c = imports_info._item("testpackage.c"); + let d = imports_info._item("testpackage.d"); + let e = imports_info._item("testpackage.e"); + let f = imports_info._item("testpackage.f"); + + let imports = imports_info + .internal_imports() + .get_downstream_items(a) + .unwrap(); + assert_eq!(imports, hashset! {b, c},); let imports = imports_info .internal_imports() - .get_downstream_items(root_package_init) + .get_downstream_items(hashset! {a, d}) .unwrap(); - assert_eq!(imports, hashset! {fruit, colors, books},); + assert_eq!(imports, hashset! {b, c, e, f},); Ok(()) }