first commit
This commit is contained in:
195
rag_pipeline.py
Normal file
195
rag_pipeline.py
Normal file
@ -0,0 +1,195 @@
|
||||
# rag_pipeline.py
|
||||
from haystack import Pipeline
|
||||
from haystack import Document # 导入 Document
|
||||
|
||||
from milvus_haystack import MilvusDocumentStore
|
||||
from data_handling import initialize_milvus_lite
|
||||
from embedding import initialize_text_embedder
|
||||
from retrieval import initialize_vector_retriever
|
||||
from llm_integration import initialize_llm_and_prompt_builder
|
||||
from haystack.utils import Secret
|
||||
|
||||
|
||||
def build_rag_pipeline(user_id: str) -> tuple[Pipeline, MilvusDocumentStore]:
|
||||
"""
|
||||
为指定用户构建并返回 RAG 查询管道和对应的 DocumentStore。
|
||||
"""
|
||||
print(f"\n--- Building RAG Pipeline for User: {user_id} ---")
|
||||
|
||||
# 1. 初始化该用户的 DocumentStore
|
||||
document_store = initialize_milvus_lite(user_id)
|
||||
|
||||
# 2. 初始化共享组件(可以在应用启动时初始化一次,这里为简单起见每次都创建)
|
||||
text_embedder = initialize_text_embedder()
|
||||
vector_retriever = initialize_vector_retriever(document_store)
|
||||
llm, prompt_builder = initialize_llm_and_prompt_builder()
|
||||
|
||||
# 3. 创建 Haystack Pipeline
|
||||
rag_pipeline = Pipeline()
|
||||
|
||||
# 4. 向管道添加组件,并指定名称
|
||||
rag_pipeline.add_component(instance=text_embedder, name="text_embedder")
|
||||
rag_pipeline.add_component(instance=vector_retriever, name="retriever")
|
||||
rag_pipeline.add_component(instance=prompt_builder, name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=llm, name="llm")
|
||||
|
||||
# 5. 连接管道组件
|
||||
# - 将用户问题文本输入到 text_embedder
|
||||
# - 将 text_embedder 输出的嵌入向量连接到 retriever 的查询嵌入输入
|
||||
# - 将 retriever 输出的文档连接到 prompt_builder 的文档输入
|
||||
# - 将用户问题文本也连接到 prompt_builder 的问题输入
|
||||
# - 将 prompt_builder 输出的完整提示连接到 llm 的提示输入
|
||||
|
||||
rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
|
||||
rag_pipeline.connect("retriever.documents", "prompt_builder.documents")
|
||||
rag_pipeline.connect("prompt_builder.prompt", "llm.prompt")
|
||||
|
||||
print("--- RAG Pipeline Built Successfully ---")
|
||||
# 返回管道和文档存储实例,因为主程序需要用文档存储来写入数据
|
||||
return rag_pipeline, document_store
|
||||
|
||||
|
||||
# --- Corrected Test Block ---
|
||||
if __name__ == "__main__":
|
||||
import os # Needed for API Key check
|
||||
|
||||
# We need OpenAIDocumentEmbedder to index test documents
|
||||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||||
|
||||
# Import necessary config for initializing the Document Embedder
|
||||
from config import (
|
||||
OPENAI_API_KEY_FROM_CONFIG,
|
||||
OPENAI_API_BASE_URL_CONFIG,
|
||||
OPENAI_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
# --- Configuration ---
|
||||
test_user = "test_user"
|
||||
test_query = "Haystack是什么?"
|
||||
# Sample documents to index for testing
|
||||
docs_to_index = [
|
||||
Document(
|
||||
content="Haystack是一个用于构建 NLP 应用程序(如问答系统、语义搜索)的开源框架。",
|
||||
meta={"user_id": test_user, "source": "test_doc_1"},
|
||||
),
|
||||
Document(
|
||||
content="你可以使用 Haystack 连接不同的组件,如文档存储、检索器和生成器。",
|
||||
meta={"user_id": test_user, "source": "test_doc_2"},
|
||||
),
|
||||
Document(
|
||||
content="Milvus 是一个流行的向量数据库,常用于 RAG 系统中存储嵌入。",
|
||||
meta={"user_id": test_user, "source": "test_doc_3"},
|
||||
),
|
||||
]
|
||||
|
||||
print(f"--- Running Test for RAG Pipeline (User: {test_user}) ---")
|
||||
|
||||
# --- 1. Check API Key Availability ---
|
||||
# Pipeline execution requires OpenAI API calls
|
||||
api_key_configured = (
|
||||
OPENAI_API_KEY_FROM_CONFIG and "YOUR_API_KEY" not in OPENAI_API_KEY_FROM_CONFIG
|
||||
)
|
||||
if not api_key_configured:
|
||||
print("\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("! WARNING: OpenAI API Key not configured in config.py. !")
|
||||
print("! Skipping RAG pipeline test execution. !")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
exit() # Exit script if key is missing for test run
|
||||
else:
|
||||
print("\n[Test Setup] OpenAI API Key found in config.")
|
||||
|
||||
# --- 2. Build the RAG Pipeline and get the Document Store ---
|
||||
# This function initializes the store (potentially dropping old data)
|
||||
# and builds the *querying* pipeline.
|
||||
try:
|
||||
pipeline, store = build_rag_pipeline(test_user)
|
||||
except Exception as e:
|
||||
print(f"\nError building RAG pipeline: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit()
|
||||
|
||||
# --- 3. Index Test Documents (with embeddings) ---
|
||||
print("\n[Test Setup] Initializing Document Embedder for indexing test data...")
|
||||
try:
|
||||
# Initialize the Document Embedder directly here for the test
|
||||
document_embedder = OpenAIDocumentEmbedder(
|
||||
api_key=Secret.from_token(OPENAI_API_KEY_FROM_CONFIG),
|
||||
model=OPENAI_EMBEDDING_MODEL,
|
||||
api_base_url=OPENAI_API_BASE_URL_CONFIG,
|
||||
)
|
||||
print("[Test Setup] Document Embedder initialized.")
|
||||
|
||||
print("[Test Setup] Embedding test documents...")
|
||||
embedding_result = document_embedder.run(docs_to_index)
|
||||
embedded_docs = embedding_result.get("documents", [])
|
||||
|
||||
if embedded_docs:
|
||||
print(
|
||||
f"[Test Setup] Writing {len(embedded_docs)} embedded documents to Milvus..."
|
||||
)
|
||||
store.write_documents(embedded_docs)
|
||||
print("[Test Setup] Test documents written successfully.")
|
||||
# Optional: Verify count
|
||||
# print(f"[Test Setup] Document count in store: {store.count_documents()}")
|
||||
documents_indexed = True
|
||||
else:
|
||||
print("[Test Setup] ERROR: Failed to embed test documents.")
|
||||
documents_indexed = False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError during test data indexing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
documents_indexed = False
|
||||
|
||||
# --- 4. Run the RAG Pipeline (if setup succeeded) ---
|
||||
if documents_indexed:
|
||||
print(f"\n[Test Run] Running RAG pipeline for query: '{test_query}'")
|
||||
|
||||
# Prepare input for the RAG pipeline instance built by build_rag_pipeline
|
||||
pipeline_input = {
|
||||
"text_embedder": {"text": test_query}, # Input for the query embedder
|
||||
"prompt_builder": {
|
||||
"query": test_query
|
||||
}, # Input for the prompt builder template
|
||||
}
|
||||
|
||||
try:
|
||||
results = pipeline.run(pipeline_input)
|
||||
|
||||
print("\n[Test Run] Pipeline Results:")
|
||||
# Process and print the generator's answer
|
||||
if "llm" in results and results["llm"]["replies"]:
|
||||
answer = results["llm"]["replies"][0]
|
||||
print(f"\nGenerated Answer: {answer}")
|
||||
else:
|
||||
print("\n[Test Run] Could not extract answer from generator.")
|
||||
print(
|
||||
"Full Pipeline Output:", results
|
||||
) # Print full output for debugging
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n[Test Run] Error running RAG pipeline: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
else:
|
||||
print("\n[Test Run] Skipping RAG pipeline execution due to indexing failure.")
|
||||
|
||||
# --- 5. Cleanup Note ---
|
||||
# Optional: Add instructions or commented-out code for cleaning up the test Milvus data
|
||||
print(
|
||||
f"\n[Test Cleanup] Test finished. Consider manually removing data in: ./milvus_user_data_openai_fixed/{test_user}"
|
||||
)
|
||||
# import shutil
|
||||
# from pathlib import Path
|
||||
# from config import MILVUS_PERSIST_BASE_DIR
|
||||
# test_db_path = MILVUS_PERSIST_BASE_DIR / test_user
|
||||
# if test_db_path.exists():
|
||||
# print(f"\nAttempting to clean up test data at {test_db_path}...")
|
||||
# # shutil.rmtree(test_db_path) # Use with caution
|
||||
|
||||
print("\n--- RAG Pipeline Test Complete ---")
|
Reference in New Issue
Block a user