Skip to content

Commit 2618439

Browse files
committed
fix(postgres): case-aware type name equality
1 parent 3214336 commit 2618439

File tree

1 file changed

+97
-1
lines changed

1 file changed

+97
-1
lines changed

sqlx-postgres/src/type_info.rs

+97-1
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,103 @@ impl PartialEq<PgType> for PgType {
11541154
true
11551155
} else {
11561156
// Otherwise, perform a match on the name
1157-
self.name().eq_ignore_ascii_case(other.name())
1157+
name_eq(self.name(), other.name())
11581158
}
11591159
}
11601160
}
1161+
1162+
/// Check type names for equality, respecting Postgres' case sensitivity rules for identifiers.
1163+
///
1164+
/// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
1165+
fn name_eq(name1: &str, name2: &str) -> bool {
1166+
// Cop-out of processing Unicode escapes by just using string equality.
1167+
if name1.starts_with("U&") {
1168+
// If `name2` doesn't start with `U&` this will automatically be `false`.
1169+
return name1 == name2;
1170+
}
1171+
1172+
let mut chars1 = identifier_chars(name1);
1173+
let mut chars2 = identifier_chars(name2);
1174+
1175+
while let (Some(a), Some(b)) = (chars1.next(), chars2.next()) {
1176+
if !a.eq(&b) {
1177+
return false;
1178+
}
1179+
}
1180+
1181+
chars1.next().is_none() && chars2.next().is_none()
1182+
}
1183+
1184+
struct IdentifierChar {
1185+
ch: char,
1186+
case_sensitive: bool,
1187+
}
1188+
1189+
impl IdentifierChar {
1190+
fn eq(&self, other: &Self) -> bool {
1191+
if self.case_sensitive || other.case_sensitive {
1192+
self.ch == other.ch
1193+
} else {
1194+
self.ch.eq_ignore_ascii_case(&other.ch)
1195+
}
1196+
}
1197+
}
1198+
1199+
/// Return an iterator over all significant characters of an identifier.
1200+
///
1201+
/// Ignores non-escaped quotation marks.
1202+
fn identifier_chars(ident: &str) -> impl Iterator<Item = IdentifierChar> + '_ {
1203+
let mut case_sensitive = false;
1204+
let mut last_char_quote = false;
1205+
1206+
ident.chars().filter_map(move |ch| {
1207+
if ch == '"' {
1208+
if last_char_quote {
1209+
last_char_quote = false;
1210+
} else {
1211+
last_char_quote = true;
1212+
return None;
1213+
}
1214+
} else if last_char_quote {
1215+
last_char_quote = false;
1216+
case_sensitive = !case_sensitive;
1217+
}
1218+
1219+
Some(IdentifierChar { ch, case_sensitive })
1220+
})
1221+
}
1222+
1223+
#[test]
1224+
fn test_name_eq() {
1225+
let test_values = [
1226+
("foo", "foo", true),
1227+
("foo", "Foo", true),
1228+
("foo", "FOO", true),
1229+
("foo", r#""foo""#, true),
1230+
("foo", r#""Foo""#, false),
1231+
("foo", "foo.foo", false),
1232+
("foo.foo", "foo.foo", true),
1233+
("foo.foo", "foo.Foo", true),
1234+
("foo.foo", "foo.FOO", true),
1235+
("foo.foo", "Foo.foo", true),
1236+
("foo.foo", "Foo.Foo", true),
1237+
("foo.foo", "FOO.FOO", true),
1238+
("foo.foo", "foo", false),
1239+
("foo.foo", r#"foo."foo""#, true),
1240+
("foo.foo", r#"foo."Foo""#, false),
1241+
("foo.foo", r#"foo."FOO""#, false),
1242+
];
1243+
1244+
for (left, right, eq) in test_values {
1245+
assert_eq!(
1246+
name_eq(left, right),
1247+
eq,
1248+
"failed check for name_eq({left:?}, {right:?})"
1249+
);
1250+
assert_eq!(
1251+
name_eq(right, left),
1252+
eq,
1253+
"failed check for name_eq({right:?}, {left:?})"
1254+
);
1255+
}
1256+
}

0 commit comments

Comments
 (0)