import streamlit as st import cv2 import time import os import numpy as np import torch import threading import queue from rtmlib import Body, draw_skeleton from audio_player import AudioPlayer from pose_analyzer import PoseSimilarityAnalyzer from config import REALSENSE_AVAILABLE def draw_skeleton_with_similarity(img, keypoints, scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=2): """ 自定义骨骼绘制函数,根据关节相似度设置颜色 相似度 > 90: 绿色 相似度 80-90: 黄色 相似度 < 80: 红色 支持多人检测,将绘制所有检测到的人体骨骼 """ if keypoints is None or len(keypoints) == 0: return img # OpenPose连接关系 skeleton = [ [1, 0], [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11], [11, 12], [12, 13], [0, 14], [0, 15], [14, 16], [15, 17] ] # 关节与关节角度的映射 joint_to_angle_mapping = { (2, 3): 'left_shoulder', (3, 4): 'left_elbow', (5, 6): 'right_shoulder', (6, 7): 'right_elbow', (8, 9): 'left_hip', (9, 10): 'left_knee', (11, 12): 'right_hip', (12, 13): 'right_knee' } # 骨骼的默认颜色 default_color = (0, 255, 255) # 黄色 # 确保keypoints和scores的形状正确,处理多人情况 if len(keypoints.shape) == 2: keypoints = keypoints[None, :, :] scores = scores[None, :, :] # 遍历所有人 num_instances = keypoints.shape[0] for person_idx in range(num_instances): person_kpts = keypoints[person_idx] person_scores = scores[person_idx] # 绘制骨骼 for limb_id, limb in enumerate(skeleton): joint_a, joint_b = limb if joint_a >= len(person_scores) or joint_b >= len(person_scores): continue if person_scores[joint_a] < kpt_thr or person_scores[joint_b] < kpt_thr: continue x_a, y_a = person_kpts[joint_a] x_b, y_b = person_kpts[joint_b] # 确定线条颜色 color = default_color if joint_similarities is not None: # 检查这个连接是否有对应的关节角度 if (joint_a, joint_b) in joint_to_angle_mapping: angle_name = joint_to_angle_mapping[(joint_a, joint_b)] if angle_name in joint_similarities: similarity = joint_similarities[angle_name] if similarity > 90: color = (0, 255, 0) # 绿色 elif similarity > 80: color = (0, 255, 255) # 黄色 else: color = (0, 0, 255) # 红色 cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), color, thickness=line_width) # 绘制关键点 for kpt_id, (x, y) in enumerate(person_kpts): if person_scores[kpt_id] < kpt_thr: continue cv2.circle(img, (int(x), int(y)), 3, (255, 0, 255), -1) return img if REALSENSE_AVAILABLE: import pyrealsense2 as rs class MotionComparisonApp: """Main application class for motion comparison.""" def __init__(self): self.body_detector = None self.is_running = False self.standard_video_path = None self.webcam_cap = None self.standard_cap = None self.similarity_analyzer = PoseSimilarityAnalyzer() self.frame_counter = 0 self.audio_player = AudioPlayer() self.display_settings = {'resolution_mode': 'high', 'target_width': 960, 'target_height': 720} self.realsense_pipeline = None self.is_realsense_active = False self.last_error_time = 0 self.error_count = 0 # Async processing components self.pose_data_queue = queue.Queue(maxsize=50) self.similarity_thread = None self.similarity_stop_flag = threading.Event() if 'comparison_state' not in st.session_state: st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False} def get_display_resolution(self): modes = {'high': (1024, 576), 'medium': (960, 720), 'low': (640, 480)} mode = self.display_settings.get('resolution_mode', 'medium') return modes.get(mode, (960, 720)) def initialize_detector(self): if self.body_detector is None: try: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.body_detector = Body(mode='lightweight', to_openpose=True, backend='onnxruntime', device=device) st.success(f"Keypoint detector initialized on device: {device}") return True except Exception as e: st.error(f"Detector initialization failed: {e}") return False return True def initialize_camera(self): if REALSENSE_AVAILABLE: try: self.realsense_pipeline = rs.pipeline() config = rs.config() width, height = self.get_display_resolution() config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30) profile = self.realsense_pipeline.start(config) device = profile.get_device().get_info(rs.camera_info.name) st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})") self.is_realsense_active = True return True except Exception as e: st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.") return self._initialize_webcam() else: return self._initialize_webcam() def _initialize_webcam(self): try: self.webcam_cap = cv2.VideoCapture(0) if self.webcam_cap.isOpened(): self.webcam_cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')) width, height = self.get_display_resolution() self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) self.webcam_cap.set(cv2.CAP_PROP_FPS, 30) actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})") return True else: st.error("❌ 无法打开USB摄像头") return False except Exception as e: st.error(f"❌ USB摄像头初始化失败: {e}") return False def read_camera_frame(self): if self.is_realsense_active and self.realsense_pipeline: try: frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000) color_frame = frames.get_color_frame() if not color_frame: return False, None return True, np.asanyarray(color_frame.get_data()) except Exception: return False, None elif self.webcam_cap and self.webcam_cap.isOpened(): return self.webcam_cap.read() return False, None def get_camera_preview_frame(self): ret, frame = self.read_camera_frame() if not ret or frame is None: return None frame = cv2.flip(frame, 1) if self.body_detector: try: keypoints, scores = self.body_detector(frame) frame = draw_skeleton_with_similarity(frame.copy(), keypoints, scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=1) except Exception: pass return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) def cleanup(self): """Cleans up all resources.""" # Stop similarity calculation thread if self.similarity_thread and self.similarity_thread.is_alive(): self.similarity_stop_flag.set() self.similarity_thread.join(timeout=2) if self.standard_cap: self.standard_cap.release() if self.webcam_cap: self.webcam_cap.release() if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop() self.audio_player.cleanup() self.is_running = False st.session_state.comparison_state['is_running'] = False def show_final_statistics(self): """Displays final statistics after the comparison ends.""" st.markdown("---") # 添加分隔线 history = self.similarity_analyzer.similarity_history if not history: st.warning("没有足够的比较数据来生成分析结果。") if st.button("返回主菜单", use_container_width=True): st.rerun() return final_avg = sum(history) / len(history) # 我们将评级文本和颜色分开定义,逻辑更清晰 if final_avg >= 80: level = "非常棒! 👏" md_color = "green" # 使用 'green' 颜色 elif final_avg >= 60: level = "整体不错! 👍" md_color = "blue" # 使用 'blue' 颜色 else: level = "需要改进! 💪" md_color = "orange" # 使用 'orange' 颜色替代 'warning' st.success("🎉 比较完成!") st.markdown(f"**整体表现**: :{md_color}[{level}]") col1, col2, col3 = st.columns(3) col1.metric("平均相似度", f"{final_avg:.1f}%") col2.metric("最高相似度", f"{max(history):.1f}%") col3.metric("最低相似度", f"{min(history):.1f}%") if final_avg < 60: with st.expander("💡 改善建议"): st.markdown("- 确保您的全身在摄像头画面中清晰可见\n" "- 尽量匹配标准视频的节奏和动作幅度\n" "- 确保光线充足且稳定") # 添加返回按钮 st.markdown("---") if st.button("返回主菜单", use_container_width=True): st.session_state.comparison_state['is_running'] = False # 确保状态重置 st.rerun() def _similarity_calculation_worker(self, start_time, video_fps): """Background thread for similarity calculation and plot updates.""" while not self.similarity_stop_flag.is_set(): try: # Get pose data from queue with timeout pose_data = self.pose_data_queue.get(timeout=1.0) if pose_data is None: # Poison pill to stop break elapsed_time, standard_angles, webcam_angles = pose_data if standard_angles and webcam_angles: current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles) self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time) self.pose_data_queue.task_done() except queue.Empty: continue except Exception as e: # Log error but continue processing continue def start_comparison(self, video_path): """The main loop for comparing motion, synchronized with real-time.""" self.is_running = True st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False}) self.standard_video_path = video_path self.frame_counter = 0 self.similarity_analyzer.reset() audio_loaded = self.audio_player.load_audio(video_path) if audio_loaded: st.success("✅ 音频加载成功") else: st.info("ℹ️ 将不会播放音频") self.standard_cap = cv2.VideoCapture(video_path) if not self.standard_cap.isOpened(): st.error("无法打开标准视频文件") return if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()): if not self.initialize_camera(): return # --- 获取视频信息和设置变量 --- total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT)) video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS) if video_fps == 0: video_fps = 30 # 防止除零错误 video_duration = total_frames / video_fps target_width, target_height = self.get_display_resolution() # --- UI 占位符 --- st.markdown("### 📺 视频比较") video_col1, video_col2 = st.columns(2, gap="small") standard_placeholder = video_col1.empty() webcam_placeholder = video_col2.empty() with video_col1: st.markdown("#### 🎯 标准动作视频") with video_col2: st.markdown(f"#### 📹 {'RealSense' if self.is_realsense_active else 'USB'}实时影像") st.markdown("---") control_col1, control_col2 = st.columns(2) stop_button = control_col1.button("⏹️ 停止", use_container_width=True, key="stop_comparison") restart_button = control_col2.button("🔄 重新开始", use_container_width=True, key="restart_comparison") st.markdown("---") st.markdown("### 📊 动作相似度分析") sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2]) similarity_score_placeholder = sim_col1.empty() avg_score_placeholder = sim_col2.empty() similarity_plot_placeholder = sim_col3.empty() progress_bar = st.progress(0) status_text = st.empty() # --- 按钮逻辑 --- if stop_button: st.session_state.comparison_state['should_stop'] = True if restart_button: st.session_state.comparison_state['should_restart'] = True # --- 主循环 (基于时间同步,专注于实时显示) --- start_time = time.time() # --- 启动异步相似度计算线程 --- self.similarity_stop_flag.clear() self.similarity_thread = threading.Thread( target=self._similarity_calculation_worker, args=(start_time, video_fps), daemon=True ) self.similarity_thread.start() current_similarity = 0 last_plot_update = 0 if audio_loaded: self.audio_player.play() while True: # --- 检查退出条件 --- elapsed_time = time.time() - start_time if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']: # 停止音频播放 if audio_loaded: self.audio_player.stop() break # --- 检查重新开始 --- if st.session_state.comparison_state['should_restart']: self.similarity_analyzer.reset() self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置视频到开头 start_time = time.time() # 重置计时器 if audio_loaded: self.audio_player.restart() st.session_state.comparison_state['should_restart'] = False continue # --- 基于真实时间获取标准视频帧 --- target_frame_idx = int(elapsed_time * video_fps) self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_idx) ret_standard, standard_frame = self.standard_cap.read() if not ret_standard: continue # 如果读取失败,跳过这一轮 # --- 读取当前摄像头帧 --- ret_webcam, webcam_frame = self.read_camera_frame() if not ret_webcam or webcam_frame is None: status_text.warning("摄像头画面读取失败,请检查连接...") time.sleep(0.1) # 短暂等待 continue self.frame_counter += 1 webcam_frame = cv2.flip(webcam_frame, 1) standard_frame = cv2.resize(standard_frame, (target_width, target_height)) webcam_frame = cv2.resize(webcam_frame, (target_width, target_height)) # --- 关键点检测 (实时处理) --- try: standard_keypoints, standard_scores = self.body_detector(standard_frame) webcam_keypoints, webcam_scores = self.body_detector(webcam_frame) # --- 异步相似度计算 (发送给后台线程) --- if standard_keypoints is not None and webcam_keypoints is not None: standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores) webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores) # 非阻塞地将数据放入队列 try: self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles)) except queue.Full: # 如果队列满了,丢弃最旧的数据 try: self.pose_data_queue.get_nowait() self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles)) except queue.Empty: pass except Exception: standard_keypoints, webcam_keypoints = None, None # --- 绘制与显示 (可以保持原来的逻辑,只修改UI更新部分) --- try: # 计算关节相似度 joint_similarities = None if standard_keypoints is not None and webcam_keypoints is not None: standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores) webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores) if standard_angles and webcam_angles: joint_similarities = self.similarity_analyzer.calculate_joint_similarities(standard_angles, webcam_angles) # 绘制骨骼 (线条更窄,根据相似度设置颜色) standard_with_keypoints = draw_skeleton_with_similarity(standard_frame.copy(), standard_keypoints, standard_scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=1) webcam_with_keypoints = draw_skeleton_with_similarity(webcam_frame.copy(), webcam_keypoints, webcam_scores, joint_similarities=joint_similarities, openpose_skeleton=True, kpt_thr=0.43, line_width=1) # 更新视频画面 (主线程专注于实时显示) standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True) webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True) # 更新UI (在主线程中,降低更新频率) if self.frame_counter % 30 == 0: # 每10帧更新一次UI # 更新状态信息 if self.similarity_analyzer.similarity_history: latest_similarity = self.similarity_analyzer.similarity_history[-1] avg_sim = np.mean(self.similarity_analyzer.similarity_history) similarity_score_placeholder.metric("当前相似度", f"{latest_similarity:.1f}%") avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%") # 更新图表 (每50帧更新一次) if self.frame_counter % 30 == 0: plot = self.similarity_analyzer.get_similarity_plot() if plot: similarity_plot_placeholder.plotly_chart(plot, use_container_width=True) # 更新进度条和状态文本 progress = min(elapsed_time / video_duration, 1.0) progress_bar.progress(progress) processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0 status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | " f"处理帧率: {processing_fps:.1f} FPS | " f"标准视频帧: {target_frame_idx}/{total_frames}") except Exception as e: # 如果UI更新失败,避免程序崩溃 pass # 停止相似度计算线程 self.similarity_stop_flag.set() if self.similarity_thread and self.similarity_thread.is_alive(): # 发送毒丸信号停止线程 try: self.pose_data_queue.put_nowait(None) except queue.Full: pass self.similarity_thread.join(timeout=2) # 确保音频已停止 if audio_loaded: self.audio_player.stop() self.cleanup() self.show_final_statistics()