Skip to content

Commit b230bb5

Browse files
Exposes methods that would be required by pineko#183 and more
1 parent 3386f4b commit b230bb5

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

pineappl_py/src/grid.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,16 @@ impl PyGrid {
698698
.collect()
699699
}
700700

701+
/// Rotate the Grid into the specified basis
702+
///
703+
/// Parameters
704+
/// ----------
705+
/// pid_basis: PyPidBasis
706+
/// PID basis of the resulting Grid
707+
pub fn rotate_pid_basis(&mut self, pid_basis: PyPidBasis) {
708+
self.grid.rotate_pid_basis(pid_basis.into());
709+
}
710+
701711
/// Scale all subgrids.
702712
///
703713
/// Parameters
@@ -716,12 +726,23 @@ impl PyGrid {
716726
///
717727
/// Parameters
718728
/// ----------
719-
/// factors : numpy.ndarray[float]
729+
/// factors : list[float]
720730
/// bin-dependent factors by which to scale
721731
pub fn scale_by_bin(&mut self, factors: Vec<f64>) {
722732
self.grid.scale_by_bin(&factors);
723733
}
724734

735+
/// Delete orders with the corresponding `order_indices`. Repeated indices and indices larger
736+
/// or equal than the number of orders are ignored.
737+
///
738+
/// Parameters
739+
/// ----------
740+
/// order_indices : list[int]
741+
/// list of indices of orders to be removed
742+
pub fn delete_orders(&mut self, order_indices: Vec<usize>) {
743+
self.grid.delete_orders(&order_indices);
744+
}
745+
725746
/// Delete bins.
726747
///
727748
/// # Panics
@@ -732,11 +753,27 @@ impl PyGrid {
732753
///
733754
/// Parameters
734755
/// ----------
735-
/// bin_indices : numpy.ndarray[int]
736-
/// list of indices of bins to removed
756+
/// bin_indices : list[int]
757+
/// list of indices of bins to be removed
737758
pub fn delete_bins(&mut self, bin_indices: Vec<usize>) {
738759
self.grid.delete_bins(&bin_indices);
739760
}
761+
762+
/// Deletes channels with the corresponding `channel_indices`. Repeated indices and indices
763+
/// larger or equal than the number of channels are ignored.
764+
///
765+
/// Parameters
766+
/// ----------
767+
/// bin_indices : list[int]
768+
/// list of indices of bins to be removed
769+
pub fn delete_channels(&mut self, channel_indices: Vec<usize>) {
770+
self.grid.delete_channels(&channel_indices);
771+
}
772+
773+
/// Splits the grid such that each channel contains only a single tuple of PIDs.
774+
pub fn split_channels(&mut self) {
775+
self.grid.split_channels();
776+
}
740777
}
741778

742779
/// Register submodule in parent.

pineappl_py/tests/test_grid.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,55 @@ def test_bins(self, fake_grids):
212212
np.testing.assert_allclose(g.bin_left(1), [2, 3])
213213
np.testing.assert_allclose(g.bin_right(1), [3, 5])
214214

215+
def test_rotate_pidbasis(self, fake_grids):
216+
g = fake_grids.grid_with_generic_convolution(
217+
nb_convolutions=2,
218+
channels=CHANNELS,
219+
orders=ORDERS,
220+
convolutions=[CONVOBJECT, CONVOBJECT],
221+
)
222+
# Rotate the Grid into the PDG basis
223+
g.rotate_pid_basis(PidBasis.Pdg)
224+
assert g.pid_basis == PidBasis.Pdg
225+
226+
def test_delete_orders(
227+
self,
228+
download_objects,
229+
gridname: str = "GRID_STAR_WMWP_510GEV_WP-AL-POL.pineappl.lz4",
230+
order_indices: list[int] = [1],
231+
):
232+
grid = download_objects(f"{gridname}")
233+
g = Grid.read(grid)
234+
orders = [o.as_tuple() for o in g.orders()]
235+
g.delete_orders(order_indices)
236+
for idx in order_indices:
237+
assert orders[idx] not in g.orders()
238+
239+
def test_delete_channels(
240+
self,
241+
download_objects,
242+
gridname: str = "GRID_STAR_WMWP_510GEV_WP-AL-POL.pineappl.lz4",
243+
channel_indices: list[int] = [1, 4, 5],
244+
):
245+
grid = download_objects(f"{gridname}")
246+
g = Grid.read(grid)
247+
channels = g.channels()
248+
g.delete_channels(channel_indices)
249+
for idx in channel_indices:
250+
assert channels[idx] not in g.channels()
251+
252+
def test_split_channels(
253+
self,
254+
pdf,
255+
download_objects,
256+
gridname: str = "GRID_DYE906R_D_bin_1.pineappl.lz4",
257+
):
258+
grid = download_objects(f"{gridname}")
259+
g = Grid.read(grid)
260+
assert len(g.channels()) == 15
261+
g.split_channels()
262+
assert len(g.channels()) == 170
263+
215264
def test_grid(
216265
self,
217266
download_objects,

0 commit comments

Comments
 (0)