Files
miku-discord/backups/2025-12-07/old-bot-bak-80825/utils/core.py

107 lines
3.8 KiB
Python
Raw Normal View History

2025-12-07 17:15:09 +02:00
# utils/core.py
import asyncio
import aiohttp
import re
import globals
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.schema import Document
async def switch_model(model_name: str, timeout: int = 600):
if globals.current_model == model_name:
print(f"🔁 Model '{model_name}' already loaded.")
return
# Unload all other models to clear VRAM
async with aiohttp.ClientSession() as session:
async with session.get(f"{globals.OLLAMA_URL}/api/show") as resp:
if resp.status == 200:
data = await resp.json()
loaded_models = data.get("models", [])
for model in loaded_models:
if model["name"] != model_name:
print(f"🔁 Unloading model: {model['name']}")
await session.post(f"{globals.OLLAMA_URL}/api/stop", json={"name": model["name"]})
else:
print("⚠️ Failed to check currently loaded models.")
print(f"🔄 Switching to model '{model_name}'...")
async with aiohttp.ClientSession() as session:
await session.post(f"{globals.OLLAMA_URL}/api/stop")
# Warm up the new model (dummy call to preload it)
payload = {
"model": model_name,
"prompt": "Hello",
"stream": False
}
headers = {"Content-Type": "application/json"}
# Poll until /api/generate returns 200
async with aiohttp.ClientSession() as session:
for _ in range(timeout):
async with session.post(f"{globals.OLLAMA_URL}/api/generate", json=payload, headers=headers) as resp:
if resp.status == 200:
globals.current_model = model_name
print(f"✅ Model {model_name} ready!")
return
await asyncio.sleep(1) # Wait a second before trying again
raise TimeoutError(f"Timed out waiting for model '{model_name}' to become available.")
async def is_miku_addressed(message) -> bool:
# If message contains a ping for Miku, return true
if message.guild.me in message.mentions:
return True
# If message is a reply, check the referenced message author
if message.reference:
try:
referenced_msg = await message.channel.fetch_message(message.reference.message_id)
if referenced_msg.author == message.guild.me: # or globals.client.user if you use client
return True
except Exception as e:
print(f"⚠️ Could not fetch referenced message: {e}")
cleaned = message.content.strip()
return bool(re.search(
r'(?<![\w\(])(?:[^\w\s]{0,2}\s*)?miku(?:\s*[^\w\s]{0,2})?(?=,|\s*,|[!\.?\s]*$)',
cleaned,
re.IGNORECASE
))
# Load and index once at startup
def load_miku_knowledge():
with open("miku_lore.txt", "r", encoding="utf-8") as f:
text = f.read()
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=520,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)
docs = [Document(page_content=chunk) for chunk in text_splitter.split_text(text)]
vectorstore = FAISS.from_documents(docs, globals.embeddings)
return vectorstore
def load_miku_lyrics():
with open("miku_lyrics.txt", "r", encoding="utf-8") as f:
lyrics_text = f.read()
text_splitter = CharacterTextSplitter(chunk_size=520, chunk_overlap=50)
docs = [Document(page_content=chunk) for chunk in text_splitter.split_text(lyrics_text)]
vectorstore = FAISS.from_documents(docs, globals.embeddings)
return vectorstore
miku_vectorstore = load_miku_knowledge()
miku_lyrics_vectorstore = load_miku_lyrics()