Skip to content

Commit

Permalink
implemented fix for escaping underscores in latex repr and added a un…
Browse files Browse the repository at this point in the history
…it test
  • Loading branch information
Dekermanjian committed Sep 11, 2024
1 parent d596afb commit b284be1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
if not var_reprs:
return ""
if "latex" in formatting:
var_reprs = [_format_underscore(x) for x in var_reprs]
var_reprs = [
var_repr.replace(r"\sim", r"&\sim &").strip("$")
for var_repr in var_reprs
Expand Down Expand Up @@ -295,3 +296,19 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
except (ModuleNotFoundError, AttributeError):
# no ipython shell
pass


def _format_underscore(variable: str) -> str:
"""
formats variables with underscores in its name by prefixing underscores by '\\'
---
Params:
variable: The string representation of the variable in the model
"""
if "_" not in variable:
return variable
inds = [i for i, ltr in enumerate(variable) if ltr == "_"]
for i, ind in enumerate(inds):
ind = ind + i
variable = variable[:ind] + "\\" + variable[ind:]
return variable
24 changes: 24 additions & 0 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

import numpy as np

from pytensor.tensor.random import normal
Expand Down Expand Up @@ -316,3 +318,25 @@ def random(rng, mu, size):

str_repr = model.str_repr(include_params=False)
assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])


class TestLatexRepr:
@staticmethod
def simple_model() -> Model:
with Model() as simple_model:
error = HalfNormal("error", 0.5)
alpha = Normal("alpha", 0, 1)
Normal("y", alpha, error)
return simple_model

def test_latex_escaped_underscore(self):
"""
Ensures that all underscores in model variable names are properly escaped for LaTeX representation
"""
model = self.simple_model()
model_str = model.str_repr(formatting="latex")
underscores = re.finditer(r"_", model_str)
for match in underscores:
if match:
start = match.span(0)[0] - 1
assert model_str[start : start + 1] == "\\"

0 comments on commit b284be1

Please sign in to comment.