Files
rag_chat/api/main.py
game-loader f858576c02
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
feat(api): add streaming chat endpoint proxying upstream SSE
2025-12-30 12:01:17 +08:00

190 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
import httpx
import uvicorn
import os
import sys
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from chat_service import ChatService
from config import OPENAI_API_BASE_URL_CONFIG, OPENAI_API_KEY_FROM_CONFIG, OPENAI_LLM_MODEL
from logging_config import setup_logging, get_logger
# 设置日志
setup_logging(
level=os.getenv('LOG_LEVEL', 'INFO'),
enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true'
)
# 获取logger
logger = get_logger(__name__)
# 请求和响应模型
class ChatRequest(BaseModel):
message: str
user_id: Optional[str] = None
include_audio: Optional[bool] = True
class ChatResponse(BaseModel):
success: bool
response: Optional[str] = None
action: Optional[str] = None
parse_error: Optional[str] = None
tokens: Optional[int] = None
user_id: str
error: Optional[str] = None
audio_data: Optional[str] = None # base64 编码的音频数据
audio_error: Optional[str] = None
# 全局聊天服务实例
chat_service = ChatService()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理器"""
logger.info("Starting Mem0 Memory API...")
# 应用启动时初始化聊天服务
chat_service.initialize()
logger.info("Mem0 Memory API started successfully")
yield
# 应用关闭时可以在这里执行清理操作
# 例如关闭数据库连接、释放资源等
logger.info("Shutting down Mem0 Memory API...")
# 创建 FastAPI 应用
app = FastAPI(
title="Mem0 Memory API",
description="基于 Mem0 的记忆增强聊天服务 API",
version="1.0.0",
lifespan=lifespan
)
@app.get("/")
async def root():
"""根路径,返回 API 信息"""
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "healthy"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
聊天接口
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
"""
try:
# 如果请求中指定了用户ID创建新的服务实例
if request.user_id and request.user_id != chat_service.user_id:
user_chat_service = ChatService(request.user_id)
user_chat_service.initialize()
result = user_chat_service.chat(request.message, request.include_audio)
else:
result = chat_service.chat(request.message, request.include_audio)
return ChatResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/chat/stream")
async def chat_stream_endpoint(request: Request):
"""
纯转发豆包 SSE 流式输出的接口OpenAI 格式兼容)。
允许传入 OpenAI 标准的 messages也兼容单一 message 字段。
"""
try:
payload = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body")
if not isinstance(payload, dict):
raise HTTPException(status_code=400, detail="Request body must be a JSON object")
messages = payload.get("messages")
if messages is None:
message = payload.get("message")
if not message:
raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'")
messages = [{"role": "user", "content": message}]
payload.pop("message", None)
payload["messages"] = messages
payload.pop("user_id", None)
payload.pop("include_audio", None)
if not payload.get("model"):
payload["model"] = OPENAI_LLM_MODEL
payload["stream"] = True
base_url = OPENAI_API_BASE_URL_CONFIG or ""
if not base_url:
raise HTTPException(status_code=500, detail="Upstream base URL is not configured")
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
raise HTTPException(status_code=500, detail="Upstream API key is not configured")
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream",
}
client = httpx.AsyncClient(timeout=None)
req = client.build_request("POST", upstream_url, headers=headers, json=payload)
resp = await client.send(req, stream=True)
if resp.status_code != 200:
detail_bytes = await resp.aread()
await resp.aclose()
await client.aclose()
detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
raise HTTPException(status_code=resp.status_code, detail=detail)
async def event_stream():
try:
async for chunk in resp.aiter_raw():
if chunk:
yield chunk
finally:
await resp.aclose()
await client.aclose()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)