From ae4e40f2d7f4918d06a631b7d654f18f1afa7f38 Mon Sep 17 00:00:00 2001 From: koko210Serve Date: Wed, 20 May 2026 13:54:59 +0300 Subject: [PATCH] feat(models): add model selection API endpoints - GET /models/available: query both llama-swap instances for model lists - POST /models/select: set per-persona model (regular/evil/japanese) with persistence - GET /models/status: return current per-persona model assignments - Fall back to known model list when containers are unreachable --- bot/routes/models_selector.py | 161 ++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 bot/routes/models_selector.py diff --git a/bot/routes/models_selector.py b/bot/routes/models_selector.py new file mode 100644 index 0000000..fd20023 --- /dev/null +++ b/bot/routes/models_selector.py @@ -0,0 +1,161 @@ +"""Model selection routes: query available models and set per-persona models.""" + +import aiohttp +import asyncio +import globals +from fastapi import APIRouter +from fastapi.responses import JSONResponse +from utils.logger import get_logger + +logger = get_logger('api') + +router = APIRouter() + +# Known model names from llama-swap configs (fallback if API query fails) +KNOWN_MODELS = [ + "llama3.1", + "darkidol", + "swallow", + "vision", + "rocinante", + "qwen3.5", +] + +# Which GPU each model is available on +MODEL_GPU_MAP = { + "llama3.1": {"nvidia", "amd"}, + "darkidol": {"nvidia", "amd"}, + "swallow": {"nvidia", "amd"}, + "vision": {"nvidia"}, + "rocinante": {"amd"}, + "qwen3.5": {"amd"}, +} + + +async def _query_llama_swap_models(url: str, timeout: int = 10) -> list: + """Query a llama-swap instance for its available models via /v1/models.""" + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{url}/v1/models", + timeout=aiohttp.ClientTimeout(total=timeout), + ) as resp: + if resp.status == 200: + data = await resp.json() + # OpenAI-compatible format: { data: [{ id: "model_name", ... }] } + return [m["id"] for m in data.get("data", []) if "id" in m] + else: + logger.warning(f"llama-swap models query failed ({resp.status}) for {url}") + return [] + except (asyncio.TimeoutError, aiohttp.ClientError) as e: + logger.warning(f"llama-swap unreachable at {url}: {e}") + return [] + except Exception as e: + logger.warning(f"Unexpected error querying {url}: {e}") + return [] + + +@router.get("/models/available") +async def get_available_models(): + """ + Query both NVIDIA and AMD llama-swap instances for available models. + Returns model lists per GPU, their intersection, and all unique models. + Falls back to known model list if containers are unreachable. + """ + nvidia_models = await _query_llama_swap_models(globals.LLAMA_URL) + amd_models = await _query_llama_swap_models(globals.LLAMA_AMD_URL) + + # If both failed, use the known model list from configs + if not nvidia_models and not amd_models: + logger.info("Both llama-swap instances unreachable, using known model list") + nvidia_set = {m for m, gpus in MODEL_GPU_MAP.items() if "nvidia" in gpus} + amd_set = {m for m, gpus in MODEL_GPU_MAP.items() if "amd" in gpus} + return { + "success": True, + "nvidia": sorted(nvidia_set), + "amd": sorted(amd_set), + "intersection": sorted(nvidia_set & amd_set), + "all": sorted(nvidia_set | amd_set), + "gpu_map": MODEL_GPU_MAP, + "source": "fallback", + } + + nvidia_set = set(nvidia_models) + amd_set = set(amd_models) + + return { + "success": True, + "nvidia": sorted(nvidia_set), + "amd": sorted(amd_set), + "intersection": sorted(nvidia_set & amd_set), + "all": sorted(nvidia_set | amd_set), + "gpu_map": MODEL_GPU_MAP, + "source": "live", + } + + +@router.post("/models/select") +async def select_model(body: dict): + """ + Set the model for a specific persona. + + Body: { + "persona": "regular" | "evil" | "japanese", + "model": "model_name" + } + + Persists the selection so it survives bot restarts. + """ + persona = body.get("persona", "").strip().lower() + model = body.get("model", "").strip() + + valid_personas = {"regular", "evil", "japanese"} + if persona not in valid_personas: + return JSONResponse( + status_code=400, + content={"success": False, "error": f"Invalid persona '{persona}'. Must be one of: {', '.join(valid_personas)}"} + ) + + if not model: + return JSONResponse( + status_code=400, + content={"success": False, "error": "model is required"} + ) + + # Map persona to globals attribute and config key + PERSONA_MAP = { + "regular": ("TEXT_MODEL", "models.text"), + "evil": ("EVIL_TEXT_MODEL", "models.evil"), + "japanese": ("JAPANESE_TEXT_MODEL", "models.japanese"), + } + + attr_name, config_key = PERSONA_MAP[persona] + + # Set the global + setattr(globals, attr_name, model) + logger.info(f"Model selection: {persona} → {model} (globals.{attr_name})") + + # Persist via config manager + try: + from config_manager import config_manager + config_manager.set(config_key, model, persist=True) + except Exception as e: + logger.warning(f"Failed to persist model selection: {e}") + + return { + "success": True, + "persona": persona, + "model": model, + "message": f"{persona.capitalize()} model set to '{model}'", + } + + +@router.get("/models/status") +async def get_model_status(): + """Return the current per-persona model assignments.""" + return { + "success": True, + "regular": getattr(globals, "TEXT_MODEL", "llama3.1"), + "evil": getattr(globals, "EVIL_TEXT_MODEL", "darkidol"), + "japanese": getattr(globals, "JAPANESE_TEXT_MODEL", "swallow"), + }