# main.py import sys import logging from haystack import Document # 需要 OpenAIDocumentEmbedder 来嵌入要写入的文档 from haystack.components.embedders import OpenAIDocumentEmbedder from haystack.utils import Secret # 设置logger logger = logging.getLogger(__name__) # 导入所需的配置和构建函数 from config import ( DEFAULT_USER_ID, OPENAI_API_KEY_FROM_CONFIG, OPENAI_API_BASE_URL_CONFIG, OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_KEY, OPENAI_EMBEDDING_BASE, ) from .rag_pipeline import build_rag_pipeline # 构建 RAG 查询管道 # 辅助函数:初始化 Document Embedder (与 embedding.py 中的类似) def initialize_document_embedder() -> OpenAIDocumentEmbedder: """初始化用于嵌入文档的 OpenAIDocumentEmbedder。""" if not OPENAI_API_KEY_FROM_CONFIG or "YOUR_API_KEY" in OPENAI_API_KEY_FROM_CONFIG: print("警告: OpenAI API Key 未在 config.py 中有效配置。") # raise ValueError("OpenAI API Key not configured correctly in config.py") print(f"Initializing OpenAI Document Embedder with model: {OPENAI_EMBEDDING_MODEL}") if OPENAI_API_BASE_URL_CONFIG: print(f"Using custom API base URL from config: {OPENAI_API_BASE_URL_CONFIG}") else: print("Using default OpenAI API base URL (None specified in config).") document_embedder = OpenAIDocumentEmbedder( api_key=Secret.from_token(OPENAI_EMBEDDING_KEY), model=OPENAI_EMBEDDING_MODEL, api_base_url=OPENAI_EMBEDDING_BASE, # meta_fields_to_embed=["name"] # 如果需要嵌入元数据字段 # embedding_batch_size=10 # 可以调整批处理大小 ) print("OpenAI Document Embedder initialized.") return document_embedder def run_chat_session(user_id: str): """ 运行 RAG 聊天会话主循环。 每次用户输入时,先将其嵌入并添加到 Milvus,然后运行 RAG 管道生成回复。 """ print(f"--- Starting Chat Session for User: {user_id} ---") # 构建 RAG 查询管道和获取 DocumentStore 实例 rag_query_pipeline, document_store = build_rag_pipeline(user_id) # 初始化用于写入用户输入的 Document Embedder document_embedder = initialize_document_embedder() print("\nChatbot is ready! Type your questions or 'exit' to quit.") # 打印使用的模型信息 try: pass # print(f"Using LLM: {rag_query_pipeline.get_component('generator').model}") # 注意 RAG pipeline 中 query embedder 的名字是 'text_embedder' # print(f"Using Query Embedder: {rag_query_pipeline.get_component('text_embedder').model}") # print(f"Using Document Embedder (for writing): {document_embedder.model}") except Exception as e: print(f"Warning: Could not get component model names - {e}") while True: try: query = input(f"[{user_id}] You: ") if query.lower() == "exit": print("Exiting chat session. Goodbye!") break if not query.strip(): continue # --- 步骤 1: 嵌入用户输入并写入 Milvus --- # print(f"[Workflow] Embedding user input as a document...") # 将用户输入包装成 Haystack Document user_doc_to_write = Document(content=query, meta={"user_id": user_id}) # 使用 OpenAIDocumentEmbedder 运行嵌入 # 它需要一个列表作为输入,即使只有一个文档 embedding_result = document_embedder.run([user_doc_to_write]) embedded_docs = embedding_result.get( "documents", [] ) # 获取带有嵌入的文档列表 if embedded_docs: # print(f"[Workflow] Writing embedded document to Milvus for user {user_id}...") # 将带有嵌入的文档写入 DocumentStore document_store.write_documents(embedded_docs) # print("[Workflow] Document written to Milvus.") else: print("[Workflow] Warning: Failed to embed document, skipping write.") # 可以在这里添加错误处理或日志记录 # --- 步骤 2: 使用 RAG 查询管道生成回复 --- # print("[Workflow] Running RAG query pipeline...") # 准备 RAG 管道的输入 # text_embedder 需要原始查询文本 # prompt_builder 也需要原始查询文本(在模板中用作 {{query}}) pipeline_input = { "text_embedder": {"text": query}, "prompt_builder": {"query": query}, } # 运行 RAG 查询管道 results = rag_query_pipeline.run(pipeline_input) # --- 步骤 3: 处理并打印结果 --- # 根据文档示例,生成器的输出在 'generator' 组件的 'replies' 键中 if "llm" in results and results["llm"]["replies"]: answer = results["llm"]["replies"][0] # 尝试获取 token 使用量(可能在 meta 中) total_tokens = "N/A" try: # meta 结构可能因版本或配置而异,需要检查确认 if ( "meta" in results["llm"] and isinstance(results["llm"]["meta"], list) and results["llm"]["meta"] ): usage_info = results["llm"]["meta"][0].get("usage", {}) total_tokens = usage_info.get("total_tokens", "N/A") except Exception: pass # 忽略获取 token 的错误 print(f"Chatbot: {answer} (Tokens: {total_tokens})") else: print("Chatbot: Sorry, I couldn't generate an answer for that.") logger.debug(f"Pipeline Results: {results}") # 记录调试信息 except KeyboardInterrupt: print("\nExiting chat session. Goodbye!") break except Exception as e: print(f"\nAn error occurred: {e}") import traceback traceback.print_exc() # 打印详细的回溯信息 if __name__ == "__main__": current_user_id = DEFAULT_USER_ID run_chat_session(current_user_id)