From cfd06717e93d069e7088af70e2128e5cc8993314 Mon Sep 17 00:00:00 2001 From: gameloader Date: Sat, 24 Jan 2026 17:02:54 +0800 Subject: [PATCH] feat(api): integrate Mem0 context into chat stream --- api/main.py | 74 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/api/main.py b/api/main.py index 5551589..20be643 100644 --- a/api/main.py +++ b/api/main.py @@ -23,8 +23,8 @@ from logging_config import setup_logging, get_logger # 设置日志 setup_logging( - level=os.getenv('LOG_LEVEL', 'INFO'), - enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true' + level=os.getenv("LOG_LEVEL", "INFO"), + enable_file_logging=os.getenv("ENABLE_FILE_LOGGING", "false").lower() == "true", ) # 获取logger @@ -85,7 +85,7 @@ app = FastAPI( title="Mem0 Memory API", description="基于 Mem0 的记忆增强聊天服务 API", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, ) @@ -105,7 +105,7 @@ async def health_check(): async def chat_endpoint(request: ChatRequest): """ 聊天接口 - + 接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据) """ try: @@ -116,9 +116,9 @@ async def chat_endpoint(request: ChatRequest): result = user_chat_service.chat(request.message, request.include_audio) else: result = chat_service.chat(request.message, request.include_audio) - + return ChatResponse(**result) - + except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @@ -136,32 +136,64 @@ async def chat_stream_endpoint(request: Request): raise HTTPException(status_code=400, detail="Invalid JSON body") if not isinstance(payload, dict): - raise HTTPException(status_code=400, detail="Request body must be a JSON object") + raise HTTPException( + status_code=400, detail="Request body must be a JSON object" + ) messages = payload.get("messages") if messages is None: message = payload.get("message") if not message: - raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'") + raise HTTPException( + status_code=400, detail="Missing 'message' or 'messages'" + ) messages = [{"role": "user", "content": message}] payload.pop("message", None) - payload["messages"] = messages + user_id = payload.get("user_id") or DEFAULT_USER_ID + + user_input = "" + for msg in reversed(messages): + if isinstance(msg, dict) and msg.get("role") == "user" and msg.get("content"): + user_input = msg["content"] + break + + try: + mem0 = chat_service.mem0_integration + memories = ( + mem0.search_memories(user_input, user_id, limit=5) if user_input else [] + ) + memory_block = mem0.format_memories_for_prompt(memories) + system_prompt = mem0.system_prompt_template.replace( + "{memory_block}", memory_block if memory_block else "" + ).strip() + except Exception as exc: + logger.warning("Failed to build system prompt for stream: %s", exc) + system_prompt = chat_service.mem0_integration.system_prompt_template.replace( + "{memory_block}", "" + ).strip() + + payload["messages"] = [{"role": "system", "content": system_prompt}] + messages payload.pop("user_id", None) payload.pop("include_audio", None) payload["model"] = OPENAI_LLM_MODEL payload["thinking"] = {"type": "disabled"} payload["reasoning"] = {"effort": "minimal"} - payload["max_tokens"] = 200 + payload["max_tokens"] = 400 + payload.setdefault("response_format", {"type": "json_object"}) payload["stream"] = True base_url = OPENAI_API_BASE_URL_CONFIG or "" if not base_url: - raise HTTPException(status_code=500, detail="Upstream base URL is not configured") + raise HTTPException( + status_code=500, detail="Upstream base URL is not configured" + ) api_key = OPENAI_API_KEY_FROM_CONFIG or "" if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"): - raise HTTPException(status_code=500, detail="Upstream API key is not configured") + raise HTTPException( + status_code=500, detail="Upstream API key is not configured" + ) upstream_url = f"{base_url.rstrip('/')}/chat/completions" headers = { @@ -178,7 +210,9 @@ async def chat_stream_endpoint(request: Request): detail_bytes = await resp.aread() await resp.aclose() await client.aclose() - detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error" + detail = ( + detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error" + ) raise HTTPException(status_code=resp.status_code, detail=detail) async def event_stream(): @@ -218,25 +252,17 @@ async def tts_endpoint(request: TTSRequest): success, message, base64_audio = text_to_speech(text, user_id) if success and base64_audio: return TTSResponse( - success=True, - audio_data=base64_audio, - message=message, - user_id=user_id + success=True, audio_data=base64_audio, message=message, user_id=user_id ) return TTSResponse( success=False, message=message, error=message or "TTS failed", - user_id=user_id + user_id=user_id, ) except Exception as e: raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}") if __name__ == "__main__": - uvicorn.run( - "main:app", - host="0.0.0.0", - port=8000, - reload=True - ) + uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)