Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import mediapipe as mp | |
| import numpy as np | |
| import tempfile | |
| import time | |
| import base64 | |
| # Initialize MediaPipe Pose | |
| mp_pose = mp.solutions.pose | |
| pose = mp_pose.Pose( | |
| static_image_mode=False, | |
| model_complexity=1, | |
| enable_segmentation=True, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5 | |
| ) | |
| def calculate_angle_between_vectors(v1, v2): | |
| """ | |
| Calculates the angle between two vectors in degrees. | |
| """ | |
| unit_vector_1 = v1 / np.linalg.norm(v1) | |
| unit_vector_2 = v2 / np.linalg.norm(v2) | |
| dot_product = np.dot(unit_vector_1, unit_vector_2) | |
| angle = np.arccos(np.clip(dot_product, -1.0, 1.0)) # Clip for numerical stability | |
| return np.degrees(angle) | |
| def process_video(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| # Define the order of phases | |
| phases_order = [ | |
| "Setup Phase", | |
| "Mid Backswing Phase", | |
| "Top Backswing Phase", | |
| "Mid Downswing Phase", | |
| "Ball Impact Phase", | |
| "Follow Through Phase" | |
| ] | |
| # Initialize a dictionary to store the first detected frame for each phase | |
| phase_images = {phase: None for phase in phases_order} | |
| # Initialize variables for phase detection | |
| current_phase = "Not Setup Phase" | |
| prev_wrist_left_y = None | |
| prev_wrist_right_y = None | |
| top_backswing_detected = False | |
| mid_downswing_detected = False | |
| ball_impact_detected = False | |
| top_backswing_frame = -2 | |
| mid_downswing_frame = -2 | |
| ball_impact_frame = -2 | |
| mid_backswing_wrist_left_y = None | |
| mid_backswing_wrist_right_y = None | |
| BALL_IMPACT_DURATION = 2 # Duration in frames to display Ball Impact phase | |
| MIN_MOVEMENT_THRESHOLD = 0.01 | |
| HIP_NEAR_THRESHOLD = 0.05 | |
| MID_SWING_THRESHOLD = 0.05 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| processed_frames = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) | |
| image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| result = pose.process(image_rgb) | |
| if result.pose_landmarks: | |
| landmarks = result.pose_landmarks.landmark | |
| wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y | |
| wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y | |
| hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y | |
| hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y | |
| shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y | |
| shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y | |
| hip_y_avg = (hip_left_y + hip_right_y) / 2 | |
| shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2 | |
| mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2 | |
| # Phase Detection Logic | |
| if ball_impact_detected and frame_no > ball_impact_frame + BALL_IMPACT_DURATION: | |
| current_phase = "Follow Through Phase" | |
| elif abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and \ | |
| abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and \ | |
| not top_backswing_detected and not ball_impact_detected: | |
| current_phase = "Mid Backswing Phase" | |
| mid_backswing_wrist_left_y = wrist_left_y | |
| mid_backswing_wrist_right_y = wrist_right_y | |
| elif wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and \ | |
| not mid_downswing_detected and not ball_impact_detected: | |
| current_phase = "Top Backswing Phase" | |
| top_backswing_detected = True | |
| top_backswing_frame = frame_no | |
| elif mid_backswing_wrist_left_y is not None and mid_backswing_wrist_right_y is not None and \ | |
| abs(wrist_left_y - mid_backswing_wrist_left_y) < MID_SWING_THRESHOLD and \ | |
| abs(wrist_right_y - mid_backswing_wrist_right_y) < MID_SWING_THRESHOLD and \ | |
| top_backswing_detected and frame_no > top_backswing_frame: | |
| current_phase = "Mid Downswing Phase" | |
| mid_downswing_detected = True | |
| mid_downswing_frame = frame_no | |
| elif abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and \ | |
| abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD: | |
| if prev_wrist_left_y is not None and prev_wrist_right_y is not None: | |
| if abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and \ | |
| abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD: | |
| if mid_downswing_detected and frame_no > mid_downswing_frame: | |
| current_phase = "Ball Impact Phase" | |
| ball_impact_detected = True | |
| ball_impact_frame = frame_no | |
| else: | |
| current_phase = "Setup Phase" | |
| top_backswing_detected = False | |
| mid_downswing_detected = False | |
| else: | |
| current_phase = "" | |
| else: | |
| if mid_downswing_detected and frame_no > mid_downswing_frame: | |
| current_phase = "Ball Impact Phase" | |
| ball_impact_detected = True | |
| ball_impact_frame = frame_no | |
| else: | |
| current_phase = "Setup Phase" | |
| top_backswing_detected = False | |
| mid_downswing_detected = False | |
| else: | |
| current_phase = "" | |
| prev_wrist_left_y = wrist_left_y | |
| prev_wrist_right_y = wrist_right_y | |
| # **Removed Overlay of Phase Information on Video Frames** | |
| # Encode frame for display | |
| frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| ret, buffer = cv2.imencode('.jpg', frame_bgr) | |
| frame_jpg = buffer.tobytes() | |
| # Store the first detected frame for each phase | |
| if current_phase in phases_order and phase_images[current_phase] is None: | |
| phase_images[current_phase] = frame_jpg | |
| processed_frames += 1 | |
| yield frame_jpg, current_phase, processed_frames, total_frames | |
| cap.release() | |
| pose.close() | |
| yield phase_images, None, processed_frames, total_frames | |
| # Streamlit UI Configuration | |
| # Custom CSS for styling | |
| st.markdown(""" | |
| <style> | |
| /* General Styles */ | |
| body { | |
| background-color: #f5f5f5; | |
| } | |
| .title { | |
| font-size: 36px; | |
| font-weight: bold; | |
| color: #2E86C1; | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| .subtitle { | |
| font-size: 18px; | |
| color: #555; | |
| text-align: center; | |
| margin-bottom: 30px; | |
| } | |
| .phase-placeholder { | |
| border: 2px solid #2E86C1; | |
| background-color: #D6EAF8; | |
| padding: 15px; | |
| text-align: center; | |
| font-size: 24px; | |
| border-radius: 10px; | |
| height: 60px; | |
| margin-bottom: 20px; | |
| color: #1B4F72; | |
| } | |
| .detected-phases { | |
| display: flex; | |
| flex-wrap: wrap; | |
| justify-content: center; | |
| gap: 30px; /* Increased gap for more space between phase cards */ | |
| } | |
| .phase-card { | |
| border: 1px solid #ddd; | |
| border-radius: 10px; | |
| padding: 15px; | |
| background-color: #fff; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| text-align: center; | |
| width: 250px; | |
| height: 300px; /* Fixed height to accommodate image and text */ | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| .phase-card img { | |
| width: 200px; /* Fixed width */ | |
| height: 200px; /* Fixed height */ | |
| object-fit: cover; /* Ensures image covers the area without distortion */ | |
| border-radius: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .footer { | |
| text-align: center; | |
| font-size: 14px; | |
| color: #aaa; | |
| margin-top: 40px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Title and Description | |
| st.markdown("<div class='title'>ποΈββοΈ Golf Swing Phase Detection</div><hr>", unsafe_allow_html=True) | |
| st.markdown("<div class='subtitle'>Upload a golf swing video to automatically detect and visualize its different phases.</div>", unsafe_allow_html=True) | |
| # File Uploader | |
| video_file = st.file_uploader("π₯ **Upload Your Golf Swing Video**", type=["mp4", "avi", "mov", "mkv"]) | |
| if video_file is not None: | |
| with st.spinner("π Processing video..."): | |
| try: | |
| # Save uploaded video to a temporary file | |
| tfile = tempfile.NamedTemporaryFile(delete=False) | |
| tfile.write(video_file.read()) | |
| tfile_path = tfile.name | |
| phase_placeholder = st.empty() | |
| video_placeholder = st.empty() | |
| progress_bar = st.progress(0) | |
| progress_text = st.empty() | |
| phase_images = {} | |
| detected_phases = [] | |
| for result, current_phase, processed, total in process_video(tfile_path): | |
| if current_phase is not None: | |
| # Update phase placeholder | |
| phase_placeholder.markdown( | |
| f"<div class='phase-placeholder'>{current_phase}</div>", | |
| unsafe_allow_html=True | |
| ) | |
| # Display video frame | |
| video_placeholder.image(result, channels='BGR', use_column_width=True) | |
| # Update progress bar | |
| progress = processed / total if total > 0 else 0 | |
| progress_bar.progress(progress) | |
| progress_text.text(f"Processing frame {processed} of {total}") | |
| # Collect detected phases | |
| if current_phase not in detected_phases and current_phase != "Not Setup Phase": | |
| detected_phases.append(current_phase) | |
| time.sleep(0.01) # Adjust delay as needed for smoother progress | |
| else: | |
| phase_images = result | |
| # Cleanup placeholders | |
| phase_placeholder.empty() | |
| video_placeholder.empty() | |
| progress_bar.empty() | |
| progress_text.empty() | |
| st.success("β Video processing complete!") | |
| # Define the order of phases for fixed grid positions | |
| phases_order = [ | |
| "Setup Phase", | |
| "Mid Backswing Phase", | |
| "Top Backswing Phase", | |
| "Mid Downswing Phase", | |
| "Ball Impact Phase", | |
| "Follow Through Phase" | |
| ] | |
| # Display Detected Phases in Fixed Grid Layout | |
| st.markdown("### π Detected Phases:") | |
| if phases_order: | |
| st.markdown("<div class='detected-phases'>", unsafe_allow_html=True) | |
| for phase in phases_order: | |
| image = phase_images.get(phase) | |
| if image: | |
| # Proper Base64 encoding | |
| encoded_image = base64.b64encode(image).decode('utf-8') | |
| st.markdown(f""" | |
| <div class='phase-card'> | |
| <img src="data:image/jpeg;base64,{encoded_image}" alt="{phase}"> | |
| <h3>{phase}</h3> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| # Placeholder for missing phases | |
| st.markdown(f""" | |
| <div class='phase-card'> | |
| <img src="https://via.placeholder.com/200x200.png?text=No+Image" alt="{phase}"> | |
| <h3>{phase}</h3> | |
| <p style="color: #aaa; font-size: 14px;">Not Detected</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| else: | |
| st.warning("β οΈ No distinct phases detected.") | |
| except Exception as e: | |
| st.error(f"An error occurred during video processing: {e}") | |
| # Footer | |
| st.markdown("<div class='footer'>Developed by Your Name | Β© 2024 Golf Swing Analyzer</div>", unsafe_allow_html=True) |