196 lines
7.9 KiB
Python
196 lines
7.9 KiB
Python
# rag_pipeline.py
|
||
from haystack import Pipeline
|
||
from haystack import Document # 导入 Document
|
||
|
||
from milvus_haystack import MilvusDocumentStore
|
||
from data_handling import initialize_milvus_lite
|
||
from embedding import initialize_text_embedder
|
||
from retrieval import initialize_vector_retriever
|
||
from llm_integration import initialize_llm_and_prompt_builder
|
||
from haystack.utils import Secret
|
||
|
||
|
||
def build_rag_pipeline(user_id: str) -> tuple[Pipeline, MilvusDocumentStore]:
|
||
"""
|
||
为指定用户构建并返回 RAG 查询管道和对应的 DocumentStore。
|
||
"""
|
||
print(f"\n--- Building RAG Pipeline for User: {user_id} ---")
|
||
|
||
# 1. 初始化该用户的 DocumentStore
|
||
document_store = initialize_milvus_lite(user_id)
|
||
|
||
# 2. 初始化共享组件(可以在应用启动时初始化一次,这里为简单起见每次都创建)
|
||
text_embedder = initialize_text_embedder()
|
||
vector_retriever = initialize_vector_retriever(document_store)
|
||
llm, prompt_builder = initialize_llm_and_prompt_builder()
|
||
|
||
# 3. 创建 Haystack Pipeline
|
||
rag_pipeline = Pipeline()
|
||
|
||
# 4. 向管道添加组件,并指定名称
|
||
rag_pipeline.add_component(instance=text_embedder, name="text_embedder")
|
||
rag_pipeline.add_component(instance=vector_retriever, name="retriever")
|
||
rag_pipeline.add_component(instance=prompt_builder, name="prompt_builder")
|
||
rag_pipeline.add_component(instance=llm, name="llm")
|
||
|
||
# 5. 连接管道组件
|
||
# - 将用户问题文本输入到 text_embedder
|
||
# - 将 text_embedder 输出的嵌入向量连接到 retriever 的查询嵌入输入
|
||
# - 将 retriever 输出的文档连接到 prompt_builder 的文档输入
|
||
# - 将用户问题文本也连接到 prompt_builder 的问题输入
|
||
# - 将 prompt_builder 输出的完整提示连接到 llm 的提示输入
|
||
|
||
rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
|
||
rag_pipeline.connect("retriever.documents", "prompt_builder.documents")
|
||
rag_pipeline.connect("prompt_builder.prompt", "llm.prompt")
|
||
|
||
print("--- RAG Pipeline Built Successfully ---")
|
||
# 返回管道和文档存储实例,因为主程序需要用文档存储来写入数据
|
||
return rag_pipeline, document_store
|
||
|
||
|
||
# --- Corrected Test Block ---
|
||
if __name__ == "__main__":
|
||
import os # Needed for API Key check
|
||
|
||
# We need OpenAIDocumentEmbedder to index test documents
|
||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||
|
||
# Import necessary config for initializing the Document Embedder
|
||
from config import (
|
||
OPENAI_API_KEY_FROM_CONFIG,
|
||
OPENAI_API_BASE_URL_CONFIG,
|
||
OPENAI_EMBEDDING_MODEL,
|
||
)
|
||
|
||
# --- Configuration ---
|
||
test_user = "test_user"
|
||
test_query = "Haystack是什么?"
|
||
# Sample documents to index for testing
|
||
docs_to_index = [
|
||
Document(
|
||
content="Haystack是一个用于构建 NLP 应用程序(如问答系统、语义搜索)的开源框架。",
|
||
meta={"user_id": test_user, "source": "test_doc_1"},
|
||
),
|
||
Document(
|
||
content="你可以使用 Haystack 连接不同的组件,如文档存储、检索器和生成器。",
|
||
meta={"user_id": test_user, "source": "test_doc_2"},
|
||
),
|
||
Document(
|
||
content="Milvus 是一个流行的向量数据库,常用于 RAG 系统中存储嵌入。",
|
||
meta={"user_id": test_user, "source": "test_doc_3"},
|
||
),
|
||
]
|
||
|
||
print(f"--- Running Test for RAG Pipeline (User: {test_user}) ---")
|
||
|
||
# --- 1. Check API Key Availability ---
|
||
# Pipeline execution requires OpenAI API calls
|
||
api_key_configured = (
|
||
OPENAI_API_KEY_FROM_CONFIG and "YOUR_API_KEY" not in OPENAI_API_KEY_FROM_CONFIG
|
||
)
|
||
if not api_key_configured:
|
||
print("\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||
print("! WARNING: OpenAI API Key not configured in config.py. !")
|
||
print("! Skipping RAG pipeline test execution. !")
|
||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||
exit() # Exit script if key is missing for test run
|
||
else:
|
||
print("\n[Test Setup] OpenAI API Key found in config.")
|
||
|
||
# --- 2. Build the RAG Pipeline and get the Document Store ---
|
||
# This function initializes the store (potentially dropping old data)
|
||
# and builds the *querying* pipeline.
|
||
try:
|
||
pipeline, store = build_rag_pipeline(test_user)
|
||
except Exception as e:
|
||
print(f"\nError building RAG pipeline: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
exit()
|
||
|
||
# --- 3. Index Test Documents (with embeddings) ---
|
||
print("\n[Test Setup] Initializing Document Embedder for indexing test data...")
|
||
try:
|
||
# Initialize the Document Embedder directly here for the test
|
||
document_embedder = OpenAIDocumentEmbedder(
|
||
api_key=Secret.from_token(OPENAI_API_KEY_FROM_CONFIG),
|
||
model=OPENAI_EMBEDDING_MODEL,
|
||
api_base_url=OPENAI_API_BASE_URL_CONFIG,
|
||
)
|
||
print("[Test Setup] Document Embedder initialized.")
|
||
|
||
print("[Test Setup] Embedding test documents...")
|
||
embedding_result = document_embedder.run(docs_to_index)
|
||
embedded_docs = embedding_result.get("documents", [])
|
||
|
||
if embedded_docs:
|
||
print(
|
||
f"[Test Setup] Writing {len(embedded_docs)} embedded documents to Milvus..."
|
||
)
|
||
store.write_documents(embedded_docs)
|
||
print("[Test Setup] Test documents written successfully.")
|
||
# Optional: Verify count
|
||
# print(f"[Test Setup] Document count in store: {store.count_documents()}")
|
||
documents_indexed = True
|
||
else:
|
||
print("[Test Setup] ERROR: Failed to embed test documents.")
|
||
documents_indexed = False
|
||
|
||
except Exception as e:
|
||
print(f"\nError during test data indexing: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
documents_indexed = False
|
||
|
||
# --- 4. Run the RAG Pipeline (if setup succeeded) ---
|
||
if documents_indexed:
|
||
print(f"\n[Test Run] Running RAG pipeline for query: '{test_query}'")
|
||
|
||
# Prepare input for the RAG pipeline instance built by build_rag_pipeline
|
||
pipeline_input = {
|
||
"text_embedder": {"text": test_query}, # Input for the query embedder
|
||
"prompt_builder": {
|
||
"query": test_query
|
||
}, # Input for the prompt builder template
|
||
}
|
||
|
||
try:
|
||
results = pipeline.run(pipeline_input)
|
||
|
||
print("\n[Test Run] Pipeline Results:")
|
||
# Process and print the generator's answer
|
||
if "llm" in results and results["llm"]["replies"]:
|
||
answer = results["llm"]["replies"][0]
|
||
print(f"\nGenerated Answer: {answer}")
|
||
else:
|
||
print("\n[Test Run] Could not extract answer from generator.")
|
||
print(
|
||
"Full Pipeline Output:", results
|
||
) # Print full output for debugging
|
||
|
||
except Exception as e:
|
||
print(f"\n[Test Run] Error running RAG pipeline: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
else:
|
||
print("\n[Test Run] Skipping RAG pipeline execution due to indexing failure.")
|
||
|
||
# --- 5. Cleanup Note ---
|
||
# Optional: Add instructions or commented-out code for cleaning up the test Milvus data
|
||
print(
|
||
f"\n[Test Cleanup] Test finished. Consider manually removing data in: ./milvus_user_data_openai_fixed/{test_user}"
|
||
)
|
||
# import shutil
|
||
# from pathlib import Path
|
||
# from config import MILVUS_PERSIST_BASE_DIR
|
||
# test_db_path = MILVUS_PERSIST_BASE_DIR / test_user
|
||
# if test_db_path.exists():
|
||
# print(f"\nAttempting to clean up test data at {test_db_path}...")
|
||
# # shutil.rmtree(test_db_path) # Use with caution
|
||
|
||
print("\n--- RAG Pipeline Test Complete ---")
|