from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Optional import httpx import uvicorn import os import sys # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from chat_service import ChatService from config import OPENAI_API_BASE_URL_CONFIG, OPENAI_API_KEY_FROM_CONFIG, OPENAI_LLM_MODEL from logging_config import setup_logging, get_logger # 设置日志 setup_logging( level=os.getenv('LOG_LEVEL', 'INFO'), enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true' ) # 获取logger logger = get_logger(__name__) # 请求和响应模型 class ChatRequest(BaseModel): message: str user_id: Optional[str] = None include_audio: Optional[bool] = True class ChatResponse(BaseModel): success: bool response: Optional[str] = None action: Optional[str] = None parse_error: Optional[str] = None tokens: Optional[int] = None user_id: str error: Optional[str] = None audio_data: Optional[str] = None # base64 编码的音频数据 audio_error: Optional[str] = None # 全局聊天服务实例 chat_service = ChatService() @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理器""" logger.info("Starting Mem0 Memory API...") # 应用启动时初始化聊天服务 chat_service.initialize() logger.info("Mem0 Memory API started successfully") yield # 应用关闭时可以在这里执行清理操作 # 例如关闭数据库连接、释放资源等 logger.info("Shutting down Mem0 Memory API...") # 创建 FastAPI 应用 app = FastAPI( title="Mem0 Memory API", description="基于 Mem0 的记忆增强聊天服务 API", version="1.0.0", lifespan=lifespan ) @app.get("/") async def root(): """根路径,返回 API 信息""" return {"message": "Mem0 Memory API is running", "version": "1.0.0"} @app.get("/health") async def health_check(): """健康检查端点""" return {"status": "healthy"} @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): """ 聊天接口 接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据) """ try: # 如果请求中指定了用户ID,创建新的服务实例 if request.user_id and request.user_id != chat_service.user_id: user_chat_service = ChatService(request.user_id) user_chat_service.initialize() result = user_chat_service.chat(request.message, request.include_audio) else: result = chat_service.chat(request.message, request.include_audio) return ChatResponse(**result) except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/chat/stream") async def chat_stream_endpoint(request: Request): """ 纯转发豆包 SSE 流式输出的接口(OpenAI 格式兼容)。 允许传入 OpenAI 标准的 messages,也兼容单一 message 字段。 """ try: payload = await request.json() except Exception: raise HTTPException(status_code=400, detail="Invalid JSON body") if not isinstance(payload, dict): raise HTTPException(status_code=400, detail="Request body must be a JSON object") messages = payload.get("messages") if messages is None: message = payload.get("message") if not message: raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'") messages = [{"role": "user", "content": message}] payload.pop("message", None) payload["messages"] = messages payload.pop("user_id", None) payload.pop("include_audio", None) if not payload.get("model"): payload["model"] = OPENAI_LLM_MODEL payload["stream"] = True base_url = OPENAI_API_BASE_URL_CONFIG or "" if not base_url: raise HTTPException(status_code=500, detail="Upstream base URL is not configured") api_key = OPENAI_API_KEY_FROM_CONFIG or "" if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"): raise HTTPException(status_code=500, detail="Upstream API key is not configured") upstream_url = f"{base_url.rstrip('/')}/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "text/event-stream", } client = httpx.AsyncClient(timeout=None) req = client.build_request("POST", upstream_url, headers=headers, json=payload) resp = await client.send(req, stream=True) if resp.status_code != 200: detail_bytes = await resp.aread() await resp.aclose() await client.aclose() detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error" raise HTTPException(status_code=resp.status_code, detail=detail) async def event_stream(): try: async for chunk in resp.aiter_raw(): if chunk: yield chunk finally: await resp.aclose() await client.aclose() return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) if __name__ == "__main__": uvicorn.run( "main:app", host="0.0.0.0", port=8000, reload=True )