229 lines
8.4 KiB
Python
229 lines
8.4 KiB
Python
# face_detector_manager.py
|
|
"""
|
|
Manages on-demand starting/stopping of anime-face-detector container
|
|
to free up VRAM when not needed.
|
|
"""
|
|
|
|
import asyncio
|
|
import aiohttp
|
|
import subprocess
|
|
import time
|
|
from typing import Optional, Dict
|
|
|
|
|
|
class FaceDetectorManager:
|
|
"""Manages the anime-face-detector container lifecycle"""
|
|
|
|
FACE_DETECTOR_API = "http://anime-face-detector:6078/detect"
|
|
HEALTH_ENDPOINT = "http://anime-face-detector:6078/health"
|
|
CONTAINER_NAME = "anime-face-detector"
|
|
STARTUP_TIMEOUT = 30 # seconds
|
|
|
|
def __init__(self):
|
|
self.is_running = False
|
|
|
|
async def start_container(self, debug: bool = False) -> bool:
|
|
"""
|
|
Start the anime-face-detector container.
|
|
|
|
Returns:
|
|
True if started successfully, False otherwise
|
|
"""
|
|
try:
|
|
if debug:
|
|
print("🚀 Starting anime-face-detector container...")
|
|
|
|
# Start container using docker compose
|
|
result = subprocess.run(
|
|
["docker", "compose", "up", "-d", self.CONTAINER_NAME],
|
|
cwd="/app", # Assumes we're in the bot container, adjust path as needed
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
if debug:
|
|
print(f"⚠️ Failed to start container: {result.stderr}")
|
|
return False
|
|
|
|
# Wait for API to be ready
|
|
start_time = time.time()
|
|
while time.time() - start_time < self.STARTUP_TIMEOUT:
|
|
if await self._check_health():
|
|
self.is_running = True
|
|
if debug:
|
|
print(f"✅ Face detector container started and ready")
|
|
return True
|
|
await asyncio.sleep(1)
|
|
|
|
if debug:
|
|
print(f"⚠️ Container started but API not ready after {self.STARTUP_TIMEOUT}s")
|
|
return False
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
print(f"⚠️ Error starting face detector container: {e}")
|
|
return False
|
|
|
|
async def stop_container(self, debug: bool = False) -> bool:
|
|
"""
|
|
Stop the anime-face-detector container to free VRAM.
|
|
|
|
Returns:
|
|
True if stopped successfully, False otherwise
|
|
"""
|
|
try:
|
|
if debug:
|
|
print("🛑 Stopping anime-face-detector container...")
|
|
|
|
result = subprocess.run(
|
|
["docker", "compose", "stop", self.CONTAINER_NAME],
|
|
cwd="/app",
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=15
|
|
)
|
|
|
|
if result.returncode == 0:
|
|
self.is_running = False
|
|
if debug:
|
|
print("✅ Face detector container stopped")
|
|
return True
|
|
else:
|
|
if debug:
|
|
print(f"⚠️ Failed to stop container: {result.stderr}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
print(f"⚠️ Error stopping face detector container: {e}")
|
|
return False
|
|
|
|
async def _check_health(self) -> bool:
|
|
"""Check if the face detector API is responding"""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
self.HEALTH_ENDPOINT,
|
|
timeout=aiohttp.ClientTimeout(total=2)
|
|
) as response:
|
|
return response.status == 200
|
|
except:
|
|
return False
|
|
|
|
async def detect_face_with_management(
|
|
self,
|
|
image_bytes: bytes,
|
|
unload_vision_model: callable = None,
|
|
reload_vision_model: callable = None,
|
|
debug: bool = False
|
|
) -> Optional[Dict]:
|
|
"""
|
|
Detect face with automatic container lifecycle management.
|
|
|
|
Args:
|
|
image_bytes: Image data as bytes
|
|
unload_vision_model: Optional callback to unload vision model first
|
|
reload_vision_model: Optional callback to reload vision model after
|
|
debug: Enable debug output
|
|
|
|
Returns:
|
|
Detection dict or None
|
|
"""
|
|
container_was_started = False
|
|
|
|
try:
|
|
# Step 1: Unload vision model if callback provided
|
|
if unload_vision_model:
|
|
if debug:
|
|
print("📤 Unloading vision model to free VRAM...")
|
|
await unload_vision_model()
|
|
await asyncio.sleep(2) # Give time for VRAM to clear
|
|
|
|
# Step 2: Start face detector if not running
|
|
if not self.is_running:
|
|
if not await self.start_container(debug=debug):
|
|
if debug:
|
|
print("⚠️ Could not start face detector container")
|
|
return None
|
|
container_was_started = True
|
|
|
|
# Step 3: Detect face
|
|
result = await self._detect_face_api(image_bytes, debug=debug)
|
|
|
|
return result
|
|
|
|
finally:
|
|
# Step 4: Stop container and reload vision model
|
|
if container_was_started:
|
|
await self.stop_container(debug=debug)
|
|
|
|
if reload_vision_model:
|
|
if debug:
|
|
print("📥 Reloading vision model...")
|
|
await reload_vision_model()
|
|
|
|
async def _detect_face_api(self, image_bytes: bytes, debug: bool = False) -> Optional[Dict]:
|
|
"""Call the face detection API"""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
form = aiohttp.FormData()
|
|
form.add_field('file', image_bytes, filename='image.jpg', content_type='image/jpeg')
|
|
|
|
async with session.post(
|
|
self.FACE_DETECTOR_API,
|
|
data=form,
|
|
timeout=aiohttp.ClientTimeout(total=30)
|
|
) as response:
|
|
if response.status != 200:
|
|
if debug:
|
|
print(f"⚠️ Face detection API returned status {response.status}")
|
|
return None
|
|
|
|
result = await response.json()
|
|
|
|
if result.get('count', 0) == 0:
|
|
if debug:
|
|
print("👤 No faces detected by API")
|
|
return None
|
|
|
|
detections = result.get('detections', [])
|
|
if not detections:
|
|
return None
|
|
|
|
best_detection = max(detections, key=lambda d: d.get('confidence', 0))
|
|
bbox = best_detection.get('bbox', [])
|
|
confidence = best_detection.get('confidence', 0)
|
|
keypoints = best_detection.get('keypoints', [])
|
|
|
|
if len(bbox) >= 4:
|
|
x1, y1, x2, y2 = bbox[:4]
|
|
center_x = int((x1 + x2) / 2)
|
|
center_y = int((y1 + y2) / 2)
|
|
|
|
if debug:
|
|
width = int(x2 - x1)
|
|
height = int(y2 - y1)
|
|
print(f"👤 Detected {len(detections)} face(s) via API, using best at ({center_x}, {center_y}) [confidence: {confidence:.2%}]")
|
|
print(f" Bounding box: x={int(x1)}, y={int(y1)}, w={width}, h={height}")
|
|
print(f" Keypoints: {len(keypoints)} facial landmarks detected")
|
|
|
|
return {
|
|
'center': (center_x, center_y),
|
|
'bbox': bbox,
|
|
'confidence': confidence,
|
|
'keypoints': keypoints,
|
|
'count': len(detections)
|
|
}
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
print(f"⚠️ Error calling face detection API: {e}")
|
|
|
|
return None
|
|
|
|
|
|
# Global instance
|
|
face_detector_manager = FaceDetectorManager()
|