Files
rag_chat/api/main.py
gameloader 618c2ec209
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
feat(memory): integrate Mem0 for enhanced conversational memory
2025-10-12 16:36:41 +08:00

83 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
from chat_service import ChatService
# 请求和响应模型
class ChatRequest(BaseModel):
message: str
user_id: Optional[str] = None
include_audio: Optional[bool] = True
class ChatResponse(BaseModel):
success: bool
response: Optional[str] = None
tokens: Optional[int] = None
user_id: str
error: Optional[str] = None
audio_data: Optional[str] = None # base64 编码的音频数据
audio_error: Optional[str] = None
# 创建 FastAPI 应用
app = FastAPI(
title="Mem0 Memory API",
description="基于 Mem0 的记忆增强聊天服务 API",
version="1.0.0"
)
# 全局聊天服务实例
chat_service = ChatService()
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化聊天服务"""
chat_service.initialize()
@app.get("/")
async def root():
"""根路径,返回 API 信息"""
return {"message": "Mem0 Memory API is running", "version": "1.0.0"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "healthy"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
聊天接口
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
"""
try:
# 如果请求中指定了用户ID创建新的服务实例
if request.user_id and request.user_id != chat_service.user_id:
user_chat_service = ChatService(request.user_id)
user_chat_service.initialize()
result = user_chat_service.chat(request.message, request.include_audio)
else:
result = chat_service.chat(request.message, request.include_audio)
return ChatResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)