Skip to content

Commit

Permalink
fix(requirements): parsing ~= would not parse correctly and would par…
Browse files Browse the repository at this point in the history
…se as ~==
  • Loading branch information
stvnksslr committed Nov 15, 2024
1 parent d1bc2b0 commit 9f763b4
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 70 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "uv-migrator"
version = "2024.7.2"
version = "2024.7.3"
edition = "2021"
authors = ["stvnksslr@gmail.com"]
description = "Tool for converting various python package soltutions to use the uv solution by astral"
Expand Down
16 changes: 11 additions & 5 deletions src/migrators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,20 @@ impl MigrationTool for UvTool {
for dep in deps {
let mut dep_str = if let Some(version) = &dep.version {
let version = version.trim();
if let Some(stripped) = version.strip_prefix('^') {
// Convert caret version to >= format
format!("{}>={}", dep.name, stripped)
if version.contains(',') {
// For multiple version constraints, preserve as-is
format!("{}{}", dep.name, version)
} else if version.starts_with("~=") {
// Already in correct format
format!("{}{}", dep.name, version)
} else if let Some(stripped) = version.strip_prefix('~') {
// Convert tilde version to ~= format
// Convert single tilde to ~= format
format!("{}~={}", dep.name, stripped)
} else if let Some(stripped) = version.strip_prefix('^') {
// Convert caret version to >= format
format!("{}>={}", dep.name, stripped)
} else if version.starts_with(['>', '<', '=']) {
// Other version constraints remain as is
// Version constraints remain as is
format!("{}{}", dep.name, version)
} else {
// For exact versions
Expand Down
131 changes: 74 additions & 57 deletions src/migrators/requirements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,36 @@ impl RequirementsMigrationSource {

fn process_version_spec(&self, version_spec: &str) -> String {
let version_spec = version_spec.trim();
if let Some(stripped) = version_spec.strip_prefix("==") {
stripped.to_string()
} else if let Some(stripped) = version_spec.strip_prefix(">=") {
stripped.to_string()
} else if let Some(stripped) = version_spec.strip_prefix("<=") {
stripped.to_string()
} else if let Some(stripped) = version_spec.strip_prefix('>') {
stripped.to_string()
} else if let Some(stripped) = version_spec.strip_prefix('<') {

// For version specs with multiple constraints, preserve as-is
if version_spec.contains(',') {
return version_spec.to_string();
}

// Handle special cases in order of precedence
if version_spec.starts_with("~=")
|| version_spec.starts_with(">=")
|| version_spec.starts_with("<=")
|| version_spec.starts_with(">")
|| version_spec.starts_with("<")
|| version_spec.starts_with("!=")
{
// Preserve these operators as-is
version_spec.to_string()
} else if let Some(stripped) = version_spec.strip_prefix("==") {
// Remove double equals for exact versions
stripped.to_string()
} else if let Some(stripped) = version_spec.strip_prefix('~') {
// Convert single tilde to tilde-equals
format!("~={}", stripped)
} else {
// If no operator is present, preserve as-is
version_spec.to_string()
}
}

fn parse_requirement(&self, line: &str) -> Result<Option<Dependency>, String> {
// Handle editable installs
// Handle editable installs (-e flag)
let line = if line.starts_with("-e") {
let parts: Vec<&str> = line.splitn(2, ' ').collect();
if parts.len() != 2 {
Expand All @@ -145,54 +158,9 @@ impl RequirementsMigrationSource {
// Handle URLs and git repositories
let (name, version) =
if package_spec.starts_with("git+") || package_spec.starts_with("http") {
let name = if let Some(egg_part) = package_spec.split('#').last() {
if egg_part.starts_with("egg=") {
egg_part.trim_start_matches("egg=")
} else if package_spec.ends_with(".whl") {
package_spec
.split('/')
.last()
.and_then(|f| f.split('-').next())
.ok_or("Invalid wheel filename")?
} else {
return Err("Invalid URL format".to_string());
}
} else {
package_spec
.split('/')
.last()
.and_then(|f| f.split('.').next())
.ok_or("Invalid URL format")?
};
(name.to_string(), None)
self.parse_url_requirement(package_spec)?
} else {
// Regular package specification
let (name, version) = {
if !package_spec.contains(&['>', '<', '=', '~', '!'][..]) {
(package_spec.to_string(), None)
} else {
let name_end = package_spec
.find(|c| ['>', '<', '=', '~', '!'].contains(&c))
.unwrap();
let name = package_spec[..name_end].trim().to_string();
let version_spec = package_spec[name_end..].trim();

let version = if version_spec.contains(',') {
// For multiple version constraints
let version_parts: Vec<String> = version_spec
.split(',')
.map(|p| p.trim().to_string())
.collect();
Some(version_parts.join(","))
} else {
// For single version constraint
Some(self.process_version_spec(version_spec))
};

(name, version)
}
};
(name, version)
self.parse_regular_requirement(package_spec)?
};

if name == "python" {
Expand All @@ -213,4 +181,53 @@ impl RequirementsMigrationSource {
environment_markers,
}))
}

fn parse_url_requirement(
&self,
package_spec: &str,
) -> Result<(String, Option<String>), String> {
let name = if let Some(egg_part) = package_spec.split('#').last() {
if egg_part.starts_with("egg=") {
egg_part.trim_start_matches("egg=").to_string()
} else if package_spec.ends_with(".whl") {
package_spec
.split('/')
.last()
.and_then(|f| f.split('-').next())
.ok_or("Invalid wheel filename")?
.to_string()
} else {
return Err("Invalid URL format".to_string());
}
} else {
package_spec
.split('/')
.last()
.and_then(|f| f.split('.').next())
.ok_or("Invalid URL format")?
.to_string()
};

Ok((name, None))
}

fn parse_regular_requirement(
&self,
package_spec: &str,
) -> Result<(String, Option<String>), String> {
// Return early if no version specifier is present
if !package_spec.contains(&['>', '<', '=', '~', '!'][..]) {
return Ok((package_spec.to_string(), None));
}

let name_end = package_spec
.find(|c| ['>', '<', '=', '~', '!'].contains(&c))
.unwrap();
let name = package_spec[..name_end].trim().to_string();
let version_spec = package_spec[name_end..].trim();

let version = Some(self.process_version_spec(version_spec));

Ok((name, version))
}
}
43 changes: 37 additions & 6 deletions tests/requirements_text_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ sqlalchemy<2.0.0
assert_eq!(requests_dep.dep_type, DependencyType::Main);

let flask_dep = dependencies.iter().find(|d| d.name == "flask").unwrap();
assert_eq!(flask_dep.version, Some("2.0.0".to_string()));
assert_eq!(flask_dep.version, Some(">=2.0.0".to_string()));
assert!(matches!(flask_dep.dep_type, DependencyType::Main));

let sqlalchemy_dep = dependencies
.iter()
.find(|d| d.name == "sqlalchemy")
.unwrap();
assert_eq!(sqlalchemy_dep.version, Some("<2.0.0".to_string()));
assert!(matches!(sqlalchemy_dep.dep_type, DependencyType::Main));
}

/// Test handling of comments and empty lines in requirements files.
Expand Down Expand Up @@ -223,18 +230,45 @@ flask>=2.0.0,<3.0.0
requests~=2.31.0
django>3.0.0,<=4.2.0
sqlalchemy!=1.4.0,>=1.3.0
django-filters~=23.5
boto3~=1.35
"#;

let (_temp_dir, project_dir) = create_test_project(vec![("requirements.txt", content)]);

let source = RequirementsMigrationSource;
let dependencies = source.extract_dependencies(&project_dir).unwrap();

assert_eq!(dependencies.len(), 4);
assert_eq!(dependencies.len(), 6);

// Verify complex version constraints are preserved
let flask_dep = dependencies.iter().find(|d| d.name == "flask").unwrap();
assert_eq!(flask_dep.version, Some(">=2.0.0,<3.0.0".to_string()));

// Verify tilde-equal is preserved
let requests_dep = dependencies.iter().find(|d| d.name == "requests").unwrap();
assert_eq!(requests_dep.version, Some("~=2.31.0".to_string()));

// Verify multiple constraints with inequality
let django_dep = dependencies.iter().find(|d| d.name == "django").unwrap();
assert_eq!(django_dep.version, Some(">3.0.0,<=4.2.0".to_string()));

// Verify complex constraints with not-equal
let sqlalchemy_dep = dependencies
.iter()
.find(|d| d.name == "sqlalchemy")
.unwrap();
assert_eq!(sqlalchemy_dep.version, Some("!=1.4.0,>=1.3.0".to_string()));

// Verify tilde-equal cases
let filters_dep = dependencies
.iter()
.find(|d| d.name == "django-filters")
.unwrap();
assert_eq!(filters_dep.version, Some("~=23.5".to_string()));

let boto3_dep = dependencies.iter().find(|d| d.name == "boto3").unwrap();
assert_eq!(boto3_dep.version, Some("~=1.35".to_string()));
}

/// Test handling of editable installs and URLs.
Expand All @@ -257,12 +291,9 @@ git+https://github.com/user/other-project.git@v1.0.0#egg=other-project
let dependencies = source.extract_dependencies(&project_dir).unwrap();

assert_eq!(dependencies.len(), 3);

// Note: The exact assertions here will depend on how your implementation handles URLs
// You might want to add more specific assertions based on your implementation
}

/// Test error handling for malformed requirements files.
/// Test handling of malformed requirements files.
///
/// This test verifies that:
/// 1. Invalid requirement formats are handled gracefully
Expand Down

0 comments on commit 9f763b4

Please sign in to comment.