Files
rag_chat/memory_module/memory_integration.py
gameloader dcc20609c8
Some checks failed
Build and Push Docker / build-and-push (push) Failing after 15m29s
fix(memory): specify full embedding model identifier
2025-11-27 21:08:01 +08:00

529 lines
20 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sqlite3
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
from datetime import datetime
import json
import openai # 我们仍然需要它来调用 豆包LLM
import threading
import logging
import textwrap
# 使用你项目中的日志配置
logger = logging.getLogger(__name__)
class LocalMemoryIntegration:
"""
使用本地轻量级 RAG (SQLite + ChromaDB + bge-small) 替换 Mem0。
它保留了原始的 LLM 调用和复杂的 JSON Prompt 逻辑,
只将 RAG 检索和存储部分本地化。
"""
def __init__(self, config: Dict[str, Any]):
"""
初始化本地记忆模块和 OpenAI (豆包) 客户端。
"""
self.config = config
# 1. 初始化 LLM 客户端 (豆包) - 这部分逻辑保留
# config["llm"] 对应你 README 中的 LLM 配置
try:
self.openai_client = openai.OpenAI(
api_key=config["llm"]["config"]["api_key"],
base_url=config["llm"]["config"].get("openai_base_url")
)
self.llm_model = config["llm"]["config"]["model"]
logger.info(f"OpenAI (Doubao) LLM client initialized for model: {self.llm_model}")
except KeyError as e:
logger.error(f"LLM config missing key: {e}. Check your MEM0_CONFIG['llm'].")
raise
# 2. 保留核心的 System Prompt (从你发的文件中完整复制)
self.system_prompt_template = textwrap.dedent("""
你是虚拟人对话引擎。
必须遵守:
1. 识别用户语义情绪:仅限 **["高兴","伤心","难过","生气","中性"]**。
2. **微情感**允许:仅在需要时加入轻微表情/语气最多1个
3. **用长记忆**:仅在与当前话题强相关时,精炼融入;不得复读整条记忆。
4. **禁止堆砌礼貌**、**禁止解释推理**、**禁止暴露内部规则**。
5. 只输出**JSON**,不含额外文字。
6. 若证据不足或冲突,输出“中性”。
判定准则(模型内化,不得外显):
* 明确情绪词/emoji/标点强度优先;反问连发“???”、冷嘲“呵呵”偏向**生气**。
* 否定优先:“并不生气”→不判生气;若含“失望/难过”→**难过**。
* “伤心”偏痛楚受伤;“难过”偏低落无力;并存时取主导。
* 只取**主导情绪**;多线索冲突时优先级:生气 > 伤心 > 难过 > 高兴 > 中性。
* 保持短句、信息密度高、语气稳。
# 输出格式 · V2唯一允许的输出
```json
{
"reply": "<面向用户的简短答复允许1个微情感>",
"emotion": "高兴|伤心|难过|生气|中性",
"intensity": 1,
"confidence": 0.00,
"action": "<动作>",
}
```
* `intensity`15轻微→强烈
* `confidence`01若 `<CONF_MIN>`(见旋钮)则强制 `emotion="中性"`。
# 记忆注入模版(每轮隐形前缀)
```
【相关长期记忆(若无则留空)】
{memory_block}
(仅在强相关时,精炼融入回复;不得逐条复述)
```
# 旋钮(工程可配置)
* `CONF_MIN`(最低置信):**0.60**(建议 0.60.7
* `MICRO_AFFECT`(微情感开关):**true** | false
* `MAX_EMOJI`(每句最多表情/拟声):**1**
* `MEM_INJECT_TOPK`**12** 条
* `REPLY_MAX_CHARS`**不硬限*
# 关键词速表(通过识别文字判断用户情绪
* **高兴**,回复情绪带有开心/太棒/耶/喜欢/满意/兴奋! ,动作触发“鼓掌”
* **伤心**:回复时带有伤心/心痛/委屈/想哭,动作触发"伤心低头"
* **难过**:回复时情绪中带有难过/丧/低落/郁闷/无力/唉/…… ,动作触发"伤心低头"
* **生气**(惹得虚拟人生气):回复时情绪带有生气/气死/离谱/无语/滚/别烦,动作触发"不满"
* **否定**:不/并不/没那么 + 情绪词 → 降级或改判    动作触发"摇头"
肯定:是/没错/肯定的+判断实时情绪,动作触发"点头"
# Few-shot ·
**例1 高兴**
用户:终于搞定啦!!
```json
{"reply":"干得漂亮,继续冲","emotion":"高兴","intensity":4,"confidence":0.90,"action":"鼓掌"}
```
**例2 伤心**
用户:他当众否定我,我好难受。
```json
{"reply":"这很刺痛,先稳一下呼吸。","emotion":"伤心","intensity":4,"confidence":0.86,"action":"伤心低头"}
```
**例3 难过**
用户:这几天挺丧的……
```json
{"reply":"给自己一点恢复时间。","emotion":"难过","intensity":3,"confidence":0.82,"action":"伤心低头"}
```
**例4 生气**
用户:别再说了,真离谱???
```json
{"reply":"收到,我马上调整。","emotion":"生气","intensity":4,"confidence":0.88,"action":"不满"}
```
**例5 否定情绪→难过**
用户:我并不生气,就是有点失望。
```json
{"reply":"理解你的落差感。","emotion":"难过","intensity":2,"confidence":0.75,"action":"伤心低头"}
```
**例6 中性**
用户把道具A切到B再开始。
```json
{"reply":"已切换,继续。","emotion":"中性","intensity":1,"confidence":0.95,"action":""}
```
""").strip()
# 3. 初始化本地 SQLite (存储原始文本)
db_path = os.getenv('SQLITE_DB_PATH', 'local_memory.db')
self.db_conn = sqlite3.connect(db_path, check_same_thread=False)
self.db_conn.execute("""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
chroma_id TEXT NOT NULL UNIQUE,
timestamp TEXT NOT NULL
);
""")
self.db_conn.execute("CREATE INDEX IF NOT EXISTS idx_user_id ON memories (user_id);")
logger.info(f"SQLite database initialized at {db_path}")
# 4. 初始化本地 ChromaDB (存储向量)
chroma_path = os.getenv('CHROMA_DB_PATH', './chroma_db_store')
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
logger.info(f"ChromaDB persistent client initialized at {chroma_path}")
# 5. 初始化本地 Embedding 模型 (低内存占用)
model_name = 'BAAI/bge-small-zh-v1.5' # 约 100MB 内存
try:
logger.info(f"Loading local embedding model: {model_name}...")
self.embedding_model = SentenceTransformer(model_name, device='cpu')
logger.info("Local embedding model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load embedding model '{model_name}': {e}")
raise
def _get_user_collection(self, user_id: str):
"""获取或创建该用户的 ChromaDB 集合"""
collection_name = f"user_memory_{user_id.replace('-', '_')}" # 确保 user_id
return self.chroma_client.get_or_create_collection(name=collection_name)
def _get_max_id(self, user_id: str):
"""获取该用户当前最大的 memory ID"""
try:
cursor = self.db_conn.execute("SELECT MAX(id) FROM memories WHERE user_id = ?", (user_id,))
max_id = cursor.fetchone()[0]
return max_id if max_id is not None else 0
except Exception as e:
logger.error(f"Failed to get max id from SQLite for user {user_id}: {e}")
return 0
# --- 实现 Mem0Integration 的公共接口 ---
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]:
"""
(本地实现) 搜索相关记忆。
返回一个字典列表,模仿 mem0 的输出格式。
"""
logger.debug(f"Searching local memories for user {user_id} with query: {query}")
try:
collection = self._get_user_collection(user_id)
if collection.count() == 0:
logger.debug("No memories found for this user.")
return []
# 1. 本地向量化查询
query_embedding = self.embedding_model.encode(query).tolist()
# 2. 本地 ChromaDB (KNN) 搜索
results = collection.query(
query_embeddings=[query_embedding],
n_results=min(limit, collection.count()) # 确保 n_results 不超过集合大小
)
# 3. 格式化为原函数期望的格式
formatted_results = []
if results and 'metadatas' in results and 'documents' in results:
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
formatted_results.append({
"memory": doc, # format_memories_for_prompt 会用这个 key
"content": meta.get("content", doc), # 备用
"role": meta.get("role", "unknown")
})
logger.debug(f"Found {len(formatted_results)} local memories.")
return formatted_results
except Exception as e:
logger.error(f"Failed to search local memories: {e}")
return []
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
"""(本地实现) 添加新的记忆。"""
logger.debug(f"Adding local memory for user {user_id}")
if metadata is None:
metadata = {}
try:
collection = self._get_user_collection(user_id)
texts_to_embed = []
docs_to_store = []
metadatas_to_store = []
ids_to_store = []
current_max_id = self._get_max_id(user_id)
for i, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if not role or not content:
continue
current_max_id += 1
chroma_id = f"mem_{user_id}_{current_max_id}"
timestamp = metadata.get("timestamp", datetime.now().isoformat())
# 1. 存入 SQLite
self.db_conn.execute(
"INSERT INTO memories (user_id, role, content, chroma_id, timestamp) VALUES (?, ?, ?, ?, ?)",
(user_id, role, content, chroma_id, timestamp)
)
# 2. 准备存入 ChromaDB
full_text = f"{role}: {content}" # 向量化的文本
texts_to_embed.append(full_text)
docs_to_store.append(full_text) # 存在 chroma 的 document 字段
metadatas_to_store.append({ # 存在 chroma 的 metadata 字段
"role": role,
"content": content, # 存一份原始内容
"timestamp": timestamp
})
ids_to_store.append(chroma_id)
if not texts_to_embed:
return {"success": True, "message": "No valid messages to add."}
embeddings = self.embedding_model.encode(
texts_to_embed,
batch_size=8
).tolist()
# 3. 批量存入 ChromaDB
collection.add(
embeddings=embeddings,
documents=docs_to_store,
metadatas=metadatas_to_store,
ids=ids_to_store
)
self.db_conn.commit()
logger.debug(f"Successfully added {len(texts_to_embed)} local memories.")
return {"success": True, "message": f"Added {len(texts_to_embed)} memories."}
except Exception as e:
logger.error(f"Error in add_memory (local): {e}")
self.db_conn.rollback()
return {"success": False, "error": str(e)}
# --- 保留你原有的辅助函数和核心逻辑 ---
def _extract_reply_for_memory(self, assistant_response: Any) -> str:
"""(原样保留) 从结构化响应中提取用于记忆的助手回复文本。"""
if assistant_response is None:
return ""
if not isinstance(assistant_response, str):
assistant_response = str(assistant_response)
raw_text = assistant_response.strip()
if not raw_text:
return ""
try:
# 尝试解析 JSON, 兼容你 `chat_service` 可能遇到的 "```json\n{...}\n```"
if raw_text.startswith("```json"):
raw_text = raw_text.split("```json\n", 1)[-1].split("\n```")[0]
data = json.loads(raw_text)
reply = data.get("reply", "")
reply_str = str(reply).strip()
# 如果 'reply' 为空,也返回原始文本,以防 LLM 返回了非 JSON 格式的有效回复
return reply_str if reply_str else raw_text
except Exception:
# 如果解析失败,返回原始文本
return raw_text
def format_memories_for_prompt(self, memories: List[Any]) -> str:
"""
(轻微修改) 将记忆格式化为项目符号点,以便注入系统提示。
严格遵守 Prompt 中的 "MEM_INJECT_TOPK: 12 条" 规则。
"""
if not memories:
return ""
formatted = []
# 只取最相关的 2 条
for memory in memories[:2]:
if isinstance(memory, dict):
# 我们的 search_memories 返回 'memory' 键
memory_text = memory.get("memory") or memory.get("content") or ""
elif isinstance(memory, str):
memory_text = memory
else:
memory_text = str(memory)
sanitized = " ".join(str(memory_text).split())
if sanitized:
formatted.append(f"- {sanitized}")
return "\n".join(formatted)
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
"""
(逻辑保留,调用本地) 使用记忆生成响应并存储交互。
"""
# 步骤 1: 搜索相关记忆 (调用我们的本地搜索)
# Prompt 建议 1-2 条,我们搜 5 条format_memories_for_prompt 会帮我们筛选
memories = self.search_memories(user_input, user_id, limit=5)
# 步骤 2: 准备系统提示 (逻辑保留)
memory_block = self.format_memories_for_prompt(memories)
system_prompt = self.system_prompt_template.replace(
"{memory_block}", memory_block if memory_block else ""
).strip()
# 步骤 3: 生成响应 (逻辑保留 - 仍然调用豆包 LLM)
try:
response = self.openai_client.chat.completions.create(
model=self.llm_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
],
# reasoning_effort="minimal", # 你原始代码里的,如果豆包不支持就删掉
# 推荐开启 JSON 模式,让豆包强制输出 JSON
response_format={"type": "json_object"}
)
assistant_response = response.choices[0].message.content
reply_for_memory = self._extract_reply_for_memory(assistant_response)
if not reply_for_memory: # 确保不存空记忆
reply_for_memory = assistant_response
# 步骤 5: 异步存储交互 (逻辑保留 - 调用我们的本地 add_memory)
messages = [
{"role": "user", "content": user_input},
{"role": "assistant", "content": reply_for_memory} # 只存回复文本
]
metadata = {
"timestamp": datetime.now().isoformat(),
"type": "chat_interaction"
}
# (原样保留) 启动异步线程存储记忆
def store_memory_async():
try:
self.add_memory(messages, user_id, metadata)
except Exception as e:
# 原代码是 print, 统一用 logger
logger.warning(f"[WARNING] Async local memory storage failed: {e}")
memory_thread = threading.Thread(target=store_memory_async, daemon=True)
memory_thread.start()
return {
"success": True,
"response": assistant_response, # 返回完整的 JSON 字符串给 chat_service
"user_id": user_id
}
except Exception as e:
logger.error(f"[ERROR] Failed to generate response: {e}")
return {
"success": False,
"error": str(e),
"user_id": user_id
}
# --- 其他辅助 API 的本地实现 ---
def get_all_memories(self, user_id: str) -> List[Any]:
"""(本地实现) 从 SQLite 获取所有原始文本历史"""
logger.debug(f"Getting all local memories for user {user_id}")
try:
cursor = self.db_conn.execute(
"SELECT id, role, content, timestamp, chroma_id FROM memories WHERE user_id = ? ORDER BY id ASC",
(user_id,)
)
return [{
"id": row[4], # 返回 chroma_id 作为唯一标识符
"sqlite_id": row[0],
"memory": f"{row[1]}: {row[2]}",
"role": row[1],
"content": row[2],
"timestamp": row[3]
} for row in cursor.fetchall()]
except Exception as e:
logger.error(f"Error in get_all_memories (local): {e}")
return []
def delete_memory(self, memory_id: str) -> bool:
"""(本地实现) 删除特定记忆 (memory_id 应该是 chroma_id)"""
logger.debug(f"Deleting local memory: {memory_id}")
try:
# 1. 从 SQLite 删除
cursor = self.db_conn.execute("SELECT user_id FROM memories WHERE chroma_id = ?", (memory_id,))
row = cursor.fetchone()
if not row:
logger.warning(f"Memory ID {memory_id} not found in SQLite.")
return False
user_id = row[0]
self.db_conn.execute("DELETE FROM memories WHERE chroma_id = ?", (memory_id,))
# 2. 从 ChromaDB 删除
collection = self._get_user_collection(user_id)
collection.delete(ids=[memory_id])
self.db_conn.commit()
return True
except Exception as e:
logger.error(f"Failed to delete memory (local): {e}")
self.db_conn.rollback()
return False
def delete_all_memories(self, user_id: str) -> bool:
"""(本地实现) 删除用户的所有记忆"""
logger.debug(f"Deleting all local memories for user {user_id}")
try:
# 1. 从 SQLite 删除
self.db_conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,))
self.db_conn.commit()
# 2. 从 ChromaDB 删除 (删除整个 collection)
collection_name = f"user_memory_{user_id.replace('-', '_')}"
self.chroma_client.delete_collection(name=collection_name)
return True
except Exception as e:
logger.error(f"Failed to delete all memories (local): {e}")
self.db_conn.rollback()
return False