feat(memory): replace Mem0 with local RAG using SQLite, ChromaDB, and bge-small
Some checks failed
Build and Push Docker / build-and-push (push) Failing after 16m28s

This commit is contained in:
gameloader
2025-11-27 19:14:26 +08:00
parent b88a897151
commit a43c6a10ac
5 changed files with 2220 additions and 110 deletions

View File

@ -13,7 +13,7 @@ from config import (
MEM0_CONFIG, MEM0_CONFIG,
) )
from api.doubao_tts import text_to_speech from api.doubao_tts import text_to_speech
from memory_module.memory_integration import Mem0Integration from memory_module import LocalMemoryIntegration
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class ChatService: class ChatService:
def __init__(self, user_id: str = None): def __init__(self, user_id: str = None):
self.user_id = user_id or DEFAULT_USER_ID self.user_id = user_id or DEFAULT_USER_ID
self.mem0_integration = Mem0Integration(MEM0_CONFIG) self.mem0_integration = LocalMemoryIntegration(MEM0_CONFIG) # 使用新类
self._initialized = False self._initialized = False
def initialize(self): def initialize(self):

View File

@ -1,3 +1,3 @@
from .memory_integration import Mem0Integration from .memory_integration import LocalMemoryIntegration
__all__ = ["Mem0Integration"] __all__ = ["LocalMemoryIntegration"]

View File

@ -1,32 +1,47 @@
import os import os
import sqlite3
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from datetime import datetime from datetime import datetime
import json import json
import openai import openai # 我们仍然需要它来调用 豆包LLM
import threading import threading
from mem0 import Memory
import logging import logging
import textwrap import textwrap
# 使用你项目中的日志配置
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Mem0Integration: class LocalMemoryIntegration:
"""Mem0 integration for memory retrieval and storage in RAG pipeline.""" """
使用本地轻量级 RAG (SQLite + ChromaDB + bge-small) 替换 Mem0。
它保留了原始的 LLM 调用和复杂的 JSON Prompt 逻辑,
只将 RAG 检索和存储部分本地化。
"""
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
"""Initialize Mem0 with configuration.""" """
初始化本地记忆模块和 OpenAI (豆包) 客户端。
"""
self.config = config self.config = config
self.memory = Memory.from_config(config)
# Initialize OpenAI client for chat completion # 1. 初始化 LLM 客户端 (豆包) - 这部分逻辑保留
# config["llm"] 对应你 README 中的 LLM 配置
try:
self.openai_client = openai.OpenAI( self.openai_client = openai.OpenAI(
api_key=config["llm"]["config"]["api_key"], api_key=config["llm"]["config"]["api_key"],
base_url=config["llm"]["config"].get("openai_base_url") base_url=config["llm"]["config"].get("openai_base_url")
) )
self.llm_model = config["llm"]["config"]["model"] self.llm_model = config["llm"]["config"]["model"]
logger.info(f"OpenAI (Doubao) LLM client initialized for model: {self.llm_model}")
except KeyError as e:
logger.error(f"LLM config missing key: {e}. Check your MEM0_CONFIG['llm'].")
raise
# System prompt template aligned with virtual persona requirements # 2. 保留核心的 System Prompt (从你发的文件中完整复制)
self.system_prompt_template = textwrap.dedent(""" self.system_prompt_template = textwrap.dedent("""
你是虚拟人对话引擎。 你是虚拟人对话引擎。
@ -106,7 +121,7 @@ class Mem0Integration:
* **生气**(惹得虚拟人生气):回复时情绪带有生气/气死/离谱/无语/滚/别烦,动作触发"不满" * **生气**(惹得虚拟人生气):回复时情绪带有生气/气死/离谱/无语/滚/别烦,动作触发"不满"
* **否定**:不/并不/没那么 + 情绪词 → 降级或改判 动作触发"摇头" * **否定**:不/并不/没那么 + 情绪词 → 降级或改判    动作触发"摇头"
肯定:是/没错/肯定的+判断实时情绪,动作触发"点头" 肯定:是/没错/肯定的+判断实时情绪,动作触发"点头"
@ -173,50 +188,164 @@ class Mem0Integration:
``` ```
""").strip() """).strip()
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]: # 3. 初始化本地 SQLite (存储原始文本)
"""Search for relevant memories about the user.""" db_path = os.getenv('SQLITE_DB_PATH', 'local_memory.db')
self.db_conn = sqlite3.connect(db_path, check_same_thread=False)
self.db_conn.execute("""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
chroma_id TEXT NOT NULL UNIQUE,
timestamp TEXT NOT NULL
);
""")
self.db_conn.execute("CREATE INDEX IF NOT EXISTS idx_user_id ON memories (user_id);")
logger.info(f"SQLite database initialized at {db_path}")
# 4. 初始化本地 ChromaDB (存储向量)
chroma_path = os.getenv('CHROMA_DB_PATH', './chroma_db_store')
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
logger.info(f"ChromaDB persistent client initialized at {chroma_path}")
# 5. 初始化本地 Embedding 模型 (低内存占用)
model_name = 'bge-small-zh-v1.5' # 约 100MB 内存
try: try:
results = self.memory.search( logger.info(f"Loading local embedding model: {model_name}...")
query=query, self.embedding_model = SentenceTransformer(model_name, device='cpu')
user_id=user_id, logger.info("Local embedding model loaded successfully.")
limit=limit
)
# Handle dictionary response format - check for both 'memories' and 'results' keys
if isinstance(results, dict):
memories = results.get("memories", results.get("results", []))
return memories
else:
logger.error(f"Unexpected search results format: {type(results)}")
return []
except Exception as e: except Exception as e:
logger.error(f"Failed to search memories: {e}") logger.error(f"Failed to load embedding model '{model_name}': {e}")
raise
def _get_user_collection(self, user_id: str):
"""获取或创建该用户的 ChromaDB 集合"""
collection_name = f"user_memory_{user_id.replace('-', '_')}" # 确保 user_id
return self.chroma_client.get_or_create_collection(name=collection_name)
def _get_max_id(self, user_id: str):
"""获取该用户当前最大的 memory ID"""
try:
cursor = self.db_conn.execute("SELECT MAX(id) FROM memories WHERE user_id = ?", (user_id,))
max_id = cursor.fetchone()[0]
return max_id if max_id is not None else 0
except Exception as e:
logger.error(f"Failed to get max id from SQLite for user {user_id}: {e}")
return 0
# --- 实现 Mem0Integration 的公共接口 ---
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]:
"""
(本地实现) 搜索相关记忆。
返回一个字典列表,模仿 mem0 的输出格式。
"""
logger.debug(f"Searching local memories for user {user_id} with query: {query}")
try:
collection = self._get_user_collection(user_id)
if collection.count() == 0:
logger.debug("No memories found for this user.")
return []
# 1. 本地向量化查询
query_embedding = self.embedding_model.encode(query).tolist()
# 2. 本地 ChromaDB (KNN) 搜索
results = collection.query(
query_embeddings=[query_embedding],
n_results=min(limit, collection.count()) # 确保 n_results 不超过集合大小
)
# 3. 格式化为原函数期望的格式
formatted_results = []
if results and 'metadatas' in results and 'documents' in results:
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
formatted_results.append({
"memory": doc, # format_memories_for_prompt 会用这个 key
"content": meta.get("content", doc), # 备用
"role": meta.get("role", "unknown")
})
logger.debug(f"Found {len(formatted_results)} local memories.")
return formatted_results
except Exception as e:
logger.error(f"Failed to search local memories: {e}")
return [] return []
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]: def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
"""Add a memory for the user.""" """(本地实现) 添加新的记忆。"""
try: logger.debug(f"Adding local memory for user {user_id}")
# Debug: Log what we're trying to add if metadata is None:
logger.debug(f"Adding memory for user {user_id}") metadata = {}
logger.debug(f"Messages: {messages}")
logger.debug(f"Metadata: {metadata}")
result = self.memory.add( try:
messages=messages, collection = self._get_user_collection(user_id)
user_id=user_id,
metadata=metadata or {}, texts_to_embed = []
infer= False docs_to_store = []
metadatas_to_store = []
ids_to_store = []
current_max_id = self._get_max_id(user_id)
for i, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if not role or not content:
continue
current_max_id += 1
chroma_id = f"mem_{user_id}_{current_max_id}"
timestamp = metadata.get("timestamp", datetime.now().isoformat())
# 1. 存入 SQLite
self.db_conn.execute(
"INSERT INTO memories (user_id, role, content, chroma_id, timestamp) VALUES (?, ?, ?, ?, ?)",
(user_id, role, content, chroma_id, timestamp)
) )
# Debug: Log the result # 2. 准备存入 ChromaDB
logger.debug(f"Add memory result: {result}") full_text = f"{role}: {content}" # 向量化的文本
return result texts_to_embed.append(full_text)
docs_to_store.append(full_text) # 存在 chroma 的 document 字段
metadatas_to_store.append({ # 存在 chroma 的 metadata 字段
"role": role,
"content": content, # 存一份原始内容
"timestamp": timestamp
})
ids_to_store.append(chroma_id)
if not texts_to_embed:
return {"success": True, "message": "No valid messages to add."}
embeddings = self.embedding_model.encode(
texts_to_embed,
batch_size=8
).tolist()
# 3. 批量存入 ChromaDB
collection.add(
embeddings=embeddings,
documents=docs_to_store,
metadatas=metadatas_to_store,
ids=ids_to_store
)
self.db_conn.commit()
logger.debug(f"Successfully added {len(texts_to_embed)} local memories.")
return {"success": True, "message": f"Added {len(texts_to_embed)} memories."}
except Exception as e: except Exception as e:
logger.error(f"Failed to add memory: {e}") logger.error(f"Error in add_memory (local): {e}")
logger.exception("Exception details:") self.db_conn.rollback()
return {} return {"success": False, "error": str(e)}
# --- 保留你原有的辅助函数和核心逻辑 ---
def _extract_reply_for_memory(self, assistant_response: Any) -> str: def _extract_reply_for_memory(self, assistant_response: Any) -> str:
"""Extract the assistant reply text from structured responses for memory storage.""" """(原样保留) 从结构化响应中提取用于记忆的助手回复文本。"""
if assistant_response is None: if assistant_response is None:
return "" return ""
@ -228,21 +357,32 @@ class Mem0Integration:
return "" return ""
try: try:
# 尝试解析 JSON, 兼容你 `chat_service` 可能遇到的 "```json\n{...}\n```"
if raw_text.startswith("```json"):
raw_text = raw_text.split("```json\n", 1)[-1].split("\n```")[0]
data = json.loads(raw_text) data = json.loads(raw_text)
reply = data.get("reply", "") reply = data.get("reply", "")
reply_str = str(reply).strip() reply_str = str(reply).strip()
# 如果 'reply' 为空,也返回原始文本,以防 LLM 返回了非 JSON 格式的有效回复
return reply_str if reply_str else raw_text return reply_str if reply_str else raw_text
except Exception: except Exception:
# 如果解析失败,返回原始文本
return raw_text return raw_text
def format_memories_for_prompt(self, memories: List[Any]) -> str: def format_memories_for_prompt(self, memories: List[Any]) -> str:
"""Format memories into bullet points for injection into the system prompt.""" """
(轻微修改) 将记忆格式化为项目符号点,以便注入系统提示。
严格遵守 Prompt 中的 "MEM_INJECT_TOPK: 12 条" 规则。
"""
if not memories: if not memories:
return "" return ""
formatted = [] formatted = []
for memory in memories: # 只取最相关的 2 条
for memory in memories[:2]:
if isinstance(memory, dict): if isinstance(memory, dict):
# 我们的 search_memories 返回 'memory' 键
memory_text = memory.get("memory") or memory.get("content") or "" memory_text = memory.get("memory") or memory.get("content") or ""
elif isinstance(memory, str): elif isinstance(memory, str):
memory_text = memory memory_text = memory
@ -256,17 +396,20 @@ class Mem0Integration:
return "\n".join(formatted) return "\n".join(formatted)
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]: def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
"""Generate a response using memories and store the interaction.""" """
# Step 1: Search for relevant memories (逻辑保留,调用本地) 使用记忆生成响应并存储交互。
"""
# 步骤 1: 搜索相关记忆 (调用我们的本地搜索)
# Prompt 建议 1-2 条,我们搜 5 条format_memories_for_prompt 会帮我们筛选
memories = self.search_memories(user_input, user_id, limit=5) memories = self.search_memories(user_input, user_id, limit=5)
# Step 2: Prepare system prompt with memory injection # 步骤 2: 准备系统提示 (逻辑保留)
memory_block = self.format_memories_for_prompt(memories) memory_block = self.format_memories_for_prompt(memories)
system_prompt = self.system_prompt_template.replace( system_prompt = self.system_prompt_template.replace(
"{memory_block}", memory_block if memory_block else "" "{memory_block}", memory_block if memory_block else ""
).strip() ).strip()
# Step 3: Generate response using OpenAI # 步骤 3: 生成响应 (逻辑保留 - 仍然调用豆包 LLM)
try: try:
response = self.openai_client.chat.completions.create( response = self.openai_client.chat.completions.create(
model=self.llm_model, model=self.llm_model,
@ -274,92 +417,112 @@ class Mem0Integration:
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": user_input} {"role": "user", "content": user_input}
], ],
reasoning_effort="minimal", # reasoning_effort="minimal", # 你原始代码里的,如果豆包不支持就删掉
# 推荐开启 JSON 模式,让豆包强制输出 JSON
response_format={"type": "json_object"}
) )
assistant_response = response.choices[0].message.content assistant_response = response.choices[0].message.content
reply_for_memory = self._extract_reply_for_memory(assistant_response) reply_for_memory = self._extract_reply_for_memory(assistant_response)
if not reply_for_memory: if not reply_for_memory: # 确保不存空记忆
reply_for_memory = assistant_response reply_for_memory = assistant_response
# Step 5: Store the interaction as new memories (异步执行) # 步骤 5: 异步存储交互 (逻辑保留 - 调用我们的本地 add_memory)
messages = [ messages = [
{"role": "user", "content": user_input}, {"role": "user", "content": user_input},
{"role": "assistant", "content": reply_for_memory} {"role": "assistant", "content": reply_for_memory} # 只存回复文本
] ]
# Store with metadata including timestamp
metadata = { metadata = {
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"type": "chat_interaction" "type": "chat_interaction"
} }
# 异步存储记忆,不阻塞主流程 # (原样保留) 启动异步线程存储记忆
def store_memory_async(): def store_memory_async():
try: try:
self.add_memory(messages, user_id, metadata) self.add_memory(messages, user_id, metadata)
except Exception as e: except Exception as e:
print(f"[WARNING] Async memory storage failed: {e}") # 原代码是 print, 统一用 logger
logger.warning(f"[WARNING] Async local memory storage failed: {e}")
# 启动异步线程存储记忆
memory_thread = threading.Thread(target=store_memory_async, daemon=True) memory_thread = threading.Thread(target=store_memory_async, daemon=True)
memory_thread.start() memory_thread.start()
return { return {
"success": True, "success": True,
"response": assistant_response, "response": assistant_response, # 返回完整的 JSON 字符串给 chat_service
"user_id": user_id "user_id": user_id
} }
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to generate response: {e}") logger.error(f"[ERROR] Failed to generate response: {e}")
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"user_id": user_id "user_id": user_id
} }
# --- 其他辅助 API 的本地实现 ---
def get_all_memories(self, user_id: str) -> List[Any]: def get_all_memories(self, user_id: str) -> List[Any]:
"""Get all memories for a user.""" """(本地实现) 从 SQLite 获取所有原始文本历史"""
logger.debug(f"Getting all local memories for user {user_id}")
try: try:
memories = self.memory.get_all(user_id=user_id) cursor = self.db_conn.execute(
# Handle dictionary response format "SELECT id, role, content, timestamp, chroma_id FROM memories WHERE user_id = ? ORDER BY id ASC",
if isinstance(memories, dict): (user_id,)
# Check for different possible key names )
if "memories" in memories: return [{
memories_list = memories.get("memories", []) "id": row[4], # 返回 chroma_id 作为唯一标识符
elif "results" in memories: "sqlite_id": row[0],
memories_list = memories.get("results", []) "memory": f"{row[1]}: {row[2]}",
else: "role": row[1],
# Try to find any list in the dict "content": row[2],
for key, value in memories.items(): "timestamp": row[3]
if isinstance(value, list): } for row in cursor.fetchall()]
memories_list = value
break
else:
memories_list = []
return memories_list
else:
print(f"[ERROR] Unexpected memories format: {type(memories)}")
return []
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to get all memories: {e}") logger.error(f"Error in get_all_memories (local): {e}")
return [] return []
def delete_memory(self, memory_id: str) -> bool: def delete_memory(self, memory_id: str) -> bool:
"""Delete a specific memory.""" """(本地实现) 删除特定记忆 (memory_id 应该是 chroma_id)"""
logger.debug(f"Deleting local memory: {memory_id}")
try: try:
self.memory.delete(memory_id) # 1. 从 SQLite 删除
cursor = self.db_conn.execute("SELECT user_id FROM memories WHERE chroma_id = ?", (memory_id,))
row = cursor.fetchone()
if not row:
logger.warning(f"Memory ID {memory_id} not found in SQLite.")
return False
user_id = row[0]
self.db_conn.execute("DELETE FROM memories WHERE chroma_id = ?", (memory_id,))
# 2. 从 ChromaDB 删除
collection = self._get_user_collection(user_id)
collection.delete(ids=[memory_id])
self.db_conn.commit()
return True return True
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to delete memory: {e}") logger.error(f"Failed to delete memory (local): {e}")
self.db_conn.rollback()
return False return False
def delete_all_memories(self, user_id: str) -> bool: def delete_all_memories(self, user_id: str) -> bool:
"""Delete all memories for a user.""" """(本地实现) 删除用户的所有记忆"""
logger.debug(f"Deleting all local memories for user {user_id}")
try: try:
self.memory.delete_all(user_id=user_id) # 1. 从 SQLite 删除
self.db_conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,))
self.db_conn.commit()
# 2. 从 ChromaDB 删除 (删除整个 collection)
collection_name = f"user_memory_{user_id.replace('-', '_')}"
self.chroma_client.delete_collection(name=collection_name)
return True return True
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to delete all memories: {e}") logger.error(f"Failed to delete all memories (local): {e}")
self.db_conn.rollback()
return False return False

View File

@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"chromadb>=1.3.5",
"fastapi>=0.115.12", "fastapi>=0.115.12",
"haystack-ai>=2.12.1", "haystack-ai>=2.12.1",
"huggingface-hub>=0.30.2", "huggingface-hub>=0.30.2",
@ -12,6 +13,7 @@ dependencies = [
"milvus-haystack>=0.0.15", "milvus-haystack>=0.0.15",
"pydantic>=2.11.3", "pydantic>=2.11.3",
"pymilvus>=2.5.6", "pymilvus>=2.5.6",
"sentence-transformers>=5.1.2",
"uvicorn>=0.34.0", "uvicorn>=0.34.0",
] ]
[[tool.uv.index]] [[tool.uv.index]]

1951
uv.lock generated

File diff suppressed because it is too large Load Diff