Skip to content

Commit

Permalink
Refactor 1 replica DY
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Feb 26, 2024
1 parent 1a751c2 commit 16551f6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
20 changes: 16 additions & 4 deletions n3fit/src/n3fit/layers/DY.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def gen_mask(self, basis):
return op.numpy_to_tensor(basis_mask, dtype=bool)

def mask_fk(self, fk, mask):
return op.einsum('fgF, nFxy -> nxfyg', mask, fk)
if self.num_replicas > 1:
return op.einsum('fgF, nFxy -> nxfyg', mask, fk)
else:
mask = op.einsum('fgF -> Ffg', mask)
fk = op.einsum('nFxy -> nFyx', fk)
mask_and_fk = (mask, fk)
return mask_and_fk

def build(self, input_shape):
super().build(input_shape)
Expand All @@ -45,10 +51,16 @@ def compute_observable(pdf, masked_fk):

else:

def compute_observable(pdf, masked_fk):
def compute_observable(pdf, mask_and_fk):
# with 1 replica, it's more efficient to mask the PDF rather than the fk table
mask, unmasked_fk = mask_and_fk
pdf = pdf[0][0] # yg
temp = op.tensor_product(masked_fk, pdf, axes=2) # nxfyg, yg -> nxf
observable = op.tensor_product(temp, pdf, axes=2) # nxf, xf -> n

mask_x_pdf = op.tensor_product(mask, pdf, axes=[(2,), (1,)]) # Ffg, yg -> Ffy
pdf_x_pdf = op.tensor_product(mask_x_pdf, pdf, axes=[(1,), (1,)]) # Ffy, xf -> Fyx
# nFyx, Fyx -> n
observable = op.tensor_product(unmasked_fk, pdf_x_pdf, axes=[(1, 2, 3), (0, 1, 2)])

return op.batchit(op.batchit(observable)) # brn

self.compute_observable = compute_observable
16 changes: 9 additions & 7 deletions n3fit/src/n3fit/layers/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
xgrids.append(fkdata.xgrid.reshape(1, -1))
all_bases.append(fkdata.luminosity_mapping)
fktables.append(op.numpy_to_tensor(fk))
fktables = fktables
self.fktables = fktables

# check how many xgrids this dataset needs
if is_unique(xgrids):
Expand All @@ -62,21 +62,23 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
self.splitting = [i.shape[1] for i in xgrids]

self.operation = op.c_to_py_fun(operation_name)
self.output_dim = fktables[0].shape[0]
self.output_dim = self.fktables[0].shape[0]

if is_unique(all_bases) and is_unique(xgrids):
self.all_masks = [self.gen_mask(all_bases[0])]
else:
self.all_masks = [self.gen_mask(basis) for basis in all_bases]

masks = [self.compute_float_mask(bool_mask) for bool_mask in self.all_masks]
# repeat the masks if necessary for fktables (if not, the extra copies
# will get lost in the zip)
masks = masks * len(fktables)
self.masked_fktables = [self.mask_fk(fk, mask) for fk, mask in zip(fktables, masks)]
self.masks = [self.compute_float_mask(bool_mask) for bool_mask in self.all_masks]

def build(self, input_shape):
self.num_replicas = input_shape[1]

# repeat the masks if necessary for fktables (if not, the extra copies
# will get lost in the zip)
masks = self.masks * len(self.fktables)
self.masked_fktables = [self.mask_fk(fk, mask) for fk, mask in zip(self.fktables, masks)]

super().build(input_shape)

def compute_output_shape(self, input_shape):
Expand Down

0 comments on commit 16551f6

Please sign in to comment.