feat(api): integrate Mem0 context into chat stream
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled

This commit is contained in:
gameloader
2026-01-24 17:02:54 +08:00
parent d98879a2db
commit cfd06717e9

View File

@@ -23,8 +23,8 @@ from logging_config import setup_logging, get_logger
# 设置日志
setup_logging(
level=os.getenv('LOG_LEVEL', 'INFO'),
enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true'
level=os.getenv("LOG_LEVEL", "INFO"),
enable_file_logging=os.getenv("ENABLE_FILE_LOGGING", "false").lower() == "true",
)
# 获取logger
@@ -85,7 +85,7 @@ app = FastAPI(
title="Mem0 Memory API",
description="基于 Mem0 的记忆增强聊天服务 API",
version="1.0.0",
lifespan=lifespan
lifespan=lifespan,
)
@@ -136,32 +136,64 @@ async def chat_stream_endpoint(request: Request):
raise HTTPException(status_code=400, detail="Invalid JSON body")
if not isinstance(payload, dict):
raise HTTPException(status_code=400, detail="Request body must be a JSON object")
raise HTTPException(
status_code=400, detail="Request body must be a JSON object"
)
messages = payload.get("messages")
if messages is None:
message = payload.get("message")
if not message:
raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'")
raise HTTPException(
status_code=400, detail="Missing 'message' or 'messages'"
)
messages = [{"role": "user", "content": message}]
payload.pop("message", None)
payload["messages"] = messages
user_id = payload.get("user_id") or DEFAULT_USER_ID
user_input = ""
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user" and msg.get("content"):
user_input = msg["content"]
break
try:
mem0 = chat_service.mem0_integration
memories = (
mem0.search_memories(user_input, user_id, limit=5) if user_input else []
)
memory_block = mem0.format_memories_for_prompt(memories)
system_prompt = mem0.system_prompt_template.replace(
"{memory_block}", memory_block if memory_block else ""
).strip()
except Exception as exc:
logger.warning("Failed to build system prompt for stream: %s", exc)
system_prompt = chat_service.mem0_integration.system_prompt_template.replace(
"{memory_block}", ""
).strip()
payload["messages"] = [{"role": "system", "content": system_prompt}] + messages
payload.pop("user_id", None)
payload.pop("include_audio", None)
payload["model"] = OPENAI_LLM_MODEL
payload["thinking"] = {"type": "disabled"}
payload["reasoning"] = {"effort": "minimal"}
payload["max_tokens"] = 200
payload["max_tokens"] = 400
payload.setdefault("response_format", {"type": "json_object"})
payload["stream"] = True
base_url = OPENAI_API_BASE_URL_CONFIG or ""
if not base_url:
raise HTTPException(status_code=500, detail="Upstream base URL is not configured")
raise HTTPException(
status_code=500, detail="Upstream base URL is not configured"
)
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
raise HTTPException(status_code=500, detail="Upstream API key is not configured")
raise HTTPException(
status_code=500, detail="Upstream API key is not configured"
)
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
headers = {
@@ -178,7 +210,9 @@ async def chat_stream_endpoint(request: Request):
detail_bytes = await resp.aread()
await resp.aclose()
await client.aclose()
detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
detail = (
detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
)
raise HTTPException(status_code=resp.status_code, detail=detail)
async def event_stream():
@@ -218,25 +252,17 @@ async def tts_endpoint(request: TTSRequest):
success, message, base64_audio = text_to_speech(text, user_id)
if success and base64_audio:
return TTSResponse(
success=True,
audio_data=base64_audio,
message=message,
user_id=user_id
success=True, audio_data=base64_audio, message=message, user_id=user_id
)
return TTSResponse(
success=False,
message=message,
error=message or "TTS failed",
user_id=user_id
user_id=user_id,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)