feat(models): add model selection API endpoints

- GET /models/available: query both llama-swap instances for model lists
- POST /models/select: set per-persona model (regular/evil/japanese) with persistence
- GET /models/status: return current per-persona model assignments
- Fall back to known model list when containers are unreachable
This commit is contained in:
2026-05-20 13:54:59 +03:00
parent 7cb21a372b
commit ae4e40f2d7

View File

@@ -0,0 +1,161 @@
"""Model selection routes: query available models and set per-persona models."""
import aiohttp
import asyncio
import globals
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from utils.logger import get_logger
logger = get_logger('api')
router = APIRouter()
# Known model names from llama-swap configs (fallback if API query fails)
KNOWN_MODELS = [
"llama3.1",
"darkidol",
"swallow",
"vision",
"rocinante",
"qwen3.5",
]
# Which GPU each model is available on
MODEL_GPU_MAP = {
"llama3.1": {"nvidia", "amd"},
"darkidol": {"nvidia", "amd"},
"swallow": {"nvidia", "amd"},
"vision": {"nvidia"},
"rocinante": {"amd"},
"qwen3.5": {"amd"},
}
async def _query_llama_swap_models(url: str, timeout: int = 10) -> list:
"""Query a llama-swap instance for its available models via /v1/models."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{url}/v1/models",
timeout=aiohttp.ClientTimeout(total=timeout),
) as resp:
if resp.status == 200:
data = await resp.json()
# OpenAI-compatible format: { data: [{ id: "model_name", ... }] }
return [m["id"] for m in data.get("data", []) if "id" in m]
else:
logger.warning(f"llama-swap models query failed ({resp.status}) for {url}")
return []
except (asyncio.TimeoutError, aiohttp.ClientError) as e:
logger.warning(f"llama-swap unreachable at {url}: {e}")
return []
except Exception as e:
logger.warning(f"Unexpected error querying {url}: {e}")
return []
@router.get("/models/available")
async def get_available_models():
"""
Query both NVIDIA and AMD llama-swap instances for available models.
Returns model lists per GPU, their intersection, and all unique models.
Falls back to known model list if containers are unreachable.
"""
nvidia_models = await _query_llama_swap_models(globals.LLAMA_URL)
amd_models = await _query_llama_swap_models(globals.LLAMA_AMD_URL)
# If both failed, use the known model list from configs
if not nvidia_models and not amd_models:
logger.info("Both llama-swap instances unreachable, using known model list")
nvidia_set = {m for m, gpus in MODEL_GPU_MAP.items() if "nvidia" in gpus}
amd_set = {m for m, gpus in MODEL_GPU_MAP.items() if "amd" in gpus}
return {
"success": True,
"nvidia": sorted(nvidia_set),
"amd": sorted(amd_set),
"intersection": sorted(nvidia_set & amd_set),
"all": sorted(nvidia_set | amd_set),
"gpu_map": MODEL_GPU_MAP,
"source": "fallback",
}
nvidia_set = set(nvidia_models)
amd_set = set(amd_models)
return {
"success": True,
"nvidia": sorted(nvidia_set),
"amd": sorted(amd_set),
"intersection": sorted(nvidia_set & amd_set),
"all": sorted(nvidia_set | amd_set),
"gpu_map": MODEL_GPU_MAP,
"source": "live",
}
@router.post("/models/select")
async def select_model(body: dict):
"""
Set the model for a specific persona.
Body: {
"persona": "regular" | "evil" | "japanese",
"model": "model_name"
}
Persists the selection so it survives bot restarts.
"""
persona = body.get("persona", "").strip().lower()
model = body.get("model", "").strip()
valid_personas = {"regular", "evil", "japanese"}
if persona not in valid_personas:
return JSONResponse(
status_code=400,
content={"success": False, "error": f"Invalid persona '{persona}'. Must be one of: {', '.join(valid_personas)}"}
)
if not model:
return JSONResponse(
status_code=400,
content={"success": False, "error": "model is required"}
)
# Map persona to globals attribute and config key
PERSONA_MAP = {
"regular": ("TEXT_MODEL", "models.text"),
"evil": ("EVIL_TEXT_MODEL", "models.evil"),
"japanese": ("JAPANESE_TEXT_MODEL", "models.japanese"),
}
attr_name, config_key = PERSONA_MAP[persona]
# Set the global
setattr(globals, attr_name, model)
logger.info(f"Model selection: {persona}{model} (globals.{attr_name})")
# Persist via config manager
try:
from config_manager import config_manager
config_manager.set(config_key, model, persist=True)
except Exception as e:
logger.warning(f"Failed to persist model selection: {e}")
return {
"success": True,
"persona": persona,
"model": model,
"message": f"{persona.capitalize()} model set to '{model}'",
}
@router.get("/models/status")
async def get_model_status():
"""Return the current per-persona model assignments."""
return {
"success": True,
"regular": getattr(globals, "TEXT_MODEL", "llama3.1"),
"evil": getattr(globals, "EVIL_TEXT_MODEL", "darkidol"),
"japanese": getattr(globals, "JAPANESE_TEXT_MODEL", "swallow"),
}