feat(memory): replace Mem0 with local RAG using SQLite, ChromaDB, and bge-small
Some checks failed
Build and Push Docker / build-and-push (push) Failing after 16m28s
Some checks failed
Build and Push Docker / build-and-push (push) Failing after 16m28s
This commit is contained in:
@ -13,7 +13,7 @@ from config import (
|
|||||||
MEM0_CONFIG,
|
MEM0_CONFIG,
|
||||||
)
|
)
|
||||||
from api.doubao_tts import text_to_speech
|
from api.doubao_tts import text_to_speech
|
||||||
from memory_module.memory_integration import Mem0Integration
|
from memory_module import LocalMemoryIntegration
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class ChatService:
|
class ChatService:
|
||||||
def __init__(self, user_id: str = None):
|
def __init__(self, user_id: str = None):
|
||||||
self.user_id = user_id or DEFAULT_USER_ID
|
self.user_id = user_id or DEFAULT_USER_ID
|
||||||
self.mem0_integration = Mem0Integration(MEM0_CONFIG)
|
self.mem0_integration = LocalMemoryIntegration(MEM0_CONFIG) # 使用新类
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
from .memory_integration import Mem0Integration
|
from .memory_integration import LocalMemoryIntegration
|
||||||
|
|
||||||
__all__ = ["Mem0Integration"]
|
__all__ = ["LocalMemoryIntegration"]
|
||||||
|
|||||||
@ -1,32 +1,47 @@
|
|||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import chromadb
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import openai
|
import openai # 我们仍然需要它来调用 豆包LLM
|
||||||
import threading
|
import threading
|
||||||
from mem0 import Memory
|
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
# 使用你项目中的日志配置
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Mem0Integration:
|
class LocalMemoryIntegration:
|
||||||
"""Mem0 integration for memory retrieval and storage in RAG pipeline."""
|
"""
|
||||||
|
使用本地轻量级 RAG (SQLite + ChromaDB + bge-small) 替换 Mem0。
|
||||||
|
|
||||||
|
它保留了原始的 LLM 调用和复杂的 JSON Prompt 逻辑,
|
||||||
|
只将 RAG 检索和存储部分本地化。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any]):
|
def __init__(self, config: Dict[str, Any]):
|
||||||
"""Initialize Mem0 with configuration."""
|
"""
|
||||||
|
初始化本地记忆模块和 OpenAI (豆包) 客户端。
|
||||||
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.memory = Memory.from_config(config)
|
|
||||||
|
|
||||||
# Initialize OpenAI client for chat completion
|
# 1. 初始化 LLM 客户端 (豆包) - 这部分逻辑保留
|
||||||
self.openai_client = openai.OpenAI(
|
# config["llm"] 对应你 README 中的 LLM 配置
|
||||||
api_key=config["llm"]["config"]["api_key"],
|
try:
|
||||||
base_url=config["llm"]["config"].get("openai_base_url")
|
self.openai_client = openai.OpenAI(
|
||||||
)
|
api_key=config["llm"]["config"]["api_key"],
|
||||||
self.llm_model = config["llm"]["config"]["model"]
|
base_url=config["llm"]["config"].get("openai_base_url")
|
||||||
|
)
|
||||||
# System prompt template aligned with virtual persona requirements
|
self.llm_model = config["llm"]["config"]["model"]
|
||||||
|
logger.info(f"OpenAI (Doubao) LLM client initialized for model: {self.llm_model}")
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"LLM config missing key: {e}. Check your MEM0_CONFIG['llm'].")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 2. 保留核心的 System Prompt (从你发的文件中完整复制)
|
||||||
self.system_prompt_template = textwrap.dedent("""
|
self.system_prompt_template = textwrap.dedent("""
|
||||||
你是虚拟人对话引擎。
|
你是虚拟人对话引擎。
|
||||||
|
|
||||||
@ -106,7 +121,7 @@ class Mem0Integration:
|
|||||||
|
|
||||||
* **生气**(惹得虚拟人生气):回复时情绪带有生气/气死/离谱/无语/滚/别烦,动作触发"不满"
|
* **生气**(惹得虚拟人生气):回复时情绪带有生气/气死/离谱/无语/滚/别烦,动作触发"不满"
|
||||||
|
|
||||||
* **否定**:不/并不/没那么 + 情绪词 → 降级或改判 动作触发"摇头"
|
* **否定**:不/并不/没那么 + 情绪词 → 降级或改判 动作触发"摇头"
|
||||||
|
|
||||||
肯定:是/没错/肯定的+判断实时情绪,动作触发"点头"
|
肯定:是/没错/肯定的+判断实时情绪,动作触发"点头"
|
||||||
|
|
||||||
@ -172,51 +187,165 @@ class Mem0Integration:
|
|||||||
|
|
||||||
```
|
```
|
||||||
""").strip()
|
""").strip()
|
||||||
|
|
||||||
|
# 3. 初始化本地 SQLite (存储原始文本)
|
||||||
|
db_path = os.getenv('SQLITE_DB_PATH', 'local_memory.db')
|
||||||
|
self.db_conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
|
self.db_conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
chroma_id TEXT NOT NULL UNIQUE,
|
||||||
|
timestamp TEXT NOT NULL
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
self.db_conn.execute("CREATE INDEX IF NOT EXISTS idx_user_id ON memories (user_id);")
|
||||||
|
logger.info(f"SQLite database initialized at {db_path}")
|
||||||
|
|
||||||
|
# 4. 初始化本地 ChromaDB (存储向量)
|
||||||
|
chroma_path = os.getenv('CHROMA_DB_PATH', './chroma_db_store')
|
||||||
|
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||||||
|
logger.info(f"ChromaDB persistent client initialized at {chroma_path}")
|
||||||
|
|
||||||
|
# 5. 初始化本地 Embedding 模型 (低内存占用)
|
||||||
|
model_name = 'bge-small-zh-v1.5' # 约 100MB 内存
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading local embedding model: {model_name}...")
|
||||||
|
self.embedding_model = SentenceTransformer(model_name, device='cpu')
|
||||||
|
logger.info("Local embedding model loaded successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load embedding model '{model_name}': {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_user_collection(self, user_id: str):
|
||||||
|
"""获取或创建该用户的 ChromaDB 集合"""
|
||||||
|
collection_name = f"user_memory_{user_id.replace('-', '_')}" # 确保 user_id
|
||||||
|
return self.chroma_client.get_or_create_collection(name=collection_name)
|
||||||
|
|
||||||
|
def _get_max_id(self, user_id: str):
|
||||||
|
"""获取该用户当前最大的 memory ID"""
|
||||||
|
try:
|
||||||
|
cursor = self.db_conn.execute("SELECT MAX(id) FROM memories WHERE user_id = ?", (user_id,))
|
||||||
|
max_id = cursor.fetchone()[0]
|
||||||
|
return max_id if max_id is not None else 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get max id from SQLite for user {user_id}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# --- 实现 Mem0Integration 的公共接口 ---
|
||||||
|
|
||||||
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]:
|
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]:
|
||||||
"""Search for relevant memories about the user."""
|
"""
|
||||||
|
(本地实现) 搜索相关记忆。
|
||||||
|
返回一个字典列表,模仿 mem0 的输出格式。
|
||||||
|
"""
|
||||||
|
logger.debug(f"Searching local memories for user {user_id} with query: {query}")
|
||||||
try:
|
try:
|
||||||
results = self.memory.search(
|
collection = self._get_user_collection(user_id)
|
||||||
query=query,
|
if collection.count() == 0:
|
||||||
user_id=user_id,
|
logger.debug("No memories found for this user.")
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
# Handle dictionary response format - check for both 'memories' and 'results' keys
|
|
||||||
if isinstance(results, dict):
|
|
||||||
memories = results.get("memories", results.get("results", []))
|
|
||||||
return memories
|
|
||||||
else:
|
|
||||||
logger.error(f"Unexpected search results format: {type(results)}")
|
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to search memories: {e}")
|
# 1. 本地向量化查询
|
||||||
return []
|
query_embedding = self.embedding_model.encode(query).tolist()
|
||||||
|
|
||||||
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
|
||||||
"""Add a memory for the user."""
|
|
||||||
try:
|
|
||||||
# Debug: Log what we're trying to add
|
|
||||||
logger.debug(f"Adding memory for user {user_id}")
|
|
||||||
logger.debug(f"Messages: {messages}")
|
|
||||||
logger.debug(f"Metadata: {metadata}")
|
|
||||||
|
|
||||||
result = self.memory.add(
|
# 2. 本地 ChromaDB (KNN) 搜索
|
||||||
messages=messages,
|
results = collection.query(
|
||||||
user_id=user_id,
|
query_embeddings=[query_embedding],
|
||||||
metadata=metadata or {},
|
n_results=min(limit, collection.count()) # 确保 n_results 不超过集合大小
|
||||||
infer= False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Debug: Log the result
|
# 3. 格式化为原函数期望的格式
|
||||||
logger.debug(f"Add memory result: {result}")
|
formatted_results = []
|
||||||
return result
|
if results and 'metadatas' in results and 'documents' in results:
|
||||||
|
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
|
||||||
|
formatted_results.append({
|
||||||
|
"memory": doc, # format_memories_for_prompt 会用这个 key
|
||||||
|
"content": meta.get("content", doc), # 备用
|
||||||
|
"role": meta.get("role", "unknown")
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(f"Found {len(formatted_results)} local memories.")
|
||||||
|
return formatted_results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to add memory: {e}")
|
logger.error(f"Failed to search local memories: {e}")
|
||||||
logger.exception("Exception details:")
|
return []
|
||||||
return {}
|
|
||||||
|
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
|
"""(本地实现) 添加新的记忆。"""
|
||||||
|
logger.debug(f"Adding local memory for user {user_id}")
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
collection = self._get_user_collection(user_id)
|
||||||
|
|
||||||
|
texts_to_embed = []
|
||||||
|
docs_to_store = []
|
||||||
|
metadatas_to_store = []
|
||||||
|
ids_to_store = []
|
||||||
|
|
||||||
|
current_max_id = self._get_max_id(user_id)
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content")
|
||||||
|
if not role or not content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_max_id += 1
|
||||||
|
chroma_id = f"mem_{user_id}_{current_max_id}"
|
||||||
|
timestamp = metadata.get("timestamp", datetime.now().isoformat())
|
||||||
|
|
||||||
|
# 1. 存入 SQLite
|
||||||
|
self.db_conn.execute(
|
||||||
|
"INSERT INTO memories (user_id, role, content, chroma_id, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(user_id, role, content, chroma_id, timestamp)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 准备存入 ChromaDB
|
||||||
|
full_text = f"{role}: {content}" # 向量化的文本
|
||||||
|
texts_to_embed.append(full_text)
|
||||||
|
docs_to_store.append(full_text) # 存在 chroma 的 document 字段
|
||||||
|
metadatas_to_store.append({ # 存在 chroma 的 metadata 字段
|
||||||
|
"role": role,
|
||||||
|
"content": content, # 存一份原始内容
|
||||||
|
"timestamp": timestamp
|
||||||
|
})
|
||||||
|
ids_to_store.append(chroma_id)
|
||||||
|
|
||||||
|
if not texts_to_embed:
|
||||||
|
return {"success": True, "message": "No valid messages to add."}
|
||||||
|
|
||||||
|
embeddings = self.embedding_model.encode(
|
||||||
|
texts_to_embed,
|
||||||
|
batch_size=8
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
# 3. 批量存入 ChromaDB
|
||||||
|
collection.add(
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=docs_to_store,
|
||||||
|
metadatas=metadatas_to_store,
|
||||||
|
ids=ids_to_store
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db_conn.commit()
|
||||||
|
logger.debug(f"Successfully added {len(texts_to_embed)} local memories.")
|
||||||
|
return {"success": True, "message": f"Added {len(texts_to_embed)} memories."}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in add_memory (local): {e}")
|
||||||
|
self.db_conn.rollback()
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
# --- 保留你原有的辅助函数和核心逻辑 ---
|
||||||
|
|
||||||
def _extract_reply_for_memory(self, assistant_response: Any) -> str:
|
def _extract_reply_for_memory(self, assistant_response: Any) -> str:
|
||||||
"""Extract the assistant reply text from structured responses for memory storage."""
|
"""(原样保留) 从结构化响应中提取用于记忆的助手回复文本。"""
|
||||||
if assistant_response is None:
|
if assistant_response is None:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@ -228,21 +357,32 @@ class Mem0Integration:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 尝试解析 JSON, 兼容你 `chat_service` 可能遇到的 "```json\n{...}\n```"
|
||||||
|
if raw_text.startswith("```json"):
|
||||||
|
raw_text = raw_text.split("```json\n", 1)[-1].split("\n```")[0]
|
||||||
|
|
||||||
data = json.loads(raw_text)
|
data = json.loads(raw_text)
|
||||||
reply = data.get("reply", "")
|
reply = data.get("reply", "")
|
||||||
reply_str = str(reply).strip()
|
reply_str = str(reply).strip()
|
||||||
return reply_str if reply_str else raw_text
|
# 如果 'reply' 为空,也返回原始文本,以防 LLM 返回了非 JSON 格式的有效回复
|
||||||
|
return reply_str if reply_str else raw_text
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# 如果解析失败,返回原始文本
|
||||||
return raw_text
|
return raw_text
|
||||||
|
|
||||||
def format_memories_for_prompt(self, memories: List[Any]) -> str:
|
def format_memories_for_prompt(self, memories: List[Any]) -> str:
|
||||||
"""Format memories into bullet points for injection into the system prompt."""
|
"""
|
||||||
|
(轻微修改) 将记忆格式化为项目符号点,以便注入系统提示。
|
||||||
|
严格遵守 Prompt 中的 "MEM_INJECT_TOPK: 1–2 条" 规则。
|
||||||
|
"""
|
||||||
if not memories:
|
if not memories:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
formatted = []
|
formatted = []
|
||||||
for memory in memories:
|
# 只取最相关的 2 条
|
||||||
|
for memory in memories[:2]:
|
||||||
if isinstance(memory, dict):
|
if isinstance(memory, dict):
|
||||||
|
# 我们的 search_memories 返回 'memory' 键
|
||||||
memory_text = memory.get("memory") or memory.get("content") or ""
|
memory_text = memory.get("memory") or memory.get("content") or ""
|
||||||
elif isinstance(memory, str):
|
elif isinstance(memory, str):
|
||||||
memory_text = memory
|
memory_text = memory
|
||||||
@ -256,17 +396,20 @@ class Mem0Integration:
|
|||||||
return "\n".join(formatted)
|
return "\n".join(formatted)
|
||||||
|
|
||||||
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
|
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
|
||||||
"""Generate a response using memories and store the interaction."""
|
"""
|
||||||
# Step 1: Search for relevant memories
|
(逻辑保留,调用本地) 使用记忆生成响应并存储交互。
|
||||||
memories = self.search_memories(user_input, user_id, limit=5)
|
"""
|
||||||
|
# 步骤 1: 搜索相关记忆 (调用我们的本地搜索)
|
||||||
|
# Prompt 建议 1-2 条,我们搜 5 条,format_memories_for_prompt 会帮我们筛选
|
||||||
|
memories = self.search_memories(user_input, user_id, limit=5)
|
||||||
|
|
||||||
# Step 2: Prepare system prompt with memory injection
|
# 步骤 2: 准备系统提示 (逻辑保留)
|
||||||
memory_block = self.format_memories_for_prompt(memories)
|
memory_block = self.format_memories_for_prompt(memories)
|
||||||
system_prompt = self.system_prompt_template.replace(
|
system_prompt = self.system_prompt_template.replace(
|
||||||
"{memory_block}", memory_block if memory_block else ""
|
"{memory_block}", memory_block if memory_block else ""
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
# Step 3: Generate response using OpenAI
|
# 步骤 3: 生成响应 (逻辑保留 - 仍然调用豆包 LLM)
|
||||||
try:
|
try:
|
||||||
response = self.openai_client.chat.completions.create(
|
response = self.openai_client.chat.completions.create(
|
||||||
model=self.llm_model,
|
model=self.llm_model,
|
||||||
@ -274,92 +417,112 @@ class Mem0Integration:
|
|||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": user_input}
|
{"role": "user", "content": user_input}
|
||||||
],
|
],
|
||||||
reasoning_effort="minimal",
|
# reasoning_effort="minimal", # 你原始代码里的,如果豆包不支持就删掉
|
||||||
|
# 推荐开启 JSON 模式,让豆包强制输出 JSON
|
||||||
|
response_format={"type": "json_object"}
|
||||||
)
|
)
|
||||||
|
|
||||||
assistant_response = response.choices[0].message.content
|
assistant_response = response.choices[0].message.content
|
||||||
reply_for_memory = self._extract_reply_for_memory(assistant_response)
|
reply_for_memory = self._extract_reply_for_memory(assistant_response)
|
||||||
if not reply_for_memory:
|
if not reply_for_memory: # 确保不存空记忆
|
||||||
reply_for_memory = assistant_response
|
reply_for_memory = assistant_response
|
||||||
|
|
||||||
# Step 5: Store the interaction as new memories (异步执行)
|
# 步骤 5: 异步存储交互 (逻辑保留 - 调用我们的本地 add_memory)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": user_input},
|
{"role": "user", "content": user_input},
|
||||||
{"role": "assistant", "content": reply_for_memory}
|
{"role": "assistant", "content": reply_for_memory} # 只存回复文本
|
||||||
]
|
]
|
||||||
|
|
||||||
# Store with metadata including timestamp
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
"type": "chat_interaction"
|
"type": "chat_interaction"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 异步存储记忆,不阻塞主流程
|
# (原样保留) 启动异步线程存储记忆
|
||||||
def store_memory_async():
|
def store_memory_async():
|
||||||
try:
|
try:
|
||||||
self.add_memory(messages, user_id, metadata)
|
self.add_memory(messages, user_id, metadata)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WARNING] Async memory storage failed: {e}")
|
# 原代码是 print, 统一用 logger
|
||||||
|
logger.warning(f"[WARNING] Async local memory storage failed: {e}")
|
||||||
|
|
||||||
# 启动异步线程存储记忆
|
|
||||||
memory_thread = threading.Thread(target=store_memory_async, daemon=True)
|
memory_thread = threading.Thread(target=store_memory_async, daemon=True)
|
||||||
memory_thread.start()
|
memory_thread.start()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"response": assistant_response,
|
"response": assistant_response, # 返回完整的 JSON 字符串给 chat_service
|
||||||
"user_id": user_id
|
"user_id": user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Failed to generate response: {e}")
|
logger.error(f"[ERROR] Failed to generate response: {e}")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"user_id": user_id
|
"user_id": user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_all_memories(self, user_id: str) -> List[Any]:
|
# --- 其他辅助 API 的本地实现 ---
|
||||||
"""Get all memories for a user."""
|
|
||||||
try:
|
|
||||||
memories = self.memory.get_all(user_id=user_id)
|
|
||||||
# Handle dictionary response format
|
|
||||||
if isinstance(memories, dict):
|
|
||||||
# Check for different possible key names
|
|
||||||
if "memories" in memories:
|
|
||||||
memories_list = memories.get("memories", [])
|
|
||||||
elif "results" in memories:
|
|
||||||
memories_list = memories.get("results", [])
|
|
||||||
else:
|
|
||||||
# Try to find any list in the dict
|
|
||||||
for key, value in memories.items():
|
|
||||||
if isinstance(value, list):
|
|
||||||
memories_list = value
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
memories_list = []
|
|
||||||
return memories_list
|
|
||||||
else:
|
|
||||||
print(f"[ERROR] Unexpected memories format: {type(memories)}")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Failed to get all memories: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def delete_memory(self, memory_id: str) -> bool:
|
def get_all_memories(self, user_id: str) -> List[Any]:
|
||||||
"""Delete a specific memory."""
|
"""(本地实现) 从 SQLite 获取所有原始文本历史"""
|
||||||
|
logger.debug(f"Getting all local memories for user {user_id}")
|
||||||
try:
|
try:
|
||||||
self.memory.delete(memory_id)
|
cursor = self.db_conn.execute(
|
||||||
|
"SELECT id, role, content, timestamp, chroma_id FROM memories WHERE user_id = ? ORDER BY id ASC",
|
||||||
|
(user_id,)
|
||||||
|
)
|
||||||
|
return [{
|
||||||
|
"id": row[4], # 返回 chroma_id 作为唯一标识符
|
||||||
|
"sqlite_id": row[0],
|
||||||
|
"memory": f"{row[1]}: {row[2]}",
|
||||||
|
"role": row[1],
|
||||||
|
"content": row[2],
|
||||||
|
"timestamp": row[3]
|
||||||
|
} for row in cursor.fetchall()]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_all_memories (local): {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_memory(self, memory_id: str) -> bool:
|
||||||
|
"""(本地实现) 删除特定记忆 (memory_id 应该是 chroma_id)"""
|
||||||
|
logger.debug(f"Deleting local memory: {memory_id}")
|
||||||
|
try:
|
||||||
|
# 1. 从 SQLite 删除
|
||||||
|
cursor = self.db_conn.execute("SELECT user_id FROM memories WHERE chroma_id = ?", (memory_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
logger.warning(f"Memory ID {memory_id} not found in SQLite.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_id = row[0]
|
||||||
|
self.db_conn.execute("DELETE FROM memories WHERE chroma_id = ?", (memory_id,))
|
||||||
|
|
||||||
|
# 2. 从 ChromaDB 删除
|
||||||
|
collection = self._get_user_collection(user_id)
|
||||||
|
collection.delete(ids=[memory_id])
|
||||||
|
|
||||||
|
self.db_conn.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Failed to delete memory: {e}")
|
logger.error(f"Failed to delete memory (local): {e}")
|
||||||
|
self.db_conn.rollback()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_memories(self, user_id: str) -> bool:
|
def delete_all_memories(self, user_id: str) -> bool:
|
||||||
"""Delete all memories for a user."""
|
"""(本地实现) 删除用户的所有记忆"""
|
||||||
|
logger.debug(f"Deleting all local memories for user {user_id}")
|
||||||
try:
|
try:
|
||||||
self.memory.delete_all(user_id=user_id)
|
# 1. 从 SQLite 删除
|
||||||
|
self.db_conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,))
|
||||||
|
self.db_conn.commit()
|
||||||
|
|
||||||
|
# 2. 从 ChromaDB 删除 (删除整个 collection)
|
||||||
|
collection_name = f"user_memory_{user_id.replace('-', '_')}"
|
||||||
|
self.chroma_client.delete_collection(name=collection_name)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Failed to delete all memories: {e}")
|
logger.error(f"Failed to delete all memories (local): {e}")
|
||||||
|
self.db_conn.rollback()
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -5,6 +5,7 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"chromadb>=1.3.5",
|
||||||
"fastapi>=0.115.12",
|
"fastapi>=0.115.12",
|
||||||
"haystack-ai>=2.12.1",
|
"haystack-ai>=2.12.1",
|
||||||
"huggingface-hub>=0.30.2",
|
"huggingface-hub>=0.30.2",
|
||||||
@ -12,6 +13,7 @@ dependencies = [
|
|||||||
"milvus-haystack>=0.0.15",
|
"milvus-haystack>=0.0.15",
|
||||||
"pydantic>=2.11.3",
|
"pydantic>=2.11.3",
|
||||||
"pymilvus>=2.5.6",
|
"pymilvus>=2.5.6",
|
||||||
|
"sentence-transformers>=5.1.2",
|
||||||
"uvicorn>=0.34.0",
|
"uvicorn>=0.34.0",
|
||||||
]
|
]
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
|
|||||||
Reference in New Issue
Block a user