Files
asif2bd_openclaw-token-opti…/scripts/model_router.py

439 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Smart model router - routes tasks to appropriate models based on complexity.
Supports multiple providers: Anthropic, OpenAI, Google, OpenRouter.
Helps reduce token costs by using cheaper models for simpler tasks.
Version: 1.1.0
"""
import re
import os
import json
# ============================================================================
# PROVIDER CONFIGURATION
# ============================================================================
# Detect primary provider from environment (default: anthropic)
def detect_provider():
"""Detect which provider to use based on available API keys."""
if os.environ.get("ANTHROPIC_API_KEY"):
return "anthropic"
elif os.environ.get("OPENAI_API_KEY"):
return "openai"
elif os.environ.get("GOOGLE_API_KEY"):
return "google"
elif os.environ.get("OPENROUTER_API_KEY"):
return "openrouter"
# Default to anthropic
return "anthropic"
# Model tiers per provider
PROVIDER_MODELS = {
"anthropic": {
"cheap": "anthropic/claude-haiku-4",
"balanced": "anthropic/claude-sonnet-4-5",
"smart": "anthropic/claude-opus-4",
"costs": { # $/MTok (input)
"cheap": 0.25,
"balanced": 3.00,
"smart": 15.00
}
},
"openai": {
"cheap": "openai/gpt-4.1-nano",
"balanced": "openai/gpt-4.1-mini",
"smart": "openai/gpt-4.1",
"premium": "openai/gpt-5",
"costs": {
"cheap": 0.10,
"balanced": 0.40,
"smart": 2.00,
"premium": 10.00
}
},
"google": {
"cheap": "google/gemini-2.0-flash",
"balanced": "google/gemini-2.5-flash",
"smart": "google/gemini-2.5-pro",
"costs": {
"cheap": 0.075,
"balanced": 0.15,
"smart": 1.25
}
},
"openrouter": {
"cheap": "google/gemini-2.0-flash",
"balanced": "anthropic/claude-sonnet-4-5",
"smart": "anthropic/claude-opus-4",
"costs": {
"cheap": 0.075,
"balanced": 3.00,
"smart": 15.00
}
}
}
# Tier mapping for cross-provider compatibility
TIER_ALIASES = {
"haiku": "cheap",
"sonnet": "balanced",
"opus": "smart",
"nano": "cheap",
"mini": "balanced",
"flash": "cheap",
"pro": "smart"
}
# ============================================================================
# TASK CLASSIFICATION PATTERNS
# ============================================================================
# Communication patterns that should ALWAYS use cheap tier (never balanced/smart)
COMMUNICATION_PATTERNS = [
r'^(hi|hey|hello|yo|sup)\b',
r'^(thanks|thank you|thx)\b',
r'^(ok|okay|sure|got it|understood)\b',
r'^(yes|yeah|yep|yup|no|nope)\b',
r'^(good|great|nice|cool|awesome)\b',
r"^(what|how)'s (up|it going)",
r'^\w{1,15}$', # Single short word
r'^(lol|haha|lmao)\b',
]
# Background/routine tasks that should ALWAYS use cheap tier
BACKGROUND_TASK_PATTERNS = [
# Heartbeat checks
r'heartbeat',
r'check\s+(email|calendar|weather|monitoring)',
r'monitor\s+',
r'poll\s+',
# Cronjob/scheduled tasks
r'cron',
r'scheduled\s+task',
r'periodic\s+check',
r'reminder',
# Document parsing/extraction
r'parse\s+(document|file|log|csv|json|xml)',
r'extract\s+(text|data|content)\s+from',
r'read\s+(log|logs)',
r'scan\s+(file|document)',
r'process\s+(csv|json|xml|yaml)',
]
# Model routing rules with tier-based approach
ROUTING_RULES = {
"cheap": {
"patterns": [
r"read\s+file",
r"list\s+files",
r"show\s+(me\s+)?the\s+contents?",
r"what's\s+in",
r"cat\s+",
r"get\s+status",
r"check\s+(if|whether)",
r"is\s+\w+\s+(running|active|enabled)"
],
"keywords": ["read", "list", "show", "status", "check", "get"],
"cost_multiplier": 0.083 # vs balanced
},
"balanced": {
"patterns": [
r"write\s+\w+",
r"create\s+\w+",
r"edit\s+\w+",
r"fix\s+\w+",
r"debug\s+\w+",
r"explain\s+\w+",
r"how\s+(do|can)\s+i"
],
"keywords": ["write", "create", "edit", "update", "fix", "debug", "explain"],
"cost_multiplier": 1.0
},
"smart": {
"patterns": [
r"complex\s+\w+",
r"design\s+\w+",
r"architect\w+",
r"analyze\s+deeply",
r"comprehensive\s+\w+"
],
"keywords": ["design", "architect", "complex", "comprehensive", "deep"],
"cost_multiplier": 5.0
}
}
# Legacy tier names for backwards compatibility
LEGACY_TIER_MAP = {
"haiku": "cheap",
"sonnet": "balanced",
"opus": "smart"
}
# ============================================================================
# CORE FUNCTIONS
# ============================================================================
def classify_task(prompt):
"""Classify task complexity based on prompt text.
Args:
prompt: User's message/request
Returns:
tuple of (tier, confidence, reasoning)
tier is one of: cheap, balanced, smart
"""
prompt_lower = prompt.lower()
# FIRST: Check if this is simple communication (ALWAYS cheap)
for pattern in COMMUNICATION_PATTERNS:
if re.search(pattern, prompt_lower):
return ("cheap", 1.0, "Simple communication - use cheapest model")
# SECOND: Check if this is a background/routine task (ALWAYS cheap)
for pattern in BACKGROUND_TASK_PATTERNS:
if re.search(pattern, prompt_lower):
return ("cheap", 1.0, "Background task (heartbeat/cron/parsing) - use cheapest model")
# Score each tier
scores = {}
for tier, rules in ROUTING_RULES.items():
score = 0
matches = []
# Pattern matching
for pattern in rules["patterns"]:
if re.search(pattern, prompt_lower):
score += 2
matches.append(f"pattern:{pattern}")
# Keyword matching
for keyword in rules["keywords"]:
if keyword in prompt_lower:
score += 1
matches.append(f"keyword:{keyword}")
scores[tier] = {
"score": score,
"matches": matches
}
# Determine best tier
best_tier = max(scores.items(), key=lambda x: x[1]["score"])
if best_tier[1]["score"] == 0:
# Default to balanced if unclear
return ("balanced", 0.5, "No clear indicators, defaulting to balanced model")
confidence = min(best_tier[1]["score"] / 5.0, 1.0) # Cap at 1.0
reasoning = f"Matched: {', '.join(best_tier[1]['matches'][:3])}"
return (best_tier[0], confidence, reasoning)
def normalize_tier(tier):
"""Normalize tier name to standard format (cheap/balanced/smart)."""
tier_lower = tier.lower()
# Check legacy mappings
if tier_lower in LEGACY_TIER_MAP:
return LEGACY_TIER_MAP[tier_lower]
# Check aliases
if tier_lower in TIER_ALIASES:
return TIER_ALIASES[tier_lower]
# Already standard or unknown
if tier_lower in ["cheap", "balanced", "smart", "premium"]:
return tier_lower
return "balanced" # Default
def get_model_for_tier(tier, provider=None):
"""Get the specific model name for a tier and provider.
Args:
tier: cheap, balanced, smart, or premium
provider: anthropic, openai, google, openrouter (or None to auto-detect)
Returns:
Model identifier string
"""
if provider is None:
provider = detect_provider()
provider_config = PROVIDER_MODELS.get(provider, PROVIDER_MODELS["anthropic"])
# Normalize tier
tier = normalize_tier(tier)
# Get model (fallback to balanced if tier not available)
return provider_config.get(tier, provider_config.get("balanced"))
def route_task(prompt, current_model=None, force_tier=None, provider=None):
"""Route a task to appropriate model.
Args:
prompt: User's message/request
current_model: Current model being used (optional)
force_tier: Override classification (cheap/balanced/smart or haiku/sonnet/opus)
provider: Force specific provider (anthropic/openai/google/openrouter)
Returns:
dict with routing decision
"""
# Auto-detect provider if not specified
if provider is None:
provider = detect_provider()
# Set default current model
if current_model is None:
current_model = get_model_for_tier("balanced", provider)
if force_tier:
tier = normalize_tier(force_tier)
confidence = 1.0
reasoning = "User-specified tier"
else:
tier, confidence, reasoning = classify_task(prompt)
recommended_model = get_model_for_tier(tier, provider)
# Calculate cost savings
provider_config = PROVIDER_MODELS.get(provider, PROVIDER_MODELS["anthropic"])
base_cost = provider_config["costs"].get("balanced", 1.0)
tier_cost = provider_config["costs"].get(tier, base_cost)
cost_savings = (1.0 - (tier_cost / base_cost)) * 100
return {
"provider": provider,
"current_model": current_model,
"recommended_model": recommended_model,
"tier": tier,
"tier_display": {
"cheap": "Cheap (Haiku/Nano/Flash)",
"balanced": "Balanced (Sonnet/Mini/Flash)",
"smart": "Smart (Opus/GPT-4.1/Pro)",
"premium": "Premium (GPT-5)"
}.get(tier, tier),
"confidence": confidence,
"reasoning": reasoning,
"cost_savings_percent": max(0, cost_savings),
"should_switch": recommended_model != current_model,
"all_providers": {
p: get_model_for_tier(tier, p) for p in PROVIDER_MODELS.keys()
}
}
def get_model_comparison():
"""Get a comparison of all models across providers.
Returns:
dict with provider -> tier -> model mapping
"""
result = {}
for provider, config in PROVIDER_MODELS.items():
result[provider] = {
tier: {
"model": model,
"cost_per_mtok": config["costs"].get(tier, "N/A")
}
for tier, model in config.items()
if tier != "costs"
}
return result
# ============================================================================
# CLI INTERFACE
# ============================================================================
def main():
"""CLI interface for model router."""
import sys
if len(sys.argv) < 2:
print("Usage: model_router.py <command> [args]")
print("")
print("Commands:")
print(" route '<prompt>' [current_model] [force_tier] [provider]")
print(" compare — Show all models across providers")
print(" providers — List available providers")
print(" detect — Show auto-detected provider")
print("")
print("Examples:")
print(" model_router.py route 'thanks!'")
print(" model_router.py route 'design an architecture' --provider openai")
print(" model_router.py compare")
sys.exit(1)
command = sys.argv[1]
# Known commands
known_commands = ["route", "compare", "providers", "detect"]
if command == "route" or command not in known_commands:
# Route a prompt (either explicit "route" command or shorthand)
if command == "route":
if len(sys.argv) < 3:
print("Usage: model_router.py route '<prompt>'")
sys.exit(1)
prompt = sys.argv[2]
start_idx = 3
else:
# Shorthand: first arg is the prompt
prompt = command
start_idx = 2
# Parse remaining args
current_model = None
force_tier = None
provider = None
i = start_idx
while i < len(sys.argv):
arg = sys.argv[i]
if arg.startswith("--provider="):
provider = arg.split("=")[1]
elif arg.startswith("--tier="):
force_tier = arg.split("=")[1]
elif arg == "--provider" and i+1 < len(sys.argv):
provider = sys.argv[i+1]
i += 1
elif arg == "--tier" and i+1 < len(sys.argv):
force_tier = sys.argv[i+1]
i += 1
elif arg.startswith("--"):
pass # Skip unknown flags
elif current_model is None and "/" in arg:
current_model = arg
elif force_tier is None:
force_tier = arg
i += 1
result = route_task(prompt, current_model, force_tier, provider)
print(json.dumps(result, indent=2))
elif command == "compare":
result = get_model_comparison()
print(json.dumps(result, indent=2))
elif command == "providers":
print("Available providers:")
for provider in PROVIDER_MODELS.keys():
detected = " (detected)" if provider == detect_provider() else ""
print(f" - {provider}{detected}")
elif command == "detect":
provider = detect_provider()
print(f"Auto-detected provider: {provider}")
print(f"Models: {json.dumps(PROVIDER_MODELS[provider], indent=2)}")
else:
print(f"Unknown command: {command}")
sys.exit(1)
if __name__ == "__main__":
main()