From 1d94e775579038f8c85116cb136283ac4def2e92 Mon Sep 17 00:00:00 2001 From: game-loader Date: Sun, 22 Jun 2025 15:17:07 +0800 Subject: [PATCH] feat(motion_app): implement asynchronous similarity calculation --- audio_player.py | 2 + motion_app.py | 141 +++++++++++++++++++++++++++++++++++------------ pose_analyzer.py | 29 ++++++---- 3 files changed, 128 insertions(+), 44 deletions(-) diff --git a/audio_player.py b/audio_player.py index d338267..f3e46c9 100644 --- a/audio_player.py +++ b/audio_player.py @@ -91,6 +91,8 @@ class AudioPlayer: """Restarts the audio from the beginning.""" if self.pygame_initialized and self.audio_file: self.stop() + # Small delay to ensure proper stopping + time.sleep(0.1) return self.play() return False diff --git a/motion_app.py b/motion_app.py index 48503b2..6c3db64 100644 --- a/motion_app.py +++ b/motion_app.py @@ -4,6 +4,8 @@ import time import os import numpy as np import torch +import threading +import queue from rtmlib import Body, draw_skeleton from audio_player import AudioPlayer @@ -32,6 +34,12 @@ class MotionComparisonApp: self.last_error_time = 0 self.error_count = 0 + # Async processing components + self.pose_data_queue = queue.Queue(maxsize=50) + self.similarity_thread = None + self.similarity_stop_flag = threading.Event() + + if 'comparison_state' not in st.session_state: st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False} @@ -118,6 +126,12 @@ class MotionComparisonApp: def cleanup(self): """Cleans up all resources.""" + # Stop similarity calculation thread + if self.similarity_thread and self.similarity_thread.is_alive(): + self.similarity_stop_flag.set() + self.similarity_thread.join(timeout=2) + + if self.standard_cap: self.standard_cap.release() if self.webcam_cap: self.webcam_cap.release() if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop() @@ -166,6 +180,30 @@ class MotionComparisonApp: st.session_state.comparison_state['is_running'] = False # 确保状态重置 st.rerun() + def _similarity_calculation_worker(self, start_time, video_fps): + """Background thread for similarity calculation and plot updates.""" + while not self.similarity_stop_flag.is_set(): + try: + # Get pose data from queue with timeout + pose_data = self.pose_data_queue.get(timeout=1.0) + if pose_data is None: # Poison pill to stop + break + + elapsed_time, standard_angles, webcam_angles = pose_data + + if standard_angles and webcam_angles: + current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles) + self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time) + + self.pose_data_queue.task_done() + + except queue.Empty: + continue + except Exception as e: + # Log error but continue processing + continue + + def start_comparison(self, video_path): """The main loop for comparing motion, synchronized with real-time.""" @@ -225,8 +263,18 @@ class MotionComparisonApp: if stop_button: st.session_state.comparison_state['should_stop'] = True if restart_button: st.session_state.comparison_state['should_restart'] = True - # --- 主循环 (基于时间同步) --- + # --- 主循环 (基于时间同步,专注于实时显示) --- start_time = time.time() + + # --- 启动异步相似度计算线程 --- + self.similarity_stop_flag.clear() + self.similarity_thread = threading.Thread( + target=self._similarity_calculation_worker, + args=(start_time, video_fps), + daemon=True + ) + self.similarity_thread.start() + current_similarity = 0 last_plot_update = 0 if audio_loaded: self.audio_player.play() @@ -235,15 +283,19 @@ class MotionComparisonApp: elapsed_time = time.time() - start_time if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']: + # 停止音频播放 + if audio_loaded: + self.audio_player.stop() break # --- 检查重新开始 --- if st.session_state.comparison_state['should_restart']: self.similarity_analyzer.reset() + self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置视频到开头 start_time = time.time() # 重置计时器 - if audio_loaded: self.audio_player.restart() + if audio_loaded: + self.audio_player.restart() st.session_state.comparison_state['should_restart'] = False - continue # --- 基于真实时间获取标准视频帧 --- target_frame_idx = int(elapsed_time * video_fps) @@ -264,18 +316,26 @@ class MotionComparisonApp: webcam_frame = cv2.flip(webcam_frame, 1) standard_frame = cv2.resize(standard_frame, (target_width, target_height)) webcam_frame = cv2.resize(webcam_frame, (target_width, target_height)) - # --- 关键点检测与相似度计算 (可以保持原来的逻辑) --- + # --- 关键点检测 (实时处理) --- try: standard_keypoints, standard_scores = self.body_detector(standard_frame) webcam_keypoints, webcam_scores = self.body_detector(webcam_frame) + + # --- 异步相似度计算 (发送给后台线程) --- if standard_keypoints is not None and webcam_keypoints is not None: standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores) - webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores) - if standard_angles and webcam_angles: - current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles) - - self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time) + + # 非阻塞地将数据放入队列 + try: + self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles)) + except queue.Full: + # 如果队列满了,丢弃最旧的数据 + try: + self.pose_data_queue.get_nowait() + self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles)) + except queue.Empty: + pass except Exception: standard_keypoints, webcam_keypoints = None, None # --- 绘制与显示 (可以保持原来的逻辑,只修改UI更新部分) --- @@ -285,38 +345,51 @@ class MotionComparisonApp: standard_with_keypoints = draw_skeleton(standard_frame.copy(), standard_keypoints, standard_scores, openpose_skeleton=True, kpt_thr=0.43) webcam_with_keypoints = draw_skeleton(webcam_frame.copy(), webcam_keypoints, webcam_scores, openpose_skeleton=True, kpt_thr=0.43) - # 更新视频画面 + # 更新视频画面 (主线程专注于实时显示) standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True) webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True) - # 更新状态信息 - if self.similarity_analyzer.similarity_history: + + # 更新UI (在主线程中,降低更新频率) + if self.frame_counter % 30 == 0: # 每10帧更新一次UI + # 更新状态信息 + if self.similarity_analyzer.similarity_history: + latest_similarity = self.similarity_analyzer.similarity_history[-1] + avg_sim = np.mean(self.similarity_analyzer.similarity_history) + similarity_score_placeholder.metric("当前相似度", f"{latest_similarity:.1f}%") + avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%") + + # 更新图表 (每50帧更新一次) + if self.frame_counter % 30 == 0: + plot = self.similarity_analyzer.get_similarity_plot() + if plot: + similarity_plot_placeholder.plotly_chart(plot, use_container_width=True) - avg_sim = np.mean(self.similarity_analyzer.similarity_history) - similarity_score_placeholder.metric("当前相似度", f"{current_similarity:.1f}%") - - avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%") - # 更新图表 (节流以提高性能) - if self.frame_counter - last_plot_update > 5: # 每处理约5帧更新一次图表 - plot = self.similarity_analyzer.get_similarity_plot() - - if plot: similarity_plot_placeholder.plotly_chart(plot, use_container_width=True) - last_plot_update = self.frame_counter - - # 更新进度条和状态文本 - - progress = min(elapsed_time / video_duration, 1.0) - progress_bar.progress(progress) - processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0 - status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | " - f"处理帧率: {processing_fps:.1f} FPS | " - f"标准视频帧: {target_frame_idx}/{total_frames}") + # 更新进度条和状态文本 + progress = min(elapsed_time / video_duration, 1.0) + progress_bar.progress(progress) + processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0 + status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | " + f"处理帧率: {processing_fps:.1f} FPS | " + f"标准视频帧: {target_frame_idx}/{total_frames}") except Exception as e: # 如果UI更新失败,避免程序崩溃 pass - # 这个循环没有 time.sleep(),它会尽力运行。 - # 同步是通过每次循环都根据真实时间去标准视频中定位帧来实现的。 - # --- 循环结束后的清理工作 --- + # 停止相似度计算线程 + self.similarity_stop_flag.set() + if self.similarity_thread and self.similarity_thread.is_alive(): + # 发送毒丸信号停止线程 + try: + self.pose_data_queue.put_nowait(None) + except queue.Full: + pass + self.similarity_thread.join(timeout=2) + + + # 确保音频已停止 + if audio_loaded: + self.audio_player.stop() + self.cleanup() self.show_final_statistics() diff --git a/pose_analyzer.py b/pose_analyzer.py index d647d97..996fd4a 100644 --- a/pose_analyzer.py +++ b/pose_analyzer.py @@ -1,6 +1,7 @@ import numpy as np import math import time +import threading from collections import deque import plotly.graph_objects as go @@ -11,6 +12,7 @@ class PoseSimilarityAnalyzer: self.similarity_history = deque(maxlen=500) self.frame_timestamps = deque(maxlen=500) self.start_time = None + self._lock = threading.Lock() # Thread safety for shared data self.keypoint_map = { 'nose': 0, 'neck': 1, 'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4, @@ -94,21 +96,27 @@ class PoseSimilarityAnalyzer: def add_similarity_score(self, score, timestamp=None): """Adds a similarity score to the history.""" - if self.start_time is None: self.start_time = time.time() - timestamp = timestamp if timestamp is not None else time.time() - self.start_time - self.similarity_history.append(float(score)) - self.frame_timestamps.append(float(timestamp)) + with self._lock: + if self.start_time is None: self.start_time = time.time() + timestamp = timestamp if timestamp is not None else time.time() - self.start_time + self.similarity_history.append(float(score)) + self.frame_timestamps.append(float(timestamp)) def get_similarity_plot(self): """Generates a Plotly figure for the similarity history.""" - if len(self.similarity_history) < 2: return None + with self._lock: + if len(self.similarity_history) < 2: return None + + # Create copies to avoid data changes during plotting + timestamps_copy = list(self.frame_timestamps) + history_copy = list(self.similarity_history) fig = go.Figure() - fig.add_trace(go.Scatter(x=list(self.frame_timestamps), y=list(self.similarity_history), + fig.add_trace(go.Scatter(x=timestamps_copy, y=history_copy, mode='lines+markers', name='Similarity', line=dict(color='#2E86AB', width=2), marker=dict(size=4))) - avg_score = sum(self.similarity_history) / len(self.similarity_history) + avg_score = sum(history_copy) / len(history_copy) fig.add_hline(y=avg_score, line_dash="dash", line_color="red", annotation_text=f"Avg: {avg_score:.1f}%") @@ -119,6 +127,7 @@ class PoseSimilarityAnalyzer: def reset(self): """Resets the analyzer's history.""" - self.similarity_history.clear() - self.frame_timestamps.clear() - self.start_time = None + with self._lock: + self.similarity_history.clear() + self.frame_timestamps.clear() + self.start_time = None