first commit
This commit is contained in:
commit
636fda3598
3
.streamlit/config.toml
Normal file
3
.streamlit/config.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[server]
|
||||
maxUploadSize = 1024 # 允许上传最大1GB文件
|
||||
port = 8080
|
143
environment.yml
Normal file
143
environment.yml
Normal file
@ -0,0 +1,143 @@
|
||||
name: /root/shared-nvme/posedet/posedet
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1
|
||||
- _openmp_mutex=5.1
|
||||
- asttokens=3.0.0
|
||||
- blas=1.0
|
||||
- bzip2=1.0.8
|
||||
- ca-certificates=2025.2.25
|
||||
- comm=0.2.1
|
||||
- debugpy=1.8.11
|
||||
- decorator=5.1.1
|
||||
- exceptiongroup=1.2.0
|
||||
- executing=0.8.3
|
||||
- expat=2.7.1
|
||||
- intel-openmp=2023.1.0
|
||||
- ipykernel=6.29.5
|
||||
- ipython=8.30.0
|
||||
- jedi=0.19.2
|
||||
- jupyter_client=8.6.3
|
||||
- jupyter_core=5.7.2
|
||||
- ld_impl_linux-64=2.40
|
||||
- libffi=3.4.4
|
||||
- libgcc-ng=11.2.0
|
||||
- libgomp=11.2.0
|
||||
- libsodium=1.0.18
|
||||
- libstdcxx-ng=11.2.0
|
||||
- libuuid=1.41.5
|
||||
- libxcb=1.17.0
|
||||
- matplotlib-inline=0.1.6
|
||||
- mkl=2023.1.0
|
||||
- mkl-service=2.4.0
|
||||
- mkl_fft=1.3.11
|
||||
- mkl_random=1.2.8
|
||||
- ncurses=6.4
|
||||
- nest-asyncio=1.6.0
|
||||
- numpy-base=2.2.5
|
||||
- openssl=3.0.16
|
||||
- packaging=24.2
|
||||
- parso=0.8.4
|
||||
- pexpect=4.8.0
|
||||
- pip=25.1
|
||||
- platformdirs=4.3.7
|
||||
- prompt-toolkit=3.0.43
|
||||
- prompt_toolkit=3.0.43
|
||||
- psutil=5.9.0
|
||||
- pthread-stubs=0.3
|
||||
- ptyprocess=0.7.0
|
||||
- pure_eval=0.2.2
|
||||
- pygments=2.19.1
|
||||
- python=3.10.18
|
||||
- python-dateutil=2.9.0post0
|
||||
- pyzmq=26.2.0
|
||||
- readline=8.2
|
||||
- setuptools=72.1.0
|
||||
- six=1.17.0
|
||||
- sqlite=3.45.3
|
||||
- stack_data=0.2.0
|
||||
- tbb=2021.8.0
|
||||
- tk=8.6.14
|
||||
- tornado=6.5.1
|
||||
- traitlets=5.14.3
|
||||
- typing_extensions=4.12.2
|
||||
- tzdata=2025b
|
||||
- wcwidth=0.2.5
|
||||
- wheel=0.45.1
|
||||
- xorg-libx11=1.8.12
|
||||
- xorg-libxau=1.0.12
|
||||
- xorg-libxdmcp=1.1.5
|
||||
- xorg-xorgproto=2024.1
|
||||
- xz=5.6.4
|
||||
- zeromq=4.3.5
|
||||
- zlib=1.2.13
|
||||
- pip:
|
||||
- altair==5.5.0
|
||||
- attrs==25.3.0
|
||||
- blinker==1.9.0
|
||||
- cachetools==5.5.2
|
||||
- certifi==2025.4.26
|
||||
- charset-normalizer==3.4.2
|
||||
- click==8.2.1
|
||||
- coloredlogs==15.0.1
|
||||
- contourpy==1.3.2
|
||||
- cycler==0.12.1
|
||||
- filelock==3.18.0
|
||||
- flatbuffers==25.2.10
|
||||
- fonttools==4.58.4
|
||||
- fsspec==2025.5.1
|
||||
- gitdb==4.0.12
|
||||
- gitpython==3.1.44
|
||||
- humanfriendly==10.0
|
||||
- idna==3.10
|
||||
- jinja2==3.1.6
|
||||
- jsonschema==4.24.0
|
||||
- jsonschema-specifications==2025.4.1
|
||||
- kiwisolver==1.4.8
|
||||
- markupsafe==3.0.2
|
||||
- matplotlib==3.10.3
|
||||
- mpmath==1.3.0
|
||||
- narwhals==1.42.1
|
||||
- networkx==3.4.2
|
||||
- numpy==2.2.6
|
||||
- nvidia-cublas-cu12==12.6.4.1
|
||||
- nvidia-cuda-cupti-cu12==12.6.80
|
||||
- nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
- nvidia-cuda-runtime-cu12==12.6.77
|
||||
- nvidia-cudnn-cu12==9.5.1.17
|
||||
- nvidia-cufft-cu12==11.3.0.4
|
||||
- nvidia-cufile-cu12==1.11.1.6
|
||||
- nvidia-curand-cu12==10.3.7.77
|
||||
- nvidia-cusolver-cu12==11.7.1.2
|
||||
- nvidia-cusparse-cu12==12.5.4.2
|
||||
- nvidia-cusparselt-cu12==0.6.3
|
||||
- nvidia-nccl-cu12==2.26.2
|
||||
- nvidia-nvjitlink-cu12==12.6.85
|
||||
- nvidia-nvtx-cu12==12.6.77
|
||||
- onnxruntime==1.22.0
|
||||
- opencv-contrib-python-headless==4.11.0.86
|
||||
- pillow==11.2.1
|
||||
- plotly==6.1.2
|
||||
- protobuf==6.31.1
|
||||
- pyarrow==20.0.0
|
||||
- pydeck==0.9.1
|
||||
- pyparsing==3.2.3
|
||||
- pytz==2025.2
|
||||
- referencing==0.36.2
|
||||
- requests==2.32.4
|
||||
- rpds-py==0.25.1
|
||||
- rtmlib==0.0.13
|
||||
- smmap==5.0.2
|
||||
- streamlit==1.45.1
|
||||
- sympy==1.14.0
|
||||
- tenacity==9.1.2
|
||||
- toml==0.10.2
|
||||
- torch==2.7.1
|
||||
- torchaudio==2.7.1
|
||||
- torchvision==0.22.1
|
||||
- tqdm==4.67.1
|
||||
- triton==3.3.1
|
||||
- urllib3==2.4.0
|
||||
- watchdog==6.0.0
|
||||
prefix: /root/shared-nvme/posedet/posedet
|
BIN
example.png
Normal file
BIN
example.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.0 MiB |
54
main.py
Normal file
54
main.py
Normal file
@ -0,0 +1,54 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from rtmlib import Body, draw_skeleton
|
||||
|
||||
def detect_body_keypoints_image(image_path, output_path=None):
|
||||
# 初始化模型
|
||||
device = 'cuda' if cv2.cuda.getCudaEnabledDeviceCount() > 0 else 'cpu'
|
||||
backend = 'onnxruntime'
|
||||
openpose_skeleton = True # True为OpenPose风格,False为MMPose风格
|
||||
|
||||
# 创建Body实例,使用平衡模式
|
||||
body = Body(
|
||||
mode='balanced',
|
||||
to_openpose=openpose_skeleton,
|
||||
backend=backend,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 读取图片
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f"无法读取图片: {image_path}")
|
||||
return
|
||||
|
||||
# 检测身体关键点
|
||||
keypoints, scores = body(image)
|
||||
|
||||
print(keypoints)
|
||||
|
||||
# 绘制关键点
|
||||
result_image = draw_skeleton(
|
||||
image.copy(),
|
||||
keypoints,
|
||||
scores,
|
||||
openpose_skeleton=openpose_skeleton,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
|
||||
# 保存或显示结果
|
||||
if output_path:
|
||||
cv2.imwrite(output_path, result_image)
|
||||
print(f"已保存结果到: {output_path}")
|
||||
|
||||
# 显示结果
|
||||
# cv2.imshow('Body Keypoints Detection', result_image)
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为您的图片路径
|
||||
image_path = "example.png"
|
||||
output_path = "body_result.jpg"
|
||||
detect_body_keypoints_image(image_path, output_path)
|
||||
|
705
motion_comparison_app.py
Normal file
705
motion_comparison_app.py
Normal file
@ -0,0 +1,705 @@
|
||||
import streamlit as st
|
||||
import cv2
|
||||
import time
|
||||
import tempfile
|
||||
import os
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import math
|
||||
|
||||
|
||||
# 导入rtmlib
|
||||
try:
|
||||
from rtmlib import Body, draw_skeleton
|
||||
except ImportError:
|
||||
st.error("请安装rtmlib库: pip install rtmlib")
|
||||
st.stop()
|
||||
|
||||
class PoseSimilarityAnalyzer:
|
||||
def __init__(self):
|
||||
self.similarity_history = deque(maxlen=500) # 保存最近500个相似度值
|
||||
self.frame_timestamps = deque(maxlen=500)
|
||||
self.start_time = None
|
||||
|
||||
# OpenPose关键点索引映射(Body 17个关键点)
|
||||
self.keypoint_map = {
|
||||
'nose': 0, 'neck': 1,
|
||||
'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
|
||||
'right_shoulder': 5, 'right_elbow': 6, 'right_wrist': 7,
|
||||
'left_hip': 8, 'left_knee': 9, 'left_ankle': 10,
|
||||
'right_hip': 11, 'right_knee': 12, 'right_ankle': 13,
|
||||
'left_eye': 14, 'right_eye': 15,
|
||||
'left_ear': 16, 'right_ear': 17
|
||||
}
|
||||
|
||||
# 关节角度定义(关节名:[点1, 关节点, 点2])
|
||||
self.joint_angles = {
|
||||
'left_elbow': ['left_shoulder', 'left_elbow', 'left_wrist'],
|
||||
'right_elbow': ['right_shoulder', 'right_elbow', 'right_wrist'],
|
||||
'left_shoulder': ['left_elbow', 'left_shoulder', 'neck'],
|
||||
'right_shoulder': ['right_elbow', 'right_shoulder', 'neck'],
|
||||
'left_knee': ['left_hip', 'left_knee', 'left_ankle'],
|
||||
'right_knee': ['right_hip', 'right_knee', 'right_ankle'],
|
||||
'left_hip': ['left_knee', 'left_hip', 'neck'],
|
||||
'right_hip': ['right_knee', 'right_hip', 'neck'],
|
||||
}
|
||||
|
||||
# 关节权重(重要关节权重更高)
|
||||
self.joint_weights = {
|
||||
'left_elbow': 1.2,
|
||||
'right_elbow': 1.2,
|
||||
'left_shoulder': 1.0,
|
||||
'right_shoulder': 1.0,
|
||||
'left_knee': 1.3,
|
||||
'right_knee': 1.3,
|
||||
'left_hip': 1.1,
|
||||
'right_hip': 1.1,
|
||||
}
|
||||
|
||||
def calculate_angle(self, p1, p2, p3):
|
||||
"""计算三个点组成的角度"""
|
||||
# 向量v1 = p1 - p2, v2 = p3 - p2
|
||||
v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]])
|
||||
v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]])
|
||||
|
||||
# 计算向量长度
|
||||
v1_norm = np.linalg.norm(v1)
|
||||
v2_norm = np.linalg.norm(v2)
|
||||
|
||||
if v1_norm == 0 or v2_norm == 0:
|
||||
return None
|
||||
|
||||
# 计算夹角(弧度)
|
||||
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
|
||||
cos_angle = np.clip(cos_angle, -1.0, 1.0) # 防止数值误差
|
||||
angle = np.arccos(cos_angle)
|
||||
|
||||
# 转换为角度
|
||||
return np.degrees(angle)
|
||||
|
||||
def extract_joint_angles(self, keypoints, scores, confidence_threshold=0.3):
|
||||
"""从关键点提取关节角度"""
|
||||
if keypoints is None or len(keypoints) == 0:
|
||||
return None
|
||||
|
||||
# 取第一个人的关键点
|
||||
person_keypoints = keypoints[0] if len(keypoints.shape) > 2 else keypoints
|
||||
person_scores = scores[0] if len(scores.shape) > 1 else scores
|
||||
|
||||
joint_angles_result = {}
|
||||
|
||||
for joint_name, (p1_name, p2_name, p3_name) in self.joint_angles.items():
|
||||
try:
|
||||
# 获取关键点索引
|
||||
p1_idx = self.keypoint_map[p1_name]
|
||||
p2_idx = self.keypoint_map[p2_name]
|
||||
p3_idx = self.keypoint_map[p3_name]
|
||||
|
||||
# 检查置信度
|
||||
if (person_scores[p1_idx] < confidence_threshold or
|
||||
person_scores[p2_idx] < confidence_threshold or
|
||||
person_scores[p3_idx] < confidence_threshold):
|
||||
continue
|
||||
|
||||
# 获取坐标
|
||||
p1 = person_keypoints[p1_idx]
|
||||
p2 = person_keypoints[p2_idx]
|
||||
p3 = person_keypoints[p3_idx]
|
||||
|
||||
# 计算角度
|
||||
angle = self.calculate_angle(p1, p2, p3)
|
||||
if angle is not None:
|
||||
joint_angles_result[joint_name] = angle
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
continue
|
||||
|
||||
return joint_angles_result
|
||||
|
||||
def calculate_similarity(self, angles1, angles2):
|
||||
"""计算两组关节角度的相似度"""
|
||||
if not angles1 or not angles2:
|
||||
return 0.0
|
||||
|
||||
# 找到共同的关节
|
||||
common_joints = set(angles1.keys()) & set(angles2.keys())
|
||||
if not common_joints:
|
||||
return 0.0
|
||||
|
||||
total_weight = 0
|
||||
weighted_similarity = 0
|
||||
|
||||
for joint in common_joints:
|
||||
angle_diff = abs(angles1[joint] - angles2[joint])
|
||||
|
||||
# 角度差异转换为相似度(0-1)
|
||||
# 使用高斯函数,角度差异越小,相似度越高
|
||||
similarity = math.exp(-(angle_diff ** 2) / (2 * (30 ** 2))) # 30度标准差
|
||||
|
||||
# 应用权重
|
||||
weight = self.joint_weights.get(joint, 1.0)
|
||||
weighted_similarity += similarity * weight
|
||||
total_weight += weight
|
||||
|
||||
# 归一化相似度
|
||||
final_similarity = weighted_similarity / total_weight if total_weight > 0 else 0
|
||||
return min(max(final_similarity * 100, 0), 100) # 转换为0-100分
|
||||
|
||||
def add_similarity_score(self, score, timestamp=None):
|
||||
"""添加相似度分数到历史记录"""
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
|
||||
if timestamp is None:
|
||||
timestamp = time.time() - self.start_time
|
||||
|
||||
self.similarity_history.append(score)
|
||||
self.frame_timestamps.append(timestamp)
|
||||
|
||||
def get_similarity_plot(self):
|
||||
"""生成相似度变化折线图"""
|
||||
if len(self.similarity_history) < 2:
|
||||
return None
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=list(self.frame_timestamps),
|
||||
y=list(self.similarity_history),
|
||||
mode='lines+markers',
|
||||
name='动作相似度',
|
||||
line=dict(color='#2E86AB', width=2),
|
||||
marker=dict(size=4)
|
||||
))
|
||||
|
||||
fig.update_layout(
|
||||
title='动作相似度变化趋势',
|
||||
xaxis_title='时间 (秒)',
|
||||
yaxis_title='相似度分数 (%)',
|
||||
yaxis=dict(range=[0, 100]),
|
||||
height=300,
|
||||
margin=dict(l=50, r=50, t=50, b=50),
|
||||
showlegend=False
|
||||
)
|
||||
|
||||
# 添加平均分数线
|
||||
if len(self.similarity_history) > 0:
|
||||
avg_score = sum(self.similarity_history) / len(self.similarity_history)
|
||||
fig.add_hline(
|
||||
y=avg_score,
|
||||
line_dash="dash",
|
||||
line_color="red",
|
||||
annotation_text=f"平均分: {avg_score:.1f}%"
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def reset(self):
|
||||
"""重置分析器"""
|
||||
self.similarity_history.clear()
|
||||
self.frame_timestamps.clear()
|
||||
self.start_time = None
|
||||
|
||||
|
||||
class MotionComparisonApp:
|
||||
def __init__(self):
|
||||
self.body_detector = None
|
||||
self.is_running = False
|
||||
self.standard_video_path = None
|
||||
self.webcam_cap = None
|
||||
self.standard_cap = None
|
||||
self.similarity_analyzer = PoseSimilarityAnalyzer()
|
||||
self.frame_counter = 0
|
||||
|
||||
|
||||
def preview_webcam(self):
|
||||
"""显示摄像头预览,帮助用户调整位置"""
|
||||
# 打开摄像头
|
||||
self.webcam_cap = cv2.VideoCapture(0)
|
||||
if not self.webcam_cap.isOpened():
|
||||
st.error("无法打开摄像头")
|
||||
return False
|
||||
|
||||
# 创建显示容器
|
||||
st.subheader("摄像头预览")
|
||||
preview_text = st.empty()
|
||||
preview_text.info("请调整您的位置,确保全身在画面中清晰可见")
|
||||
|
||||
preview_placeholder = st.empty()
|
||||
|
||||
# 显示停止预览按钮
|
||||
col1, col2, col3 = st.columns([1, 1, 1])
|
||||
stop_preview = col2.button("停止预览", key="stop_preview_btn")
|
||||
|
||||
# 预览循环
|
||||
while not stop_preview:
|
||||
# 读取摄像头帧
|
||||
ret, frame = self.webcam_cap.read()
|
||||
if not ret:
|
||||
st.error("无法获取摄像头画面")
|
||||
break
|
||||
|
||||
# 翻转摄像头画面(镜像效果)
|
||||
frame = cv2.flip(frame, 1)
|
||||
|
||||
# 如果检测器已初始化,可以显示关键点检测结果
|
||||
if self.body_detector is not None:
|
||||
try:
|
||||
keypoints, scores = self.body_detector(frame)
|
||||
frame = draw_skeleton(
|
||||
frame.copy(),
|
||||
keypoints,
|
||||
scores,
|
||||
openpose_skeleton=True,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
except Exception as e:
|
||||
pass # 预览阶段,忽略检测错误
|
||||
|
||||
# 转换颜色空间并显示
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
preview_placeholder.image(frame_rgb, caption="摄像头预览", use_column_width=True)
|
||||
|
||||
# 检查停止按钮状态
|
||||
if col2.button("停止预览", key=f"stop_preview_btn_{time.time()}", replace=True):
|
||||
break
|
||||
|
||||
# 控制帧率
|
||||
time.sleep(0.03) # 约30fps
|
||||
|
||||
# 返回预览是否成功
|
||||
return True
|
||||
|
||||
def initialize_detector(self):
|
||||
"""初始化身体关键点检测器"""
|
||||
if self.body_detector is None:
|
||||
device = 'cuda' if torch.cuda.device_count()> 0 else 'cpu'
|
||||
self.body_detector = Body(
|
||||
mode='lightweight', # 使用轻量级模式提高实时性能
|
||||
to_openpose=True,
|
||||
backend='onnxruntime',
|
||||
device=device
|
||||
)
|
||||
st.success(f"关键点检测器初始化完成 (设备: {device})")
|
||||
|
||||
def process_frame_with_keypoints(self, frame):
|
||||
"""处理帧并添加关键点"""
|
||||
if self.body_detector is None:
|
||||
return frame
|
||||
|
||||
try:
|
||||
# 检测关键点
|
||||
keypoints, scores = self.body_detector(frame)
|
||||
|
||||
# 绘制关键点
|
||||
result_frame = draw_skeleton(
|
||||
frame.copy(),
|
||||
keypoints,
|
||||
scores,
|
||||
openpose_skeleton=True,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
return result_frame
|
||||
except Exception as e:
|
||||
st.error(f"关键点检测错误: {e}")
|
||||
return frame
|
||||
|
||||
def start_comparison(self, video_path):
|
||||
"""开始动作比较"""
|
||||
self.is_running = True
|
||||
self.standard_video_path = video_path
|
||||
self.frame_counter = 0
|
||||
|
||||
# 重置相似度分析器
|
||||
self.similarity_analyzer.reset()
|
||||
|
||||
# 打开标准视频
|
||||
self.standard_cap = cv2.VideoCapture(video_path)
|
||||
if not self.standard_cap.isOpened():
|
||||
st.error("无法打开标准视频文件")
|
||||
return
|
||||
|
||||
# 如果摄像头未打开,则打开摄像头
|
||||
if self.webcam_cap is None or not self.webcam_cap.isOpened():
|
||||
self.webcam_cap = cv2.VideoCapture(0)
|
||||
if not self.webcam_cap.isOpened():
|
||||
st.error("无法打开摄像头")
|
||||
return
|
||||
|
||||
# 获取视频信息
|
||||
fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_delay = 1.0 / fps if fps > 0 else 1.0 / 30
|
||||
|
||||
# 创建显示容器
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.subheader("标准动作视频")
|
||||
standard_placeholder = st.empty()
|
||||
|
||||
with col2:
|
||||
st.subheader("摄像头实时影像")
|
||||
webcam_placeholder = st.empty()
|
||||
|
||||
# 添加相似度显区域
|
||||
similarity_col1, similarity_col2 = st.columns([1, 2])
|
||||
with similarity_col1:
|
||||
st.subheader("实时相似度")
|
||||
similarity_score_placeholder = st.empty()
|
||||
avg_score_placeholder = st.empty()
|
||||
|
||||
with similarity_col2:
|
||||
st.subheader("相似度变化图")
|
||||
similarity_plot_placeholder = st.empty()
|
||||
|
||||
# 添加控制按钮
|
||||
control_col1, control_col2, control_col3 = st.columns(3)
|
||||
with control_col1:
|
||||
if st.button("停止", key="stop_btn"):
|
||||
self.is_running = False
|
||||
with control_col2:
|
||||
if st.button("重新开始", key="restart_btn"):
|
||||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
self.similarity_analyzer.reset()
|
||||
|
||||
# 主循环
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
current_similarity = 0
|
||||
|
||||
while self.is_running:
|
||||
self.frame_counter += 1
|
||||
|
||||
# 读取标准视频帧
|
||||
ret_standard, standard_frame = self.standard_cap.read()
|
||||
if not ret_standard:
|
||||
# 视频结束,重新开始
|
||||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
continue
|
||||
|
||||
# 读取摄像头帧
|
||||
ret_webcam, webcam_frame = self.webcam_cap.read()
|
||||
if not ret_webcam:
|
||||
st.error("无法获取摄像头画面")
|
||||
break
|
||||
|
||||
# 翻转摄像头画面(镜像效果)
|
||||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||||
|
||||
# 调整尺寸使两个视频大小一致
|
||||
target_height = 480
|
||||
target_width = 640
|
||||
|
||||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||||
|
||||
# 处理关键点检测
|
||||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||||
|
||||
# 每10帧计算一次相似度
|
||||
if self.frame_counter % 10 == 0:
|
||||
# 提取关节角度
|
||||
standard_angles = self.similarity_analyzer.extract_joint_angles(
|
||||
standard_keypoints, standard_scores
|
||||
)
|
||||
webcam_angles = self.similarity_analyzer.extract_joint_angles(
|
||||
webcam_keypoints, webcam_scores
|
||||
)
|
||||
|
||||
# 计算相似度
|
||||
if standard_angles and webcam_angles:
|
||||
current_similarity = self.similarity_analyzer.calculate_similarity(
|
||||
standard_angles, webcam_angles
|
||||
)
|
||||
|
||||
# 添加到历史记录
|
||||
timestamp = time.time() - start_time
|
||||
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
|
||||
|
||||
# 绘制关键点
|
||||
standard_with_keypoints = draw_skeleton(
|
||||
standard_frame.copy(),
|
||||
standard_keypoints,
|
||||
standard_scores,
|
||||
openpose_skeleton=True,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
webcam_with_keypoints = draw_skeleton(
|
||||
webcam_frame.copy(),
|
||||
webcam_keypoints,
|
||||
webcam_scores,
|
||||
openpose_skeleton=True,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
|
||||
# 转换颜色空间 (BGR to RGB)
|
||||
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
|
||||
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 显示画面
|
||||
standard_placeholder.image(standard_rgb, caption="标准动作", use_column_width=True)
|
||||
webcam_placeholder.image(webcam_rgb, caption="您的动作", use_column_width=True)
|
||||
|
||||
# 显示相似度信息
|
||||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||||
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||||
|
||||
# 使用不同颜色显示相似度
|
||||
if current_similarity >= 80:
|
||||
similarity_color = "🟢"
|
||||
elif current_similarity >= 60:
|
||||
similarity_color = "🟡"
|
||||
else:
|
||||
similarity_color = "🔴"
|
||||
|
||||
similarity_score_placeholder.metric(
|
||||
"当前相似度",
|
||||
f"{similarity_color} {current_similarity:.1f}%",
|
||||
delta=f"{current_similarity - avg_similarity:.1f}%" if len(self.similarity_analyzer.similarity_history) > 1 else None
|
||||
)
|
||||
|
||||
avg_score_placeholder.metric(
|
||||
"平均相似度",
|
||||
f"{avg_similarity:.1f}%"
|
||||
)
|
||||
|
||||
# 更新相似度图表
|
||||
if len(self.similarity_analyzer.similarity_history) >= 2:
|
||||
similarity_plot = self.similarity_analyzer.get_similarity_plot()
|
||||
if similarity_plot:
|
||||
similarity_plot_placeholder.plotly_chart(
|
||||
similarity_plot,
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
# 计算FPS
|
||||
frame_count += 1
|
||||
if frame_count % 30 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps_display = frame_count / elapsed
|
||||
st.sidebar.metric("实时FPS", f"{fps_display:.1f}")
|
||||
|
||||
# 控制播放速度
|
||||
time.sleep(frame_delay)
|
||||
|
||||
# 显示最终统计
|
||||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||||
final_avg = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||||
max_similarity = max(self.similarity_analyzer.similarity_history)
|
||||
min_similarity = min(self.similarity_analyzer.similarity_history)
|
||||
|
||||
st.success(f"动作比较完成!")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("平均相似度", f"{final_avg:.1f}%")
|
||||
with col2:
|
||||
st.metric("最高相似度", f"{max_similarity:.1f}%")
|
||||
with col3:
|
||||
st.metric("最低相似度", f"{min_similarity:.1f}%")
|
||||
|
||||
# 清理资源
|
||||
if self.standard_cap:
|
||||
self.standard_cap.release()
|
||||
if self.webcam_cap:
|
||||
self.webcam_cap.release()
|
||||
|
||||
def main():
|
||||
os.makedirs("preset_videos", exist_ok=True)
|
||||
# 页面配置
|
||||
st.set_page_config(
|
||||
page_title="动作比较与关键点检测",
|
||||
page_icon="🏃",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
st.title("🏃 动作比较与关键点检测系统")
|
||||
st.markdown("---")
|
||||
|
||||
# 创建应用实例
|
||||
if 'app' not in st.session_state:
|
||||
st.session_state.app = MotionComparisonApp()
|
||||
|
||||
app = st.session_state.app
|
||||
|
||||
# 侧边栏控制面板
|
||||
with st.sidebar:
|
||||
st.header("控制面板")
|
||||
|
||||
# 预设视频选择
|
||||
preset_videos = {
|
||||
"六字诀": "preset_videos/liuzi.mp4",
|
||||
}
|
||||
|
||||
# 在这里放置预设视频状态检查代码
|
||||
st.subheader("预设视频状态")
|
||||
video_status_ok = False # 跟踪是否至少有一个预设视频可用
|
||||
for name, path in preset_videos.items():
|
||||
if os.path.exists(path):
|
||||
st.success(f"✅ {name}")
|
||||
video_status_ok = True
|
||||
else:
|
||||
st.error(f"❌ {name} (文件不存在)")
|
||||
|
||||
if not video_status_ok:
|
||||
st.warning("⚠️ 没有找到任何预设视频文件!请将视频文件放入'preset_videos'文件夹。")
|
||||
|
||||
# 接下来是视频来源选择
|
||||
video_source = st.radio(
|
||||
"选择视频来源",
|
||||
["预设视频", "上传视频"],
|
||||
index=0 if video_status_ok else 1 # 如果没有预设视频可用,默认选择上传
|
||||
)
|
||||
if video_source == "预设视频":
|
||||
selected_preset = st.selectbox(
|
||||
"选择预设视频",
|
||||
list(preset_videos.keys())
|
||||
)
|
||||
video_path = preset_videos[selected_preset]
|
||||
|
||||
# 检查预设视频文件是否存在
|
||||
if not os.path.exists(video_path):
|
||||
st.error(f"预设视频文件不存在: {video_path}")
|
||||
st.info("请确保在'preset_videos'文件夹中放置了对应视频文件")
|
||||
video_path = None
|
||||
else:
|
||||
st.success(f"已选择预设视频: {selected_preset}")
|
||||
else:
|
||||
# 上传标准视频,调高大小限制到1GB
|
||||
uploaded_video = st.file_uploader(
|
||||
"上传标准动作视频",
|
||||
type=['mp4', 'avi', 'mov', 'mkv'],
|
||||
help="支持 MP4, AVI, MOV, MKV 格式",
|
||||
accept_multiple_files=False
|
||||
)
|
||||
|
||||
# 设置最大上传大小 (1GB)
|
||||
# 注意:这需要修改streamlit的配置
|
||||
st.markdown("""
|
||||
<style>
|
||||
.uploadedFile {max-width: 100% !important;}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
if uploaded_video is not None:
|
||||
# 保存上传的视频到临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
|
||||
tmp_file.write(uploaded_video.read())
|
||||
video_path = tmp_file.name
|
||||
st.success("✅ 视频上传成功!")
|
||||
else:
|
||||
video_path = None
|
||||
|
||||
# 初始化检测器
|
||||
if st.button("初始化关键点检测器"):
|
||||
with st.spinner("正在初始化..."):
|
||||
app.initialize_detector()
|
||||
|
||||
# 显示系统信息
|
||||
st.subheader("系统信息")
|
||||
if torch.cuda.device_count() > 0:
|
||||
st.success("✅ CUDA 可用")
|
||||
else:
|
||||
st.info("ℹ️ 使用 CPU 模式")
|
||||
|
||||
|
||||
# 主界面
|
||||
if video_path is not None:
|
||||
# 显示视频信息
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if cap.isOpened():
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = frame_count / fps if fps > 0 else 0
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("时长", f"{duration:.1f}秒")
|
||||
with col2:
|
||||
st.metric("帧率", f"{fps:.1f} FPS")
|
||||
with col3:
|
||||
st.metric("分辨率", f"{width}×{height}")
|
||||
with col4:
|
||||
st.metric("总帧数", f"{frame_count}")
|
||||
|
||||
cap.release()
|
||||
|
||||
# 开始按钮
|
||||
st.markdown("---")
|
||||
col1, col2, col3 = st.columns([1, 1, 1])
|
||||
# 添加一个session_state变量来跟踪预览状态
|
||||
if 'preview_done' not in st.session_state:
|
||||
st.session_state.preview_done = False
|
||||
with col2:
|
||||
# 如果还没预览,显示"预览摄像头"按钮
|
||||
if not st.session_state.preview_done:
|
||||
if st.button("📷 预览摄像头", use_container_width=True):
|
||||
if app.body_detector is None:
|
||||
st.error("请先初始化关键点检测器")
|
||||
else:
|
||||
# 进行摄像头预览
|
||||
preview_success = app.preview_webcam()
|
||||
if preview_success:
|
||||
st.session_state.preview_done = True
|
||||
# 使用rerun来刷新界面,显示开始比较按钮
|
||||
st.rerun()
|
||||
# 如果已经预览过,显示"开始动作比较"按钮
|
||||
else:
|
||||
if st.button("🚀 开始动作比较", use_container_width=True):
|
||||
if app.body_detector is None:
|
||||
st.error("请先初始化关键点检测器")
|
||||
else:
|
||||
st.info("开始动作比较...")
|
||||
app.start_comparison(video_path)
|
||||
# 重置预览状态,以便下次使用
|
||||
st.session_state.preview_done = False
|
||||
|
||||
# 如果是上传的视频,结束时清理临时文件
|
||||
if video_source == "上传视频" and os.path.exists(video_path) and video_path.startswith(tempfile.gettempdir()):
|
||||
try:
|
||||
# 这里不要立即删除,而是在应用结束时删除
|
||||
# 注册一个退出时删除文件的函数
|
||||
import atexit
|
||||
atexit.register(lambda: os.unlink(video_path) if os.path.exists(video_path) else None)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
if video_source == "预设视频":
|
||||
st.error("预设视频文件不可用,请检查文件路径")
|
||||
else:
|
||||
st.info("👆 请在左侧上传标准动作视频文件")
|
||||
|
||||
# 显示使用说明
|
||||
with st.expander("📖 使用说明"):
|
||||
st.markdown("""
|
||||
### 如何使用此应用:
|
||||
|
||||
1. **上传视频**: 在左侧上传标准动作视频文件
|
||||
2. **初始化检测器**: 点击"初始化关键点检测器"按钮
|
||||
3. **开始比较**: 点击"开始动作比较"按钮
|
||||
4. **跟随动作**: 在摄像头前跟随标准视频做相同动作
|
||||
5. **观察对比**: 系统会实时显示两边的关键点检测结果
|
||||
|
||||
### 特性:
|
||||
- ✅ 实时关键点检测
|
||||
- ✅ 镜像摄像头画面
|
||||
- ✅ 自动循环播放标准视频
|
||||
- ✅ FPS 监控
|
||||
- ✅ GPU 加速支持
|
||||
|
||||
### 要求:
|
||||
- 摄像头权限
|
||||
- 良好的光照条件
|
||||
- 清晰的人体轮廓
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
32
test.py
Normal file
32
test.py
Normal file
@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简单GPU检测脚本 - 检测系统是否有显卡及数量
|
||||
"""
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
def check_gpu():
|
||||
"""检测GPU数量"""
|
||||
print("🔍 正在检测GPU...")
|
||||
|
||||
# 检查CUDA是否可用
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
print(f"✅ 检测到 {gpu_count} 张GPU")
|
||||
|
||||
print(f"cv2 检测到{cv2.cuda.getCudaEnabledDeviceCount()}张gpu")
|
||||
|
||||
# 显示GPU名称
|
||||
for i in range(gpu_count):
|
||||
gpu_name = torch.cuda.get_device_name(i)
|
||||
print(f" GPU {i}: {gpu_name}")
|
||||
else:
|
||||
print("❌ 未检测到可用的GPU")
|
||||
|
||||
# 检查MPS (Apple Silicon)
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
print("✅ 检测到Apple Silicon GPU (MPS)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_gpu()
|
Loading…
Reference in New Issue
Block a user