feat(memory): integrate Mem0 for enhanced conversational memory
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
This commit is contained in:
@ -1,147 +1,140 @@
|
||||
from typing import Dict, Any, Tuple
|
||||
from typing import Dict, Any, Optional
|
||||
import base64
|
||||
import threading
|
||||
from haystack import Document, Pipeline
|
||||
from milvus_haystack import MilvusDocumentStore
|
||||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config import (
|
||||
DEFAULT_USER_ID,
|
||||
OPENAI_EMBEDDING_KEY,
|
||||
OPENAI_EMBEDDING_MODEL,
|
||||
OPENAI_EMBEDDING_BASE,
|
||||
MEM0_CONFIG,
|
||||
)
|
||||
from haystack_rag.rag_pipeline import build_rag_pipeline
|
||||
from doubao_tts import text_to_speech
|
||||
from api.doubao_tts import text_to_speech
|
||||
from memory_module.memory_integration import Mem0Integration
|
||||
|
||||
|
||||
class ChatService:
|
||||
def __init__(self, user_id: str = None):
|
||||
self.user_id = user_id or DEFAULT_USER_ID
|
||||
self.rag_pipeline = None
|
||||
self.document_store = None
|
||||
self.document_embedder = None
|
||||
self.mem0_integration = Mem0Integration(MEM0_CONFIG)
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""初始化 RAG 管道和相关组件"""
|
||||
"""初始化 Mem0 集成"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 构建 RAG 查询管道和获取 DocumentStore 实例
|
||||
self.rag_pipeline, self.document_store = build_rag_pipeline(self.user_id)
|
||||
|
||||
# 初始化用于写入用户输入的 Document Embedder
|
||||
self.document_embedder = OpenAIDocumentEmbedder(
|
||||
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
|
||||
model=OPENAI_EMBEDDING_MODEL,
|
||||
api_base_url=OPENAI_EMBEDDING_BASE,
|
||||
)
|
||||
|
||||
print(f"[INFO] Initializing Mem0 integration for user: {self.user_id}")
|
||||
self._initialized = True
|
||||
|
||||
def _embed_and_store_async(self, user_input: str):
|
||||
"""异步嵌入并存储用户输入"""
|
||||
try:
|
||||
# 步骤 1: 嵌入用户输入并写入 Milvus
|
||||
user_doc_to_write = Document(content=user_input, meta={"user_id": self.user_id})
|
||||
|
||||
# 使用 OpenAIDocumentEmbedder 运行嵌入
|
||||
embedding_result = self.document_embedder.run([user_doc_to_write])
|
||||
embedded_docs = embedding_result.get("documents", [])
|
||||
|
||||
if embedded_docs:
|
||||
# 将带有嵌入的文档写入 DocumentStore
|
||||
self.document_store.write_documents(embedded_docs)
|
||||
print(f"[INFO] 用户输入已成功嵌入并存储: {user_input[:50]}...")
|
||||
else:
|
||||
print(f"[WARNING] 用户输入嵌入失败: {user_input[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 异步嵌入和存储过程出错: {e}")
|
||||
|
||||
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
|
||||
"""处理用户输入并返回回复(包含音频)"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
try:
|
||||
# 步骤 1: 异步启动嵌入和存储过程(不阻塞主流程)
|
||||
embedding_thread = threading.Thread(
|
||||
target=self._embed_and_store_async,
|
||||
args=(user_input,),
|
||||
daemon=True
|
||||
# Step 1: Get response with memory integration
|
||||
result = self.mem0_integration.generate_response_with_memory(
|
||||
user_input=user_input,
|
||||
user_id=self.user_id
|
||||
)
|
||||
embedding_thread.start()
|
||||
|
||||
# 步骤 2: 立即使用 RAG 查询管道生成回复(不等待嵌入完成)
|
||||
pipeline_input = {
|
||||
"text_embedder": {"text": user_input},
|
||||
"prompt_builder": {"query": user_input},
|
||||
}
|
||||
|
||||
# 运行 RAG 查询管道
|
||||
results = self.rag_pipeline.run(pipeline_input)
|
||||
|
||||
# 步骤 3: 处理并返回结果
|
||||
if "llm" in results and results["llm"]["replies"]:
|
||||
answer = results["llm"]["replies"][0]
|
||||
|
||||
# 尝试获取 token 使用量
|
||||
total_tokens = None
|
||||
try:
|
||||
if (
|
||||
"meta" in results["llm"]
|
||||
and isinstance(results["llm"]["meta"], list)
|
||||
and results["llm"]["meta"]
|
||||
):
|
||||
usage_info = results["llm"]["meta"][0].get("usage", {})
|
||||
total_tokens = usage_info.get("total_tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 步骤 4: 生成语音(如果需要)
|
||||
audio_data = None
|
||||
audio_error = None
|
||||
if include_audio:
|
||||
try:
|
||||
success, message, base64_audio = text_to_speech(answer, self.user_id)
|
||||
if success and base64_audio:
|
||||
# 直接使用 base64 音频数据
|
||||
audio_data = base64_audio
|
||||
else:
|
||||
audio_error = message
|
||||
except Exception as e:
|
||||
audio_error = f"TTS错误: {str(e)}"
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"response": answer,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if total_tokens is not None:
|
||||
result["tokens"] = total_tokens
|
||||
if audio_data:
|
||||
result["audio_data"] = audio_data
|
||||
if audio_error:
|
||||
result["audio_error"] = audio_error
|
||||
|
||||
return result
|
||||
else:
|
||||
if not result["success"]:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Could not generate an answer",
|
||||
"debug_info": results,
|
||||
"error": result.get("error", "Unknown error"),
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
|
||||
assistant_response = result["response"]
|
||||
|
||||
# Step 2: Generate audio if requested
|
||||
audio_data = None
|
||||
audio_error = None
|
||||
if include_audio:
|
||||
try:
|
||||
success, message, base64_audio = text_to_speech(assistant_response, self.user_id)
|
||||
if success and base64_audio:
|
||||
audio_data = base64_audio
|
||||
else:
|
||||
audio_error = message
|
||||
except Exception as e:
|
||||
audio_error = f"TTS错误: {str(e)}"
|
||||
|
||||
# Step 3: Prepare response
|
||||
response_data = {
|
||||
"success": True,
|
||||
"response": assistant_response,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if audio_data:
|
||||
response_data["audio_data"] = audio_data
|
||||
if audio_error:
|
||||
response_data["audio_error"] = audio_error
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
def get_user_memories(self) -> Dict[str, Any]:
|
||||
"""获取当前用户的所有记忆"""
|
||||
try:
|
||||
memories = self.mem0_integration.get_all_memories(self.user_id)
|
||||
return {
|
||||
"success": True,
|
||||
"memories": memories,
|
||||
"count": len(memories),
|
||||
"user_id": self.user_id
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
def clear_user_memories(self) -> Dict[str, Any]:
|
||||
"""清除当前用户的所有记忆"""
|
||||
try:
|
||||
success = self.mem0_integration.delete_all_memories(self.user_id)
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"所有记忆已清除,用户: {self.user_id}",
|
||||
"user_id": self.user_id
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "清除记忆失败",
|
||||
"user_id": self.user_id
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
def search_memories(self, query: str, limit: int = 5) -> Dict[str, Any]:
|
||||
"""搜索当前用户的记忆"""
|
||||
try:
|
||||
memories = self.mem0_integration.search_memories(query, self.user_id, limit)
|
||||
return {
|
||||
"success": True,
|
||||
"memories": memories,
|
||||
"count": len(memories),
|
||||
"query": query,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
|
@ -25,8 +25,8 @@ class ChatResponse(BaseModel):
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="Haystack RAG API",
|
||||
description="基于 Haystack 的 RAG 聊天服务 API",
|
||||
title="Mem0 Memory API",
|
||||
description="基于 Mem0 的记忆增强聊天服务 API",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
@ -43,7 +43,7 @@ async def startup_event():
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径,返回 API 信息"""
|
||||
return {"message": "Haystack RAG API is running", "version": "1.0.0"}
|
||||
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
Reference in New Issue
Block a user