feat(api): add streaming chat endpoint proxying upstream SSE
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
Some checks failed
Build and Push Docker / build-and-push (push) Has been cancelled
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,3 +11,5 @@ wheels/
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
milvus_user_data_openai
|
milvus_user_data_openai
|
||||||
milvus_lite.db
|
milvus_lite.db
|
||||||
|
chroma_db_store/chroma.sqlite3
|
||||||
|
local_memory.db
|
||||||
|
|||||||
81
api/main.py
81
api/main.py
@@ -1,7 +1,9 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import httpx
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -10,6 +12,7 @@ import sys
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from chat_service import ChatService
|
from chat_service import ChatService
|
||||||
|
from config import OPENAI_API_BASE_URL_CONFIG, OPENAI_API_KEY_FROM_CONFIG, OPENAI_LLM_MODEL
|
||||||
from logging_config import setup_logging, get_logger
|
from logging_config import setup_logging, get_logger
|
||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
@@ -101,6 +104,82 @@ async def chat_endpoint(request: ChatRequest):
|
|||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat/stream")
|
||||||
|
async def chat_stream_endpoint(request: Request):
|
||||||
|
"""
|
||||||
|
纯转发豆包 SSE 流式输出的接口(OpenAI 格式兼容)。
|
||||||
|
|
||||||
|
允许传入 OpenAI 标准的 messages,也兼容单一 message 字段。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise HTTPException(status_code=400, detail="Request body must be a JSON object")
|
||||||
|
|
||||||
|
messages = payload.get("messages")
|
||||||
|
if messages is None:
|
||||||
|
message = payload.get("message")
|
||||||
|
if not message:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing 'message' or 'messages'")
|
||||||
|
messages = [{"role": "user", "content": message}]
|
||||||
|
payload.pop("message", None)
|
||||||
|
|
||||||
|
payload["messages"] = messages
|
||||||
|
payload.pop("user_id", None)
|
||||||
|
payload.pop("include_audio", None)
|
||||||
|
if not payload.get("model"):
|
||||||
|
payload["model"] = OPENAI_LLM_MODEL
|
||||||
|
payload["stream"] = True
|
||||||
|
|
||||||
|
base_url = OPENAI_API_BASE_URL_CONFIG or ""
|
||||||
|
if not base_url:
|
||||||
|
raise HTTPException(status_code=500, detail="Upstream base URL is not configured")
|
||||||
|
|
||||||
|
api_key = OPENAI_API_KEY_FROM_CONFIG or ""
|
||||||
|
if not api_key or api_key.startswith("YOUR_API_KEY_PLACEHOLDER"):
|
||||||
|
raise HTTPException(status_code=500, detail="Upstream API key is not configured")
|
||||||
|
|
||||||
|
upstream_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
}
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(timeout=None)
|
||||||
|
req = client.build_request("POST", upstream_url, headers=headers, json=payload)
|
||||||
|
resp = await client.send(req, stream=True)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
detail_bytes = await resp.aread()
|
||||||
|
await resp.aclose()
|
||||||
|
await client.aclose()
|
||||||
|
detail = detail_bytes.decode("utf-8", "ignore") if detail_bytes else "Upstream error"
|
||||||
|
raise HTTPException(status_code=resp.status_code, detail=detail)
|
||||||
|
|
||||||
|
async def event_stream():
|
||||||
|
try:
|
||||||
|
async for chunk in resp.aiter_raw():
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
await resp.aclose()
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_stream(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"main:app",
|
"main:app",
|
||||||
|
|||||||
Reference in New Issue
Block a user