Skip to content

Commit

Permalink
sweep: DIRACGrid#8026 fix: make the setting of inputDataBulk extendable
Browse files Browse the repository at this point in the history
  • Loading branch information
fstagni committed Feb 4, 2025
1 parent 970b50f commit 479663f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 18 deletions.
45 changes: 27 additions & 18 deletions src/DIRAC/TransformationSystem/Client/WorkflowTasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername
from DIRAC.Core.Security.ProxyInfo import getProxyInfo
from DIRAC.Core.Utilities.DErrno import ETSDATA, ETSUKN
from DIRAC.Core.Utilities.DErrno import ETSUKN
from DIRAC.Core.Utilities.List import fromChar
from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader
from DIRAC.Interfaces.API.Job import Job
Expand Down Expand Up @@ -83,6 +83,8 @@ def __init__(
self.outputDataModule_o = None
self.objectLoader = ObjectLoader()

self.parametricSequencedKeys = ["JOB_ID", "PRODUCTION_ID", "InputData"]

def prepareTransformationTasks(self, transBody, taskDict, owner="", ownerGroup="", bulkSubmissionFlag=False):
"""Prepare tasks, given a taskDict, that is created (with some manipulation) by the DB
jobClass is by default "DIRAC.Interfaces.API.Job.Job". An extension of it also works.
Expand Down Expand Up @@ -191,22 +193,7 @@ def _prepareTasksBulk(self, transBody, taskDict, owner, ownerGroup):
method=method,
)

# Handle Input Data
inputData = paramsDict.get("InputData")
if inputData:
if isinstance(inputData, str):
inputData = inputData.replace(" ", "").split(";")
self._logVerbose(f"Setting input data to {inputData}", transID=transID, method=method)
seqDict["InputData"] = inputData
elif paramSeqDict.get("InputData") is not None:
self._logError("Invalid mixture of jobs with and without input data")
return S_ERROR(ETSDATA, "Invalid mixture of jobs with and without input data")

for paramName, paramValue in paramsDict.items():
if paramName not in ("InputData", "Site", "TargetSE"):
if paramValue:
self._logVerbose(f"Setting {paramName} to {paramValue}", transID=transID, method=method)
seqDict[paramName] = paramValue
inputData = self._handleInputsBulk(seqDict, paramsDict, transID)

outputParameterList = []
if self.outputDataModule:
Expand Down Expand Up @@ -235,7 +222,7 @@ def _prepareTasksBulk(self, transBody, taskDict, owner, ownerGroup):
paramSeqDict.setdefault(pName, []).append(seq)

for paramName, paramSeq in paramSeqDict.items():
if paramName in ["JOB_ID", "PRODUCTION_ID", "InputData"] + outputParameterList:
if paramName in self.parametricSequencedKeys + outputParameterList:
res = oJob.setParameterSequence(paramName, paramSeq, addToWorkflow=paramName)
else:
res = oJob.setParameterSequence(paramName, paramSeq)
Expand Down Expand Up @@ -399,6 +386,28 @@ def _handleInputs(self, oJob, paramsDict):
if not res["OK"]:
self._logError(f"Could not set the inputs: {res['Message']}", transID=transID, method="_handleInputs")

def _handleInputsBulk(self, seqDict, paramsDict, transID):
"""set job inputs (+ metadata)"""
method = "_handleInputsBulk"
if seqDict:
self._logVerbose(f"Setting job input data to {seqDict}", transID=transID, method=method)

# Handle Input Data
inputData = paramsDict.get("InputData")
if inputData:
if isinstance(inputData, str):
inputData = inputData.replace(" ", "").split(";")
self._logVerbose(f"Setting input data {inputData} to {seqDict}", transID=transID, method=method)
seqDict["InputData"] = inputData

for paramName, paramValue in paramsDict.items():
if paramName not in ("InputData", "Site", "TargetSE"):
if paramValue:
self._logVerbose(f"Setting {paramName} to {paramValue}", transID=transID, method=method)
seqDict[paramName] = paramValue

return inputData

def _handleRest(self, oJob, paramsDict):
"""add as JDL parameters all the other parameters that are not for inputs or destination"""
transID = paramsDict["TransformationID"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# pylint: disable=protected-access,missing-docstring,invalid-name

from unittest.mock import MagicMock

import pytest

from DIRAC import gLogger, S_OK
Expand Down Expand Up @@ -136,3 +137,28 @@ def test__handleDestination(mocker, paramsDict, expected):
mocker.patch("DIRAC.TransformationSystem.Client.TaskManagerPlugin.getSitesForSE", side_effect=ourgetSitesForSE)
res = wfTasks._handleDestination(paramsDict)
assert sorted(res) == sorted(expected)


@pytest.mark.parametrize(
"seqDict, paramsDict, expected",
[
({}, {}, None),
({"Site": "Site1", "JobName": "Job1", "JOB_ID": "00000001"}, {}, None),
(
{"Site": "Site1", "JobName": "Job1", "JOB_ID": "00000001"},
{"Site": "Site1", "JobType": "Sprucing", "TransformationID": 1},
None,
),
(
{"Site": "Site1", "JobName": "Job1", "JOB_ID": "00000001"},
{"Site": "Site1", "JobType": "Sprucing", "TransformationID": 1, "InputData": ["a1", "a2"]},
["a1", "a2"],
),
# ({"a1": "aa1", "a2": "aa2", "a3": "aa3"}, {"b1": "bb1", "b2": "bb2", "b3": "bb3"}, {"b1": "bb1", "b2": "bb2"}, ["a1", "a2"]),
],
)
def test__handleInputsBulk(mocker, seqDict, paramsDict, expected):
"""Test the _handleInputsBulk method WorkflowTasks"""
mocker.patch("DIRAC.TransformationSystem.Client.TaskManagerPlugin.getSitesForSE", side_effect=ourgetSitesForSE)
res = wfTasks._handleInputsBulk(seqDict, paramsDict, transID=1)
assert res == expected

0 comments on commit 479663f

Please sign in to comment.