feat(docker): containerize application and add TTS integration

This commit is contained in:
gameloader
2025-10-07 22:52:04 +08:00
parent 45470fd13d
commit 315dbfed90
19 changed files with 878 additions and 80 deletions

150
api/chat_service.py Normal file
View File

@ -0,0 +1,150 @@
from typing import Dict, Any, Tuple
import base64
import threading
from haystack import Document, Pipeline
from milvus_haystack import MilvusDocumentStore
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils import Secret
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import (
DEFAULT_USER_ID,
OPENAI_EMBEDDING_KEY,
OPENAI_EMBEDDING_MODEL,
OPENAI_EMBEDDING_BASE,
)
from haystack_rag.rag_pipeline import build_rag_pipeline
from doubao_tts import text_to_speech
class ChatService:
def __init__(self, user_id: str = None):
self.user_id = user_id or DEFAULT_USER_ID
self.rag_pipeline = None
self.document_store = None
self.document_embedder = None
self._initialized = False
def initialize(self):
"""初始化 RAG 管道和相关组件"""
if self._initialized:
return
# 构建 RAG 查询管道和获取 DocumentStore 实例
self.rag_pipeline, self.document_store = build_rag_pipeline(self.user_id)
# 初始化用于写入用户输入的 Document Embedder
self.document_embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
model=OPENAI_EMBEDDING_MODEL,
api_base_url=OPENAI_EMBEDDING_BASE,
)
self._initialized = True
def _embed_and_store_async(self, user_input: str):
"""异步嵌入并存储用户输入"""
try:
# 步骤 1: 嵌入用户输入并写入 Milvus
user_doc_to_write = Document(content=user_input, meta={"user_id": self.user_id})
# 使用 OpenAIDocumentEmbedder 运行嵌入
embedding_result = self.document_embedder.run([user_doc_to_write])
embedded_docs = embedding_result.get("documents", [])
if embedded_docs:
# 将带有嵌入的文档写入 DocumentStore
self.document_store.write_documents(embedded_docs)
print(f"[INFO] 用户输入已成功嵌入并存储: {user_input[:50]}...")
else:
print(f"[WARNING] 用户输入嵌入失败: {user_input[:50]}...")
except Exception as e:
print(f"[ERROR] 异步嵌入和存储过程出错: {e}")
def chat(self, user_input: str, include_audio: bool = True) -> Dict[str, Any]:
"""处理用户输入并返回回复(包含音频)"""
if not self._initialized:
self.initialize()
try:
# 步骤 1: 异步启动嵌入和存储过程(不阻塞主流程)
embedding_thread = threading.Thread(
target=self._embed_and_store_async,
args=(user_input,),
daemon=True
)
embedding_thread.start()
# 步骤 2: 立即使用 RAG 查询管道生成回复(不等待嵌入完成)
pipeline_input = {
"text_embedder": {"text": user_input},
"prompt_builder": {"query": user_input},
}
# 运行 RAG 查询管道
results = self.rag_pipeline.run(pipeline_input)
# 步骤 3: 处理并返回结果
if "llm" in results and results["llm"]["replies"]:
answer = results["llm"]["replies"][0]
# 尝试获取 token 使用量
total_tokens = None
try:
if (
"meta" in results["llm"]
and isinstance(results["llm"]["meta"], list)
and results["llm"]["meta"]
):
usage_info = results["llm"]["meta"][0].get("usage", {})
total_tokens = usage_info.get("total_tokens")
except Exception:
pass
# 步骤 4: 生成语音(如果需要)
audio_data = None
audio_error = None
if include_audio:
try:
success, message, base64_audio = text_to_speech(answer, self.user_id)
if success and base64_audio:
# 直接使用 base64 音频数据
audio_data = base64_audio
else:
audio_error = message
except Exception as e:
audio_error = f"TTS错误: {str(e)}"
result = {
"success": True,
"response": answer,
"user_id": self.user_id
}
# 添加可选字段
if total_tokens is not None:
result["tokens"] = total_tokens
if audio_data:
result["audio_data"] = audio_data
if audio_error:
result["audio_error"] = audio_error
return result
else:
return {
"success": False,
"error": "Could not generate an answer",
"debug_info": results,
"user_id": self.user_id
}
except Exception as e:
return {
"success": False,
"error": str(e),
"user_id": self.user_id
}

173
api/doubao_tts.py Normal file
View File

@ -0,0 +1,173 @@
"""
豆包 TTS (Text-to-Speech) 服务模块
基于火山引擎豆包语音合成 API
"""
import json
import uuid
import base64
import requests
from typing import Dict, Any, Optional, Tuple
from io import BytesIO
from config import (
DOUBAO_TTS_API_URL,
DOUBAO_TTS_APP_ID,
DOUBAO_TTS_ACCESS_KEY,
DOUBAO_TTS_RESOURCE_ID,
DOUBAO_TTS_SPEAKER,
DOUBAO_TTS_FORMAT,
DOUBAO_TTS_SAMPLE_RATE,
)
class DoubaoTTS:
"""豆包 TTS 服务类,支持连接复用"""
def __init__(self):
# 使用 requests.Session 进行连接复用
self.session = requests.Session()
self.api_url = DOUBAO_TTS_API_URL
def _prepare_headers(self) -> Dict[str, str]:
"""准备请求头"""
return {
"X-Api-App-Id": DOUBAO_TTS_APP_ID,
"X-Api-Access-Key": DOUBAO_TTS_ACCESS_KEY,
"X-Api-Resource-Id": DOUBAO_TTS_RESOURCE_ID,
"X-Api-Request-Id": str(uuid.uuid4()),
"Content-Type": "application/json"
}
def _prepare_payload(self, text: str, user_id: str = "default") -> Dict[str, Any]:
"""准备请求负载"""
return {
"user": {
"uid": user_id
},
"req_params": {
"text": text,
"speaker": DOUBAO_TTS_SPEAKER,
"audio_params": {
"format": DOUBAO_TTS_FORMAT,
"sample_rate": DOUBAO_TTS_SAMPLE_RATE
}
}
}
def text_to_speech(self, text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
"""
将文本转换为语音
Args:
text: 要转换的文本
user_id: 用户ID
Returns:
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
"""
try:
headers = self._prepare_headers()
payload = self._prepare_payload(text, user_id)
# 发送流式请求
response = self.session.post(
self.api_url,
headers=headers,
json=payload,
stream=True,
timeout=30
)
if response.status_code != 200:
return False, f"HTTP错误: {response.status_code}", None
# 收集音频数据
audio_base64_chunks = []
for line in response.iter_lines():
if line:
try:
# 解析 JSON 响应
json_data = json.loads(line.decode('utf-8'))
# 检查错误
if json_data.get("code", 0) != 0:
# 检查是否是结束标识
if json_data.get("code") == 20000000:
break
else:
error_msg = json_data.get("message", "未知错误")
return False, f"API错误: {error_msg} (code: {json_data.get('code')})", None
# 提取音频数据(直接保存 base64 格式)
if "data" in json_data and json_data["data"]:
audio_base64_chunks.append(json_data["data"])
except json.JSONDecodeError as e:
continue # 跳过非 JSON 行
except Exception as e:
return False, f"处理响应时出错: {str(e)}", None
if not audio_base64_chunks:
return False, "没有接收到音频数据", None
# 合并所有 base64 音频块
complete_audio_base64 = ''.join(audio_base64_chunks)
return True, "转换成功", complete_audio_base64
except requests.exceptions.Timeout:
return False, "请求超时", None
except requests.exceptions.ConnectionError:
return False, "连接错误", None
except Exception as e:
return False, f"未知错误: {str(e)}", None
def save_audio_to_file(self, audio_data: bytes, filename: str) -> bool:
"""
将音频数据保存到文件
Args:
audio_data: 音频二进制数据
filename: 保存的文件名
Returns:
bool: 保存是否成功
"""
try:
with open(filename, 'wb') as f:
f.write(audio_data)
return True
except Exception as e:
print(f"保存音频文件失败: {e}")
return False
def close(self):
"""关闭会话连接"""
if self.session:
self.session.close()
# 全局 TTS 实例(单例模式)
_tts_instance = None
def get_tts_instance() -> DoubaoTTS:
"""获取 TTS 实例(单例)"""
global _tts_instance
if _tts_instance is None:
_tts_instance = DoubaoTTS()
return _tts_instance
def text_to_speech(text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
"""
便捷函数:将文本转换为语音
Args:
text: 要转换的文本
user_id: 用户ID
Returns:
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
"""
tts = get_tts_instance()
return tts.text_to_speech(text, user_id)

83
api/main.py Normal file
View File

@ -0,0 +1,83 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
from chat_service import ChatService
# 请求和响应模型
class ChatRequest(BaseModel):
message: str
user_id: Optional[str] = None
include_audio: Optional[bool] = True
class ChatResponse(BaseModel):
success: bool
response: Optional[str] = None
tokens: Optional[int] = None
user_id: str
error: Optional[str] = None
audio_data: Optional[str] = None # base64 编码的音频数据
audio_error: Optional[str] = None
# 创建 FastAPI 应用
app = FastAPI(
title="Haystack RAG API",
description="基于 Haystack 的 RAG 聊天服务 API",
version="1.0.0"
)
# 全局聊天服务实例
chat_service = ChatService()
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化聊天服务"""
chat_service.initialize()
@app.get("/")
async def root():
"""根路径,返回 API 信息"""
return {"message": "Haystack RAG API is running", "version": "1.0.0"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "healthy"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
聊天接口
接收用户消息,通过 RAG 管道处理并返回回复(可包含 base64 音频数据)
"""
try:
# 如果请求中指定了用户ID创建新的服务实例
if request.user_id and request.user_id != chat_service.user_id:
user_chat_service = ChatService(request.user_id)
user_chat_service.initialize()
result = user_chat_service.chat(request.message, request.include_audio)
else:
result = chat_service.chat(request.message, request.include_audio)
return ChatResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

187
api/test_doubao_tts.py Normal file
View File

@ -0,0 +1,187 @@
"""
豆包 TTS 服务单元测试
"""
import unittest
import tempfile
import os
from unittest.mock import patch, MagicMock
from doubao_tts import DoubaoTTS, get_tts_instance, text_to_speech
class TestDoubaoTTS(unittest.TestCase):
"""豆包 TTS 服务测试类"""
def setUp(self):
"""测试前置设置"""
self.tts = DoubaoTTS()
def tearDown(self):
"""测试后置清理"""
self.tts.close()
def test_prepare_headers(self):
"""测试请求头准备"""
headers = self.tts._prepare_headers()
# 检查必要的头部字段
required_headers = [
"X-Api-App-Id",
"X-Api-Access-Key",
"X-Api-Resource-Id",
"X-Api-Request-Id",
"Content-Type"
]
for header in required_headers:
self.assertIn(header, headers)
self.assertEqual(headers["Content-Type"], "application/json")
def test_prepare_payload(self):
"""测试请求负载准备"""
test_text = "测试文本"
test_user_id = "test_user"
payload = self.tts._prepare_payload(test_text, test_user_id)
# 检查负载结构
self.assertIn("user", payload)
self.assertIn("req_params", payload)
self.assertEqual(payload["user"]["uid"], test_user_id)
self.assertEqual(payload["req_params"]["text"], test_text)
self.assertIn("speaker", payload["req_params"])
self.assertIn("audio_params", payload["req_params"])
@patch('doubao_tts.requests.Session.post')
def test_text_to_speech_success(self, mock_post):
"""测试文本转语音成功场景"""
# 模拟成功的 API 响应
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = [
b'{"code": 0, "data": "dGVzdCBhdWRpbyBkYXRh"}', # base64 encoded "test audio data"
b'{"code": 20000000, "message": "ok", "data": null}'
]
mock_post.return_value = mock_response
success, message, audio_data = self.tts.text_to_speech("测试文本")
self.assertTrue(success)
self.assertEqual(message, "转换成功")
self.assertIsNotNone(audio_data)
self.assertEqual(audio_data, b"test audio data")
@patch('doubao_tts.requests.Session.post')
def test_text_to_speech_api_error(self, mock_post):
"""测试 API 错误场景"""
# 模拟 API 错误响应
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = [
b'{"code": 40402003, "message": "TTSExceededTextLimit:exceed max limit"}'
]
mock_post.return_value = mock_response
success, message, audio_data = self.tts.text_to_speech("测试文本")
self.assertFalse(success)
self.assertIn("API错误", message)
self.assertIsNone(audio_data)
@patch('doubao_tts.requests.Session.post')
def test_text_to_speech_http_error(self, mock_post):
"""测试 HTTP 错误场景"""
# 模拟 HTTP 错误
mock_response = MagicMock()
mock_response.status_code = 500
mock_post.return_value = mock_response
success, message, audio_data = self.tts.text_to_speech("测试文本")
self.assertFalse(success)
self.assertIn("HTTP错误", message)
self.assertIsNone(audio_data)
def test_save_audio_to_file(self):
"""测试音频文件保存"""
test_audio_data = b"test audio data"
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp_file:
tmp_filename = tmp_file.name
try:
# 测试保存文件
success = self.tts.save_audio_to_file(test_audio_data, tmp_filename)
self.assertTrue(success)
# 验证文件内容
with open(tmp_filename, 'rb') as f:
saved_data = f.read()
self.assertEqual(saved_data, test_audio_data)
finally:
# 清理临时文件
if os.path.exists(tmp_filename):
os.unlink(tmp_filename)
def test_singleton_instance(self):
"""测试单例模式"""
instance1 = get_tts_instance()
instance2 = get_tts_instance()
self.assertIs(instance1, instance2)
@patch('doubao_tts.get_tts_instance')
def test_text_to_speech_function(self, mock_get_instance):
"""测试便捷函数"""
# 模拟 TTS 实例
mock_tts = MagicMock()
mock_tts.text_to_speech.return_value = (True, "成功", b"audio_data")
mock_get_instance.return_value = mock_tts
success, message, audio_data = text_to_speech("测试文本", "test_user")
self.assertTrue(success)
self.assertEqual(message, "成功")
self.assertEqual(audio_data, b"audio_data")
mock_tts.text_to_speech.assert_called_once_with("测试文本", "test_user")
class TestDoubaoTTSIntegration(unittest.TestCase):
"""豆包 TTS 集成测试(需要真实的 API 密钥)"""
def setUp(self):
"""检查是否有有效的配置"""
from config import DOUBAO_TTS_APP_ID, DOUBAO_TTS_ACCESS_KEY
# 如果没有有效配置,跳过集成测试
if (not DOUBAO_TTS_APP_ID or DOUBAO_TTS_APP_ID == "YOUR_APP_ID" or
not DOUBAO_TTS_ACCESS_KEY or DOUBAO_TTS_ACCESS_KEY == "YOUR_ACCESS_KEY"):
self.skipTest("需要有效的豆包 TTS API 配置才能运行集成测试")
self.tts = DoubaoTTS()
def tearDown(self):
"""测试后置清理"""
if hasattr(self, 'tts'):
self.tts.close()
def test_real_tts_request(self):
"""真实的 TTS 请求测试"""
test_text = "你好,这是豆包语音合成测试。"
success, message, audio_data = self.tts.text_to_speech(test_text, "test_user")
if success:
self.assertIsNotNone(audio_data)
self.assertGreater(len(audio_data), 0)
print(f"TTS 测试成功: {message}")
print(f"音频数据大小: {len(audio_data)} bytes")
else:
print(f"TTS 测试失败: {message}")
# 集成测试失败时不强制断言失败,因为可能是网络或配置问题
if __name__ == '__main__':
# 运行测试
unittest.main(verbosity=2)