Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring model creation code #1734

Merged
merged 74 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6b40d90
Name input layer xgrids
APJansen May 12, 2023
57f6552
Rename and simplify dense_me -> neural_network
APJansen May 12, 2023
d81814d
Rename i->i_replica and layer_seed->replica_seed
APJansen May 12, 2023
557b14c
Rename preprocessing -> prefactor
APJansen May 12, 2023
e087a38
Create apply_prefactor layer (can't name yet as it is created twice r…
APJansen May 12, 2023
ca187e4
Merge msr.py into layers/msr_normalization.py
APJansen May 12, 2023
97d7409
Turn msr_impose into class method
APJansen May 12, 2023
e0fb2f5
Use MSR_Normalization class attributes
APJansen May 12, 2023
c8e7e63
Create integration grid in MSR_Normalizations init
APJansen May 12, 2023
c5551f1
Renamings
APJansen May 12, 2023
6241614
Prepare call method in MSR_Normalization
APJansen May 12, 2023
4ecb1ba
Share more layers, add names
APJansen May 15, 2023
767977b
Create named layers in msr_normalization
APJansen May 15, 2023
d6ed37c
Add pdf_integrated step
APJansen May 15, 2023
396fc6f
Move computation of xdivided to init, move integration into call
APJansen May 15, 2023
472abed
remove tempcall
APJansen May 15, 2023
08c8876
Fix bug introduced after renaming preprocessing to prefactor
APJansen May 16, 2023
5337d18
Join neural network layers into their own NN_i model
APJansen May 16, 2023
ed85389
Clarify layer names
APJansen May 16, 2023
6bc651f
Revert all changes to MSR
APJansen May 17, 2023
96328a2
Rewrite msr into a model that takes as inputs the pdf, pdf on integra…
APJansen May 17, 2023
82cf95b
Set shape of integration grid to (None, 1) rather than (2000, 1) to d…
APJansen May 17, 2023
7b475ba
Revert renaming of Preprocessing to Prefactor (now use PreprocessingF…
APJansen May 17, 2023
7153e93
Clean up model creation code
APJansen May 17, 2023
dfb4eed
Reorganize input options, keeping effect the same
APJansen May 22, 2023
9dc1384
Make preprocessing layers dependent on which option is used (scaling,…
APJansen May 22, 2023
7c6d8ae
Fix msr with scaling
APJansen May 22, 2023
cbce181
Fix msr with scaling, shape of integration grid
APJansen May 22, 2023
3296aa4
Set shape of x=1 input to None so tensorflow displays proper shapes
APJansen May 22, 2023
853dd2e
Merge from master
APJansen May 23, 2023
fba7827
Add dummy ph_replica argument to MSR_Normalization layer
APJansen May 25, 2023
38e1984
Factor out scaler into own module scaler.py
APJansen May 23, 2023
d0065df
Do scaling of [0,1] to [-1, 1] inside scaler
APJansen May 23, 2023
2d2f8a0
Replace None layers with identity functions
APJansen May 23, 2023
08d26f2
Add option to save model plots (to current directory for now)
APJansen May 23, 2023
577d5d8
Add test for xDivide
APJansen May 25, 2023
f798c22
Update xDivide documentation to include v15 in the default settings, …
APJansen May 25, 2023
00a4974
Simplify xDivide layer
APJansen May 25, 2023
4c94100
Clarify xDivide documentation.
APJansen May 25, 2023
4f3d949
Turn off plotting of model for now
APJansen May 25, 2023
96e71b2
Merge branch 'master' into model_refactor
APJansen May 30, 2023
b279068
Manually change regression weights hdf5 file structure to fit new model
APJansen May 31, 2023
c1bb0ac
Manually change developing_weights.h5 file structure to fit new model
APJansen May 31, 2023
7d80a61
Remove inp option in pdfNN_layer_generator, instead enforcing adding …
APJansen Jun 1, 2023
38a1ccb
Incorporate Juan's comments
APJansen Jun 2, 2023
789a4c3
Incorporate Roy's comments
APJansen Jun 2, 2023
fe84657
Remove plotting of model from model_trainer
APJansen Jun 5, 2023
1f4537e
Remove unnecessary get_original layer from msr if scaler is not used
APJansen Jun 5, 2023
6353db3
Add model plot example script to documentation.
APJansen Jun 5, 2023
559d480
Merge master
APJansen Jun 5, 2023
0168360
Apply black to all changed files.
APJansen Jun 5, 2023
801366a
Remove sentence about plotting full model.
APJansen Jun 7, 2023
f99b77b
Remove now unused scatter_to_zero.
APJansen Jun 14, 2023
2b8c806
Merge branch 'master' into model_refactor
APJansen Jun 20, 2023
0bd5582
Pass replica number as additional argument to impose_msr model
APJansen Jun 20, 2023
8dbc98d
Replace indexing with tf.gather since ph_replica is now a tensor
APJansen Jun 23, 2023
a0f6e09
gather -> op_gather
APJansen Jun 23, 2023
20833ca
Revert to commit 2b8c806e
APJansen Jun 24, 2023
6a1ce02
Set photon integral as explicit argument to normalization model
APJansen Jun 24, 2023
6617da2
Remove duplicate imports
APJansen Jul 5, 2023
5aa72fc
Remove duplicate imports
APJansen Jul 5, 2023
a80a77e
Add link to plot_model documentation
APJansen Jul 5, 2023
f82b407
Move photon integrals down to final loop
APJansen Jul 5, 2023
80ef5d7
Simplify numpy_to_input, removing unused no_reshape and setting gridp…
APJansen Jul 5, 2023
ccfb7ce
Merge branch 'master' into model_refactor
APJansen Jul 6, 2023
0a4d684
Improve numpy_to_input documentation
APJansen Jul 7, 2023
0518c7c
Apply black and isort to all affected files
APJansen Jul 10, 2023
e92b706
Update n3fit/src/n3fit/backends/keras_backend/operations.py
APJansen Jul 10, 2023
5bcf169
Restore comments in scaler
APJansen Jul 10, 2023
1a9ce30
Add Optional type hint for arguments that can be None
APJansen Jul 12, 2023
9b6a675
Use numpy.typing
APJansen Jul 12, 2023
8566728
Raise error from original error
APJansen Jul 12, 2023
27f9357
Fix error when not imposing sumrule
APJansen Jul 12, 2023
694417d
Merge branch 'master' into model_refactor
APJansen Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added doc/sphinx/source/n3fit/figures/plot_pdf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 50 additions & 0 deletions doc/sphinx/source/n3fit/methodology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,56 @@ provide a faster convergence to the solution.
.. important::
Parameters like the number of layers, nodes, activation functions are hyper-parameters that require tuning.


To see the structure of the model, one can use Keras's ``plot_model`` function as illustrated in the script below.
See the `Keras documentation <https://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model>`_ for more details.

.. code-block:: python

from tensorflow.keras.utils import plot_model
from n3fit.model_gen import pdfNN_layer_generator
from validphys.api import API

fit_info = API.fit(fit="NNPDF40_nnlo_as_01180_1000").as_input()
basis_info = fit_info["fitting"]["basis"]

pdf_models = pdfNN_layer_generator(
nodes=[25, 20, 8],
activations=["tanh", "tanh", "linear"],
initializer_name="glorot_normal",
layer_type="dense",
flav_info=basis_info,
fitbasis="EVOL",
out=14,
seed=42,
dropout=0.0,
regularizer=None,
regularizer_args=None,
impose_sumrule="All",
scaler=None,
parallel_models=1,
)

pdf_model = pdf_models[0]
nn_model = pdf_model.get_layer("NN_0")
msr_model = pdf_model.get_layer("impose_msr")
models_to_plot = {
'plot_pdf': pdf_model,
'plot_nn': nn_model,
'plot_msr': msr_model
}

for name, model in models_to_plot.items():
plot_model(model, to_file=f"./{name}.png", show_shapes=True)


This will produce for instance the plot of the PDF model below, and can also be used to plot the
neural network model, and the momentum sum rule model.

.. image::
figures/plot_pdf.png


.. _preprocessing:

Preprocessing
Expand Down
Binary file modified n3fit/runcards/examples/developing_weights.h5
Binary file not shown.
58 changes: 31 additions & 27 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@
equally operations are automatically converted to layers when used as such.
"""

from typing import Optional

import numpy as np
import numpy.typing as npt
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda as keras_Lambda
from tensorflow.keras.layers import multiply as keras_multiply
from tensorflow.keras.layers import subtract as keras_subtract

from tensorflow.keras.layers import Input
from tensorflow.keras import backend as K

from validphys.convolution import OP


def evaluate(tensor):
""" Evaluate input tensor using the backend """
"""Evaluate input tensor using the backend"""
return K.eval(tensor)


Expand Down Expand Up @@ -107,36 +109,29 @@ def numpy_to_tensor(ival, **kwargs):

# f(x: tensor) -> y: tensor
def batchit(x, batch_dimension=0, **kwarg):
""" Add a batch dimension to tensor x """
"""Add a batch dimension to tensor x"""
return tf.expand_dims(x, batch_dimension, **kwarg)


# layer generation
def numpy_to_input(numpy_array, no_reshape=False, name=None):
def numpy_to_input(numpy_array: npt.NDArray, name: Optional[str] = None):
"""
Takes a numpy array and generates a Input layer.
By default it adds a batch dimension (of size 1) so that the shape of the layer
is that of the array
Takes a numpy array and generates an Input layer with the same shape,
but with a batch dimension (of size 1) added.

Parameters
----------
numpy_array: np.ndarray
no_reshape: bool
if true, don't add batch dimension, take the first dimension of the array as the batch
name: bool
name: str
name to give to the layer
"""
if no_reshape:
batched_array = numpy_array
batch_size = numpy_array.shape[0]
shape = numpy_array.shape[1:]
else:
batched_array = np.expand_dims(numpy_array, 0)
batch_size = 1
shape = numpy_array.shape
input_layer = Input(batch_size=batch_size, shape=shape, name=name)
batched_array = np.expand_dims(numpy_array, 0)
shape = list(numpy_array.shape)
# set the number of gridpoints to None, otherwise shapes don't show in model.summary
shape[0] = None
APJansen marked this conversation as resolved.
Show resolved Hide resolved

input_layer = Input(batch_size=1, shape=shape, name=name)
input_layer.tensor_content = batched_array
input_layer.original_shape = no_reshape
return input_layer


Expand Down Expand Up @@ -174,7 +169,7 @@ def op_gather_keep_dims(tensor, indices, axis=0, **kwargs):
both eager and non-eager tensors
"""
if indices == -1:
indices = tensor.shape[axis]-1
indices = tensor.shape[axis] - 1

def tmp(x):
y = tf.gather(x, indices, axis=axis, **kwargs)
Expand All @@ -189,6 +184,7 @@ def tmp(x):
# f(x: tensor[s]) -> y: tensor
#


# Generation operations
# generate tensors of given shape/content
@tf.function
Expand Down Expand Up @@ -216,7 +212,7 @@ def many_replication(grid, replications, axis=0, **kwargs):
# modify properties of the tensor like the shape or elements it has
@tf.function
def flatten(x):
""" Flatten tensor x """
"""Flatten tensor x"""
return tf.reshape(x, (-1,))


Expand All @@ -240,7 +236,7 @@ def transpose(tensor, **kwargs):


def stack(tensor_list, axis=0, **kwargs):
""" Stack a list of tensors
"""Stack a list of tensors
see full `docs <https://www.tensorflow.org/api_docs/python/tf/stack>`_
"""
return tf.stack(tensor_list, axis=axis, **kwargs)
Expand Down Expand Up @@ -280,8 +276,8 @@ def pdf_masked_convolution(raw_pdf, basis_mask):
pdf_x_pdf: tf.tensor
rank3 (len(mask_true), xgrid, xgrid, replicas)
"""
if raw_pdf.shape[-1] == 1: # only one replica!
pdf = tf.squeeze(raw_pdf, axis=(0,-1))
if raw_pdf.shape[-1] == 1: # only one replica!
pdf = tf.squeeze(raw_pdf, axis=(0, -1))
luminosity = tensor_product(pdf, pdf, axes=0)
lumi_tmp = K.permute_dimensions(luminosity, (3, 1, 2, 0))
pdf_x_pdf = batchit(boolean_mask(lumi_tmp, basis_mask), -1)
Expand Down Expand Up @@ -309,6 +305,14 @@ def tensor_product(*args, **kwargs):
return tf.tensordot(*args, **kwargs)


@tf.function
def pow(tensor, power):
"""
Computes the power of the tensor
"""
return tf.pow(tensor, power)


@tf.function(experimental_relax_shapes=True)
def op_log(o_tensor, **kwargs):
"""
Expand Down
26 changes: 16 additions & 10 deletions n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ class MSR_Normalization(MetaLayer):
_msr_enabled = False
_vsr_enabled = False

def __init__(self, output_dim=14, mode="ALL", photons_contribution=None, **kwargs):
self._photons = photons_contribution
def __init__(self, output_dim=14, mode="ALL", **kwargs):
if mode == True or mode.upper() == "ALL":
self._msr_enabled = True
self._vsr_enabled = True
Expand All @@ -38,9 +37,9 @@ def __init__(self, output_dim=14, mode="ALL", photons_contribution=None, **kwarg
op.scatter_to_one, op_kwargs={"indices": idx, "output_dim": output_dim}
)

super().__init__(**kwargs, name="normalizer")
super().__init__(**kwargs)

def call(self, pdf_integrated, ph_replica):
def call(self, pdf_integrated, photon_integral):
"""Imposes the valence and momentum sum rules:
A_g = (1-sigma-photon)/g
A_v = A_v24 = A_v35 = 3/V
Expand All @@ -49,17 +48,24 @@ def call(self, pdf_integrated, ph_replica):
A_v15 = 3/V_15

Note that both the input and the output are in the 14-flavours fk-basis

Parameters
----------
pdf_integrated: (Tensor(1,None,14))
the integrated PDF
photon_integral: (Tensor(1)):
the integrated photon, not included in PDF

Returns
-------
normalization_factor: Tensor(14)
The normalization factors per flavour.
"""
y = op.flatten(pdf_integrated)
norm_constants = []

if self._photons:
photon_integral = self._photons[ph_replica]
else:
photon_integral = 0.0

if self._msr_enabled:
n_ag = [(1.0 - y[GLUON_IDX[0][0] - 1] - photon_integral) / y[GLUON_IDX[0][0]]] * len(
n_ag = [(1.0 - y[GLUON_IDX[0][0] - 1] - photon_integral[0]) / y[GLUON_IDX[0][0]]] * len(
GLUON_IDX
)
norm_constants += n_ag
Expand Down
20 changes: 8 additions & 12 deletions n3fit/src/n3fit/layers/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from n3fit.backends import MetaLayer
from n3fit.backends import constraints
from n3fit.backends import MetaLayer, constraints
from n3fit.backends import operations as op


class Preprocessing(MetaLayer):
"""
Applies preprocessing to the PDF.
Computes preprocessing factor for the PDF.

This layer generates a factor (1-x)^beta*x^(1-alpha) where both beta and alpha
are model paramters that can be trained. If feature scaling is used, the preprocessing
Expand All @@ -21,29 +20,26 @@ class Preprocessing(MetaLayer):
Parameters
----------
flav_info: list
list of dicts containing the information about the fitting of the preprocessing
list of dicts containing the information about the fitting of the preprocessing factor
This corresponds to the `fitting::basis` parameter in the nnpdf runcard.
The dicts can contain the following fields:
`smallx`: range of alpha
`largex`: range of beta
`trainable`: whether these alpha-beta should be trained during the fit
(defaults to true)
large_x: bool
Whether large x preprocessing should be active
Whether large x preprocessing factor should be active
seed: int
seed for the initializer of the random alpha and beta values
"""

def __init__(
self,
flav_info=None,
seed=0,
initializer="random_uniform",
large_x=True,
**kwargs,
self, flav_info=None, seed=0, initializer="random_uniform", large_x=True, **kwargs,
):
if flav_info is None:
raise ValueError("Trying to instantiate a preprocessing with no basis information")
raise ValueError(
"Trying to instantiate a preprocessing factor with no basis information"
)
self.flav_info = flav_info
self.seed = seed
self.output_dim = len(flav_info)
Expand Down
37 changes: 20 additions & 17 deletions n3fit/src/n3fit/layers/x_operations.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,57 @@
"""
This module contains layers acting on the x-grid input of the NN

The three operations included are:
The two operations included are:
- ``xDivide``
APJansen marked this conversation as resolved.
Show resolved Hide resolved
- ``xIntegrator``

The names are self-describing. The only subtlety is that they do not act equally
for all flavours. The choice of flavours on which to act in a different way is given
as an input argument.
"""
from typing import List, Optional

from n3fit.backends import MetaLayer
from n3fit.backends import operations as op

BASIS_SIZE=14
BASIS_SIZE = 14


class xDivide(MetaLayer):
"""
Divide some PDFs by x
Create tensor of either 1/x or ones depending on the flavour,
to be used to divide some PDFs by x by multiplying with the result.

By default it utilizes the 14-flavour FK basis and divides [v, v3, v8]
which corresponds to indices (3,4,5) from
By default it utilizes the 14-flavour FK basis and divides [v, v3, v8, v15]
which corresponds to indices (3, 4, 5, 6) from
(photon, sigma, g, v, v3, v8, v15, v24, v35, t3, t8, t15, t24, t35)

Parameters:
-----------
output_dim: int
dimension of the pdf
div_list: list
list of indices to be divided by X (by default [3,4,5]; [v, v3, v8]
list of indices to be divided by X (by default [3, 4, 5, 6]; [v, v3, v8, v15]
"""

def __init__(self, output_dim=BASIS_SIZE, div_list=None, **kwargs):
def __init__(
self, output_dim: int = BASIS_SIZE, div_list: Optional[List[int]] = None, **kwargs
):
if div_list is None:
div_list = [3, 4, 5, 6]
self.output_dim = output_dim
self.div_list = div_list
super().__init__(**kwargs)

self.powers = [-1 if i in div_list else 0 for i in range(output_dim)]

def call(self, x):
out_array = []
one = op.tensor_ones_like(x)
for i in range(self.output_dim):
if i in self.div_list:
res = one / x
else:
res = one
out_array.append(res)
out_tensor = op.concatenate(out_array)
return out_tensor
return op.pow(x, self.powers)

def get_config(self):
config = super().get_config()
config.update({"output_dim": self.output_dim, "div_list": self.div_list})
return config


class xIntegrator(MetaLayer):
Expand Down
Loading