223 lines
8.7 KiB
Python
223 lines
8.7 KiB
Python
import os
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
import openai
|
|
from mem0 import Memory
|
|
|
|
|
|
class Mem0Integration:
|
|
"""Mem0 integration for memory retrieval and storage in RAG pipeline."""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
"""Initialize Mem0 with configuration."""
|
|
self.config = config
|
|
self.memory = Memory.from_config(config)
|
|
|
|
# Initialize OpenAI client for chat completion
|
|
self.openai_client = openai.OpenAI(
|
|
api_key=config["llm"]["config"]["api_key"],
|
|
base_url=config["llm"]["config"].get("openai_base_url")
|
|
)
|
|
self.llm_model = config["llm"]["config"]["model"]
|
|
|
|
# Memory prompt template
|
|
self.memory_template = """Based on the following memories about the user:
|
|
{memories}
|
|
|
|
Please respond to the user's query: {query}
|
|
|
|
In your response, consider the memories above to provide a personalized answer."""
|
|
|
|
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Any]:
|
|
"""Search for relevant memories about the user."""
|
|
try:
|
|
results = self.memory.search(
|
|
query=query,
|
|
user_id=user_id,
|
|
limit=limit
|
|
)
|
|
# Debug: Print the actual format of results
|
|
print(f"[DEBUG] Search results type: {type(results)}")
|
|
print(f"[DEBUG] Search results content: {results}")
|
|
if isinstance(results, dict):
|
|
# Handle dictionary response format (based on official docs)
|
|
memories = results.get("memories", [])
|
|
if memories:
|
|
print(f"[DEBUG] First memory type: {type(memories[0])}")
|
|
print(f"[DEBUG] First memory: {memories[0]}")
|
|
else:
|
|
print(f"[DEBUG] No memories found in search results")
|
|
return memories
|
|
elif isinstance(results, list):
|
|
# Handle list response format
|
|
return results
|
|
else:
|
|
print(f"[WARNING] Unexpected search results format: {type(results)}")
|
|
return []
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to search memories: {e}")
|
|
return []
|
|
|
|
def add_memory(self, messages: List[Dict[str, str]], user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
|
"""Add a memory for the user."""
|
|
try:
|
|
# Debug: Print what we're trying to add
|
|
print(f"[DEBUG] Adding memory for user {user_id}")
|
|
print(f"[DEBUG] Messages: {messages}")
|
|
print(f"[DEBUG] Metadata: {metadata}")
|
|
|
|
result = self.memory.add(
|
|
messages=messages,
|
|
user_id=user_id,
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
# Debug: Print the result
|
|
print(f"[DEBUG] Add memory result: {result}")
|
|
return result
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to add memory: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return {}
|
|
|
|
def format_memories_for_prompt(self, memories: List[Any]) -> str:
|
|
"""Format memories into a string for the prompt."""
|
|
if not memories:
|
|
return "No previous memories about this user."
|
|
|
|
formatted = []
|
|
for i, memory in enumerate(memories, 1):
|
|
# Handle both string and dict formats
|
|
if isinstance(memory, dict):
|
|
memory_text = memory.get("memory", "")
|
|
created_at = memory.get("created_at", "")
|
|
if created_at:
|
|
try:
|
|
# Format the date if it's available
|
|
created_date = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
|
created_str = created_date.strftime("%Y-%m-%d %H:%M")
|
|
except:
|
|
created_str = created_at
|
|
formatted.append(f"{i}. {memory_text} (remembered on: {created_str})")
|
|
else:
|
|
formatted.append(f"{i}. {memory_text}")
|
|
elif isinstance(memory, str):
|
|
formatted.append(f"{i}. {memory}")
|
|
|
|
return "\n".join(formatted)
|
|
|
|
def generate_response_with_memory(self, user_input: str, user_id: str) -> Dict[str, Any]:
|
|
"""Generate a response using memories and store the interaction."""
|
|
# Step 1: Search for relevant memories
|
|
memories = self.search_memories(user_input, user_id)
|
|
|
|
# Step 2: Format memories for the prompt (or use empty if no memories)
|
|
if memories:
|
|
formatted_memories = self.format_memories_for_prompt(memories)
|
|
else:
|
|
formatted_memories = "No previous memories about this user."
|
|
|
|
# Step 3: Create the enhanced prompt
|
|
enhanced_prompt = self.memory_template.format(
|
|
memories=formatted_memories,
|
|
query=user_input
|
|
)
|
|
|
|
# Step 4: Generate response using OpenAI
|
|
try:
|
|
response = self.openai_client.chat.completions.create(
|
|
model=self.llm_model,
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant with access to user memories. Use the provided memories to personalize your responses."},
|
|
{"role": "user", "content": enhanced_prompt}
|
|
],
|
|
)
|
|
|
|
assistant_response = response.choices[0].message.content
|
|
|
|
# Step 5: Store the interaction as new memories
|
|
messages = [
|
|
{"role": "user", "content": user_input},
|
|
{"role": "assistant", "content": assistant_response}
|
|
]
|
|
|
|
# Store with metadata including timestamp
|
|
metadata = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"type": "chat_interaction"
|
|
}
|
|
|
|
self.add_memory(messages, user_id, metadata)
|
|
|
|
return {
|
|
"success": True,
|
|
"response": assistant_response,
|
|
"user_id": user_id
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to generate response: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"user_id": user_id
|
|
}
|
|
|
|
def get_all_memories(self, user_id: str) -> List[Any]:
|
|
"""Get all memories for a user."""
|
|
try:
|
|
memories = self.memory.get_all(user_id=user_id)
|
|
# Debug: Print the actual format of memories
|
|
print(f"[DEBUG] All memories type: {type(memories)}")
|
|
print(f"[DEBUG] All memories content: {memories}")
|
|
|
|
if isinstance(memories, dict):
|
|
# Check for different possible key names
|
|
if "memories" in memories:
|
|
memories_list = memories.get("memories", [])
|
|
elif "results" in memories:
|
|
memories_list = memories.get("results", [])
|
|
else:
|
|
# Try to find any list in the dict
|
|
for key, value in memories.items():
|
|
if isinstance(value, list):
|
|
memories_list = value
|
|
break
|
|
else:
|
|
memories_list = []
|
|
|
|
if memories_list:
|
|
print(f"[DEBUG] First memory type: {type(memories_list[0])}")
|
|
print(f"[DEBUG] First memory: {memories_list[0]}")
|
|
return memories_list
|
|
elif isinstance(memories, list):
|
|
# Handle list response format
|
|
return memories
|
|
else:
|
|
print(f"[WARNING] Unexpected memories format: {type(memories)}")
|
|
return []
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to get all memories: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return []
|
|
|
|
def delete_memory(self, memory_id: str) -> bool:
|
|
"""Delete a specific memory."""
|
|
try:
|
|
self.memory.delete(memory_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to delete memory: {e}")
|
|
return False
|
|
|
|
def delete_all_memories(self, user_id: str) -> bool:
|
|
"""Delete all memories for a user."""
|
|
try:
|
|
self.memory.delete_all(user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to delete all memories: {e}")
|
|
return False
|