diff --git a/src/import_graph/graph.rs b/src/import_graph/graph.rs index f191cdf..1dfde95 100644 --- a/src/import_graph/graph.rs +++ b/src/import_graph/graph.rs @@ -415,4 +415,91 @@ impl ImportGraph { reverse_imports, }) } + + pub fn squash_package(&self, package: &str) -> Result { + let package_to_squash = match self.packages_by_pypath.get(package) { + Some(package) => package, + None => { + return Err(Error::PackageNotFound(package.to_string()))?; + } + }; + let binding = self + .child_modules(package)? + .into_iter() + .filter(|m| m.ends_with(".__init__")) + .collect::>(); + let init_module_pypath = binding.first().unwrap(); + let init_module = self.modules_by_pypath.get(init_module_pypath).unwrap(); + + let packages_to_replace = self._descendant_packages(package)?; + let package_pypaths_to_replace = packages_to_replace + .iter() + .map(|p| p.pypath.clone()) + .collect::>(); + let modules_to_replace = { + let mut modules_to_replace = self._descendant_modules(package)?; + modules_to_replace.remove(init_module); + modules_to_replace + }; + let module_pypaths_to_replace = modules_to_replace + .iter() + .map(|m| m.pypath.clone()) + .collect::>(); + + let mut packages_by_pypath = self.packages_by_pypath.clone(); + for pypath in packages_by_pypath.clone().keys() { + if package_pypaths_to_replace.contains(pypath) { + packages_by_pypath.remove(pypath); + } + } + + let mut modules_by_pypath = self.modules_by_pypath.clone(); + for pypath in modules_by_pypath.clone().keys() { + if module_pypaths_to_replace.contains(pypath) { + modules_by_pypath.remove(pypath); + } + } + + let mut packages_by_module = self.packages_by_module.clone(); + for (module, package) in packages_by_module.clone().iter() { + if modules_to_replace.contains(module) && packages_to_replace.contains(package) { + packages_by_module.remove(module); + packages_by_module.insert(Arc::clone(init_module), Arc::clone(package_to_squash)); + } else if modules_to_replace.contains(module) { + packages_by_module.remove(module); + packages_by_module.insert(Arc::clone(init_module), Arc::clone(package)); + } else if packages_to_replace.contains(package) { + packages_by_module.remove(module); + packages_by_module.insert(Arc::clone(module), Arc::clone(package_to_squash)); + } + } + + let mut imports = self.imports.clone(); + for (module, imported_modules) in imports.clone().iter_mut() { + for imported_module in imported_modules.clone().iter() { + if modules_to_replace.contains(imported_module) { + imported_modules.remove(imported_module); + imported_modules.insert(Arc::clone(init_module)); + } + } + imports.insert(Arc::clone(module), imported_modules.clone()); + if modules_to_replace.contains(module) { + imports.remove(module); + imports + .get_mut(&Arc::clone(init_module)) + .unwrap() + .extend(imported_modules.clone()); + } + } + + let reverse_imports = indexing::reverse_imports(&imports)?; + + Ok(ImportGraph { + packages_by_pypath, + modules_by_pypath, + packages_by_module, + imports, + reverse_imports, + }) + } } diff --git a/src/main.rs b/src/main.rs index c263c38..03a5294 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,13 +9,14 @@ fn main() -> Result<()> { let import_graph = { let import_graph = ImportGraphBuilder::new(root_package_path).build()?; - if args.len() ==2 { + if args.len() == 2 { import_graph } else { + let import_graph = import_graph.squash_package(&args[2])?; import_graph.subgraph(&args[2])? } }; - + let imports = import_graph.direct_imports_flat(); println!("source, target"); for (from_module, to_module) in imports { diff --git a/tests/test_import_graph.rs b/tests/test_import_graph.rs index cebd58b..afe4763 100644 --- a/tests/test_import_graph.rs +++ b/tests/test_import_graph.rs @@ -1065,3 +1065,113 @@ fn test_subgraph() { .collect() ); } + +#[test] +fn test_squash_package() { + let root_package_path = Path::new("./testpackages/somesillypackage"); + let import_graph = ImportGraphBuilder::new(root_package_path).build().unwrap(); + let squashed = import_graph + .squash_package("somesillypackage.child1") + .unwrap(); + assert_eq!( + squashed.packages(), + hashset! { + "somesillypackage", + "somesillypackage.child1", + "somesillypackage.child2", + "somesillypackage.child3", + "somesillypackage.child4", + "somesillypackage.child5", + } + .into_iter() + .map(|s| s.to_string()) + .collect() + ); + assert_eq!( + squashed.modules(), + hashset! { + "somesillypackage.__init__", + "somesillypackage.a", + "somesillypackage.b", + "somesillypackage.c", + "somesillypackage.d", + "somesillypackage.e", + "somesillypackage.z", + "somesillypackage.child1.__init__", + "somesillypackage.child2.__init__", + "somesillypackage.child3.__init__", + "somesillypackage.child4.__init__", + "somesillypackage.child5.__init__", + } + .into_iter() + .map(|s| s.to_string()) + .collect() + ); + assert_eq!( + squashed.direct_imports(), + hashmap! { + "somesillypackage.__init__" => hashset!{ + "somesillypackage.a", + "somesillypackage.b", + "somesillypackage.c", + "somesillypackage.d", + "somesillypackage.e", + "somesillypackage.child1.__init__", + "somesillypackage.child2.__init__", + "somesillypackage.child3.__init__", + "somesillypackage.child4.__init__", + "somesillypackage.child5.__init__", + }, + "somesillypackage.a" => hashset!{ + "somesillypackage.b", + "somesillypackage.c", + }, + "somesillypackage.b" => hashset!{ + "somesillypackage.c", + }, + "somesillypackage.c" => hashset!{ + "somesillypackage.d", + "somesillypackage.e", + }, + "somesillypackage.d" => hashset!{ + "somesillypackage.e" + }, + "somesillypackage.e" => hashset!{}, + "somesillypackage.z" => hashset! { + "somesillypackage.a", + "somesillypackage.b", + "somesillypackage.c", + "somesillypackage.d", + "somesillypackage.e", + "somesillypackage.child1.__init__", + "somesillypackage.child2.__init__", + "somesillypackage.child3.__init__", + "somesillypackage.child4.__init__", + "somesillypackage.child5.__init__", + }, + "somesillypackage.child1.__init__" => hashset!{ + "somesillypackage.a", + "somesillypackage.b", + "somesillypackage.c", + "somesillypackage.d", + "somesillypackage.e", + "somesillypackage.__init__", + "somesillypackage.child1.__init__", + "somesillypackage.child2.__init__", + "somesillypackage.child3.__init__", + "somesillypackage.child4.__init__", + "somesillypackage.child5.__init__", + }, + "somesillypackage.child2.__init__" => hashset!{}, + "somesillypackage.child3.__init__" => hashset!{}, + "somesillypackage.child4.__init__" => hashset!{}, + "somesillypackage.child5.__init__" => hashset!{}, + } + .into_iter() + .map(|(k, v)| ( + k.to_string(), + v.into_iter().map(|v| v.to_string()).collect() + )) + .collect() + ); +}