feat(motion_app): implement asynchronous similarity calculation
This commit is contained in:
parent
53979d9501
commit
1d94e77557
@ -91,6 +91,8 @@ class AudioPlayer:
|
||||
"""Restarts the audio from the beginning."""
|
||||
if self.pygame_initialized and self.audio_file:
|
||||
self.stop()
|
||||
# Small delay to ensure proper stopping
|
||||
time.sleep(0.1)
|
||||
return self.play()
|
||||
return False
|
||||
|
||||
|
141
motion_app.py
141
motion_app.py
@ -4,6 +4,8 @@ import time
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import threading
|
||||
import queue
|
||||
from rtmlib import Body, draw_skeleton
|
||||
|
||||
from audio_player import AudioPlayer
|
||||
@ -32,6 +34,12 @@ class MotionComparisonApp:
|
||||
self.last_error_time = 0
|
||||
self.error_count = 0
|
||||
|
||||
# Async processing components
|
||||
self.pose_data_queue = queue.Queue(maxsize=50)
|
||||
self.similarity_thread = None
|
||||
self.similarity_stop_flag = threading.Event()
|
||||
|
||||
|
||||
if 'comparison_state' not in st.session_state:
|
||||
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
|
||||
|
||||
@ -118,6 +126,12 @@ class MotionComparisonApp:
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all resources."""
|
||||
# Stop similarity calculation thread
|
||||
if self.similarity_thread and self.similarity_thread.is_alive():
|
||||
self.similarity_stop_flag.set()
|
||||
self.similarity_thread.join(timeout=2)
|
||||
|
||||
|
||||
if self.standard_cap: self.standard_cap.release()
|
||||
if self.webcam_cap: self.webcam_cap.release()
|
||||
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
|
||||
@ -166,6 +180,30 @@ class MotionComparisonApp:
|
||||
st.session_state.comparison_state['is_running'] = False # 确保状态重置
|
||||
st.rerun()
|
||||
|
||||
def _similarity_calculation_worker(self, start_time, video_fps):
|
||||
"""Background thread for similarity calculation and plot updates."""
|
||||
while not self.similarity_stop_flag.is_set():
|
||||
try:
|
||||
# Get pose data from queue with timeout
|
||||
pose_data = self.pose_data_queue.get(timeout=1.0)
|
||||
if pose_data is None: # Poison pill to stop
|
||||
break
|
||||
|
||||
elapsed_time, standard_angles, webcam_angles = pose_data
|
||||
|
||||
if standard_angles and webcam_angles:
|
||||
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
|
||||
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
|
||||
|
||||
self.pose_data_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
# Log error but continue processing
|
||||
continue
|
||||
|
||||
|
||||
def start_comparison(self, video_path):
|
||||
"""The main loop for comparing motion, synchronized with real-time."""
|
||||
|
||||
@ -225,8 +263,18 @@ class MotionComparisonApp:
|
||||
if stop_button: st.session_state.comparison_state['should_stop'] = True
|
||||
if restart_button: st.session_state.comparison_state['should_restart'] = True
|
||||
|
||||
# --- 主循环 (基于时间同步) ---
|
||||
# --- 主循环 (基于时间同步,专注于实时显示) ---
|
||||
start_time = time.time()
|
||||
|
||||
# --- 启动异步相似度计算线程 ---
|
||||
self.similarity_stop_flag.clear()
|
||||
self.similarity_thread = threading.Thread(
|
||||
target=self._similarity_calculation_worker,
|
||||
args=(start_time, video_fps),
|
||||
daemon=True
|
||||
)
|
||||
self.similarity_thread.start()
|
||||
|
||||
current_similarity = 0
|
||||
last_plot_update = 0
|
||||
if audio_loaded: self.audio_player.play()
|
||||
@ -235,15 +283,19 @@ class MotionComparisonApp:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']:
|
||||
# 停止音频播放
|
||||
if audio_loaded:
|
||||
self.audio_player.stop()
|
||||
break
|
||||
|
||||
# --- 检查重新开始 ---
|
||||
if st.session_state.comparison_state['should_restart']:
|
||||
self.similarity_analyzer.reset()
|
||||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置视频到开头
|
||||
start_time = time.time() # 重置计时器
|
||||
if audio_loaded: self.audio_player.restart()
|
||||
if audio_loaded:
|
||||
self.audio_player.restart()
|
||||
st.session_state.comparison_state['should_restart'] = False
|
||||
|
||||
continue
|
||||
# --- 基于真实时间获取标准视频帧 ---
|
||||
target_frame_idx = int(elapsed_time * video_fps)
|
||||
@ -264,18 +316,26 @@ class MotionComparisonApp:
|
||||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||||
# --- 关键点检测与相似度计算 (可以保持原来的逻辑) ---
|
||||
# --- 关键点检测 (实时处理) ---
|
||||
try:
|
||||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||||
|
||||
# --- 异步相似度计算 (发送给后台线程) ---
|
||||
if standard_keypoints is not None and webcam_keypoints is not None:
|
||||
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
|
||||
|
||||
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
|
||||
if standard_angles and webcam_angles:
|
||||
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
|
||||
|
||||
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
|
||||
|
||||
# 非阻塞地将数据放入队列
|
||||
try:
|
||||
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
|
||||
except queue.Full:
|
||||
# 如果队列满了,丢弃最旧的数据
|
||||
try:
|
||||
self.pose_data_queue.get_nowait()
|
||||
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception:
|
||||
standard_keypoints, webcam_keypoints = None, None
|
||||
# --- 绘制与显示 (可以保持原来的逻辑,只修改UI更新部分) ---
|
||||
@ -285,38 +345,51 @@ class MotionComparisonApp:
|
||||
standard_with_keypoints = draw_skeleton(standard_frame.copy(), standard_keypoints, standard_scores, openpose_skeleton=True, kpt_thr=0.43)
|
||||
webcam_with_keypoints = draw_skeleton(webcam_frame.copy(), webcam_keypoints, webcam_scores, openpose_skeleton=True, kpt_thr=0.43)
|
||||
|
||||
# 更新视频画面
|
||||
# 更新视频画面 (主线程专注于实时显示)
|
||||
standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
|
||||
webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
|
||||
# 更新状态信息
|
||||
if self.similarity_analyzer.similarity_history:
|
||||
|
||||
# 更新UI (在主线程中,降低更新频率)
|
||||
if self.frame_counter % 30 == 0: # 每10帧更新一次UI
|
||||
# 更新状态信息
|
||||
if self.similarity_analyzer.similarity_history:
|
||||
latest_similarity = self.similarity_analyzer.similarity_history[-1]
|
||||
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
|
||||
similarity_score_placeholder.metric("当前相似度", f"{latest_similarity:.1f}%")
|
||||
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
|
||||
|
||||
# 更新图表 (每50帧更新一次)
|
||||
if self.frame_counter % 30 == 0:
|
||||
plot = self.similarity_analyzer.get_similarity_plot()
|
||||
if plot:
|
||||
similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
|
||||
|
||||
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
|
||||
similarity_score_placeholder.metric("当前相似度", f"{current_similarity:.1f}%")
|
||||
|
||||
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
|
||||
# 更新图表 (节流以提高性能)
|
||||
if self.frame_counter - last_plot_update > 5: # 每处理约5帧更新一次图表
|
||||
plot = self.similarity_analyzer.get_similarity_plot()
|
||||
|
||||
if plot: similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
|
||||
last_plot_update = self.frame_counter
|
||||
|
||||
# 更新进度条和状态文本
|
||||
|
||||
progress = min(elapsed_time / video_duration, 1.0)
|
||||
progress_bar.progress(progress)
|
||||
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
|
||||
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
|
||||
f"处理帧率: {processing_fps:.1f} FPS | "
|
||||
f"标准视频帧: {target_frame_idx}/{total_frames}")
|
||||
# 更新进度条和状态文本
|
||||
progress = min(elapsed_time / video_duration, 1.0)
|
||||
progress_bar.progress(progress)
|
||||
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
|
||||
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
|
||||
f"处理帧率: {processing_fps:.1f} FPS | "
|
||||
f"标准视频帧: {target_frame_idx}/{total_frames}")
|
||||
|
||||
except Exception as e:
|
||||
# 如果UI更新失败,避免程序崩溃
|
||||
pass
|
||||
|
||||
# 这个循环没有 time.sleep(),它会尽力运行。
|
||||
# 同步是通过每次循环都根据真实时间去标准视频中定位帧来实现的。
|
||||
# --- 循环结束后的清理工作 ---
|
||||
# 停止相似度计算线程
|
||||
self.similarity_stop_flag.set()
|
||||
if self.similarity_thread and self.similarity_thread.is_alive():
|
||||
# 发送毒丸信号停止线程
|
||||
try:
|
||||
self.pose_data_queue.put_nowait(None)
|
||||
except queue.Full:
|
||||
pass
|
||||
self.similarity_thread.join(timeout=2)
|
||||
|
||||
|
||||
# 确保音频已停止
|
||||
if audio_loaded:
|
||||
self.audio_player.stop()
|
||||
|
||||
self.cleanup()
|
||||
self.show_final_statistics()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import threading
|
||||
from collections import deque
|
||||
import plotly.graph_objects as go
|
||||
|
||||
@ -11,6 +12,7 @@ class PoseSimilarityAnalyzer:
|
||||
self.similarity_history = deque(maxlen=500)
|
||||
self.frame_timestamps = deque(maxlen=500)
|
||||
self.start_time = None
|
||||
self._lock = threading.Lock() # Thread safety for shared data
|
||||
|
||||
self.keypoint_map = {
|
||||
'nose': 0, 'neck': 1, 'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
|
||||
@ -94,21 +96,27 @@ class PoseSimilarityAnalyzer:
|
||||
|
||||
def add_similarity_score(self, score, timestamp=None):
|
||||
"""Adds a similarity score to the history."""
|
||||
if self.start_time is None: self.start_time = time.time()
|
||||
timestamp = timestamp if timestamp is not None else time.time() - self.start_time
|
||||
self.similarity_history.append(float(score))
|
||||
self.frame_timestamps.append(float(timestamp))
|
||||
with self._lock:
|
||||
if self.start_time is None: self.start_time = time.time()
|
||||
timestamp = timestamp if timestamp is not None else time.time() - self.start_time
|
||||
self.similarity_history.append(float(score))
|
||||
self.frame_timestamps.append(float(timestamp))
|
||||
|
||||
def get_similarity_plot(self):
|
||||
"""Generates a Plotly figure for the similarity history."""
|
||||
if len(self.similarity_history) < 2: return None
|
||||
with self._lock:
|
||||
if len(self.similarity_history) < 2: return None
|
||||
|
||||
# Create copies to avoid data changes during plotting
|
||||
timestamps_copy = list(self.frame_timestamps)
|
||||
history_copy = list(self.similarity_history)
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(x=list(self.frame_timestamps), y=list(self.similarity_history),
|
||||
fig.add_trace(go.Scatter(x=timestamps_copy, y=history_copy,
|
||||
mode='lines+markers', name='Similarity',
|
||||
line=dict(color='#2E86AB', width=2), marker=dict(size=4)))
|
||||
|
||||
avg_score = sum(self.similarity_history) / len(self.similarity_history)
|
||||
avg_score = sum(history_copy) / len(history_copy)
|
||||
fig.add_hline(y=avg_score, line_dash="dash", line_color="red",
|
||||
annotation_text=f"Avg: {avg_score:.1f}%")
|
||||
|
||||
@ -119,6 +127,7 @@ class PoseSimilarityAnalyzer:
|
||||
|
||||
def reset(self):
|
||||
"""Resets the analyzer's history."""
|
||||
self.similarity_history.clear()
|
||||
self.frame_timestamps.clear()
|
||||
self.start_time = None
|
||||
with self._lock:
|
||||
self.similarity_history.clear()
|
||||
self.frame_timestamps.clear()
|
||||
self.start_time = None
|
||||
|
Loading…
Reference in New Issue
Block a user