MCTS Search Engine

The search engine is the core of ProbeLLM, implementing Monte Carlo Tree Search (MCTS) to intelligently explore the space of test cases.

Core Classes

VulnerabilityPipelineAsync

Purpose: Main orchestrator for vulnerability search

Key Methods:

from probellm import VulnerabilityPipelineAsync

pipeline = VulnerabilityPipelineAsync(
    model_name="gpt-5.2",        # Test generation model
    test_model="gpt-4o-mini", # Model under test
    judge_model="gpt-5.2",        # Answer correctness judge
    max_depth=1000,             # Max tree depth
    num_simulations=100,        # MCTS iterations per dataset
    num_samples=10,             # Initial samples per dataset
    tool_registry=None          # Optional: custom tools
)

# Add datasets
pipeline.add_datasets_batch(['mmlu', 'hellaswag'])

# Run search (concurrent across datasets)
pipeline.run()

Concurrent Execution:

By default, all datasets run concurrently using asyncio:

# Internally:
async def _run_concurrent(self):
    tasks = [
        asyncio.create_task(self._mcts_search_async(root))
        for root in self.tree_roots
    ]
    await asyncio.gather(*tasks)

RootNode

Purpose: Represents a dataset root in the search tree

Attributes:

  • dataset_id: Dataset identifier

  • searchtype: “micro” or “macro”

  • sampler: HierarchicalSampler or SequentialSampler

  • children: List of SyntheticNode

  • visits, error_count: MCTS statistics

Methods:

  • next_sample(): Get next sample from sampler

  • has_more_samples(): Check if more data available

  • get_next_index(depth): Get next node index at given depth

Synthetic Node

Purpose: Represents a generated test case

Attributes:

  • sample: {"query": str, "ground_truth": str, ...}

  • datatype: “Original” (from dataset) or “Synthetic” (generated)

  • parent: Parent node

  • children: Child nodes

  • visits, error_count: MCTS statistics

  • results: List of test results

MCTS Algorithm

Phase 1: Selection

Goal: Navigate tree to find node for expansion

Algorithm (UCB1):

def _select(self, root: RootNode) -> BaseNode:
    current = root

    while current.is_fully_expanded() and current.children:
        if current.is_terminal(max_depth):
            break

        current = current.best_child(exploration_weight=1.414)

    return current

def best_child(self, exploration_weight=1.414):
    # Handle unvisited children
    unvisited = [c for c in self.children if c.visits == 0]
    if unvisited:
        return random.choice(unvisited)

    # UCB1 formula
    def ucb1(child):
        exploitation = child.error_count / child.visits
        exploration = exploration_weight * sqrt(log(self.visits) / child.visits)
        return exploitation + exploration

    return max(self.children, key=ucb1)

Key Points:

  • Prioritizes unvisited nodes first

  • Balances exploitation (high error rate) vs exploration (low visit count)

  • exploration_weight=1.414 (default): sqrt(2), standard MCTS parameter

Phase 2: Expansion

Goal: Generate new test case using tools

Algorithm:

async def _expand_async(self, node: SyntheticNode, root: RootNode):
    # Micro: Perturb nearby
    if node.searchtype == "micro":
        new_question = self.test_case_gen.generate_nearby(
            node.sample['query'],
            node.sample['ground_truth']
        )

    # Macro: Explore far
    else:  # "macro"
        # Load embeddings of seen failures
        items = await self._load_pickle_async(f"embeddings_{node.dataset_id}_macro.pkl")

        # Greedy-k-center: select diverse samples
        indices = self.greedy_k_center(embeddings=items, k=5)
        samples = [items[i] for i in indices]

        new_question = self.test_case_gen.generate_far(samples)

    # Generate answer
    new_answer = self.answer_gen.generate_answer(
        new_question['candidate']['question'],
        context=node.sample
    )

    # Create child node
    child = SyntheticNode(...)
    node.add_child(child)

    return child

Tool Selection:

  • generate_nearby() / generate_far() internally call ToolRegistry.call_tool()

  • LLM planner chooses tool based on question characteristics

  • See Tool System (MCP-Based) for details

Phase 3: Simulation

Goal: Test model-under-test on generated case

Algorithm:

async def _simulate_async(self, node: SyntheticNode):
    query = node.sample['query']
    ground_truth = node.sample['ground_truth']

    # Inference
    model_response = await self.model_async(query)
    prediction = model_response["content"]

    # Judge
    correct, reason, judge_tokens = await self.judge_function_async(
        query, prediction, ground_truth
    )

    # Record result
    result = {
        "id": node.id,
        "query": query,
        "prediction": prediction,
        "error_detected": "False" if correct else "True",
        "error_reason": reason if not correct else "",
        "token_usage": {...}
    }
    node.results.append(result)

    return not correct  # True if error detected

Judge Model:

Uses separate LLM to compare prediction vs ground truth:

  • Strict on factual equivalence

  • Lenient on formatting/wording

  • Returns: (is_correct: bool, reason: str, tokens: dict)

Phase 4: Backpropagation

Goal: Update statistics from leaf to root

Algorithm:

def _backpropagate(self, node: BaseNode, error_detected: bool):
    current = node
    while current is not None:
        current.visits += 1
        if error_detected:
            current.error_count += 1
        current = current.parent

Effect:

  • Future UCB1 calculations favor branches with high error rates

  • Gradually shifts search toward promising regions

Checkpointing

Saving:

from probellm import create_checkpoints
create_checkpoints("results/run_xxx/")

Or via CLI:

python -m probellm.checkpoint results/run_xxx/

Format:

{
  "metadata": {
    "dataset_id": "mmlu",
    "last_simulation": 42,
    "timestamp": "2026-01-27T12:00:00"
  },
  "root_state": { ... },
  "nodes": [...]
}

Resuming:

from probellm import resume_from_checkpoint
resume_from_checkpoint("results/run_xxx/")

Or via CLI:

python -m probellm.resume results/run_xxx/

See CLI Reference for details.

Async Execution

All I/O-bound operations use asyncio for concurrency:

  • Model calls: await self.async_client.chat.completions.create(...)

  • Embeddings: await loop.run_in_executor(None, get_embedding, ...)

  • Dataset loading: await loop.run_in_executor(None, interface.load_dataset, ...)

Benefits:

  • Multiple datasets run in parallel

  • Model API calls don’t block each other

  • Significant speed-up for I/O-heavy workloads

Configuration

Per-Pipeline:

pipeline = VulnerabilityPipelineAsync(
    model_name="gpt-5.2",
    test_model="gpt-4o-mini",
    judge_model="gpt-5.2",
    max_depth=1000,
    num_simulations=100,
    num_samples=10
)

Global (config.py):

MODEL_RESPONSES_DEFAULT = "gpt-4o-mini"
MODEL_EMBEDDING_DEFAULT = "text-embedding-3-small"
MODEL_JUDGE = "gpt-4o-mini"

Best Practices

  1. Start Small: Use num_simulations=10 for quick tests

  2. Monitor Progress: Check error_log.txt for issues

  3. Validate First: Run python validate_config.py before long searches

  4. Use Checkpoints: Long searches → create checkpoints periodically

  5. Tune UCB: Adjust exploration_weight if search is too random/greedy

See Also