From 9f763b466ba3917aace152306ac9d3aef66a069e Mon Sep 17 00:00:00 2001 From: Steven Kessler Date: Fri, 15 Nov 2024 13:57:39 -0500 Subject: [PATCH] fix(requirements): parsing ~= would not parse correctly and would parse as ~== --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/migrators/mod.rs | 16 ++-- src/migrators/requirements.rs | 131 ++++++++++++++++++-------------- tests/requirements_text_test.rs | 43 +++++++++-- 5 files changed, 124 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 25d44f1..0f0d353 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1804,7 +1804,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uv-migrator" -version = "2024.7.1" +version = "2024.7.2" dependencies = [ "clap", "dirs", diff --git a/Cargo.toml b/Cargo.toml index fd27818..939feb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "uv-migrator" -version = "2024.7.2" +version = "2024.7.3" edition = "2021" authors = ["stvnksslr@gmail.com"] description = "Tool for converting various python package soltutions to use the uv solution by astral" diff --git a/src/migrators/mod.rs b/src/migrators/mod.rs index ec83f3f..5db7d31 100644 --- a/src/migrators/mod.rs +++ b/src/migrators/mod.rs @@ -95,14 +95,20 @@ impl MigrationTool for UvTool { for dep in deps { let mut dep_str = if let Some(version) = &dep.version { let version = version.trim(); - if let Some(stripped) = version.strip_prefix('^') { - // Convert caret version to >= format - format!("{}>={}", dep.name, stripped) + if version.contains(',') { + // For multiple version constraints, preserve as-is + format!("{}{}", dep.name, version) + } else if version.starts_with("~=") { + // Already in correct format + format!("{}{}", dep.name, version) } else if let Some(stripped) = version.strip_prefix('~') { - // Convert tilde version to ~= format + // Convert single tilde to ~= format format!("{}~={}", dep.name, stripped) + } else if let Some(stripped) = version.strip_prefix('^') { + // Convert caret version to >= format + format!("{}>={}", dep.name, stripped) } else if version.starts_with(['>', '<', '=']) { - // Other version constraints remain as is + // Version constraints remain as is format!("{}{}", dep.name, version) } else { // For exact versions diff --git a/src/migrators/requirements.rs b/src/migrators/requirements.rs index e1f315f..d6d6af5 100644 --- a/src/migrators/requirements.rs +++ b/src/migrators/requirements.rs @@ -106,23 +106,36 @@ impl RequirementsMigrationSource { fn process_version_spec(&self, version_spec: &str) -> String { let version_spec = version_spec.trim(); - if let Some(stripped) = version_spec.strip_prefix("==") { - stripped.to_string() - } else if let Some(stripped) = version_spec.strip_prefix(">=") { - stripped.to_string() - } else if let Some(stripped) = version_spec.strip_prefix("<=") { - stripped.to_string() - } else if let Some(stripped) = version_spec.strip_prefix('>') { - stripped.to_string() - } else if let Some(stripped) = version_spec.strip_prefix('<') { + + // For version specs with multiple constraints, preserve as-is + if version_spec.contains(',') { + return version_spec.to_string(); + } + + // Handle special cases in order of precedence + if version_spec.starts_with("~=") + || version_spec.starts_with(">=") + || version_spec.starts_with("<=") + || version_spec.starts_with(">") + || version_spec.starts_with("<") + || version_spec.starts_with("!=") + { + // Preserve these operators as-is + version_spec.to_string() + } else if let Some(stripped) = version_spec.strip_prefix("==") { + // Remove double equals for exact versions stripped.to_string() + } else if let Some(stripped) = version_spec.strip_prefix('~') { + // Convert single tilde to tilde-equals + format!("~={}", stripped) } else { + // If no operator is present, preserve as-is version_spec.to_string() } } fn parse_requirement(&self, line: &str) -> Result, String> { - // Handle editable installs + // Handle editable installs (-e flag) let line = if line.starts_with("-e") { let parts: Vec<&str> = line.splitn(2, ' ').collect(); if parts.len() != 2 { @@ -145,54 +158,9 @@ impl RequirementsMigrationSource { // Handle URLs and git repositories let (name, version) = if package_spec.starts_with("git+") || package_spec.starts_with("http") { - let name = if let Some(egg_part) = package_spec.split('#').last() { - if egg_part.starts_with("egg=") { - egg_part.trim_start_matches("egg=") - } else if package_spec.ends_with(".whl") { - package_spec - .split('/') - .last() - .and_then(|f| f.split('-').next()) - .ok_or("Invalid wheel filename")? - } else { - return Err("Invalid URL format".to_string()); - } - } else { - package_spec - .split('/') - .last() - .and_then(|f| f.split('.').next()) - .ok_or("Invalid URL format")? - }; - (name.to_string(), None) + self.parse_url_requirement(package_spec)? } else { - // Regular package specification - let (name, version) = { - if !package_spec.contains(&['>', '<', '=', '~', '!'][..]) { - (package_spec.to_string(), None) - } else { - let name_end = package_spec - .find(|c| ['>', '<', '=', '~', '!'].contains(&c)) - .unwrap(); - let name = package_spec[..name_end].trim().to_string(); - let version_spec = package_spec[name_end..].trim(); - - let version = if version_spec.contains(',') { - // For multiple version constraints - let version_parts: Vec = version_spec - .split(',') - .map(|p| p.trim().to_string()) - .collect(); - Some(version_parts.join(",")) - } else { - // For single version constraint - Some(self.process_version_spec(version_spec)) - }; - - (name, version) - } - }; - (name, version) + self.parse_regular_requirement(package_spec)? }; if name == "python" { @@ -213,4 +181,53 @@ impl RequirementsMigrationSource { environment_markers, })) } + + fn parse_url_requirement( + &self, + package_spec: &str, + ) -> Result<(String, Option), String> { + let name = if let Some(egg_part) = package_spec.split('#').last() { + if egg_part.starts_with("egg=") { + egg_part.trim_start_matches("egg=").to_string() + } else if package_spec.ends_with(".whl") { + package_spec + .split('/') + .last() + .and_then(|f| f.split('-').next()) + .ok_or("Invalid wheel filename")? + .to_string() + } else { + return Err("Invalid URL format".to_string()); + } + } else { + package_spec + .split('/') + .last() + .and_then(|f| f.split('.').next()) + .ok_or("Invalid URL format")? + .to_string() + }; + + Ok((name, None)) + } + + fn parse_regular_requirement( + &self, + package_spec: &str, + ) -> Result<(String, Option), String> { + // Return early if no version specifier is present + if !package_spec.contains(&['>', '<', '=', '~', '!'][..]) { + return Ok((package_spec.to_string(), None)); + } + + let name_end = package_spec + .find(|c| ['>', '<', '=', '~', '!'].contains(&c)) + .unwrap(); + let name = package_spec[..name_end].trim().to_string(); + let version_spec = package_spec[name_end..].trim(); + + let version = Some(self.process_version_spec(version_spec)); + + Ok((name, version)) + } } diff --git a/tests/requirements_text_test.rs b/tests/requirements_text_test.rs index 9f6ba5c..5ecd6bc 100644 --- a/tests/requirements_text_test.rs +++ b/tests/requirements_text_test.rs @@ -51,8 +51,15 @@ sqlalchemy<2.0.0 assert_eq!(requests_dep.dep_type, DependencyType::Main); let flask_dep = dependencies.iter().find(|d| d.name == "flask").unwrap(); - assert_eq!(flask_dep.version, Some("2.0.0".to_string())); + assert_eq!(flask_dep.version, Some(">=2.0.0".to_string())); assert!(matches!(flask_dep.dep_type, DependencyType::Main)); + + let sqlalchemy_dep = dependencies + .iter() + .find(|d| d.name == "sqlalchemy") + .unwrap(); + assert_eq!(sqlalchemy_dep.version, Some("<2.0.0".to_string())); + assert!(matches!(sqlalchemy_dep.dep_type, DependencyType::Main)); } /// Test handling of comments and empty lines in requirements files. @@ -223,6 +230,8 @@ flask>=2.0.0,<3.0.0 requests~=2.31.0 django>3.0.0,<=4.2.0 sqlalchemy!=1.4.0,>=1.3.0 +django-filters~=23.5 +boto3~=1.35 "#; let (_temp_dir, project_dir) = create_test_project(vec![("requirements.txt", content)]); @@ -230,11 +239,36 @@ sqlalchemy!=1.4.0,>=1.3.0 let source = RequirementsMigrationSource; let dependencies = source.extract_dependencies(&project_dir).unwrap(); - assert_eq!(dependencies.len(), 4); + assert_eq!(dependencies.len(), 6); // Verify complex version constraints are preserved let flask_dep = dependencies.iter().find(|d| d.name == "flask").unwrap(); assert_eq!(flask_dep.version, Some(">=2.0.0,<3.0.0".to_string())); + + // Verify tilde-equal is preserved + let requests_dep = dependencies.iter().find(|d| d.name == "requests").unwrap(); + assert_eq!(requests_dep.version, Some("~=2.31.0".to_string())); + + // Verify multiple constraints with inequality + let django_dep = dependencies.iter().find(|d| d.name == "django").unwrap(); + assert_eq!(django_dep.version, Some(">3.0.0,<=4.2.0".to_string())); + + // Verify complex constraints with not-equal + let sqlalchemy_dep = dependencies + .iter() + .find(|d| d.name == "sqlalchemy") + .unwrap(); + assert_eq!(sqlalchemy_dep.version, Some("!=1.4.0,>=1.3.0".to_string())); + + // Verify tilde-equal cases + let filters_dep = dependencies + .iter() + .find(|d| d.name == "django-filters") + .unwrap(); + assert_eq!(filters_dep.version, Some("~=23.5".to_string())); + + let boto3_dep = dependencies.iter().find(|d| d.name == "boto3").unwrap(); + assert_eq!(boto3_dep.version, Some("~=1.35".to_string())); } /// Test handling of editable installs and URLs. @@ -257,12 +291,9 @@ git+https://github.com/user/other-project.git@v1.0.0#egg=other-project let dependencies = source.extract_dependencies(&project_dir).unwrap(); assert_eq!(dependencies.len(), 3); - - // Note: The exact assertions here will depend on how your implementation handles URLs - // You might want to add more specific assertions based on your implementation } -/// Test error handling for malformed requirements files. +/// Test handling of malformed requirements files. /// /// This test verifies that: /// 1. Invalid requirement formats are handled gracefully