feat(memory): integrate Mem0 for enhanced conversational memory
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
This commit is contained in:
@ -1,146 +1,81 @@
|
|||||||
from typing import Dict, Any, Tuple
|
from typing import Dict, Any, Optional
|
||||||
import base64
|
import base64
|
||||||
import threading
|
import threading
|
||||||
from haystack import Document, Pipeline
|
from datetime import datetime
|
||||||
from milvus_haystack import MilvusDocumentStore
|
|
||||||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
|
||||||
from haystack.utils import Secret
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
DEFAULT_USER_ID,
|
DEFAULT_USER_ID,
|
||||||
OPENAI_EMBEDDING_KEY,
|
MEM0_CONFIG,
|
||||||
OPENAI_EMBEDDING_MODEL,
|
|
||||||
OPENAI_EMBEDDING_BASE,
|
|
||||||
)
|
)
|
||||||
from haystack_rag.rag_pipeline import build_rag_pipeline
|
from api.doubao_tts import text_to_speech
|
||||||
from doubao_tts import text_to_speech
|
from memory_module.memory_integration import Mem0Integration
|
||||||
|
|
||||||
|
|
||||||
class ChatService:
|
class ChatService:
|
||||||
def __init__(self, user_id: str = None):
|
def __init__(self, user_id: str = None):
|
||||||
self.user_id = user_id or DEFAULT_USER_ID
|
self.user_id = user_id or DEFAULT_USER_ID
|
||||||
self.rag_pipeline = None
|
self.mem0_integration = Mem0Integration(MEM0_CONFIG)
|
||||||
self.document_store = None
|
|
||||||
self.document_embedder = None
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""初始化 RAG 管道和相关组件"""
|
"""初始化 Mem0 集成"""
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 构建 RAG 查询管道和获取 DocumentStore 实例
|
print(f"[INFO] Initializing Mem0 integration for user: {self.user_id}")
|
||||||
self.rag_pipeline, self.document_store = build_rag_pipeline(self.user_id)
|
|
||||||
|
|
||||||
# 初始化用于写入用户输入的 Document Embedder
|
|
||||||
self.document_embedder = OpenAIDocumentEmbedder(
|
|
||||||
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
|
|
||||||
model=OPENAI_EMBEDDING_MODEL,
|
|
||||||
api_base_url=OPENAI_EMBEDDING_BASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def _embed_and_store_async(self, user_input: str):
|
|
||||||
"""异步嵌入并存储用户输入"""
|
|
||||||
try:
|
|
||||||
# 步骤 1: 嵌入用户输入并写入 Milvus
|
|
||||||
user_doc_to_write = Document(content=user_input, meta={"user_id": self.user_id})
|
|
||||||
|
|
||||||
# 使用 OpenAIDocumentEmbedder 运行嵌入
|
|
||||||
embedding_result = self.document_embedder.run([user_doc_to_write])
|
|
||||||
embedded_docs = embedding_result.get("documents", [])
|
|
||||||
|
|
||||||
if embedded_docs:
|
|
||||||
# 将带有嵌入的文档写入 DocumentStore
|
|
||||||
self.document_store.write_documents(embedded_docs)
|
|
||||||
print(f"[INFO] 用户输入已成功嵌入并存储: {user_input[:50]}...")
|
|
||||||
else:
|
|
||||||
print(f"[WARNING] 用户输入嵌入失败: {user_input[:50]}...")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] 异步嵌入和存储过程出错: {e}")
|
|
||||||
|
|
||||||
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
|
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
|
||||||
"""处理用户输入并返回回复(包含音频)"""
|
"""处理用户输入并返回回复(包含音频)"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 步骤 1: 异步启动嵌入和存储过程(不阻塞主流程)
|
# Step 1: Get response with memory integration
|
||||||
embedding_thread = threading.Thread(
|
result = self.mem0_integration.generate_response_with_memory(
|
||||||
target=self._embed_and_store_async,
|
user_input=user_input,
|
||||||
args=(user_input,),
|
user_id=self.user_id
|
||||||
daemon=True
|
|
||||||
)
|
)
|
||||||
embedding_thread.start()
|
|
||||||
|
|
||||||
# 步骤 2: 立即使用 RAG 查询管道生成回复(不等待嵌入完成)
|
if not result["success"]:
|
||||||
pipeline_input = {
|
return {
|
||||||
"text_embedder": {"text": user_input},
|
"success": False,
|
||||||
"prompt_builder": {"query": user_input},
|
"error": result.get("error", "Unknown error"),
|
||||||
|
"user_id": self.user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行 RAG 查询管道
|
assistant_response = result["response"]
|
||||||
results = self.rag_pipeline.run(pipeline_input)
|
|
||||||
|
|
||||||
# 步骤 3: 处理并返回结果
|
# Step 2: Generate audio if requested
|
||||||
if "llm" in results and results["llm"]["replies"]:
|
|
||||||
answer = results["llm"]["replies"][0]
|
|
||||||
|
|
||||||
# 尝试获取 token 使用量
|
|
||||||
total_tokens = None
|
|
||||||
try:
|
|
||||||
if (
|
|
||||||
"meta" in results["llm"]
|
|
||||||
and isinstance(results["llm"]["meta"], list)
|
|
||||||
and results["llm"]["meta"]
|
|
||||||
):
|
|
||||||
usage_info = results["llm"]["meta"][0].get("usage", {})
|
|
||||||
total_tokens = usage_info.get("total_tokens")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 步骤 4: 生成语音(如果需要)
|
|
||||||
audio_data = None
|
audio_data = None
|
||||||
audio_error = None
|
audio_error = None
|
||||||
if include_audio:
|
if include_audio:
|
||||||
try:
|
try:
|
||||||
success, message, base64_audio = text_to_speech(answer, self.user_id)
|
success, message, base64_audio = text_to_speech(assistant_response, self.user_id)
|
||||||
if success and base64_audio:
|
if success and base64_audio:
|
||||||
# 直接使用 base64 音频数据
|
|
||||||
audio_data = base64_audio
|
audio_data = base64_audio
|
||||||
else:
|
else:
|
||||||
audio_error = message
|
audio_error = message
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
audio_error = f"TTS错误: {str(e)}"
|
audio_error = f"TTS错误: {str(e)}"
|
||||||
|
|
||||||
result = {
|
# Step 3: Prepare response
|
||||||
|
response_data = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"response": answer,
|
"response": assistant_response,
|
||||||
"user_id": self.user_id
|
"user_id": self.user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加可选字段
|
# Add optional fields
|
||||||
if total_tokens is not None:
|
|
||||||
result["tokens"] = total_tokens
|
|
||||||
if audio_data:
|
if audio_data:
|
||||||
result["audio_data"] = audio_data
|
response_data["audio_data"] = audio_data
|
||||||
if audio_error:
|
if audio_error:
|
||||||
result["audio_error"] = audio_error
|
response_data["audio_error"] = audio_error
|
||||||
|
|
||||||
return result
|
return response_data
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Could not generate an answer",
|
|
||||||
"debug_info": results,
|
|
||||||
"user_id": self.user_id
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
@ -148,3 +83,61 @@ class ChatService:
|
|||||||
"error": str(e),
|
"error": str(e),
|
||||||
"user_id": self.user_id
|
"user_id": self.user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_user_memories(self) -> Dict[str, Any]:
|
||||||
|
"""获取当前用户的所有记忆"""
|
||||||
|
try:
|
||||||
|
memories = self.mem0_integration.get_all_memories(self.user_id)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"memories": memories,
|
||||||
|
"count": len(memories),
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_user_memories(self) -> Dict[str, Any]:
|
||||||
|
"""清除当前用户的所有记忆"""
|
||||||
|
try:
|
||||||
|
success = self.mem0_integration.delete_all_memories(self.user_id)
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"所有记忆已清除,用户: {self.user_id}",
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "清除记忆失败",
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def search_memories(self, query: str, limit: int = 5) -> Dict[str, Any]:
|
||||||
|
"""搜索当前用户的记忆"""
|
||||||
|
try:
|
||||||
|
memories = self.mem0_integration.search_memories(query, self.user_id, limit)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"memories": memories,
|
||||||
|
"count": len(memories),
|
||||||
|
"query": query,
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"user_id": self.user_id
|
||||||
|
}
|
||||||
|
@ -25,8 +25,8 @@ class ChatResponse(BaseModel):
|
|||||||
|
|
||||||
# 创建 FastAPI 应用
|
# 创建 FastAPI 应用
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Haystack RAG API",
|
title="Mem0 Memory API",
|
||||||
description="基于 Haystack 的 RAG 聊天服务 API",
|
description="基于 Mem0 的记忆增强聊天服务 API",
|
||||||
version="1.0.0"
|
version="1.0.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ async def startup_event():
|
|||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
"""根路径,返回 API 信息"""
|
"""根路径,返回 API 信息"""
|
||||||
return {"message": "Haystack RAG API is running", "version": "1.0.0"}
|
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
@ -52,17 +52,31 @@ MILVUS_INDEX_PARAMS = {"index_type": "FLAT", "metric_type": "L2", "params": {}}
|
|||||||
MILVUS_SEARCH_PARAMS = {"metric_type": "L2", "params": {}}
|
MILVUS_SEARCH_PARAMS = {"metric_type": "L2", "params": {}}
|
||||||
MILVUS_STAND_URI = ""
|
MILVUS_STAND_URI = ""
|
||||||
|
|
||||||
# --- RAG Pipeline Configuration (保持不变) ---
|
MEM0_CONFIG = {
|
||||||
RETRIEVER_TOP_K = 3
|
"vector_store": {
|
||||||
DEFAULT_PROMPT_TEMPLATE = """
|
"provider": "milvus",
|
||||||
hello
|
"config": {
|
||||||
{% for doc in documents %}
|
"embedding_model_dims": 2048,
|
||||||
{{ doc.content }}
|
}
|
||||||
{% endfor %}
|
},
|
||||||
|
"llm": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"api_key": OPENAI_API_KEY_FROM_CONFIG,
|
||||||
|
"model": "doubao-seed-1-6-250615",
|
||||||
|
"openai_base_url": OPENAI_API_BASE_URL_CONFIG
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"embedder": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"api_key": OPENAI_EMBEDDING_KEY,
|
||||||
|
"model": "doubao-embedding-large-text-250515",
|
||||||
|
"openai_base_url": OPENAI_EMBEDDING_BASE
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
问题: {{query}}
|
|
||||||
答案:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# --- Application Settings (保持不变) ---
|
# --- Application Settings (保持不变) ---
|
||||||
DEFAULT_USER_ID = "user_openai"
|
DEFAULT_USER_ID = "user_openai"
|
||||||
|
86
memory_module/README.md
Normal file
86
memory_module/README.md
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Memory Module Integration
|
||||||
|
|
||||||
|
This module provides memory integration for the chat service using Mem0, allowing the system to remember user preferences and past conversations.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Persistent Memory**: Stores user interactions and preferences
|
||||||
|
- **Contextual Responses**: Uses stored memories to provide personalized responses
|
||||||
|
- **Memory Search**: Search through stored memories
|
||||||
|
- **Memory Management**: View and clear user memories
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Chat with Memory
|
||||||
|
|
||||||
|
```python
|
||||||
|
from api.chat_service import ChatService
|
||||||
|
|
||||||
|
# Initialize chat service
|
||||||
|
chat_service = ChatService("user_id")
|
||||||
|
chat_service.initialize()
|
||||||
|
|
||||||
|
# Send a message
|
||||||
|
result = chat_service.chat("My name is Alice and I love sci-fi movies")
|
||||||
|
print(result["response"])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Operations
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Get all memories for a user
|
||||||
|
memories = chat_service.get_user_memories()
|
||||||
|
|
||||||
|
# Search memories
|
||||||
|
search_results = chat_service.search_memories("movies")
|
||||||
|
|
||||||
|
# Clear all memories
|
||||||
|
chat_service.clear_user_memories()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The Mem0 configuration is defined in `config/config.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
MEM0_CONFIG = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "milvus",
|
||||||
|
"config": {
|
||||||
|
"embedding_model_dims": 2048,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"api_key": OPENAI_API_KEY_FROM_CONFIG,
|
||||||
|
"model": "doubao-seed-1-6-250615",
|
||||||
|
"openai_base_url": OPENAI_API_BASE_URL_CONFIG
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"embedder": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"api_key": OPENAI_EMBEDDING_KEY,
|
||||||
|
"model": "doubao-embedding-large-text-250515",
|
||||||
|
"openai_base_url": OPENAI_EMBEDDING_BASE
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. **Memory Retrieval**: When a user sends a message, the system searches for relevant memories about the user
|
||||||
|
2. **Enhanced Prompt**: The retrieved memories are formatted and included in the prompt to the LLM
|
||||||
|
3. **Response Generation**: The LLM generates a response considering the user's memories
|
||||||
|
4. **Memory Storage**: The conversation is automatically stored as new memories
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
The main API endpoints remain the same:
|
||||||
|
|
||||||
|
- `POST /chat` - Send a message and get a response
|
||||||
|
- `GET /health` - Health check
|
||||||
|
|
||||||
|
Additional memory management endpoints can be added to the main API if needed.
|
3
memory_module/__init__.py
Normal file
3
memory_module/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .memory_integration import Mem0Integration
|
||||||
|
|
||||||
|
__all__ = ["Mem0Integration"]
|
158
memory_module/memory_integration.py
Normal file
158
memory_module/memory_integration.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
import openai
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
|
||||||
|
class Mem0Integration:
|
||||||
|
"""Mem0 integration for memory retrieval and storage in RAG pipeline."""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""Initialize Mem0 with configuration."""
|
||||||
|
self.config = config
|
||||||
|
self.memory = Memory.from_config(config)
|
||||||
|
|
||||||
|
# Initialize OpenAI client for chat completion
|
||||||
|
self.openai_client = openai.OpenAI(
|
||||||
|
api_key=config["llm"]["config"]["api_key"],
|
||||||
|
base_url=config["llm"]["config"].get("openai_base_url")
|
||||||
|
)
|
||||||
|
self.llm_model = config["llm"]["config"]["model"]
|
||||||
|
|
||||||
|
# Memory prompt template
|
||||||
|
self.memory_template = """Based on the following memories about the user:
|
||||||
|
{memories}
|
||||||
|
|
||||||
|
Please respond to the user's query: {query}
|
||||||
|
|
||||||
|
In your response, consider the memories above to provide a personalized answer."""
|
||||||
|
|
||||||
|
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
"""Search for relevant memories about the user."""
|
||||||
|
try:
|
||||||
|
results = self.memory.search(
|
||||||
|
query=query,
|
||||||
|
user_id=user_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to search memories: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
|
"""Add a memory for the user."""
|
||||||
|
try:
|
||||||
|
result = self.memory.add(
|
||||||
|
messages=messages,
|
||||||
|
user_id=user_id,
|
||||||
|
metadata=metadata or {}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to add memory: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def format_memories_for_prompt(self, memories: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Format memories into a string for the prompt."""
|
||||||
|
if not memories:
|
||||||
|
return "No previous memories about this user."
|
||||||
|
|
||||||
|
formatted = []
|
||||||
|
for i, memory in enumerate(memories, 1):
|
||||||
|
memory_text = memory.get("memory", "")
|
||||||
|
created_at = memory.get("created_at", "")
|
||||||
|
if created_at:
|
||||||
|
try:
|
||||||
|
# Format the date if it's available
|
||||||
|
created_date = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
||||||
|
created_str = created_date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
except:
|
||||||
|
created_str = created_at
|
||||||
|
formatted.append(f"{i}. {memory_text} (remembered on: {created_str})")
|
||||||
|
else:
|
||||||
|
formatted.append(f"{i}. {memory_text}")
|
||||||
|
|
||||||
|
return "\n".join(formatted)
|
||||||
|
|
||||||
|
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
|
||||||
|
"""Generate a response using memories and store the interaction."""
|
||||||
|
# Step 1: Search for relevant memories
|
||||||
|
memories = self.search_memories(user_input, user_id)
|
||||||
|
|
||||||
|
# Step 2: Format memories for the prompt
|
||||||
|
formatted_memories = self.format_memories_for_prompt(memories)
|
||||||
|
|
||||||
|
# Step 3: Create the enhanced prompt
|
||||||
|
enhanced_prompt = self.memory_template.format(
|
||||||
|
memories=formatted_memories,
|
||||||
|
query=user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Generate response using OpenAI
|
||||||
|
try:
|
||||||
|
response = self.openai_client.chat.completions.create(
|
||||||
|
model=self.llm_model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant with access to user memories. Use the provided memories to personalize your responses."},
|
||||||
|
{"role": "user", "content": enhanced_prompt}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_response = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Step 5: Store the interaction as new memories
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": user_input},
|
||||||
|
{"role": "assistant", "content": assistant_response}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store with metadata including timestamp
|
||||||
|
metadata = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"type": "chat_interaction"
|
||||||
|
}
|
||||||
|
|
||||||
|
self.add_memory(messages, user_id, metadata)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"response": assistant_response,
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to generate response: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_memories(self, user_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all memories for a user."""
|
||||||
|
try:
|
||||||
|
memories = self.memory.get_all(user_id=user_id)
|
||||||
|
return memories
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to get all memories: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_memory(self, memory_id: str) -> bool:
|
||||||
|
"""Delete a specific memory."""
|
||||||
|
try:
|
||||||
|
self.memory.delete(memory_id)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to delete memory: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_all_memories(self, user_id: str) -> bool:
|
||||||
|
"""Delete all memories for a user."""
|
||||||
|
try:
|
||||||
|
self.memory.delete_all(user_id=user_id)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to delete all memories: {e}")
|
||||||
|
return False
|
@ -8,6 +8,7 @@ dependencies = [
|
|||||||
"fastapi>=0.115.12",
|
"fastapi>=0.115.12",
|
||||||
"haystack-ai>=2.12.1",
|
"haystack-ai>=2.12.1",
|
||||||
"huggingface-hub>=0.30.2",
|
"huggingface-hub>=0.30.2",
|
||||||
|
"mem0ai>=0.1.118",
|
||||||
"milvus-haystack>=0.0.15",
|
"milvus-haystack>=0.0.15",
|
||||||
"pydantic>=2.11.3",
|
"pydantic>=2.11.3",
|
||||||
"pymilvus>=2.5.6",
|
"pymilvus>=2.5.6",
|
||||||
|
78
test_mem0_service.py
Normal file
78
test_mem0_service.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test script for memory module-based chat service."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from api.chat_service import ChatService
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_service():
|
||||||
|
"""Test the memory module-based chat service."""
|
||||||
|
print("=== Testing Memory Module-based Chat Service ===\n")
|
||||||
|
|
||||||
|
# Initialize chat service
|
||||||
|
chat_service = ChatService("test_user")
|
||||||
|
chat_service.initialize()
|
||||||
|
|
||||||
|
# Test conversations
|
||||||
|
test_inputs = [
|
||||||
|
"Hi, my name is Alice and I love science fiction movies.",
|
||||||
|
"What kind of movies do I like?",
|
||||||
|
"I also enjoy reading science fiction books.",
|
||||||
|
"Tell me about my hobbies and interests.",
|
||||||
|
"I went to Paris last summer and loved it!",
|
||||||
|
"Where did I travel recently?"
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Starting conversation test...\n")
|
||||||
|
|
||||||
|
for i, user_input in enumerate(test_inputs, 1):
|
||||||
|
print(f"--- Test {i} ---")
|
||||||
|
print(f"User: {user_input}")
|
||||||
|
|
||||||
|
# Get response from chat service
|
||||||
|
result = chat_service.chat(user_input, include_audio=False)
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
print(f"Assistant: {result['response']}")
|
||||||
|
print(f"Status: Success")
|
||||||
|
else:
|
||||||
|
print(f"Error: {result['error']}")
|
||||||
|
print(f"Status: Failed")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test memory retrieval
|
||||||
|
print("\n--- Testing Memory Retrieval ---")
|
||||||
|
memories_result = chat_service.get_user_memories()
|
||||||
|
|
||||||
|
if memories_result["success"]:
|
||||||
|
print(f"Total memories stored: {memories_result['count']}")
|
||||||
|
print("\nStored memories:")
|
||||||
|
for i, memory in enumerate(memories_result["memories"], 1):
|
||||||
|
print(f"{i}. {memory.get('memory', 'N/A')}")
|
||||||
|
else:
|
||||||
|
print(f"Failed to retrieve memories: {memories_result['error']}")
|
||||||
|
|
||||||
|
# Test memory search
|
||||||
|
print("\n--- Testing Memory Search ---")
|
||||||
|
search_queries = ["movies", "travel", "hobbies"]
|
||||||
|
|
||||||
|
for query in search_queries:
|
||||||
|
print(f"\nSearching for '{query}':")
|
||||||
|
search_result = chat_service.search_memories(query)
|
||||||
|
|
||||||
|
if search_result["success"]:
|
||||||
|
print(f"Found {search_result['count']} memories:")
|
||||||
|
for memory in search_result["memories"]:
|
||||||
|
print(f"- {memory.get('memory', 'N/A')}")
|
||||||
|
else:
|
||||||
|
print(f"Search failed: {search_result['error']}")
|
||||||
|
|
||||||
|
print("\n=== Test completed ===")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_chat_service()
|
Reference in New Issue
Block a user