Files
rag_chat/api/doubao_tts.py

173 lines
5.5 KiB
Python

"""
豆包 TTS (Text-to-Speech) 服务模块
基于火山引擎豆包语音合成 API
"""
import json
import uuid
import base64
import requests
from typing import Dict, Any, Optional, Tuple
from io import BytesIO
from config import (
DOUBAO_TTS_API_URL,
DOUBAO_TTS_APP_ID,
DOUBAO_TTS_ACCESS_KEY,
DOUBAO_TTS_RESOURCE_ID,
DOUBAO_TTS_SPEAKER,
DOUBAO_TTS_FORMAT,
DOUBAO_TTS_SAMPLE_RATE,
)
class DoubaoTTS:
"""豆包 TTS 服务类,支持连接复用"""
def __init__(self):
# 使用 requests.Session 进行连接复用
self.session = requests.Session()
self.api_url = DOUBAO_TTS_API_URL
def _prepare_headers(self) -> Dict[str, str]:
"""准备请求头"""
return {
"X-Api-App-Id": DOUBAO_TTS_APP_ID,
"X-Api-Access-Key": DOUBAO_TTS_ACCESS_KEY,
"X-Api-Resource-Id": DOUBAO_TTS_RESOURCE_ID,
"X-Api-Request-Id": str(uuid.uuid4()),
"Content-Type": "application/json"
}
def _prepare_payload(self, text: str, user_id: str = "default") -> Dict[str, Any]:
"""准备请求负载"""
return {
"user": {
"uid": user_id
},
"req_params": {
"text": text,
"speaker": DOUBAO_TTS_SPEAKER,
"audio_params": {
"format": DOUBAO_TTS_FORMAT,
"sample_rate": DOUBAO_TTS_SAMPLE_RATE
}
}
}
def text_to_speech(self, text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
"""
将文本转换为语音
Args:
text: 要转换的文本
user_id: 用户ID
Returns:
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
"""
try:
headers = self._prepare_headers()
payload = self._prepare_payload(text, user_id)
# 发送流式请求
response = self.session.post(
self.api_url,
headers=headers,
json=payload,
stream=True,
timeout=30
)
if response.status_code != 200:
return False, f"HTTP错误: {response.status_code}", None
# 收集音频数据
audio_base64_chunks = []
for line in response.iter_lines():
if line:
try:
# 解析 JSON 响应
json_data = json.loads(line.decode('utf-8'))
# 检查错误
if json_data.get("code", 0) != 0:
# 检查是否是结束标识
if json_data.get("code") == 20000000:
break
else:
error_msg = json_data.get("message", "未知错误")
return False, f"API错误: {error_msg} (code: {json_data.get('code')})", None
# 提取音频数据(直接保存 base64 格式)
if "data" in json_data and json_data["data"]:
audio_base64_chunks.append(json_data["data"])
except json.JSONDecodeError as e:
continue # 跳过非 JSON 行
except Exception as e:
return False, f"处理响应时出错: {str(e)}", None
if not audio_base64_chunks:
return False, "没有接收到音频数据", None
# 合并所有 base64 音频块
complete_audio_base64 = ''.join(audio_base64_chunks)
return True, "转换成功", complete_audio_base64
except requests.exceptions.Timeout:
return False, "请求超时", None
except requests.exceptions.ConnectionError:
return False, "连接错误", None
except Exception as e:
return False, f"未知错误: {str(e)}", None
def save_audio_to_file(self, audio_data: bytes, filename: str) -> bool:
"""
将音频数据保存到文件
Args:
audio_data: 音频二进制数据
filename: 保存的文件名
Returns:
bool: 保存是否成功
"""
try:
with open(filename, 'wb') as f:
f.write(audio_data)
return True
except Exception as e:
print(f"保存音频文件失败: {e}")
return False
def close(self):
"""关闭会话连接"""
if self.session:
self.session.close()
# 全局 TTS 实例(单例模式)
_tts_instance = None
def get_tts_instance() -> DoubaoTTS:
"""获取 TTS 实例(单例)"""
global _tts_instance
if _tts_instance is None:
_tts_instance = DoubaoTTS()
return _tts_instance
def text_to_speech(text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
"""
便捷函数:将文本转换为语音
Args:
text: 要转换的文本
user_id: 用户ID
Returns:
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
"""
tts = get_tts_instance()
return tts.text_to_speech(text, user_id)