Skip to content

Commit

Permalink
Fix mypy and add recursive tranformation
Browse files Browse the repository at this point in the history
  • Loading branch information
lazebnyi committed Dec 17, 2024
1 parent 453580d commit 7319974
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ definitions:
- "$ref": "#/definitions/CustomTransformation"
- "$ref": "#/definitions/RemoveFields"
- "$ref": "#/definitions/KeysToLower"
- "$ref": "#/definitions/KeyToSnakeCase"
- "$ref": "#/definitions/KeysToSnakeCase"
state_migrations:
title: State Migrations
description: Array of state migrations to be applied on the input state
Expand Down Expand Up @@ -1839,7 +1839,7 @@ definitions:
$parameters:
type: object
additionalProperties: true
KeyToSnakeCase:
KeysToSnakeCase:
title: Key to Snake Case
description: A transformation that renames all keys to snake case.
type: object
Expand All @@ -1848,7 +1848,7 @@ definitions:
properties:
type:
type: string
enum: [KeyToSnakeCase]
enum: [KeysToSnakeCase]
$parameters:
type: object
additionalProperties: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,8 @@ class KeysToLower(BaseModel):
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


class KeyToSnakeCase(BaseModel):
type: Literal["KeyToSnakeCase"]
class KeysToSnakeCase(BaseModel):
type: Literal["KeysToSnakeCase"]
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


Expand Down Expand Up @@ -1665,7 +1665,7 @@ class Config:
CustomTransformation,
RemoveFields,
KeysToLower,
KeyToSnakeCase,
KeysToSnakeCase,
]
]
] = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@

import re
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import unidecode

from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState

TOKEN_PATTERN = re.compile(r"[A-Z]+[a-z]*|[a-z]+|\d+|(?P<NoToken>[^a-zA-Z\d]+)")
DEFAULT_SEPARATOR = "_"


@dataclass
class KeyToSnakeCaseTransformation(RecordTransformation):
token_pattern: re.Pattern = TOKEN_PATTERN
class KeysToSnakeCaseTransformation(RecordTransformation):
token_pattern: re.Pattern[str] = re.compile(
r"[A-Z]+[a-z]*|[a-z]+|\d+|(?P<NoToken>[^a-zA-Z\d]+)"
)

def transform(
self,
Expand All @@ -26,13 +25,22 @@ def transform(
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
) -> None:
transformed_record = {}
for key in record:
transformed_key = self.process_key(key)
transformed_record[transformed_key] = record[key]
transformed_record = self._transform_record(record)
record.clear()
record.update(transformed_record)

def _transform_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
transformed_record = {}
for key, value in record.items():
transformed_key = self.process_key(key)
transformed_value = value

if isinstance(value, dict):
transformed_value = self._transform_record(value)

transformed_record[transformed_key] = transformed_value
return transformed_record

def process_key(self, key: str) -> str:
key = self.normalize_key(key)
tokens = self.tokenize_key(key)
Expand All @@ -42,19 +50,19 @@ def process_key(self, key: str) -> str:
def normalize_key(self, key: str) -> str:
return unidecode.unidecode(key)

def tokenize_key(self, key: str) -> list:
def tokenize_key(self, key: str) -> List[str]:
tokens = []
for match in self.token_pattern.finditer(key):
token = match.group(0) if match.group("NoToken") is None else ""
tokens.append(token)
return tokens

def filter_tokens(self, tokens: list) -> list:
def filter_tokens(self, tokens: List[str]) -> List[str]:
if len(tokens) >= 3:
tokens = tokens[:1] + [t for t in tokens[1:-1] if t] + tokens[-1:]
if tokens and tokens[0].isdigit():
tokens.insert(0, "")
return tokens

def tokens_to_snake_case(self, tokens: list) -> str:
def tokens_to_snake_case(self, tokens: List[str]) -> str:
return "_".join(token.lower() for token in tokens)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from airbyte_cdk.sources.declarative.transformations.keys_to_snake_transformation import (
KeyToSnakeCaseTransformation,
KeysToSnakeCaseTransformation,
)

_ANY_VALUE = -1
Expand All @@ -22,6 +22,16 @@
{"123Number": _ANY_VALUE, "456Another123": _ANY_VALUE},
{"_123_number": _ANY_VALUE, "_456_another_123": _ANY_VALUE},
),
(
{
"NestedRecord": {"FirstName": _ANY_VALUE, "lastName": _ANY_VALUE},
"456Another123": _ANY_VALUE,
},
{
"nested_record": {"first_name": _ANY_VALUE, "last_name": _ANY_VALUE},
"_456_another_123": _ANY_VALUE,
},
),
(
{"hello@world": _ANY_VALUE, "test#case": _ANY_VALUE},
{"hello_world": _ANY_VALUE, "test_case": _ANY_VALUE},
Expand All @@ -43,6 +53,6 @@
),
],
)
def test_key_transformation(input_keys, expected_keys):
KeyToSnakeCaseTransformation().transform(input_keys)
def test_keys_transformation(input_keys, expected_keys):
KeysToSnakeCaseTransformation().transform(input_keys)
assert input_keys == expected_keys

0 comments on commit 7319974

Please sign in to comment.