Based on claude-code-tools TmuxCLIController, this refactor: - Added DockerTmuxController class for robust tmux session management - Implements send_keys() with configurable delay_enter - Implements capture_pane() for output retrieval - Implements wait_for_prompt() for pattern-based completion detection - Implements wait_for_idle() for content-hash-based idle detection - Implements wait_for_shell_prompt() for shell prompt detection Also includes workflow improvements: - Pre-task git snapshot before agent execution - Post-task commit protocol in agent guidelines Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
593 lines
20 KiB
Python
593 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Smart Router - Intelligent task routing using Gemini 3 Flash for decision making.
|
|
|
|
Key decision points:
|
|
1. Task Complexity Analysis - Before dispatch, assess complexity
|
|
2. Agent Selection - Route to optimal agent/model based on task
|
|
3. Response Validation - Check output quality before returning
|
|
4. Continuation Decisions - Determine if follow-up is needed
|
|
|
|
Uses Gemini 3 Flash for fast, cost-effective decisions at critical flow points.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import time
|
|
|
|
# Try to import google.generativeai
|
|
try:
|
|
import google.generativeai as genai
|
|
GEMINI_AVAILABLE = True
|
|
except ImportError:
|
|
GEMINI_AVAILABLE = False
|
|
genai = None
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TaskComplexity(Enum):
|
|
"""Task complexity levels for routing decisions."""
|
|
TRIVIAL = "trivial" # Simple command, quick lookup
|
|
SIMPLE = "simple" # Single-step task, clear path
|
|
MODERATE = "moderate" # Multi-step, some reasoning needed
|
|
COMPLEX = "complex" # Deep analysis, multi-agent coordination
|
|
RESEARCH = "research" # Open-ended exploration
|
|
|
|
|
|
class AgentTier(Enum):
|
|
"""Agent tiers for model selection."""
|
|
FLASH = "flash" # Gemini Flash - fast decisions
|
|
HAIKU = "haiku" # Claude Haiku - quick tasks
|
|
SONNET = "sonnet" # Claude Sonnet - balanced
|
|
OPUS = "opus" # Claude Opus - complex tasks
|
|
PRO = "pro" # Gemini Pro - deep reasoning
|
|
|
|
|
|
@dataclass
|
|
class RoutingDecision:
|
|
"""Result of routing analysis."""
|
|
complexity: TaskComplexity
|
|
recommended_agent: AgentTier
|
|
reasoning: str
|
|
confidence: float
|
|
suggested_steps: List[str]
|
|
estimated_tokens: int
|
|
requires_human: bool = False
|
|
validation_needed: bool = True
|
|
|
|
|
|
@dataclass
|
|
class ValidationResult:
|
|
"""Result of output validation."""
|
|
is_valid: bool
|
|
quality_score: float # 0-1
|
|
issues: List[str]
|
|
suggestions: List[str]
|
|
needs_retry: bool = False
|
|
continuation_prompt: Optional[str] = None
|
|
|
|
|
|
class GeminiDecisionEngine:
|
|
"""Gemini Flash-powered decision engine for fast routing decisions."""
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
"""Initialize Gemini decision engine.
|
|
|
|
Args:
|
|
api_key: Gemini API key (defaults to env var)
|
|
"""
|
|
self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
|
self.model = None
|
|
self.available = False
|
|
self._initialize()
|
|
|
|
def _initialize(self) -> None:
|
|
"""Initialize Gemini client."""
|
|
if not GEMINI_AVAILABLE:
|
|
logger.warning("google-generativeai not installed - falling back to heuristics")
|
|
return
|
|
|
|
if not self.api_key:
|
|
# Try multiple sources for API key
|
|
api_key_sources = [
|
|
"/opt/pal-mcp-server/.env", # PAL MCP server env (primary)
|
|
"/etc/shared-ai-credentials/gemini/api-key", # Shared credentials
|
|
]
|
|
|
|
for source in api_key_sources:
|
|
try:
|
|
if source.endswith('.env'):
|
|
# Parse .env file
|
|
with open(source, "r") as f:
|
|
for line in f:
|
|
if line.startswith("GEMINI_API_KEY="):
|
|
self.api_key = line.split("=", 1)[1].strip().strip('"\'')
|
|
break
|
|
else:
|
|
# Plain text file
|
|
with open(source, "r") as f:
|
|
self.api_key = f.read().strip()
|
|
|
|
if self.api_key:
|
|
logger.info(f"Gemini API key loaded from {source}")
|
|
break
|
|
except (FileNotFoundError, PermissionError):
|
|
continue
|
|
|
|
if not self.api_key:
|
|
logger.warning("Gemini API key not found - falling back to heuristics")
|
|
return
|
|
|
|
try:
|
|
genai.configure(api_key=self.api_key)
|
|
self.model = genai.GenerativeModel("gemini-2.0-flash")
|
|
self.available = True
|
|
logger.info("Gemini decision engine initialized (gemini-2.0-flash)")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Gemini: {e}")
|
|
|
|
def analyze_complexity(self, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""Analyze task complexity using Gemini Flash.
|
|
|
|
Args:
|
|
task: Task description
|
|
context: Optional context about project, history
|
|
|
|
Returns:
|
|
Complexity analysis result
|
|
"""
|
|
if not self.available:
|
|
return self._heuristic_complexity(task)
|
|
|
|
prompt = f"""Analyze this task's complexity for routing to an AI agent.
|
|
|
|
TASK: {task}
|
|
|
|
CONTEXT: {json.dumps(context or {}, indent=2)}
|
|
|
|
Respond in JSON:
|
|
{{
|
|
"complexity": "trivial|simple|moderate|complex|research",
|
|
"reasoning": "brief explanation",
|
|
"confidence": 0.0-1.0,
|
|
"estimated_steps": ["step1", "step2"],
|
|
"requires_code_changes": true/false,
|
|
"requires_file_reads": true/false,
|
|
"requires_external_calls": true/false,
|
|
"risk_level": "low|medium|high"
|
|
}}"""
|
|
|
|
try:
|
|
response = self.model.generate_content(
|
|
prompt,
|
|
generation_config=genai.GenerationConfig(
|
|
temperature=0.1,
|
|
max_output_tokens=500
|
|
)
|
|
)
|
|
|
|
# Parse JSON from response
|
|
text = response.text.strip()
|
|
# Handle markdown code blocks
|
|
if text.startswith("```"):
|
|
text = text.split("```")[1]
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
|
|
return json.loads(text)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Gemini complexity analysis failed: {e}")
|
|
return self._heuristic_complexity(task)
|
|
|
|
def _heuristic_complexity(self, task: str) -> Dict[str, Any]:
|
|
"""Fallback heuristic-based complexity analysis."""
|
|
task_lower = task.lower()
|
|
|
|
# Simple keyword matching for fallback
|
|
if any(word in task_lower for word in ["list", "show", "check", "status", "what is"]):
|
|
complexity = "trivial"
|
|
confidence = 0.7
|
|
elif any(word in task_lower for word in ["fix", "update", "change", "add"]):
|
|
complexity = "simple"
|
|
confidence = 0.6
|
|
elif any(word in task_lower for word in ["implement", "create", "build", "develop"]):
|
|
complexity = "moderate"
|
|
confidence = 0.5
|
|
elif any(word in task_lower for word in ["refactor", "optimize", "debug", "investigate"]):
|
|
complexity = "complex"
|
|
confidence = 0.5
|
|
elif any(word in task_lower for word in ["research", "analyze", "design", "architect"]):
|
|
complexity = "research"
|
|
confidence = 0.5
|
|
else:
|
|
complexity = "moderate"
|
|
confidence = 0.4
|
|
|
|
return {
|
|
"complexity": complexity,
|
|
"reasoning": "Heuristic analysis (Gemini unavailable)",
|
|
"confidence": confidence,
|
|
"estimated_steps": [],
|
|
"requires_code_changes": "implement" in task_lower or "fix" in task_lower,
|
|
"requires_file_reads": True,
|
|
"requires_external_calls": False,
|
|
"risk_level": "medium"
|
|
}
|
|
|
|
def validate_output(self, task: str, output: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""Validate agent output quality.
|
|
|
|
Args:
|
|
task: Original task
|
|
output: Agent's output
|
|
context: Additional context
|
|
|
|
Returns:
|
|
Validation result
|
|
"""
|
|
if not self.available:
|
|
return self._heuristic_validation(task, output)
|
|
|
|
# Truncate output for validation (avoid huge prompts)
|
|
output_truncated = output[:3000] if len(output) > 3000 else output
|
|
|
|
prompt = f"""Validate this AI agent's response to a task.
|
|
|
|
TASK: {task}
|
|
|
|
RESPONSE (may be truncated):
|
|
{output_truncated}
|
|
|
|
Respond in JSON:
|
|
{{
|
|
"is_valid": true/false,
|
|
"quality_score": 0.0-1.0,
|
|
"issues": ["issue1", "issue2"],
|
|
"suggestions": ["suggestion1"],
|
|
"task_completed": true/false,
|
|
"needs_follow_up": true/false,
|
|
"follow_up_prompt": "optional continuation prompt"
|
|
}}"""
|
|
|
|
try:
|
|
response = self.model.generate_content(
|
|
prompt,
|
|
generation_config=genai.GenerationConfig(
|
|
temperature=0.1,
|
|
max_output_tokens=500
|
|
)
|
|
)
|
|
|
|
text = response.text.strip()
|
|
if text.startswith("```"):
|
|
text = text.split("```")[1]
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
|
|
return json.loads(text)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Gemini validation failed: {e}")
|
|
return self._heuristic_validation(task, output)
|
|
|
|
def _heuristic_validation(self, task: str, output: str) -> Dict[str, Any]:
|
|
"""Fallback heuristic output validation."""
|
|
# Basic checks
|
|
has_content = len(output.strip()) > 50
|
|
has_code = "```" in output or "def " in output or "function " in output
|
|
has_error = "error" in output.lower() or "failed" in output.lower()
|
|
|
|
quality = 0.5
|
|
if has_content:
|
|
quality += 0.2
|
|
if has_code and ("implement" in task.lower() or "code" in task.lower()):
|
|
quality += 0.2
|
|
if has_error:
|
|
quality -= 0.3
|
|
|
|
return {
|
|
"is_valid": has_content and not has_error,
|
|
"quality_score": max(0.0, min(1.0, quality)),
|
|
"issues": ["Error detected in output"] if has_error else [],
|
|
"suggestions": [],
|
|
"task_completed": has_content,
|
|
"needs_follow_up": has_error,
|
|
"follow_up_prompt": "Please fix the errors and try again" if has_error else None
|
|
}
|
|
|
|
def route_task(self, task: str, project: str, complexity: str) -> Dict[str, Any]:
|
|
"""Determine optimal agent/model for task.
|
|
|
|
Args:
|
|
task: Task description
|
|
project: Target project
|
|
complexity: Pre-analyzed complexity
|
|
|
|
Returns:
|
|
Routing recommendation
|
|
"""
|
|
if not self.available:
|
|
return self._heuristic_routing(task, project, complexity)
|
|
|
|
prompt = f"""Recommend the best AI agent configuration for this task.
|
|
|
|
TASK: {task}
|
|
PROJECT: {project}
|
|
COMPLEXITY: {complexity}
|
|
|
|
Available agents:
|
|
- flash: Gemini Flash - Fast, cheap, good for simple tasks
|
|
- haiku: Claude Haiku - Quick, efficient, good for straightforward coding
|
|
- sonnet: Claude Sonnet - Balanced, good for most development tasks
|
|
- opus: Claude Opus - Most capable, for complex analysis
|
|
- pro: Gemini Pro - Deep reasoning, research tasks
|
|
|
|
Respond in JSON:
|
|
{{
|
|
"recommended_agent": "flash|haiku|sonnet|opus|pro",
|
|
"reasoning": "why this agent",
|
|
"backup_agent": "alternative if first fails",
|
|
"special_instructions": "any task-specific guidance",
|
|
"estimated_time": "quick|moderate|long",
|
|
"suggested_tools": ["Read", "Edit", "Bash"]
|
|
}}"""
|
|
|
|
try:
|
|
response = self.model.generate_content(
|
|
prompt,
|
|
generation_config=genai.GenerationConfig(
|
|
temperature=0.1,
|
|
max_output_tokens=400
|
|
)
|
|
)
|
|
|
|
text = response.text.strip()
|
|
if text.startswith("```"):
|
|
text = text.split("```")[1]
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
|
|
return json.loads(text)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Gemini routing failed: {e}")
|
|
return self._heuristic_routing(task, project, complexity)
|
|
|
|
def _heuristic_routing(self, task: str, project: str, complexity: str) -> Dict[str, Any]:
|
|
"""Fallback heuristic task routing."""
|
|
# Map complexity to agent
|
|
complexity_to_agent = {
|
|
"trivial": "haiku",
|
|
"simple": "haiku",
|
|
"moderate": "sonnet",
|
|
"complex": "sonnet",
|
|
"research": "pro"
|
|
}
|
|
|
|
return {
|
|
"recommended_agent": complexity_to_agent.get(complexity, "sonnet"),
|
|
"reasoning": f"Heuristic routing for {complexity} task",
|
|
"backup_agent": "sonnet",
|
|
"special_instructions": None,
|
|
"estimated_time": "moderate",
|
|
"suggested_tools": ["Read", "Edit", "Bash", "Glob", "Grep"]
|
|
}
|
|
|
|
|
|
class SmartRouter:
|
|
"""Main smart routing orchestrator integrating Gemini decisions."""
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
"""Initialize smart router.
|
|
|
|
Args:
|
|
api_key: Optional Gemini API key
|
|
"""
|
|
self.decision_engine = GeminiDecisionEngine(api_key)
|
|
self.routing_history: List[Dict[str, Any]] = []
|
|
self.max_history = 100
|
|
logger.info(f"SmartRouter initialized (Gemini: {self.decision_engine.available})")
|
|
|
|
def analyze_and_route(self, task: str, project: str,
|
|
context: Dict[str, Any] = None) -> RoutingDecision:
|
|
"""Full analysis and routing for a task.
|
|
|
|
Args:
|
|
task: Task description
|
|
project: Target project
|
|
context: Additional context
|
|
|
|
Returns:
|
|
Complete routing decision
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# Step 1: Analyze complexity
|
|
complexity_result = self.decision_engine.analyze_complexity(task, context)
|
|
complexity = TaskComplexity(complexity_result.get("complexity", "moderate"))
|
|
|
|
# Step 2: Get routing recommendation
|
|
routing_result = self.decision_engine.route_task(
|
|
task, project, complexity_result.get("complexity", "moderate")
|
|
)
|
|
|
|
# Step 3: Build decision
|
|
agent_map = {
|
|
"flash": AgentTier.FLASH,
|
|
"haiku": AgentTier.HAIKU,
|
|
"sonnet": AgentTier.SONNET,
|
|
"opus": AgentTier.OPUS,
|
|
"pro": AgentTier.PRO
|
|
}
|
|
|
|
recommended_agent = agent_map.get(
|
|
routing_result.get("recommended_agent", "sonnet"),
|
|
AgentTier.SONNET
|
|
)
|
|
|
|
# Estimate tokens based on complexity
|
|
token_estimates = {
|
|
TaskComplexity.TRIVIAL: 500,
|
|
TaskComplexity.SIMPLE: 2000,
|
|
TaskComplexity.MODERATE: 8000,
|
|
TaskComplexity.COMPLEX: 20000,
|
|
TaskComplexity.RESEARCH: 50000
|
|
}
|
|
|
|
decision = RoutingDecision(
|
|
complexity=complexity,
|
|
recommended_agent=recommended_agent,
|
|
reasoning=f"{complexity_result.get('reasoning', '')} | {routing_result.get('reasoning', '')}",
|
|
confidence=complexity_result.get("confidence", 0.5),
|
|
suggested_steps=complexity_result.get("estimated_steps", []),
|
|
estimated_tokens=token_estimates.get(complexity, 8000),
|
|
requires_human=complexity_result.get("risk_level", "low") == "high",
|
|
validation_needed=complexity not in [TaskComplexity.TRIVIAL]
|
|
)
|
|
|
|
# Record history
|
|
elapsed = time.time() - start_time
|
|
self._record_routing(task, project, decision, elapsed)
|
|
|
|
return decision
|
|
|
|
def validate_response(self, task: str, output: str,
|
|
context: Dict[str, Any] = None) -> ValidationResult:
|
|
"""Validate agent response quality.
|
|
|
|
Args:
|
|
task: Original task
|
|
output: Agent output
|
|
context: Additional context
|
|
|
|
Returns:
|
|
Validation result with quality assessment
|
|
"""
|
|
result = self.decision_engine.validate_output(task, output, context)
|
|
|
|
return ValidationResult(
|
|
is_valid=result.get("is_valid", True),
|
|
quality_score=result.get("quality_score", 0.5),
|
|
issues=result.get("issues", []),
|
|
suggestions=result.get("suggestions", []),
|
|
needs_retry=not result.get("task_completed", True),
|
|
continuation_prompt=result.get("follow_up_prompt")
|
|
)
|
|
|
|
def should_escalate(self, task: str, error: str) -> Tuple[bool, str]:
|
|
"""Determine if a failed task should be escalated.
|
|
|
|
Args:
|
|
task: Original task
|
|
error: Error encountered
|
|
|
|
Returns:
|
|
(should_escalate, reason)
|
|
"""
|
|
# Check for patterns that need escalation
|
|
escalate_patterns = [
|
|
"permission denied",
|
|
"authentication",
|
|
"security",
|
|
"production",
|
|
"database migration",
|
|
"delete",
|
|
"remove"
|
|
]
|
|
|
|
error_lower = error.lower()
|
|
task_lower = task.lower()
|
|
|
|
for pattern in escalate_patterns:
|
|
if pattern in error_lower or pattern in task_lower:
|
|
return True, f"Task involves sensitive operation: {pattern}"
|
|
|
|
# Check if error suggests human intervention
|
|
if "requires approval" in error_lower or "blocked" in error_lower:
|
|
return True, "Task requires human approval"
|
|
|
|
return False, ""
|
|
|
|
def _record_routing(self, task: str, project: str,
|
|
decision: RoutingDecision, elapsed: float) -> None:
|
|
"""Record routing decision for learning."""
|
|
record = {
|
|
"timestamp": time.time(),
|
|
"task": task[:200], # Truncate
|
|
"project": project,
|
|
"complexity": decision.complexity.value,
|
|
"agent": decision.recommended_agent.value,
|
|
"confidence": decision.confidence,
|
|
"elapsed_ms": round(elapsed * 1000, 2)
|
|
}
|
|
|
|
self.routing_history.append(record)
|
|
|
|
# Trim history
|
|
if len(self.routing_history) > self.max_history:
|
|
self.routing_history = self.routing_history[-self.max_history:]
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get routing statistics."""
|
|
if not self.routing_history:
|
|
return {"total_routings": 0}
|
|
|
|
complexities = [r["complexity"] for r in self.routing_history]
|
|
agents = [r["agent"] for r in self.routing_history]
|
|
avg_elapsed = sum(r["elapsed_ms"] for r in self.routing_history) / len(self.routing_history)
|
|
|
|
return {
|
|
"total_routings": len(self.routing_history),
|
|
"complexity_distribution": {c: complexities.count(c) for c in set(complexities)},
|
|
"agent_distribution": {a: agents.count(a) for a in set(agents)},
|
|
"avg_routing_time_ms": round(avg_elapsed, 2),
|
|
"gemini_available": self.decision_engine.available
|
|
}
|
|
|
|
|
|
# CLI for testing
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("Smart Router - Gemini 3 Flash Decision Engine")
|
|
logger.info("=" * 60)
|
|
|
|
router = SmartRouter()
|
|
|
|
# Test tasks
|
|
test_cases = [
|
|
("List all running containers", "admin"),
|
|
("Fix the bug in track component", "musica"),
|
|
("Implement new authentication system with OAuth2", "overbits"),
|
|
("Research microservices architecture patterns", "dss"),
|
|
("Refactor the entire API layer for better performance", "musica"),
|
|
]
|
|
|
|
for task, project in test_cases:
|
|
logger.info(f"\nTask: '{task}'")
|
|
logger.info(f"Project: {project}")
|
|
|
|
decision = router.analyze_and_route(task, project)
|
|
|
|
logger.info(f" Complexity: {decision.complexity.value}")
|
|
logger.info(f" Agent: {decision.recommended_agent.value}")
|
|
logger.info(f" Confidence: {decision.confidence:.2f}")
|
|
logger.info(f" Tokens Est: {decision.estimated_tokens}")
|
|
logger.info(f" Human Required: {decision.requires_human}")
|
|
if decision.suggested_steps:
|
|
logger.info(f" Steps: {decision.suggested_steps[:3]}")
|
|
|
|
# Show stats
|
|
logger.info("\n" + "=" * 60)
|
|
stats = router.get_stats()
|
|
logger.info(f"Stats: {json.dumps(stats, indent=2)}")
|
|
logger.info("=" * 60)
|