Add dual GPU support with web UI selector
Features: - Built custom ROCm container for AMD RX 6800 GPU - Added GPU selection toggle in web UI (NVIDIA/AMD) - Unified model names across both GPUs for seamless switching - Vision model always uses NVIDIA GPU (optimal performance) - Text models (llama3.1, darkidol) can use either GPU - Added /gpu-status and /gpu-select API endpoints - Implemented GPU state persistence in memory/gpu_state.json Technical details: - Multi-stage Dockerfile.llamaswap-rocm with ROCm 6.2.4 - llama.cpp compiled with GGML_HIP=ON for gfx1030 (RX 6800) - Proper GPU permissions without root (groups 187/989) - AMD container on port 8091, NVIDIA on port 8090 - Updated bot/utils/llm.py with get_current_gpu_url() and get_vision_gpu_url() - Modified bot/utils/image_handling.py to always use NVIDIA for vision - Enhanced web UI with GPU selector button (blue=NVIDIA, red=AMD) Files modified: - docker-compose.yml (added llama-swap-amd service) - bot/globals.py (added LLAMA_AMD_URL) - bot/api.py (added GPU selection endpoints and helper function) - bot/utils/llm.py (GPU routing for text models) - bot/utils/image_handling.py (GPU routing for vision models) - bot/static/index.html (GPU selector UI) - llama-swap-rocm-config.yaml (unified model names) New files: - Dockerfile.llamaswap-rocm - bot/memory/gpu_state.json - bot/utils/gpu_router.py (load balancing utility) - setup-dual-gpu.sh (setup verification script) - DUAL_GPU_*.md (documentation files)
This commit is contained in:
157
bot/api.py
157
bot/api.py
@@ -52,6 +52,22 @@ from utils.figurine_notifier import (
|
||||
from utils.dm_logger import dm_logger
|
||||
nest_asyncio.apply()
|
||||
|
||||
# ========== GPU Selection Helper ==========
|
||||
def get_current_gpu_url():
|
||||
"""Get the URL for the currently selected GPU"""
|
||||
gpu_state_file = os.path.join(os.path.dirname(__file__), "memory", "gpu_state.json")
|
||||
try:
|
||||
with open(gpu_state_file, "r") as f:
|
||||
state = json.load(f)
|
||||
current_gpu = state.get("current_gpu", "nvidia")
|
||||
if current_gpu == "amd":
|
||||
return globals.LLAMA_AMD_URL
|
||||
else:
|
||||
return globals.LLAMA_URL
|
||||
except:
|
||||
# Default to NVIDIA if state file doesn't exist
|
||||
return globals.LLAMA_URL
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Serve static folder
|
||||
@@ -363,6 +379,97 @@ def trigger_argument(data: BipolarTriggerRequest):
|
||||
"channel_id": channel_id
|
||||
}
|
||||
|
||||
@app.post("/bipolar-mode/trigger-dialogue")
|
||||
def trigger_dialogue(data: dict):
|
||||
"""Manually trigger a persona dialogue from a message
|
||||
|
||||
Forces the opposite persona to start a dialogue (bypasses the interjection check).
|
||||
"""
|
||||
from utils.persona_dialogue import get_dialogue_manager
|
||||
from utils.bipolar_mode import is_bipolar_mode, is_argument_in_progress
|
||||
|
||||
message_id_str = data.get("message_id")
|
||||
if not message_id_str:
|
||||
return {"status": "error", "message": "Message ID is required"}
|
||||
|
||||
# Parse message ID
|
||||
try:
|
||||
message_id = int(message_id_str)
|
||||
except ValueError:
|
||||
return {"status": "error", "message": "Invalid message ID format"}
|
||||
|
||||
if not is_bipolar_mode():
|
||||
return {"status": "error", "message": "Bipolar mode is not enabled"}
|
||||
|
||||
if not globals.client or not globals.client.loop or not globals.client.loop.is_running():
|
||||
return {"status": "error", "message": "Discord client not ready"}
|
||||
|
||||
import asyncio
|
||||
|
||||
async def trigger_dialogue_task():
|
||||
try:
|
||||
# Fetch the message
|
||||
message = None
|
||||
for channel in globals.client.get_all_channels():
|
||||
if hasattr(channel, 'fetch_message'):
|
||||
try:
|
||||
message = await channel.fetch_message(message_id)
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if not message:
|
||||
print(f"⚠️ Message {message_id} not found")
|
||||
return
|
||||
|
||||
# Check if there's already an argument or dialogue in progress
|
||||
dialogue_manager = get_dialogue_manager()
|
||||
if dialogue_manager.is_dialogue_active(message.channel.id):
|
||||
print(f"⚠️ Dialogue already active in channel {message.channel.id}")
|
||||
return
|
||||
|
||||
if is_argument_in_progress(message.channel.id):
|
||||
print(f"⚠️ Argument already in progress in channel {message.channel.id}")
|
||||
return
|
||||
|
||||
# Determine current persona from the message author
|
||||
if message.webhook_id:
|
||||
# It's a webhook message, need to determine which persona
|
||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||
elif message.author.id == globals.client.user.id:
|
||||
# It's the bot's message
|
||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||
else:
|
||||
# User message - can't trigger dialogue from user messages
|
||||
print(f"⚠️ Cannot trigger dialogue from user message")
|
||||
return
|
||||
|
||||
opposite_persona = "evil" if current_persona == "miku" else "miku"
|
||||
|
||||
print(f"🎭 [Manual Trigger] Forcing {opposite_persona} to start dialogue on message {message_id}")
|
||||
|
||||
# Force start the dialogue (bypass interjection check)
|
||||
dialogue_manager.start_dialogue(message.channel.id)
|
||||
asyncio.create_task(
|
||||
dialogue_manager.handle_dialogue_turn(
|
||||
message.channel,
|
||||
opposite_persona,
|
||||
trigger_reason="manual_trigger"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error triggering dialogue: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
globals.client.loop.create_task(trigger_dialogue_task())
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Dialogue triggered for message {message_id}"
|
||||
}
|
||||
|
||||
@app.get("/bipolar-mode/scoreboard")
|
||||
def get_bipolar_scoreboard():
|
||||
"""Get the bipolar mode argument scoreboard"""
|
||||
@@ -392,6 +499,51 @@ def cleanup_bipolar_webhooks():
|
||||
globals.client.loop.create_task(cleanup_webhooks(globals.client))
|
||||
return {"status": "ok", "message": "Webhook cleanup started"}
|
||||
|
||||
# ========== GPU Selection ==========
|
||||
@app.get("/gpu-status")
|
||||
def get_gpu_status():
|
||||
"""Get current GPU selection"""
|
||||
gpu_state_file = os.path.join(os.path.dirname(__file__), "memory", "gpu_state.json")
|
||||
try:
|
||||
with open(gpu_state_file, "r") as f:
|
||||
state = json.load(f)
|
||||
return {"gpu": state.get("current_gpu", "nvidia")}
|
||||
except:
|
||||
return {"gpu": "nvidia"}
|
||||
|
||||
@app.post("/gpu-select")
|
||||
async def select_gpu(request: Request):
|
||||
"""Select which GPU to use for inference"""
|
||||
from utils.gpu_preload import preload_amd_models
|
||||
|
||||
data = await request.json()
|
||||
gpu = data.get("gpu", "nvidia").lower()
|
||||
|
||||
if gpu not in ["nvidia", "amd"]:
|
||||
return {"status": "error", "message": "Invalid GPU selection. Must be 'nvidia' or 'amd'"}
|
||||
|
||||
gpu_state_file = os.path.join(os.path.dirname(__file__), "memory", "gpu_state.json")
|
||||
try:
|
||||
from datetime import datetime
|
||||
state = {
|
||||
"current_gpu": gpu,
|
||||
"last_updated": datetime.now().isoformat()
|
||||
}
|
||||
with open(gpu_state_file, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
print(f"🎮 GPU Selection: Switched to {gpu.upper()} GPU")
|
||||
|
||||
# Preload models on AMD GPU (16GB VRAM - can hold both text + vision)
|
||||
if gpu == "amd":
|
||||
asyncio.create_task(preload_amd_models())
|
||||
print("🔧 Preloading text and vision models on AMD GPU...")
|
||||
|
||||
return {"status": "ok", "message": f"Switched to {gpu.upper()} GPU", "gpu": gpu}
|
||||
except Exception as e:
|
||||
print(f"🎮 GPU Selection Error: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@app.get("/bipolar-mode/arguments")
|
||||
def get_active_arguments():
|
||||
"""Get all active arguments"""
|
||||
@@ -2100,10 +2252,13 @@ Be detailed but conversational. React to what you see with Miku's cheerful, play
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
# Get current GPU URL based on user selection
|
||||
llama_url = get_current_gpu_url()
|
||||
|
||||
# Make streaming request to llama.cpp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{globals.LLAMA_URL}/v1/chat/completions",
|
||||
f"{llama_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
) as response:
|
||||
|
||||
Reference in New Issue
Block a user