diff --git a/n3fit/src/n3fit/layers/DY.py b/n3fit/src/n3fit/layers/DY.py index 232434c11c..9fb6897083 100644 --- a/n3fit/src/n3fit/layers/DY.py +++ b/n3fit/src/n3fit/layers/DY.py @@ -33,7 +33,13 @@ def gen_mask(self, basis): return op.numpy_to_tensor(basis_mask, dtype=bool) def mask_fk(self, fk, mask): - return op.einsum('fgF, nFxy -> nxfyg', mask, fk) + if self.num_replicas > 1: + return op.einsum('fgF, nFxy -> nxfyg', mask, fk) + else: + mask = op.einsum('fgF -> Ffg', mask) + fk = op.einsum('nFxy -> nFyx', fk) + mask_and_fk = (mask, fk) + return mask_and_fk def build(self, input_shape): super().build(input_shape) @@ -45,10 +51,16 @@ def compute_observable(pdf, masked_fk): else: - def compute_observable(pdf, masked_fk): + def compute_observable(pdf, mask_and_fk): + # with 1 replica, it's more efficient to mask the PDF rather than the fk table + mask, unmasked_fk = mask_and_fk pdf = pdf[0][0] # yg - temp = op.tensor_product(masked_fk, pdf, axes=2) # nxfyg, yg -> nxf - observable = op.tensor_product(temp, pdf, axes=2) # nxf, xf -> n + + mask_x_pdf = op.tensor_product(mask, pdf, axes=[(2,), (1,)]) # Ffg, yg -> Ffy + pdf_x_pdf = op.tensor_product(mask_x_pdf, pdf, axes=[(1,), (1,)]) # Ffy, xf -> Fyx + # nFyx, Fyx -> n + observable = op.tensor_product(unmasked_fk, pdf_x_pdf, axes=[(1, 2, 3), (0, 1, 2)]) + return op.batchit(op.batchit(observable)) # brn self.compute_observable = compute_observable diff --git a/n3fit/src/n3fit/layers/observable.py b/n3fit/src/n3fit/layers/observable.py index b4163b67b0..bb1a3eec5d 100644 --- a/n3fit/src/n3fit/layers/observable.py +++ b/n3fit/src/n3fit/layers/observable.py @@ -53,7 +53,7 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs): xgrids.append(fkdata.xgrid.reshape(1, -1)) all_bases.append(fkdata.luminosity_mapping) fktables.append(op.numpy_to_tensor(fk)) - fktables = fktables + self.fktables = fktables # check how many xgrids this dataset needs if is_unique(xgrids): @@ -62,21 +62,23 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs): self.splitting = [i.shape[1] for i in xgrids] self.operation = op.c_to_py_fun(operation_name) - self.output_dim = fktables[0].shape[0] + self.output_dim = self.fktables[0].shape[0] if is_unique(all_bases) and is_unique(xgrids): self.all_masks = [self.gen_mask(all_bases[0])] else: self.all_masks = [self.gen_mask(basis) for basis in all_bases] - masks = [self.compute_float_mask(bool_mask) for bool_mask in self.all_masks] - # repeat the masks if necessary for fktables (if not, the extra copies - # will get lost in the zip) - masks = masks * len(fktables) - self.masked_fktables = [self.mask_fk(fk, mask) for fk, mask in zip(fktables, masks)] + self.masks = [self.compute_float_mask(bool_mask) for bool_mask in self.all_masks] def build(self, input_shape): self.num_replicas = input_shape[1] + + # repeat the masks if necessary for fktables (if not, the extra copies + # will get lost in the zip) + masks = self.masks * len(self.fktables) + self.masked_fktables = [self.mask_fk(fk, mask) for fk, mask in zip(self.fktables, masks)] + super().build(input_shape) def compute_output_shape(self, input_shape):