|
8 | 8 | import logging
|
9 | 9 | import re
|
10 | 10 | from typing import (
|
| 11 | + Dict, |
11 | 12 | List,
|
12 | 13 | Optional,
|
13 | 14 | )
|
@@ -88,7 +89,7 @@ def validate_email(trans, email, user=None, check_dup=True, allow_empty=False, v
|
88 | 89 | message = validate_email_domain_name(domain)
|
89 | 90 |
|
90 | 91 | if not message:
|
91 |
| - if is_email_banned(email, trans.app.config.email_ban_file): |
| 92 | + if is_email_banned(email, trans.app.config.email_ban_file, trans.app.config.canonical_email_rules): |
92 | 93 | message = "This email address has been banned."
|
93 | 94 |
|
94 | 95 | stmt = select(trans.app.model.User).filter(func.lower(trans.app.model.User.email) == email.lower()).limit(1)
|
@@ -175,33 +176,81 @@ def validate_preferred_object_store_id(
|
175 | 176 | return object_store.validate_selected_object_store_id(trans.user, preferred_object_store_id) or ""
|
176 | 177 |
|
177 | 178 |
|
178 |
| -def is_email_banned(email: str, filepath: Optional[str]) -> bool: |
| 179 | +def is_email_banned(email: str, filepath: Optional[str], canonical_email_rules: Optional[Dict]) -> bool: |
179 | 180 | if not filepath:
|
180 | 181 | return False
|
181 |
| - email = _make_canonical_email(email) |
| 182 | + normalizer = EmailAddressNormalizer(canonical_email_rules) |
| 183 | + email = normalizer.normalize(email) |
182 | 184 | banned_emails = _read_email_ban_list(filepath)
|
183 | 185 | for address in banned_emails:
|
184 |
| - if email == _make_canonical_email(address): |
| 186 | + if email == normalizer.normalize(address): |
185 | 187 | return True
|
186 | 188 | return False
|
187 | 189 |
|
188 | 190 |
|
189 |
| -def _make_canonical_email(email: str) -> str: |
190 |
| - """ |
191 |
| - Transform to canonical representation: |
192 |
| - - lowercase |
193 |
| - - gmail: drop periods in local-part |
194 |
| - - gmail: drop plus suffixes in local-part |
195 |
| - """ |
196 |
| - email = email.lower() |
197 |
| - localpart, domain = email.split("@") |
198 |
| - if domain == "gmail.com": |
199 |
| - localpart = localpart.replace(".", "") |
200 |
| - if localpart.find("+") > -1: |
201 |
| - localpart = localpart[: localpart.index("+")] |
202 |
| - return f"{localpart}@{domain}" |
203 |
| - |
204 |
| - |
205 | 191 | def _read_email_ban_list(filepath: str) -> List[str]:
|
206 | 192 | with open(filepath) as f:
|
207 | 193 | return [line.strip() for line in f if not line.startswith("#")]
|
| 194 | + |
| 195 | + |
| 196 | +class EmailAddressNormalizer: |
| 197 | + IGNORE_CASE_RULE = "ignore_case" |
| 198 | + IGNORE_DOTS_RULE = "ignore_dots" |
| 199 | + SUB_ADDRESSING_RULE = "sub_addressing" |
| 200 | + SUB_ADDRESSING_DELIM = "sub_addressing_delim" |
| 201 | + SUB_ADDRESSING_DELIM_DEFAULT = "+" |
| 202 | + ALL = "all" |
| 203 | + |
| 204 | + def __init__(self, canonical_email_rules: Optional[Dict]) -> None: |
| 205 | + self.config = canonical_email_rules |
| 206 | + |
| 207 | + def normalize(self, email: str) -> str: |
| 208 | + """Transform email to its canonical form.""" |
| 209 | + |
| 210 | + email_localpart, email_domain = email.split("@") |
| 211 | + # the domain part of an email address is case-insensitive (RFC1035) |
| 212 | + email_domain = email_domain.lower() |
| 213 | + |
| 214 | + # Step 1: If no rules are set, do not modify local-part |
| 215 | + if not self.config: |
| 216 | + return f"{email_localpart}@{email_domain}" |
| 217 | + |
| 218 | + # Step 2: Apply rules defined for all services before applying rules defined for specific services |
| 219 | + if self.ALL in self.config: |
| 220 | + email_localpart = self._apply_rules(email_localpart, self.ALL) |
| 221 | + |
| 222 | + # Step 3: Apply rules definied for each email service if email matches service |
| 223 | + for service in (s for s in self.config if s != self.ALL): |
| 224 | + service = service.lower() # ensure domain is lowercase |
| 225 | + apply_rules = False |
| 226 | + |
| 227 | + if email_domain == service: |
| 228 | + apply_rules = True |
| 229 | + elif self.config[service].get("aliases"): |
| 230 | + service_aliases = [ |
| 231 | + a.lower() for a in self.config[service]["aliases"] |
| 232 | + ] # ensure domain aliases are lowercase |
| 233 | + if email_domain in service_aliases: |
| 234 | + # email domain is an alias of the service. Change it to the service's primary domain name. |
| 235 | + email_domain = service |
| 236 | + apply_rules = True |
| 237 | + |
| 238 | + if apply_rules: |
| 239 | + email_localpart = self._apply_rules(email_localpart, service) |
| 240 | + |
| 241 | + return f"{email_localpart}@{email_domain}" |
| 242 | + |
| 243 | + def _apply_rules(self, email_localpart: str, service: str) -> str: |
| 244 | + assert self.config |
| 245 | + config = self.config[service] |
| 246 | + |
| 247 | + if config.get(self.IGNORE_CASE_RULE, False): |
| 248 | + email_localpart = email_localpart.lower() |
| 249 | + if config.get(self.IGNORE_DOTS_RULE, False): |
| 250 | + email_localpart = email_localpart.replace(".", "") |
| 251 | + if config.get(self.SUB_ADDRESSING_RULE, False): |
| 252 | + delim = config.get(self.SUB_ADDRESSING_DELIM, self.SUB_ADDRESSING_DELIM_DEFAULT) |
| 253 | + if email_localpart.find(delim) > -1: |
| 254 | + email_localpart = email_localpart[: email_localpart.index(delim)] |
| 255 | + |
| 256 | + return email_localpart |
0 commit comments