64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
# embedding.py
|
|
from haystack.components.embedders import OpenAITextEmbedder
|
|
|
|
from haystack.utils import Secret
|
|
|
|
# 从 config 导入新的变量名
|
|
from config import (
|
|
OPENAI_EMBEDDING_MODEL,
|
|
OPENAI_API_KEY_FROM_CONFIG, # 使用配置中的 Key
|
|
OPENAI_API_BASE_URL_CONFIG, # 使用配置中的 Base URL
|
|
OPENAI_EMBEDDING_KEY,
|
|
OPENAI_EMBEDDING_BASE,
|
|
HUGGINGFACE_KEY,
|
|
HUGGINGFACE_EMBEDDING_MODEL,
|
|
OPENAI_EMBEDDING_DIM
|
|
)
|
|
|
|
|
|
def initialize_text_embedder() -> OpenAITextEmbedder:
|
|
"""
|
|
Initializes the Haystack OpenAITextEmbedder component.
|
|
Reads API Key and Base URL directly from config.py.
|
|
"""
|
|
|
|
# 检查从配置加载的 key 是否有效 (基础检查)
|
|
if not OPENAI_API_KEY_FROM_CONFIG or "YOUR_API_KEY" in OPENAI_API_KEY_FROM_CONFIG:
|
|
print("警告: OpenAI API Key 未在 config.py 中有效配置。")
|
|
# Consider raising an error here if the key is mandatory
|
|
# raise ValueError("OpenAI API Key not configured correctly in config.py")
|
|
|
|
print(f"Initializing OpenAI Text Embedder with model: {OPENAI_EMBEDDING_MODEL}")
|
|
|
|
# 使用配置中的 Base URL
|
|
if OPENAI_API_BASE_URL_CONFIG:
|
|
print(f"Using custom API base URL from config: {OPENAI_API_BASE_URL_CONFIG}")
|
|
else:
|
|
print("Using default OpenAI API base URL (None specified in config).")
|
|
|
|
text_embedder = OpenAITextEmbedder(
|
|
# 直接使用从 config.py 导入的 key 和 base_url
|
|
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
|
|
api_base_url=OPENAI_EMBEDDING_BASE,
|
|
model=OPENAI_EMBEDDING_MODEL,
|
|
dimensions=OPENAI_EMBEDDING_DIM,
|
|
)
|
|
print("Text Embedder initialized.")
|
|
return text_embedder
|
|
|
|
|
|
# __main__ 部分也需要调整以反映不依赖环境变量
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
embedder = initialize_text_embedder()
|
|
sample_text = "这是一个示例文本,用于测试嵌入功能。"
|
|
try:
|
|
result = embedder.run(text=sample_text)
|
|
embedding = result["embedding"]
|
|
print(f"Sample text: '{sample_text}'")
|
|
# print(f"Generated embedding (first 5 dims): {embedding[:5]}")
|
|
print(f"Generated embedding dimension: {len(embedding)}")
|
|
print(f"Tokens used: {result['meta']['usage']['total_tokens']}")
|
|
except Exception as e:
|
|
print(f"Error during huggingface API call: {e}")
|