Skip to content

Commit

Permalink
Protect from_hepevt against invalid parents/children record (#47)
Browse files Browse the repository at this point in the history
This protects from_hepevt from invalid parent/children ranges.
Previously the code would abort or even crash on such input. Now it
raises a Python RuntimeError.
  • Loading branch information
HDembinski authored Oct 6, 2022
1 parent 968be50 commit eb6a3e6
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/from_hepevt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@
#include <utility>
#include <vector>

// template <>
// struct std::hash<std::pair<int, int>> {
// std::size_t operator()(const std::pair<int, int>& p) const noexcept {
// auto h1 = std::hash<int>{}(p.first);
// auto h2 = std::hash<int>{}(p.second);
// return h1 ^ (h2 << 1); // or use boost::hash_combine
// }
// };

template <>
struct std::less<std::pair<int, int>> {
bool operator()(const std::pair<int, int>& a,
Expand All @@ -35,9 +26,10 @@ void normalize(int& m1, int& m2) {
// m1 < m2, both > 0: interaction
// m2 < m1, both > 0: same, needs swapping

if (m1 > 0 && m2 > 0 && m2 < m1) std::swap(m1, m2);

if (m1 > 0 && m2 == 0) m2 = m1;
if (m1 > 0 && m2 == 0)
m2 = m1;
else if (m2 < m1)
std::swap(m1, m2);

--m1; // fortran to c index
}
Expand Down Expand Up @@ -102,14 +94,28 @@ void connect_parents_and_children(GenEvent& event, bool parents,

int m1 = vi.first.first;
int m2 = vi.first.second;

// there must be at least one parent or child when we arrive here...
normalize(m1, m2);
assert(m1 < m2);
assert(m1 < m2); // postcondition after normalize

if (m1 < 0 || m2 > n) {
std::ostringstream os;
os << "invalid " << (parents ? "parents" : "children") << " range for vertex "
<< event.vertices().size() << " [" << m1 << ", " << m2
<< ") total number of particles " << n;
throw std::runtime_error(os.str().c_str());
}

// ...with at least one child or parent
const auto& co = vi.second;
assert(!co.empty());

if (co.empty()) {
std::ostringstream os;
os << "invalid empty " << (!parents ? "parents" : "children")
<< " list for vertex " << event.vertices().size();
throw std::runtime_error(os.str().c_str());
}
FourVector pos;
if (has_vertex) {
// we assume this is a production vertex
Expand Down
57 changes: 57 additions & 0 deletions tests/test_from_hepevt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pyhepmc as hep
import numpy as np
import pytest


def test_no_vertex_info():
px = py = pz = en = m = np.linspace(0, 1, 4)

pid = np.arange(4) + 1
sta = np.zeros(4, dtype=np.int32)
parents = [(0, 0), (1, 1), (0, 0), (0, 0)]
hev = hep.GenEvent()
hev.from_hepevt(0, px, py, pz, en, m, pid, sta, parents)
assert len(hev.vertices) == 1
assert len(hev.particles) == 4


def test_parents_range_exceeding_particle_range():
px = py = pz = en = m = np.linspace(0, 1, 6)
pid = np.arange(6) + 1
sta = np.zeros(6, dtype=np.int32)
parents = [(0, 0), (1, 1), (2, 0), (3, 5), (4, 10), (3, 5)]
with pytest.raises(RuntimeError):
hep.GenEvent().from_hepevt(0, px, py, pz, en, m, pid, sta, parents)


def test_invalid_length_of_parents():
px = py = pz = en = m = np.linspace(0, 1, 3)
pid = np.arange(3) + 1
sta = np.zeros(3, dtype=np.int32)
parents = [(0, 0), (1, 2)]
with pytest.raises(RuntimeError):
hep.GenEvent().from_hepevt(0, px, py, pz, en, m, pid, sta, parents)


def test_inverted_parents_range():
px = py = pz = en = m = vx = vy = vz = vt = np.linspace(0, 1, 4)
pid = np.arange(4) + 1
sta = np.zeros(4, dtype=np.int32)
# inverted range is not an error (2, 1) will be converted to (1, 2)
parents = [(0, 0), (2, 1), (3, 3), (3, 3)]
hev = hep.GenEvent()
hev.from_hepevt(0, px, py, pz, en, m, pid, sta, parents)
expected = [[0, 1], [2]]
got = [[p.id - 1 for p in v.particles_in] for v in hev.vertices]
assert expected == got


@pytest.mark.parametrize("bad", ([-4, 1], [1, -4]))
def test_negative_parents_range(bad):
px = py = pz = en = m = vx = vy = vz = vt = np.linspace(0, 1, 4)
pid = np.arange(4) + 1
sta = np.zeros(4, dtype=np.int32)
# inverted range is not an error (2, 1) will be converted to (1, 2)
parents = [(0, 0), bad, (3, 3), (3, 3)]
with pytest.raises(RuntimeError):
hep.GenEvent().from_hepevt(0, px, py, pz, en, m, pid, sta, parents)

0 comments on commit eb6a3e6

Please sign in to comment.