first commit

This commit is contained in:
game-loader 2025-06-16 14:21:04 +08:00
commit 636fda3598
6 changed files with 937 additions and 0 deletions

3
.streamlit/config.toml Normal file
View File

@ -0,0 +1,3 @@
[server]
maxUploadSize = 1024 # 允许上传最大1GB文件
port = 8080

143
environment.yml Normal file
View File

@ -0,0 +1,143 @@
name: /root/shared-nvme/posedet/posedet
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- asttokens=3.0.0
- blas=1.0
- bzip2=1.0.8
- ca-certificates=2025.2.25
- comm=0.2.1
- debugpy=1.8.11
- decorator=5.1.1
- exceptiongroup=1.2.0
- executing=0.8.3
- expat=2.7.1
- intel-openmp=2023.1.0
- ipykernel=6.29.5
- ipython=8.30.0
- jedi=0.19.2
- jupyter_client=8.6.3
- jupyter_core=5.7.2
- ld_impl_linux-64=2.40
- libffi=3.4.4
- libgcc-ng=11.2.0
- libgomp=11.2.0
- libsodium=1.0.18
- libstdcxx-ng=11.2.0
- libuuid=1.41.5
- libxcb=1.17.0
- matplotlib-inline=0.1.6
- mkl=2023.1.0
- mkl-service=2.4.0
- mkl_fft=1.3.11
- mkl_random=1.2.8
- ncurses=6.4
- nest-asyncio=1.6.0
- numpy-base=2.2.5
- openssl=3.0.16
- packaging=24.2
- parso=0.8.4
- pexpect=4.8.0
- pip=25.1
- platformdirs=4.3.7
- prompt-toolkit=3.0.43
- prompt_toolkit=3.0.43
- psutil=5.9.0
- pthread-stubs=0.3
- ptyprocess=0.7.0
- pure_eval=0.2.2
- pygments=2.19.1
- python=3.10.18
- python-dateutil=2.9.0post0
- pyzmq=26.2.0
- readline=8.2
- setuptools=72.1.0
- six=1.17.0
- sqlite=3.45.3
- stack_data=0.2.0
- tbb=2021.8.0
- tk=8.6.14
- tornado=6.5.1
- traitlets=5.14.3
- typing_extensions=4.12.2
- tzdata=2025b
- wcwidth=0.2.5
- wheel=0.45.1
- xorg-libx11=1.8.12
- xorg-libxau=1.0.12
- xorg-libxdmcp=1.1.5
- xorg-xorgproto=2024.1
- xz=5.6.4
- zeromq=4.3.5
- zlib=1.2.13
- pip:
- altair==5.5.0
- attrs==25.3.0
- blinker==1.9.0
- cachetools==5.5.2
- certifi==2025.4.26
- charset-normalizer==3.4.2
- click==8.2.1
- coloredlogs==15.0.1
- contourpy==1.3.2
- cycler==0.12.1
- filelock==3.18.0
- flatbuffers==25.2.10
- fonttools==4.58.4
- fsspec==2025.5.1
- gitdb==4.0.12
- gitpython==3.1.44
- humanfriendly==10.0
- idna==3.10
- jinja2==3.1.6
- jsonschema==4.24.0
- jsonschema-specifications==2025.4.1
- kiwisolver==1.4.8
- markupsafe==3.0.2
- matplotlib==3.10.3
- mpmath==1.3.0
- narwhals==1.42.1
- networkx==3.4.2
- numpy==2.2.6
- nvidia-cublas-cu12==12.6.4.1
- nvidia-cuda-cupti-cu12==12.6.80
- nvidia-cuda-nvrtc-cu12==12.6.77
- nvidia-cuda-runtime-cu12==12.6.77
- nvidia-cudnn-cu12==9.5.1.17
- nvidia-cufft-cu12==11.3.0.4
- nvidia-cufile-cu12==1.11.1.6
- nvidia-curand-cu12==10.3.7.77
- nvidia-cusolver-cu12==11.7.1.2
- nvidia-cusparse-cu12==12.5.4.2
- nvidia-cusparselt-cu12==0.6.3
- nvidia-nccl-cu12==2.26.2
- nvidia-nvjitlink-cu12==12.6.85
- nvidia-nvtx-cu12==12.6.77
- onnxruntime==1.22.0
- opencv-contrib-python-headless==4.11.0.86
- pillow==11.2.1
- plotly==6.1.2
- protobuf==6.31.1
- pyarrow==20.0.0
- pydeck==0.9.1
- pyparsing==3.2.3
- pytz==2025.2
- referencing==0.36.2
- requests==2.32.4
- rpds-py==0.25.1
- rtmlib==0.0.13
- smmap==5.0.2
- streamlit==1.45.1
- sympy==1.14.0
- tenacity==9.1.2
- toml==0.10.2
- torch==2.7.1
- torchaudio==2.7.1
- torchvision==0.22.1
- tqdm==4.67.1
- triton==3.3.1
- urllib3==2.4.0
- watchdog==6.0.0
prefix: /root/shared-nvme/posedet/posedet

BIN
example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

54
main.py Normal file
View File

@ -0,0 +1,54 @@
import cv2
import numpy as np
from rtmlib import Body, draw_skeleton
def detect_body_keypoints_image(image_path, output_path=None):
# 初始化模型
device = 'cuda' if cv2.cuda.getCudaEnabledDeviceCount() > 0 else 'cpu'
backend = 'onnxruntime'
openpose_skeleton = True # True为OpenPose风格False为MMPose风格
# 创建Body实例使用平衡模式
body = Body(
mode='balanced',
to_openpose=openpose_skeleton,
backend=backend,
device=device
)
# 读取图片
image = cv2.imread(image_path)
if image is None:
print(f"无法读取图片: {image_path}")
return
# 检测身体关键点
keypoints, scores = body(image)
print(keypoints)
# 绘制关键点
result_image = draw_skeleton(
image.copy(),
keypoints,
scores,
openpose_skeleton=openpose_skeleton,
kpt_thr=0.43
)
# 保存或显示结果
if output_path:
cv2.imwrite(output_path, result_image)
print(f"已保存结果到: {output_path}")
# 显示结果
# cv2.imshow('Body Keypoints Detection', result_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
if __name__ == "__main__":
# 替换为您的图片路径
image_path = "example.png"
output_path = "body_result.jpg"
detect_body_keypoints_image(image_path, output_path)

705
motion_comparison_app.py Normal file
View File

@ -0,0 +1,705 @@
import streamlit as st
import cv2
import time
import tempfile
import os
import torch
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
from collections import deque
import math
# 导入rtmlib
try:
from rtmlib import Body, draw_skeleton
except ImportError:
st.error("请安装rtmlib库: pip install rtmlib")
st.stop()
class PoseSimilarityAnalyzer:
def __init__(self):
self.similarity_history = deque(maxlen=500) # 保存最近500个相似度值
self.frame_timestamps = deque(maxlen=500)
self.start_time = None
# OpenPose关键点索引映射Body 17个关键点
self.keypoint_map = {
'nose': 0, 'neck': 1,
'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
'right_shoulder': 5, 'right_elbow': 6, 'right_wrist': 7,
'left_hip': 8, 'left_knee': 9, 'left_ankle': 10,
'right_hip': 11, 'right_knee': 12, 'right_ankle': 13,
'left_eye': 14, 'right_eye': 15,
'left_ear': 16, 'right_ear': 17
}
# 关节角度定义(关节名:[点1, 关节点, 点2]
self.joint_angles = {
'left_elbow': ['left_shoulder', 'left_elbow', 'left_wrist'],
'right_elbow': ['right_shoulder', 'right_elbow', 'right_wrist'],
'left_shoulder': ['left_elbow', 'left_shoulder', 'neck'],
'right_shoulder': ['right_elbow', 'right_shoulder', 'neck'],
'left_knee': ['left_hip', 'left_knee', 'left_ankle'],
'right_knee': ['right_hip', 'right_knee', 'right_ankle'],
'left_hip': ['left_knee', 'left_hip', 'neck'],
'right_hip': ['right_knee', 'right_hip', 'neck'],
}
# 关节权重(重要关节权重更高)
self.joint_weights = {
'left_elbow': 1.2,
'right_elbow': 1.2,
'left_shoulder': 1.0,
'right_shoulder': 1.0,
'left_knee': 1.3,
'right_knee': 1.3,
'left_hip': 1.1,
'right_hip': 1.1,
}
def calculate_angle(self, p1, p2, p3):
"""计算三个点组成的角度"""
# 向量v1 = p1 - p2, v2 = p3 - p2
v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]])
v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]])
# 计算向量长度
v1_norm = np.linalg.norm(v1)
v2_norm = np.linalg.norm(v2)
if v1_norm == 0 or v2_norm == 0:
return None
# 计算夹角(弧度)
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
cos_angle = np.clip(cos_angle, -1.0, 1.0) # 防止数值误差
angle = np.arccos(cos_angle)
# 转换为角度
return np.degrees(angle)
def extract_joint_angles(self, keypoints, scores, confidence_threshold=0.3):
"""从关键点提取关节角度"""
if keypoints is None or len(keypoints) == 0:
return None
# 取第一个人的关键点
person_keypoints = keypoints[0] if len(keypoints.shape) > 2 else keypoints
person_scores = scores[0] if len(scores.shape) > 1 else scores
joint_angles_result = {}
for joint_name, (p1_name, p2_name, p3_name) in self.joint_angles.items():
try:
# 获取关键点索引
p1_idx = self.keypoint_map[p1_name]
p2_idx = self.keypoint_map[p2_name]
p3_idx = self.keypoint_map[p3_name]
# 检查置信度
if (person_scores[p1_idx] < confidence_threshold or
person_scores[p2_idx] < confidence_threshold or
person_scores[p3_idx] < confidence_threshold):
continue
# 获取坐标
p1 = person_keypoints[p1_idx]
p2 = person_keypoints[p2_idx]
p3 = person_keypoints[p3_idx]
# 计算角度
angle = self.calculate_angle(p1, p2, p3)
if angle is not None:
joint_angles_result[joint_name] = angle
except (KeyError, IndexError) as e:
continue
return joint_angles_result
def calculate_similarity(self, angles1, angles2):
"""计算两组关节角度的相似度"""
if not angles1 or not angles2:
return 0.0
# 找到共同的关节
common_joints = set(angles1.keys()) & set(angles2.keys())
if not common_joints:
return 0.0
total_weight = 0
weighted_similarity = 0
for joint in common_joints:
angle_diff = abs(angles1[joint] - angles2[joint])
# 角度差异转换为相似度0-1
# 使用高斯函数,角度差异越小,相似度越高
similarity = math.exp(-(angle_diff ** 2) / (2 * (30 ** 2))) # 30度标准差
# 应用权重
weight = self.joint_weights.get(joint, 1.0)
weighted_similarity += similarity * weight
total_weight += weight
# 归一化相似度
final_similarity = weighted_similarity / total_weight if total_weight > 0 else 0
return min(max(final_similarity * 100, 0), 100) # 转换为0-100分
def add_similarity_score(self, score, timestamp=None):
"""添加相似度分数到历史记录"""
if self.start_time is None:
self.start_time = time.time()
if timestamp is None:
timestamp = time.time() - self.start_time
self.similarity_history.append(score)
self.frame_timestamps.append(timestamp)
def get_similarity_plot(self):
"""生成相似度变化折线图"""
if len(self.similarity_history) < 2:
return None
fig = go.Figure()
fig.add_trace(go.Scatter(
x=list(self.frame_timestamps),
y=list(self.similarity_history),
mode='lines+markers',
name='动作相似度',
line=dict(color='#2E86AB', width=2),
marker=dict(size=4)
))
fig.update_layout(
title='动作相似度变化趋势',
xaxis_title='时间 (秒)',
yaxis_title='相似度分数 (%)',
yaxis=dict(range=[0, 100]),
height=300,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=False
)
# 添加平均分数线
if len(self.similarity_history) > 0:
avg_score = sum(self.similarity_history) / len(self.similarity_history)
fig.add_hline(
y=avg_score,
line_dash="dash",
line_color="red",
annotation_text=f"平均分: {avg_score:.1f}%"
)
return fig
def reset(self):
"""重置分析器"""
self.similarity_history.clear()
self.frame_timestamps.clear()
self.start_time = None
class MotionComparisonApp:
def __init__(self):
self.body_detector = None
self.is_running = False
self.standard_video_path = None
self.webcam_cap = None
self.standard_cap = None
self.similarity_analyzer = PoseSimilarityAnalyzer()
self.frame_counter = 0
def preview_webcam(self):
"""显示摄像头预览,帮助用户调整位置"""
# 打开摄像头
self.webcam_cap = cv2.VideoCapture(0)
if not self.webcam_cap.isOpened():
st.error("无法打开摄像头")
return False
# 创建显示容器
st.subheader("摄像头预览")
preview_text = st.empty()
preview_text.info("请调整您的位置,确保全身在画面中清晰可见")
preview_placeholder = st.empty()
# 显示停止预览按钮
col1, col2, col3 = st.columns([1, 1, 1])
stop_preview = col2.button("停止预览", key="stop_preview_btn")
# 预览循环
while not stop_preview:
# 读取摄像头帧
ret, frame = self.webcam_cap.read()
if not ret:
st.error("无法获取摄像头画面")
break
# 翻转摄像头画面(镜像效果)
frame = cv2.flip(frame, 1)
# 如果检测器已初始化,可以显示关键点检测结果
if self.body_detector is not None:
try:
keypoints, scores = self.body_detector(frame)
frame = draw_skeleton(
frame.copy(),
keypoints,
scores,
openpose_skeleton=True,
kpt_thr=0.43
)
except Exception as e:
pass # 预览阶段,忽略检测错误
# 转换颜色空间并显示
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
preview_placeholder.image(frame_rgb, caption="摄像头预览", use_column_width=True)
# 检查停止按钮状态
if col2.button("停止预览", key=f"stop_preview_btn_{time.time()}", replace=True):
break
# 控制帧率
time.sleep(0.03) # 约30fps
# 返回预览是否成功
return True
def initialize_detector(self):
"""初始化身体关键点检测器"""
if self.body_detector is None:
device = 'cuda' if torch.cuda.device_count()> 0 else 'cpu'
self.body_detector = Body(
mode='lightweight', # 使用轻量级模式提高实时性能
to_openpose=True,
backend='onnxruntime',
device=device
)
st.success(f"关键点检测器初始化完成 (设备: {device})")
def process_frame_with_keypoints(self, frame):
"""处理帧并添加关键点"""
if self.body_detector is None:
return frame
try:
# 检测关键点
keypoints, scores = self.body_detector(frame)
# 绘制关键点
result_frame = draw_skeleton(
frame.copy(),
keypoints,
scores,
openpose_skeleton=True,
kpt_thr=0.43
)
return result_frame
except Exception as e:
st.error(f"关键点检测错误: {e}")
return frame
def start_comparison(self, video_path):
"""开始动作比较"""
self.is_running = True
self.standard_video_path = video_path
self.frame_counter = 0
# 重置相似度分析器
self.similarity_analyzer.reset()
# 打开标准视频
self.standard_cap = cv2.VideoCapture(video_path)
if not self.standard_cap.isOpened():
st.error("无法打开标准视频文件")
return
# 如果摄像头未打开,则打开摄像头
if self.webcam_cap is None or not self.webcam_cap.isOpened():
self.webcam_cap = cv2.VideoCapture(0)
if not self.webcam_cap.isOpened():
st.error("无法打开摄像头")
return
# 获取视频信息
fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
frame_delay = 1.0 / fps if fps > 0 else 1.0 / 30
# 创建显示容器
col1, col2 = st.columns(2)
with col1:
st.subheader("标准动作视频")
standard_placeholder = st.empty()
with col2:
st.subheader("摄像头实时影像")
webcam_placeholder = st.empty()
# 添加相似度显区域
similarity_col1, similarity_col2 = st.columns([1, 2])
with similarity_col1:
st.subheader("实时相似度")
similarity_score_placeholder = st.empty()
avg_score_placeholder = st.empty()
with similarity_col2:
st.subheader("相似度变化图")
similarity_plot_placeholder = st.empty()
# 添加控制按钮
control_col1, control_col2, control_col3 = st.columns(3)
with control_col1:
if st.button("停止", key="stop_btn"):
self.is_running = False
with control_col2:
if st.button("重新开始", key="restart_btn"):
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.similarity_analyzer.reset()
# 主循环
frame_count = 0
start_time = time.time()
current_similarity = 0
while self.is_running:
self.frame_counter += 1
# 读取标准视频帧
ret_standard, standard_frame = self.standard_cap.read()
if not ret_standard:
# 视频结束,重新开始
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
# 读取摄像头帧
ret_webcam, webcam_frame = self.webcam_cap.read()
if not ret_webcam:
st.error("无法获取摄像头画面")
break
# 翻转摄像头画面(镜像效果)
webcam_frame = cv2.flip(webcam_frame, 1)
# 调整尺寸使两个视频大小一致
target_height = 480
target_width = 640
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# 处理关键点检测
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
# 每10帧计算一次相似度
if self.frame_counter % 10 == 0:
# 提取关节角度
standard_angles = self.similarity_analyzer.extract_joint_angles(
standard_keypoints, standard_scores
)
webcam_angles = self.similarity_analyzer.extract_joint_angles(
webcam_keypoints, webcam_scores
)
# 计算相似度
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(
standard_angles, webcam_angles
)
# 添加到历史记录
timestamp = time.time() - start_time
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
# 绘制关键点
standard_with_keypoints = draw_skeleton(
standard_frame.copy(),
standard_keypoints,
standard_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
webcam_with_keypoints = draw_skeleton(
webcam_frame.copy(),
webcam_keypoints,
webcam_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
# 转换颜色空间 (BGR to RGB)
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
# 显示画面
standard_placeholder.image(standard_rgb, caption="标准动作", use_column_width=True)
webcam_placeholder.image(webcam_rgb, caption="您的动作", use_column_width=True)
# 显示相似度信息
if len(self.similarity_analyzer.similarity_history) > 0:
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
# 使用不同颜色显示相似度
if current_similarity >= 80:
similarity_color = "🟢"
elif current_similarity >= 60:
similarity_color = "🟡"
else:
similarity_color = "🔴"
similarity_score_placeholder.metric(
"当前相似度",
f"{similarity_color} {current_similarity:.1f}%",
delta=f"{current_similarity - avg_similarity:.1f}%" if len(self.similarity_analyzer.similarity_history) > 1 else None
)
avg_score_placeholder.metric(
"平均相似度",
f"{avg_similarity:.1f}%"
)
# 更新相似度图表
if len(self.similarity_analyzer.similarity_history) >= 2:
similarity_plot = self.similarity_analyzer.get_similarity_plot()
if similarity_plot:
similarity_plot_placeholder.plotly_chart(
similarity_plot,
use_container_width=True
)
# 计算FPS
frame_count += 1
if frame_count % 30 == 0:
elapsed = time.time() - start_time
fps_display = frame_count / elapsed
st.sidebar.metric("实时FPS", f"{fps_display:.1f}")
# 控制播放速度
time.sleep(frame_delay)
# 显示最终统计
if len(self.similarity_analyzer.similarity_history) > 0:
final_avg = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
max_similarity = max(self.similarity_analyzer.similarity_history)
min_similarity = min(self.similarity_analyzer.similarity_history)
st.success(f"动作比较完成!")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("平均相似度", f"{final_avg:.1f}%")
with col2:
st.metric("最高相似度", f"{max_similarity:.1f}%")
with col3:
st.metric("最低相似度", f"{min_similarity:.1f}%")
# 清理资源
if self.standard_cap:
self.standard_cap.release()
if self.webcam_cap:
self.webcam_cap.release()
def main():
os.makedirs("preset_videos", exist_ok=True)
# 页面配置
st.set_page_config(
page_title="动作比较与关键点检测",
page_icon="🏃",
layout="wide"
)
st.title("🏃 动作比较与关键点检测系统")
st.markdown("---")
# 创建应用实例
if 'app' not in st.session_state:
st.session_state.app = MotionComparisonApp()
app = st.session_state.app
# 侧边栏控制面板
with st.sidebar:
st.header("控制面板")
# 预设视频选择
preset_videos = {
"六字诀": "preset_videos/liuzi.mp4",
}
# 在这里放置预设视频状态检查代码
st.subheader("预设视频状态")
video_status_ok = False # 跟踪是否至少有一个预设视频可用
for name, path in preset_videos.items():
if os.path.exists(path):
st.success(f"{name}")
video_status_ok = True
else:
st.error(f"{name} (文件不存在)")
if not video_status_ok:
st.warning("⚠️ 没有找到任何预设视频文件!请将视频文件放入'preset_videos'文件夹。")
# 接下来是视频来源选择
video_source = st.radio(
"选择视频来源",
["预设视频", "上传视频"],
index=0 if video_status_ok else 1 # 如果没有预设视频可用,默认选择上传
)
if video_source == "预设视频":
selected_preset = st.selectbox(
"选择预设视频",
list(preset_videos.keys())
)
video_path = preset_videos[selected_preset]
# 检查预设视频文件是否存在
if not os.path.exists(video_path):
st.error(f"预设视频文件不存在: {video_path}")
st.info("请确保在'preset_videos'文件夹中放置了对应视频文件")
video_path = None
else:
st.success(f"已选择预设视频: {selected_preset}")
else:
# 上传标准视频调高大小限制到1GB
uploaded_video = st.file_uploader(
"上传标准动作视频",
type=['mp4', 'avi', 'mov', 'mkv'],
help="支持 MP4, AVI, MOV, MKV 格式",
accept_multiple_files=False
)
# 设置最大上传大小 (1GB)
# 注意这需要修改streamlit的配置
st.markdown("""
<style>
.uploadedFile {max-width: 100% !important;}
</style>
""", unsafe_allow_html=True)
if uploaded_video is not None:
# 保存上传的视频到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(uploaded_video.read())
video_path = tmp_file.name
st.success("✅ 视频上传成功!")
else:
video_path = None
# 初始化检测器
if st.button("初始化关键点检测器"):
with st.spinner("正在初始化..."):
app.initialize_detector()
# 显示系统信息
st.subheader("系统信息")
if torch.cuda.device_count() > 0:
st.success("✅ CUDA 可用")
else:
st.info(" 使用 CPU 模式")
# 主界面
if video_path is not None:
# 显示视频信息
cap = cv2.VideoCapture(video_path)
if cap.isOpened():
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps if fps > 0 else 0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("时长", f"{duration:.1f}")
with col2:
st.metric("帧率", f"{fps:.1f} FPS")
with col3:
st.metric("分辨率", f"{width}×{height}")
with col4:
st.metric("总帧数", f"{frame_count}")
cap.release()
# 开始按钮
st.markdown("---")
col1, col2, col3 = st.columns([1, 1, 1])
# 添加一个session_state变量来跟踪预览状态
if 'preview_done' not in st.session_state:
st.session_state.preview_done = False
with col2:
# 如果还没预览,显示"预览摄像头"按钮
if not st.session_state.preview_done:
if st.button("📷 预览摄像头", use_container_width=True):
if app.body_detector is None:
st.error("请先初始化关键点检测器")
else:
# 进行摄像头预览
preview_success = app.preview_webcam()
if preview_success:
st.session_state.preview_done = True
# 使用rerun来刷新界面显示开始比较按钮
st.rerun()
# 如果已经预览过,显示"开始动作比较"按钮
else:
if st.button("🚀 开始动作比较", use_container_width=True):
if app.body_detector is None:
st.error("请先初始化关键点检测器")
else:
st.info("开始动作比较...")
app.start_comparison(video_path)
# 重置预览状态,以便下次使用
st.session_state.preview_done = False
# 如果是上传的视频,结束时清理临时文件
if video_source == "上传视频" and os.path.exists(video_path) and video_path.startswith(tempfile.gettempdir()):
try:
# 这里不要立即删除,而是在应用结束时删除
# 注册一个退出时删除文件的函数
import atexit
atexit.register(lambda: os.unlink(video_path) if os.path.exists(video_path) else None)
except:
pass
else:
if video_source == "预设视频":
st.error("预设视频文件不可用,请检查文件路径")
else:
st.info("👆 请在左侧上传标准动作视频文件")
# 显示使用说明
with st.expander("📖 使用说明"):
st.markdown("""
### 如何使用此应用:
1. **上传视频**: 在左侧上传标准动作视频文件
2. **初始化检测器**: 点击"初始化关键点检测器"按钮
3. **开始比较**: 点击"开始动作比较"按钮
4. **跟随动作**: 在摄像头前跟随标准视频做相同动作
5. **观察对比**: 系统会实时显示两边的关键点检测结果
### 特性:
- 实时关键点检测
- 镜像摄像头画面
- 自动循环播放标准视频
- FPS 监控
- GPU 加速支持
### 要求:
- 摄像头权限
- 良好的光照条件
- 清晰的人体轮廓
""")
if __name__ == "__main__":
main()

32
test.py Normal file
View File

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""
简单GPU检测脚本 - 检测系统是否有显卡及数量
"""
import torch
import cv2
def check_gpu():
"""检测GPU数量"""
print("🔍 正在检测GPU...")
# 检查CUDA是否可用
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"✅ 检测到 {gpu_count} 张GPU")
print(f"cv2 检测到{cv2.cuda.getCudaEnabledDeviceCount()}张gpu")
# 显示GPU名称
for i in range(gpu_count):
gpu_name = torch.cuda.get_device_name(i)
print(f" GPU {i}: {gpu_name}")
else:
print("❌ 未检测到可用的GPU")
# 检查MPS (Apple Silicon)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
print("✅ 检测到Apple Silicon GPU (MPS)")
if __name__ == "__main__":
check_gpu()