diff --git a/tests/common/sudoku_test.py b/tests/common/sudoku_test.py new file mode 100644 index 0000000000..ad43754e6d --- /dev/null +++ b/tests/common/sudoku_test.py @@ -0,0 +1,116 @@ +from trinity.common.workflows.envs.sudoku.sudoku_generator import SudokuGenerator +from trinity.common.workflows.envs.sudoku.sudoku_judge import SudokuJudge + +# ---------- Generator Tests (9x9) ---------- + + +def test_9x9_generator_produces_valid_solution(): + gen = SudokuGenerator() + puzzle, solution = gen.generate() + + assert len(puzzle) == 9 + assert len(solution) == 9 + assert SudokuJudge.is_valid(solution) + + +def test_9x9_generator_creates_holes(): + gen = SudokuGenerator() + puzzle, _ = gen.generate() + + zero_count = sum(row.count(0) for row in puzzle) + assert zero_count > 0 + + +def test_9x9_solution_is_fully_filled(): + gen = SudokuGenerator() + _, solution = gen.generate() + + for row in solution: + assert 0 not in row + + +# ---------- Judge Tests (9x9) ---------- + + +def test_judge_allows_incomplete_board(): + board = [ + [5, 3, 0, 0, 7, 0, 0, 0, 0], + [6, 0, 0, 1, 9, 5, 0, 0, 0], + [0, 9, 8, 0, 0, 0, 0, 6, 0], + [8, 0, 0, 0, 6, 0, 0, 0, 3], + [4, 0, 0, 8, 0, 3, 0, 0, 1], + [7, 0, 0, 0, 2, 0, 0, 0, 6], + [0, 6, 0, 0, 0, 0, 2, 8, 0], + [0, 0, 0, 4, 1, 9, 0, 0, 5], + [0, 0, 0, 0, 8, 0, 0, 7, 9], + ] + + assert SudokuJudge.is_valid(board) + + +def test_judge_detects_row_violation(): + board = [ + [1, 1, 0, 0, 0, 0, 0, 0, 0], + ] + [[0] * 9 for _ in range(8)] + + assert not SudokuJudge.is_valid(board) + + +def test_judge_detects_column_violation(): + board = [ + [5, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 0, 0, 0, 0, 0, 0, 0, 0], + ] + [[0] * 9 for _ in range(7)] + + assert not SudokuJudge.is_valid(board) + + +def test_judge_detects_block_violation(): + board = [ + [1, 2, 3, 0, 0, 0, 0, 0, 0], + [4, 1, 0, 0, 0, 0, 0, 0, 0], + ] + [[0] * 9 for _ in range(7)] + + assert not SudokuJudge.is_valid(board) + + +# ---------- Generator & Judge Tests (4x4) ---------- + + +def test_4x4_generator_produces_valid_solution(): + gen = SudokuGenerator(size=4) + puzzle, solution = gen.generate() + + assert len(puzzle) == 4 + assert len(solution) == 4 + assert SudokuJudge.is_valid(solution) + + +def test_4x4_solution_is_fully_filled(): + gen = SudokuGenerator(size=4) + _, solution = gen.generate() + + for row in solution: + assert 0 not in row + + +def test_4x4_judge_detects_row_violation(): + board = [ + [1, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ] + + assert not SudokuJudge.is_valid(board) + + +def test_4x4_judge_detects_block_violation(): + board = [ + [1, 2, 0, 0], + [3, 1, 0, 0], # duplicate "1" in top-left 2x2 block + [0, 0, 0, 0], + [0, 0, 0, 0], + ] + + assert not SudokuJudge.is_valid(board) diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index ea7390b4a4..6d1e0b9614 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -48,6 +48,8 @@ # on-policy distillation workflows "on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow", "on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow", + # custom workflows + "sudoku_workflow": "trinity.common.workflows.envs.sudoku.sudoku_workflow.SudokuWorkflow", }, ) diff --git a/trinity/common/workflows/envs/sudoku/sudoku_generator.py b/trinity/common/workflows/envs/sudoku/sudoku_generator.py new file mode 100644 index 0000000000..3d524a82ef --- /dev/null +++ b/trinity/common/workflows/envs/sudoku/sudoku_generator.py @@ -0,0 +1,138 @@ +import math +import random + + +class SudokuGenerator: + """ + Sudoku puzzle generator using randomized backtracking. + + Features: + - Supports arbitrary square sizes (e.g., 9x9, 4x4) + - Generates a fully solved board first + - Removes cells based on difficulty to create a puzzle + - Avoids relying on a single canonical solution + """ + + def __init__(self, size: int = 9): + """ + Initialize the generator. + + Args: + size (int): Size of the Sudoku board (must be a perfect square). + Examples: 9 for 9x9, 4 for 4x4. + """ + self.size = size + self.block = int(math.sqrt(size)) + assert self.block * self.block == size, "Size must be a perfect square" + + def generate(self, difficulty: str = "medium"): + """ + Generate a Sudoku puzzle and its solution. + + Args: + difficulty (str): Difficulty level ("easy", "medium", "hard"). + + Returns: + tuple: (puzzle, solution), where puzzle contains zeros for empty cells. + """ + holes_map = { + "easy": self.size * self.size // 3, + "medium": self.size * self.size // 2, + "hard": self.size * self.size * 2 // 3, + } + holes = holes_map.get(difficulty, holes_map["medium"]) + + board = [[0 for _ in range(self.size)] for _ in range(self.size)] + self._fill_board(board) + + solution = [row[:] for row in board] + self._remove_cells(board, holes) + + return board, solution + + def _fill_board(self, board): + """ + Recursively fill the board using backtracking. + + Args: + board (list[list[int]]): Current board state. + + Returns: + bool: True if the board is completely filled. + """ + empty = self._find_empty(board) + if not empty: + return True + + r, c = empty + nums = list(range(1, self.size + 1)) + random.shuffle(nums) + + for v in nums: + if self._is_valid(board, r, c, v): + board[r][c] = v + if self._fill_board(board): + return True + board[r][c] = 0 + + return False + + def _find_empty(self, board): + """ + Find the next empty cell in the board. + + Args: + board (list[list[int]]): Current board state. + + Returns: + tuple | None: (row, col) of empty cell, or None if full. + """ + for i in range(self.size): + for j in range(self.size): + if board[i][j] == 0: + return i, j + return None + + def _is_valid(self, board, r, c, v): + """ + Check whether placing value v at (r, c) is valid. + + Args: + board (list[list[int]]): Current board state. + r (int): Row index. + c (int): Column index. + v (int): Value to place. + + Returns: + bool: True if valid, False otherwise. + """ + if v in board[r]: + return False + + for i in range(self.size): + if board[i][c] == v: + return False + + br = (r // self.block) * self.block + bc = (c // self.block) * self.block + for i in range(br, br + self.block): + for j in range(bc, bc + self.block): + if board[i][j] == v: + return False + + return True + + def _remove_cells(self, board, holes): + """ + Remove cells from a solved board to create a puzzle. + + Args: + board (list[list[int]]): Solved board. + holes (int): Number of cells to clear. + """ + cells = [(i, j) for i in range(self.size) for j in range(self.size)] + random.shuffle(cells) + + for i in range(min(holes, self.size * self.size)): + r, c = cells[i] + board[r][c] = 0 diff --git a/trinity/common/workflows/envs/sudoku/sudoku_judge.py b/trinity/common/workflows/envs/sudoku/sudoku_judge.py new file mode 100644 index 0000000000..e3e701b23c --- /dev/null +++ b/trinity/common/workflows/envs/sudoku/sudoku_judge.py @@ -0,0 +1,53 @@ +import math + + +class SudokuJudge: + """ + Judge Sudoku board state. + + - Supports both 9x9 and 4x4 Sudoku boards + - Allows incomplete boards (zeros are treated as empty cells) + - Checks: + * Row validity + * Column validity + * Sub-grid validity (3x3 for 9x9, 2x2 for 4x4) + """ + + @staticmethod + def is_valid(board): + size = len(board) + block = int(math.sqrt(size)) + + # Check rows + for row in board: + nums = [v for v in row if v != 0] + if len(nums) != len(set(nums)): + return False + + # Check columns + for c in range(size): + nums = [] + for r in range(size): + v = board[r][c] + if v != 0: + nums.append(v) + if len(nums) != len(set(nums)): + return False + + # Check sub-grids + for br in range(0, size, block): + for bc in range(0, size, block): + nums = [] + for r in range(br, br + block): + for c in range(bc, bc + block): + v = board[r][c] + if v != 0: + nums.append(v) + if len(nums) != len(set(nums)): + return False + + return True + + @staticmethod + def is_solved(board, solution): + return board == solution diff --git a/trinity/common/workflows/envs/sudoku/sudoku_workflow.py b/trinity/common/workflows/envs/sudoku/sudoku_workflow.py new file mode 100644 index 0000000000..7a00d90afd --- /dev/null +++ b/trinity/common/workflows/envs/sudoku/sudoku_workflow.py @@ -0,0 +1,208 @@ +from trinity.common.experience import Experience +from trinity.common.workflows.workflow import Workflow + +from .sudoku_generator import SudokuGenerator +from .sudoku_judge import SudokuJudge + + +class SudokuWorkflow(Workflow): + """ + Agentic multi-step Sudoku solving workflow. + + The workflow: + - Presents the current Sudoku board to the model + - Allows the model to propose multiple moves per step + - Applies moves incrementally and validates them + - Terminates on success, invalid action, or step limit + """ + + can_reset = True + + def __init__(self, task, model, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + + # Load puzzle from task if provided, otherwise generate a new one + if "puzzle" in task.raw_task and "solution" in task.raw_task: + self.board = [row[:] for row in task.raw_task["puzzle"]] + self.solution = [row[:] for row in task.raw_task["solution"]] + else: + generator = SudokuGenerator() + self.board, self.solution = generator.generate() + + self.judge = SudokuJudge() + + # Workflow configuration + self.max_steps = 20 + self.max_moves_per_step = 5 + + # Runtime state + self.current_step = 0 + self.last_board = None + self.last_action = None + + def reset(self, task): + """ + Reset workflow state for a new task instance. + """ + self.board = [row[:] for row in task.raw_task["puzzle"]] + self.solution = [row[:] for row in task.raw_task["solution"]] + self.current_step = 0 + self.last_board = None + self.last_action = None + + def render_board(self): + """ + Render the board into a human-readable string format + for inclusion in the prompt. + """ + return "\n".join(" ".join(str(v) for v in row) for row in self.board) + + def _build_prompt(self): + """ + Build a step-aware prompt describing: + - Sudoku rules + - Current board state + - Allowed action format + """ + prompt = ( + "You are playing a Sudoku game.\n\n" + "Rules:\n" + "- The board is 9x9.\n" + "- 0 means empty.\n" + "- Numbers 1–9 must appear exactly once in every row, column, and 3x3 block.\n" + "- You may only fill empty cells.\n\n" + "Task:\n" + "- In each step, output ONE OR MORE valid moves.\n" + f"- You may output up to {self.max_moves_per_step} moves per step.\n\n" + "Output format (STRICT):\n" + "row col value\n" + "row col value\n\n" + "Example:\n" + "0 2 4\n" + "1 3 5\n\n" + f"Current step: {self.current_step}\n" + f"Remaining steps: {self.max_steps - self.current_step}\n\n" + f"Current board:\n{self.render_board()}\n" + ) + + # Feedback when the previous step made no progress + if self.last_board is not None and self.board == self.last_board: + prompt += ( + "\nYour previous response was invalid or had no effect. " + "Please follow the rules and output format strictly." + ) + + return prompt + + def parse_action(self, text): + """ + Parse model output into a list of (row, col, value) moves. + + Expected format: + row col value + row col value + """ + lines = text.strip().splitlines() + actions = [] + + for line in lines: + line = line.strip() + if not line: + continue + + parts = line.split() + if len(parts) != 3: + return None + + try: + r, c, v = map(int, parts) + except ValueError: + return None + + if not (0 <= r <= 8 and 0 <= c <= 8 and 1 <= v <= 9): + return None + + actions.append((r, c, v)) + + if not actions or len(actions) > self.max_moves_per_step: + return None + + return actions + + def run(self): + """ + Execute the Sudoku workflow until: + - The puzzle is solved + - An invalid action is produced + - The maximum number of steps is reached + """ + experiences = [] + + for _ in range(self.max_steps): + prompt = self._build_prompt() + responses = self.model.chat([{"role": "user", "content": prompt}]) + resp = responses[0] + + # Snapshot board to detect no-op steps + self.last_board = [row[:] for row in self.board] + + actions = self.parse_action(resp.response_text) + if actions is None: + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=-1.0, + logprobs=resp.logprobs, + ) + ) + break + + board_changed = False + invalid_move = False + + for r, c, v in actions: + if self.board[r][c] != 0: + invalid_move = True + break + self.board[r][c] = v + board_changed = True + + # Invalid or ineffective step + if invalid_move or not board_changed or not self.judge.is_valid(self.board): + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=-1.0, + logprobs=resp.logprobs, + ) + ) + break + + # Solved successfully + if self.judge.is_solved(self.board, self.solution): + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=1.0, + logprobs=resp.logprobs, + ) + ) + break + + # Intermediate step + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=0.0, + logprobs=resp.logprobs, + ) + ) + + self.last_action = actions + self.current_step += 1 + + return experiences