Skip to content

Commit

Permalink
Pass replica number as additional argument to impose_msr model
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jun 20, 2023
1 parent 2b8c806 commit 0d6516f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,8 @@ def compute_unnormalized_pdf(x, neural_network, compute_preprocessing_factor):
pdf_unnormalized = compute_unnormalized_pdf(pdf_input, nn, preprocessing_factor)
pdf_integration_grid = compute_unnormalized_pdf(integrator_input, nn, preprocessing_factor)

pdf_normalized = sumrule_layer([pdf_unnormalized, pdf_integration_grid, integrator_input])

# i_replica argument is necessary to select the right photon integral
pdf_normalized = sumrule_layer([pdf_unnormalized, pdf_integration_grid, integrator_input, i_replica])
if photons:
pdf_normalized = layer_photon(pdf_normalized, i_replica)

Expand Down
7 changes: 5 additions & 2 deletions n3fit/src/n3fit/msr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,18 @@ def generate_msr_model_and_grid(
# 4. Integrate the pdf
pdf_integrated = xIntegrator(weights_array, input_shape=(nx,))(pdf_integrand)

# 5. If a photon is given, compute the photon component of the MSR
# 5. If a photon is given, retrieve the photon component of the MSR...
photons_c = None
if photons:
photons_c = photons.integral
# ... and add the replica number as an input, as the above contains the integrals for all replicas
replica_number = Input(shape=(1,), batch_size=1, name="replica_number", dtype="int32")

# 6. Compute the normalization factor
# For now set the photon component to None
normalization_factor = MSR_Normalization(
output_dim, mode, name="msr_weights", photons_contribution=photons_c
)(pdf_integrated, ph_replica=None)
)(pdf_integrated, ph_replica=replica_number)

# 7. Apply the normalization factor to the pdf
pdf_normalized = Lambda(lambda pdf_norm: pdf_norm[0] * pdf_norm[1], name="pdf_normalized")(
Expand All @@ -105,6 +107,7 @@ def generate_msr_model_and_grid(
"pdf_x": pdf_x,
"pdf_xgrid_integration": pdf_xgrid_integration,
"xgrid_integration": xgrid_integration,
"replica_number": replica_number,
}
model = MetaModel(inputs, pdf_normalized, name="impose_msr")

Expand Down

0 comments on commit 0d6516f

Please sign in to comment.