feat(models): add model selection API endpoints
- GET /models/available: query both llama-swap instances for model lists - POST /models/select: set per-persona model (regular/evil/japanese) with persistence - GET /models/status: return current per-persona model assignments - Fall back to known model list when containers are unreachable
This commit is contained in:
161
bot/routes/models_selector.py
Normal file
161
bot/routes/models_selector.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Model selection routes: query available models and set per-persona models."""
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import globals
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('api')
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Known model names from llama-swap configs (fallback if API query fails)
|
||||||
|
KNOWN_MODELS = [
|
||||||
|
"llama3.1",
|
||||||
|
"darkidol",
|
||||||
|
"swallow",
|
||||||
|
"vision",
|
||||||
|
"rocinante",
|
||||||
|
"qwen3.5",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Which GPU each model is available on
|
||||||
|
MODEL_GPU_MAP = {
|
||||||
|
"llama3.1": {"nvidia", "amd"},
|
||||||
|
"darkidol": {"nvidia", "amd"},
|
||||||
|
"swallow": {"nvidia", "amd"},
|
||||||
|
"vision": {"nvidia"},
|
||||||
|
"rocinante": {"amd"},
|
||||||
|
"qwen3.5": {"amd"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _query_llama_swap_models(url: str, timeout: int = 10) -> list:
|
||||||
|
"""Query a llama-swap instance for its available models via /v1/models."""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{url}/v1/models",
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||||
|
) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
# OpenAI-compatible format: { data: [{ id: "model_name", ... }] }
|
||||||
|
return [m["id"] for m in data.get("data", []) if "id" in m]
|
||||||
|
else:
|
||||||
|
logger.warning(f"llama-swap models query failed ({resp.status}) for {url}")
|
||||||
|
return []
|
||||||
|
except (asyncio.TimeoutError, aiohttp.ClientError) as e:
|
||||||
|
logger.warning(f"llama-swap unreachable at {url}: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unexpected error querying {url}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models/available")
|
||||||
|
async def get_available_models():
|
||||||
|
"""
|
||||||
|
Query both NVIDIA and AMD llama-swap instances for available models.
|
||||||
|
Returns model lists per GPU, their intersection, and all unique models.
|
||||||
|
Falls back to known model list if containers are unreachable.
|
||||||
|
"""
|
||||||
|
nvidia_models = await _query_llama_swap_models(globals.LLAMA_URL)
|
||||||
|
amd_models = await _query_llama_swap_models(globals.LLAMA_AMD_URL)
|
||||||
|
|
||||||
|
# If both failed, use the known model list from configs
|
||||||
|
if not nvidia_models and not amd_models:
|
||||||
|
logger.info("Both llama-swap instances unreachable, using known model list")
|
||||||
|
nvidia_set = {m for m, gpus in MODEL_GPU_MAP.items() if "nvidia" in gpus}
|
||||||
|
amd_set = {m for m, gpus in MODEL_GPU_MAP.items() if "amd" in gpus}
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"nvidia": sorted(nvidia_set),
|
||||||
|
"amd": sorted(amd_set),
|
||||||
|
"intersection": sorted(nvidia_set & amd_set),
|
||||||
|
"all": sorted(nvidia_set | amd_set),
|
||||||
|
"gpu_map": MODEL_GPU_MAP,
|
||||||
|
"source": "fallback",
|
||||||
|
}
|
||||||
|
|
||||||
|
nvidia_set = set(nvidia_models)
|
||||||
|
amd_set = set(amd_models)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"nvidia": sorted(nvidia_set),
|
||||||
|
"amd": sorted(amd_set),
|
||||||
|
"intersection": sorted(nvidia_set & amd_set),
|
||||||
|
"all": sorted(nvidia_set | amd_set),
|
||||||
|
"gpu_map": MODEL_GPU_MAP,
|
||||||
|
"source": "live",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/models/select")
|
||||||
|
async def select_model(body: dict):
|
||||||
|
"""
|
||||||
|
Set the model for a specific persona.
|
||||||
|
|
||||||
|
Body: {
|
||||||
|
"persona": "regular" | "evil" | "japanese",
|
||||||
|
"model": "model_name"
|
||||||
|
}
|
||||||
|
|
||||||
|
Persists the selection so it survives bot restarts.
|
||||||
|
"""
|
||||||
|
persona = body.get("persona", "").strip().lower()
|
||||||
|
model = body.get("model", "").strip()
|
||||||
|
|
||||||
|
valid_personas = {"regular", "evil", "japanese"}
|
||||||
|
if persona not in valid_personas:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"success": False, "error": f"Invalid persona '{persona}'. Must be one of: {', '.join(valid_personas)}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"success": False, "error": "model is required"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map persona to globals attribute and config key
|
||||||
|
PERSONA_MAP = {
|
||||||
|
"regular": ("TEXT_MODEL", "models.text"),
|
||||||
|
"evil": ("EVIL_TEXT_MODEL", "models.evil"),
|
||||||
|
"japanese": ("JAPANESE_TEXT_MODEL", "models.japanese"),
|
||||||
|
}
|
||||||
|
|
||||||
|
attr_name, config_key = PERSONA_MAP[persona]
|
||||||
|
|
||||||
|
# Set the global
|
||||||
|
setattr(globals, attr_name, model)
|
||||||
|
logger.info(f"Model selection: {persona} → {model} (globals.{attr_name})")
|
||||||
|
|
||||||
|
# Persist via config manager
|
||||||
|
try:
|
||||||
|
from config_manager import config_manager
|
||||||
|
config_manager.set(config_key, model, persist=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to persist model selection: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"persona": persona,
|
||||||
|
"model": model,
|
||||||
|
"message": f"{persona.capitalize()} model set to '{model}'",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models/status")
|
||||||
|
async def get_model_status():
|
||||||
|
"""Return the current per-persona model assignments."""
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"regular": getattr(globals, "TEXT_MODEL", "llama3.1"),
|
||||||
|
"evil": getattr(globals, "EVIL_TEXT_MODEL", "darkidol"),
|
||||||
|
"japanese": getattr(globals, "JAPANESE_TEXT_MODEL", "swallow"),
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user