Skip to content

Commit

Permalink
Refactor FK contractions
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Feb 22, 2024
1 parent a7dfc0d commit ad94132
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 126 deletions.
35 changes: 0 additions & 35 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,41 +256,6 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None):
return concatenated_tensor


# Mathematical operations
def pdf_masked_convolution(raw_pdf, basis_mask):
"""Computes a masked convolution of two equal pdfs
And applies a basis_mask so that only the actually useful values
of the convolution are returned.
If training many PDFs at once, it will use as a backend `einsum`, which
is better suited for running on GPU (but slower on CPU).
Parameters
----------
pdf: tf.tensor
rank 4 (batchsize, replicas, xgrid, flavours)
basis_mask: tf.tensor
rank 2 tensor (flavours, flavours)
mask to apply to the pdf convolution
Return
------
pdf_x_pdf: tf.tensor
rank3 (replicas, len(mask_true), xgrid, xgrid)
"""
if raw_pdf.shape[1] == 1: # only one replica!
pdf = tf.squeeze(raw_pdf, axis=(0, 1))
luminosity = tensor_product(pdf, pdf, axes=0)
lumi_tmp = K.permute_dimensions(luminosity, (3, 1, 2, 0))
pdf_x_pdf = batchit(boolean_mask(lumi_tmp, basis_mask), 0)
else:
pdf = tf.squeeze(raw_pdf, axis=0) # remove the batchsize
luminosity = tf.einsum('rai,rbj->rjiba', pdf, pdf)
# (xgrid, flavour, xgrid, flavour)
pdf_x_pdf = boolean_mask(luminosity, basis_mask, axis=1)
return pdf_x_pdf


def einsum(equation, *args, **kwargs):
"""
Computes the tensor product using einsum
Expand Down
50 changes: 16 additions & 34 deletions n3fit/src/n3fit/layers/DIS.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class DIS(Observable):
the incoming pdf.
The fktable is expected to be rank 3 (ndata, xgrid, flavours)
while the input pdf is rank 4 where the first dimension is the batch dimension
and the last dimension the number of replicas being fitted (1, replicas, xgrid, flavours)
while the input pdf is rank 4 of shape (batch_size, replicas, xgrid, flavours)
"""

def gen_mask(self, basis):
Expand All @@ -32,6 +31,11 @@ def gen_mask(self, basis):
----------
basis: list(int)
list of active flavours
Returns
-------
mask: tensor
rank 1 tensor (flavours)
"""
if basis is None:
self.basis = np.ones(self.nfl, dtype=bool)
Expand All @@ -41,39 +45,17 @@ def gen_mask(self, basis):
basis_mask[i] = True
return op.numpy_to_tensor(basis_mask, dtype=bool)

def call(self, pdf):
"""
This function perform the fktable \otimes pdf convolution.
First pass the input PDF through a mask to remove the unactive flavors,
then a tensor_product between the PDF and each fktable is performed
finally the defined operation is applied to all the results
Parameters
----------
pdf: backend tensor
rank 4 tensor (batch_size, replicas, xgrid, flavours)
def build(self, input_shape):
super().build(input_shape)
if self.num_replicas > 1:

Returns
-------
result: backend tensor
rank 3 tensor (batchsize, replicas, ndata)
"""
# DIS never needs splitting
if self.splitting is not None:
raise ValueError("DIS layer call with a dataset that needs more than one xgrid?")
def compute_observable(pdf, mask, fk):
return op.einsum('fF, nFx, brxf -> brn', mask, fk, pdf)

results = []
# Separate the two possible paths this layer can take
if self.many_masks:
for mask, fktable in zip(self.all_masks, self.fktables):
pdf_masked = op.boolean_mask(pdf, mask, axis=3)
res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)])
results.append(res)
else:
pdf_masked = op.boolean_mask(pdf, self.all_masks[0], axis=3)
for fktable in self.fktables:
res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)])
results.append(res)

return self.operation(results)
def compute_observable(pdf, mask, fk):
pdf_masked = op.tensor_product(pdf, mask, axes=1) # brxf, fF -> brxF
return op.tensor_product(pdf_masked, fk, axes=[(2, 3), (2, 1)]) # brxF, nFx -> brn

self.compute_observable = compute_observable
69 changes: 28 additions & 41 deletions n3fit/src/n3fit/layers/DY.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ class DY(Observable):
"""

def gen_mask(self, basis):
"""
Receives a list of active flavours and generates a boolean mask tensor
Parameters
----------
basis: list(int)
list of active flavours
Returns
-------
mask: tensor
rank 2 tensor (flavours, flavours)
"""
if basis is None:
basis_mask = np.ones((self.nfl, self.nfl), dtype=bool)
else:
Expand All @@ -19,48 +32,22 @@ def gen_mask(self, basis):
basis_mask[i, j] = True
return op.numpy_to_tensor(basis_mask, dtype=bool)

def call(self, pdf_raw):
"""
This function perform the fktable \otimes pdf \otimes pdf convolution.
First uses the basis of active combinations to generate a luminosity tensor
with only some flavours active.
def build(self, input_shape):
super().build(input_shape)
if self.num_replicas > 1:

The concatenate function returns a rank-3 tensor (combination_index, xgrid, xgrid)
which can in turn be contracted with the rank-4 fktable.
def compute_observable(pdf, mask, fk):
return op.einsum('fgF, nFxy, brxf, bryg -> brn', mask, fk, pdf, pdf)

Parameters
----------
pdf_in: tensor
rank 4 tensor (batchsize, replicas, xgrid, flavours)
Returns
-------
results: tensor
rank 3 tensor (batchsize, replicas, ndata)
"""
# Hadronic observables might need splitting of the input pdf in the x dimension
# so we have 3 different paths for this layer

results = []
if self.many_masks:
if self.splitting:
splitted_pdf = op.split(pdf_raw, self.splitting, axis=2)
for mask, pdf, fk in zip(self.all_masks, splitted_pdf, self.fktables):
pdf_x_pdf = op.pdf_masked_convolution(pdf, mask)
res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)])
results.append(res)
else:
for mask, fk in zip(self.all_masks, self.fktables):
pdf_x_pdf = op.pdf_masked_convolution(pdf_raw, mask)
res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)])
results.append(res)
else:
pdf_x_pdf = op.pdf_masked_convolution(pdf_raw, self.all_masks[0])
for fk in self.fktables:
res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)])
results.append(res)

# the masked convolution removes the batch dimension
ret = op.transpose(self.operation(results))
return op.batchit(ret)
def compute_observable(pdf, mask, fk):
pdf = pdf[0][0] # yg
pdf_x_mask = op.tensor_product(pdf, mask, axes=[[1], [1]]) # yg, fgF -> yfF
pdf_x_pdf = op.tensor_product(pdf, pdf_x_mask, axes=[[1], [1]]) # xf, yfF -> xyF
observable = op.tensor_product(
fk, pdf_x_pdf, axes=[(1, 2, 3), (2, 0, 1)]
) # nFxy, xyF
return op.batchit(op.batchit(observable)) # brn

self.compute_observable = compute_observable
93 changes: 77 additions & 16 deletions n3fit/src/n3fit/layers/observable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from n3fit.backends import MetaLayer
from abc import ABC, abstractmethod

import numpy as np
from abc import abstractmethod, ABC

from n3fit.backends import MetaLayer
from n3fit.backends import operations as op


Expand All @@ -25,7 +27,6 @@ class Observable(MetaLayer, ABC):
fktables and pdfs
- call: this is what does the actual operation
Parameters
----------
fktable_data: list[validphys.coredata.FKTableData]
Expand All @@ -42,36 +43,96 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
super(MetaLayer, self).__init__(**kwargs)

self.nfl = nfl
self.num_replicas = None # set in build
self.compute_observable = None # set in build

basis = []
all_bases = []
xgrids = []
self.fktables = []
fktables = []
for fkdata, fk in zip(fktable_data, fktable_arr):
xgrids.append(fkdata.xgrid.reshape(1, -1))
basis.append(fkdata.luminosity_mapping)
self.fktables.append(op.numpy_to_tensor(fk))
all_bases.append(fkdata.luminosity_mapping)
fktables.append(op.numpy_to_tensor(fk))
self.fktables = fktables

# check how many xgrids this dataset needs
if is_unique(xgrids):
self.splitting = None
else:
self.splitting = [i.shape[1] for i in xgrids]

# check how many basis this dataset needs
if is_unique(basis) and is_unique(xgrids):
self.all_masks = [self.gen_mask(basis[0])]
self.many_masks = False
self.operation = op.c_to_py_fun(operation_name)
self.output_dim = fktables[0].shape[0]

if is_unique(all_bases) and is_unique(xgrids):
self.all_masks = [self.gen_mask(all_bases[0])]
else:
self.many_masks = True
self.all_masks = [self.gen_mask(i) for i in basis]
self.all_masks = [self.gen_mask(basis) for basis in all_bases]

self.operation = op.c_to_py_fun(operation_name)
self.output_dim = self.fktables[0].shape[0]
self.masks = [self.compute_float_mask(bool_mask) for bool_mask in self.all_masks]

def build(self, input_shape):
self.num_replicas = input_shape[1]
super().build(input_shape)

def compute_output_shape(self, input_shape):
return (self.output_dim, None)

# Overridables
def compute_float_mask(self, bool_mask):
"""
Compute a float form of the given boolean mask, that can be contracted over the full flavor
axes to obtain a PDF of only the active flavors.
Parameters
----------
bool_mask: boolean tensor
mask of the active flavours
Returns
-------
masked_to_full: float tensor
float form of mask
"""
# Create a tensor with the shape (**bool_mask.shape, num_active_flavours)
masked_to_full = []
for idx in np.argwhere(bool_mask):
temp_matrix = np.zeros(bool_mask.shape)
temp_matrix[tuple(idx)] = 1
masked_to_full.append(temp_matrix)
masked_to_full = np.stack(masked_to_full, axis=-1)
masked_to_full = op.numpy_to_tensor(masked_to_full)

return masked_to_full

def call(self, pdf):
"""
This function perform the convolution with the fktable and one (DY) or two (DIS) pdfs.
Parameters
----------
pdf: backend tensor
rank 4 tensor (batch_size, replicas, xgrid, flavours)
Returns
-------
observables: backend tensor
rank 3 tensor (batchsize, replicas, ndata)
"""
if self.splitting:
pdfs = op.split(pdf, self.splitting, axis=2)
else:
pdfs = [pdf] * len(self.fktables)
# If we have only one mask (or PDF above), just repeat it to be used for all fktables
masks = self.masks * len(self.fktables)

observables = []
for pdf, mask, fk in zip(pdfs, masks, self.fktables):
observable = self.compute_observable(pdf, mask, fk)
observables.append(observable)

observables = self.operation(observables)
return observables

@abstractmethod
def gen_mask(self, basis):
pass

0 comments on commit ad94132

Please sign in to comment.