feat(models): add model selection API endpoints
- GET /models/available: query both llama-swap instances for model lists - POST /models/select: set per-persona model (regular/evil/japanese) with persistence - GET /models/status: return current per-persona model assignments - Fall back to known model list when containers are unreachable
This commit is contained in:
161
bot/routes/models_selector.py
Normal file
161
bot/routes/models_selector.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Model selection routes: query available models and set per-persona models."""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import globals
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger('api')
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Known model names from llama-swap configs (fallback if API query fails)
|
||||
KNOWN_MODELS = [
|
||||
"llama3.1",
|
||||
"darkidol",
|
||||
"swallow",
|
||||
"vision",
|
||||
"rocinante",
|
||||
"qwen3.5",
|
||||
]
|
||||
|
||||
# Which GPU each model is available on
|
||||
MODEL_GPU_MAP = {
|
||||
"llama3.1": {"nvidia", "amd"},
|
||||
"darkidol": {"nvidia", "amd"},
|
||||
"swallow": {"nvidia", "amd"},
|
||||
"vision": {"nvidia"},
|
||||
"rocinante": {"amd"},
|
||||
"qwen3.5": {"amd"},
|
||||
}
|
||||
|
||||
|
||||
async def _query_llama_swap_models(url: str, timeout: int = 10) -> list:
|
||||
"""Query a llama-swap instance for its available models via /v1/models."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{url}/v1/models",
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
# OpenAI-compatible format: { data: [{ id: "model_name", ... }] }
|
||||
return [m["id"] for m in data.get("data", []) if "id" in m]
|
||||
else:
|
||||
logger.warning(f"llama-swap models query failed ({resp.status}) for {url}")
|
||||
return []
|
||||
except (asyncio.TimeoutError, aiohttp.ClientError) as e:
|
||||
logger.warning(f"llama-swap unreachable at {url}: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected error querying {url}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/models/available")
|
||||
async def get_available_models():
|
||||
"""
|
||||
Query both NVIDIA and AMD llama-swap instances for available models.
|
||||
Returns model lists per GPU, their intersection, and all unique models.
|
||||
Falls back to known model list if containers are unreachable.
|
||||
"""
|
||||
nvidia_models = await _query_llama_swap_models(globals.LLAMA_URL)
|
||||
amd_models = await _query_llama_swap_models(globals.LLAMA_AMD_URL)
|
||||
|
||||
# If both failed, use the known model list from configs
|
||||
if not nvidia_models and not amd_models:
|
||||
logger.info("Both llama-swap instances unreachable, using known model list")
|
||||
nvidia_set = {m for m, gpus in MODEL_GPU_MAP.items() if "nvidia" in gpus}
|
||||
amd_set = {m for m, gpus in MODEL_GPU_MAP.items() if "amd" in gpus}
|
||||
return {
|
||||
"success": True,
|
||||
"nvidia": sorted(nvidia_set),
|
||||
"amd": sorted(amd_set),
|
||||
"intersection": sorted(nvidia_set & amd_set),
|
||||
"all": sorted(nvidia_set | amd_set),
|
||||
"gpu_map": MODEL_GPU_MAP,
|
||||
"source": "fallback",
|
||||
}
|
||||
|
||||
nvidia_set = set(nvidia_models)
|
||||
amd_set = set(amd_models)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"nvidia": sorted(nvidia_set),
|
||||
"amd": sorted(amd_set),
|
||||
"intersection": sorted(nvidia_set & amd_set),
|
||||
"all": sorted(nvidia_set | amd_set),
|
||||
"gpu_map": MODEL_GPU_MAP,
|
||||
"source": "live",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/models/select")
|
||||
async def select_model(body: dict):
|
||||
"""
|
||||
Set the model for a specific persona.
|
||||
|
||||
Body: {
|
||||
"persona": "regular" | "evil" | "japanese",
|
||||
"model": "model_name"
|
||||
}
|
||||
|
||||
Persists the selection so it survives bot restarts.
|
||||
"""
|
||||
persona = body.get("persona", "").strip().lower()
|
||||
model = body.get("model", "").strip()
|
||||
|
||||
valid_personas = {"regular", "evil", "japanese"}
|
||||
if persona not in valid_personas:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"success": False, "error": f"Invalid persona '{persona}'. Must be one of: {', '.join(valid_personas)}"}
|
||||
)
|
||||
|
||||
if not model:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"success": False, "error": "model is required"}
|
||||
)
|
||||
|
||||
# Map persona to globals attribute and config key
|
||||
PERSONA_MAP = {
|
||||
"regular": ("TEXT_MODEL", "models.text"),
|
||||
"evil": ("EVIL_TEXT_MODEL", "models.evil"),
|
||||
"japanese": ("JAPANESE_TEXT_MODEL", "models.japanese"),
|
||||
}
|
||||
|
||||
attr_name, config_key = PERSONA_MAP[persona]
|
||||
|
||||
# Set the global
|
||||
setattr(globals, attr_name, model)
|
||||
logger.info(f"Model selection: {persona} → {model} (globals.{attr_name})")
|
||||
|
||||
# Persist via config manager
|
||||
try:
|
||||
from config_manager import config_manager
|
||||
config_manager.set(config_key, model, persist=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to persist model selection: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"persona": persona,
|
||||
"model": model,
|
||||
"message": f"{persona.capitalize()} model set to '{model}'",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models/status")
|
||||
async def get_model_status():
|
||||
"""Return the current per-persona model assignments."""
|
||||
return {
|
||||
"success": True,
|
||||
"regular": getattr(globals, "TEXT_MODEL", "llama3.1"),
|
||||
"evil": getattr(globals, "EVIL_TEXT_MODEL", "darkidol"),
|
||||
"japanese": getattr(globals, "JAPANESE_TEXT_MODEL", "swallow"),
|
||||
}
|
||||
Reference in New Issue
Block a user