706 lines
26 KiB
Python
706 lines
26 KiB
Python
import streamlit as st
|
||
import cv2
|
||
import time
|
||
import tempfile
|
||
import os
|
||
import torch
|
||
import matplotlib.pyplot as plt
|
||
import plotly.graph_objects as go
|
||
from plotly.subplots import make_subplots
|
||
import pandas as pd
|
||
import numpy as np
|
||
from collections import deque
|
||
import math
|
||
|
||
|
||
# 导入rtmlib
|
||
try:
|
||
from rtmlib import Body, draw_skeleton
|
||
except ImportError:
|
||
st.error("请安装rtmlib库: pip install rtmlib")
|
||
st.stop()
|
||
|
||
class PoseSimilarityAnalyzer:
|
||
def __init__(self):
|
||
self.similarity_history = deque(maxlen=500) # 保存最近500个相似度值
|
||
self.frame_timestamps = deque(maxlen=500)
|
||
self.start_time = None
|
||
|
||
# OpenPose关键点索引映射(Body 17个关键点)
|
||
self.keypoint_map = {
|
||
'nose': 0, 'neck': 1,
|
||
'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
|
||
'right_shoulder': 5, 'right_elbow': 6, 'right_wrist': 7,
|
||
'left_hip': 8, 'left_knee': 9, 'left_ankle': 10,
|
||
'right_hip': 11, 'right_knee': 12, 'right_ankle': 13,
|
||
'left_eye': 14, 'right_eye': 15,
|
||
'left_ear': 16, 'right_ear': 17
|
||
}
|
||
|
||
# 关节角度定义(关节名:[点1, 关节点, 点2])
|
||
self.joint_angles = {
|
||
'left_elbow': ['left_shoulder', 'left_elbow', 'left_wrist'],
|
||
'right_elbow': ['right_shoulder', 'right_elbow', 'right_wrist'],
|
||
'left_shoulder': ['left_elbow', 'left_shoulder', 'neck'],
|
||
'right_shoulder': ['right_elbow', 'right_shoulder', 'neck'],
|
||
'left_knee': ['left_hip', 'left_knee', 'left_ankle'],
|
||
'right_knee': ['right_hip', 'right_knee', 'right_ankle'],
|
||
'left_hip': ['left_knee', 'left_hip', 'neck'],
|
||
'right_hip': ['right_knee', 'right_hip', 'neck'],
|
||
}
|
||
|
||
# 关节权重(重要关节权重更高)
|
||
self.joint_weights = {
|
||
'left_elbow': 1.2,
|
||
'right_elbow': 1.2,
|
||
'left_shoulder': 1.0,
|
||
'right_shoulder': 1.0,
|
||
'left_knee': 1.3,
|
||
'right_knee': 1.3,
|
||
'left_hip': 1.1,
|
||
'right_hip': 1.1,
|
||
}
|
||
|
||
def calculate_angle(self, p1, p2, p3):
|
||
"""计算三个点组成的角度"""
|
||
# 向量v1 = p1 - p2, v2 = p3 - p2
|
||
v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]])
|
||
v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]])
|
||
|
||
# 计算向量长度
|
||
v1_norm = np.linalg.norm(v1)
|
||
v2_norm = np.linalg.norm(v2)
|
||
|
||
if v1_norm == 0 or v2_norm == 0:
|
||
return None
|
||
|
||
# 计算夹角(弧度)
|
||
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
|
||
cos_angle = np.clip(cos_angle, -1.0, 1.0) # 防止数值误差
|
||
angle = np.arccos(cos_angle)
|
||
|
||
# 转换为角度
|
||
return np.degrees(angle)
|
||
|
||
def extract_joint_angles(self, keypoints, scores, confidence_threshold=0.3):
|
||
"""从关键点提取关节角度"""
|
||
if keypoints is None or len(keypoints) == 0:
|
||
return None
|
||
|
||
# 取第一个人的关键点
|
||
person_keypoints = keypoints[0] if len(keypoints.shape) > 2 else keypoints
|
||
person_scores = scores[0] if len(scores.shape) > 1 else scores
|
||
|
||
joint_angles_result = {}
|
||
|
||
for joint_name, (p1_name, p2_name, p3_name) in self.joint_angles.items():
|
||
try:
|
||
# 获取关键点索引
|
||
p1_idx = self.keypoint_map[p1_name]
|
||
p2_idx = self.keypoint_map[p2_name]
|
||
p3_idx = self.keypoint_map[p3_name]
|
||
|
||
# 检查置信度
|
||
if (person_scores[p1_idx] < confidence_threshold or
|
||
person_scores[p2_idx] < confidence_threshold or
|
||
person_scores[p3_idx] < confidence_threshold):
|
||
continue
|
||
|
||
# 获取坐标
|
||
p1 = person_keypoints[p1_idx]
|
||
p2 = person_keypoints[p2_idx]
|
||
p3 = person_keypoints[p3_idx]
|
||
|
||
# 计算角度
|
||
angle = self.calculate_angle(p1, p2, p3)
|
||
if angle is not None:
|
||
joint_angles_result[joint_name] = angle
|
||
|
||
except (KeyError, IndexError) as e:
|
||
continue
|
||
|
||
return joint_angles_result
|
||
|
||
def calculate_similarity(self, angles1, angles2):
|
||
"""计算两组关节角度的相似度"""
|
||
if not angles1 or not angles2:
|
||
return 0.0
|
||
|
||
# 找到共同的关节
|
||
common_joints = set(angles1.keys()) & set(angles2.keys())
|
||
if not common_joints:
|
||
return 0.0
|
||
|
||
total_weight = 0
|
||
weighted_similarity = 0
|
||
|
||
for joint in common_joints:
|
||
angle_diff = abs(angles1[joint] - angles2[joint])
|
||
|
||
# 角度差异转换为相似度(0-1)
|
||
# 使用高斯函数,角度差异越小,相似度越高
|
||
similarity = math.exp(-(angle_diff ** 2) / (2 * (30 ** 2))) # 30度标准差
|
||
|
||
# 应用权重
|
||
weight = self.joint_weights.get(joint, 1.0)
|
||
weighted_similarity += similarity * weight
|
||
total_weight += weight
|
||
|
||
# 归一化相似度
|
||
final_similarity = weighted_similarity / total_weight if total_weight > 0 else 0
|
||
return min(max(final_similarity * 100, 0), 100) # 转换为0-100分
|
||
|
||
def add_similarity_score(self, score, timestamp=None):
|
||
"""添加相似度分数到历史记录"""
|
||
if self.start_time is None:
|
||
self.start_time = time.time()
|
||
|
||
if timestamp is None:
|
||
timestamp = time.time() - self.start_time
|
||
|
||
self.similarity_history.append(score)
|
||
self.frame_timestamps.append(timestamp)
|
||
|
||
def get_similarity_plot(self):
|
||
"""生成相似度变化折线图"""
|
||
if len(self.similarity_history) < 2:
|
||
return None
|
||
|
||
fig = go.Figure()
|
||
|
||
fig.add_trace(go.Scatter(
|
||
x=list(self.frame_timestamps),
|
||
y=list(self.similarity_history),
|
||
mode='lines+markers',
|
||
name='动作相似度',
|
||
line=dict(color='#2E86AB', width=2),
|
||
marker=dict(size=4)
|
||
))
|
||
|
||
fig.update_layout(
|
||
title='动作相似度变化趋势',
|
||
xaxis_title='时间 (秒)',
|
||
yaxis_title='相似度分数 (%)',
|
||
yaxis=dict(range=[0, 100]),
|
||
height=300,
|
||
margin=dict(l=50, r=50, t=50, b=50),
|
||
showlegend=False
|
||
)
|
||
|
||
# 添加平均分数线
|
||
if len(self.similarity_history) > 0:
|
||
avg_score = sum(self.similarity_history) / len(self.similarity_history)
|
||
fig.add_hline(
|
||
y=avg_score,
|
||
line_dash="dash",
|
||
line_color="red",
|
||
annotation_text=f"平均分: {avg_score:.1f}%"
|
||
)
|
||
|
||
return fig
|
||
|
||
def reset(self):
|
||
"""重置分析器"""
|
||
self.similarity_history.clear()
|
||
self.frame_timestamps.clear()
|
||
self.start_time = None
|
||
|
||
|
||
class MotionComparisonApp:
|
||
def __init__(self):
|
||
self.body_detector = None
|
||
self.is_running = False
|
||
self.standard_video_path = None
|
||
self.webcam_cap = None
|
||
self.standard_cap = None
|
||
self.similarity_analyzer = PoseSimilarityAnalyzer()
|
||
self.frame_counter = 0
|
||
|
||
|
||
def preview_webcam(self):
|
||
"""显示摄像头预览,帮助用户调整位置"""
|
||
# 打开摄像头
|
||
self.webcam_cap = cv2.VideoCapture(0)
|
||
if not self.webcam_cap.isOpened():
|
||
st.error("无法打开摄像头")
|
||
return False
|
||
|
||
# 创建显示容器
|
||
st.subheader("摄像头预览")
|
||
preview_text = st.empty()
|
||
preview_text.info("请调整您的位置,确保全身在画面中清晰可见")
|
||
|
||
preview_placeholder = st.empty()
|
||
|
||
# 显示停止预览按钮
|
||
col1, col2, col3 = st.columns([1, 1, 1])
|
||
stop_preview = col2.button("停止预览", key="stop_preview_btn")
|
||
|
||
# 预览循环
|
||
while not stop_preview:
|
||
# 读取摄像头帧
|
||
ret, frame = self.webcam_cap.read()
|
||
if not ret:
|
||
st.error("无法获取摄像头画面")
|
||
break
|
||
|
||
# 翻转摄像头画面(镜像效果)
|
||
frame = cv2.flip(frame, 1)
|
||
|
||
# 如果检测器已初始化,可以显示关键点检测结果
|
||
if self.body_detector is not None:
|
||
try:
|
||
keypoints, scores = self.body_detector(frame)
|
||
frame = draw_skeleton(
|
||
frame.copy(),
|
||
keypoints,
|
||
scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
except Exception as e:
|
||
pass # 预览阶段,忽略检测错误
|
||
|
||
# 转换颜色空间并显示
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
preview_placeholder.image(frame_rgb, caption="摄像头预览", use_column_width=True)
|
||
|
||
# 检查停止按钮状态
|
||
if col2.button("停止预览", key=f"stop_preview_btn_{time.time()}", replace=True):
|
||
break
|
||
|
||
# 控制帧率
|
||
time.sleep(0.03) # 约30fps
|
||
|
||
# 返回预览是否成功
|
||
return True
|
||
|
||
def initialize_detector(self):
|
||
"""初始化身体关键点检测器"""
|
||
if self.body_detector is None:
|
||
device = 'cuda' if torch.cuda.device_count()> 0 else 'cpu'
|
||
self.body_detector = Body(
|
||
mode='lightweight', # 使用轻量级模式提高实时性能
|
||
to_openpose=True,
|
||
backend='onnxruntime',
|
||
device=device
|
||
)
|
||
st.success(f"关键点检测器初始化完成 (设备: {device})")
|
||
|
||
def process_frame_with_keypoints(self, frame):
|
||
"""处理帧并添加关键点"""
|
||
if self.body_detector is None:
|
||
return frame
|
||
|
||
try:
|
||
# 检测关键点
|
||
keypoints, scores = self.body_detector(frame)
|
||
|
||
# 绘制关键点
|
||
result_frame = draw_skeleton(
|
||
frame.copy(),
|
||
keypoints,
|
||
scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
return result_frame
|
||
except Exception as e:
|
||
st.error(f"关键点检测错误: {e}")
|
||
return frame
|
||
|
||
def start_comparison(self, video_path):
|
||
"""开始动作比较"""
|
||
self.is_running = True
|
||
self.standard_video_path = video_path
|
||
self.frame_counter = 0
|
||
|
||
# 重置相似度分析器
|
||
self.similarity_analyzer.reset()
|
||
|
||
# 打开标准视频
|
||
self.standard_cap = cv2.VideoCapture(video_path)
|
||
if not self.standard_cap.isOpened():
|
||
st.error("无法打开标准视频文件")
|
||
return
|
||
|
||
# 如果摄像头未打开,则打开摄像头
|
||
if self.webcam_cap is None or not self.webcam_cap.isOpened():
|
||
self.webcam_cap = cv2.VideoCapture(0)
|
||
if not self.webcam_cap.isOpened():
|
||
st.error("无法打开摄像头")
|
||
return
|
||
|
||
# 获取视频信息
|
||
fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
|
||
frame_delay = 1.0 / fps if fps > 0 else 1.0 / 30
|
||
|
||
# 创建显示容器
|
||
col1, col2 = st.columns(2)
|
||
|
||
with col1:
|
||
st.subheader("标准动作视频")
|
||
standard_placeholder = st.empty()
|
||
|
||
with col2:
|
||
st.subheader("摄像头实时影像")
|
||
webcam_placeholder = st.empty()
|
||
|
||
# 添加相似度显区域
|
||
similarity_col1, similarity_col2 = st.columns([1, 2])
|
||
with similarity_col1:
|
||
st.subheader("实时相似度")
|
||
similarity_score_placeholder = st.empty()
|
||
avg_score_placeholder = st.empty()
|
||
|
||
with similarity_col2:
|
||
st.subheader("相似度变化图")
|
||
similarity_plot_placeholder = st.empty()
|
||
|
||
# 添加控制按钮
|
||
control_col1, control_col2, control_col3 = st.columns(3)
|
||
with control_col1:
|
||
if st.button("停止", key="stop_btn"):
|
||
self.is_running = False
|
||
with control_col2:
|
||
if st.button("重新开始", key="restart_btn"):
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
self.similarity_analyzer.reset()
|
||
|
||
# 主循环
|
||
frame_count = 0
|
||
start_time = time.time()
|
||
current_similarity = 0
|
||
|
||
while self.is_running:
|
||
self.frame_counter += 1
|
||
|
||
# 读取标准视频帧
|
||
ret_standard, standard_frame = self.standard_cap.read()
|
||
if not ret_standard:
|
||
# 视频结束,重新开始
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
continue
|
||
|
||
# 读取摄像头帧
|
||
ret_webcam, webcam_frame = self.webcam_cap.read()
|
||
if not ret_webcam:
|
||
st.error("无法获取摄像头画面")
|
||
break
|
||
|
||
# 翻转摄像头画面(镜像效果)
|
||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||
|
||
# 调整尺寸使两个视频大小一致
|
||
target_height = 480
|
||
target_width = 640
|
||
|
||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||
|
||
# 处理关键点检测
|
||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||
|
||
# 每10帧计算一次相似度
|
||
if self.frame_counter % 10 == 0:
|
||
# 提取关节角度
|
||
standard_angles = self.similarity_analyzer.extract_joint_angles(
|
||
standard_keypoints, standard_scores
|
||
)
|
||
webcam_angles = self.similarity_analyzer.extract_joint_angles(
|
||
webcam_keypoints, webcam_scores
|
||
)
|
||
|
||
# 计算相似度
|
||
if standard_angles and webcam_angles:
|
||
current_similarity = self.similarity_analyzer.calculate_similarity(
|
||
standard_angles, webcam_angles
|
||
)
|
||
|
||
# 添加到历史记录
|
||
timestamp = time.time() - start_time
|
||
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
|
||
|
||
# 绘制关键点
|
||
standard_with_keypoints = draw_skeleton(
|
||
standard_frame.copy(),
|
||
standard_keypoints,
|
||
standard_scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
webcam_with_keypoints = draw_skeleton(
|
||
webcam_frame.copy(),
|
||
webcam_keypoints,
|
||
webcam_scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
|
||
# 转换颜色空间 (BGR to RGB)
|
||
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
|
||
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
|
||
|
||
# 显示画面
|
||
standard_placeholder.image(standard_rgb, caption="标准动作", use_column_width=True)
|
||
webcam_placeholder.image(webcam_rgb, caption="您的动作", use_column_width=True)
|
||
|
||
# 显示相似度信息
|
||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||
|
||
# 使用不同颜色显示相似度
|
||
if current_similarity >= 80:
|
||
similarity_color = "🟢"
|
||
elif current_similarity >= 60:
|
||
similarity_color = "🟡"
|
||
else:
|
||
similarity_color = "🔴"
|
||
|
||
similarity_score_placeholder.metric(
|
||
"当前相似度",
|
||
f"{similarity_color} {current_similarity:.1f}%",
|
||
delta=f"{current_similarity - avg_similarity:.1f}%" if len(self.similarity_analyzer.similarity_history) > 1 else None
|
||
)
|
||
|
||
avg_score_placeholder.metric(
|
||
"平均相似度",
|
||
f"{avg_similarity:.1f}%"
|
||
)
|
||
|
||
# 更新相似度图表
|
||
if len(self.similarity_analyzer.similarity_history) >= 2:
|
||
similarity_plot = self.similarity_analyzer.get_similarity_plot()
|
||
if similarity_plot:
|
||
similarity_plot_placeholder.plotly_chart(
|
||
similarity_plot,
|
||
use_container_width=True
|
||
)
|
||
|
||
# 计算FPS
|
||
frame_count += 1
|
||
if frame_count % 30 == 0:
|
||
elapsed = time.time() - start_time
|
||
fps_display = frame_count / elapsed
|
||
st.sidebar.metric("实时FPS", f"{fps_display:.1f}")
|
||
|
||
# 控制播放速度
|
||
time.sleep(frame_delay)
|
||
|
||
# 显示最终统计
|
||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||
final_avg = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||
max_similarity = max(self.similarity_analyzer.similarity_history)
|
||
min_similarity = min(self.similarity_analyzer.similarity_history)
|
||
|
||
st.success(f"动作比较完成!")
|
||
col1, col2, col3 = st.columns(3)
|
||
with col1:
|
||
st.metric("平均相似度", f"{final_avg:.1f}%")
|
||
with col2:
|
||
st.metric("最高相似度", f"{max_similarity:.1f}%")
|
||
with col3:
|
||
st.metric("最低相似度", f"{min_similarity:.1f}%")
|
||
|
||
# 清理资源
|
||
if self.standard_cap:
|
||
self.standard_cap.release()
|
||
if self.webcam_cap:
|
||
self.webcam_cap.release()
|
||
|
||
def main():
|
||
os.makedirs("preset_videos", exist_ok=True)
|
||
# 页面配置
|
||
st.set_page_config(
|
||
page_title="动作比较与关键点检测",
|
||
page_icon="🏃",
|
||
layout="wide"
|
||
)
|
||
|
||
st.title("🏃 动作比较与关键点检测系统")
|
||
st.markdown("---")
|
||
|
||
# 创建应用实例
|
||
if 'app' not in st.session_state:
|
||
st.session_state.app = MotionComparisonApp()
|
||
|
||
app = st.session_state.app
|
||
|
||
# 侧边栏控制面板
|
||
with st.sidebar:
|
||
st.header("控制面板")
|
||
|
||
# 预设视频选择
|
||
preset_videos = {
|
||
"六字诀": "preset_videos/liuzi.mp4",
|
||
}
|
||
|
||
# 在这里放置预设视频状态检查代码
|
||
st.subheader("预设视频状态")
|
||
video_status_ok = False # 跟踪是否至少有一个预设视频可用
|
||
for name, path in preset_videos.items():
|
||
if os.path.exists(path):
|
||
st.success(f"✅ {name}")
|
||
video_status_ok = True
|
||
else:
|
||
st.error(f"❌ {name} (文件不存在)")
|
||
|
||
if not video_status_ok:
|
||
st.warning("⚠️ 没有找到任何预设视频文件!请将视频文件放入'preset_videos'文件夹。")
|
||
|
||
# 接下来是视频来源选择
|
||
video_source = st.radio(
|
||
"选择视频来源",
|
||
["预设视频", "上传视频"],
|
||
index=0 if video_status_ok else 1 # 如果没有预设视频可用,默认选择上传
|
||
)
|
||
if video_source == "预设视频":
|
||
selected_preset = st.selectbox(
|
||
"选择预设视频",
|
||
list(preset_videos.keys())
|
||
)
|
||
video_path = preset_videos[selected_preset]
|
||
|
||
# 检查预设视频文件是否存在
|
||
if not os.path.exists(video_path):
|
||
st.error(f"预设视频文件不存在: {video_path}")
|
||
st.info("请确保在'preset_videos'文件夹中放置了对应视频文件")
|
||
video_path = None
|
||
else:
|
||
st.success(f"已选择预设视频: {selected_preset}")
|
||
else:
|
||
# 上传标准视频,调高大小限制到1GB
|
||
uploaded_video = st.file_uploader(
|
||
"上传标准动作视频",
|
||
type=['mp4', 'avi', 'mov', 'mkv'],
|
||
help="支持 MP4, AVI, MOV, MKV 格式",
|
||
accept_multiple_files=False
|
||
)
|
||
|
||
# 设置最大上传大小 (1GB)
|
||
# 注意:这需要修改streamlit的配置
|
||
st.markdown("""
|
||
<style>
|
||
.uploadedFile {max-width: 100% !important;}
|
||
</style>
|
||
""", unsafe_allow_html=True)
|
||
|
||
if uploaded_video is not None:
|
||
# 保存上传的视频到临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
|
||
tmp_file.write(uploaded_video.read())
|
||
video_path = tmp_file.name
|
||
st.success("✅ 视频上传成功!")
|
||
else:
|
||
video_path = None
|
||
|
||
# 初始化检测器
|
||
if st.button("初始化关键点检测器"):
|
||
with st.spinner("正在初始化..."):
|
||
app.initialize_detector()
|
||
|
||
# 显示系统信息
|
||
st.subheader("系统信息")
|
||
if torch.cuda.device_count() > 0:
|
||
st.success("✅ CUDA 可用")
|
||
else:
|
||
st.info("ℹ️ 使用 CPU 模式")
|
||
|
||
|
||
# 主界面
|
||
if video_path is not None:
|
||
# 显示视频信息
|
||
cap = cv2.VideoCapture(video_path)
|
||
if cap.isOpened():
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
duration = frame_count / fps if fps > 0 else 0
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
col1, col2, col3, col4 = st.columns(4)
|
||
with col1:
|
||
st.metric("时长", f"{duration:.1f}秒")
|
||
with col2:
|
||
st.metric("帧率", f"{fps:.1f} FPS")
|
||
with col3:
|
||
st.metric("分辨率", f"{width}×{height}")
|
||
with col4:
|
||
st.metric("总帧数", f"{frame_count}")
|
||
|
||
cap.release()
|
||
|
||
# 开始按钮
|
||
st.markdown("---")
|
||
col1, col2, col3 = st.columns([1, 1, 1])
|
||
# 添加一个session_state变量来跟踪预览状态
|
||
if 'preview_done' not in st.session_state:
|
||
st.session_state.preview_done = False
|
||
with col2:
|
||
# 如果还没预览,显示"预览摄像头"按钮
|
||
if not st.session_state.preview_done:
|
||
if st.button("📷 预览摄像头", use_container_width=True):
|
||
if app.body_detector is None:
|
||
st.error("请先初始化关键点检测器")
|
||
else:
|
||
# 进行摄像头预览
|
||
preview_success = app.preview_webcam()
|
||
if preview_success:
|
||
st.session_state.preview_done = True
|
||
# 使用rerun来刷新界面,显示开始比较按钮
|
||
st.rerun()
|
||
# 如果已经预览过,显示"开始动作比较"按钮
|
||
else:
|
||
if st.button("🚀 开始动作比较", use_container_width=True):
|
||
if app.body_detector is None:
|
||
st.error("请先初始化关键点检测器")
|
||
else:
|
||
st.info("开始动作比较...")
|
||
app.start_comparison(video_path)
|
||
# 重置预览状态,以便下次使用
|
||
st.session_state.preview_done = False
|
||
|
||
# 如果是上传的视频,结束时清理临时文件
|
||
if video_source == "上传视频" and os.path.exists(video_path) and video_path.startswith(tempfile.gettempdir()):
|
||
try:
|
||
# 这里不要立即删除,而是在应用结束时删除
|
||
# 注册一个退出时删除文件的函数
|
||
import atexit
|
||
atexit.register(lambda: os.unlink(video_path) if os.path.exists(video_path) else None)
|
||
except:
|
||
pass
|
||
else:
|
||
if video_source == "预设视频":
|
||
st.error("预设视频文件不可用,请检查文件路径")
|
||
else:
|
||
st.info("👆 请在左侧上传标准动作视频文件")
|
||
|
||
# 显示使用说明
|
||
with st.expander("📖 使用说明"):
|
||
st.markdown("""
|
||
### 如何使用此应用:
|
||
|
||
1. **上传视频**: 在左侧上传标准动作视频文件
|
||
2. **初始化检测器**: 点击"初始化关键点检测器"按钮
|
||
3. **开始比较**: 点击"开始动作比较"按钮
|
||
4. **跟随动作**: 在摄像头前跟随标准视频做相同动作
|
||
5. **观察对比**: 系统会实时显示两边的关键点检测结果
|
||
|
||
### 特性:
|
||
- ✅ 实时关键点检测
|
||
- ✅ 镜像摄像头画面
|
||
- ✅ 自动循环播放标准视频
|
||
- ✅ FPS 监控
|
||
- ✅ GPU 加速支持
|
||
|
||
### 要求:
|
||
- 摄像头权限
|
||
- 良好的光照条件
|
||
- 清晰的人体轮廓
|
||
""")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|