147 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			147 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Dict, Any, Optional
 | |
| import base64
 | |
| import threading
 | |
| from datetime import datetime
 | |
| import sys
 | |
| import os
 | |
| import logging
 | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 | |
| 
 | |
| from config import (
 | |
|     DEFAULT_USER_ID,
 | |
|     MEM0_CONFIG,
 | |
| )
 | |
| from api.doubao_tts import text_to_speech
 | |
| from memory_module.memory_integration import Mem0Integration
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class ChatService:
 | |
|     def __init__(self, user_id: str = None):
 | |
|         self.user_id = user_id or DEFAULT_USER_ID
 | |
|         self.mem0_integration = Mem0Integration(MEM0_CONFIG)
 | |
|         self._initialized = False
 | |
|     
 | |
|     def initialize(self):
 | |
|         """初始化 Mem0 集成"""
 | |
|         if self._initialized:
 | |
|             return
 | |
|         
 | |
|         logger.info(f"Initializing Mem0 integration for user: {self.user_id}")
 | |
|         self._initialized = True
 | |
|     
 | |
|     def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
 | |
|         """处理用户输入并返回回复(包含音频)"""
 | |
|         if not self._initialized:
 | |
|             self.initialize()
 | |
|         
 | |
|         try:
 | |
|             # Step 1: Get response with memory integration
 | |
|             result = self.mem0_integration.generate_response_with_memory(
 | |
|                 user_input=user_input,
 | |
|                 user_id=self.user_id
 | |
|             )
 | |
|             
 | |
|             if not result["success"]:
 | |
|                 return {
 | |
|                     "success": False,
 | |
|                     "error": result.get("error", "Unknown error"),
 | |
|                     "user_id": self.user_id
 | |
|                 }
 | |
|             
 | |
|             assistant_response = result["response"]
 | |
|             
 | |
|             # Step 2: Generate audio if requested
 | |
|             audio_data = None
 | |
|             audio_error = None
 | |
|             if include_audio:
 | |
|                 try:
 | |
|                     success, message, base64_audio = text_to_speech(assistant_response, self.user_id)
 | |
|                     if success and base64_audio:
 | |
|                         audio_data = base64_audio
 | |
|                     else:
 | |
|                         audio_error = message
 | |
|                 except Exception as e:
 | |
|                     audio_error = f"TTS错误: {str(e)}"
 | |
|             
 | |
|             # Step 3: Prepare response
 | |
|             response_data = {
 | |
|                 "success": True,
 | |
|                 "response": assistant_response,
 | |
|                 "user_id": self.user_id
 | |
|             }
 | |
|             
 | |
|             # Add optional fields
 | |
|             if audio_data:
 | |
|                 response_data["audio_data"] = audio_data
 | |
|             if audio_error:
 | |
|                 response_data["audio_error"] = audio_error
 | |
|             
 | |
|             return response_data
 | |
|             
 | |
|         except Exception as e:
 | |
|             return {
 | |
|                 "success": False,
 | |
|                 "error": str(e),
 | |
|                 "user_id": self.user_id
 | |
|             }
 | |
|     
 | |
|     def get_user_memories(self) -> Dict[str, Any]:
 | |
|         """获取当前用户的所有记忆"""
 | |
|         try:
 | |
|             memories = self.mem0_integration.get_all_memories(self.user_id)
 | |
|             return {
 | |
|                 "success": True,
 | |
|                 "memories": memories,
 | |
|                 "count": len(memories),
 | |
|                 "user_id": self.user_id
 | |
|             }
 | |
|         except Exception as e:
 | |
|             return {
 | |
|                 "success": False,
 | |
|                 "error": str(e),
 | |
|                 "user_id": self.user_id
 | |
|             }
 | |
|     
 | |
|     def clear_user_memories(self) -> Dict[str, Any]:
 | |
|         """清除当前用户的所有记忆"""
 | |
|         try:
 | |
|             success = self.mem0_integration.delete_all_memories(self.user_id)
 | |
|             if success:
 | |
|                 return {
 | |
|                     "success": True,
 | |
|                     "message": f"所有记忆已清除,用户: {self.user_id}",
 | |
|                     "user_id": self.user_id
 | |
|                 }
 | |
|             else:
 | |
|                 return {
 | |
|                     "success": False,
 | |
|                     "error": "清除记忆失败",
 | |
|                     "user_id": self.user_id
 | |
|                 }
 | |
|         except Exception as e:
 | |
|             return {
 | |
|                 "success": False,
 | |
|                 "error": str(e),
 | |
|                 "user_id": self.user_id
 | |
|             }
 | |
|     
 | |
|     def search_memories(self, query: str, limit: int = 5) -> Dict[str, Any]:
 | |
|         """搜索当前用户的记忆"""
 | |
|         try:
 | |
|             memories = self.mem0_integration.search_memories(query, self.user_id, limit)
 | |
|             return {
 | |
|                 "success": True,
 | |
|                 "memories": memories,
 | |
|                 "count": len(memories),
 | |
|                 "query": query,
 | |
|                 "user_id": self.user_id
 | |
|             }
 | |
|         except Exception as e:
 | |
|             return {
 | |
|                 "success": False,
 | |
|                 "error": str(e),
 | |
|                 "user_id": self.user_id
 | |
|             }
 | 
