"""Model selection routes: query available models and set per-persona models.""" import aiohttp import asyncio import globals from fastapi import APIRouter from fastapi.responses import JSONResponse from utils.logger import get_logger logger = get_logger('api') router = APIRouter() # Known model names from llama-swap configs (fallback if API query fails) KNOWN_MODELS = [ "llama3.1", "darkidol", "swallow", "vision", "rocinante", "qwen3.5", ] # Which GPU each model is available on MODEL_GPU_MAP = { "llama3.1": {"nvidia", "amd"}, "darkidol": {"nvidia", "amd"}, "swallow": {"nvidia", "amd"}, "vision": {"nvidia"}, "rocinante": {"amd"}, "qwen3.5": {"amd"}, } async def _query_llama_swap_models(url: str, timeout: int = 10) -> list: """Query a llama-swap instance for its available models via /v1/models.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{url}/v1/models", timeout=aiohttp.ClientTimeout(total=timeout), ) as resp: if resp.status == 200: data = await resp.json() # OpenAI-compatible format: { data: [{ id: "model_name", ... }] } return [m["id"] for m in data.get("data", []) if "id" in m] else: logger.warning(f"llama-swap models query failed ({resp.status}) for {url}") return [] except (asyncio.TimeoutError, aiohttp.ClientError) as e: logger.warning(f"llama-swap unreachable at {url}: {e}") return [] except Exception as e: logger.warning(f"Unexpected error querying {url}: {e}") return [] @router.get("/models/available") async def get_available_models(): """ Query both NVIDIA and AMD llama-swap instances for available models. Returns model lists per GPU, their intersection, and all unique models. Falls back to known model list if containers are unreachable. """ nvidia_models = await _query_llama_swap_models(globals.LLAMA_URL) amd_models = await _query_llama_swap_models(globals.LLAMA_AMD_URL) # If both failed, use the known model list from configs if not nvidia_models and not amd_models: logger.info("Both llama-swap instances unreachable, using known model list") nvidia_set = {m for m, gpus in MODEL_GPU_MAP.items() if "nvidia" in gpus} amd_set = {m for m, gpus in MODEL_GPU_MAP.items() if "amd" in gpus} return { "success": True, "nvidia": sorted(nvidia_set), "amd": sorted(amd_set), "intersection": sorted(nvidia_set & amd_set), "all": sorted(nvidia_set | amd_set), "gpu_map": MODEL_GPU_MAP, "source": "fallback", } nvidia_set = set(nvidia_models) amd_set = set(amd_models) return { "success": True, "nvidia": sorted(nvidia_set), "amd": sorted(amd_set), "intersection": sorted(nvidia_set & amd_set), "all": sorted(nvidia_set | amd_set), "gpu_map": MODEL_GPU_MAP, "source": "live", } @router.post("/models/select") async def select_model(body: dict): """ Set the model for a specific persona. Body: { "persona": "regular" | "evil" | "japanese", "model": "model_name" } Persists the selection so it survives bot restarts. """ persona = body.get("persona", "").strip().lower() model = body.get("model", "").strip() valid_personas = {"regular", "evil", "japanese"} if persona not in valid_personas: return JSONResponse( status_code=400, content={"success": False, "error": f"Invalid persona '{persona}'. Must be one of: {', '.join(valid_personas)}"} ) if not model: return JSONResponse( status_code=400, content={"success": False, "error": "model is required"} ) # Map persona to globals attribute and config key PERSONA_MAP = { "regular": ("TEXT_MODEL", "models.text"), "evil": ("EVIL_TEXT_MODEL", "models.evil"), "japanese": ("JAPANESE_TEXT_MODEL", "models.japanese"), } attr_name, config_key = PERSONA_MAP[persona] # Set the global setattr(globals, attr_name, model) logger.info(f"Model selection: {persona} → {model} (globals.{attr_name})") # Persist via config manager try: from config_manager import config_manager config_manager.set(config_key, model, persist=True) except Exception as e: logger.warning(f"Failed to persist model selection: {e}") return { "success": True, "persona": persona, "model": model, "message": f"{persona.capitalize()} model set to '{model}'", } @router.get("/models/status") async def get_model_status(): """Return the current per-persona model assignments.""" return { "success": True, "regular": getattr(globals, "TEXT_MODEL", "llama3.1"), "evil": getattr(globals, "EVIL_TEXT_MODEL", "darkidol"), "japanese": getattr(globals, "JAPANESE_TEXT_MODEL", "swallow"), }