Skip to content

Commit

Permalink
Make training work
Browse files Browse the repository at this point in the history
  • Loading branch information
Radonirinaunimi committed May 14, 2023
1 parent af6f859 commit 88793ed
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
5 changes: 4 additions & 1 deletion n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,5 +400,8 @@ def add_target_dependence(xgrid, a_value):
# in the input. Now, as opposed to the free-proton fit the input
# (still called `xgrid`) is a two-dimensional array.
# TODO: Find a better (again) to propagate the A-dependence
if xgrid.ndim == 1:
xgrid = np.expand_dims(xgrid, axis=-1)

a_value_expand = np.full(xgrid.shape, a_value)
return np.stack((xgrid, a_value_expand), axis=-1)
return np.concatenate((xgrid, a_value_expand), axis=-1)
3 changes: 3 additions & 0 deletions n3fit/src/n3fit/io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import validphys
import n3fit
from n3fit import vpinterface
from n3fit.backends import operations as op

XGRID = np.array(
[
Expand Down Expand Up @@ -326,6 +327,8 @@ def jsonfit(replica_status, pdf_object, tr_chi2, vl_chi2, true_chi2, stop_epoch,
all_info["erf_vl"] = vl_chi2
all_info["chi2"] = true_chi2
all_info["pos_state"] = replica_status.positivity_status
# all_info["arc_lengths"] = None
# all_info["integrability"] = None
all_info["arc_lengths"] = vpinterface.compute_arclength(pdf_object).tolist()
all_info["integrability"] = vpinterface.integrability_numbers(pdf_object).tolist()
all_info["timing"] = timing
Expand Down
7 changes: 5 additions & 2 deletions n3fit/src/n3fit/vpinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections.abc import Iterable
import numpy as np
import numpy.linalg as la
from n3fit.backends import operations as op
from validphys.core import PDF, MCStats
from validphys.pdfbases import ALL_FLAVOURS, check_basis
from validphys.lhapdfset import LHAPDFSet
Expand Down Expand Up @@ -102,15 +103,17 @@ def __call__(self, xarr, flavours=None, replica=None):
# Ensures that the input has the shape the model expect, no matter the input
# as the scaling is done by the model itself
mod_xgrid = xarr.reshape(1, -1, 1)
# TODO: Make different `A` predictions depending on some inputs.
nn_input = op.add_target_dependence(mod_xgrid, a_value=1)

if replica is None or replica == 0:
# We need generate output values for all replicas
result = np.concatenate([m.predict({"pdf_input": mod_xgrid}) for m in self._lhapdf_set], axis=0)
result = np.concatenate([m.predict({"pdf_input": nn_input}) for m in self._lhapdf_set], axis=0)
if replica == 0:
# We want _only_ the central value
result = np.mean(result, axis=0, keepdims=True)
else:
result = self._lhapdf_set[replica - 1].predict({"pdf_input": mod_xgrid})
result = self._lhapdf_set[replica - 1].predict({"pdf_input": nn_input})

if flavours != "n3fit":
# Ensure that the result has its flavour in the basis-defined order
Expand Down

0 comments on commit 88793ed

Please sign in to comment.