Section 06

Code: MCTS Selection and Program-of-Thought Verification

rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking 2025

Two Python implementations that show the core mechanics of rStar-Math. Both run free on Google Colab.

Block 1: MCTS UCB Selection

Implement the Upper Confidence Bound formula and simulate one round of MCTS node selection.

import math

class MCTSNode:
    """Represents one node in the MCTS search tree (a partial solution)."""
    def __init__(self, step_description, parent=None):
        self.step = step_description  # what solution step this node represents
        self.parent = parent           # parent node in tree
        self.children = []             # child nodes (next steps)
        self.visits = 0                # N(v): times this node has been visited
        self.total_reward = 0.0        # Q(v): sum of rewards from rollouts

    def ucb_score(self, exploration_const=math.sqrt(2)):
        """
        Compute UCB = Q(v)/N(v) + C * sqrt(ln(N(parent)) / N(v))
        Returns inf for unvisited nodes (always explore first).
        """
        if self.visits == 0:
            return float('inf')  # unvisited nodes have infinite UCB
        avg_reward = self.total_reward / self.visits
        parent_visits = self.parent.visits if self.parent else self.visits
        exploration_bonus = exploration_const * math.sqrt(
            math.log(parent_visits) / self.visits
        )
        return avg_reward + exploration_bonus

def select_best_child(node):
    """Return the child node with highest UCB score."""
    return max(node.children, key=lambda c: c.ucb_score())

# Simulate: MCTS has explored a problem and has 3 candidate next steps
root = MCTSNode("Problem: count integers 1-99 divisible by 3 but not 5")
root.visits = 20  # root has been visited 20 times in exploration

# Three candidate approaches: careful count, formula, list comprehension
approaches = [
    "Count multiples of 3 using range(3, 100, 3)",
    "Use arithmetic sequence formula",
    "List comprehension: [x for x in range(1,100) if x%3==0]"
]
histories = [
    (16.0, 19),  # (total_reward, visits) — cautious, heavily explored
    (4.2, 6),    # — moderate, less explored
    (0.8, 2),    # — bold, barely tried
]

for approach, (reward, visits) in zip(approaches, histories):
    child = MCTSNode(approach, parent=root)
    child.visits = visits
    child.total_reward = reward
    root.children.append(child)

# Show UCB scores
print("MCTS Node Selection (UCB Scores):")
print("=" * 70)
for i, child in enumerate(root.children, 1):
    ucb = child.ucb_score()
    avg = child.total_reward / child.visits
    print(f"\nApproach {i}: {child.step[:45]}...")
    print(f"  Visits: {child.visits:2d}  |  Total Reward: {child.total_reward:5.1f}")
    print(f"  Avg Reward: {avg:.3f}  |  Exploration Bonus: {ucb - avg:.3f}")
    print(f"  UCB Score: {ucb:.3f}")

selected = select_best_child(root)
print(f"\n{'=' * 70}")
print(f"MCTS selects: Approach with UCB = {selected.ucb_score():.3f}")
print(f"Step: {selected.step}")
print("\nWhy? The bold approach hasn't been tried much (2 visits).")
print("Exploration bonus outweighs lower average reward.")

What you see: Despite the “cautious approach” having a higher average reward (16/19 = 0.84), MCTS selects the bold approach (UCB = 2.9 vs. 1.5) because it hasn’t been explored. This forces MCTS to sample different strategies, not just exploit the best-known one.

Key insight: UCB prevents MCTS from prematurely committing to a suboptimal strategy. It balances learning what works (exploitation) with trying new things (exploration).


Block 2: Program-of-Thought Auto-Verification

Generate Python code for a solution, execute it, and automatically verify correctness.

def execute_and_verify(solution_code: str, expected_answer) -> tuple:
    """
    Execute a Python solution and verify if it produces expected_answer.
    Returns (result, is_correct, error_msg).
    """
    local_namespace = {}
    error_msg = None
    is_correct = False
    result = None
    
    try:
        # Execute the generated code in an isolated namespace
        exec(solution_code, {}, local_namespace)
        
        # Extract the final 'answer' variable
        result = local_namespace.get('answer', None)
        
        # Check if it matches expected
        is_correct = (result == expected_answer)
        
    except Exception as e:
        # Code failed: syntax error, runtime error, etc.
        error_msg = f"Execution error: {str(e)[:50]}"
        result = None
        is_correct = False
    
    return result, is_correct, error_msg

# Example 1: Correct solution (divisibility problem from Section 5)
correct_code = """
# Count integers 1-99 divisible by 3 but not 5
count_div3 = len([x for x in range(1, 100) if x % 3 == 0])
count_div15 = len([x for x in range(1, 100) if x % 15 == 0])
answer = count_div3 - count_div15
"""
result, is_correct, error = execute_and_verify(correct_code, 27)
print("Correct Solution:")
print(f"  Result: {result}  |  Correct: {is_correct}  |  Error: {error}")

# Example 2: Incorrect solution (forgot to subtract multiples of 15)
incomplete_code = """
count_div3 = len([x for x in range(1, 100) if x % 3 == 0])
answer = count_div3  # forgot to subtract!
"""
result, is_correct, error = execute_and_verify(incomplete_code, 27)
print("\nIncomplete Solution (forgot to exclude multiples of 5):")
print(f"  Result: {result}  |  Correct: {is_correct}  |  Error: {error}")

# Example 3: Syntax error (code doesn't run at all)
broken_code = """
count = len([x for x in range(1, 100) if x % 3 == 0)  # missing bracket!
answer = count
"""
result, is_correct, error = execute_and_verify(broken_code, 27)
print("\nBroken Syntax:")
print(f"  Result: {result}  |  Correct: {is_correct}")
print(f"  Error: {error}")

# Show the filtering in data collection
print("\n" + "=" * 70)
print("Data Collection Simulation:")
print("=" * 70)

solutions = [
    (correct_code, "Round 1 candidate 1 (high-quality)", 0.95),
    (incomplete_code, "Round 1 candidate 2 (low-quality)", 0.60),
    (broken_code, "Round 1 candidate 3 (broken)", 0.10),
]

collected_count = 0
for code, description, prm_score in solutions:
    result, is_correct, _ = execute_and_verify(code, 27)
    status = "KEEP" if (is_correct and prm_score > 0.7) else "DISCARD"
    if status == "KEEP":
        collected_count += 1
    print(f"{description:40s} | Correct: {is_correct} | PRM: {prm_score} | {status}")

print(f"\nAfter filtering: {collected_count} high-quality solutions collected")
print("(Only these are used for supervised fine-tuning)")

What you see:

  1. The correct solution returns 27 ✓
  2. The incomplete solution returns 33 (wrong) ✗
  3. The broken solution fails execution ✗

Data collection logic: Only keep solutions that (a) execute without error, (b) produce the correct answer, and (c) have high PRM scores. This ensures the training data is clean and high-quality.


Key Takeaways

MCTS UCB (Block 1)

  • Exploitation: Nodes with high average reward get selected often
  • Exploration: Rarely-visited nodes get a bonus, forcing MCTS to try different strategies
  • Balance: The constant C controls the exploration-exploitation tradeoff (higher C = more exploration)

Program-of-Thought (Block 2)

  • Automatic verification: No humans needed to judge if solutions are correct
  • Scalability: Can generate and verify thousands of solutions automatically
  • Filtering: Combine execution verification with PRM scores to select high-quality training data

Why this enables rStar-Math:

  1. MCTS efficiently explores solution spaces using guided search
  2. Python verification gives automatic, scalable correctness checking
  3. Together, they generate clean training data without human annotation
  4. The model trains on this data and improves, enabling better search next round