121 lines
4.6 KiB
Python
121 lines
4.6 KiB
Python
|
|
# utils/conversation_history.py
|
||
|
|
"""
|
||
|
|
Centralized conversation history management for Miku bot.
|
||
|
|
Tracks conversation context per server/DM channel.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from collections import defaultdict, deque
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Optional, List, Dict, Tuple
|
||
|
|
|
||
|
|
|
||
|
|
class ConversationHistory:
|
||
|
|
"""Manages conversation history per channel (server or DM)."""
|
||
|
|
|
||
|
|
def __init__(self, max_messages: int = 8):
|
||
|
|
"""
|
||
|
|
Initialize conversation history manager.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_messages: Maximum number of messages to keep per channel
|
||
|
|
"""
|
||
|
|
self.max_messages = max_messages
|
||
|
|
# Key: channel_id (guild_id for servers, user_id for DMs)
|
||
|
|
# Value: deque of (author_name, content, timestamp, is_bot) tuples
|
||
|
|
self._histories: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_messages * 2))
|
||
|
|
|
||
|
|
def add_message(self, channel_id: str, author_name: str, content: str, is_bot: bool = False):
|
||
|
|
"""
|
||
|
|
Add a message to the conversation history.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
channel_id: Server ID (for server messages) or user ID (for DMs)
|
||
|
|
author_name: Display name of the message author
|
||
|
|
content: Message content
|
||
|
|
is_bot: Whether this message is from Miku
|
||
|
|
"""
|
||
|
|
# Skip empty messages
|
||
|
|
if not content or not content.strip():
|
||
|
|
return
|
||
|
|
|
||
|
|
timestamp = datetime.utcnow()
|
||
|
|
self._histories[channel_id].append((author_name, content.strip(), timestamp, is_bot))
|
||
|
|
|
||
|
|
def get_recent_messages(self, channel_id: str, max_messages: Optional[int] = None) -> List[Tuple[str, str, bool]]:
|
||
|
|
"""
|
||
|
|
Get recent messages from a channel.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
channel_id: Server ID or user ID
|
||
|
|
max_messages: Number of messages to return (default: self.max_messages)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of (author_name, content, is_bot) tuples, oldest first
|
||
|
|
"""
|
||
|
|
if max_messages is None:
|
||
|
|
max_messages = self.max_messages
|
||
|
|
|
||
|
|
history = list(self._histories.get(channel_id, []))
|
||
|
|
# Return only the most recent messages (up to max_messages)
|
||
|
|
recent = history[-max_messages * 2:] if len(history) > max_messages * 2 else history
|
||
|
|
|
||
|
|
# Return without timestamp for simpler API
|
||
|
|
return [(author, content, is_bot) for author, content, _, is_bot in recent]
|
||
|
|
|
||
|
|
def format_for_llm(self, channel_id: str, max_messages: Optional[int] = None,
|
||
|
|
max_chars_per_message: int = 500) -> List[Dict[str, str]]:
|
||
|
|
"""
|
||
|
|
Format conversation history for LLM consumption (OpenAI messages format).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
channel_id: Server ID or user ID
|
||
|
|
max_messages: Number of messages to include (default: self.max_messages)
|
||
|
|
max_chars_per_message: Truncate messages longer than this
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of {"role": "user"|"assistant", "content": str} dicts
|
||
|
|
"""
|
||
|
|
recent = self.get_recent_messages(channel_id, max_messages)
|
||
|
|
|
||
|
|
messages = []
|
||
|
|
for author, content, is_bot in recent:
|
||
|
|
# Truncate very long messages
|
||
|
|
if len(content) > max_chars_per_message:
|
||
|
|
content = content[:max_chars_per_message] + "..."
|
||
|
|
|
||
|
|
# For bot messages, use "assistant" role
|
||
|
|
if is_bot:
|
||
|
|
messages.append({"role": "assistant", "content": content})
|
||
|
|
else:
|
||
|
|
# For user messages, optionally include author name for multi-user context
|
||
|
|
# Format: "username: message" to help Miku understand who said what
|
||
|
|
if author:
|
||
|
|
formatted_content = f"{author}: {content}"
|
||
|
|
else:
|
||
|
|
formatted_content = content
|
||
|
|
messages.append({"role": "user", "content": formatted_content})
|
||
|
|
|
||
|
|
return messages
|
||
|
|
|
||
|
|
def clear_channel(self, channel_id: str):
|
||
|
|
"""Clear all history for a specific channel."""
|
||
|
|
if channel_id in self._histories:
|
||
|
|
del self._histories[channel_id]
|
||
|
|
|
||
|
|
def get_channel_stats(self, channel_id: str) -> Dict[str, int]:
|
||
|
|
"""Get statistics about a channel's conversation history."""
|
||
|
|
history = self._histories.get(channel_id, deque())
|
||
|
|
total_messages = len(history)
|
||
|
|
bot_messages = sum(1 for _, _, _, is_bot in history if is_bot)
|
||
|
|
user_messages = total_messages - bot_messages
|
||
|
|
|
||
|
|
return {
|
||
|
|
"total_messages": total_messages,
|
||
|
|
"bot_messages": bot_messages,
|
||
|
|
"user_messages": user_messages
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# Global instance
|
||
|
|
conversation_history = ConversationHistory(max_messages=8)
|