feat(motion_app): implement asynchronous similarity calculation

This commit is contained in:
game-loader 2025-06-22 15:17:07 +08:00
parent 53979d9501
commit 1d94e77557
3 changed files with 128 additions and 44 deletions

View File

@ -91,6 +91,8 @@ class AudioPlayer:
"""Restarts the audio from the beginning."""
if self.pygame_initialized and self.audio_file:
self.stop()
# Small delay to ensure proper stopping
time.sleep(0.1)
return self.play()
return False

View File

@ -4,6 +4,8 @@ import time
import os
import numpy as np
import torch
import threading
import queue
from rtmlib import Body, draw_skeleton
from audio_player import AudioPlayer
@ -32,6 +34,12 @@ class MotionComparisonApp:
self.last_error_time = 0
self.error_count = 0
# Async processing components
self.pose_data_queue = queue.Queue(maxsize=50)
self.similarity_thread = None
self.similarity_stop_flag = threading.Event()
if 'comparison_state' not in st.session_state:
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
@ -118,6 +126,12 @@ class MotionComparisonApp:
def cleanup(self):
"""Cleans up all resources."""
# Stop similarity calculation thread
if self.similarity_thread and self.similarity_thread.is_alive():
self.similarity_stop_flag.set()
self.similarity_thread.join(timeout=2)
if self.standard_cap: self.standard_cap.release()
if self.webcam_cap: self.webcam_cap.release()
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
@ -166,6 +180,30 @@ class MotionComparisonApp:
st.session_state.comparison_state['is_running'] = False # 确保状态重置
st.rerun()
def _similarity_calculation_worker(self, start_time, video_fps):
"""Background thread for similarity calculation and plot updates."""
while not self.similarity_stop_flag.is_set():
try:
# Get pose data from queue with timeout
pose_data = self.pose_data_queue.get(timeout=1.0)
if pose_data is None: # Poison pill to stop
break
elapsed_time, standard_angles, webcam_angles = pose_data
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
self.pose_data_queue.task_done()
except queue.Empty:
continue
except Exception as e:
# Log error but continue processing
continue
def start_comparison(self, video_path):
"""The main loop for comparing motion, synchronized with real-time."""
@ -225,8 +263,18 @@ class MotionComparisonApp:
if stop_button: st.session_state.comparison_state['should_stop'] = True
if restart_button: st.session_state.comparison_state['should_restart'] = True
# --- 主循环 (基于时间同步) ---
# --- 主循环 (基于时间同步,专注于实时显示) ---
start_time = time.time()
# --- 启动异步相似度计算线程 ---
self.similarity_stop_flag.clear()
self.similarity_thread = threading.Thread(
target=self._similarity_calculation_worker,
args=(start_time, video_fps),
daemon=True
)
self.similarity_thread.start()
current_similarity = 0
last_plot_update = 0
if audio_loaded: self.audio_player.play()
@ -235,15 +283,19 @@ class MotionComparisonApp:
elapsed_time = time.time() - start_time
if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']:
# 停止音频播放
if audio_loaded:
self.audio_player.stop()
break
# --- 检查重新开始 ---
if st.session_state.comparison_state['should_restart']:
self.similarity_analyzer.reset()
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置视频到开头
start_time = time.time() # 重置计时器
if audio_loaded: self.audio_player.restart()
if audio_loaded:
self.audio_player.restart()
st.session_state.comparison_state['should_restart'] = False
continue
# --- 基于真实时间获取标准视频帧 ---
target_frame_idx = int(elapsed_time * video_fps)
@ -264,18 +316,26 @@ class MotionComparisonApp:
webcam_frame = cv2.flip(webcam_frame, 1)
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# --- 关键点检测与相似度计算 (可以保持原来的逻辑) ---
# --- 关键点检测 (实时处理) ---
try:
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
# --- 异步相似度计算 (发送给后台线程) ---
if standard_keypoints is not None and webcam_keypoints is not None:
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
# 非阻塞地将数据放入队列
try:
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
except queue.Full:
# 如果队列满了,丢弃最旧的数据
try:
self.pose_data_queue.get_nowait()
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
except queue.Empty:
pass
except Exception:
standard_keypoints, webcam_keypoints = None, None
# --- 绘制与显示 (可以保持原来的逻辑只修改UI更新部分) ---
@ -285,38 +345,51 @@ class MotionComparisonApp:
standard_with_keypoints = draw_skeleton(standard_frame.copy(), standard_keypoints, standard_scores, openpose_skeleton=True, kpt_thr=0.43)
webcam_with_keypoints = draw_skeleton(webcam_frame.copy(), webcam_keypoints, webcam_scores, openpose_skeleton=True, kpt_thr=0.43)
# 更新视频画面
# 更新视频画面 (主线程专注于实时显示)
standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
# 更新状态信息
if self.similarity_analyzer.similarity_history:
# 更新UI (在主线程中,降低更新频率)
if self.frame_counter % 30 == 0: # 每10帧更新一次UI
# 更新状态信息
if self.similarity_analyzer.similarity_history:
latest_similarity = self.similarity_analyzer.similarity_history[-1]
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
similarity_score_placeholder.metric("当前相似度", f"{latest_similarity:.1f}%")
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
# 更新图表 (每50帧更新一次)
if self.frame_counter % 30 == 0:
plot = self.similarity_analyzer.get_similarity_plot()
if plot:
similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
similarity_score_placeholder.metric("当前相似度", f"{current_similarity:.1f}%")
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
# 更新图表 (节流以提高性能)
if self.frame_counter - last_plot_update > 5: # 每处理约5帧更新一次图表
plot = self.similarity_analyzer.get_similarity_plot()
if plot: similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
last_plot_update = self.frame_counter
# 更新进度条和状态文本
progress = min(elapsed_time / video_duration, 1.0)
progress_bar.progress(progress)
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
f"处理帧率: {processing_fps:.1f} FPS | "
f"标准视频帧: {target_frame_idx}/{total_frames}")
# 更新进度条和状态文本
progress = min(elapsed_time / video_duration, 1.0)
progress_bar.progress(progress)
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
f"处理帧率: {processing_fps:.1f} FPS | "
f"标准视频帧: {target_frame_idx}/{total_frames}")
except Exception as e:
# 如果UI更新失败避免程序崩溃
pass
# 这个循环没有 time.sleep(),它会尽力运行。
# 同步是通过每次循环都根据真实时间去标准视频中定位帧来实现的。
# --- 循环结束后的清理工作 ---
# 停止相似度计算线程
self.similarity_stop_flag.set()
if self.similarity_thread and self.similarity_thread.is_alive():
# 发送毒丸信号停止线程
try:
self.pose_data_queue.put_nowait(None)
except queue.Full:
pass
self.similarity_thread.join(timeout=2)
# 确保音频已停止
if audio_loaded:
self.audio_player.stop()
self.cleanup()
self.show_final_statistics()

View File

@ -1,6 +1,7 @@
import numpy as np
import math
import time
import threading
from collections import deque
import plotly.graph_objects as go
@ -11,6 +12,7 @@ class PoseSimilarityAnalyzer:
self.similarity_history = deque(maxlen=500)
self.frame_timestamps = deque(maxlen=500)
self.start_time = None
self._lock = threading.Lock() # Thread safety for shared data
self.keypoint_map = {
'nose': 0, 'neck': 1, 'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
@ -94,21 +96,27 @@ class PoseSimilarityAnalyzer:
def add_similarity_score(self, score, timestamp=None):
"""Adds a similarity score to the history."""
if self.start_time is None: self.start_time = time.time()
timestamp = timestamp if timestamp is not None else time.time() - self.start_time
self.similarity_history.append(float(score))
self.frame_timestamps.append(float(timestamp))
with self._lock:
if self.start_time is None: self.start_time = time.time()
timestamp = timestamp if timestamp is not None else time.time() - self.start_time
self.similarity_history.append(float(score))
self.frame_timestamps.append(float(timestamp))
def get_similarity_plot(self):
"""Generates a Plotly figure for the similarity history."""
if len(self.similarity_history) < 2: return None
with self._lock:
if len(self.similarity_history) < 2: return None
# Create copies to avoid data changes during plotting
timestamps_copy = list(self.frame_timestamps)
history_copy = list(self.similarity_history)
fig = go.Figure()
fig.add_trace(go.Scatter(x=list(self.frame_timestamps), y=list(self.similarity_history),
fig.add_trace(go.Scatter(x=timestamps_copy, y=history_copy,
mode='lines+markers', name='Similarity',
line=dict(color='#2E86AB', width=2), marker=dict(size=4)))
avg_score = sum(self.similarity_history) / len(self.similarity_history)
avg_score = sum(history_copy) / len(history_copy)
fig.add_hline(y=avg_score, line_dash="dash", line_color="red",
annotation_text=f"Avg: {avg_score:.1f}%")
@ -119,6 +127,7 @@ class PoseSimilarityAnalyzer:
def reset(self):
"""Resets the analyzer's history."""
self.similarity_history.clear()
self.frame_timestamps.clear()
self.start_time = None
with self._lock:
self.similarity_history.clear()
self.frame_timestamps.clear()
self.start_time = None