feat(api): integrate Mem0 context into chat stream
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
This commit is contained in:
68
api/main.py
68
api/main.py
@@ -23,8 +23,8 @@ from logging_config import setup_logging, get_logger
|
|||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
setup_logging(
|
setup_logging(
|
||||||
level=os.getenv('LOG_LEVEL', 'INFO'),
|
level=os.getenv("LOG_LEVEL", "INFO"),
|
||||||
enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true'
|
enable_file_logging=os.getenv("ENABLE_FILE_LOGGING", "false").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取logger
|
# 获取logger
|
||||||
@@ -85,7 +85,7 @@ app = FastAPI(
|
|||||||
title="Mem0 Memory API",
|
title="Mem0 Memory API",
|
||||||
description="基于 Mem0 的记忆增强聊天服务 API",
|
description="基于 Mem0 的记忆增强聊天服务 API",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
lifespan=lifespan
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -136,32 +136,64 @@ async def chat_stream_endpoint(request: Request):
|
|||||||
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
||||||
|
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
raise HTTPException(status_code=400, detail="Request body must be a JSON object")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Request body must be a JSON object"
|
||||||
|
)
|
||||||
|
|
||||||
messages = payload.get("messages")
|
messages = payload.get("messages")
|
||||||
if messages is None:
|
if messages is None:
|
||||||
message = payload.get("message")
|
message = payload.get("message")
|
||||||
if not message:
|
if not message:
|
||||||
raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing 'message' or 'messages'"
|
||||||
|
)
|
||||||
messages = [{"role": "user", "content": message}]
|
messages = [{"role": "user", "content": message}]
|
||||||
payload.pop("message", None)
|
payload.pop("message", None)
|
||||||
|
|
||||||
payload["messages"] = messages
|
user_id = payload.get("user_id") or DEFAULT_USER_ID
|
||||||
|
|
||||||
|
user_input = ""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, dict) and msg.get("role") == "user" and msg.get("content"):
|
||||||
|
user_input = msg["content"]
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
mem0 = chat_service.mem0_integration
|
||||||
|
memories = (
|
||||||
|
mem0.search_memories(user_input, user_id, limit=5) if user_input else []
|
||||||
|
)
|
||||||
|
memory_block = mem0.format_memories_for_prompt(memories)
|
||||||
|
system_prompt = mem0.system_prompt_template.replace(
|
||||||
|
"{memory_block}", memory_block if memory_block else ""
|
||||||
|
).strip()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to build system prompt for stream: %s", exc)
|
||||||
|
system_prompt = chat_service.mem0_integration.system_prompt_template.replace(
|
||||||
|
"{memory_block}", ""
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
payload["messages"] = [{"role": "system", "content": system_prompt}] + messages
|
||||||
payload.pop("user_id", None)
|
payload.pop("user_id", None)
|
||||||
payload.pop("include_audio", None)
|
payload.pop("include_audio", None)
|
||||||
payload["model"] = OPENAI_LLM_MODEL
|
payload["model"] = OPENAI_LLM_MODEL
|
||||||
payload["thinking"] = {"type": "disabled"}
|
payload["thinking"] = {"type": "disabled"}
|
||||||
payload["reasoning"] = {"effort": "minimal"}
|
payload["reasoning"] = {"effort": "minimal"}
|
||||||
payload["max_tokens"] = 200
|
payload["max_tokens"] = 400
|
||||||
|
payload.setdefault("response_format", {"type": "json_object"})
|
||||||
payload["stream"] = True
|
payload["stream"] = True
|
||||||
|
|
||||||
base_url = OPENAI_API_BASE_URL_CONFIG or ""
|
base_url = OPENAI_API_BASE_URL_CONFIG or ""
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise HTTPException(status_code=500, detail="Upstream base URL is not configured")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Upstream base URL is not configured"
|
||||||
|
)
|
||||||
|
|
||||||
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
|
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
|
||||||
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
|
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
|
||||||
raise HTTPException(status_code=500, detail="Upstream API key is not configured")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Upstream API key is not configured"
|
||||||
|
)
|
||||||
|
|
||||||
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
|
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
@@ -178,7 +210,9 @@ async def chat_stream_endpoint(request: Request):
|
|||||||
detail_bytes = await resp.aread()
|
detail_bytes = await resp.aread()
|
||||||
await resp.aclose()
|
await resp.aclose()
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
|
detail = (
|
||||||
|
detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
|
||||||
|
)
|
||||||
raise HTTPException(status_code=resp.status_code, detail=detail)
|
raise HTTPException(status_code=resp.status_code, detail=detail)
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
@@ -218,25 +252,17 @@ async def tts_endpoint(request: TTSRequest):
|
|||||||
success, message, base64_audio = text_to_speech(text, user_id)
|
success, message, base64_audio = text_to_speech(text, user_id)
|
||||||
if success and base64_audio:
|
if success and base64_audio:
|
||||||
return TTSResponse(
|
return TTSResponse(
|
||||||
success=True,
|
success=True, audio_data=base64_audio, message=message, user_id=user_id
|
||||||
audio_data=base64_audio,
|
|
||||||
message=message,
|
|
||||||
user_id=user_id
|
|
||||||
)
|
)
|
||||||
return TTSResponse(
|
return TTSResponse(
|
||||||
success=False,
|
success=False,
|
||||||
message=message,
|
message=message,
|
||||||
error=message or "TTS failed",
|
error=message or "TTS failed",
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
"main:app",
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=8000,
|
|
||||||
reload=True
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user