147 lines
4.7 KiB
Python
147 lines
4.7 KiB
Python
from typing import Dict, Any, Optional
|
|
import base64
|
|
import threading
|
|
from datetime import datetime
|
|
import sys
|
|
import os
|
|
import logging
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from config import (
|
|
DEFAULT_USER_ID,
|
|
MEM0_CONFIG,
|
|
)
|
|
from api.doubao_tts import text_to_speech
|
|
from memory_module.memory_integration import Mem0Integration
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatService:
|
|
def __init__(self, user_id: str = None):
|
|
self.user_id = user_id or DEFAULT_USER_ID
|
|
self.mem0_integration = Mem0Integration(MEM0_CONFIG)
|
|
self._initialized = False
|
|
|
|
def initialize(self):
|
|
"""初始化 Mem0 集成"""
|
|
if self._initialized:
|
|
return
|
|
|
|
logger.info(f"Initializing Mem0 integration for user: {self.user_id}")
|
|
self._initialized = True
|
|
|
|
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
|
|
"""处理用户输入并返回回复(包含音频)"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
|
|
try:
|
|
# Step 1: Get response with memory integration
|
|
result = self.mem0_integration.generate_response_with_memory(
|
|
user_input=user_input,
|
|
user_id=self.user_id
|
|
)
|
|
|
|
if not result["success"]:
|
|
return {
|
|
"success": False,
|
|
"error": result.get("error", "Unknown error"),
|
|
"user_id": self.user_id
|
|
}
|
|
|
|
assistant_response = result["response"]
|
|
|
|
# Step 2: Generate audio if requested
|
|
audio_data = None
|
|
audio_error = None
|
|
if include_audio:
|
|
try:
|
|
success, message, base64_audio = text_to_speech(assistant_response, self.user_id)
|
|
if success and base64_audio:
|
|
audio_data = base64_audio
|
|
else:
|
|
audio_error = message
|
|
except Exception as e:
|
|
audio_error = f"TTS错误: {str(e)}"
|
|
|
|
# Step 3: Prepare response
|
|
response_data = {
|
|
"success": True,
|
|
"response": assistant_response,
|
|
"user_id": self.user_id
|
|
}
|
|
|
|
# Add optional fields
|
|
if audio_data:
|
|
response_data["audio_data"] = audio_data
|
|
if audio_error:
|
|
response_data["audio_error"] = audio_error
|
|
|
|
return response_data
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"user_id": self.user_id
|
|
}
|
|
|
|
def get_user_memories(self) -> Dict[str, Any]:
|
|
"""获取当前用户的所有记忆"""
|
|
try:
|
|
memories = self.mem0_integration.get_all_memories(self.user_id)
|
|
return {
|
|
"success": True,
|
|
"memories": memories,
|
|
"count": len(memories),
|
|
"user_id": self.user_id
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"user_id": self.user_id
|
|
}
|
|
|
|
def clear_user_memories(self) -> Dict[str, Any]:
|
|
"""清除当前用户的所有记忆"""
|
|
try:
|
|
success = self.mem0_integration.delete_all_memories(self.user_id)
|
|
if success:
|
|
return {
|
|
"success": True,
|
|
"message": f"所有记忆已清除,用户: {self.user_id}",
|
|
"user_id": self.user_id
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"error": "清除记忆失败",
|
|
"user_id": self.user_id
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"user_id": self.user_id
|
|
}
|
|
|
|
def search_memories(self, query: str, limit: int = 5) -> Dict[str, Any]:
|
|
"""搜索当前用户的记忆"""
|
|
try:
|
|
memories = self.mem0_integration.search_memories(query, self.user_id, limit)
|
|
return {
|
|
"success": True,
|
|
"memories": memories,
|
|
"count": len(memories),
|
|
"query": query,
|
|
"user_id": self.user_id
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"user_id": self.user_id
|
|
}
|