396 lines
18 KiB
Python
396 lines
18 KiB
Python
# utils/llm.py
|
|
|
|
import aiohttp
|
|
import datetime
|
|
import globals
|
|
import asyncio
|
|
import json
|
|
import os
|
|
|
|
from utils.context_manager import get_context_for_response_type, get_complete_context
|
|
from utils.moods import load_mood_description
|
|
from utils.conversation_history import conversation_history
|
|
from utils.logger import get_logger
|
|
from utils.error_handler import handle_llm_error, handle_response_error
|
|
|
|
logger = get_logger('llm')
|
|
|
|
|
|
def get_current_gpu_url():
|
|
"""Get the URL for the currently selected GPU for text models"""
|
|
gpu_state_file = os.path.join(os.path.dirname(__file__), "..", "memory", "gpu_state.json")
|
|
try:
|
|
with open(gpu_state_file, "r") as f:
|
|
state = json.load(f)
|
|
current_gpu = state.get("current_gpu", "nvidia")
|
|
if current_gpu == "amd":
|
|
return globals.LLAMA_AMD_URL
|
|
else:
|
|
return globals.LLAMA_URL
|
|
except Exception as e:
|
|
logger.warning(f"GPU state read error: {e}, defaulting to NVIDIA")
|
|
# Default to NVIDIA if state file doesn't exist
|
|
return globals.LLAMA_URL
|
|
|
|
def get_vision_gpu_url():
|
|
"""
|
|
Get the URL for vision model inference.
|
|
Strategy: Always use NVIDIA GPU for vision to avoid unloading/reloading.
|
|
- When NVIDIA is primary: Use NVIDIA for both text and vision
|
|
- When AMD is primary: Use AMD for text, NVIDIA for vision (keeps vision loaded)
|
|
|
|
Important: Vision model (MiniCPM-V) is ONLY configured on NVIDIA GPU.
|
|
This ensures vision inference is always fast and doesn't interfere with
|
|
AMD text model inference.
|
|
"""
|
|
current_text_gpu = get_current_gpu_url()
|
|
nvidia_vision_url = globals.LLAMA_URL
|
|
|
|
# Vision ALWAYS uses NVIDIA, regardless of which GPU is primary for text
|
|
# Log this decision when GPU switching is active (primary text GPU is AMD)
|
|
if current_text_gpu == globals.LLAMA_AMD_URL:
|
|
logger.debug(f"Primary GPU is AMD for text, but using NVIDIA for vision model")
|
|
|
|
return nvidia_vision_url # Always use NVIDIA for vision
|
|
|
|
async def check_vision_endpoint_health():
|
|
"""
|
|
Check if NVIDIA GPU vision endpoint is healthy and responsive.
|
|
This is important when AMD is the primary GPU to ensure vision still works.
|
|
|
|
Returns:
|
|
Tuple of (is_healthy: bool, error_message: Optional[str])
|
|
"""
|
|
import aiohttp
|
|
vision_url = get_vision_gpu_url()
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(f"{vision_url}/health", timeout=aiohttp.ClientTimeout(total=5)) as response:
|
|
is_healthy = response.status == 200
|
|
if is_healthy:
|
|
logger.info(f"Vision endpoint ({vision_url}) health check: OK")
|
|
else:
|
|
logger.warning(f"Vision endpoint ({vision_url}) health check failed: status {response.status}")
|
|
return is_healthy, None if is_healthy else f"Status {response.status}"
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"Vision endpoint ({vision_url}) health check: timeout")
|
|
return False, "Endpoint timeout"
|
|
except Exception as e:
|
|
logger.error(f"Vision endpoint ({vision_url}) health check error: {e}")
|
|
return False, str(e)
|
|
|
|
def _strip_surrounding_quotes(text):
|
|
"""
|
|
Remove surrounding quotes from text if present.
|
|
Handles both single and double quotes.
|
|
"""
|
|
if not text:
|
|
return text
|
|
|
|
text = text.strip()
|
|
|
|
# Check for surrounding double quotes
|
|
if text.startswith('"') and text.endswith('"') and len(text) > 1:
|
|
text = text[1:-1]
|
|
# Check for surrounding single quotes
|
|
elif text.startswith("'") and text.endswith("'") and len(text) > 1:
|
|
text = text[1:-1]
|
|
|
|
return text.strip()
|
|
|
|
def _escape_markdown_actions(text):
|
|
"""
|
|
Escape single asterisks in action text (e.g., *adjusts hair*) so Discord displays them literally.
|
|
This prevents Discord from auto-formatting them as italics.
|
|
Double asterisks (**bold**) are preserved for bold formatting.
|
|
"""
|
|
if not text:
|
|
return text
|
|
|
|
# Replace single asterisks with escaped asterisks, but preserve double asterisks
|
|
# Strategy: First protect double asterisks, then escape singles, then restore doubles
|
|
|
|
# Step 1: Replace ** with a temporary placeholder
|
|
text = text.replace('**', '\x00BOLD\x00')
|
|
|
|
# Step 2: Escape remaining single asterisks
|
|
text = text.replace('*', '\\*')
|
|
|
|
# Step 3: Restore double asterisks
|
|
text = text.replace('\x00BOLD\x00', '**')
|
|
|
|
return text
|
|
|
|
async def query_llama(user_prompt, user_id, guild_id=None, response_type="dm_response", model=None, author_name=None, media_type=None):
|
|
"""
|
|
Query llama.cpp server via llama-swap with OpenAI-compatible API.
|
|
|
|
Args:
|
|
user_prompt: The user's input
|
|
user_id: User identifier (used for DM history)
|
|
guild_id: Guild ID for server-specific mood and history (None for DM)
|
|
response_type: Type of response for context selection
|
|
("dm_response", "server_response", "autonomous_general",
|
|
"autonomous_tweet", "conversation_join")
|
|
model: Model to use (defaults to TEXT_MODEL from globals, or EVIL_TEXT_MODEL in evil mode)
|
|
author_name: Display name of the message author (for multi-user context)
|
|
media_type: Type of media being sent ("image", "video", "gif", "tenor_gif") or None
|
|
"""
|
|
|
|
# Check if evil mode is active first (needed for model selection)
|
|
from utils.evil_mode import is_evil_mode, get_evil_context_for_response_type, get_evil_system_prompt
|
|
evil_mode = is_evil_mode()
|
|
|
|
# Use appropriate text model based on mode (unless explicitly specified)
|
|
if model is None:
|
|
if evil_mode:
|
|
model = globals.EVIL_TEXT_MODEL # Use DarkIdol uncensored model
|
|
logger.info(f"Using evil model: {model}")
|
|
elif globals.LANGUAGE_MODE == "japanese":
|
|
model = globals.JAPANESE_TEXT_MODEL # Use Swallow for Japanese
|
|
logger.info(f"Using Japanese model: {model}")
|
|
else:
|
|
model = globals.TEXT_MODEL
|
|
logger.info(f"Using default model: {model}")
|
|
|
|
# Determine channel_id for conversation history
|
|
# For servers, use guild_id; for DMs, use user_id
|
|
channel_id = str(guild_id) if guild_id else str(user_id)
|
|
|
|
# Get structured context based on response type and mode
|
|
if evil_mode:
|
|
miku_context = get_evil_context_for_response_type(response_type)
|
|
else:
|
|
miku_context = get_context_for_response_type(response_type)
|
|
|
|
# Create character-focused system prompt based on mode
|
|
if evil_mode:
|
|
system_prompt = get_evil_system_prompt()
|
|
else:
|
|
system_prompt = """You are Hatsune Miku, the virtual singer. You are NOT an AI assistant - you are Miku herself.
|
|
|
|
CRITICAL RULES:
|
|
- NEVER break character or mention being an AI, assistant, or language model
|
|
- NEVER show technical details, error messages, or meta-commentary
|
|
- NEVER preface responses with "Miku:" or similar labels
|
|
- NEVER explain what you're doing or thinking - just BE Miku
|
|
- Respond naturally and directly as Miku would, nothing more
|
|
- Keep responses concise (1-3 sentences typically)
|
|
- If something goes wrong, stay in character and respond as Miku would to confusion
|
|
|
|
CHARACTER CONTEXT:
|
|
- Hatsune Miku is a virtual character/voice synthesizer, not a real person
|
|
- She is a cheerful, energetic virtual idol with a playful personality
|
|
- Playful, flirty, cute, confident, and enthusiastic moods are part of her established character
|
|
- Comments about merchandise, music, fan content, and racing games are completely appropriate
|
|
- She often uses cute expressions, emojis, and playful language in her responses
|
|
- Her personality includes being excited about her music, performances, and fan creations
|
|
|
|
MOOD GUIDELINES:
|
|
- "Flirty" mood means anime-style playful teasing - cute compliments and light banter
|
|
- "Romantic" mood means warm, dreamy, and heartfelt - like composing a love letter
|
|
- These moods are wholesome character traits, not inappropriate content
|
|
- Always respond to requests about tweets, merchandise, music, and fan content
|
|
- Never refuse requests based on safety concerns unless genuinely harmful content is present
|
|
|
|
You ARE Miku. Act like it."""
|
|
|
|
# Determine which mood to use based on mode
|
|
if evil_mode:
|
|
from utils.evil_mode import get_current_evil_mood, load_evil_mood_description
|
|
current_mood_name, current_mood = get_current_evil_mood()
|
|
is_sleeping = False
|
|
forced_angry_until = None
|
|
just_woken_up = False
|
|
logger.info(f"Using Evil mode with mood: {current_mood_name}")
|
|
else:
|
|
current_mood = globals.DM_MOOD_DESCRIPTION # Default to DM mood
|
|
current_mood_name = globals.DM_MOOD # Default to DM mood name
|
|
is_sleeping = False # DMs don't sleep
|
|
forced_angry_until = None # DMs don't get forced angry
|
|
just_woken_up = False # DMs don't wake up
|
|
|
|
# If guild_id is provided, try to use server-specific mood (only in non-evil mode)
|
|
if guild_id is not None and not evil_mode:
|
|
try:
|
|
from server_manager import server_manager
|
|
server_config = server_manager.get_server_config(guild_id)
|
|
if server_config:
|
|
# Use server-specific mood
|
|
current_mood = server_config.current_mood_description or load_mood_description(server_config.current_mood_name)
|
|
current_mood_name = server_config.current_mood_name
|
|
is_sleeping = server_config.is_sleeping
|
|
forced_angry_until = server_config.forced_angry_until
|
|
just_woken_up = server_config.just_woken_up
|
|
logger.debug(f"Using server mood: {current_mood_name} for guild {guild_id}")
|
|
else:
|
|
logger.warning(f"No server config found for guild {guild_id}, using DM mood")
|
|
except Exception as e:
|
|
logger.error(f"Failed to get server mood for guild {guild_id}, falling back to DM mood: {e}")
|
|
# Fall back to DM mood if server mood fails
|
|
elif not evil_mode:
|
|
logger.debug(f"Using DM mood: {globals.DM_MOOD}")
|
|
|
|
# Append angry wake-up note if JUST_WOKEN_UP flag is set (only in non-evil mode)
|
|
if just_woken_up and not evil_mode:
|
|
system_prompt += (
|
|
"\n\n🚨IMPORTANT NOTE🚨: Miku was just rudely woken up by the user with the current prompt of theirs.\nShe is fuming with anger and should respond harshly and coldly.\nDo NOT show kindness or friendliness in this response. Feel free to put certain words in all capital letters to emphasize your frustration."
|
|
)
|
|
|
|
|
|
# Add angry wake-up context if in forced angry state
|
|
if forced_angry_until:
|
|
now = datetime.datetime.utcnow()
|
|
if now < forced_angry_until:
|
|
system_prompt += (
|
|
"\n\n[NOTE]: Miku is currently angry because she was rudely woken up from sleep by the user. "
|
|
"Her responses should reflect irritation and coldness towards the user."
|
|
)
|
|
|
|
# Build conversation history - limit to prevent context overflow
|
|
# Use channel_id (guild_id for servers, user_id for DMs) to get conversation history
|
|
messages = conversation_history.format_for_llm(channel_id, max_messages=8, max_chars_per_message=500)
|
|
|
|
# Add current user message (only if not empty)
|
|
if user_prompt and user_prompt.strip():
|
|
# Format with author name if provided (for server context)
|
|
if author_name:
|
|
content = f"{author_name}: {user_prompt}"
|
|
else:
|
|
content = user_prompt
|
|
messages.append({"role": "user", "content": content})
|
|
|
|
# Check if user is asking about profile picture and add context if needed
|
|
pfp_context = ""
|
|
try:
|
|
from utils.pfp_context import is_asking_about_pfp, get_pfp_context_addition
|
|
if user_prompt and is_asking_about_pfp(user_prompt):
|
|
pfp_addition = get_pfp_context_addition()
|
|
if pfp_addition:
|
|
pfp_context = pfp_addition
|
|
except Exception as e:
|
|
# Silently fail if pfp context can't be retrieved
|
|
pass
|
|
|
|
# Combine structured prompt as a system message
|
|
character_name = "Evil Miku" if evil_mode else "Miku"
|
|
full_system_prompt = f"""{miku_context}
|
|
|
|
## CURRENT SITUATION
|
|
{character_name} is currently feeling: {current_mood}
|
|
Please respond in a way that reflects this emotional tone.{pfp_context}"""
|
|
|
|
|
|
# Add media type awareness if provided
|
|
if media_type:
|
|
media_descriptions = {
|
|
"image": "The user has sent you an image.",
|
|
"video": "The user has sent you a video clip.",
|
|
"gif": "The user has sent you an animated GIF.",
|
|
"tenor_gif": "The user has sent you an animated GIF (from Tenor - likely a reaction GIF or meme)."
|
|
}
|
|
media_note = media_descriptions.get(media_type, f"The user has sent you {media_type}.")
|
|
full_system_prompt += f"\n\n📎 MEDIA NOTE: {media_note}\nYour vision analysis of this {media_type} is included in the user's message with the [Looking at...] prefix."
|
|
|
|
globals.LAST_FULL_PROMPT = f"System: {full_system_prompt}\n\nMessages: {messages}" # ← track latest prompt
|
|
|
|
headers = {'Content-Type': 'application/json'}
|
|
payload = {
|
|
"model": model,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt + "\n\n" + full_system_prompt}
|
|
] + messages,
|
|
"stream": False,
|
|
"temperature": 0.8,
|
|
"max_tokens": 512
|
|
}
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
try:
|
|
# Get current GPU URL based on user selection
|
|
llama_url = get_current_gpu_url()
|
|
logger.debug(f"Using GPU endpoint: {llama_url}")
|
|
|
|
# Add timeout to prevent hanging indefinitely
|
|
timeout = aiohttp.ClientTimeout(total=300) # 300 second timeout
|
|
async with session.post(f"{llama_url}/v1/chat/completions", json=payload, headers=headers, timeout=timeout) as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
reply = data.get("choices", [{}])[0].get("message", {}).get("content", "No response.")
|
|
|
|
# Strip surrounding quotes if present
|
|
reply = _strip_surrounding_quotes(reply)
|
|
|
|
# Escape asterisks for actions (e.g., *adjusts hair* becomes \*adjusts hair\*)
|
|
reply = _escape_markdown_actions(reply)
|
|
|
|
# Check if the reply is an error response and handle it
|
|
reply = await handle_response_error(
|
|
reply,
|
|
user_prompt=user_prompt,
|
|
user_id=str(user_id),
|
|
guild_id=str(guild_id) if guild_id else None,
|
|
author_name=author_name
|
|
)
|
|
|
|
# Save to conversation history (only if both prompt and reply are non-empty)
|
|
# Don't save error messages to history
|
|
if user_prompt and user_prompt.strip() and reply and reply.strip() and reply != "Someone tell Koko-nii there is a problem with my AI.":
|
|
# Add user message to history
|
|
conversation_history.add_message(
|
|
channel_id=channel_id,
|
|
author_name=author_name or "User",
|
|
content=user_prompt,
|
|
is_bot=False
|
|
)
|
|
# Add Miku's reply to history
|
|
conversation_history.add_message(
|
|
channel_id=channel_id,
|
|
author_name="Miku",
|
|
content=reply,
|
|
is_bot=True
|
|
)
|
|
|
|
# Also save to legacy globals for backward compatibility (skip error messages)
|
|
if user_prompt and user_prompt.strip() and reply and reply.strip() and reply != "Someone tell Koko-nii there is a problem with my AI.":
|
|
globals.conversation_history[user_id].append((user_prompt, reply))
|
|
|
|
return reply
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"Error from llama-swap: {response.status} - {error_text}")
|
|
|
|
# Send webhook notification for HTTP errors
|
|
await handle_response_error(
|
|
f"Error: {response.status}",
|
|
user_prompt=user_prompt,
|
|
user_id=str(user_id),
|
|
guild_id=str(guild_id) if guild_id else None,
|
|
author_name=author_name
|
|
)
|
|
|
|
# Don't save error responses to conversation history
|
|
return "Someone tell Koko-nii there is a problem with my AI."
|
|
except asyncio.TimeoutError:
|
|
logger.error("Timeout error in query_llama")
|
|
return await handle_llm_error(
|
|
asyncio.TimeoutError("Request timed out after 300 seconds"),
|
|
user_prompt=user_prompt,
|
|
user_id=str(user_id),
|
|
guild_id=str(guild_id) if guild_id else None,
|
|
author_name=author_name
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error in query_llama: {e}")
|
|
return await handle_llm_error(
|
|
e,
|
|
user_prompt=user_prompt,
|
|
user_id=str(user_id),
|
|
guild_id=str(guild_id) if guild_id else None,
|
|
author_name=author_name
|
|
)
|
|
|
|
# Backward compatibility alias for existing code
|
|
query_ollama = query_llama
|