|
|
|
""" |
|
Setup script for the Semantic Segmentation Gradio App |
|
This script helps install dependencies and set up the environment |
|
""" |
|
|
|
import subprocess |
|
import sys |
|
import os |
|
from pathlib import Path |
|
|
|
def run_command(command, description): |
|
"""Run a command and handle errors.""" |
|
print(f"\nπ {description}...") |
|
try: |
|
result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True) |
|
print(f"β
{description} completed successfully") |
|
return True |
|
except subprocess.CalledProcessError as e: |
|
print(f"β Error during {description}:") |
|
print(f"Command: {command}") |
|
print(f"Error: {e.stderr}") |
|
return False |
|
|
|
def check_python_version(): |
|
"""Check if Python version is compatible.""" |
|
version = sys.version_info |
|
if version.major < 3 or (version.major == 3 and version.minor < 8): |
|
print("β Python 3.8 or higher is required") |
|
sys.exit(1) |
|
print(f"β
Python {version.major}.{version.minor}.{version.micro} detected") |
|
|
|
def install_dependencies(): |
|
"""Install required dependencies.""" |
|
requirements = [ |
|
"gradio>=4.0.0", |
|
"torch>=1.9.0", |
|
"torchvision>=0.10.0", |
|
"transformers>=4.21.0", |
|
"pillow>=8.0.0", |
|
"numpy>=1.21.0", |
|
"matplotlib>=3.5.0", |
|
"requests>=2.25.0", |
|
] |
|
|
|
print("\nπ¦ Installing dependencies...") |
|
for req in requirements: |
|
if not run_command(f"pip install {req}", f"Installing {req.split('>=')[0]}"): |
|
return False |
|
return True |
|
|
|
def create_directory_structure(): |
|
"""Create necessary directories.""" |
|
directories = [ |
|
"src", |
|
"src/models", |
|
"sample_images", |
|
"outputs" |
|
] |
|
|
|
for directory in directories: |
|
Path(directory).mkdir(parents=True, exist_ok=True) |
|
print(f"π Created directory: {directory}") |
|
|
|
def download_sample_images(): |
|
"""Download some sample images for testing.""" |
|
import requests |
|
from PIL import Image |
|
import io |
|
|
|
sample_urls = { |
|
"street_scene_1.jpg": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=800", |
|
"street_scene_2.jpg": "https://images.unsplash.com/photo-1502920917128-1aa500764cbd?w=800", |
|
"urban_road.jpg": "https://images.unsplash.com/photo-1516738901171-8eb4fc13bd20?w=800", |
|
} |
|
|
|
sample_dir = Path("sample_images") |
|
sample_dir.mkdir(exist_ok=True) |
|
|
|
print("\nπΌοΈ Downloading sample images...") |
|
for filename, url in sample_urls.items(): |
|
try: |
|
response = requests.get(url, timeout=30) |
|
response.raise_for_status() |
|
|
|
image = Image.open(io.BytesIO(response.content)) |
|
image_path = sample_dir / filename |
|
image.save(image_path) |
|
print(f"β
Downloaded: {filename}") |
|
|
|
except Exception as e: |
|
print(f"β οΈ Failed to download {filename}: {e}") |
|
|
|
def create_launch_script(): |
|
"""Create a simple launch script.""" |
|
launch_script = '''#!/usr/bin/env python3 |
|
""" |
|
Launch script for the Semantic Segmentation App |
|
""" |
|
|
|
import sys |
|
import os |
|
|
|
# Add the current directory to the path |
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
# Import and run the app |
|
try: |
|
from complete_gradio_app import create_gradio_interface |
|
import torch |
|
|
|
print("π Starting Semantic Segmentation App...") |
|
print("π» Device:", "CUDA" if torch.cuda.is_available() else "CPU") |
|
|
|
demo = create_gradio_interface() |
|
demo.launch( |
|
share=True, |
|
debug=True, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |
|
|
|
except ImportError as e: |
|
print(f"β Import error: {e}") |
|
print("Please make sure all dependencies are installed by running: python setup.py") |
|
|
|
except Exception as e: |
|
print(f"β Error starting app: {e}") |
|
''' |
|
|
|
with open("launch_app.py", "w") as f: |
|
f.write(launch_script) |
|
|
|
|
|
if os.name != 'nt': |
|
os.chmod("launch_app.py", 0o755) |
|
|
|
print("β
Created launch script: launch_app.py") |
|
|
|
def create_readme(): |
|
"""Create a README file with usage instructions.""" |
|
readme_content = '''# Semantic Segmentation Gradio App |
|
|
|
A user-friendly web interface for semantic segmentation using OneFormer and Mask2Former models. |
|
|
|
## π Quick Start |
|
|
|
1. **Install dependencies:** |
|
```bash |
|
python setup.py |
|
``` |
|
|
|
2. **Launch the app:** |
|
```bash |
|
python launch_app.py |
|
``` |
|
|
|
Or run directly: |
|
```bash |
|
python complete_gradio_app.py |
|
``` |
|
|
|
3. **Open your browser** and go to the provided local URL (usually http://localhost:7860) |
|
|
|
## π Requirements |
|
|
|
- Python 3.8+ |
|
- CUDA-compatible GPU (optional, but recommended) |
|
- At least 8GB RAM |
|
- Internet connection (for model downloads) |
|
|
|
## π― Features |
|
|
|
- **Two State-of-the-Art Models:** |
|
- OneFormer: Universal segmentation (semantic, instance, panoptic) |
|
- Mask2Former: High-accuracy semantic segmentation |
|
|
|
- **User-Friendly Interface:** |
|
- Upload custom images |
|
- Select from sample images |
|
- Adjustable overlay transparency |
|
- Real-time processing |
|
|
|
- **Professional Output:** |
|
- Colored segmentation overlays |
|
- Detailed class statistics |
|
- High-quality visualizations |
|
|
|
## π§ Troubleshooting |
|
|
|
### Common Issues: |
|
|
|
1. **CUDA out of memory:** |
|
- Reduce image size |
|
- Use CPU instead of GPU |
|
|
|
2. **Model download fails:** |
|
- Check internet connection |
|
- Try again (models are large ~1-2GB each) |
|
|
|
3. **ImportError:** |
|
- Run `python setup.py` again |
|
- Check Python version (3.8+ required) |
|
|
|
### Performance Tips: |
|
|
|
- First model load takes time (downloading from HuggingFace) |
|
- GPU acceleration significantly speeds up processing |
|
- Images are automatically resized to prevent memory issues |
|
|
|
## π Supported Classes |
|
|
|
The models are trained on Cityscapes dataset and can recognize: |
|
- Road, sidewalk, building, wall, fence |
|
- Traffic light, traffic sign, pole |
|
- Vegetation, terrain, sky |
|
- Person, rider, car, truck, bus, train, motorcycle, bicycle |
|
|
|
## π¨ Color Coding |
|
|
|
Each class is visualized with a specific color following Cityscapes conventions: |
|
- Road: Dark purple |
|
- Sky: Steel blue |
|
- Person: Crimson |
|
- Car: Dark blue |
|
- Vegetation: Olive green |
|
- And more... |
|
|
|
## π License |
|
|
|
This project uses pre-trained models from HuggingFace: |
|
- OneFormer: [Model License](https://huggingface.co/shi-labs/oneformer_cityscapes_swin_large) |
|
- Mask2Former: [Model License](https://huggingface.co/facebook/mask2former-swin-large-cityscapes-semantic) |
|
|
|
## π€ Contributing |
|
|
|
Feel free to submit issues and enhancement requests! |
|
''' |
|
|
|
with open("README.md", "w") as f: |
|
f.write(readme_content) |
|
|
|
print("β
Created README.md") |
|
|
|
def main(): |
|
"""Main setup function.""" |
|
print("π― Semantic Segmentation App Setup") |
|
print("=" * 50) |
|
|
|
|
|
check_python_version() |
|
|
|
|
|
create_directory_structure() |
|
|
|
|
|
if not install_dependencies(): |
|
print("\nβ Failed to install some dependencies. Please check the errors above.") |
|
return False |
|
|
|
|
|
try: |
|
download_sample_images() |
|
except Exception as e: |
|
print(f"β οΈ Warning: Could not download sample images: {e}") |
|
|
|
|
|
create_launch_script() |
|
|
|
|
|
create_readme() |
|
|
|
print("\n" + "=" * 50) |
|
print("β
Setup completed successfully!") |
|
print("\nπ To launch the app, run:") |
|
print(" python launch_app.py") |
|
print("\nπ For more information, see README.md") |
|
|
|
return True |
|
|
|
if __name__ == "__main__": |
|
success = main() |
|
sys.exit(0 if success else 1) |