srivatsavdamaraju's picture
Update app.py
4b69a2c verified
# Install required packages (if not done)
# !pip install gradio opencv-python mediapipe pandas
import gradio as gr
import cv2
import numpy as np
import os
import json
import pandas as pd
from datetime import datetime
from PIL import Image
import mediapipe as mp
# Setup folders
os.makedirs("pose_images", exist_ok=True)
json_path = "pose_dataset.json"
# Load or create dataset
if os.path.exists(json_path):
with open(json_path, "r") as f:
pose_dataset = json.load(f)
else:
pose_dataset = {}
# MediaPipe pose setup
mp_pose = mp.solutions.pose
pose_model = mp_pose.Pose(static_image_mode=True, model_complexity=2)
mp_drawing = mp.solutions.drawing_utils
mp_styles = mp.solutions.drawing_styles
# Define function to extract nodes and edges
def create_pose_graph_data(pose_landmarks):
nodes = {}
edges = []
for idx, lm in enumerate(pose_landmarks.landmark):
name = mp_pose.PoseLandmark(idx).name
nodes[idx] = {
"id": idx,
"name": name,
"x": round(lm.x, 4),
"y": round(lm.y, 4),
"z": round(lm.z, 4),
"visibility": round(lm.visibility, 3)
}
for connection in mp_pose.POSE_CONNECTIONS:
start_idx, end_idx = connection
if start_idx < len(pose_landmarks.landmark) and end_idx < len(pose_landmarks.landmark):
edges.append({
"from": start_idx,
"to": end_idx,
"from_name": mp_pose.PoseLandmark(start_idx).name,
"to_name": mp_pose.PoseLandmark(end_idx).name
})
return nodes, edges
# Main pose processing function
def process_pose(image, pose_description=""):
if image is None:
return None, "❌ Please upload an image.", "", None
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
pose_id = f"pose_{ts}"
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
results = pose_model.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
if not results.pose_landmarks:
return None, "❌ No pose detected.", "", None
overlay = img_bgr.copy()
mp_drawing.draw_landmarks(
overlay,
results.pose_landmarks,
mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=mp_styles.get_default_pose_landmarks_style()
)
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
overlay_path = f"pose_images/{pose_id}.png"
cv2.imwrite(overlay_path, overlay)
nodes, edges = create_pose_graph_data(results.pose_landmarks)
pose_data = {
"pose_id": pose_id,
"total_nodes": len(nodes),
"total_edges": len(edges),
"nodes": nodes,
"edges": edges,
"description": pose_description if pose_description else "No description provided"
}
pose_dataset[pose_id] = {
"pose_name": pose_id,
"image_path": overlay_path,
"pose_description": pose_description,
"pose_data": pose_data,
"timestamp": ts
}
with open(json_path, "w") as f:
json.dump(pose_dataset, f, indent=2)
data_display = f"""🎯 **Pose Analysis Results**
πŸ“Š **Graph Structure:**
- Total Nodes: {len(nodes)}
- Total Edges: {len(edges)}
πŸ“ **Pose Description:** {pose_description or "No description provided"}
πŸ” **Key Nodes (First 10):**
"""
for i, (idx, node) in enumerate(list(nodes.items())[:10]):
data_display += f"β€’ {node['name']}: ({node['x']}, {node['y']}, {node['z']}) [vis: {node['visibility']}]\n"
if len(nodes) > 10:
data_display += f"... and {len(nodes) - 10} more nodes\n"
data_display += "\nπŸ”— **Sample Edges:**\n"
for edge in edges[:5]:
data_display += f"β€’ {edge['from_name']} β†’ {edge['to_name']}\n"
if len(edges) > 5:
data_display += f"... and {len(edges) - 5} more connections\n"
data_display += f"\nπŸ’Ύ **Saved as:** {overlay_path}"
return overlay_rgb, data_display, f"βœ… Pose '{pose_id}' saved successfully!", None
# Save with description and return CSV
def save_with_description(image, description):
if image is None:
return None, "❌ Please upload an image first.", "❌ No image to process", None
overlay_img, data_display, status, _ = process_pose(image, description)
csv_file = create_csv_download()
return overlay_img, data_display, status, csv_file
# βœ… NEW Simplified CSV Creator (Only 3 fields)
def create_csv_download():
if not pose_dataset:
return None
csv_data = []
for pose_id, pose_info in pose_dataset.items():
pose_data = pose_info.get("pose_data", {})
pose_description = pose_info.get("pose_description", "")
image_path = pose_info.get("image_path", "")
csv_data.append({
"image_name": os.path.basename(image_path),
"pose_data": json.dumps(pose_data),
"pose_description": pose_description
})
df = pd.DataFrame(csv_data)
csv_filename = f"simplified_pose_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
df.to_csv(csv_filename, index=False)
return csv_filename
# Export only current pose
def export_current_pose_csv(image, description=""):
if image is None:
return None, "❌ Please upload an image first."
overlay_img, data_display, status, _ = process_pose(image, description)
if pose_dataset:
latest_pose_id = max(pose_dataset.keys())
pose_info = pose_dataset[latest_pose_id]
csv_data = [{
"image_name": os.path.basename(pose_info.get("image_path", "")),
"pose_data": json.dumps(pose_info.get("pose_data", {})),
"pose_description": pose_info.get("pose_description", "")
}]
df = pd.DataFrame(csv_data)
csv_filename = f"current_pose_{latest_pose_id}.csv"
df.to_csv(csv_filename, index=False)
return csv_filename, f"βœ… Current pose exported as {csv_filename}"
return None, "❌ No pose data to export"
# Gradio Interface
with gr.Blocks(title="🧘 Simplified Pose Analysis Tool") as demo:
gr.Markdown("# 🧘 Pose Analysis with MediaPipe + CSV Export")
gr.Markdown("Upload a pose image and extract keypoints, description, and download results.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## πŸ“€ Upload")
input_image = gr.Image(type="numpy", label="Upload Pose Image")
pose_description = gr.Textbox(
label="Pose Description",
placeholder="e.g. 'Triangle Pose with right arm up'",
lines=3
)
with gr.Row():
analyze_btn = gr.Button("πŸ” Analyze Pose", variant="primary")
save_btn = gr.Button("πŸ’Ύ Save with Description", variant="secondary")
gr.Markdown("### πŸ“₯ Download Options")
with gr.Row():
download_current_btn = gr.Button("πŸ“„ Download Current Pose CSV", variant="secondary")
download_all_btn = gr.Button("πŸ“ Download All Poses CSV", variant="secondary")
current_csv_download = gr.File(label="Current Pose CSV", visible=False)
all_csv_download = gr.File(label="All Poses CSV", visible=False)
download_status = gr.Textbox(label="Download Status", visible=False)
with gr.Column(scale=1):
gr.Markdown("## πŸ“Š Results")
output_image = gr.Image(label="Pose with Overlay")
status_text = gr.Textbox(label="Status", lines=1)
gr.Markdown("## 🧠 Pose Data")
pose_data_display = gr.Textbox(label="Pose Details", lines=15, show_copy_button=True)
# Event handlers
analyze_btn.click(
fn=lambda img: process_pose(img, ""),
inputs=[input_image],
outputs=[output_image, pose_data_display, status_text, current_csv_download]
)
save_btn.click(
fn=save_with_description,
inputs=[input_image, pose_description],
outputs=[output_image, pose_data_display, status_text, current_csv_download]
)
download_current_btn.click(
fn=export_current_pose_csv,
inputs=[input_image, pose_description],
outputs=[current_csv_download, download_status]
).then(
fn=lambda: gr.update(visible=True),
outputs=[current_csv_download]
).then(
fn=lambda: gr.update(visible=True),
outputs=[download_status]
)
download_all_btn.click(
fn=create_csv_download,
outputs=[all_csv_download]
).then(
fn=lambda: gr.update(visible=True),
outputs=[all_csv_download]
)
input_image.change(
fn=lambda img: process_pose(img, ""),
inputs=[input_image],
outputs=[output_image, pose_data_display, status_text, current_csv_download]
)
gr.Markdown("""
## πŸ“ How It Works
1. Upload a pose image
2. Automatically analyze and overlay pose keypoints
3. Add an optional description
4. Save and export to CSV
### CSV Format (Simplified):
- `image_name`: Name of the saved pose image
- `pose_data`: JSON string containing nodes and edges
- `pose_description`: Your text description
""")
# Launch app
demo.launch(share=True)