feat(motion_app): improve comparison loop and UI

This commit is contained in:
game-loader 2025-06-22 12:11:30 +08:00
parent 899c85ab8f
commit 53979d9501
2 changed files with 117 additions and 256 deletions

View File

@ -190,7 +190,6 @@ def main():
app.start_comparison(video_path)
# 比较结束后,重置状态,以便下次可以重新显示主按钮
st.session_state.start_main_comparison = False
st.rerun() # 强制重跑以显示主按钮界面
except Exception as e:
st.error(f"❌ 处理视频时出现错误: {str(e)}")

View File

@ -127,16 +127,28 @@ class MotionComparisonApp:
def show_final_statistics(self):
"""Displays final statistics after the comparison ends."""
st.markdown("---") # 添加分隔线
history = self.similarity_analyzer.similarity_history
if not history: return
if not history:
st.warning("没有足够的比较数据来生成分析结果。")
if st.button("返回主菜单", use_container_width=True):
st.rerun()
return
final_avg = sum(history) / len(history)
level, color = ("非常棒! 👏", "success") if final_avg >= 80 else \
("整体不错! 👍", "info") if final_avg >= 60 else \
("需要改进! 💪", "warning")
# 我们将评级文本和颜色分开定义,逻辑更清晰
if final_avg >= 80:
level = "非常棒! 👏"
md_color = "green" # 使用 'green' 颜色
elif final_avg >= 60:
level = "整体不错! 👍"
md_color = "blue" # 使用 'blue' 颜色
else:
level = "需要改进! 💪"
md_color = "orange" # 使用 'orange' 颜色替代 'warning'
st.success("🎉 比较完成!")
st.markdown(f"**整体表现**: :{color}[{level}]")
st.markdown(f"**整体表现**: :{md_color}[{level}]")
col1, col2, col3 = st.columns(3)
col1.metric("平均相似度", f"{final_avg:.1f}%")
@ -148,16 +160,23 @@ class MotionComparisonApp:
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
"- 尽量匹配标准视频的节奏和动作幅度\n"
"- 确保光线充足且稳定")
# 添加返回按钮
st.markdown("---")
if st.button("返回主菜单", use_container_width=True):
st.session_state.comparison_state['is_running'] = False # 确保状态重置
st.rerun()
def start_comparison(self, video_path):
"""The main loop for comparing motion."""
# Setup and initialization... (abbreviated for clarity, logic is the same as original)
"""The main loop for comparing motion, synchronized with real-time."""
self.is_running = True
st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False})
self.standard_video_path = video_path
self.frame_counter = 0
self.similarity_analyzer.reset()
audio_loaded = self.audio_player.load_audio(video_path)
if audio_loaded: st.success("✅ 音频加载成功")
@ -170,50 +189,28 @@ class MotionComparisonApp:
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
if not self.initialize_camera(): return
# Get video info and setup variables
# --- 获取视频信息和设置变量 ---
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
frame_delay = 1.0 / video_fps if video_fps > 0 else 1.0 / 30.0
if video_fps == 0: video_fps = 30 # 防止除零错误
video_duration = total_frames / video_fps
target_width, target_height = self.get_display_resolution()
# UI Placeholders
# --- UI 占位符 ---
st.markdown("### 📺 视频比较")
video_col1, video_col2 = st.columns(2, gap="small")
standard_placeholder = video_col1.empty()
webcam_placeholder = video_col2.empty()
with video_col1:
st.markdown("#### 🎯 标准动作视频")
with video_col2:
camera_type = "RealSense摄像头" if self.is_realsense_active else "USB摄像头"
st.markdown(f"#### 📹 {camera_type}实时影像")
# 创建控制按钮区域(紧凑布局)
st.markdown("---")
control_container = st.container()
with control_container:
control_col1, control_col2, control_col3, control_col4 = st.columns([1, 1, 1, 1])
with control_col1:
stop_button = st.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
with control_col2:
restart_button = st.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
with control_col3:
if audio_loaded:
if st.button("🔊 音频状态", use_container_width=True, key="audio_status"):
st.info(f"音频: {'播放中' if self.audio_player.is_playing else '已停止'}")
with control_col4:
# 分辨率切换按钮
if st.button("📐 切换分辨率", use_container_width=True, key="resolution_toggle"):
modes = ['high', 'medium', 'low']
current_idx = modes.index(self.display_settings.get('resolution_mode', 'high'))
next_idx = (current_idx + 1) % len(modes)
self.display_settings['resolution_mode'] = modes[next_idx]
st.info(f"分辨率模式: {modes[next_idx]}")
with video_col1: st.markdown("#### 🎯 标准动作视频")
# Similarity UI
with video_col2: st.markdown(f"#### 📹 {'RealSense' if self.is_realsense_active else 'USB'}实时影像")
st.markdown("---")
control_col1, control_col2 = st.columns(2)
stop_button = control_col1.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
restart_button = control_col2.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
st.markdown("---")
st.markdown("### 📊 动作相似度分析")
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
@ -221,240 +218,105 @@ class MotionComparisonApp:
avg_score_placeholder = sim_col2.empty()
similarity_plot_placeholder = sim_col3.empty()
# 处理按钮点击
if stop_button:
st.session_state.comparison_state['should_stop'] = True
if restart_button:
st.session_state.comparison_state['should_restart'] = True
# 状态显示区域
status_container = st.container()
with status_container:
progress_bar = st.progress(0)
status_text = st.empty()
# 主循环
video_frame_idx = 0
progress_bar = st.progress(0)
status_text = st.empty()
# --- 按钮逻辑 ---
if stop_button: st.session_state.comparison_state['should_stop'] = True
if restart_button: st.session_state.comparison_state['should_restart'] = True
# --- 主循环 (基于时间同步) ---
start_time = time.time()
current_similarity = 0
last_plot_update = 0
# Start Audio
if audio_loaded: self.audio_player.play()
while True:
# --- 检查退出条件 ---
elapsed_time = time.time() - start_time
while (st.session_state.comparison_state['is_running'] and
not st.session_state.comparison_state['should_stop']):
if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']:
break
loop_start = time.time()
self.frame_counter += 1
# 检查重新开始
# --- 检查重新开始 ---
if st.session_state.comparison_state['should_restart']:
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.similarity_analyzer.reset()
start_time = time.time()
video_frame_idx = 0
if audio_loaded:
self.audio_player.restart()
start_time = time.time() # 重置计时器
if audio_loaded: self.audio_player.restart()
st.session_state.comparison_state['should_restart'] = False
continue
# 读取标准视频的当前帧
# --- 基于真实时间获取标准视频帧 ---
target_frame_idx = int(elapsed_time * video_fps)
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_idx)
ret_standard, standard_frame = self.standard_cap.read()
if not ret_standard:
# 视频结束,重新开始
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
video_frame_idx = 0
self.similarity_analyzer.reset()
start_time = time.time()
if audio_loaded:
self.audio_player.restart()
continue
# 读取摄像头当前帧
if not ret_standard: continue # 如果读取失败,跳过这一轮
# --- 读取当前摄像头帧 ---
ret_webcam, webcam_frame = self.read_camera_frame()
if not ret_webcam or webcam_frame is None:
self.error_count += 1
if self.error_count > 100: # 连续错误过多时停止
st.error("摄像头连接出现问题,停止比较")
break
status_text.warning("摄像头画面读取失败,请检查连接...")
time.sleep(0.1) # 短暂等待
continue
self.error_count = 0 # 重置错误计数
# 翻转摄像头画面(镜像效果)
self.frame_counter += 1
webcam_frame = cv2.flip(webcam_frame, 1)
# 调整尺寸使两个视频大小一致(使用更高分辨率)
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# 处理关键点检测
# --- 关键点检测与相似度计算 (可以保持原来的逻辑) ---
try:
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
except Exception as e:
# 关键点检测失败时继续显示原始图像
standard_keypoints, standard_scores = None, None
webcam_keypoints, webcam_scores = None, None
# 计算相似度每5帧计算一次以提高性能
if (self.frame_counter % 5 == 0 and
standard_keypoints is not None and
webcam_keypoints is not None):
try:
# 提取当前标准视频帧和摄像头帧的关节角度
standard_angles = self.similarity_analyzer.extract_joint_angles(
standard_keypoints, standard_scores
)
webcam_angles = self.similarity_analyzer.extract_joint_angles(
webcam_keypoints, webcam_scores
)
# 计算相似度
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(
standard_angles, webcam_angles
)
# 添加到历史记录
timestamp = time.time() - start_time
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
except Exception as e:
pass # 忽略相似度计算错误
# 绘制关键点
try:
if standard_keypoints is not None and standard_scores is not None:
standard_with_keypoints = draw_skeleton(
standard_frame.copy(),
standard_keypoints,
standard_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
else:
standard_with_keypoints = standard_frame.copy()
if webcam_keypoints is not None and webcam_scores is not None:
webcam_with_keypoints = draw_skeleton(
webcam_frame.copy(),
webcam_keypoints,
webcam_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
else:
webcam_with_keypoints = webcam_frame.copy()
except Exception as e:
# 如果绘制失败,使用原始帧
standard_with_keypoints = standard_frame.copy()
webcam_with_keypoints = webcam_frame.copy()
# 转换颜色空间 (BGR to RGB)
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
# 添加帧信息到图像上
current_time = time.time() - start_time
frame_info = f"时间: {current_time:.1f}s | 帧: {video_frame_idx}/{total_frames}"
audio_info = f" | 音频: {'🔊' if self.audio_player.is_playing else '🔇'}" if audio_loaded else ""
resolution_info = f" | {target_width}x{target_height}"
# 显示大尺寸画面
with video_col1:
standard_placeholder.image(
standard_rgb,
caption=f"标准动作 - {frame_info}{audio_info}{resolution_info}",
use_container_width=True
)
with video_col2:
webcam_placeholder.image(
webcam_rgb,
caption=f"您的动作 - 实时画面{resolution_info}",
use_container_width=True
)
# 显示相似度信息(紧凑显示)
if len(self.similarity_analyzer.similarity_history) > 0:
try:
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
# 使用不同颜色显示相似度
if current_similarity >= 80:
similarity_color = "🟢"
level = "优秀"
elif current_similarity >= 60:
similarity_color = "🟡"
level = "良好"
else:
similarity_color = "🔴"
level = "需要改进"
with sim_col1:
similarity_score_placeholder.metric(
"当前相似度",
f"{similarity_color} {current_similarity:.1f}%",
delta=f"{level}"
)
with sim_col2:
avg_score_placeholder.metric(
"平均相似度",
f"{avg_similarity:.1f}%"
)
# 更新相似度图表每20帧更新一次以提高性能
if (len(self.similarity_analyzer.similarity_history) >= 2 and
self.frame_counter - last_plot_update >= 20):
try:
similarity_plot = self.similarity_analyzer.get_similarity_plot()
if similarity_plot:
with sim_col3:
similarity_plot_placeholder.plotly_chart(
similarity_plot,
use_container_width=True,
key=f"similarity_plot_{int(time.time() * 1000)}" # 使用时间戳避免重复ID
)
last_plot_update = self.frame_counter
except Exception as e:
pass # 忽略图表更新错误
except Exception as e:
pass # 忽略显示错误
# 更新进度和状态(紧凑显示)
try:
progress = min(video_frame_idx / total_frames, 1.0) if total_frames > 0 else 0
progress_bar.progress(progress)
elapsed_time = time.time() - start_time
fps_actual = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(
f"进度: {video_frame_idx}/{total_frames} | "
f"实际FPS: {fps_actual:.1f} | "
f"分辨率: {target_width}x{target_height} | "
f"模式: {self.display_settings['resolution_mode']}"
)
except Exception as e:
pass # 忽略状态更新错误
video_frame_idx += 1
# 精确的帧率控制
loop_elapsed = time.time() - loop_start
sleep_time = max(0, frame_delay - loop_elapsed)
if sleep_time > 0:
time.sleep(sleep_time)
# 强制更新UI每30帧一次
if self.frame_counter % 30 == 0:
st.empty() # 触发UI更新
if standard_keypoints is not None and webcam_keypoints is not None:
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
except Exception:
standard_keypoints, webcam_keypoints = None, None
# --- 绘制与显示 (可以保持原来的逻辑只修改UI更新部分) ---
try:
# 绘制骨骼
standard_with_keypoints = draw_skeleton(standard_frame.copy(), standard_keypoints, standard_scores, openpose_skeleton=True, kpt_thr=0.43)
webcam_with_keypoints = draw_skeleton(webcam_frame.copy(), webcam_keypoints, webcam_scores, openpose_skeleton=True, kpt_thr=0.43)
# 更新视频画面
standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
# 更新状态信息
if self.similarity_analyzer.similarity_history:
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
similarity_score_placeholder.metric("当前相似度", f"{current_similarity:.1f}%")
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
# 更新图表 (节流以提高性能)
if self.frame_counter - last_plot_update > 5: # 每处理约5帧更新一次图表
plot = self.similarity_analyzer.get_similarity_plot()
if plot: similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
last_plot_update = self.frame_counter
# 更新进度条和状态文本
progress = min(elapsed_time / video_duration, 1.0)
progress_bar.progress(progress)
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
f"处理帧率: {processing_fps:.1f} FPS | "
f"标准视频帧: {target_frame_idx}/{total_frames}")
except Exception as e:
# 如果UI更新失败避免程序崩溃
pass
# 这个循环没有 time.sleep(),它会尽力运行。
# 同步是通过每次循环都根据真实时间去标准视频中定位帧来实现的。
# --- 循环结束后的清理工作 ---
self.cleanup()
self.show_final_statistics()