""" 豆包 TTS 服务单元测试 """ import unittest import tempfile import os from unittest.mock import patch, MagicMock from doubao_tts import DoubaoTTS, get_tts_instance, text_to_speech class TestDoubaoTTS(unittest.TestCase): """豆包 TTS 服务测试类""" def setUp(self): """测试前置设置""" self.tts = DoubaoTTS() def tearDown(self): """测试后置清理""" self.tts.close() def test_prepare_headers(self): """测试请求头准备""" headers = self.tts._prepare_headers() # 检查必要的头部字段 required_headers = [ "X-Api-App-Id", "X-Api-Access-Key", "X-Api-Resource-Id", "X-Api-Request-Id", "Content-Type" ] for header in required_headers: self.assertIn(header, headers) self.assertEqual(headers["Content-Type"], "application/json") def test_prepare_payload(self): """测试请求负载准备""" test_text = "测试文本" test_user_id = "test_user" payload = self.tts._prepare_payload(test_text, test_user_id) # 检查负载结构 self.assertIn("user", payload) self.assertIn("req_params", payload) self.assertEqual(payload["user"]["uid"], test_user_id) self.assertEqual(payload["req_params"]["text"], test_text) self.assertIn("speaker", payload["req_params"]) self.assertIn("audio_params", payload["req_params"]) @patch('doubao_tts.requests.Session.post') def test_text_to_speech_success(self, mock_post): """测试文本转语音成功场景""" # 模拟成功的 API 响应 mock_response = MagicMock() mock_response.status_code = 200 mock_response.iter_lines.return_value = [ b'{"code": 0, "data": "dGVzdCBhdWRpbyBkYXRh"}', # base64 encoded "test audio data" b'{"code": 20000000, "message": "ok", "data": null}' ] mock_post.return_value = mock_response success, message, audio_data = self.tts.text_to_speech("测试文本") self.assertTrue(success) self.assertEqual(message, "转换成功") self.assertIsNotNone(audio_data) self.assertEqual(audio_data, b"test audio data") @patch('doubao_tts.requests.Session.post') def test_text_to_speech_api_error(self, mock_post): """测试 API 错误场景""" # 模拟 API 错误响应 mock_response = MagicMock() mock_response.status_code = 200 mock_response.iter_lines.return_value = [ b'{"code": 40402003, "message": "TTSExceededTextLimit:exceed max limit"}' ] mock_post.return_value = mock_response success, message, audio_data = self.tts.text_to_speech("测试文本") self.assertFalse(success) self.assertIn("API错误", message) self.assertIsNone(audio_data) @patch('doubao_tts.requests.Session.post') def test_text_to_speech_http_error(self, mock_post): """测试 HTTP 错误场景""" # 模拟 HTTP 错误 mock_response = MagicMock() mock_response.status_code = 500 mock_post.return_value = mock_response success, message, audio_data = self.tts.text_to_speech("测试文本") self.assertFalse(success) self.assertIn("HTTP错误", message) self.assertIsNone(audio_data) def test_save_audio_to_file(self): """测试音频文件保存""" test_audio_data = b"test audio data" with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp_file: tmp_filename = tmp_file.name try: # 测试保存文件 success = self.tts.save_audio_to_file(test_audio_data, tmp_filename) self.assertTrue(success) # 验证文件内容 with open(tmp_filename, 'rb') as f: saved_data = f.read() self.assertEqual(saved_data, test_audio_data) finally: # 清理临时文件 if os.path.exists(tmp_filename): os.unlink(tmp_filename) def test_singleton_instance(self): """测试单例模式""" instance1 = get_tts_instance() instance2 = get_tts_instance() self.assertIs(instance1, instance2) @patch('doubao_tts.get_tts_instance') def test_text_to_speech_function(self, mock_get_instance): """测试便捷函数""" # 模拟 TTS 实例 mock_tts = MagicMock() mock_tts.text_to_speech.return_value = (True, "成功", b"audio_data") mock_get_instance.return_value = mock_tts success, message, audio_data = text_to_speech("测试文本", "test_user") self.assertTrue(success) self.assertEqual(message, "成功") self.assertEqual(audio_data, b"audio_data") mock_tts.text_to_speech.assert_called_once_with("测试文本", "test_user") class TestDoubaoTTSIntegration(unittest.TestCase): """豆包 TTS 集成测试(需要真实的 API 密钥)""" def setUp(self): """检查是否有有效的配置""" from config import DOUBAO_TTS_APP_ID, DOUBAO_TTS_ACCESS_KEY # 如果没有有效配置,跳过集成测试 if (not DOUBAO_TTS_APP_ID or DOUBAO_TTS_APP_ID == "YOUR_APP_ID" or not DOUBAO_TTS_ACCESS_KEY or DOUBAO_TTS_ACCESS_KEY == "YOUR_ACCESS_KEY"): self.skipTest("需要有效的豆包 TTS API 配置才能运行集成测试") self.tts = DoubaoTTS() def tearDown(self): """测试后置清理""" if hasattr(self, 'tts'): self.tts.close() def test_real_tts_request(self): """真实的 TTS 请求测试""" test_text = "你好,这是豆包语音合成测试。" success, message, audio_data = self.tts.text_to_speech(test_text, "test_user") if success: self.assertIsNotNone(audio_data) self.assertGreater(len(audio_data), 0) print(f"TTS 测试成功: {message}") print(f"音频数据大小: {len(audio_data)} bytes") else: print(f"TTS 测试失败: {message}") # 集成测试失败时不强制断言失败,因为可能是网络或配置问题 if __name__ == '__main__': # 运行测试 unittest.main(verbosity=2)