from typing import Dict, Any, Tuple import base64 import threading from haystack import Document, Pipeline from milvus_haystack import MilvusDocumentStore from haystack.components.embedders import OpenAIDocumentEmbedder from haystack.utils import Secret import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import ( DEFAULT_USER_ID, OPENAI_EMBEDDING_KEY, OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_BASE, ) from haystack_rag.rag_pipeline import build_rag_pipeline from doubao_tts import text_to_speech class ChatService: def __init__(self, user_id: str = None): self.user_id = user_id or DEFAULT_USER_ID self.rag_pipeline = None self.document_store = None self.document_embedder = None self._initialized = False def initialize(self): """初始化 RAG 管道和相关组件""" if self._initialized: return # 构建 RAG 查询管道和获取 DocumentStore 实例 self.rag_pipeline, self.document_store = build_rag_pipeline(self.user_id) # 初始化用于写入用户输入的 Document Embedder self.document_embedder = OpenAIDocumentEmbedder( api_key=Secret.from_token(OPENAI_EMBEDDING_KEY), model=OPENAI_EMBEDDING_MODEL, api_base_url=OPENAI_EMBEDDING_BASE, ) self._initialized = True def _embed_and_store_async(self, user_input: str): """异步嵌入并存储用户输入""" try: # 步骤 1: 嵌入用户输入并写入 Milvus user_doc_to_write = Document(content=user_input, meta={"user_id": self.user_id}) # 使用 OpenAIDocumentEmbedder 运行嵌入 embedding_result = self.document_embedder.run([user_doc_to_write]) embedded_docs = embedding_result.get("documents", []) if embedded_docs: # 将带有嵌入的文档写入 DocumentStore self.document_store.write_documents(embedded_docs) print(f"[INFO] 用户输入已成功嵌入并存储: {user_input[:50]}...") else: print(f"[WARNING] 用户输入嵌入失败: {user_input[:50]}...") except Exception as e: print(f"[ERROR] 异步嵌入和存储过程出错: {e}") def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]: """处理用户输入并返回回复(包含音频)""" if not self._initialized: self.initialize() try: # 步骤 1: 异步启动嵌入和存储过程(不阻塞主流程) embedding_thread = threading.Thread( target=self._embed_and_store_async, args=(user_input,), daemon=True ) embedding_thread.start() # 步骤 2: 立即使用 RAG 查询管道生成回复(不等待嵌入完成) pipeline_input = { "text_embedder": {"text": user_input}, "prompt_builder": {"query": user_input}, } # 运行 RAG 查询管道 results = self.rag_pipeline.run(pipeline_input) # 步骤 3: 处理并返回结果 if "llm" in results and results["llm"]["replies"]: answer = results["llm"]["replies"][0] # 尝试获取 token 使用量 total_tokens = None try: if ( "meta" in results["llm"] and isinstance(results["llm"]["meta"], list) and results["llm"]["meta"] ): usage_info = results["llm"]["meta"][0].get("usage", {}) total_tokens = usage_info.get("total_tokens") except Exception: pass # 步骤 4: 生成语音(如果需要) audio_data = None audio_error = None if include_audio: try: success, message, base64_audio = text_to_speech(answer, self.user_id) if success and base64_audio: # 直接使用 base64 音频数据 audio_data = base64_audio else: audio_error = message except Exception as e: audio_error = f"TTS错误: {str(e)}" result = { "success": True, "response": answer, "user_id": self.user_id } # 添加可选字段 if total_tokens is not None: result["tokens"] = total_tokens if audio_data: result["audio_data"] = audio_data if audio_error: result["audio_error"] = audio_error return result else: return { "success": False, "error": "Could not generate an answer", "debug_info": results, "user_id": self.user_id } except Exception as e: return { "success": False, "error": str(e), "user_id": self.user_id }