Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing #57

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
56 changes: 56 additions & 0 deletions .github/workflows/typecheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: typecheck

on:
push:
branches:
- main
- typing
pull_request:
branches:
- main

jobs:
mypy:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version:
["3.8", "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10"]
option: ["", "--strict"]
include:
- python-version: "3.8"
version-name: "3.8"
- python-version: "3.9"
version-name: "3.9"
- python-version: "3.10"
version-name: "3.10"
- python-version: "3.11"
version-name: "3.11"
- python-version: "3.12"
version-name: "3.12"
- python-version: "pypy-3.9"
version-name: "3.9"
- python-version: "pypy-3.10"
version-name: "3.10"

steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

# TODO typing: fix
- name: Install dependencies
run: |
pip install "mypy==1.10.1"
pip install -e .[testing]
pip install -e .[css_inlining]
pip install types-docopt types-beautifulsoup4

- name: Run mypy ${{ matrix.option }}
run: |
mypy ${{ matrix.option }} --python-version="${{ matrix.version-name }}" mjml
if: ${{ !cancelled() }}
4 changes: 4 additions & 0 deletions mjml/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import typing as t


_Direction = t.Literal["top", "bottom", "left", "right"]
42 changes: 26 additions & 16 deletions mjml/core/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

from dotmap import DotMap
import typing as t

from ..lib import merge_dicts
from .registry import components


__all__ = ['initComponent', 'Component']

def initComponent(name, **initialDatas):
def initComponent(name: t.Optional[str],
**initialDatas: t.Any) -> t.Optional["Component"]:
if name is None:
return None
component_cls = components[name]
Expand All @@ -24,15 +24,22 @@ def initComponent(name, **initialDatas):


class Component:
component_name: t.ClassVar[str]

# LATER: not sure upstream also passes tagName, makes code easier for us
def __init__(self, *, attributes=None, children=(), content='', context=None,
props=None, globalAttributes=None, headStyle=None, tagName=None):
def __init__(self, *, attributes=None, children=(), content: str='',
context: t.Optional[t.Dict[str, t.Any]]=None,
props: t.Optional[t.Dict[str, t.Any]]=None,
globalAttributes: t.Optional[t.Dict[str, t.Any]]=None,
headStyle: t.Optional[t.Any]=None,
tagName: t.Optional[str]=None) -> None:
self.children = list(children)
self.content = content
self.context = context
# TODO typing: verify that this is the intent
self.context = context or dict()
self.tagName = tagName

self.props = DotMap(merge_dicts(props, {'children': children, 'content': content}))
self.props = merge_dicts(props or {}, {'children': children, 'content': content})

# upstream also checks "self.allowed_attrs"
self.attrs = merge_dicts(
Expand All @@ -46,26 +53,26 @@ def __init__(self, *, attributes=None, children=(), content='', context=None,
self.headStyle = headStyle

@classmethod
def getTagName(cls):
def getTagName(cls) -> str:
cls_name = cls.__name__
return cls_name

@classmethod
def isRawElement(cls):
def isRawElement(cls) -> bool:
cls_value = getattr(cls, 'rawElement', None)
return bool(cls_value)

# js: static defaultAttributes
@classmethod
def default_attrs(cls):
def default_attrs(cls) -> t.Dict[str, t.Any]:
return {}

# js: static allowedAttributes
@classmethod
def allowed_attrs(cls):
return ()
def allowed_attrs(cls) -> t.Dict[str, str]:
return {}

def getContent(self):
def getContent(self) -> str:
# Actually "self.content" should not be None but sometimes it is
# (probably due to bugs in this Python port). This special guard
# clause is the final fix to render the "welcome-email.mjml" from
Expand All @@ -74,17 +81,20 @@ def getContent(self):
return ''
return self.content.strip()

def getChildContext(self):
def getChildContext(self) -> t.Dict[str, t.Any]:
return self.context

# js: getAttribute(name)
def get_attr(self, name, *, missing_ok=False):
def get_attr(self, name: str, *, missing_ok: bool=False) -> t.Optional[t.Any]:
is_allowed_attr = name in self.allowed_attrs()
is_default_attr = name in self.default_attrs()
if not missing_ok and (not is_allowed_attr) and (not is_default_attr):
raise AssertionError(f'{self.__class__.__name__} has no declared attr {name}')
return self.attrs.get(name)
getAttribute = get_attr

def render(self):
def handler(self) -> t.Optional[str]:
return None

def render(self) -> str:
return ''
62 changes: 37 additions & 25 deletions mjml/elements/_base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@


from dotmap import DotMap
import itertools
import typing as t

from ..core import Component, initComponent
from ..core.registry import components
from ..helpers import *
from ..lib import merge_dicts


if t.TYPE_CHECKING:
from mjml._types import _Direction


__all__ = [
'BodyComponent',
]


class BodyComponent(Component):
def render(self):
raise NotImplementedError(f'{self.__cls__.__name__} should override ".render()"')
def render(self) -> str:
raise NotImplementedError(f'{self.__class__.__name__} should override ".render()"')

def getShorthandAttrValue(self, attribute, direction, attr_with_direction=True):
def getShorthandAttrValue(self,
attribute: str, direction: "_Direction",
attr_with_direction: bool=True) -> int:
if attr_with_direction:
mjAttributeDirection = self.getAttribute(f'{attribute}-{direction}')
else:
Expand All @@ -29,29 +35,28 @@ def getShorthandAttrValue(self, attribute, direction, attr_with_direction=True):
return 0
return shorthandParser(mjAttribute, direction)

def getShorthandBorderValue(self, direction):
def getShorthandBorderValue(self, direction: "_Direction") -> int:
borderDirection = direction and self.getAttribute(f'border-{direction}')
border = self.getAttribute('border')
return borderParser(borderDirection or border or '0')

def getBoxWidths(self):
def getBoxWidths(self) -> t.Dict[str, t.Any]:
containerWidth = self.context['containerWidth']
parsedWidth = strip_unit(containerWidth)
get_padding = lambda d: self.getShorthandAttrValue('padding', d)
paddings = get_padding('right') + get_padding('left')
borders = self.getShorthandBorderValue('right') + self.getShorthandBorderValue('left')

return DotMap({
return {
'totalWidth': parsedWidth,
'borders' : borders,
'paddings' : paddings,
'box' : parsedWidth - paddings - borders,
})
}

# js: htmlAttributes(attributes)
def html_attrs(self, **attrs):
def _to_str(kv):
key, value = kv
def html_attrs(self, **attrs: t.Any) -> str:
def _to_str(key: str, value: t.Any) -> t.Optional[str]:
if key == 'style':
value = self.styles(value)
elif key in ['class_', 'for_']:
Expand All @@ -61,46 +66,53 @@ def _to_str(kv):
if value is None:
return None
return f'{key}="{value}"'
serialized_attrs = map(_to_str, attrs.items())
serialized_attrs = itertools.starmap(_to_str, attrs.items())
return ' '.join(filter(None, serialized_attrs))

# js: getStyles()
def get_styles(self):
def get_styles(self) -> t.Dict[str, t.Any]:
return {}

# js: styles(styles)
def styles(self, key=None):
_styles = None
def styles(self, key: t.Optional[t.Any]=None) -> str:
_styles: t.Optional[t.Dict[str, t.Any]] = None

if key and isinstance(key, str):
_styles_dict = self.get_styles()
keys = key.split('.')
_styles = _styles_dict.get(keys[0])
if len(keys) > 1:
# TODO typing: fix
if not _styles:
raise RuntimeError()
_styles = _styles.get(keys[1])
if _styles and not isinstance(_styles, dict):
raise ValueError(f'key={key}')
elif key:
# predefined dict
_styles = key

if not _styles:
_styles = {}

def serializer(kv):
k, v = kv
def serializer(k: str, v: t.Any) -> t.Optional[str]:
return f'{k}:{v}' if is_not_empty(v) else None
style_attr_strs = filter(None, map(serializer, _styles.items()))

style_attr_strs = filter(None, itertools.starmap(serializer, _styles.items()))
style_str = ';'.join(style_attr_strs)
return style_str

def renderChildren(self, childrens=None, props=None, renderer=None,
attributes=None, rawXML=False):
# TODO typing: finish rest of type annotations
def renderChildren(self, childrens=None, props=None,
renderer: t.Optional[t.Callable[[Component], str]]=None,
attributes=None, rawXML=False) -> str:
if not props:
props = {}
if not renderer:
renderer = lambda component: component.render()
if not attributes:
attributes = {}
childrens = childrens or self.props.children
childrens = childrens or self.props.get("children")

if rawXML:
# return childrens.map(child => jsonToXML(child)).join('\n')
Expand All @@ -118,8 +130,8 @@ def renderChildren(self, childrens=None, props=None, renderer=None,
# child => !find(rawComponents, c => c.getTagName() === child.tagName),
#).length
raw_tag_names = set()
for tag_name, component in components.items():
if component.isRawElement():
for tag_name, component_cls in components.items():
if component_cls.isRawElement():
raw_tag_names.add(tag_name)
is_raw_element = lambda c: (c['tagName'] in raw_tag_names)

Expand Down
7 changes: 5 additions & 2 deletions mjml/elements/head/_head_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import typing as t

from mjml.core import Component, initComponent


__all__ = ['HeadComponent']


class HeadComponent(Component):
# TODO typing: figure out proper type annotations
def handlerChildren(self):
def handle_children(children):
def handle_children(children: t.Dict[str, t.Any]) -> t.Optional[str]:
tagName = children['tagName']
component = initComponent(
name = tagName,
Expand All @@ -25,5 +28,5 @@ def handle_children(children):
return component.render()
return None

childrens = self.props.children
childrens = self.props.get("children")
return tuple(map(handle_children, childrens))
12 changes: 9 additions & 3 deletions mjml/elements/head/mj_attributes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import typing as t

import typing_extensions as te

from mjml.helpers import omit

Expand All @@ -6,12 +9,15 @@

__all__ = ['MjAttributes']


class MjAttributes(HeadComponent):
component_name = 'mj-attributes'
component_name: t.ClassVar[str] = 'mj-attributes'

def handler(self):
@te.override
def handler(self) -> None:
add = self.context['add']
_children = self.props.children
if (_children := self.props.get("children")) is None:
return None

for child in _children:
tagName = child['tagName']
Expand Down
11 changes: 8 additions & 3 deletions mjml/elements/head/mj_breakpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import typing as t

import typing_extensions as te

from ._head_base import HeadComponent

Expand All @@ -6,14 +9,16 @@


class MjBreakpoint(HeadComponent):
component_name = 'mj-breakpoint'
component_name: t.ClassVar[str] = 'mj-breakpoint'

@te.override
@classmethod
def allowed_attrs(cls):
def allowed_attrs(cls) -> t.Dict[str, str]:
return {
'width': 'unit(px)',
}

def handler(self):
@te.override
def handler(self) -> None:
add = self.context['add']
add('breakpoint', self.getAttribute('width'))
Loading
Loading