posedet/motion_comparison_app.py
2025-06-16 14:21:04 +08:00

706 lines
26 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import cv2
import time
import tempfile
import os
import torch
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
from collections import deque
import math
# 导入rtmlib
try:
from rtmlib import Body, draw_skeleton
except ImportError:
st.error("请安装rtmlib库: pip install rtmlib")
st.stop()
class PoseSimilarityAnalyzer:
def __init__(self):
self.similarity_history = deque(maxlen=500) # 保存最近500个相似度值
self.frame_timestamps = deque(maxlen=500)
self.start_time = None
# OpenPose关键点索引映射Body 17个关键点
self.keypoint_map = {
'nose': 0, 'neck': 1,
'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
'right_shoulder': 5, 'right_elbow': 6, 'right_wrist': 7,
'left_hip': 8, 'left_knee': 9, 'left_ankle': 10,
'right_hip': 11, 'right_knee': 12, 'right_ankle': 13,
'left_eye': 14, 'right_eye': 15,
'left_ear': 16, 'right_ear': 17
}
# 关节角度定义(关节名:[点1, 关节点, 点2]
self.joint_angles = {
'left_elbow': ['left_shoulder', 'left_elbow', 'left_wrist'],
'right_elbow': ['right_shoulder', 'right_elbow', 'right_wrist'],
'left_shoulder': ['left_elbow', 'left_shoulder', 'neck'],
'right_shoulder': ['right_elbow', 'right_shoulder', 'neck'],
'left_knee': ['left_hip', 'left_knee', 'left_ankle'],
'right_knee': ['right_hip', 'right_knee', 'right_ankle'],
'left_hip': ['left_knee', 'left_hip', 'neck'],
'right_hip': ['right_knee', 'right_hip', 'neck'],
}
# 关节权重(重要关节权重更高)
self.joint_weights = {
'left_elbow': 1.2,
'right_elbow': 1.2,
'left_shoulder': 1.0,
'right_shoulder': 1.0,
'left_knee': 1.3,
'right_knee': 1.3,
'left_hip': 1.1,
'right_hip': 1.1,
}
def calculate_angle(self, p1, p2, p3):
"""计算三个点组成的角度"""
# 向量v1 = p1 - p2, v2 = p3 - p2
v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]])
v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]])
# 计算向量长度
v1_norm = np.linalg.norm(v1)
v2_norm = np.linalg.norm(v2)
if v1_norm == 0 or v2_norm == 0:
return None
# 计算夹角(弧度)
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
cos_angle = np.clip(cos_angle, -1.0, 1.0) # 防止数值误差
angle = np.arccos(cos_angle)
# 转换为角度
return np.degrees(angle)
def extract_joint_angles(self, keypoints, scores, confidence_threshold=0.3):
"""从关键点提取关节角度"""
if keypoints is None or len(keypoints) == 0:
return None
# 取第一个人的关键点
person_keypoints = keypoints[0] if len(keypoints.shape) > 2 else keypoints
person_scores = scores[0] if len(scores.shape) > 1 else scores
joint_angles_result = {}
for joint_name, (p1_name, p2_name, p3_name) in self.joint_angles.items():
try:
# 获取关键点索引
p1_idx = self.keypoint_map[p1_name]
p2_idx = self.keypoint_map[p2_name]
p3_idx = self.keypoint_map[p3_name]
# 检查置信度
if (person_scores[p1_idx] < confidence_threshold or
person_scores[p2_idx] < confidence_threshold or
person_scores[p3_idx] < confidence_threshold):
continue
# 获取坐标
p1 = person_keypoints[p1_idx]
p2 = person_keypoints[p2_idx]
p3 = person_keypoints[p3_idx]
# 计算角度
angle = self.calculate_angle(p1, p2, p3)
if angle is not None:
joint_angles_result[joint_name] = angle
except (KeyError, IndexError) as e:
continue
return joint_angles_result
def calculate_similarity(self, angles1, angles2):
"""计算两组关节角度的相似度"""
if not angles1 or not angles2:
return 0.0
# 找到共同的关节
common_joints = set(angles1.keys()) & set(angles2.keys())
if not common_joints:
return 0.0
total_weight = 0
weighted_similarity = 0
for joint in common_joints:
angle_diff = abs(angles1[joint] - angles2[joint])
# 角度差异转换为相似度0-1
# 使用高斯函数,角度差异越小,相似度越高
similarity = math.exp(-(angle_diff ** 2) / (2 * (30 ** 2))) # 30度标准差
# 应用权重
weight = self.joint_weights.get(joint, 1.0)
weighted_similarity += similarity * weight
total_weight += weight
# 归一化相似度
final_similarity = weighted_similarity / total_weight if total_weight > 0 else 0
return min(max(final_similarity * 100, 0), 100) # 转换为0-100分
def add_similarity_score(self, score, timestamp=None):
"""添加相似度分数到历史记录"""
if self.start_time is None:
self.start_time = time.time()
if timestamp is None:
timestamp = time.time() - self.start_time
self.similarity_history.append(score)
self.frame_timestamps.append(timestamp)
def get_similarity_plot(self):
"""生成相似度变化折线图"""
if len(self.similarity_history) < 2:
return None
fig = go.Figure()
fig.add_trace(go.Scatter(
x=list(self.frame_timestamps),
y=list(self.similarity_history),
mode='lines+markers',
name='动作相似度',
line=dict(color='#2E86AB', width=2),
marker=dict(size=4)
))
fig.update_layout(
title='动作相似度变化趋势',
xaxis_title='时间 (秒)',
yaxis_title='相似度分数 (%)',
yaxis=dict(range=[0, 100]),
height=300,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=False
)
# 添加平均分数线
if len(self.similarity_history) > 0:
avg_score = sum(self.similarity_history) / len(self.similarity_history)
fig.add_hline(
y=avg_score,
line_dash="dash",
line_color="red",
annotation_text=f"平均分: {avg_score:.1f}%"
)
return fig
def reset(self):
"""重置分析器"""
self.similarity_history.clear()
self.frame_timestamps.clear()
self.start_time = None
class MotionComparisonApp:
def __init__(self):
self.body_detector = None
self.is_running = False
self.standard_video_path = None
self.webcam_cap = None
self.standard_cap = None
self.similarity_analyzer = PoseSimilarityAnalyzer()
self.frame_counter = 0
def preview_webcam(self):
"""显示摄像头预览,帮助用户调整位置"""
# 打开摄像头
self.webcam_cap = cv2.VideoCapture(0)
if not self.webcam_cap.isOpened():
st.error("无法打开摄像头")
return False
# 创建显示容器
st.subheader("摄像头预览")
preview_text = st.empty()
preview_text.info("请调整您的位置,确保全身在画面中清晰可见")
preview_placeholder = st.empty()
# 显示停止预览按钮
col1, col2, col3 = st.columns([1, 1, 1])
stop_preview = col2.button("停止预览", key="stop_preview_btn")
# 预览循环
while not stop_preview:
# 读取摄像头帧
ret, frame = self.webcam_cap.read()
if not ret:
st.error("无法获取摄像头画面")
break
# 翻转摄像头画面(镜像效果)
frame = cv2.flip(frame, 1)
# 如果检测器已初始化,可以显示关键点检测结果
if self.body_detector is not None:
try:
keypoints, scores = self.body_detector(frame)
frame = draw_skeleton(
frame.copy(),
keypoints,
scores,
openpose_skeleton=True,
kpt_thr=0.43
)
except Exception as e:
pass # 预览阶段,忽略检测错误
# 转换颜色空间并显示
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
preview_placeholder.image(frame_rgb, caption="摄像头预览", use_column_width=True)
# 检查停止按钮状态
if col2.button("停止预览", key=f"stop_preview_btn_{time.time()}", replace=True):
break
# 控制帧率
time.sleep(0.03) # 约30fps
# 返回预览是否成功
return True
def initialize_detector(self):
"""初始化身体关键点检测器"""
if self.body_detector is None:
device = 'cuda' if torch.cuda.device_count()> 0 else 'cpu'
self.body_detector = Body(
mode='lightweight', # 使用轻量级模式提高实时性能
to_openpose=True,
backend='onnxruntime',
device=device
)
st.success(f"关键点检测器初始化完成 (设备: {device})")
def process_frame_with_keypoints(self, frame):
"""处理帧并添加关键点"""
if self.body_detector is None:
return frame
try:
# 检测关键点
keypoints, scores = self.body_detector(frame)
# 绘制关键点
result_frame = draw_skeleton(
frame.copy(),
keypoints,
scores,
openpose_skeleton=True,
kpt_thr=0.43
)
return result_frame
except Exception as e:
st.error(f"关键点检测错误: {e}")
return frame
def start_comparison(self, video_path):
"""开始动作比较"""
self.is_running = True
self.standard_video_path = video_path
self.frame_counter = 0
# 重置相似度分析器
self.similarity_analyzer.reset()
# 打开标准视频
self.standard_cap = cv2.VideoCapture(video_path)
if not self.standard_cap.isOpened():
st.error("无法打开标准视频文件")
return
# 如果摄像头未打开,则打开摄像头
if self.webcam_cap is None or not self.webcam_cap.isOpened():
self.webcam_cap = cv2.VideoCapture(0)
if not self.webcam_cap.isOpened():
st.error("无法打开摄像头")
return
# 获取视频信息
fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
frame_delay = 1.0 / fps if fps > 0 else 1.0 / 30
# 创建显示容器
col1, col2 = st.columns(2)
with col1:
st.subheader("标准动作视频")
standard_placeholder = st.empty()
with col2:
st.subheader("摄像头实时影像")
webcam_placeholder = st.empty()
# 添加相似度显区域
similarity_col1, similarity_col2 = st.columns([1, 2])
with similarity_col1:
st.subheader("实时相似度")
similarity_score_placeholder = st.empty()
avg_score_placeholder = st.empty()
with similarity_col2:
st.subheader("相似度变化图")
similarity_plot_placeholder = st.empty()
# 添加控制按钮
control_col1, control_col2, control_col3 = st.columns(3)
with control_col1:
if st.button("停止", key="stop_btn"):
self.is_running = False
with control_col2:
if st.button("重新开始", key="restart_btn"):
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.similarity_analyzer.reset()
# 主循环
frame_count = 0
start_time = time.time()
current_similarity = 0
while self.is_running:
self.frame_counter += 1
# 读取标准视频帧
ret_standard, standard_frame = self.standard_cap.read()
if not ret_standard:
# 视频结束,重新开始
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
# 读取摄像头帧
ret_webcam, webcam_frame = self.webcam_cap.read()
if not ret_webcam:
st.error("无法获取摄像头画面")
break
# 翻转摄像头画面(镜像效果)
webcam_frame = cv2.flip(webcam_frame, 1)
# 调整尺寸使两个视频大小一致
target_height = 480
target_width = 640
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# 处理关键点检测
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
# 每10帧计算一次相似度
if self.frame_counter % 10 == 0:
# 提取关节角度
standard_angles = self.similarity_analyzer.extract_joint_angles(
standard_keypoints, standard_scores
)
webcam_angles = self.similarity_analyzer.extract_joint_angles(
webcam_keypoints, webcam_scores
)
# 计算相似度
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(
standard_angles, webcam_angles
)
# 添加到历史记录
timestamp = time.time() - start_time
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
# 绘制关键点
standard_with_keypoints = draw_skeleton(
standard_frame.copy(),
standard_keypoints,
standard_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
webcam_with_keypoints = draw_skeleton(
webcam_frame.copy(),
webcam_keypoints,
webcam_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
# 转换颜色空间 (BGR to RGB)
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
# 显示画面
standard_placeholder.image(standard_rgb, caption="标准动作", use_column_width=True)
webcam_placeholder.image(webcam_rgb, caption="您的动作", use_column_width=True)
# 显示相似度信息
if len(self.similarity_analyzer.similarity_history) > 0:
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
# 使用不同颜色显示相似度
if current_similarity >= 80:
similarity_color = "🟢"
elif current_similarity >= 60:
similarity_color = "🟡"
else:
similarity_color = "🔴"
similarity_score_placeholder.metric(
"当前相似度",
f"{similarity_color} {current_similarity:.1f}%",
delta=f"{current_similarity - avg_similarity:.1f}%" if len(self.similarity_analyzer.similarity_history) > 1 else None
)
avg_score_placeholder.metric(
"平均相似度",
f"{avg_similarity:.1f}%"
)
# 更新相似度图表
if len(self.similarity_analyzer.similarity_history) >= 2:
similarity_plot = self.similarity_analyzer.get_similarity_plot()
if similarity_plot:
similarity_plot_placeholder.plotly_chart(
similarity_plot,
use_container_width=True
)
# 计算FPS
frame_count += 1
if frame_count % 30 == 0:
elapsed = time.time() - start_time
fps_display = frame_count / elapsed
st.sidebar.metric("实时FPS", f"{fps_display:.1f}")
# 控制播放速度
time.sleep(frame_delay)
# 显示最终统计
if len(self.similarity_analyzer.similarity_history) > 0:
final_avg = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
max_similarity = max(self.similarity_analyzer.similarity_history)
min_similarity = min(self.similarity_analyzer.similarity_history)
st.success(f"动作比较完成!")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("平均相似度", f"{final_avg:.1f}%")
with col2:
st.metric("最高相似度", f"{max_similarity:.1f}%")
with col3:
st.metric("最低相似度", f"{min_similarity:.1f}%")
# 清理资源
if self.standard_cap:
self.standard_cap.release()
if self.webcam_cap:
self.webcam_cap.release()
def main():
os.makedirs("preset_videos", exist_ok=True)
# 页面配置
st.set_page_config(
page_title="动作比较与关键点检测",
page_icon="🏃",
layout="wide"
)
st.title("🏃 动作比较与关键点检测系统")
st.markdown("---")
# 创建应用实例
if 'app' not in st.session_state:
st.session_state.app = MotionComparisonApp()
app = st.session_state.app
# 侧边栏控制面板
with st.sidebar:
st.header("控制面板")
# 预设视频选择
preset_videos = {
"六字诀": "preset_videos/liuzi.mp4",
}
# 在这里放置预设视频状态检查代码
st.subheader("预设视频状态")
video_status_ok = False # 跟踪是否至少有一个预设视频可用
for name, path in preset_videos.items():
if os.path.exists(path):
st.success(f"{name}")
video_status_ok = True
else:
st.error(f"{name} (文件不存在)")
if not video_status_ok:
st.warning("⚠️ 没有找到任何预设视频文件!请将视频文件放入'preset_videos'文件夹。")
# 接下来是视频来源选择
video_source = st.radio(
"选择视频来源",
["预设视频", "上传视频"],
index=0 if video_status_ok else 1 # 如果没有预设视频可用,默认选择上传
)
if video_source == "预设视频":
selected_preset = st.selectbox(
"选择预设视频",
list(preset_videos.keys())
)
video_path = preset_videos[selected_preset]
# 检查预设视频文件是否存在
if not os.path.exists(video_path):
st.error(f"预设视频文件不存在: {video_path}")
st.info("请确保在'preset_videos'文件夹中放置了对应视频文件")
video_path = None
else:
st.success(f"已选择预设视频: {selected_preset}")
else:
# 上传标准视频调高大小限制到1GB
uploaded_video = st.file_uploader(
"上传标准动作视频",
type=['mp4', 'avi', 'mov', 'mkv'],
help="支持 MP4, AVI, MOV, MKV 格式",
accept_multiple_files=False
)
# 设置最大上传大小 (1GB)
# 注意这需要修改streamlit的配置
st.markdown("""
<style>
.uploadedFile {max-width: 100% !important;}
</style>
""", unsafe_allow_html=True)
if uploaded_video is not None:
# 保存上传的视频到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(uploaded_video.read())
video_path = tmp_file.name
st.success("✅ 视频上传成功!")
else:
video_path = None
# 初始化检测器
if st.button("初始化关键点检测器"):
with st.spinner("正在初始化..."):
app.initialize_detector()
# 显示系统信息
st.subheader("系统信息")
if torch.cuda.device_count() > 0:
st.success("✅ CUDA 可用")
else:
st.info(" 使用 CPU 模式")
# 主界面
if video_path is not None:
# 显示视频信息
cap = cv2.VideoCapture(video_path)
if cap.isOpened():
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps if fps > 0 else 0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("时长", f"{duration:.1f}")
with col2:
st.metric("帧率", f"{fps:.1f} FPS")
with col3:
st.metric("分辨率", f"{width}×{height}")
with col4:
st.metric("总帧数", f"{frame_count}")
cap.release()
# 开始按钮
st.markdown("---")
col1, col2, col3 = st.columns([1, 1, 1])
# 添加一个session_state变量来跟踪预览状态
if 'preview_done' not in st.session_state:
st.session_state.preview_done = False
with col2:
# 如果还没预览,显示"预览摄像头"按钮
if not st.session_state.preview_done:
if st.button("📷 预览摄像头", use_container_width=True):
if app.body_detector is None:
st.error("请先初始化关键点检测器")
else:
# 进行摄像头预览
preview_success = app.preview_webcam()
if preview_success:
st.session_state.preview_done = True
# 使用rerun来刷新界面显示开始比较按钮
st.rerun()
# 如果已经预览过,显示"开始动作比较"按钮
else:
if st.button("🚀 开始动作比较", use_container_width=True):
if app.body_detector is None:
st.error("请先初始化关键点检测器")
else:
st.info("开始动作比较...")
app.start_comparison(video_path)
# 重置预览状态,以便下次使用
st.session_state.preview_done = False
# 如果是上传的视频,结束时清理临时文件
if video_source == "上传视频" and os.path.exists(video_path) and video_path.startswith(tempfile.gettempdir()):
try:
# 这里不要立即删除,而是在应用结束时删除
# 注册一个退出时删除文件的函数
import atexit
atexit.register(lambda: os.unlink(video_path) if os.path.exists(video_path) else None)
except:
pass
else:
if video_source == "预设视频":
st.error("预设视频文件不可用,请检查文件路径")
else:
st.info("👆 请在左侧上传标准动作视频文件")
# 显示使用说明
with st.expander("📖 使用说明"):
st.markdown("""
### 如何使用此应用:
1. **上传视频**: 在左侧上传标准动作视频文件
2. **初始化检测器**: 点击"初始化关键点检测器"按钮
3. **开始比较**: 点击"开始动作比较"按钮
4. **跟随动作**: 在摄像头前跟随标准视频做相同动作
5. **观察对比**: 系统会实时显示两边的关键点检测结果
### 特性:
- ✅ 实时关键点检测
- ✅ 镜像摄像头画面
- ✅ 自动循环播放标准视频
- ✅ FPS 监控
- ✅ GPU 加速支持
### 要求:
- 摄像头权限
- 良好的光照条件
- 清晰的人体轮廓
""")
if __name__ == "__main__":
main()