commit 636fda3598a08a802bf5a3cc19a96dd53072da6f Author: game-loader Date: Mon Jun 16 14:21:04 2025 +0800 first commit diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000..fd8297e --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,3 @@ +[server] +maxUploadSize = 1024 # 允许上传最大1GB文件 +port = 8080 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..e8e9f6b --- /dev/null +++ b/environment.yml @@ -0,0 +1,143 @@ +name: /root/shared-nvme/posedet/posedet +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=5.1 + - asttokens=3.0.0 + - blas=1.0 + - bzip2=1.0.8 + - ca-certificates=2025.2.25 + - comm=0.2.1 + - debugpy=1.8.11 + - decorator=5.1.1 + - exceptiongroup=1.2.0 + - executing=0.8.3 + - expat=2.7.1 + - intel-openmp=2023.1.0 + - ipykernel=6.29.5 + - ipython=8.30.0 + - jedi=0.19.2 + - jupyter_client=8.6.3 + - jupyter_core=5.7.2 + - ld_impl_linux-64=2.40 + - libffi=3.4.4 + - libgcc-ng=11.2.0 + - libgomp=11.2.0 + - libsodium=1.0.18 + - libstdcxx-ng=11.2.0 + - libuuid=1.41.5 + - libxcb=1.17.0 + - matplotlib-inline=0.1.6 + - mkl=2023.1.0 + - mkl-service=2.4.0 + - mkl_fft=1.3.11 + - mkl_random=1.2.8 + - ncurses=6.4 + - nest-asyncio=1.6.0 + - numpy-base=2.2.5 + - openssl=3.0.16 + - packaging=24.2 + - parso=0.8.4 + - pexpect=4.8.0 + - pip=25.1 + - platformdirs=4.3.7 + - prompt-toolkit=3.0.43 + - prompt_toolkit=3.0.43 + - psutil=5.9.0 + - pthread-stubs=0.3 + - ptyprocess=0.7.0 + - pure_eval=0.2.2 + - pygments=2.19.1 + - python=3.10.18 + - python-dateutil=2.9.0post0 + - pyzmq=26.2.0 + - readline=8.2 + - setuptools=72.1.0 + - six=1.17.0 + - sqlite=3.45.3 + - stack_data=0.2.0 + - tbb=2021.8.0 + - tk=8.6.14 + - tornado=6.5.1 + - traitlets=5.14.3 + - typing_extensions=4.12.2 + - tzdata=2025b + - wcwidth=0.2.5 + - wheel=0.45.1 + - xorg-libx11=1.8.12 + - xorg-libxau=1.0.12 + - xorg-libxdmcp=1.1.5 + - xorg-xorgproto=2024.1 + - xz=5.6.4 + - zeromq=4.3.5 + - zlib=1.2.13 + - pip: + - altair==5.5.0 + - attrs==25.3.0 + - blinker==1.9.0 + - cachetools==5.5.2 + - certifi==2025.4.26 + - charset-normalizer==3.4.2 + - click==8.2.1 + - coloredlogs==15.0.1 + - contourpy==1.3.2 + - cycler==0.12.1 + - filelock==3.18.0 + - flatbuffers==25.2.10 + - fonttools==4.58.4 + - fsspec==2025.5.1 + - gitdb==4.0.12 + - gitpython==3.1.44 + - humanfriendly==10.0 + - idna==3.10 + - jinja2==3.1.6 + - jsonschema==4.24.0 + - jsonschema-specifications==2025.4.1 + - kiwisolver==1.4.8 + - markupsafe==3.0.2 + - matplotlib==3.10.3 + - mpmath==1.3.0 + - narwhals==1.42.1 + - networkx==3.4.2 + - numpy==2.2.6 + - nvidia-cublas-cu12==12.6.4.1 + - nvidia-cuda-cupti-cu12==12.6.80 + - nvidia-cuda-nvrtc-cu12==12.6.77 + - nvidia-cuda-runtime-cu12==12.6.77 + - nvidia-cudnn-cu12==9.5.1.17 + - nvidia-cufft-cu12==11.3.0.4 + - nvidia-cufile-cu12==1.11.1.6 + - nvidia-curand-cu12==10.3.7.77 + - nvidia-cusolver-cu12==11.7.1.2 + - nvidia-cusparse-cu12==12.5.4.2 + - nvidia-cusparselt-cu12==0.6.3 + - nvidia-nccl-cu12==2.26.2 + - nvidia-nvjitlink-cu12==12.6.85 + - nvidia-nvtx-cu12==12.6.77 + - onnxruntime==1.22.0 + - opencv-contrib-python-headless==4.11.0.86 + - pillow==11.2.1 + - plotly==6.1.2 + - protobuf==6.31.1 + - pyarrow==20.0.0 + - pydeck==0.9.1 + - pyparsing==3.2.3 + - pytz==2025.2 + - referencing==0.36.2 + - requests==2.32.4 + - rpds-py==0.25.1 + - rtmlib==0.0.13 + - smmap==5.0.2 + - streamlit==1.45.1 + - sympy==1.14.0 + - tenacity==9.1.2 + - toml==0.10.2 + - torch==2.7.1 + - torchaudio==2.7.1 + - torchvision==0.22.1 + - tqdm==4.67.1 + - triton==3.3.1 + - urllib3==2.4.0 + - watchdog==6.0.0 +prefix: /root/shared-nvme/posedet/posedet diff --git a/example.png b/example.png new file mode 100644 index 0000000..faa828f Binary files /dev/null and b/example.png differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..ace76b4 --- /dev/null +++ b/main.py @@ -0,0 +1,54 @@ +import cv2 +import numpy as np +from rtmlib import Body, draw_skeleton + +def detect_body_keypoints_image(image_path, output_path=None): + # 初始化模型 + device = 'cuda' if cv2.cuda.getCudaEnabledDeviceCount() > 0 else 'cpu' + backend = 'onnxruntime' + openpose_skeleton = True # True为OpenPose风格,False为MMPose风格 + + # 创建Body实例,使用平衡模式 + body = Body( + mode='balanced', + to_openpose=openpose_skeleton, + backend=backend, + device=device + ) + + # 读取图片 + image = cv2.imread(image_path) + if image is None: + print(f"无法读取图片: {image_path}") + return + + # 检测身体关键点 + keypoints, scores = body(image) + + print(keypoints) + + # 绘制关键点 + result_image = draw_skeleton( + image.copy(), + keypoints, + scores, + openpose_skeleton=openpose_skeleton, + kpt_thr=0.43 + ) + + # 保存或显示结果 + if output_path: + cv2.imwrite(output_path, result_image) + print(f"已保存结果到: {output_path}") + + # 显示结果 + # cv2.imshow('Body Keypoints Detection', result_image) + # cv2.waitKey(0) + # cv2.destroyAllWindows() + +if __name__ == "__main__": + # 替换为您的图片路径 + image_path = "example.png" + output_path = "body_result.jpg" + detect_body_keypoints_image(image_path, output_path) + diff --git a/motion_comparison_app.py b/motion_comparison_app.py new file mode 100644 index 0000000..8d87e37 --- /dev/null +++ b/motion_comparison_app.py @@ -0,0 +1,705 @@ +import streamlit as st +import cv2 +import time +import tempfile +import os +import torch +import matplotlib.pyplot as plt +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import pandas as pd +import numpy as np +from collections import deque +import math + + +# 导入rtmlib +try: + from rtmlib import Body, draw_skeleton +except ImportError: + st.error("请安装rtmlib库: pip install rtmlib") + st.stop() + +class PoseSimilarityAnalyzer: + def __init__(self): + self.similarity_history = deque(maxlen=500) # 保存最近500个相似度值 + self.frame_timestamps = deque(maxlen=500) + self.start_time = None + + # OpenPose关键点索引映射(Body 17个关键点) + self.keypoint_map = { + 'nose': 0, 'neck': 1, + 'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4, + 'right_shoulder': 5, 'right_elbow': 6, 'right_wrist': 7, + 'left_hip': 8, 'left_knee': 9, 'left_ankle': 10, + 'right_hip': 11, 'right_knee': 12, 'right_ankle': 13, + 'left_eye': 14, 'right_eye': 15, + 'left_ear': 16, 'right_ear': 17 + } + + # 关节角度定义(关节名:[点1, 关节点, 点2]) + self.joint_angles = { + 'left_elbow': ['left_shoulder', 'left_elbow', 'left_wrist'], + 'right_elbow': ['right_shoulder', 'right_elbow', 'right_wrist'], + 'left_shoulder': ['left_elbow', 'left_shoulder', 'neck'], + 'right_shoulder': ['right_elbow', 'right_shoulder', 'neck'], + 'left_knee': ['left_hip', 'left_knee', 'left_ankle'], + 'right_knee': ['right_hip', 'right_knee', 'right_ankle'], + 'left_hip': ['left_knee', 'left_hip', 'neck'], + 'right_hip': ['right_knee', 'right_hip', 'neck'], + } + + # 关节权重(重要关节权重更高) + self.joint_weights = { + 'left_elbow': 1.2, + 'right_elbow': 1.2, + 'left_shoulder': 1.0, + 'right_shoulder': 1.0, + 'left_knee': 1.3, + 'right_knee': 1.3, + 'left_hip': 1.1, + 'right_hip': 1.1, + } + + def calculate_angle(self, p1, p2, p3): + """计算三个点组成的角度""" + # 向量v1 = p1 - p2, v2 = p3 - p2 + v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]]) + v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]]) + + # 计算向量长度 + v1_norm = np.linalg.norm(v1) + v2_norm = np.linalg.norm(v2) + + if v1_norm == 0 or v2_norm == 0: + return None + + # 计算夹角(弧度) + cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm) + cos_angle = np.clip(cos_angle, -1.0, 1.0) # 防止数值误差 + angle = np.arccos(cos_angle) + + # 转换为角度 + return np.degrees(angle) + + def extract_joint_angles(self, keypoints, scores, confidence_threshold=0.3): + """从关键点提取关节角度""" + if keypoints is None or len(keypoints) == 0: + return None + + # 取第一个人的关键点 + person_keypoints = keypoints[0] if len(keypoints.shape) > 2 else keypoints + person_scores = scores[0] if len(scores.shape) > 1 else scores + + joint_angles_result = {} + + for joint_name, (p1_name, p2_name, p3_name) in self.joint_angles.items(): + try: + # 获取关键点索引 + p1_idx = self.keypoint_map[p1_name] + p2_idx = self.keypoint_map[p2_name] + p3_idx = self.keypoint_map[p3_name] + + # 检查置信度 + if (person_scores[p1_idx] < confidence_threshold or + person_scores[p2_idx] < confidence_threshold or + person_scores[p3_idx] < confidence_threshold): + continue + + # 获取坐标 + p1 = person_keypoints[p1_idx] + p2 = person_keypoints[p2_idx] + p3 = person_keypoints[p3_idx] + + # 计算角度 + angle = self.calculate_angle(p1, p2, p3) + if angle is not None: + joint_angles_result[joint_name] = angle + + except (KeyError, IndexError) as e: + continue + + return joint_angles_result + + def calculate_similarity(self, angles1, angles2): + """计算两组关节角度的相似度""" + if not angles1 or not angles2: + return 0.0 + + # 找到共同的关节 + common_joints = set(angles1.keys()) & set(angles2.keys()) + if not common_joints: + return 0.0 + + total_weight = 0 + weighted_similarity = 0 + + for joint in common_joints: + angle_diff = abs(angles1[joint] - angles2[joint]) + + # 角度差异转换为相似度(0-1) + # 使用高斯函数,角度差异越小,相似度越高 + similarity = math.exp(-(angle_diff ** 2) / (2 * (30 ** 2))) # 30度标准差 + + # 应用权重 + weight = self.joint_weights.get(joint, 1.0) + weighted_similarity += similarity * weight + total_weight += weight + + # 归一化相似度 + final_similarity = weighted_similarity / total_weight if total_weight > 0 else 0 + return min(max(final_similarity * 100, 0), 100) # 转换为0-100分 + + def add_similarity_score(self, score, timestamp=None): + """添加相似度分数到历史记录""" + if self.start_time is None: + self.start_time = time.time() + + if timestamp is None: + timestamp = time.time() - self.start_time + + self.similarity_history.append(score) + self.frame_timestamps.append(timestamp) + + def get_similarity_plot(self): + """生成相似度变化折线图""" + if len(self.similarity_history) < 2: + return None + + fig = go.Figure() + + fig.add_trace(go.Scatter( + x=list(self.frame_timestamps), + y=list(self.similarity_history), + mode='lines+markers', + name='动作相似度', + line=dict(color='#2E86AB', width=2), + marker=dict(size=4) + )) + + fig.update_layout( + title='动作相似度变化趋势', + xaxis_title='时间 (秒)', + yaxis_title='相似度分数 (%)', + yaxis=dict(range=[0, 100]), + height=300, + margin=dict(l=50, r=50, t=50, b=50), + showlegend=False + ) + + # 添加平均分数线 + if len(self.similarity_history) > 0: + avg_score = sum(self.similarity_history) / len(self.similarity_history) + fig.add_hline( + y=avg_score, + line_dash="dash", + line_color="red", + annotation_text=f"平均分: {avg_score:.1f}%" + ) + + return fig + + def reset(self): + """重置分析器""" + self.similarity_history.clear() + self.frame_timestamps.clear() + self.start_time = None + + +class MotionComparisonApp: + def __init__(self): + self.body_detector = None + self.is_running = False + self.standard_video_path = None + self.webcam_cap = None + self.standard_cap = None + self.similarity_analyzer = PoseSimilarityAnalyzer() + self.frame_counter = 0 + + + def preview_webcam(self): + """显示摄像头预览,帮助用户调整位置""" + # 打开摄像头 + self.webcam_cap = cv2.VideoCapture(0) + if not self.webcam_cap.isOpened(): + st.error("无法打开摄像头") + return False + + # 创建显示容器 + st.subheader("摄像头预览") + preview_text = st.empty() + preview_text.info("请调整您的位置,确保全身在画面中清晰可见") + + preview_placeholder = st.empty() + + # 显示停止预览按钮 + col1, col2, col3 = st.columns([1, 1, 1]) + stop_preview = col2.button("停止预览", key="stop_preview_btn") + + # 预览循环 + while not stop_preview: + # 读取摄像头帧 + ret, frame = self.webcam_cap.read() + if not ret: + st.error("无法获取摄像头画面") + break + + # 翻转摄像头画面(镜像效果) + frame = cv2.flip(frame, 1) + + # 如果检测器已初始化,可以显示关键点检测结果 + if self.body_detector is not None: + try: + keypoints, scores = self.body_detector(frame) + frame = draw_skeleton( + frame.copy(), + keypoints, + scores, + openpose_skeleton=True, + kpt_thr=0.43 + ) + except Exception as e: + pass # 预览阶段,忽略检测错误 + + # 转换颜色空间并显示 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + preview_placeholder.image(frame_rgb, caption="摄像头预览", use_column_width=True) + + # 检查停止按钮状态 + if col2.button("停止预览", key=f"stop_preview_btn_{time.time()}", replace=True): + break + + # 控制帧率 + time.sleep(0.03) # 约30fps + + # 返回预览是否成功 + return True + + def initialize_detector(self): + """初始化身体关键点检测器""" + if self.body_detector is None: + device = 'cuda' if torch.cuda.device_count()> 0 else 'cpu' + self.body_detector = Body( + mode='lightweight', # 使用轻量级模式提高实时性能 + to_openpose=True, + backend='onnxruntime', + device=device + ) + st.success(f"关键点检测器初始化完成 (设备: {device})") + + def process_frame_with_keypoints(self, frame): + """处理帧并添加关键点""" + if self.body_detector is None: + return frame + + try: + # 检测关键点 + keypoints, scores = self.body_detector(frame) + + # 绘制关键点 + result_frame = draw_skeleton( + frame.copy(), + keypoints, + scores, + openpose_skeleton=True, + kpt_thr=0.43 + ) + return result_frame + except Exception as e: + st.error(f"关键点检测错误: {e}") + return frame + + def start_comparison(self, video_path): + """开始动作比较""" + self.is_running = True + self.standard_video_path = video_path + self.frame_counter = 0 + + # 重置相似度分析器 + self.similarity_analyzer.reset() + + # 打开标准视频 + self.standard_cap = cv2.VideoCapture(video_path) + if not self.standard_cap.isOpened(): + st.error("无法打开标准视频文件") + return + + # 如果摄像头未打开,则打开摄像头 + if self.webcam_cap is None or not self.webcam_cap.isOpened(): + self.webcam_cap = cv2.VideoCapture(0) + if not self.webcam_cap.isOpened(): + st.error("无法打开摄像头") + return + + # 获取视频信息 + fps = self.standard_cap.get(cv2.CAP_PROP_FPS) + frame_delay = 1.0 / fps if fps > 0 else 1.0 / 30 + + # 创建显示容器 + col1, col2 = st.columns(2) + + with col1: + st.subheader("标准动作视频") + standard_placeholder = st.empty() + + with col2: + st.subheader("摄像头实时影像") + webcam_placeholder = st.empty() + + # 添加相似度显区域 + similarity_col1, similarity_col2 = st.columns([1, 2]) + with similarity_col1: + st.subheader("实时相似度") + similarity_score_placeholder = st.empty() + avg_score_placeholder = st.empty() + + with similarity_col2: + st.subheader("相似度变化图") + similarity_plot_placeholder = st.empty() + + # 添加控制按钮 + control_col1, control_col2, control_col3 = st.columns(3) + with control_col1: + if st.button("停止", key="stop_btn"): + self.is_running = False + with control_col2: + if st.button("重新开始", key="restart_btn"): + self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + self.similarity_analyzer.reset() + + # 主循环 + frame_count = 0 + start_time = time.time() + current_similarity = 0 + + while self.is_running: + self.frame_counter += 1 + + # 读取标准视频帧 + ret_standard, standard_frame = self.standard_cap.read() + if not ret_standard: + # 视频结束,重新开始 + self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + continue + + # 读取摄像头帧 + ret_webcam, webcam_frame = self.webcam_cap.read() + if not ret_webcam: + st.error("无法获取摄像头画面") + break + + # 翻转摄像头画面(镜像效果) + webcam_frame = cv2.flip(webcam_frame, 1) + + # 调整尺寸使两个视频大小一致 + target_height = 480 + target_width = 640 + + standard_frame = cv2.resize(standard_frame, (target_width, target_height)) + webcam_frame = cv2.resize(webcam_frame, (target_width, target_height)) + + # 处理关键点检测 + standard_keypoints, standard_scores = self.body_detector(standard_frame) + webcam_keypoints, webcam_scores = self.body_detector(webcam_frame) + + # 每10帧计算一次相似度 + if self.frame_counter % 10 == 0: + # 提取关节角度 + standard_angles = self.similarity_analyzer.extract_joint_angles( + standard_keypoints, standard_scores + ) + webcam_angles = self.similarity_analyzer.extract_joint_angles( + webcam_keypoints, webcam_scores + ) + + # 计算相似度 + if standard_angles and webcam_angles: + current_similarity = self.similarity_analyzer.calculate_similarity( + standard_angles, webcam_angles + ) + + # 添加到历史记录 + timestamp = time.time() - start_time + self.similarity_analyzer.add_similarity_score(current_similarity, timestamp) + + # 绘制关键点 + standard_with_keypoints = draw_skeleton( + standard_frame.copy(), + standard_keypoints, + standard_scores, + openpose_skeleton=True, + kpt_thr=0.43 + ) + webcam_with_keypoints = draw_skeleton( + webcam_frame.copy(), + webcam_keypoints, + webcam_scores, + openpose_skeleton=True, + kpt_thr=0.43 + ) + + # 转换颜色空间 (BGR to RGB) + standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB) + webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB) + + # 显示画面 + standard_placeholder.image(standard_rgb, caption="标准动作", use_column_width=True) + webcam_placeholder.image(webcam_rgb, caption="您的动作", use_column_width=True) + + # 显示相似度信息 + if len(self.similarity_analyzer.similarity_history) > 0: + avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history) + + # 使用不同颜色显示相似度 + if current_similarity >= 80: + similarity_color = "🟢" + elif current_similarity >= 60: + similarity_color = "🟡" + else: + similarity_color = "🔴" + + similarity_score_placeholder.metric( + "当前相似度", + f"{similarity_color} {current_similarity:.1f}%", + delta=f"{current_similarity - avg_similarity:.1f}%" if len(self.similarity_analyzer.similarity_history) > 1 else None + ) + + avg_score_placeholder.metric( + "平均相似度", + f"{avg_similarity:.1f}%" + ) + + # 更新相似度图表 + if len(self.similarity_analyzer.similarity_history) >= 2: + similarity_plot = self.similarity_analyzer.get_similarity_plot() + if similarity_plot: + similarity_plot_placeholder.plotly_chart( + similarity_plot, + use_container_width=True + ) + + # 计算FPS + frame_count += 1 + if frame_count % 30 == 0: + elapsed = time.time() - start_time + fps_display = frame_count / elapsed + st.sidebar.metric("实时FPS", f"{fps_display:.1f}") + + # 控制播放速度 + time.sleep(frame_delay) + + # 显示最终统计 + if len(self.similarity_analyzer.similarity_history) > 0: + final_avg = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history) + max_similarity = max(self.similarity_analyzer.similarity_history) + min_similarity = min(self.similarity_analyzer.similarity_history) + + st.success(f"动作比较完成!") + col1, col2, col3 = st.columns(3) + with col1: + st.metric("平均相似度", f"{final_avg:.1f}%") + with col2: + st.metric("最高相似度", f"{max_similarity:.1f}%") + with col3: + st.metric("最低相似度", f"{min_similarity:.1f}%") + + # 清理资源 + if self.standard_cap: + self.standard_cap.release() + if self.webcam_cap: + self.webcam_cap.release() + +def main(): + os.makedirs("preset_videos", exist_ok=True) + # 页面配置 + st.set_page_config( + page_title="动作比较与关键点检测", + page_icon="🏃", + layout="wide" + ) + + st.title("🏃 动作比较与关键点检测系统") + st.markdown("---") + + # 创建应用实例 + if 'app' not in st.session_state: + st.session_state.app = MotionComparisonApp() + + app = st.session_state.app + + # 侧边栏控制面板 + with st.sidebar: + st.header("控制面板") + + # 预设视频选择 + preset_videos = { + "六字诀": "preset_videos/liuzi.mp4", + } + + # 在这里放置预设视频状态检查代码 + st.subheader("预设视频状态") + video_status_ok = False # 跟踪是否至少有一个预设视频可用 + for name, path in preset_videos.items(): + if os.path.exists(path): + st.success(f"✅ {name}") + video_status_ok = True + else: + st.error(f"❌ {name} (文件不存在)") + + if not video_status_ok: + st.warning("⚠️ 没有找到任何预设视频文件!请将视频文件放入'preset_videos'文件夹。") + + # 接下来是视频来源选择 + video_source = st.radio( + "选择视频来源", + ["预设视频", "上传视频"], + index=0 if video_status_ok else 1 # 如果没有预设视频可用,默认选择上传 + ) + if video_source == "预设视频": + selected_preset = st.selectbox( + "选择预设视频", + list(preset_videos.keys()) + ) + video_path = preset_videos[selected_preset] + + # 检查预设视频文件是否存在 + if not os.path.exists(video_path): + st.error(f"预设视频文件不存在: {video_path}") + st.info("请确保在'preset_videos'文件夹中放置了对应视频文件") + video_path = None + else: + st.success(f"已选择预设视频: {selected_preset}") + else: + # 上传标准视频,调高大小限制到1GB + uploaded_video = st.file_uploader( + "上传标准动作视频", + type=['mp4', 'avi', 'mov', 'mkv'], + help="支持 MP4, AVI, MOV, MKV 格式", + accept_multiple_files=False + ) + + # 设置最大上传大小 (1GB) + # 注意:这需要修改streamlit的配置 + st.markdown(""" + + """, unsafe_allow_html=True) + + if uploaded_video is not None: + # 保存上传的视频到临时文件 + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file: + tmp_file.write(uploaded_video.read()) + video_path = tmp_file.name + st.success("✅ 视频上传成功!") + else: + video_path = None + + # 初始化检测器 + if st.button("初始化关键点检测器"): + with st.spinner("正在初始化..."): + app.initialize_detector() + + # 显示系统信息 + st.subheader("系统信息") + if torch.cuda.device_count() > 0: + st.success("✅ CUDA 可用") + else: + st.info("ℹ️ 使用 CPU 模式") + + + # 主界面 + if video_path is not None: + # 显示视频信息 + cap = cv2.VideoCapture(video_path) + if cap.isOpened(): + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps if fps > 0 else 0 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("时长", f"{duration:.1f}秒") + with col2: + st.metric("帧率", f"{fps:.1f} FPS") + with col3: + st.metric("分辨率", f"{width}×{height}") + with col4: + st.metric("总帧数", f"{frame_count}") + + cap.release() + + # 开始按钮 + st.markdown("---") + col1, col2, col3 = st.columns([1, 1, 1]) +# 添加一个session_state变量来跟踪预览状态 + if 'preview_done' not in st.session_state: + st.session_state.preview_done = False + with col2: + # 如果还没预览,显示"预览摄像头"按钮 + if not st.session_state.preview_done: + if st.button("📷 预览摄像头", use_container_width=True): + if app.body_detector is None: + st.error("请先初始化关键点检测器") + else: + # 进行摄像头预览 + preview_success = app.preview_webcam() + if preview_success: + st.session_state.preview_done = True + # 使用rerun来刷新界面,显示开始比较按钮 + st.rerun() + # 如果已经预览过,显示"开始动作比较"按钮 + else: + if st.button("🚀 开始动作比较", use_container_width=True): + if app.body_detector is None: + st.error("请先初始化关键点检测器") + else: + st.info("开始动作比较...") + app.start_comparison(video_path) + # 重置预览状态,以便下次使用 + st.session_state.preview_done = False + + # 如果是上传的视频,结束时清理临时文件 + if video_source == "上传视频" and os.path.exists(video_path) and video_path.startswith(tempfile.gettempdir()): + try: + # 这里不要立即删除,而是在应用结束时删除 + # 注册一个退出时删除文件的函数 + import atexit + atexit.register(lambda: os.unlink(video_path) if os.path.exists(video_path) else None) + except: + pass + else: + if video_source == "预设视频": + st.error("预设视频文件不可用,请检查文件路径") + else: + st.info("👆 请在左侧上传标准动作视频文件") + + # 显示使用说明 + with st.expander("📖 使用说明"): + st.markdown(""" + ### 如何使用此应用: + + 1. **上传视频**: 在左侧上传标准动作视频文件 + 2. **初始化检测器**: 点击"初始化关键点检测器"按钮 + 3. **开始比较**: 点击"开始动作比较"按钮 + 4. **跟随动作**: 在摄像头前跟随标准视频做相同动作 + 5. **观察对比**: 系统会实时显示两边的关键点检测结果 + + ### 特性: + - ✅ 实时关键点检测 + - ✅ 镜像摄像头画面 + - ✅ 自动循环播放标准视频 + - ✅ FPS 监控 + - ✅ GPU 加速支持 + + ### 要求: + - 摄像头权限 + - 良好的光照条件 + - 清晰的人体轮廓 + """) + +if __name__ == "__main__": + main() + diff --git a/test.py b/test.py new file mode 100644 index 0000000..2e2aa69 --- /dev/null +++ b/test.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" +简单GPU检测脚本 - 检测系统是否有显卡及数量 +""" + +import torch +import cv2 + +def check_gpu(): + """检测GPU数量""" + print("🔍 正在检测GPU...") + + # 检查CUDA是否可用 + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + print(f"✅ 检测到 {gpu_count} 张GPU") + + print(f"cv2 检测到{cv2.cuda.getCudaEnabledDeviceCount()}张gpu") + + # 显示GPU名称 + for i in range(gpu_count): + gpu_name = torch.cuda.get_device_name(i) + print(f" GPU {i}: {gpu_name}") + else: + print("❌ 未检测到可用的GPU") + + # 检查MPS (Apple Silicon) + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + print("✅ 检测到Apple Silicon GPU (MPS)") + +if __name__ == "__main__": + check_gpu()