173 lines
5.5 KiB
Python
173 lines
5.5 KiB
Python
"""
|
|
豆包 TTS (Text-to-Speech) 服务模块
|
|
基于火山引擎豆包语音合成 API
|
|
"""
|
|
import json
|
|
import uuid
|
|
import base64
|
|
import requests
|
|
from typing import Dict, Any, Optional, Tuple
|
|
from io import BytesIO
|
|
|
|
from config import (
|
|
DOUBAO_TTS_API_URL,
|
|
DOUBAO_TTS_APP_ID,
|
|
DOUBAO_TTS_ACCESS_KEY,
|
|
DOUBAO_TTS_RESOURCE_ID,
|
|
DOUBAO_TTS_SPEAKER,
|
|
DOUBAO_TTS_FORMAT,
|
|
DOUBAO_TTS_SAMPLE_RATE,
|
|
)
|
|
|
|
|
|
class DoubaoTTS:
|
|
"""豆包 TTS 服务类,支持连接复用"""
|
|
|
|
def __init__(self):
|
|
# 使用 requests.Session 进行连接复用
|
|
self.session = requests.Session()
|
|
self.api_url = DOUBAO_TTS_API_URL
|
|
|
|
def _prepare_headers(self) -> Dict[str, str]:
|
|
"""准备请求头"""
|
|
return {
|
|
"X-Api-App-Id": DOUBAO_TTS_APP_ID,
|
|
"X-Api-Access-Key": DOUBAO_TTS_ACCESS_KEY,
|
|
"X-Api-Resource-Id": DOUBAO_TTS_RESOURCE_ID,
|
|
"X-Api-Request-Id": str(uuid.uuid4()),
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
def _prepare_payload(self, text: str, user_id: str = "default") -> Dict[str, Any]:
|
|
"""准备请求负载"""
|
|
return {
|
|
"user": {
|
|
"uid": user_id
|
|
},
|
|
"req_params": {
|
|
"text": text,
|
|
"speaker": DOUBAO_TTS_SPEAKER,
|
|
"audio_params": {
|
|
"format": DOUBAO_TTS_FORMAT,
|
|
"sample_rate": DOUBAO_TTS_SAMPLE_RATE
|
|
}
|
|
}
|
|
}
|
|
|
|
def text_to_speech(self, text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
|
|
"""
|
|
将文本转换为语音
|
|
|
|
Args:
|
|
text: 要转换的文本
|
|
user_id: 用户ID
|
|
|
|
Returns:
|
|
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
|
|
"""
|
|
try:
|
|
headers = self._prepare_headers()
|
|
payload = self._prepare_payload(text, user_id)
|
|
|
|
# 发送流式请求
|
|
response = self.session.post(
|
|
self.api_url,
|
|
headers=headers,
|
|
json=payload,
|
|
stream=True,
|
|
timeout=30
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return False, f"HTTP错误: {response.status_code}", None
|
|
|
|
# 收集音频数据
|
|
audio_base64_chunks = []
|
|
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
# 解析 JSON 响应
|
|
json_data = json.loads(line.decode('utf-8'))
|
|
|
|
# 检查错误
|
|
if json_data.get("code", 0) != 0:
|
|
# 检查是否是结束标识
|
|
if json_data.get("code") == 20000000:
|
|
break
|
|
else:
|
|
error_msg = json_data.get("message", "未知错误")
|
|
return False, f"API错误: {error_msg} (code: {json_data.get('code')})", None
|
|
|
|
# 提取音频数据(直接保存 base64 格式)
|
|
if "data" in json_data and json_data["data"]:
|
|
audio_base64_chunks.append(json_data["data"])
|
|
|
|
except json.JSONDecodeError as e:
|
|
continue # 跳过非 JSON 行
|
|
except Exception as e:
|
|
return False, f"处理响应时出错: {str(e)}", None
|
|
|
|
if not audio_base64_chunks:
|
|
return False, "没有接收到音频数据", None
|
|
|
|
# 合并所有 base64 音频块
|
|
complete_audio_base64 = ''.join(audio_base64_chunks)
|
|
|
|
return True, "转换成功", complete_audio_base64
|
|
|
|
except requests.exceptions.Timeout:
|
|
return False, "请求超时", None
|
|
except requests.exceptions.ConnectionError:
|
|
return False, "连接错误", None
|
|
except Exception as e:
|
|
return False, f"未知错误: {str(e)}", None
|
|
|
|
def save_audio_to_file(self, audio_data: bytes, filename: str) -> bool:
|
|
"""
|
|
将音频数据保存到文件
|
|
|
|
Args:
|
|
audio_data: 音频二进制数据
|
|
filename: 保存的文件名
|
|
|
|
Returns:
|
|
bool: 保存是否成功
|
|
"""
|
|
try:
|
|
with open(filename, 'wb') as f:
|
|
f.write(audio_data)
|
|
return True
|
|
except Exception as e:
|
|
print(f"保存音频文件失败: {e}")
|
|
return False
|
|
|
|
def close(self):
|
|
"""关闭会话连接"""
|
|
if self.session:
|
|
self.session.close()
|
|
|
|
|
|
# 全局 TTS 实例(单例模式)
|
|
_tts_instance = None
|
|
|
|
def get_tts_instance() -> DoubaoTTS:
|
|
"""获取 TTS 实例(单例)"""
|
|
global _tts_instance
|
|
if _tts_instance is None:
|
|
_tts_instance = DoubaoTTS()
|
|
return _tts_instance
|
|
|
|
def text_to_speech(text: str, user_id: str = "default") -> Tuple[bool, str, Optional[str]]:
|
|
"""
|
|
便捷函数:将文本转换为语音
|
|
|
|
Args:
|
|
text: 要转换的文本
|
|
user_id: 用户ID
|
|
|
|
Returns:
|
|
Tuple[bool, str, Optional[str]]: (成功状态, 消息, base64音频数据)
|
|
"""
|
|
tts = get_tts_instance()
|
|
return tts.text_to_speech(text, user_id) |