Skip to content

Commit

Permalink
more flexible open function, support for convertible events (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski authored Aug 22, 2022
1 parent 6368bc0 commit a0a1a25
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 21 deletions.
102 changes: 82 additions & 20 deletions src/pyhepmc/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
WriterAsciiHepMC2,
WriterHEPEVT,
)
from pathlib import PurePath
import typing as _tp


class _Iter:
Expand Down Expand Up @@ -76,39 +78,99 @@ def _read(self):
WriterHEPEVT.write = WriterHEPEVT.write_event


# pythonic wrapper for AsciiWriter, to be used by `open`
class WrappedAsciiWriter:
def __init__(self, filename, precision=None):
self._writer = (filename, precision)
# Wrapper for Writer, to be used by `open`
class _WrappedWriter:
def __init__(self, filename, precision, Writer):
self._writer = (filename, precision, Writer)

def write(self, event):
if not isinstance(event, GenEvent):
if hasattr(event, "to_hepmc3"):
event = event.to_hepmc3()
else:
raise TypeError(
"event must be an instance of GenEvent or "
"convertible to it by providing a to_hepmc3() method"
)
if isinstance(self._writer, tuple):
# first call
filename, precision = self._writer
self._writer = WriterAscii(filename, event.run_info)
if precision is not None:
filename, precision, Writer = self._writer
if Writer is WriterHEPEVT:
self._writer = Writer(filename)
else:
self._writer = Writer(filename, event.run_info)
if precision is not None and hasattr(self._writer, "precision"):
self._writer.precision = precision
self._writer.write_event(event)

def close(self):
self._writer.close()
if not isinstance(self._writer, tuple):
self._writer.close()

__enter__ = _enter
__exit__ = _exit


def pyhepmc_open(filename, mode="r", precision=None):
def pyhepmc_open(
filename: _tp.Union[str, PurePath],
mode: str = "r",
precision: int = None,
format: str = None,
):
"""
Open HepMC files for reading or writing.
Parameters
----------
filename : str or Path
Filename to open for reading or writing. When writing to existing files,
the contents are replaced.
mode : str, optional
Must be either "r" (default) or "w", to indicate whether to open for reading
or writing.
precision : int or None, optional
How many digits of precision to use when writing to a file. Can be used to
improve the compression rate.
format : str or None, optional
Which format to use for reading or writing. If None (default), autodetect
format when reading (this is fast and thus safe to use), and use the latest
HepMC3 format when writing. Allowed values: "HepMC3", "HepMC2", "LHEF",
"HEPEVT". "LHEF" is not supported for writing.
"""
if mode == "r":
with open(filename, "r") as f:
header = f.read(256)
if "HepMC::Asciiv3" in header:
return ReaderAscii(filename)
if "HepMC::IO_GenEvent" in header:
return ReaderAsciiHepMC2(filename)
if "<LesHouchesEvents" in header:
return ReaderLHEF(filename)
return ReaderHEPEVT(filename)

if format is None:
# auto-detect
with open(filename, "r") as f:
header = f.read(256)
if "HepMC::Asciiv3" in header:
Reader = ReaderAscii
elif "HepMC::IO_GenEvent" in header:
Reader = ReaderAsciiHepMC2
elif "<LesHouchesEvents" in header:
Reader = ReaderLHEF
else:
# this one has no header
Reader = ReaderHEPEVT
else:
Reader = {
"hepmc3": ReaderAscii,
"hepmc2": ReaderAsciiHepMC2,
"lhef": ReaderLHEF,
"hepevt": ReaderHEPEVT,
}.get(format.lower(), None)
if Reader is None:
raise ValueError(f"format {format} not recognized for reading")
return Reader(str(filename))
elif mode == "w":
return WrappedAsciiWriter(filename, precision)
if format is None:
Writer = WriterAscii
else:
Writer = {
"hepmc3": WriterAscii,
"hepmc2": WriterAsciiHepMC2,
"hepevt": WriterHEPEVT,
}.get(format.lower(), None)
if Writer is None:
raise ValueError(f"format {format} not recognized for writing")
return _WrappedWriter(str(filename), precision, Writer)
raise ValueError("mode must be r or w")
56 changes: 55 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from test_basic import evt # noqa
from pyhepmc._core import stringstream
from pathlib import Path


def test_read_write(evt): # noqa
Expand Down Expand Up @@ -50,7 +51,33 @@ def test_read_empty_stream(evt): # noqa
assert ok is True # reading empty stream is ok in HepMC


def test_open(evt): # noqa
@pytest.mark.parametrize("format", ("hepmc3", "hepmc2", "hepevt"))
def test_open_1(evt, format): # noqa
with hep.open("test_read_write_file.dat", "w", format=format) as f:
f.write(evt)

with hep.open("test_read_write_file.dat", format=format) as f:
evt2 = f.read()

if format in ("hepmc2", "hepevt"):
# ToolInfo not stored in this format, so adding it manually
evt2.run_info.tools = evt.run_info.tools

assert evt == evt2

with hep.open("test_read_write_file.dat") as f:
evt3 = f.read()

if format in ("hepmc2", "hepevt"):
# ToolInfo not stored in this format, so adding it manually
evt3.run_info.tools = evt.run_info.tools

assert evt == evt3

os.unlink("test_read_write_file.dat")


def test_open_2(evt): # noqa
with hep.open("test_read_write_file.dat", "w", precision=3) as f:
f.write(evt)

Expand All @@ -70,6 +97,33 @@ def test_open(evt): # noqa
os.unlink("test_read_write_file.dat")


def test_open_3(evt): # noqa
filename = Path("test_read_write_file.dat")

with hep.open(filename, "w") as f:
with pytest.raises(TypeError):
f.write(None)

with pytest.raises(TypeError):
f.write("foo")

class Foo:
def to_hepmc3(self):
return evt

foo = Foo()

with hep.open(filename, "w") as f:
f.write(foo)

with hep.open(filename) as f:
evt2 = f.read()

assert evt == evt2

filename.unlink()


@pytest.mark.parametrize(
"writer", (hep.WriterAscii, hep.WriterAsciiHepMC2, hep.WriterHEPEVT)
)
Expand Down

0 comments on commit a0a1a25

Please sign in to comment.