MCTS Search Engine ================== The search engine is the core of ProbeLLM, implementing Monte Carlo Tree Search (MCTS) to intelligently explore the space of test cases. Core Classes ------------ VulnerabilityPipelineAsync ^^^^^^^^^^^^^^^^^^^^^^^^^^^ **Purpose**: Main orchestrator for vulnerability search **Key Methods**: .. code-block:: python from probellm import VulnerabilityPipelineAsync pipeline = VulnerabilityPipelineAsync( model_name="gpt-5.2", # Test generation model test_model="gpt-4o-mini", # Model under test judge_model="gpt-5.2", # Answer correctness judge max_depth=1000, # Max tree depth num_simulations=100, # MCTS iterations per dataset num_samples=10, # Initial samples per dataset tool_registry=None # Optional: custom tools ) # Add datasets pipeline.add_datasets_batch(['mmlu', 'hellaswag']) # Run search (concurrent across datasets) pipeline.run() **Concurrent Execution**: By default, all datasets run **concurrently** using ``asyncio``: .. code-block:: python # Internally: async def _run_concurrent(self): tasks = [ asyncio.create_task(self._mcts_search_async(root)) for root in self.tree_roots ] await asyncio.gather(*tasks) RootNode ^^^^^^^^ **Purpose**: Represents a dataset root in the search tree **Attributes**: - ``dataset_id``: Dataset identifier - ``searchtype``: "micro" or "macro" - ``sampler``: ``HierarchicalSampler`` or ``SequentialSampler`` - ``children``: List of ``SyntheticNode`` - ``visits``, ``error_count``: MCTS statistics **Methods**: - ``next_sample()``: Get next sample from sampler - ``has_more_samples()``: Check if more data available - ``get_next_index(depth)``: Get next node index at given depth Synthetic Node ^^^^^^^^^^^^^^ **Purpose**: Represents a generated test case **Attributes**: - ``sample``: ``{"query": str, "ground_truth": str, ...}`` - ``datatype``: "Original" (from dataset) or "Synthetic" (generated) - ``parent``: Parent node - ``children``: Child nodes - ``visits``, ``error_count``: MCTS statistics - ``results``: List of test results MCTS Algorithm -------------- Phase 1: Selection ^^^^^^^^^^^^^^^^^^ **Goal**: Navigate tree to find node for expansion **Algorithm** (UCB1): .. code-block:: python def _select(self, root: RootNode) -> BaseNode: current = root while current.is_fully_expanded() and current.children: if current.is_terminal(max_depth): break current = current.best_child(exploration_weight=1.414) return current def best_child(self, exploration_weight=1.414): # Handle unvisited children unvisited = [c for c in self.children if c.visits == 0] if unvisited: return random.choice(unvisited) # UCB1 formula def ucb1(child): exploitation = child.error_count / child.visits exploration = exploration_weight * sqrt(log(self.visits) / child.visits) return exploitation + exploration return max(self.children, key=ucb1) **Key Points**: - Prioritizes **unvisited** nodes first - Balances **exploitation** (high error rate) vs **exploration** (low visit count) - ``exploration_weight=1.414`` (default): sqrt(2), standard MCTS parameter Phase 2: Expansion ^^^^^^^^^^^^^^^^^^ **Goal**: Generate new test case using tools **Algorithm**: .. code-block:: python async def _expand_async(self, node: SyntheticNode, root: RootNode): # Micro: Perturb nearby if node.searchtype == "micro": new_question = self.test_case_gen.generate_nearby( node.sample['query'], node.sample['ground_truth'] ) # Macro: Explore far else: # "macro" # Load embeddings of seen failures items = await self._load_pickle_async(f"embeddings_{node.dataset_id}_macro.pkl") # Greedy-k-center: select diverse samples indices = self.greedy_k_center(embeddings=items, k=5) samples = [items[i] for i in indices] new_question = self.test_case_gen.generate_far(samples) # Generate answer new_answer = self.answer_gen.generate_answer( new_question['candidate']['question'], context=node.sample ) # Create child node child = SyntheticNode(...) node.add_child(child) return child **Tool Selection**: - ``generate_nearby()`` / ``generate_far()`` internally call ``ToolRegistry.call_tool()`` - LLM planner chooses tool based on question characteristics - See :doc:`tools` for details Phase 3: Simulation ^^^^^^^^^^^^^^^^^^^^ **Goal**: Test model-under-test on generated case **Algorithm**: .. code-block:: python async def _simulate_async(self, node: SyntheticNode): query = node.sample['query'] ground_truth = node.sample['ground_truth'] # Inference model_response = await self.model_async(query) prediction = model_response["content"] # Judge correct, reason, judge_tokens = await self.judge_function_async( query, prediction, ground_truth ) # Record result result = { "id": node.id, "query": query, "prediction": prediction, "error_detected": "False" if correct else "True", "error_reason": reason if not correct else "", "token_usage": {...} } node.results.append(result) return not correct # True if error detected **Judge Model**: Uses separate LLM to compare prediction vs ground truth: - Strict on **factual equivalence** - Lenient on **formatting/wording** - Returns: ``(is_correct: bool, reason: str, tokens: dict)`` Phase 4: Backpropagation ^^^^^^^^^^^^^^^^^^^^^^^^^ **Goal**: Update statistics from leaf to root **Algorithm**: .. code-block:: python def _backpropagate(self, node: BaseNode, error_detected: bool): current = node while current is not None: current.visits += 1 if error_detected: current.error_count += 1 current = current.parent **Effect**: - Future UCB1 calculations favor branches with high error rates - Gradually shifts search toward promising regions Dual-Strategy Search -------------------- **Root Selection** (before each iteration): .. code-block:: python # UCB1 for choosing micro vs macro total_visits = micro_root.visits + macro_root.visits micro_ucb = (micro_root.error_count / micro_root.visits) + 2 * sqrt(log(total_visits) / micro_root.visits) macro_ucb = ... # Same formula selected_root = micro_root if micro_ucb >= macro_ucb else macro_root **Result**: Search automatically balances local exploitation (micro) and global exploration (macro). Checkpointing ------------- **Saving**: .. code-block:: python from probellm import create_checkpoints create_checkpoints("results/run_xxx/") Or via CLI: .. code-block:: bash python -m probellm.checkpoint results/run_xxx/ **Format**: .. code-block:: text { "metadata": { "dataset_id": "mmlu", "last_simulation": 42, "timestamp": "2026-01-27T12:00:00" }, "root_state": { ... }, "nodes": [...] } **Resuming**: .. code-block:: python from probellm import resume_from_checkpoint resume_from_checkpoint("results/run_xxx/") Or via CLI: .. code-block:: bash python -m probellm.resume results/run_xxx/ See :doc:`../cli` for details. Async Execution --------------- All I/O-bound operations use ``asyncio`` for concurrency: - **Model calls**: ``await self.async_client.chat.completions.create(...)`` - **Embeddings**: ``await loop.run_in_executor(None, get_embedding, ...)`` - **Dataset loading**: ``await loop.run_in_executor(None, interface.load_dataset, ...)`` **Benefits**: - Multiple datasets run in parallel - Model API calls don't block each other - Significant speed-up for I/O-heavy workloads Configuration ------------- **Per-Pipeline**: .. code-block:: python pipeline = VulnerabilityPipelineAsync( model_name="gpt-5.2", test_model="gpt-4o-mini", judge_model="gpt-5.2", max_depth=1000, num_simulations=100, num_samples=10 ) **Global** (``config.py``): .. code-block:: python MODEL_RESPONSES_DEFAULT = "gpt-4o-mini" MODEL_EMBEDDING_DEFAULT = "text-embedding-3-small" MODEL_JUDGE = "gpt-4o-mini" Best Practices -------------- 1. **Start Small**: Use ``num_simulations=10`` for quick tests 2. **Monitor Progress**: Check ``error_log.txt`` for issues 3. **Validate First**: Run ``python validate_config.py`` before long searches 4. **Use Checkpoints**: Long searches → create checkpoints periodically 5. **Tune UCB**: Adjust ``exploration_weight`` if search is too random/greedy See Also -------- - :doc:`../concepts`: Detailed algorithm explanation - :doc:`../quickstart`: Hands-on tutorial - :doc:`../api`: Full API reference