Files
rag_chat/api/main.py
gameloader cfd06717e9
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
feat(api): integrate Mem0 context into chat stream
2026-01-24 17:02:54 +08:00

269 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
import httpx
import uvicorn
import os
import sys
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from chat_service import ChatService
from config import (
DEFAULT_USER_ID,
OPENAI_API_BASE_URL_CONFIG,
OPENAI_API_KEY_FROM_CONFIG,
OPENAI_LLM_MODEL,
)
from api.doubao_tts import text_to_speech
from logging_config import setup_logging, get_logger
# 设置日志
setup_logging(
level=os.getenv("LOG_LEVEL", "INFO"),
enable_file_logging=os.getenv("ENABLE_FILE_LOGGING", "false").lower() == "true",
)
# 获取logger
logger = get_logger(__name__)
# 请求和响应模型
class ChatRequest(BaseModel):
message: str
user_id: Optional[str] = None
include_audio: Optional[bool] = True
class ChatResponse(BaseModel):
success: bool
response: Optional[str] = None
action: Optional[str] = None
parse_error: Optional[str] = None
tokens: Optional[int] = None
user_id: str
error: Optional[str] = None
audio_data: Optional[str] = None # base64 编码的音频数据
audio_error: Optional[str] = None
class TTSRequest(BaseModel):
text: str
user_id: Optional[str] = None
class TTSResponse(BaseModel):
success: bool
audio_data: Optional[str] = None # base64 编码的音频数据
message: Optional[str] = None
user_id: Optional[str] = None
error: Optional[str] = None
# 全局聊天服务实例
chat_service = ChatService()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理器"""
logger.info("Starting Mem0 Memory API...")
# 应用启动时初始化聊天服务
chat_service.initialize()
logger.info("Mem0 Memory API started successfully")
yield
# 应用关闭时可以在这里执行清理操作
# 例如关闭数据库连接、释放资源等
logger.info("Shutting down Mem0 Memory API...")
# 创建 FastAPI 应用
app = FastAPI(
title="Mem0 Memory API",
description="基于 Mem0 的记忆增强聊天服务 API",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/")
async def root():
"""根路径,返回 API 信息"""
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "healthy"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
聊天接口
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
"""
try:
# 如果请求中指定了用户ID创建新的服务实例
if request.user_id and request.user_id != chat_service.user_id:
user_chat_service = ChatService(request.user_id)
user_chat_service.initialize()
result = user_chat_service.chat(request.message, request.include_audio)
else:
result = chat_service.chat(request.message, request.include_audio)
return ChatResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/chat/stream")
async def chat_stream_endpoint(request: Request):
"""
纯转发豆包 SSE 流式输出的接口OpenAI 格式兼容)。
允许传入 OpenAI 标准的 messages也兼容单一 message 字段。
"""
try:
payload = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body")
if not isinstance(payload, dict):
raise HTTPException(
status_code=400, detail="Request body must be a JSON object"
)
messages = payload.get("messages")
if messages is None:
message = payload.get("message")
if not message:
raise HTTPException(
status_code=400, detail="Missing 'message' or 'messages'"
)
messages = [{"role": "user", "content": message}]
payload.pop("message", None)
user_id = payload.get("user_id") or DEFAULT_USER_ID
user_input = ""
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user" and msg.get("content"):
user_input = msg["content"]
break
try:
mem0 = chat_service.mem0_integration
memories = (
mem0.search_memories(user_input, user_id, limit=5) if user_input else []
)
memory_block = mem0.format_memories_for_prompt(memories)
system_prompt = mem0.system_prompt_template.replace(
"{memory_block}", memory_block if memory_block else ""
).strip()
except Exception as exc:
logger.warning("Failed to build system prompt for stream: %s", exc)
system_prompt = chat_service.mem0_integration.system_prompt_template.replace(
"{memory_block}", ""
).strip()
payload["messages"] = [{"role": "system", "content": system_prompt}] + messages
payload.pop("user_id", None)
payload.pop("include_audio", None)
payload["model"] = OPENAI_LLM_MODEL
payload["thinking"] = {"type": "disabled"}
payload["reasoning"] = {"effort": "minimal"}
payload["max_tokens"] = 400
payload.setdefault("response_format", {"type": "json_object"})
payload["stream"] = True
base_url = OPENAI_API_BASE_URL_CONFIG or ""
if not base_url:
raise HTTPException(
status_code=500, detail="Upstream base URL is not configured"
)
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
raise HTTPException(
status_code=500, detail="Upstream API key is not configured"
)
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream",
}
client = httpx.AsyncClient(timeout=None)
req = client.build_request("POST", upstream_url, headers=headers, json=payload)
resp = await client.send(req, stream=True)
if resp.status_code != 200:
detail_bytes = await resp.aread()
await resp.aclose()
await client.aclose()
detail = (
detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
)
raise HTTPException(status_code=resp.status_code, detail=detail)
async def event_stream():
try:
async for chunk in resp.aiter_raw():
if chunk:
yield chunk
finally:
await resp.aclose()
await client.aclose()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.post("/tts", response_model=TTSResponse)
async def tts_endpoint(request: TTSRequest):
"""
文本转语音接口
仅负责将输入文本转为 base64 音频数据
"""
text = request.text.strip() if request.text else ""
if not text:
raise HTTPException(status_code=400, detail="text is required")
user_id = request.user_id or DEFAULT_USER_ID
try:
success, message, base64_audio = text_to_speech(text, user_id)
if success and base64_audio:
return TTSResponse(
success=True, audio_data=base64_audio, message=message, user_id=user_id
)
return TTSResponse(
success=False,
message=message,
error=message or "TTS failed",
user_id=user_id,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)