From ad94132498ea8dc2fb2f114927fec6d143d1d9dd Mon Sep 17 00:00:00 2001 From: Aron Date: Thu, 8 Feb 2024 16:32:24 +0100 Subject: [PATCH] Refactor FK contractions --- .../backends/keras_backend/operations.py | 35 ------- n3fit/src/n3fit/layers/DIS.py | 50 ++++------ n3fit/src/n3fit/layers/DY.py | 69 ++++++-------- n3fit/src/n3fit/layers/observable.py | 93 +++++++++++++++---- 4 files changed, 121 insertions(+), 126 deletions(-) diff --git a/n3fit/src/n3fit/backends/keras_backend/operations.py b/n3fit/src/n3fit/backends/keras_backend/operations.py index 878b2ed11d..6c566df620 100644 --- a/n3fit/src/n3fit/backends/keras_backend/operations.py +++ b/n3fit/src/n3fit/backends/keras_backend/operations.py @@ -256,41 +256,6 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None): return concatenated_tensor -# Mathematical operations -def pdf_masked_convolution(raw_pdf, basis_mask): - """Computes a masked convolution of two equal pdfs - And applies a basis_mask so that only the actually useful values - of the convolution are returned. - - If training many PDFs at once, it will use as a backend `einsum`, which - is better suited for running on GPU (but slower on CPU). - - Parameters - ---------- - pdf: tf.tensor - rank 4 (batchsize, replicas, xgrid, flavours) - basis_mask: tf.tensor - rank 2 tensor (flavours, flavours) - mask to apply to the pdf convolution - - Return - ------ - pdf_x_pdf: tf.tensor - rank3 (replicas, len(mask_true), xgrid, xgrid) - """ - if raw_pdf.shape[1] == 1: # only one replica! - pdf = tf.squeeze(raw_pdf, axis=(0, 1)) - luminosity = tensor_product(pdf, pdf, axes=0) - lumi_tmp = K.permute_dimensions(luminosity, (3, 1, 2, 0)) - pdf_x_pdf = batchit(boolean_mask(lumi_tmp, basis_mask), 0) - else: - pdf = tf.squeeze(raw_pdf, axis=0) # remove the batchsize - luminosity = tf.einsum('rai,rbj->rjiba', pdf, pdf) - # (xgrid, flavour, xgrid, flavour) - pdf_x_pdf = boolean_mask(luminosity, basis_mask, axis=1) - return pdf_x_pdf - - def einsum(equation, *args, **kwargs): """ Computes the tensor product using einsum diff --git a/n3fit/src/n3fit/layers/DIS.py b/n3fit/src/n3fit/layers/DIS.py index 0aeb2d00e6..ff4a3708ed 100644 --- a/n3fit/src/n3fit/layers/DIS.py +++ b/n3fit/src/n3fit/layers/DIS.py @@ -20,8 +20,7 @@ class DIS(Observable): the incoming pdf. The fktable is expected to be rank 3 (ndata, xgrid, flavours) - while the input pdf is rank 4 where the first dimension is the batch dimension - and the last dimension the number of replicas being fitted (1, replicas, xgrid, flavours) + while the input pdf is rank 4 of shape (batch_size, replicas, xgrid, flavours) """ def gen_mask(self, basis): @@ -32,6 +31,11 @@ def gen_mask(self, basis): ---------- basis: list(int) list of active flavours + + Returns + ------- + mask: tensor + rank 1 tensor (flavours) """ if basis is None: self.basis = np.ones(self.nfl, dtype=bool) @@ -41,39 +45,17 @@ def gen_mask(self, basis): basis_mask[i] = True return op.numpy_to_tensor(basis_mask, dtype=bool) - def call(self, pdf): - """ - This function perform the fktable \otimes pdf convolution. - - First pass the input PDF through a mask to remove the unactive flavors, - then a tensor_product between the PDF and each fktable is performed - finally the defined operation is applied to all the results - - Parameters - ---------- - pdf: backend tensor - rank 4 tensor (batch_size, replicas, xgrid, flavours) + def build(self, input_shape): + super().build(input_shape) + if self.num_replicas > 1: - Returns - ------- - result: backend tensor - rank 3 tensor (batchsize, replicas, ndata) - """ - # DIS never needs splitting - if self.splitting is not None: - raise ValueError("DIS layer call with a dataset that needs more than one xgrid?") + def compute_observable(pdf, mask, fk): + return op.einsum('fF, nFx, brxf -> brn', mask, fk, pdf) - results = [] - # Separate the two possible paths this layer can take - if self.many_masks: - for mask, fktable in zip(self.all_masks, self.fktables): - pdf_masked = op.boolean_mask(pdf, mask, axis=3) - res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)]) - results.append(res) else: - pdf_masked = op.boolean_mask(pdf, self.all_masks[0], axis=3) - for fktable in self.fktables: - res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)]) - results.append(res) - return self.operation(results) + def compute_observable(pdf, mask, fk): + pdf_masked = op.tensor_product(pdf, mask, axes=1) # brxf, fF -> brxF + return op.tensor_product(pdf_masked, fk, axes=[(2, 3), (2, 1)]) # brxF, nFx -> brn + + self.compute_observable = compute_observable diff --git a/n3fit/src/n3fit/layers/DY.py b/n3fit/src/n3fit/layers/DY.py index 3b42a394bd..9a152633cf 100644 --- a/n3fit/src/n3fit/layers/DY.py +++ b/n3fit/src/n3fit/layers/DY.py @@ -11,6 +11,19 @@ class DY(Observable): """ def gen_mask(self, basis): + """ + Receives a list of active flavours and generates a boolean mask tensor + + Parameters + ---------- + basis: list(int) + list of active flavours + + Returns + ------- + mask: tensor + rank 2 tensor (flavours, flavours) + """ if basis is None: basis_mask = np.ones((self.nfl, self.nfl), dtype=bool) else: @@ -19,48 +32,22 @@ def gen_mask(self, basis): basis_mask[i, j] = True return op.numpy_to_tensor(basis_mask, dtype=bool) - def call(self, pdf_raw): - """ - This function perform the fktable \otimes pdf \otimes pdf convolution. - - First uses the basis of active combinations to generate a luminosity tensor - with only some flavours active. + def build(self, input_shape): + super().build(input_shape) + if self.num_replicas > 1: - The concatenate function returns a rank-3 tensor (combination_index, xgrid, xgrid) - which can in turn be contracted with the rank-4 fktable. + def compute_observable(pdf, mask, fk): + return op.einsum('fgF, nFxy, brxf, bryg -> brn', mask, fk, pdf, pdf) - Parameters - ---------- - pdf_in: tensor - rank 4 tensor (batchsize, replicas, xgrid, flavours) - - Returns - ------- - results: tensor - rank 3 tensor (batchsize, replicas, ndata) - """ - # Hadronic observables might need splitting of the input pdf in the x dimension - # so we have 3 different paths for this layer - - results = [] - if self.many_masks: - if self.splitting: - splitted_pdf = op.split(pdf_raw, self.splitting, axis=2) - for mask, pdf, fk in zip(self.all_masks, splitted_pdf, self.fktables): - pdf_x_pdf = op.pdf_masked_convolution(pdf, mask) - res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)]) - results.append(res) - else: - for mask, fk in zip(self.all_masks, self.fktables): - pdf_x_pdf = op.pdf_masked_convolution(pdf_raw, mask) - res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)]) - results.append(res) else: - pdf_x_pdf = op.pdf_masked_convolution(pdf_raw, self.all_masks[0]) - for fk in self.fktables: - res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)]) - results.append(res) - # the masked convolution removes the batch dimension - ret = op.transpose(self.operation(results)) - return op.batchit(ret) + def compute_observable(pdf, mask, fk): + pdf = pdf[0][0] # yg + pdf_x_mask = op.tensor_product(pdf, mask, axes=[[1], [1]]) # yg, fgF -> yfF + pdf_x_pdf = op.tensor_product(pdf, pdf_x_mask, axes=[[1], [1]]) # xf, yfF -> xyF + observable = op.tensor_product( + fk, pdf_x_pdf, axes=[(1, 2, 3), (2, 0, 1)] + ) # nFxy, xyF + return op.batchit(op.batchit(observable)) # brn + + self.compute_observable = compute_observable diff --git a/n3fit/src/n3fit/layers/observable.py b/n3fit/src/n3fit/layers/observable.py index 739d7c775b..0f735b73f6 100644 --- a/n3fit/src/n3fit/layers/observable.py +++ b/n3fit/src/n3fit/layers/observable.py @@ -1,6 +1,8 @@ -from n3fit.backends import MetaLayer +from abc import ABC, abstractmethod + import numpy as np -from abc import abstractmethod, ABC + +from n3fit.backends import MetaLayer from n3fit.backends import operations as op @@ -25,7 +27,6 @@ class Observable(MetaLayer, ABC): fktables and pdfs - call: this is what does the actual operation - Parameters ---------- fktable_data: list[validphys.coredata.FKTableData] @@ -42,14 +43,17 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs): super(MetaLayer, self).__init__(**kwargs) self.nfl = nfl + self.num_replicas = None # set in build + self.compute_observable = None # set in build - basis = [] + all_bases = [] xgrids = [] - self.fktables = [] + fktables = [] for fkdata, fk in zip(fktable_data, fktable_arr): xgrids.append(fkdata.xgrid.reshape(1, -1)) - basis.append(fkdata.luminosity_mapping) - self.fktables.append(op.numpy_to_tensor(fk)) + all_bases.append(fkdata.luminosity_mapping) + fktables.append(op.numpy_to_tensor(fk)) + self.fktables = fktables # check how many xgrids this dataset needs if is_unique(xgrids): @@ -57,21 +61,78 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs): else: self.splitting = [i.shape[1] for i in xgrids] - # check how many basis this dataset needs - if is_unique(basis) and is_unique(xgrids): - self.all_masks = [self.gen_mask(basis[0])] - self.many_masks = False + self.operation = op.c_to_py_fun(operation_name) + self.output_dim = fktables[0].shape[0] + + if is_unique(all_bases) and is_unique(xgrids): + self.all_masks = [self.gen_mask(all_bases[0])] else: - self.many_masks = True - self.all_masks = [self.gen_mask(i) for i in basis] + self.all_masks = [self.gen_mask(basis) for basis in all_bases] - self.operation = op.c_to_py_fun(operation_name) - self.output_dim = self.fktables[0].shape[0] + self.masks = [self.compute_float_mask(bool_mask) for bool_mask in self.all_masks] + + def build(self, input_shape): + self.num_replicas = input_shape[1] + super().build(input_shape) def compute_output_shape(self, input_shape): return (self.output_dim, None) - # Overridables + def compute_float_mask(self, bool_mask): + """ + Compute a float form of the given boolean mask, that can be contracted over the full flavor + axes to obtain a PDF of only the active flavors. + + Parameters + ---------- + bool_mask: boolean tensor + mask of the active flavours + + Returns + ------- + masked_to_full: float tensor + float form of mask + """ + # Create a tensor with the shape (**bool_mask.shape, num_active_flavours) + masked_to_full = [] + for idx in np.argwhere(bool_mask): + temp_matrix = np.zeros(bool_mask.shape) + temp_matrix[tuple(idx)] = 1 + masked_to_full.append(temp_matrix) + masked_to_full = np.stack(masked_to_full, axis=-1) + masked_to_full = op.numpy_to_tensor(masked_to_full) + + return masked_to_full + + def call(self, pdf): + """ + This function perform the convolution with the fktable and one (DY) or two (DIS) pdfs. + + Parameters + ---------- + pdf: backend tensor + rank 4 tensor (batch_size, replicas, xgrid, flavours) + + Returns + ------- + observables: backend tensor + rank 3 tensor (batchsize, replicas, ndata) + """ + if self.splitting: + pdfs = op.split(pdf, self.splitting, axis=2) + else: + pdfs = [pdf] * len(self.fktables) + # If we have only one mask (or PDF above), just repeat it to be used for all fktables + masks = self.masks * len(self.fktables) + + observables = [] + for pdf, mask, fk in zip(pdfs, masks, self.fktables): + observable = self.compute_observable(pdf, mask, fk) + observables.append(observable) + + observables = self.operation(observables) + return observables + @abstractmethod def gen_mask(self, basis): pass