Refactor cockpit to use DockerTmuxController pattern
Based on claude-code-tools TmuxCLIController, this refactor: - Added DockerTmuxController class for robust tmux session management - Implements send_keys() with configurable delay_enter - Implements capture_pane() for output retrieval - Implements wait_for_prompt() for pattern-based completion detection - Implements wait_for_idle() for content-hash-based idle detection - Implements wait_for_shell_prompt() for shell prompt detection Also includes workflow improvements: - Pre-task git snapshot before agent execution - Post-task commit protocol in agent guidelines Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
592
lib/smart_router.py
Normal file
592
lib/smart_router.py
Normal file
@@ -0,0 +1,592 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Smart Router - Intelligent task routing using Gemini 3 Flash for decision making.
|
||||
|
||||
Key decision points:
|
||||
1. Task Complexity Analysis - Before dispatch, assess complexity
|
||||
2. Agent Selection - Route to optimal agent/model based on task
|
||||
3. Response Validation - Check output quality before returning
|
||||
4. Continuation Decisions - Determine if follow-up is needed
|
||||
|
||||
Uses Gemini 3 Flash for fast, cost-effective decisions at critical flow points.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import time
|
||||
|
||||
# Try to import google.generativeai
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
GEMINI_AVAILABLE = True
|
||||
except ImportError:
|
||||
GEMINI_AVAILABLE = False
|
||||
genai = None
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskComplexity(Enum):
|
||||
"""Task complexity levels for routing decisions."""
|
||||
TRIVIAL = "trivial" # Simple command, quick lookup
|
||||
SIMPLE = "simple" # Single-step task, clear path
|
||||
MODERATE = "moderate" # Multi-step, some reasoning needed
|
||||
COMPLEX = "complex" # Deep analysis, multi-agent coordination
|
||||
RESEARCH = "research" # Open-ended exploration
|
||||
|
||||
|
||||
class AgentTier(Enum):
|
||||
"""Agent tiers for model selection."""
|
||||
FLASH = "flash" # Gemini Flash - fast decisions
|
||||
HAIKU = "haiku" # Claude Haiku - quick tasks
|
||||
SONNET = "sonnet" # Claude Sonnet - balanced
|
||||
OPUS = "opus" # Claude Opus - complex tasks
|
||||
PRO = "pro" # Gemini Pro - deep reasoning
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Result of routing analysis."""
|
||||
complexity: TaskComplexity
|
||||
recommended_agent: AgentTier
|
||||
reasoning: str
|
||||
confidence: float
|
||||
suggested_steps: List[str]
|
||||
estimated_tokens: int
|
||||
requires_human: bool = False
|
||||
validation_needed: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of output validation."""
|
||||
is_valid: bool
|
||||
quality_score: float # 0-1
|
||||
issues: List[str]
|
||||
suggestions: List[str]
|
||||
needs_retry: bool = False
|
||||
continuation_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class GeminiDecisionEngine:
|
||||
"""Gemini Flash-powered decision engine for fast routing decisions."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize Gemini decision engine.
|
||||
|
||||
Args:
|
||||
api_key: Gemini API key (defaults to env var)
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
||||
self.model = None
|
||||
self.available = False
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize Gemini client."""
|
||||
if not GEMINI_AVAILABLE:
|
||||
logger.warning("google-generativeai not installed - falling back to heuristics")
|
||||
return
|
||||
|
||||
if not self.api_key:
|
||||
# Try multiple sources for API key
|
||||
api_key_sources = [
|
||||
"/opt/pal-mcp-server/.env", # PAL MCP server env (primary)
|
||||
"/etc/shared-ai-credentials/gemini/api-key", # Shared credentials
|
||||
]
|
||||
|
||||
for source in api_key_sources:
|
||||
try:
|
||||
if source.endswith('.env'):
|
||||
# Parse .env file
|
||||
with open(source, "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("GEMINI_API_KEY="):
|
||||
self.api_key = line.split("=", 1)[1].strip().strip('"\'')
|
||||
break
|
||||
else:
|
||||
# Plain text file
|
||||
with open(source, "r") as f:
|
||||
self.api_key = f.read().strip()
|
||||
|
||||
if self.api_key:
|
||||
logger.info(f"Gemini API key loaded from {source}")
|
||||
break
|
||||
except (FileNotFoundError, PermissionError):
|
||||
continue
|
||||
|
||||
if not self.api_key:
|
||||
logger.warning("Gemini API key not found - falling back to heuristics")
|
||||
return
|
||||
|
||||
try:
|
||||
genai.configure(api_key=self.api_key)
|
||||
self.model = genai.GenerativeModel("gemini-2.0-flash")
|
||||
self.available = True
|
||||
logger.info("Gemini decision engine initialized (gemini-2.0-flash)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Gemini: {e}")
|
||||
|
||||
def analyze_complexity(self, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Analyze task complexity using Gemini Flash.
|
||||
|
||||
Args:
|
||||
task: Task description
|
||||
context: Optional context about project, history
|
||||
|
||||
Returns:
|
||||
Complexity analysis result
|
||||
"""
|
||||
if not self.available:
|
||||
return self._heuristic_complexity(task)
|
||||
|
||||
prompt = f"""Analyze this task's complexity for routing to an AI agent.
|
||||
|
||||
TASK: {task}
|
||||
|
||||
CONTEXT: {json.dumps(context or {}, indent=2)}
|
||||
|
||||
Respond in JSON:
|
||||
{{
|
||||
"complexity": "trivial|simple|moderate|complex|research",
|
||||
"reasoning": "brief explanation",
|
||||
"confidence": 0.0-1.0,
|
||||
"estimated_steps": ["step1", "step2"],
|
||||
"requires_code_changes": true/false,
|
||||
"requires_file_reads": true/false,
|
||||
"requires_external_calls": true/false,
|
||||
"risk_level": "low|medium|high"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
prompt,
|
||||
generation_config=genai.GenerationConfig(
|
||||
temperature=0.1,
|
||||
max_output_tokens=500
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON from response
|
||||
text = response.text.strip()
|
||||
# Handle markdown code blocks
|
||||
if text.startswith("```"):
|
||||
text = text.split("```")[1]
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
|
||||
return json.loads(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Gemini complexity analysis failed: {e}")
|
||||
return self._heuristic_complexity(task)
|
||||
|
||||
def _heuristic_complexity(self, task: str) -> Dict[str, Any]:
|
||||
"""Fallback heuristic-based complexity analysis."""
|
||||
task_lower = task.lower()
|
||||
|
||||
# Simple keyword matching for fallback
|
||||
if any(word in task_lower for word in ["list", "show", "check", "status", "what is"]):
|
||||
complexity = "trivial"
|
||||
confidence = 0.7
|
||||
elif any(word in task_lower for word in ["fix", "update", "change", "add"]):
|
||||
complexity = "simple"
|
||||
confidence = 0.6
|
||||
elif any(word in task_lower for word in ["implement", "create", "build", "develop"]):
|
||||
complexity = "moderate"
|
||||
confidence = 0.5
|
||||
elif any(word in task_lower for word in ["refactor", "optimize", "debug", "investigate"]):
|
||||
complexity = "complex"
|
||||
confidence = 0.5
|
||||
elif any(word in task_lower for word in ["research", "analyze", "design", "architect"]):
|
||||
complexity = "research"
|
||||
confidence = 0.5
|
||||
else:
|
||||
complexity = "moderate"
|
||||
confidence = 0.4
|
||||
|
||||
return {
|
||||
"complexity": complexity,
|
||||
"reasoning": "Heuristic analysis (Gemini unavailable)",
|
||||
"confidence": confidence,
|
||||
"estimated_steps": [],
|
||||
"requires_code_changes": "implement" in task_lower or "fix" in task_lower,
|
||||
"requires_file_reads": True,
|
||||
"requires_external_calls": False,
|
||||
"risk_level": "medium"
|
||||
}
|
||||
|
||||
def validate_output(self, task: str, output: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Validate agent output quality.
|
||||
|
||||
Args:
|
||||
task: Original task
|
||||
output: Agent's output
|
||||
context: Additional context
|
||||
|
||||
Returns:
|
||||
Validation result
|
||||
"""
|
||||
if not self.available:
|
||||
return self._heuristic_validation(task, output)
|
||||
|
||||
# Truncate output for validation (avoid huge prompts)
|
||||
output_truncated = output[:3000] if len(output) > 3000 else output
|
||||
|
||||
prompt = f"""Validate this AI agent's response to a task.
|
||||
|
||||
TASK: {task}
|
||||
|
||||
RESPONSE (may be truncated):
|
||||
{output_truncated}
|
||||
|
||||
Respond in JSON:
|
||||
{{
|
||||
"is_valid": true/false,
|
||||
"quality_score": 0.0-1.0,
|
||||
"issues": ["issue1", "issue2"],
|
||||
"suggestions": ["suggestion1"],
|
||||
"task_completed": true/false,
|
||||
"needs_follow_up": true/false,
|
||||
"follow_up_prompt": "optional continuation prompt"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
prompt,
|
||||
generation_config=genai.GenerationConfig(
|
||||
temperature=0.1,
|
||||
max_output_tokens=500
|
||||
)
|
||||
)
|
||||
|
||||
text = response.text.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("```")[1]
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
|
||||
return json.loads(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Gemini validation failed: {e}")
|
||||
return self._heuristic_validation(task, output)
|
||||
|
||||
def _heuristic_validation(self, task: str, output: str) -> Dict[str, Any]:
|
||||
"""Fallback heuristic output validation."""
|
||||
# Basic checks
|
||||
has_content = len(output.strip()) > 50
|
||||
has_code = "```" in output or "def " in output or "function " in output
|
||||
has_error = "error" in output.lower() or "failed" in output.lower()
|
||||
|
||||
quality = 0.5
|
||||
if has_content:
|
||||
quality += 0.2
|
||||
if has_code and ("implement" in task.lower() or "code" in task.lower()):
|
||||
quality += 0.2
|
||||
if has_error:
|
||||
quality -= 0.3
|
||||
|
||||
return {
|
||||
"is_valid": has_content and not has_error,
|
||||
"quality_score": max(0.0, min(1.0, quality)),
|
||||
"issues": ["Error detected in output"] if has_error else [],
|
||||
"suggestions": [],
|
||||
"task_completed": has_content,
|
||||
"needs_follow_up": has_error,
|
||||
"follow_up_prompt": "Please fix the errors and try again" if has_error else None
|
||||
}
|
||||
|
||||
def route_task(self, task: str, project: str, complexity: str) -> Dict[str, Any]:
|
||||
"""Determine optimal agent/model for task.
|
||||
|
||||
Args:
|
||||
task: Task description
|
||||
project: Target project
|
||||
complexity: Pre-analyzed complexity
|
||||
|
||||
Returns:
|
||||
Routing recommendation
|
||||
"""
|
||||
if not self.available:
|
||||
return self._heuristic_routing(task, project, complexity)
|
||||
|
||||
prompt = f"""Recommend the best AI agent configuration for this task.
|
||||
|
||||
TASK: {task}
|
||||
PROJECT: {project}
|
||||
COMPLEXITY: {complexity}
|
||||
|
||||
Available agents:
|
||||
- flash: Gemini Flash - Fast, cheap, good for simple tasks
|
||||
- haiku: Claude Haiku - Quick, efficient, good for straightforward coding
|
||||
- sonnet: Claude Sonnet - Balanced, good for most development tasks
|
||||
- opus: Claude Opus - Most capable, for complex analysis
|
||||
- pro: Gemini Pro - Deep reasoning, research tasks
|
||||
|
||||
Respond in JSON:
|
||||
{{
|
||||
"recommended_agent": "flash|haiku|sonnet|opus|pro",
|
||||
"reasoning": "why this agent",
|
||||
"backup_agent": "alternative if first fails",
|
||||
"special_instructions": "any task-specific guidance",
|
||||
"estimated_time": "quick|moderate|long",
|
||||
"suggested_tools": ["Read", "Edit", "Bash"]
|
||||
}}"""
|
||||
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
prompt,
|
||||
generation_config=genai.GenerationConfig(
|
||||
temperature=0.1,
|
||||
max_output_tokens=400
|
||||
)
|
||||
)
|
||||
|
||||
text = response.text.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("```")[1]
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
|
||||
return json.loads(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Gemini routing failed: {e}")
|
||||
return self._heuristic_routing(task, project, complexity)
|
||||
|
||||
def _heuristic_routing(self, task: str, project: str, complexity: str) -> Dict[str, Any]:
|
||||
"""Fallback heuristic task routing."""
|
||||
# Map complexity to agent
|
||||
complexity_to_agent = {
|
||||
"trivial": "haiku",
|
||||
"simple": "haiku",
|
||||
"moderate": "sonnet",
|
||||
"complex": "sonnet",
|
||||
"research": "pro"
|
||||
}
|
||||
|
||||
return {
|
||||
"recommended_agent": complexity_to_agent.get(complexity, "sonnet"),
|
||||
"reasoning": f"Heuristic routing for {complexity} task",
|
||||
"backup_agent": "sonnet",
|
||||
"special_instructions": None,
|
||||
"estimated_time": "moderate",
|
||||
"suggested_tools": ["Read", "Edit", "Bash", "Glob", "Grep"]
|
||||
}
|
||||
|
||||
|
||||
class SmartRouter:
|
||||
"""Main smart routing orchestrator integrating Gemini decisions."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize smart router.
|
||||
|
||||
Args:
|
||||
api_key: Optional Gemini API key
|
||||
"""
|
||||
self.decision_engine = GeminiDecisionEngine(api_key)
|
||||
self.routing_history: List[Dict[str, Any]] = []
|
||||
self.max_history = 100
|
||||
logger.info(f"SmartRouter initialized (Gemini: {self.decision_engine.available})")
|
||||
|
||||
def analyze_and_route(self, task: str, project: str,
|
||||
context: Dict[str, Any] = None) -> RoutingDecision:
|
||||
"""Full analysis and routing for a task.
|
||||
|
||||
Args:
|
||||
task: Task description
|
||||
project: Target project
|
||||
context: Additional context
|
||||
|
||||
Returns:
|
||||
Complete routing decision
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Step 1: Analyze complexity
|
||||
complexity_result = self.decision_engine.analyze_complexity(task, context)
|
||||
complexity = TaskComplexity(complexity_result.get("complexity", "moderate"))
|
||||
|
||||
# Step 2: Get routing recommendation
|
||||
routing_result = self.decision_engine.route_task(
|
||||
task, project, complexity_result.get("complexity", "moderate")
|
||||
)
|
||||
|
||||
# Step 3: Build decision
|
||||
agent_map = {
|
||||
"flash": AgentTier.FLASH,
|
||||
"haiku": AgentTier.HAIKU,
|
||||
"sonnet": AgentTier.SONNET,
|
||||
"opus": AgentTier.OPUS,
|
||||
"pro": AgentTier.PRO
|
||||
}
|
||||
|
||||
recommended_agent = agent_map.get(
|
||||
routing_result.get("recommended_agent", "sonnet"),
|
||||
AgentTier.SONNET
|
||||
)
|
||||
|
||||
# Estimate tokens based on complexity
|
||||
token_estimates = {
|
||||
TaskComplexity.TRIVIAL: 500,
|
||||
TaskComplexity.SIMPLE: 2000,
|
||||
TaskComplexity.MODERATE: 8000,
|
||||
TaskComplexity.COMPLEX: 20000,
|
||||
TaskComplexity.RESEARCH: 50000
|
||||
}
|
||||
|
||||
decision = RoutingDecision(
|
||||
complexity=complexity,
|
||||
recommended_agent=recommended_agent,
|
||||
reasoning=f"{complexity_result.get('reasoning', '')} | {routing_result.get('reasoning', '')}",
|
||||
confidence=complexity_result.get("confidence", 0.5),
|
||||
suggested_steps=complexity_result.get("estimated_steps", []),
|
||||
estimated_tokens=token_estimates.get(complexity, 8000),
|
||||
requires_human=complexity_result.get("risk_level", "low") == "high",
|
||||
validation_needed=complexity not in [TaskComplexity.TRIVIAL]
|
||||
)
|
||||
|
||||
# Record history
|
||||
elapsed = time.time() - start_time
|
||||
self._record_routing(task, project, decision, elapsed)
|
||||
|
||||
return decision
|
||||
|
||||
def validate_response(self, task: str, output: str,
|
||||
context: Dict[str, Any] = None) -> ValidationResult:
|
||||
"""Validate agent response quality.
|
||||
|
||||
Args:
|
||||
task: Original task
|
||||
output: Agent output
|
||||
context: Additional context
|
||||
|
||||
Returns:
|
||||
Validation result with quality assessment
|
||||
"""
|
||||
result = self.decision_engine.validate_output(task, output, context)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=result.get("is_valid", True),
|
||||
quality_score=result.get("quality_score", 0.5),
|
||||
issues=result.get("issues", []),
|
||||
suggestions=result.get("suggestions", []),
|
||||
needs_retry=not result.get("task_completed", True),
|
||||
continuation_prompt=result.get("follow_up_prompt")
|
||||
)
|
||||
|
||||
def should_escalate(self, task: str, error: str) -> Tuple[bool, str]:
|
||||
"""Determine if a failed task should be escalated.
|
||||
|
||||
Args:
|
||||
task: Original task
|
||||
error: Error encountered
|
||||
|
||||
Returns:
|
||||
(should_escalate, reason)
|
||||
"""
|
||||
# Check for patterns that need escalation
|
||||
escalate_patterns = [
|
||||
"permission denied",
|
||||
"authentication",
|
||||
"security",
|
||||
"production",
|
||||
"database migration",
|
||||
"delete",
|
||||
"remove"
|
||||
]
|
||||
|
||||
error_lower = error.lower()
|
||||
task_lower = task.lower()
|
||||
|
||||
for pattern in escalate_patterns:
|
||||
if pattern in error_lower or pattern in task_lower:
|
||||
return True, f"Task involves sensitive operation: {pattern}"
|
||||
|
||||
# Check if error suggests human intervention
|
||||
if "requires approval" in error_lower or "blocked" in error_lower:
|
||||
return True, "Task requires human approval"
|
||||
|
||||
return False, ""
|
||||
|
||||
def _record_routing(self, task: str, project: str,
|
||||
decision: RoutingDecision, elapsed: float) -> None:
|
||||
"""Record routing decision for learning."""
|
||||
record = {
|
||||
"timestamp": time.time(),
|
||||
"task": task[:200], # Truncate
|
||||
"project": project,
|
||||
"complexity": decision.complexity.value,
|
||||
"agent": decision.recommended_agent.value,
|
||||
"confidence": decision.confidence,
|
||||
"elapsed_ms": round(elapsed * 1000, 2)
|
||||
}
|
||||
|
||||
self.routing_history.append(record)
|
||||
|
||||
# Trim history
|
||||
if len(self.routing_history) > self.max_history:
|
||||
self.routing_history = self.routing_history[-self.max_history:]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get routing statistics."""
|
||||
if not self.routing_history:
|
||||
return {"total_routings": 0}
|
||||
|
||||
complexities = [r["complexity"] for r in self.routing_history]
|
||||
agents = [r["agent"] for r in self.routing_history]
|
||||
avg_elapsed = sum(r["elapsed_ms"] for r in self.routing_history) / len(self.routing_history)
|
||||
|
||||
return {
|
||||
"total_routings": len(self.routing_history),
|
||||
"complexity_distribution": {c: complexities.count(c) for c in set(complexities)},
|
||||
"agent_distribution": {a: agents.count(a) for a in set(agents)},
|
||||
"avg_routing_time_ms": round(avg_elapsed, 2),
|
||||
"gemini_available": self.decision_engine.available
|
||||
}
|
||||
|
||||
|
||||
# CLI for testing
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Smart Router - Gemini 3 Flash Decision Engine")
|
||||
logger.info("=" * 60)
|
||||
|
||||
router = SmartRouter()
|
||||
|
||||
# Test tasks
|
||||
test_cases = [
|
||||
("List all running containers", "admin"),
|
||||
("Fix the bug in track component", "musica"),
|
||||
("Implement new authentication system with OAuth2", "overbits"),
|
||||
("Research microservices architecture patterns", "dss"),
|
||||
("Refactor the entire API layer for better performance", "musica"),
|
||||
]
|
||||
|
||||
for task, project in test_cases:
|
||||
logger.info(f"\nTask: '{task}'")
|
||||
logger.info(f"Project: {project}")
|
||||
|
||||
decision = router.analyze_and_route(task, project)
|
||||
|
||||
logger.info(f" Complexity: {decision.complexity.value}")
|
||||
logger.info(f" Agent: {decision.recommended_agent.value}")
|
||||
logger.info(f" Confidence: {decision.confidence:.2f}")
|
||||
logger.info(f" Tokens Est: {decision.estimated_tokens}")
|
||||
logger.info(f" Human Required: {decision.requires_human}")
|
||||
if decision.suggested_steps:
|
||||
logger.info(f" Steps: {decision.suggested_steps[:3]}")
|
||||
|
||||
# Show stats
|
||||
logger.info("\n" + "=" * 60)
|
||||
stats = router.get_stats()
|
||||
logger.info(f"Stats: {json.dumps(stats, indent=2)}")
|
||||
logger.info("=" * 60)
|
||||
Reference in New Issue
Block a user