Files
rag_chat/api/chat_service.py

151 lines
5.6 KiB
Python

from typing import Dict, Any, Tuple
import base64
import threading
from haystack import Document, Pipeline
from milvus_haystack import MilvusDocumentStore
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils import Secret
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import (
DEFAULT_USER_ID,
OPENAI_EMBEDDING_KEY,
OPENAI_EMBEDDING_MODEL,
OPENAI_EMBEDDING_BASE,
)
from haystack_rag.rag_pipeline import build_rag_pipeline
from doubao_tts import text_to_speech
class ChatService:
def __init__(self, user_id: str = None):
self.user_id = user_id or DEFAULT_USER_ID
self.rag_pipeline = None
self.document_store = None
self.document_embedder = None
self._initialized = False
def initialize(self):
"""初始化 RAG 管道和相关组件"""
if self._initialized:
return
# 构建 RAG 查询管道和获取 DocumentStore 实例
self.rag_pipeline, self.document_store = build_rag_pipeline(self.user_id)
# 初始化用于写入用户输入的 Document Embedder
self.document_embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
model=OPENAI_EMBEDDING_MODEL,
api_base_url=OPENAI_EMBEDDING_BASE,
)
self._initialized = True
def _embed_and_store_async(self, user_input: str):
"""异步嵌入并存储用户输入"""
try:
# 步骤 1: 嵌入用户输入并写入 Milvus
user_doc_to_write = Document(content=user_input, meta={"user_id": self.user_id})
# 使用 OpenAIDocumentEmbedder 运行嵌入
embedding_result = self.document_embedder.run([user_doc_to_write])
embedded_docs = embedding_result.get("documents", [])
if embedded_docs:
# 将带有嵌入的文档写入 DocumentStore
self.document_store.write_documents(embedded_docs)
print(f"[INFO] 用户输入已成功嵌入并存储: {user_input[:50]}...")
else:
print(f"[WARNING] 用户输入嵌入失败: {user_input[:50]}...")
except Exception as e:
print(f"[ERROR] 异步嵌入和存储过程出错: {e}")
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
"""处理用户输入并返回回复(包含音频)"""
if not self._initialized:
self.initialize()
try:
# 步骤 1: 异步启动嵌入和存储过程(不阻塞主流程)
embedding_thread = threading.Thread(
target=self._embed_and_store_async,
args=(user_input,),
daemon=True
)
embedding_thread.start()
# 步骤 2: 立即使用 RAG 查询管道生成回复(不等待嵌入完成)
pipeline_input = {
"text_embedder": {"text": user_input},
"prompt_builder": {"query": user_input},
}
# 运行 RAG 查询管道
results = self.rag_pipeline.run(pipeline_input)
# 步骤 3: 处理并返回结果
if "llm" in results and results["llm"]["replies"]:
answer = results["llm"]["replies"][0]
# 尝试获取 token 使用量
total_tokens = None
try:
if (
"meta" in results["llm"]
and isinstance(results["llm"]["meta"], list)
and results["llm"]["meta"]
):
usage_info = results["llm"]["meta"][0].get("usage", {})
total_tokens = usage_info.get("total_tokens")
except Exception:
pass
# 步骤 4: 生成语音(如果需要)
audio_data = None
audio_error = None
if include_audio:
try:
success, message, base64_audio = text_to_speech(answer, self.user_id)
if success and base64_audio:
# 直接使用 base64 音频数据
audio_data = base64_audio
else:
audio_error = message
except Exception as e:
audio_error = f"TTS错误: {str(e)}"
result = {
"success": True,
"response": answer,
"user_id": self.user_id
}
# 添加可选字段
if total_tokens is not None:
result["tokens"] = total_tokens
if audio_data:
result["audio_data"] = audio_data
if audio_error:
result["audio_error"] = audio_error
return result
else:
return {
"success": False,
"error": "Could not generate an answer",
"debug_info": results,
"user_id": self.user_id
}
except Exception as e:
return {
"success": False,
"error": str(e),
"user_id": self.user_id
}