Skip to content

Commit a9b2514

Browse files
committed
ut fix
1 parent 8236d72 commit a9b2514

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

dlrover/python/diagnosis/common/diagnosis_action.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ def clear(self):
260260
with self._lock:
261261
self._actions.clear()
262262

263+
def len(self):
264+
with self._lock:
265+
return sum(len(d) for d in self._actions.values())
266+
263267
def next_action(
264268
self,
265269
instance=DiagnosisConstant.LOCAL_INSTANCE,

dlrover/python/tests/test_args.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,28 @@
1313

1414
import unittest
1515

16-
from dlrover.python.master.args import parse_master_args
16+
from dlrover.python.master.args import parse_master_args, str2bool
1717

1818

1919
class ArgsTest(unittest.TestCase):
20+
def test_str2bool(self):
21+
self.assertTrue(str2bool("TRUE"))
22+
self.assertTrue(str2bool("True"))
23+
self.assertTrue(str2bool("true"))
24+
self.assertTrue(str2bool("yes"))
25+
self.assertTrue(str2bool("t"))
26+
self.assertTrue(str2bool("y"))
27+
self.assertTrue(str2bool("1"))
28+
self.assertTrue(str2bool(True))
29+
30+
self.assertFalse(str2bool("FALSE"))
31+
self.assertFalse(str2bool("False"))
32+
self.assertFalse(str2bool("false"))
33+
self.assertFalse(str2bool("no"))
34+
self.assertFalse(str2bool("n"))
35+
self.assertFalse(str2bool("0"))
36+
self.assertFalse(str2bool(False))
37+
2038
def test_parse_master_args(self):
2139
original_args = [
2240
"--job_name",

dlrover/python/tests/test_diagnosis_manager.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@
1515
import unittest
1616
from typing import List
1717
from unittest import mock
18+
from unittest.mock import MagicMock
1819

20+
from dlrover.python.common.constants import NodeStatus
1921
from dlrover.python.diagnosis.common.constants import (
2022
DiagnosisActionType,
2123
DiagnosisDataType,
2224
)
25+
from dlrover.python.diagnosis.common.diagnosis_action import (
26+
DiagnosisAction,
27+
NodeAction,
28+
)
2329
from dlrover.python.diagnosis.common.diagnosis_data import (
2430
DiagnosisData,
2531
TrainingLog,
@@ -38,6 +44,11 @@
3844
DiagnosisDataManager,
3945
)
4046
from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager
47+
from dlrover.python.master.diagnosis.precheck_operator import (
48+
PreCheckOperator,
49+
PreCheckResult,
50+
)
51+
from dlrover.python.master.node.job_context import get_job_context
4152

4253

4354
class DiagnosisManagerTest(unittest.TestCase):
@@ -107,5 +118,35 @@ def test_diagnosis_manager(self):
107118
self.assertEqual(action.action_type, DiagnosisActionType.NONE)
108119

109120
def test_pre_check(self):
121+
job_context = get_job_context()
110122
mgr = DiagnosisManager()
111123
mgr.pre_check()
124+
self.assertEqual(job_context._action_queue.len(), 0)
125+
126+
mgr.get_pre_check_operators = MagicMock(return_value=[TestOperator()])
127+
mgr.pre_check()
128+
self.assertTrue(isinstance(job_context.next_action(1), NodeAction))
129+
130+
131+
class TestOperator(PreCheckOperator):
132+
@classmethod
133+
def get_retry_interval_secs(cls) -> int:
134+
return 1
135+
136+
@classmethod
137+
def get_retry_limit_times(cls) -> int:
138+
return 1
139+
140+
def check(self) -> PreCheckResult:
141+
return PreCheckResult(1, "test", [1])
142+
143+
def recover(self):
144+
pass
145+
146+
def get_failed_action(self) -> DiagnosisAction:
147+
return NodeAction(
148+
node_id=1,
149+
node_status=NodeStatus.FAILED,
150+
reason="hang",
151+
action_type=DiagnosisActionType.MASTER_RELAUNCH_WORKER,
152+
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import unittest
15+
16+
from dlrover.python.diagnosis.common.diagnosis_action import NoAction
17+
from dlrover.python.master.diagnosis.precheck_operator import (
18+
NoPreCheckOperator,
19+
)
20+
21+
22+
class PreCheckOperatorTest(unittest.TestCase):
23+
def setUp(self):
24+
pass
25+
26+
def tearDown(self):
27+
pass
28+
29+
def test_no_pre_check_op(self):
30+
op = NoPreCheckOperator()
31+
self.assertTrue(op.check())
32+
op.recover()
33+
self.assertEqual(op.get_retry_interval_secs(), 5)
34+
self.assertEqual(op.get_retry_limit_times(), 3)
35+
self.assertTrue(isinstance(op.get_failed_action(), NoAction))
36+
37+
38+
if __name__ == "__main__":
39+
unittest.main()

0 commit comments

Comments
 (0)