Skip to content

Commit

Permalink
config_file now accepts file-like object #19
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Oct 4, 2023
1 parent fedd695 commit 2da7fd5
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions showerpipe/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
arrays.
"""
import collections as cl
import contextlib as ctx
import io
import itertools as it
import math
import operator as op
Expand Down Expand Up @@ -244,11 +246,14 @@ class PythiaGenerator(base.GeneratorAdapter):
Parameters
----------
config_file : pathlib.Path | str
Path to Pythia cmnd configuration file.
lhe_file : pathlib.Path | str | bytes, optional
config_file : Path | str | file_like
Pythia cmnd configuration file. If path or string, it is assumed
that the input refers to the location of a file on disk. If file
object, it must be in a readable mode.
lhe_file : Path | str | bytes, optional
The variable or filepath containing the LHE data. May be a path,
string, or bytes object. If file, may be compressed with gzip.
string, or bytes object. If path to file, may be compressed with
gzip.
rng_seed : int
Seed passed to the random number generator used by Pythia.
quiet : bool
Expand All @@ -275,7 +280,7 @@ class PythiaGenerator(base.GeneratorAdapter):

def __init__(
self,
config_file: ty.Union[str, Path],
config_file: ty.Union[str, Path, ty.TextIO],
lhe_file: ty.Optional[lhe._LHE_STORAGE] = None,
rng_seed: ty.Optional[int] = -1,
quiet: bool = True,
Expand All @@ -290,13 +295,17 @@ def __init__(
"Print": {"quiet": "on" if quiet else "off"},
"Random": {"setSeed": "on", "seed": str(rng_seed)},
}
with open(config_file, encoding="utf-8") as f:
for line in f:
with ctx.ExitStack() as stack:
if not isinstance(config_file, io.TextIOBase):
config_file = stack.enter_context(
open(config_file, encoding="utf-8")
)
for line in config_file:
key, val = line.partition("=")[::2]
sup_key, sub_key = map(lambda s: s.strip(), key.split(":"))
if sup_key.startswith("#"):
continue
config.setdefault(sup_key, dict())
config.setdefault(sup_key, {})
config[sup_key][sub_key] = val.strip()
if lhe_file is not None:
frame_type = config.get("Beams", {}).get("frameType", None)
Expand Down

0 comments on commit 2da7fd5

Please sign in to comment.