Skip to content

Commit e88afc5

Browse files
authored
Fix a crash in cstr.to_gres_dict when the input is just "gres:gpu" (#334)
Slurm allows to pass things like --gres=gpu, which in this case the default value will be just one. However, the char *gres value also really just contains "gres:gpu" when displayed via scontrol, with no indication of the count. The code previously assumed that Slurm internally modifies the char* explicitly to reflect the default count of 1.
1 parent 98e7357 commit e88afc5

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

pyslurm/utils/cstr.pyx

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ cpdef dict to_gres_dict(char *gres):
214214
return {}
215215

216216
for item in re.split(",(?=[^,]+?:)", gres_str):
217+
# char *gres might contain just "gres:gpu", without any count.
218+
# If not given, the count is always 1, so default to it.
219+
cnt = typ = "1"
217220

218221
# Remove the additional "gres" specifier if it exists
219222
if gres_delim in item:
@@ -223,15 +226,22 @@ cpdef dict to_gres_dict(char *gres):
223226
":(?=[^:]+?)",
224227
item.replace("(", ":", 1).replace(")", "")
225228
)
229+
gres_splitted_len = len(gres_splitted)
226230

227-
name, typ, cnt = gres_splitted[0], gres_splitted[1], 0
231+
name = gres_splitted[0]
232+
if gres_splitted_len > 1:
233+
typ = gres_splitted[1]
228234

229235
# Check if we have a gres type.
230236
if typ.isdigit():
231237
cnt = typ
232238
typ = None
233-
else:
239+
elif gres_splitted_len > 2:
234240
cnt = gres_splitted[2]
241+
else:
242+
# String is somehow malformed, should never happen when the input
243+
# comes from the slurmctld. Ignore if it happens.
244+
continue
235245

236246
# Dict Key-Name depends on if we have a gres type or not
237247
name_and_typ = f"{name}:{typ}" if typ else name

tests/unit/test_common.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,32 @@ def test_dict_to_gres_str(self):
152152
assert cstr.from_gres_dict("tesla:3,a100:5", "gpu") == expected_str
153153

154154
def test_str_to_gres_dict(self):
155-
input_str = "gpu:nvidia-a100:1(IDX:0,1)"
156-
expected = {"gpu:nvidia-a100":{"count": 1, "indexes": "0,1"}}
155+
input_str = "gpu:nvidia-a100:1(IDX:0)"
156+
expected = {"gpu:nvidia-a100": {"count": 1, "indexes": "0"}}
157157
assert cstr.to_gres_dict(input_str) == expected
158158

159-
input_str = "gpu:nvidia-a100:1"
160-
expected = {"gpu:nvidia-a100": 1}
159+
input_str = "gpu:nvidia-a100:2(IDX:0,1)"
160+
expected = {"gpu:nvidia-a100": {"count": 2, "indexes": "0,1"}}
161+
assert cstr.to_gres_dict(input_str) == expected
162+
163+
input_str = "gpu:nvidia-a100:5"
164+
expected = {"gpu:nvidia-a100": 5}
165+
assert cstr.to_gres_dict(input_str) == expected
166+
167+
input_str = "gpu:nvidia-a100:5,gres:gpu:nvidia-v100:10"
168+
expected = {"gpu:nvidia-a100": 5, "gpu:nvidia-v100": 10}
169+
assert cstr.to_gres_dict(input_str) == expected
170+
171+
input_str = "gres:gpu:2"
172+
expected = {"gpu": 2}
173+
assert cstr.to_gres_dict(input_str) == expected
174+
175+
input_str = "gres:gpu"
176+
expected = {"gpu": 1}
177+
assert cstr.to_gres_dict(input_str) == expected
178+
179+
input_str = "gres:gpu:INVALID_COUNT"
180+
expected = {}
161181
assert cstr.to_gres_dict(input_str) == expected
162182

163183
def test_gres_from_tres_dict(self):

0 commit comments

Comments
 (0)