107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
|
|
# utils/core.py
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import aiohttp
|
||
|
|
import re
|
||
|
|
|
||
|
|
import globals
|
||
|
|
from langchain_community.vectorstores import FAISS
|
||
|
|
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
||
|
|
from langchain.schema import Document
|
||
|
|
|
||
|
|
|
||
|
|
async def switch_model(model_name: str, timeout: int = 600):
|
||
|
|
if globals.current_model == model_name:
|
||
|
|
print(f"🔁 Model '{model_name}' already loaded.")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Unload all other models to clear VRAM
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(f"{globals.OLLAMA_URL}/api/show") as resp:
|
||
|
|
if resp.status == 200:
|
||
|
|
data = await resp.json()
|
||
|
|
loaded_models = data.get("models", [])
|
||
|
|
for model in loaded_models:
|
||
|
|
if model["name"] != model_name:
|
||
|
|
print(f"🔁 Unloading model: {model['name']}")
|
||
|
|
await session.post(f"{globals.OLLAMA_URL}/api/stop", json={"name": model["name"]})
|
||
|
|
else:
|
||
|
|
print("⚠️ Failed to check currently loaded models.")
|
||
|
|
|
||
|
|
print(f"🔄 Switching to model '{model_name}'...")
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
await session.post(f"{globals.OLLAMA_URL}/api/stop")
|
||
|
|
# Warm up the new model (dummy call to preload it)
|
||
|
|
payload = {
|
||
|
|
"model": model_name,
|
||
|
|
"prompt": "Hello",
|
||
|
|
"stream": False
|
||
|
|
}
|
||
|
|
headers = {"Content-Type": "application/json"}
|
||
|
|
|
||
|
|
# Poll until /api/generate returns 200
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
for _ in range(timeout):
|
||
|
|
async with session.post(f"{globals.OLLAMA_URL}/api/generate", json=payload, headers=headers) as resp:
|
||
|
|
if resp.status == 200:
|
||
|
|
globals.current_model = model_name
|
||
|
|
print(f"✅ Model {model_name} ready!")
|
||
|
|
return
|
||
|
|
await asyncio.sleep(1) # Wait a second before trying again
|
||
|
|
|
||
|
|
raise TimeoutError(f"Timed out waiting for model '{model_name}' to become available.")
|
||
|
|
|
||
|
|
|
||
|
|
async def is_miku_addressed(message) -> bool:
|
||
|
|
# If message contains a ping for Miku, return true
|
||
|
|
if message.guild.me in message.mentions:
|
||
|
|
return True
|
||
|
|
|
||
|
|
# If message is a reply, check the referenced message author
|
||
|
|
if message.reference:
|
||
|
|
try:
|
||
|
|
referenced_msg = await message.channel.fetch_message(message.reference.message_id)
|
||
|
|
if referenced_msg.author == message.guild.me: # or globals.client.user if you use client
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
print(f"⚠️ Could not fetch referenced message: {e}")
|
||
|
|
|
||
|
|
cleaned = message.content.strip()
|
||
|
|
|
||
|
|
return bool(re.search(
|
||
|
|
r'(?<![\w\(])(?:[^\w\s]{0,2}\s*)?miku(?:\s*[^\w\s]{0,2})?(?=,|\s*,|[!\.?\s]*$)',
|
||
|
|
cleaned,
|
||
|
|
re.IGNORECASE
|
||
|
|
))
|
||
|
|
|
||
|
|
# Load and index once at startup
|
||
|
|
def load_miku_knowledge():
|
||
|
|
with open("miku_lore.txt", "r", encoding="utf-8") as f:
|
||
|
|
text = f.read()
|
||
|
|
|
||
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
|
|
|
||
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
||
|
|
chunk_size=520,
|
||
|
|
chunk_overlap=50,
|
||
|
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
|
||
|
|
)
|
||
|
|
|
||
|
|
docs = [Document(page_content=chunk) for chunk in text_splitter.split_text(text)]
|
||
|
|
|
||
|
|
vectorstore = FAISS.from_documents(docs, globals.embeddings)
|
||
|
|
return vectorstore
|
||
|
|
|
||
|
|
def load_miku_lyrics():
|
||
|
|
with open("miku_lyrics.txt", "r", encoding="utf-8") as f:
|
||
|
|
lyrics_text = f.read()
|
||
|
|
|
||
|
|
text_splitter = CharacterTextSplitter(chunk_size=520, chunk_overlap=50)
|
||
|
|
docs = [Document(page_content=chunk) for chunk in text_splitter.split_text(lyrics_text)]
|
||
|
|
|
||
|
|
vectorstore = FAISS.from_documents(docs, globals.embeddings)
|
||
|
|
return vectorstore
|
||
|
|
|
||
|
|
miku_vectorstore = load_miku_knowledge()
|
||
|
|
miku_lyrics_vectorstore = load_miku_lyrics()
|