Initial commit: Miku Discord Bot
This commit is contained in:
106
.bot.bak.80825/utils/core.py
Normal file
106
.bot.bak.80825/utils/core.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# utils/core.py
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import re
|
||||
|
||||
import globals
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
async def switch_model(model_name: str, timeout: int = 600):
|
||||
if globals.current_model == model_name:
|
||||
print(f"🔁 Model '{model_name}' already loaded.")
|
||||
return
|
||||
|
||||
# Unload all other models to clear VRAM
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{globals.OLLAMA_URL}/api/show") as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
loaded_models = data.get("models", [])
|
||||
for model in loaded_models:
|
||||
if model["name"] != model_name:
|
||||
print(f"🔁 Unloading model: {model['name']}")
|
||||
await session.post(f"{globals.OLLAMA_URL}/api/stop", json={"name": model["name"]})
|
||||
else:
|
||||
print("⚠️ Failed to check currently loaded models.")
|
||||
|
||||
print(f"🔄 Switching to model '{model_name}'...")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(f"{globals.OLLAMA_URL}/api/stop")
|
||||
# Warm up the new model (dummy call to preload it)
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": "Hello",
|
||||
"stream": False
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Poll until /api/generate returns 200
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for _ in range(timeout):
|
||||
async with session.post(f"{globals.OLLAMA_URL}/api/generate", json=payload, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
globals.current_model = model_name
|
||||
print(f"✅ Model {model_name} ready!")
|
||||
return
|
||||
await asyncio.sleep(1) # Wait a second before trying again
|
||||
|
||||
raise TimeoutError(f"Timed out waiting for model '{model_name}' to become available.")
|
||||
|
||||
|
||||
async def is_miku_addressed(message) -> bool:
|
||||
# If message contains a ping for Miku, return true
|
||||
if message.guild.me in message.mentions:
|
||||
return True
|
||||
|
||||
# If message is a reply, check the referenced message author
|
||||
if message.reference:
|
||||
try:
|
||||
referenced_msg = await message.channel.fetch_message(message.reference.message_id)
|
||||
if referenced_msg.author == message.guild.me: # or globals.client.user if you use client
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not fetch referenced message: {e}")
|
||||
|
||||
cleaned = message.content.strip()
|
||||
|
||||
return bool(re.search(
|
||||
r'(?<![\w\(])(?:[^\w\s]{0,2}\s*)?miku(?:\s*[^\w\s]{0,2})?(?=,|\s*,|[!\.?\s]*$)',
|
||||
cleaned,
|
||||
re.IGNORECASE
|
||||
))
|
||||
|
||||
# Load and index once at startup
|
||||
def load_miku_knowledge():
|
||||
with open("miku_lore.txt", "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=520,
|
||||
chunk_overlap=50,
|
||||
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
|
||||
)
|
||||
|
||||
docs = [Document(page_content=chunk) for chunk in text_splitter.split_text(text)]
|
||||
|
||||
vectorstore = FAISS.from_documents(docs, globals.embeddings)
|
||||
return vectorstore
|
||||
|
||||
def load_miku_lyrics():
|
||||
with open("miku_lyrics.txt", "r", encoding="utf-8") as f:
|
||||
lyrics_text = f.read()
|
||||
|
||||
text_splitter = CharacterTextSplitter(chunk_size=520, chunk_overlap=50)
|
||||
docs = [Document(page_content=chunk) for chunk in text_splitter.split_text(lyrics_text)]
|
||||
|
||||
vectorstore = FAISS.from_documents(docs, globals.embeddings)
|
||||
return vectorstore
|
||||
|
||||
miku_vectorstore = load_miku_knowledge()
|
||||
miku_lyrics_vectorstore = load_miku_lyrics()
|
||||
Reference in New Issue
Block a user