From b284be1cc56147b7d9490eaa5aec1e9d3d849ee0 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Wed, 11 Sep 2024 16:59:51 -0600 Subject: [PATCH 1/4] implemented fix for escaping underscores in latex repr and added a unit test --- pymc/printing.py | 17 +++++++++++++++++ tests/test_printing.py | 24 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pymc/printing.py b/pymc/printing.py index 56445ab9ea..fb002bff82 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -114,6 +114,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool if not var_reprs: return "" if "latex" in formatting: + var_reprs = [_format_underscore(x) for x in var_reprs] var_reprs = [ var_repr.replace(r"\sim", r"&\sim &").strip("$") for var_repr in var_reprs @@ -295,3 +296,19 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): except (ModuleNotFoundError, AttributeError): # no ipython shell pass + + +def _format_underscore(variable: str) -> str: + """ + formats variables with underscores in its name by prefixing underscores by '\\' + --- + Params: + variable: The string representation of the variable in the model + """ + if "_" not in variable: + return variable + inds = [i for i, ltr in enumerate(variable) if ltr == "_"] + for i, ind in enumerate(inds): + ind = ind + i + variable = variable[:ind] + "\\" + variable[ind:] + return variable diff --git a/tests/test_printing.py b/tests/test_printing.py index 406032b124..554628625f 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + import numpy as np from pytensor.tensor.random import normal @@ -316,3 +318,25 @@ def random(rng, mu, size): str_repr = model.str_repr(include_params=False) assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"]) + + +class TestLatexRepr: + @staticmethod + def simple_model() -> Model: + with Model() as simple_model: + error = HalfNormal("error", 0.5) + alpha = Normal("alpha", 0, 1) + Normal("y", alpha, error) + return simple_model + + def test_latex_escaped_underscore(self): + """ + Ensures that all underscores in model variable names are properly escaped for LaTeX representation + """ + model = self.simple_model() + model_str = model.str_repr(formatting="latex") + underscores = re.finditer(r"_", model_str) + for match in underscores: + if match: + start = match.span(0)[0] - 1 + assert model_str[start : start + 1] == "\\" From 753082f529c6af9e052ffc42f41bca724e22fd7a Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Wed, 11 Sep 2024 19:20:05 -0600 Subject: [PATCH 2/4] updated unit test staticmethod to include underscore in var name --- tests/test_printing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_printing.py b/tests/test_printing.py index 554628625f..4b717382c6 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -325,8 +325,8 @@ class TestLatexRepr: def simple_model() -> Model: with Model() as simple_model: error = HalfNormal("error", 0.5) - alpha = Normal("alpha", 0, 1) - Normal("y", alpha, error) + alpha_a = Normal("alpha_a", 0, 1) + Normal("y", alpha_a, error) return simple_model def test_latex_escaped_underscore(self): From 67383c54e7b681d875d265748282016b8a7fbaba Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 14 Sep 2024 09:26:07 -0600 Subject: [PATCH 3/4] add underscore escape fix to distribution repr as well as model repr, fixed testing to expect underscores in LaTeX representation to be escaped --- pymc/printing.py | 14 +++++++++++--- tests/test_printing.py | 14 +++++++------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/pymc/printing.py b/pymc/printing.py index fb002bff82..2215c6f573 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -58,6 +58,7 @@ def str_for_dist( if "latex" in formatting: if print_name is not None: print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}" + print_name = _format_underscore(print_name) op_name = ( dist.owner.op._print_name[1] @@ -307,8 +308,15 @@ def _format_underscore(variable: str) -> str: """ if "_" not in variable: return variable + inds = [i for i, ltr in enumerate(variable) if ltr == "_"] - for i, ind in enumerate(inds): - ind = ind + i - variable = variable[:ind] + "\\" + variable[ind:] + var_len_original = len(variable) + var_len = None + for ind in inds: + if var_len: + if var_len != var_len_original: + ind = ind + (var_len - var_len_original) + if variable[ind - 1 : ind] != "\\": + variable = variable[:ind] + "\\" + variable[ind:] + var_len = len(variable) return variable diff --git a/tests/test_printing.py b/tests/test_printing.py index 4b717382c6..9afc3a090d 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -173,15 +173,15 @@ def setup_class(self): r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$", r"$\text{beta} \sim \operatorname{Normal}(0,~10)$", r"$\text{Z} \sim \operatorname{MultivariateNormal}(f(),~f())$", - r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$", + r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$", r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))$", r"$\text{w} \sim \operatorname{Dirichlet}(\text{})$", ( - r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w}," + r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}(\text{w}," r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))," r"~\operatorname{Censored}(\operatorname{Bernoulli}(0.5),~-1,~1))$" ), - r"$\text{Y_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$", + r"$\text{Y\_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$", r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$", r"$\text{pred} \sim \operatorname{Deterministic}(f(\text{}))", ], @@ -191,11 +191,11 @@ def setup_class(self): r"$\text{mu} \sim \operatorname{Deterministic}$", r"$\text{beta} \sim \operatorname{Normal}$", r"$\text{Z} \sim \operatorname{MultivariateNormal}$", - r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}$", + r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}$", r"$\text{zip} \sim \operatorname{MarginalMixture}$", r"$\text{w} \sim \operatorname{Dirichlet}$", - r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$", - r"$\text{Y_obs} \sim \operatorname{Normal}$", + r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}$", + r"$\text{Y\_obs} \sim \operatorname{Normal}$", r"$\text{pot} \sim \operatorname{Potential}$", r"$\text{pred} \sim \operatorname{Deterministic}", ], @@ -258,7 +258,7 @@ def test_model_latex_repr_three_levels_model(): "$$", "\\begin{array}{rcl}", "\\text{mu} &\\sim & \\operatorname{Normal}(0,~5)\\\\\\text{sigma} &\\sim & " - "\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored_normal} &\\sim & " + "\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored\\_normal} &\\sim & " "\\operatorname{Censored}(\\operatorname{Normal}(\\text{mu},~\\text{sigma}),~-2,~2)", "\\end{array}", "$$", From 41568044e6e67cccbad056be63f0fd5db8a564da Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 14 Sep 2024 09:49:15 -0600 Subject: [PATCH 4/4] added cleaner method using re to escape underscores, added cleaner test to assert underscores are escaped --- pymc/printing.py | 22 ++++------------------ tests/test_printing.py | 8 ++------ 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/pymc/printing.py b/pymc/printing.py index 2215c6f573..ef417f3799 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,6 +13,8 @@ # limitations under the License. +import re + from functools import partial from pytensor.compile import SharedVariable @@ -301,22 +303,6 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): def _format_underscore(variable: str) -> str: """ - formats variables with underscores in its name by prefixing underscores by '\\' - --- - Params: - variable: The string representation of the variable in the model + Escapes all unescaped underscores in the variable name for LaTeX representation. """ - if "_" not in variable: - return variable - - inds = [i for i, ltr in enumerate(variable) if ltr == "_"] - var_len_original = len(variable) - var_len = None - for ind in inds: - if var_len: - if var_len != var_len_original: - ind = ind + (var_len - var_len_original) - if variable[ind - 1 : ind] != "\\": - variable = variable[:ind] + "\\" + variable[ind:] - var_len = len(variable) - return variable + return re.sub(r"(?