190 lines
5.6 KiB
Python
190 lines
5.6 KiB
Python
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import httpx
|
||
import uvicorn
|
||
import os
|
||
import sys
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from chat_service import ChatService
|
||
from config import OPENAI_API_BASE_URL_CONFIG, OPENAI_API_KEY_FROM_CONFIG, OPENAI_LLM_MODEL
|
||
from logging_config import setup_logging, get_logger
|
||
|
||
# 设置日志
|
||
setup_logging(
|
||
level=os.getenv('LOG_LEVEL', 'INFO'),
|
||
enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true'
|
||
)
|
||
|
||
# 获取logger
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
# 请求和响应模型
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
user_id: Optional[str] = None
|
||
include_audio: Optional[bool] = True
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
success: bool
|
||
response: Optional[str] = None
|
||
action: Optional[str] = None
|
||
parse_error: Optional[str] = None
|
||
tokens: Optional[int] = None
|
||
user_id: str
|
||
error: Optional[str] = None
|
||
audio_data: Optional[str] = None # base64 编码的音频数据
|
||
audio_error: Optional[str] = None
|
||
|
||
|
||
# 全局聊天服务实例
|
||
chat_service = ChatService()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理器"""
|
||
logger.info("Starting Mem0 Memory API...")
|
||
# 应用启动时初始化聊天服务
|
||
chat_service.initialize()
|
||
logger.info("Mem0 Memory API started successfully")
|
||
yield
|
||
# 应用关闭时可以在这里执行清理操作
|
||
# 例如关闭数据库连接、释放资源等
|
||
logger.info("Shutting down Mem0 Memory API...")
|
||
|
||
|
||
# 创建 FastAPI 应用
|
||
app = FastAPI(
|
||
title="Mem0 Memory API",
|
||
description="基于 Mem0 的记忆增强聊天服务 API",
|
||
version="1.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""根路径,返回 API 信息"""
|
||
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查端点"""
|
||
return {"status": "healthy"}
|
||
|
||
|
||
@app.post("/chat", response_model=ChatResponse)
|
||
async def chat_endpoint(request: ChatRequest):
|
||
"""
|
||
聊天接口
|
||
|
||
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
|
||
"""
|
||
try:
|
||
# 如果请求中指定了用户ID,创建新的服务实例
|
||
if request.user_id and request.user_id != chat_service.user_id:
|
||
user_chat_service = ChatService(request.user_id)
|
||
user_chat_service.initialize()
|
||
result = user_chat_service.chat(request.message, request.include_audio)
|
||
else:
|
||
result = chat_service.chat(request.message, request.include_audio)
|
||
|
||
return ChatResponse(**result)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/chat/stream")
|
||
async def chat_stream_endpoint(request: Request):
|
||
"""
|
||
纯转发豆包 SSE 流式输出的接口(OpenAI 格式兼容)。
|
||
|
||
允许传入 OpenAI 标准的 messages,也兼容单一 message 字段。
|
||
"""
|
||
try:
|
||
payload = await request.json()
|
||
except Exception:
|
||
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
||
|
||
if not isinstance(payload, dict):
|
||
raise HTTPException(status_code=400, detail="Request body must be a JSON object")
|
||
|
||
messages = payload.get("messages")
|
||
if messages is None:
|
||
message = payload.get("message")
|
||
if not message:
|
||
raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'")
|
||
messages = [{"role": "user", "content": message}]
|
||
payload.pop("message", None)
|
||
|
||
payload["messages"] = messages
|
||
payload.pop("user_id", None)
|
||
payload.pop("include_audio", None)
|
||
if not payload.get("model"):
|
||
payload["model"] = OPENAI_LLM_MODEL
|
||
payload["stream"] = True
|
||
|
||
base_url = OPENAI_API_BASE_URL_CONFIG or ""
|
||
if not base_url:
|
||
raise HTTPException(status_code=500, detail="Upstream base URL is not configured")
|
||
|
||
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
|
||
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
|
||
raise HTTPException(status_code=500, detail="Upstream API key is not configured")
|
||
|
||
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
"Accept": "text/event-stream",
|
||
}
|
||
|
||
client = httpx.AsyncClient(timeout=None)
|
||
req = client.build_request("POST", upstream_url, headers=headers, json=payload)
|
||
resp = await client.send(req, stream=True)
|
||
|
||
if resp.status_code != 200:
|
||
detail_bytes = await resp.aread()
|
||
await resp.aclose()
|
||
await client.aclose()
|
||
detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
|
||
raise HTTPException(status_code=resp.status_code, detail=detail)
|
||
|
||
async def event_stream():
|
||
try:
|
||
async for chunk in resp.aiter_raw():
|
||
if chunk:
|
||
yield chunk
|
||
finally:
|
||
await resp.aclose()
|
||
await client.aclose()
|
||
|
||
return StreamingResponse(
|
||
event_stream(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
"main:app",
|
||
host="0.0.0.0",
|
||
port=8000,
|
||
reload=True
|
||
)
|