109 lines
2.9 KiB
Python
109 lines
2.9 KiB
Python
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import uvicorn
|
||
import os
|
||
import sys
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from chat_service import ChatService
|
||
from logging_config import setup_logging, get_logger
|
||
|
||
# 设置日志
|
||
setup_logging(
|
||
level=os.getenv('LOG_LEVEL', 'INFO'),
|
||
enable_file_logging=os.getenv('ENABLE_FILE_LOGGING', 'false').lower() == 'true'
|
||
)
|
||
|
||
# 获取logger
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
# 请求和响应模型
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
user_id: Optional[str] = None
|
||
include_audio: Optional[bool] = True
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
success: bool
|
||
response: Optional[str] = None
|
||
tokens: Optional[int] = None
|
||
user_id: str
|
||
error: Optional[str] = None
|
||
audio_data: Optional[str] = None # base64 编码的音频数据
|
||
audio_error: Optional[str] = None
|
||
|
||
|
||
# 全局聊天服务实例
|
||
chat_service = ChatService()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理器"""
|
||
logger.info("Starting Mem0 Memory API...")
|
||
# 应用启动时初始化聊天服务
|
||
chat_service.initialize()
|
||
logger.info("Mem0 Memory API started successfully")
|
||
yield
|
||
# 应用关闭时可以在这里执行清理操作
|
||
# 例如关闭数据库连接、释放资源等
|
||
logger.info("Shutting down Mem0 Memory API...")
|
||
|
||
|
||
# 创建 FastAPI 应用
|
||
app = FastAPI(
|
||
title="Mem0 Memory API",
|
||
description="基于 Mem0 的记忆增强聊天服务 API",
|
||
version="1.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""根路径,返回 API 信息"""
|
||
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查端点"""
|
||
return {"status": "healthy"}
|
||
|
||
|
||
@app.post("/chat", response_model=ChatResponse)
|
||
async def chat_endpoint(request: ChatRequest):
|
||
"""
|
||
聊天接口
|
||
|
||
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
|
||
"""
|
||
try:
|
||
# 如果请求中指定了用户ID,创建新的服务实例
|
||
if request.user_id and request.user_id != chat_service.user_id:
|
||
user_chat_service = ChatService(request.user_id)
|
||
user_chat_service.initialize()
|
||
result = user_chat_service.chat(request.message, request.include_audio)
|
||
else:
|
||
result = chat_service.chat(request.message, request.include_audio)
|
||
|
||
return ChatResponse(**result)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
"main:app",
|
||
host="0.0.0.0",
|
||
port=8000,
|
||
reload=True
|
||
)
|