feat(docker): containerize application and add TTS integration
This commit is contained in:
150
api/chat_service.py
Normal file
150
api/chat_service.py
Normal file
@ -0,0 +1,150 @@
|
||||
from typing import Dict, Any, Tuple
|
||||
import base64
|
||||
import threading
|
||||
from haystack import Document, Pipeline
|
||||
from milvus_haystack import MilvusDocumentStore
|
||||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config import (
|
||||
DEFAULT_USER_ID,
|
||||
OPENAI_EMBEDDING_KEY,
|
||||
OPENAI_EMBEDDING_MODEL,
|
||||
OPENAI_EMBEDDING_BASE,
|
||||
)
|
||||
from haystack_rag.rag_pipeline import build_rag_pipeline
|
||||
from doubao_tts import text_to_speech
|
||||
|
||||
|
||||
class ChatService:
|
||||
def __init__(self, user_id: str = None):
|
||||
self.user_id = user_id or DEFAULT_USER_ID
|
||||
self.rag_pipeline = None
|
||||
self.document_store = None
|
||||
self.document_embedder = None
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""初始化 RAG 管道和相关组件"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 构建 RAG 查询管道和获取 DocumentStore 实例
|
||||
self.rag_pipeline, self.document_store = build_rag_pipeline(self.user_id)
|
||||
|
||||
# 初始化用于写入用户输入的 Document Embedder
|
||||
self.document_embedder = OpenAIDocumentEmbedder(
|
||||
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
|
||||
model=OPENAI_EMBEDDING_MODEL,
|
||||
api_base_url=OPENAI_EMBEDDING_BASE,
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _embed_and_store_async(self, user_input: str):
|
||||
"""异步嵌入并存储用户输入"""
|
||||
try:
|
||||
# 步骤 1: 嵌入用户输入并写入 Milvus
|
||||
user_doc_to_write = Document(content=user_input, meta={"user_id": self.user_id})
|
||||
|
||||
# 使用 OpenAIDocumentEmbedder 运行嵌入
|
||||
embedding_result = self.document_embedder.run([user_doc_to_write])
|
||||
embedded_docs = embedding_result.get("documents", [])
|
||||
|
||||
if embedded_docs:
|
||||
# 将带有嵌入的文档写入 DocumentStore
|
||||
self.document_store.write_documents(embedded_docs)
|
||||
print(f"[INFO] 用户输入已成功嵌入并存储: {user_input[:50]}...")
|
||||
else:
|
||||
print(f"[WARNING] 用户输入嵌入失败: {user_input[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 异步嵌入和存储过程出错: {e}")
|
||||
|
||||
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
|
||||
"""处理用户输入并返回回复(包含音频)"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
try:
|
||||
# 步骤 1: 异步启动嵌入和存储过程(不阻塞主流程)
|
||||
embedding_thread = threading.Thread(
|
||||
target=self._embed_and_store_async,
|
||||
args=(user_input,),
|
||||
daemon=True
|
||||
)
|
||||
embedding_thread.start()
|
||||
|
||||
# 步骤 2: 立即使用 RAG 查询管道生成回复(不等待嵌入完成)
|
||||
pipeline_input = {
|
||||
"text_embedder": {"text": user_input},
|
||||
"prompt_builder": {"query": user_input},
|
||||
}
|
||||
|
||||
# 运行 RAG 查询管道
|
||||
results = self.rag_pipeline.run(pipeline_input)
|
||||
|
||||
# 步骤 3: 处理并返回结果
|
||||
if "llm" in results and results["llm"]["replies"]:
|
||||
answer = results["llm"]["replies"][0]
|
||||
|
||||
# 尝试获取 token 使用量
|
||||
total_tokens = None
|
||||
try:
|
||||
if (
|
||||
"meta" in results["llm"]
|
||||
and isinstance(results["llm"]["meta"], list)
|
||||
and results["llm"]["meta"]
|
||||
):
|
||||
usage_info = results["llm"]["meta"][0].get("usage", {})
|
||||
total_tokens = usage_info.get("total_tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 步骤 4: 生成语音(如果需要)
|
||||
audio_data = None
|
||||
audio_error = None
|
||||
if include_audio:
|
||||
try:
|
||||
success, message, base64_audio = text_to_speech(answer, self.user_id)
|
||||
if success and base64_audio:
|
||||
# 直接使用 base64 音频数据
|
||||
audio_data = base64_audio
|
||||
else:
|
||||
audio_error = message
|
||||
except Exception as e:
|
||||
audio_error = f"TTS错误: {str(e)}"
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"response": answer,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if total_tokens is not None:
|
||||
result["tokens"] = total_tokens
|
||||
if audio_data:
|
||||
result["audio_data"] = audio_data
|
||||
if audio_error:
|
||||
result["audio_error"] = audio_error
|
||||
|
||||
return result
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Could not generate an answer",
|
||||
"debug_info": results,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"user_id": self.user_id
|
||||
}
|
173
api/doubao_tts.py
Normal file
173
api/doubao_tts.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""
|
||||
豆包 TTS (Text-to-Speech) 服务模块
|
||||
基于火山引擎豆包语音合成 API
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import requests
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from io import BytesIO
|
||||
|
||||
from config import (
|
||||
DOUBAO_TTS_API_URL,
|
||||
DOUBAO_TTS_APP_ID,
|
||||
DOUBAO_TTS_ACCESS_KEY,
|
||||
DOUBAO_TTS_RESOURCE_ID,
|
||||
DOUBAO_TTS_SPEAKER,
|
||||
DOUBAO_TTS_FORMAT,
|
||||
DOUBAO_TTS_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
|
||||
class DoubaoTTS:
|
||||
"""豆包 TTS 服务类,支持连接复用"""
|
||||
|
||||
def __init__(self):
|
||||
# 使用 requests.Session 进行连接复用
|
||||
self.session = requests.Session()
|
||||
self.api_url = DOUBAO_TTS_API_URL
|
||||
|
||||
def _prepare_headers(self) -> Dict[str, str]:
|
||||
"""准备请求头"""
|
||||
return {
|
||||
"X-Api-App-Id": DOUBAO_TTS_APP_ID,
|
||||
"X-Api-Access-Key": DOUBAO_TTS_ACCESS_KEY,
|
||||
"X-Api-Resource-Id": DOUBAO_TTS_RESOURCE_ID,
|
||||
"X-Api-Request-Id": str(uuid.uuid4()),
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def _prepare_payload(self, text: str, user_id: str = "default") -> Dict[str, Any]:
|
||||
"""准备请求负载"""
|
||||
return {
|
||||
"user": {
|
||||
"uid": user_id
|
||||
},
|
||||
"req_params": {
|
||||
"text": text,
|
||||
"speaker": DOUBAO_TTS_SPEAKER,
|
||||
"audio_params": {
|
||||
"format": DOUBAO_TTS_FORMAT,
|
||||
"sample_rate": DOUBAO_TTS_SAMPLE_RATE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def text_to_speech(self, text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
将文本转换为语音
|
||||
|
||||
Args:
|
||||
text: 要转换的文本
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
|
||||
"""
|
||||
try:
|
||||
headers = self._prepare_headers()
|
||||
payload = self._prepare_payload(text, user_id)
|
||||
|
||||
# 发送流式请求
|
||||
response = self.session.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return False, f"HTTP错误: {response.status_code}", None
|
||||
|
||||
# 收集音频数据
|
||||
audio_base64_chunks = []
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
# 解析 JSON 响应
|
||||
json_data = json.loads(line.decode('utf-8'))
|
||||
|
||||
# 检查错误
|
||||
if json_data.get("code", 0) != 0:
|
||||
# 检查是否是结束标识
|
||||
if json_data.get("code") == 20000000:
|
||||
break
|
||||
else:
|
||||
error_msg = json_data.get("message", "未知错误")
|
||||
return False, f"API错误: {error_msg} (code: {json_data.get('code')})", None
|
||||
|
||||
# 提取音频数据(直接保存 base64 格式)
|
||||
if "data" in json_data and json_data["data"]:
|
||||
audio_base64_chunks.append(json_data["data"])
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
continue # 跳过非 JSON 行
|
||||
except Exception as e:
|
||||
return False, f"处理响应时出错: {str(e)}", None
|
||||
|
||||
if not audio_base64_chunks:
|
||||
return False, "没有接收到音频数据", None
|
||||
|
||||
# 合并所有 base64 音频块
|
||||
complete_audio_base64 = ''.join(audio_base64_chunks)
|
||||
|
||||
return True, "转换成功", complete_audio_base64
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return False, "请求超时", None
|
||||
except requests.exceptions.ConnectionError:
|
||||
return False, "连接错误", None
|
||||
except Exception as e:
|
||||
return False, f"未知错误: {str(e)}", None
|
||||
|
||||
def save_audio_to_file(self, audio_data: bytes, filename: str) -> bool:
|
||||
"""
|
||||
将音频数据保存到文件
|
||||
|
||||
Args:
|
||||
audio_data: 音频二进制数据
|
||||
filename: 保存的文件名
|
||||
|
||||
Returns:
|
||||
bool: 保存是否成功
|
||||
"""
|
||||
try:
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(audio_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"保存音频文件失败: {e}")
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""关闭会话连接"""
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
|
||||
# 全局 TTS 实例(单例模式)
|
||||
_tts_instance = None
|
||||
|
||||
def get_tts_instance() -> DoubaoTTS:
|
||||
"""获取 TTS 实例(单例)"""
|
||||
global _tts_instance
|
||||
if _tts_instance is None:
|
||||
_tts_instance = DoubaoTTS()
|
||||
return _tts_instance
|
||||
|
||||
def text_to_speech(text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
便捷函数:将文本转换为语音
|
||||
|
||||
Args:
|
||||
text: 要转换的文本
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
|
||||
"""
|
||||
tts = get_tts_instance()
|
||||
return tts.text_to_speech(text, user_id)
|
83
api/main.py
Normal file
83
api/main.py
Normal file
@ -0,0 +1,83 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import uvicorn
|
||||
|
||||
from chat_service import ChatService
|
||||
|
||||
|
||||
# 请求和响应模型
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
user_id: Optional[str] = None
|
||||
include_audio: Optional[bool] = True
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
success: bool
|
||||
response: Optional[str] = None
|
||||
tokens: Optional[int] = None
|
||||
user_id: str
|
||||
error: Optional[str] = None
|
||||
audio_data: Optional[str] = None # base64 编码的音频数据
|
||||
audio_error: Optional[str] = None
|
||||
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="Haystack RAG API",
|
||||
description="基于 Haystack 的 RAG 聊天服务 API",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 全局聊天服务实例
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时初始化聊天服务"""
|
||||
chat_service.initialize()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径,返回 API 信息"""
|
||||
return {"message": "Haystack RAG API is running", "version": "1.0.0"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
"""
|
||||
聊天接口
|
||||
|
||||
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
|
||||
"""
|
||||
try:
|
||||
# 如果请求中指定了用户ID,创建新的服务实例
|
||||
if request.user_id and request.user_id != chat_service.user_id:
|
||||
user_chat_service = ChatService(request.user_id)
|
||||
user_chat_service.initialize()
|
||||
result = user_chat_service.chat(request.message, request.include_audio)
|
||||
else:
|
||||
result = chat_service.chat(request.message, request.include_audio)
|
||||
|
||||
return ChatResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
187
api/test_doubao_tts.py
Normal file
187
api/test_doubao_tts.py
Normal file
@ -0,0 +1,187 @@
|
||||
"""
|
||||
豆包 TTS 服务单元测试
|
||||
"""
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
from doubao_tts import DoubaoTTS, get_tts_instance, text_to_speech
|
||||
|
||||
|
||||
class TestDoubaoTTS(unittest.TestCase):
|
||||
"""豆包 TTS 服务测试类"""
|
||||
|
||||
def setUp(self):
|
||||
"""测试前置设置"""
|
||||
self.tts = DoubaoTTS()
|
||||
|
||||
def tearDown(self):
|
||||
"""测试后置清理"""
|
||||
self.tts.close()
|
||||
|
||||
def test_prepare_headers(self):
|
||||
"""测试请求头准备"""
|
||||
headers = self.tts._prepare_headers()
|
||||
|
||||
# 检查必要的头部字段
|
||||
required_headers = [
|
||||
"X-Api-App-Id",
|
||||
"X-Api-Access-Key",
|
||||
"X-Api-Resource-Id",
|
||||
"X-Api-Request-Id",
|
||||
"Content-Type"
|
||||
]
|
||||
|
||||
for header in required_headers:
|
||||
self.assertIn(header, headers)
|
||||
|
||||
self.assertEqual(headers["Content-Type"], "application/json")
|
||||
|
||||
def test_prepare_payload(self):
|
||||
"""测试请求负载准备"""
|
||||
test_text = "测试文本"
|
||||
test_user_id = "test_user"
|
||||
|
||||
payload = self.tts._prepare_payload(test_text, test_user_id)
|
||||
|
||||
# 检查负载结构
|
||||
self.assertIn("user", payload)
|
||||
self.assertIn("req_params", payload)
|
||||
self.assertEqual(payload["user"]["uid"], test_user_id)
|
||||
self.assertEqual(payload["req_params"]["text"], test_text)
|
||||
self.assertIn("speaker", payload["req_params"])
|
||||
self.assertIn("audio_params", payload["req_params"])
|
||||
|
||||
@patch('doubao_tts.requests.Session.post')
|
||||
def test_text_to_speech_success(self, mock_post):
|
||||
"""测试文本转语音成功场景"""
|
||||
# 模拟成功的 API 响应
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.iter_lines.return_value = [
|
||||
b'{"code": 0, "data": "dGVzdCBhdWRpbyBkYXRh"}', # base64 encoded "test audio data"
|
||||
b'{"code": 20000000, "message": "ok", "data": null}'
|
||||
]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
success, message, audio_data = self.tts.text_to_speech("测试文本")
|
||||
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(message, "转换成功")
|
||||
self.assertIsNotNone(audio_data)
|
||||
self.assertEqual(audio_data, b"test audio data")
|
||||
|
||||
@patch('doubao_tts.requests.Session.post')
|
||||
def test_text_to_speech_api_error(self, mock_post):
|
||||
"""测试 API 错误场景"""
|
||||
# 模拟 API 错误响应
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.iter_lines.return_value = [
|
||||
b'{"code": 40402003, "message": "TTSExceededTextLimit:exceed max limit"}'
|
||||
]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
success, message, audio_data = self.tts.text_to_speech("测试文本")
|
||||
|
||||
self.assertFalse(success)
|
||||
self.assertIn("API错误", message)
|
||||
self.assertIsNone(audio_data)
|
||||
|
||||
@patch('doubao_tts.requests.Session.post')
|
||||
def test_text_to_speech_http_error(self, mock_post):
|
||||
"""测试 HTTP 错误场景"""
|
||||
# 模拟 HTTP 错误
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
success, message, audio_data = self.tts.text_to_speech("测试文本")
|
||||
|
||||
self.assertFalse(success)
|
||||
self.assertIn("HTTP错误", message)
|
||||
self.assertIsNone(audio_data)
|
||||
|
||||
def test_save_audio_to_file(self):
|
||||
"""测试音频文件保存"""
|
||||
test_audio_data = b"test audio data"
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp_file:
|
||||
tmp_filename = tmp_file.name
|
||||
|
||||
try:
|
||||
# 测试保存文件
|
||||
success = self.tts.save_audio_to_file(test_audio_data, tmp_filename)
|
||||
self.assertTrue(success)
|
||||
|
||||
# 验证文件内容
|
||||
with open(tmp_filename, 'rb') as f:
|
||||
saved_data = f.read()
|
||||
self.assertEqual(saved_data, test_audio_data)
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_filename):
|
||||
os.unlink(tmp_filename)
|
||||
|
||||
def test_singleton_instance(self):
|
||||
"""测试单例模式"""
|
||||
instance1 = get_tts_instance()
|
||||
instance2 = get_tts_instance()
|
||||
|
||||
self.assertIs(instance1, instance2)
|
||||
|
||||
@patch('doubao_tts.get_tts_instance')
|
||||
def test_text_to_speech_function(self, mock_get_instance):
|
||||
"""测试便捷函数"""
|
||||
# 模拟 TTS 实例
|
||||
mock_tts = MagicMock()
|
||||
mock_tts.text_to_speech.return_value = (True, "成功", b"audio_data")
|
||||
mock_get_instance.return_value = mock_tts
|
||||
|
||||
success, message, audio_data = text_to_speech("测试文本", "test_user")
|
||||
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(message, "成功")
|
||||
self.assertEqual(audio_data, b"audio_data")
|
||||
mock_tts.text_to_speech.assert_called_once_with("测试文本", "test_user")
|
||||
|
||||
|
||||
class TestDoubaoTTSIntegration(unittest.TestCase):
|
||||
"""豆包 TTS 集成测试(需要真实的 API 密钥)"""
|
||||
|
||||
def setUp(self):
|
||||
"""检查是否有有效的配置"""
|
||||
from config import DOUBAO_TTS_APP_ID, DOUBAO_TTS_ACCESS_KEY
|
||||
|
||||
# 如果没有有效配置,跳过集成测试
|
||||
if (not DOUBAO_TTS_APP_ID or DOUBAO_TTS_APP_ID == "YOUR_APP_ID" or
|
||||
not DOUBAO_TTS_ACCESS_KEY or DOUBAO_TTS_ACCESS_KEY == "YOUR_ACCESS_KEY"):
|
||||
self.skipTest("需要有效的豆包 TTS API 配置才能运行集成测试")
|
||||
|
||||
self.tts = DoubaoTTS()
|
||||
|
||||
def tearDown(self):
|
||||
"""测试后置清理"""
|
||||
if hasattr(self, 'tts'):
|
||||
self.tts.close()
|
||||
|
||||
def test_real_tts_request(self):
|
||||
"""真实的 TTS 请求测试"""
|
||||
test_text = "你好,这是豆包语音合成测试。"
|
||||
|
||||
success, message, audio_data = self.tts.text_to_speech(test_text, "test_user")
|
||||
|
||||
if success:
|
||||
self.assertIsNotNone(audio_data)
|
||||
self.assertGreater(len(audio_data), 0)
|
||||
print(f"TTS 测试成功: {message}")
|
||||
print(f"音频数据大小: {len(audio_data)} bytes")
|
||||
else:
|
||||
print(f"TTS 测试失败: {message}")
|
||||
# 集成测试失败时不强制断言失败,因为可能是网络或配置问题
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行测试
|
||||
unittest.main(verbosity=2)
|
Reference in New Issue
Block a user