MCTS Search Engine
The search engine is the core of ProbeLLM, implementing Monte Carlo Tree Search (MCTS) to intelligently explore the space of test cases.
Core Classes
VulnerabilityPipelineAsync
Purpose: Main orchestrator for vulnerability search
Key Methods:
from probellm import VulnerabilityPipelineAsync
pipeline = VulnerabilityPipelineAsync(
model_name="gpt-5.2", # Test generation model
test_model="gpt-4o-mini", # Model under test
judge_model="gpt-5.2", # Answer correctness judge
max_depth=1000, # Max tree depth
num_simulations=100, # MCTS iterations per dataset
num_samples=10, # Initial samples per dataset
tool_registry=None # Optional: custom tools
)
# Add datasets
pipeline.add_datasets_batch(['mmlu', 'hellaswag'])
# Run search (concurrent across datasets)
pipeline.run()
Concurrent Execution:
By default, all datasets run concurrently using asyncio:
# Internally:
async def _run_concurrent(self):
tasks = [
asyncio.create_task(self._mcts_search_async(root))
for root in self.tree_roots
]
await asyncio.gather(*tasks)
RootNode
Purpose: Represents a dataset root in the search tree
Attributes:
dataset_id: Dataset identifiersearchtype: “micro” or “macro”sampler:HierarchicalSamplerorSequentialSamplerchildren: List ofSyntheticNodevisits,error_count: MCTS statistics
Methods:
next_sample(): Get next sample from samplerhas_more_samples(): Check if more data availableget_next_index(depth): Get next node index at given depth
Synthetic Node
Purpose: Represents a generated test case
Attributes:
sample:{"query": str, "ground_truth": str, ...}datatype: “Original” (from dataset) or “Synthetic” (generated)parent: Parent nodechildren: Child nodesvisits,error_count: MCTS statisticsresults: List of test results
MCTS Algorithm
Phase 1: Selection
Goal: Navigate tree to find node for expansion
Algorithm (UCB1):
def _select(self, root: RootNode) -> BaseNode:
current = root
while current.is_fully_expanded() and current.children:
if current.is_terminal(max_depth):
break
current = current.best_child(exploration_weight=1.414)
return current
def best_child(self, exploration_weight=1.414):
# Handle unvisited children
unvisited = [c for c in self.children if c.visits == 0]
if unvisited:
return random.choice(unvisited)
# UCB1 formula
def ucb1(child):
exploitation = child.error_count / child.visits
exploration = exploration_weight * sqrt(log(self.visits) / child.visits)
return exploitation + exploration
return max(self.children, key=ucb1)
Key Points:
Prioritizes unvisited nodes first
Balances exploitation (high error rate) vs exploration (low visit count)
exploration_weight=1.414(default): sqrt(2), standard MCTS parameter
Phase 2: Expansion
Goal: Generate new test case using tools
Algorithm:
async def _expand_async(self, node: SyntheticNode, root: RootNode):
# Micro: Perturb nearby
if node.searchtype == "micro":
new_question = self.test_case_gen.generate_nearby(
node.sample['query'],
node.sample['ground_truth']
)
# Macro: Explore far
else: # "macro"
# Load embeddings of seen failures
items = await self._load_pickle_async(f"embeddings_{node.dataset_id}_macro.pkl")
# Greedy-k-center: select diverse samples
indices = self.greedy_k_center(embeddings=items, k=5)
samples = [items[i] for i in indices]
new_question = self.test_case_gen.generate_far(samples)
# Generate answer
new_answer = self.answer_gen.generate_answer(
new_question['candidate']['question'],
context=node.sample
)
# Create child node
child = SyntheticNode(...)
node.add_child(child)
return child
Tool Selection:
generate_nearby()/generate_far()internally callToolRegistry.call_tool()LLM planner chooses tool based on question characteristics
See Tool System (MCP-Based) for details
Phase 3: Simulation
Goal: Test model-under-test on generated case
Algorithm:
async def _simulate_async(self, node: SyntheticNode):
query = node.sample['query']
ground_truth = node.sample['ground_truth']
# Inference
model_response = await self.model_async(query)
prediction = model_response["content"]
# Judge
correct, reason, judge_tokens = await self.judge_function_async(
query, prediction, ground_truth
)
# Record result
result = {
"id": node.id,
"query": query,
"prediction": prediction,
"error_detected": "False" if correct else "True",
"error_reason": reason if not correct else "",
"token_usage": {...}
}
node.results.append(result)
return not correct # True if error detected
Judge Model:
Uses separate LLM to compare prediction vs ground truth:
Strict on factual equivalence
Lenient on formatting/wording
Returns:
(is_correct: bool, reason: str, tokens: dict)
Phase 4: Backpropagation
Goal: Update statistics from leaf to root
Algorithm:
def _backpropagate(self, node: BaseNode, error_detected: bool):
current = node
while current is not None:
current.visits += 1
if error_detected:
current.error_count += 1
current = current.parent
Effect:
Future UCB1 calculations favor branches with high error rates
Gradually shifts search toward promising regions
Dual-Strategy Search
Root Selection (before each iteration):
# UCB1 for choosing micro vs macro
total_visits = micro_root.visits + macro_root.visits
micro_ucb = (micro_root.error_count / micro_root.visits) +
2 * sqrt(log(total_visits) / micro_root.visits)
macro_ucb = ... # Same formula
selected_root = micro_root if micro_ucb >= macro_ucb else macro_root
Result: Search automatically balances local exploitation (micro) and global exploration (macro).
Checkpointing
Saving:
from probellm import create_checkpoints
create_checkpoints("results/run_xxx/")
Or via CLI:
python -m probellm.checkpoint results/run_xxx/
Format:
{
"metadata": {
"dataset_id": "mmlu",
"last_simulation": 42,
"timestamp": "2026-01-27T12:00:00"
},
"root_state": { ... },
"nodes": [...]
}
Resuming:
from probellm import resume_from_checkpoint
resume_from_checkpoint("results/run_xxx/")
Or via CLI:
python -m probellm.resume results/run_xxx/
See CLI Reference for details.
Async Execution
All I/O-bound operations use asyncio for concurrency:
Model calls:
await self.async_client.chat.completions.create(...)Embeddings:
await loop.run_in_executor(None, get_embedding, ...)Dataset loading:
await loop.run_in_executor(None, interface.load_dataset, ...)
Benefits:
Multiple datasets run in parallel
Model API calls don’t block each other
Significant speed-up for I/O-heavy workloads
Configuration
Per-Pipeline:
pipeline = VulnerabilityPipelineAsync(
model_name="gpt-5.2",
test_model="gpt-4o-mini",
judge_model="gpt-5.2",
max_depth=1000,
num_simulations=100,
num_samples=10
)
Global (config.py):
MODEL_RESPONSES_DEFAULT = "gpt-4o-mini"
MODEL_EMBEDDING_DEFAULT = "text-embedding-3-small"
MODEL_JUDGE = "gpt-4o-mini"
Best Practices
Start Small: Use
num_simulations=10for quick testsMonitor Progress: Check
error_log.txtfor issuesValidate First: Run
python validate_config.pybefore long searchesUse Checkpoints: Long searches → create checkpoints periodically
Tune UCB: Adjust
exploration_weightif search is too random/greedy
See Also
Core Concepts: Detailed algorithm explanation
Quickstart: Hands-on tutorial
API reference: Full API reference