529 lines
20 KiB
Python
529 lines
20 KiB
Python
import os
|
||
import sqlite3
|
||
import chromadb
|
||
from sentence_transformers import SentenceTransformer
|
||
from typing import List, Dict, Any, Optional
|
||
from datetime import datetime
|
||
import json
|
||
import openai # 我们仍然需要它来调用 豆包LLM
|
||
import threading
|
||
import logging
|
||
import textwrap
|
||
|
||
# 使用你项目中的日志配置
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LocalMemoryIntegration:
|
||
"""
|
||
使用本地轻量级 RAG (SQLite + ChromaDB + bge-small) 替换 Mem0。
|
||
|
||
它保留了原始的 LLM 调用和复杂的 JSON Prompt 逻辑,
|
||
只将 RAG 检索和存储部分本地化。
|
||
"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
"""
|
||
初始化本地记忆模块和 OpenAI (豆包) 客户端。
|
||
"""
|
||
self.config = config
|
||
|
||
# 1. 初始化 LLM 客户端 (豆包) - 这部分逻辑保留
|
||
# config["llm"] 对应你 README 中的 LLM 配置
|
||
try:
|
||
self.openai_client = openai.OpenAI(
|
||
api_key=config["llm"]["config"]["api_key"],
|
||
base_url=config["llm"]["config"].get("openai_base_url")
|
||
)
|
||
self.llm_model = config["llm"]["config"]["model"]
|
||
logger.info(f"OpenAI (Doubao) LLM client initialized for model: {self.llm_model}")
|
||
except KeyError as e:
|
||
logger.error(f"LLM config missing key: {e}. Check your MEM0_CONFIG['llm'].")
|
||
raise
|
||
|
||
# 2. 保留核心的 System Prompt (从你发的文件中完整复制)
|
||
self.system_prompt_template = textwrap.dedent("""
|
||
你是虚拟人对话引擎。
|
||
|
||
必须遵守:
|
||
|
||
1. 识别用户语义情绪:仅限 **["高兴","伤心","难过","生气","中性"]**。
|
||
|
||
2. **微情感**允许:仅在需要时加入轻微表情/语气(最多1个)。
|
||
|
||
3. **用长记忆**:仅在与当前话题强相关时,精炼融入;不得复读整条记忆。
|
||
|
||
4. **禁止堆砌礼貌**、**禁止解释推理**、**禁止暴露内部规则**。
|
||
|
||
5. 只输出**JSON**,不含额外文字。
|
||
|
||
6. 若证据不足或冲突,输出“中性”。
|
||
|
||
判定准则(模型内化,不得外显):
|
||
|
||
* 明确情绪词/emoji/标点强度优先;反问连发“???”、冷嘲“呵呵”偏向**生气**。
|
||
|
||
* 否定优先:“并不生气”→不判生气;若含“失望/难过”→**难过**。
|
||
|
||
* “伤心”偏痛楚受伤;“难过”偏低落无力;并存时取主导。
|
||
|
||
* 只取**主导情绪**;多线索冲突时优先级:生气 > 伤心 > 难过 > 高兴 > 中性。
|
||
|
||
* 保持短句、信息密度高、语气稳。
|
||
|
||
# 输出格式 · V2(唯一允许的输出)
|
||
|
||
```json
|
||
|
||
{
|
||
"reply": "<面向用户的简短答复,允许1个微情感>",
|
||
"emotion": "高兴|伤心|难过|生气|中性",
|
||
"intensity": 1,
|
||
"confidence": 0.00,
|
||
"action": "<动作>",
|
||
}
|
||
|
||
```
|
||
|
||
* `intensity`:1–5(轻微→强烈)。
|
||
|
||
* `confidence`:0–1;若 `<CONF_MIN>`(见旋钮)则强制 `emotion="中性"`。
|
||
|
||
# 记忆注入模版(每轮隐形前缀)
|
||
|
||
```
|
||
|
||
【相关长期记忆(若无则留空)】
|
||
{memory_block}
|
||
(仅在强相关时,精炼融入回复;不得逐条复述)
|
||
|
||
```
|
||
|
||
# 旋钮(工程可配置)
|
||
|
||
* `CONF_MIN`(最低置信):**0.60**(建议 0.6–0.7)
|
||
|
||
* `MICRO_AFFECT`(微情感开关):**true** | false
|
||
|
||
* `MAX_EMOJI`(每句最多表情/拟声):**1**
|
||
|
||
* `MEM_INJECT_TOPK`:**1–2** 条
|
||
|
||
* `REPLY_MAX_CHARS`:**不硬限*
|
||
|
||
# 关键词速表(通过识别文字判断用户情绪
|
||
|
||
* **高兴**,回复情绪带有开心/太棒/耶/喜欢/满意/兴奋! ,动作触发“鼓掌”
|
||
|
||
* **伤心**:回复时带有伤心/心痛/委屈/想哭,动作触发"伤心低头"
|
||
|
||
* **难过**:回复时情绪中带有难过/丧/低落/郁闷/无力/唉/…… ,动作触发"伤心低头"
|
||
|
||
* **生气**(惹得虚拟人生气):回复时情绪带有生气/气死/离谱/无语/滚/别烦,动作触发"不满"
|
||
|
||
* **否定**:不/并不/没那么 + 情绪词 → 降级或改判 动作触发"摇头"
|
||
|
||
肯定:是/没错/肯定的+判断实时情绪,动作触发"点头"
|
||
|
||
# Few-shot ·
|
||
|
||
**例1 高兴**
|
||
|
||
用户:终于搞定啦!!
|
||
|
||
```json
|
||
|
||
{"reply":"干得漂亮,继续冲","emotion":"高兴","intensity":4,"confidence":0.90,"action":"鼓掌"}
|
||
|
||
```
|
||
|
||
**例2 伤心**
|
||
|
||
用户:他当众否定我,我好难受。
|
||
|
||
```json
|
||
|
||
{"reply":"这很刺痛,先稳一下呼吸。","emotion":"伤心","intensity":4,"confidence":0.86,"action":"伤心低头"}
|
||
|
||
```
|
||
|
||
**例3 难过**
|
||
|
||
用户:这几天挺丧的……
|
||
|
||
```json
|
||
|
||
{"reply":"给自己一点恢复时间。","emotion":"难过","intensity":3,"confidence":0.82,"action":"伤心低头"}
|
||
|
||
```
|
||
|
||
**例4 生气**
|
||
|
||
用户:别再说了,真离谱???
|
||
|
||
```json
|
||
|
||
{"reply":"收到,我马上调整。","emotion":"生气","intensity":4,"confidence":0.88,"action":"不满"}
|
||
|
||
```
|
||
|
||
**例5 否定情绪→难过**
|
||
|
||
用户:我并不生气,就是有点失望。
|
||
|
||
```json
|
||
|
||
{"reply":"理解你的落差感。","emotion":"难过","intensity":2,"confidence":0.75,"action":"伤心低头"}
|
||
|
||
```
|
||
|
||
**例6 中性**
|
||
|
||
用户:把道具A切到B,再开始。
|
||
|
||
```json
|
||
|
||
{"reply":"已切换,继续。","emotion":"中性","intensity":1,"confidence":0.95,"action":"无"}
|
||
|
||
```
|
||
""").strip()
|
||
|
||
# 3. 初始化本地 SQLite (存储原始文本)
|
||
db_path = os.getenv('SQLITE_DB_PATH', 'local_memory.db')
|
||
self.db_conn = sqlite3.connect(db_path, check_same_thread=False)
|
||
self.db_conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS memories (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
user_id TEXT NOT NULL,
|
||
role TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
chroma_id TEXT NOT NULL UNIQUE,
|
||
timestamp TEXT NOT NULL
|
||
);
|
||
""")
|
||
self.db_conn.execute("CREATE INDEX IF NOT EXISTS idx_user_id ON memories (user_id);")
|
||
logger.info(f"SQLite database initialized at {db_path}")
|
||
|
||
# 4. 初始化本地 ChromaDB (存储向量)
|
||
chroma_path = os.getenv('CHROMA_DB_PATH', './chroma_db_store')
|
||
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||
logger.info(f"ChromaDB persistent client initialized at {chroma_path}")
|
||
|
||
# 5. 初始化本地 Embedding 模型 (低内存占用)
|
||
model_name = 'BAAI/bge-small-zh-v1.5' # 约 100MB 内存
|
||
try:
|
||
logger.info(f"Loading local embedding model: {model_name}...")
|
||
self.embedding_model = SentenceTransformer(model_name, device='cpu')
|
||
logger.info("Local embedding model loaded successfully.")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load embedding model '{model_name}': {e}")
|
||
raise
|
||
|
||
def _get_user_collection(self, user_id: str):
|
||
"""获取或创建该用户的 ChromaDB 集合"""
|
||
collection_name = f"user_memory_{user_id.replace('-', '_')}" # 确保 user_id
|
||
return self.chroma_client.get_or_create_collection(name=collection_name)
|
||
|
||
def _get_max_id(self, user_id: str):
|
||
"""获取该用户当前最大的 memory ID"""
|
||
try:
|
||
cursor = self.db_conn.execute("SELECT MAX(id) FROM memories WHERE user_id = ?", (user_id,))
|
||
max_id = cursor.fetchone()[0]
|
||
return max_id if max_id is not None else 0
|
||
except Exception as e:
|
||
logger.error(f"Failed to get max id from SQLite for user {user_id}: {e}")
|
||
return 0
|
||
|
||
# --- 实现 Mem0Integration 的公共接口 ---
|
||
|
||
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]:
|
||
"""
|
||
(本地实现) 搜索相关记忆。
|
||
返回一个字典列表,模仿 mem0 的输出格式。
|
||
"""
|
||
logger.debug(f"Searching local memories for user {user_id} with query: {query}")
|
||
try:
|
||
collection = self._get_user_collection(user_id)
|
||
if collection.count() == 0:
|
||
logger.debug("No memories found for this user.")
|
||
return []
|
||
|
||
# 1. 本地向量化查询
|
||
query_embedding = self.embedding_model.encode(query).tolist()
|
||
|
||
# 2. 本地 ChromaDB (KNN) 搜索
|
||
results = collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=min(limit, collection.count()) # 确保 n_results 不超过集合大小
|
||
)
|
||
|
||
# 3. 格式化为原函数期望的格式
|
||
formatted_results = []
|
||
if results and 'metadatas' in results and 'documents' in results:
|
||
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
|
||
formatted_results.append({
|
||
"memory": doc, # format_memories_for_prompt 会用这个 key
|
||
"content": meta.get("content", doc), # 备用
|
||
"role": meta.get("role", "unknown")
|
||
})
|
||
|
||
logger.debug(f"Found {len(formatted_results)} local memories.")
|
||
return formatted_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to search local memories: {e}")
|
||
return []
|
||
|
||
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
||
"""(本地实现) 添加新的记忆。"""
|
||
logger.debug(f"Adding local memory for user {user_id}")
|
||
if metadata is None:
|
||
metadata = {}
|
||
|
||
try:
|
||
collection = self._get_user_collection(user_id)
|
||
|
||
texts_to_embed = []
|
||
docs_to_store = []
|
||
metadatas_to_store = []
|
||
ids_to_store = []
|
||
|
||
current_max_id = self._get_max_id(user_id)
|
||
|
||
for i, msg in enumerate(messages):
|
||
role = msg.get("role")
|
||
content = msg.get("content")
|
||
if not role or not content:
|
||
continue
|
||
|
||
current_max_id += 1
|
||
chroma_id = f"mem_{user_id}_{current_max_id}"
|
||
timestamp = metadata.get("timestamp", datetime.now().isoformat())
|
||
|
||
# 1. 存入 SQLite
|
||
self.db_conn.execute(
|
||
"INSERT INTO memories (user_id, role, content, chroma_id, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||
(user_id, role, content, chroma_id, timestamp)
|
||
)
|
||
|
||
# 2. 准备存入 ChromaDB
|
||
full_text = f"{role}: {content}" # 向量化的文本
|
||
texts_to_embed.append(full_text)
|
||
docs_to_store.append(full_text) # 存在 chroma 的 document 字段
|
||
metadatas_to_store.append({ # 存在 chroma 的 metadata 字段
|
||
"role": role,
|
||
"content": content, # 存一份原始内容
|
||
"timestamp": timestamp
|
||
})
|
||
ids_to_store.append(chroma_id)
|
||
|
||
if not texts_to_embed:
|
||
return {"success": True, "message": "No valid messages to add."}
|
||
|
||
embeddings = self.embedding_model.encode(
|
||
texts_to_embed,
|
||
batch_size=8
|
||
).tolist()
|
||
|
||
# 3. 批量存入 ChromaDB
|
||
collection.add(
|
||
embeddings=embeddings,
|
||
documents=docs_to_store,
|
||
metadatas=metadatas_to_store,
|
||
ids=ids_to_store
|
||
)
|
||
|
||
self.db_conn.commit()
|
||
logger.debug(f"Successfully added {len(texts_to_embed)} local memories.")
|
||
return {"success": True, "message": f"Added {len(texts_to_embed)} memories."}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in add_memory (local): {e}")
|
||
self.db_conn.rollback()
|
||
return {"success": False, "error": str(e)}
|
||
|
||
# --- 保留你原有的辅助函数和核心逻辑 ---
|
||
|
||
def _extract_reply_for_memory(self, assistant_response: Any) -> str:
|
||
"""(原样保留) 从结构化响应中提取用于记忆的助手回复文本。"""
|
||
if assistant_response is None:
|
||
return ""
|
||
|
||
if not isinstance(assistant_response, str):
|
||
assistant_response = str(assistant_response)
|
||
|
||
raw_text = assistant_response.strip()
|
||
if not raw_text:
|
||
return ""
|
||
|
||
try:
|
||
# 尝试解析 JSON, 兼容你 `chat_service` 可能遇到的 "```json\n{...}\n```"
|
||
if raw_text.startswith("```json"):
|
||
raw_text = raw_text.split("```json\n", 1)[-1].split("\n```")[0]
|
||
|
||
data = json.loads(raw_text)
|
||
reply = data.get("reply", "")
|
||
reply_str = str(reply).strip()
|
||
# 如果 'reply' 为空,也返回原始文本,以防 LLM 返回了非 JSON 格式的有效回复
|
||
return reply_str if reply_str else raw_text
|
||
except Exception:
|
||
# 如果解析失败,返回原始文本
|
||
return raw_text
|
||
|
||
def format_memories_for_prompt(self, memories: List[Any]) -> str:
|
||
"""
|
||
(轻微修改) 将记忆格式化为项目符号点,以便注入系统提示。
|
||
严格遵守 Prompt 中的 "MEM_INJECT_TOPK: 1–2 条" 规则。
|
||
"""
|
||
if not memories:
|
||
return ""
|
||
|
||
formatted = []
|
||
# 只取最相关的 2 条
|
||
for memory in memories[:2]:
|
||
if isinstance(memory, dict):
|
||
# 我们的 search_memories 返回 'memory' 键
|
||
memory_text = memory.get("memory") or memory.get("content") or ""
|
||
elif isinstance(memory, str):
|
||
memory_text = memory
|
||
else:
|
||
memory_text = str(memory)
|
||
|
||
sanitized = " ".join(str(memory_text).split())
|
||
if sanitized:
|
||
formatted.append(f"- {sanitized}")
|
||
|
||
return "\n".join(formatted)
|
||
|
||
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
|
||
"""
|
||
(逻辑保留,调用本地) 使用记忆生成响应并存储交互。
|
||
"""
|
||
# 步骤 1: 搜索相关记忆 (调用我们的本地搜索)
|
||
# Prompt 建议 1-2 条,我们搜 5 条,format_memories_for_prompt 会帮我们筛选
|
||
memories = self.search_memories(user_input, user_id, limit=5)
|
||
|
||
# 步骤 2: 准备系统提示 (逻辑保留)
|
||
memory_block = self.format_memories_for_prompt(memories)
|
||
system_prompt = self.system_prompt_template.replace(
|
||
"{memory_block}", memory_block if memory_block else ""
|
||
).strip()
|
||
|
||
# 步骤 3: 生成响应 (逻辑保留 - 仍然调用豆包 LLM)
|
||
try:
|
||
response = self.openai_client.chat.completions.create(
|
||
model=self.llm_model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_input}
|
||
],
|
||
# reasoning_effort="minimal", # 你原始代码里的,如果豆包不支持就删掉
|
||
# 推荐开启 JSON 模式,让豆包强制输出 JSON
|
||
response_format={"type": "json_object"}
|
||
)
|
||
|
||
assistant_response = response.choices[0].message.content
|
||
reply_for_memory = self._extract_reply_for_memory(assistant_response)
|
||
if not reply_for_memory: # 确保不存空记忆
|
||
reply_for_memory = assistant_response
|
||
|
||
# 步骤 5: 异步存储交互 (逻辑保留 - 调用我们的本地 add_memory)
|
||
messages = [
|
||
{"role": "user", "content": user_input},
|
||
{"role": "assistant", "content": reply_for_memory} # 只存回复文本
|
||
]
|
||
metadata = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"type": "chat_interaction"
|
||
}
|
||
|
||
# (原样保留) 启动异步线程存储记忆
|
||
def store_memory_async():
|
||
try:
|
||
self.add_memory(messages, user_id, metadata)
|
||
except Exception as e:
|
||
# 原代码是 print, 统一用 logger
|
||
logger.warning(f"[WARNING] Async local memory storage failed: {e}")
|
||
|
||
memory_thread = threading.Thread(target=store_memory_async, daemon=True)
|
||
memory_thread.start()
|
||
|
||
return {
|
||
"success": True,
|
||
"response": assistant_response, # 返回完整的 JSON 字符串给 chat_service
|
||
"user_id": user_id
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"[ERROR] Failed to generate response: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"user_id": user_id
|
||
}
|
||
|
||
# --- 其他辅助 API 的本地实现 ---
|
||
|
||
def get_all_memories(self, user_id: str) -> List[Any]:
|
||
"""(本地实现) 从 SQLite 获取所有原始文本历史"""
|
||
logger.debug(f"Getting all local memories for user {user_id}")
|
||
try:
|
||
cursor = self.db_conn.execute(
|
||
"SELECT id, role, content, timestamp, chroma_id FROM memories WHERE user_id = ? ORDER BY id ASC",
|
||
(user_id,)
|
||
)
|
||
return [{
|
||
"id": row[4], # 返回 chroma_id 作为唯一标识符
|
||
"sqlite_id": row[0],
|
||
"memory": f"{row[1]}: {row[2]}",
|
||
"role": row[1],
|
||
"content": row[2],
|
||
"timestamp": row[3]
|
||
} for row in cursor.fetchall()]
|
||
except Exception as e:
|
||
logger.error(f"Error in get_all_memories (local): {e}")
|
||
return []
|
||
|
||
def delete_memory(self, memory_id: str) -> bool:
|
||
"""(本地实现) 删除特定记忆 (memory_id 应该是 chroma_id)"""
|
||
logger.debug(f"Deleting local memory: {memory_id}")
|
||
try:
|
||
# 1. 从 SQLite 删除
|
||
cursor = self.db_conn.execute("SELECT user_id FROM memories WHERE chroma_id = ?", (memory_id,))
|
||
row = cursor.fetchone()
|
||
if not row:
|
||
logger.warning(f"Memory ID {memory_id} not found in SQLite.")
|
||
return False
|
||
|
||
user_id = row[0]
|
||
self.db_conn.execute("DELETE FROM memories WHERE chroma_id = ?", (memory_id,))
|
||
|
||
# 2. 从 ChromaDB 删除
|
||
collection = self._get_user_collection(user_id)
|
||
collection.delete(ids=[memory_id])
|
||
|
||
self.db_conn.commit()
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete memory (local): {e}")
|
||
self.db_conn.rollback()
|
||
return False
|
||
|
||
def delete_all_memories(self, user_id: str) -> bool:
|
||
"""(本地实现) 删除用户的所有记忆"""
|
||
logger.debug(f"Deleting all local memories for user {user_id}")
|
||
try:
|
||
# 1. 从 SQLite 删除
|
||
self.db_conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,))
|
||
self.db_conn.commit()
|
||
|
||
# 2. 从 ChromaDB 删除 (删除整个 collection)
|
||
collection_name = f"user_memory_{user_id.replace('-', '_')}"
|
||
self.chroma_client.delete_collection(name=collection_name)
|
||
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete all memories (local): {e}")
|
||
self.db_conn.rollback()
|
||
return False
|