Paras Shah
Final beautification
b07f96b
import gc
import laspy
import torch
import base64
import tempfile
import numpy as np
import open3d as o3d
import streamlit as st
import plotly.graph_objs as go
import pointnet2_cls_msg as pn2
from utils import calculate_dbh, calc_canopy_volume, CLASSES
from SingleTreePointCloudLoader import SingleTreePointCloudLoader
gc.enable()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with st.spinner("Loading PointNet++ model..."):
checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device(device))
classifier = pn2.get_model(num_class=4, normal_channel=False)
classifier.load_state_dict(checkpoint['model_state_dict'])
classifier.eval()
side_bg = "static/sidebar.png"
side_bg_ext = "png"
st.markdown(
f"""
<style>
[data-testid="stSidebar"] {{
background: url(data:image/{side_bg_ext};base64,{base64.b64encode(open(side_bg, "rb").read()).decode()});
color: #ffff00;
}}
[data-testid="stSidebarUserContent"] {{
padding-bottom: 3rem;
}}
.stMainBlockContainer {{
padding-top: 3rem;
}}
.main > div {{
padding-top: 3rem;
}}
</style>
""",
unsafe_allow_html=True
)
st.sidebar.markdown(
body=
"<div style='text-align: justify; color: #ffff00'>"
"<h1 style='color: #ffff00; font-size: 4rem;'>About</h1>"
"The species <strong>Pinus sylvestris (Scots Pine), Fagus sylvatica "
"(European Beech), Picea abies (Norway Spruce), and Betula pendula "
"(Silver Birch)</strong> are native to Europe and parts "
"of Asia but are also found in India (Parts of Himachal Pradesh, "
"Uttarakhand, Jammu and Kashmir, Sikkim and Arunachal Pradesh). "
"These temperate species, typically thriving in boreal and montane ecosystems, "
"are occasionally introduced in cooler Indian regions like the Himalayan "
"foothills for afforestation or experimental forestry, where climatic "
"conditions are favourable. However, their growth and ecological interactions "
"in India may vary significantly due to the region's unique biodiversity "
"and environmental factors.<br><br>"
"This AI-powered application employs the PointNet++ deep learning "
"architecture, optimized for processing 3D point cloud data from "
"individual <code>.las</code> <code>.laz</code> <code>.pcd</code> files "
"(fused aerial and terrestrial LiDAR) to classify tree species up to four classes "
"(<strong>Pinus sylvestris, Fagus sylvatica, Picea abies, and Betula pendula</strong>) "
"with associated confidence scores. Additionally, it calculates critical "
"metrics such as Diameter at Breast Height (DBH), actual height and "
"customizable canopy volume, enabling precise refinement of predictions "
"and analyses. By integrating species-specific and volumetric insights, "
"the tool enhances ecological research workflows, facilitating data-driven "
"decision-making.<br><br>"
"<div style='text-align: right; font-size: 10px;'>&copy;Copyright: WII, "
"Technology Laboratory<br>Authors: Shashank Sawan &amp; Paras Shah</div></div>"
,
unsafe_allow_html=True,
)
st.image("static/header.png", use_container_width=True)
uploaded_file = st.file_uploader(
label="Upload Point Cloud Data",
type=['laz', 'las', 'pcd'],
help="Please upload trees with ground points removed"
)
col1, col2 = st.columns(2)
with col1:
st.image("static/canopy.png", use_container_width=True)
with col2:
CANOPY_VOLUME = st.slider(
label="Canopy Volume in % (Z)",
min_value=10,
max_value=90,
value=70,
step=1,
help=
"Adjust the Z-threshold value to calculate the canopy volume "
"within specified limits, it uses Quickhull and DBSCAN algorithms. "
)
st.markdown(
body=
"<div style='text-align: justify; font-size: 13px;'>"
"The <b>Quickhull algorithm</b> computes the convex hull of a set of points "
"by identifying extreme points to form an initial boundary and recursively "
"refining it by adding the farthest points until all points lie within the "
"convex boundary. It uses a divide-and-conquer approach, similar to QuickSort. "
"<br>"
"<b>DBSCAN (Density-Based Spatial Clustering of Applications with Noise)</b> is "
"a density-based clustering algorithm that groups densely packed points within "
"a specified distance 'eps' and minimum points 'minpoints', while treating "
"sparse points as noise. It effectively identifies arbitrarily shaped clusters "
"and handles outliers, making it suitable for spatial data and anomaly detection."
"</div><br>",
unsafe_allow_html=True
)
col1, col2 = st.columns(2)
with col1:
st.image("static/dbh.png", use_container_width=True)
with col2:
DBH_HEIGHT = st.slider(
label="DBH (Diameter above Breast Height, in metres) (H)",
min_value=1.3,
max_value=1.4,
value=1.4,
step=0.01,
help=
"Adjust to calculate the DBH value within specified limits, "
"it utilizes Least square circle fitting method Levenberg-Marquardt "
"optimization technique."
)
st.markdown(
body=
"<div style='text-align: justify; font-size:13px;'>"
"The <b>Least Squares Circle Fitting method</b> is used to find the "
"best-fitting circle to a set of 2D points by minimizing the sum of "
"squared distances between each point and the circle's circumference. "
"<b>Levenberg-Marquardt Optimization</b> is used to fit models "
"(like circles) to point cloud data by minimizing the error between "
"the model and the actual points.</div><br>",
unsafe_allow_html=True
)
proceed = None
if uploaded_file:
try:
with st.spinner("Reading point cloud file..."):
file_type = uploaded_file.name.split('.')[-1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(uploaded_file.read())
temp_file_path = tmp.name
if file_type == 'pcd':
pcd = o3d.io.read_point_cloud(temp_file_path)
points = np.asarray(pcd.points)
else:
point_cloud = laspy.read(temp_file_path)
points = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
proceed = st.button("Run model")
except Exception as e:
st.error(f"An error occured: {str(e)}")
if proceed:
try:
with st.spinner("Calculating tree inventory..."):
dbh, trunk_points = calculate_dbh(points, DBH_HEIGHT)
z_min = np.min(points[:, 2])
z_max = np.max(points[:, 2])
height = z_max - z_min
canopy_volume, canopy_points = calc_canopy_volume(points, CANOPY_VOLUME, height, z_min)
with st.spinner("Visualizing point cloud..."):
fig = go.Figure()
fig.add_trace(go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode='markers',
marker=dict(
size=0.5,
color=points[:, 2],
colorscale='Viridis',
opacity=1.0,
),
name='Tree'
))
fig.add_trace(go.Scatter3d(
x=canopy_points[:, 0],
y=canopy_points[:, 1],
z=canopy_points[:, 2],
mode='markers',
marker=dict(
size=2,
color='blue',
opacity=0.8,
),
name='Canopy points'
))
fig.add_trace(go.Scatter3d(
x=trunk_points[:, 0],
y=trunk_points[:, 1],
z=trunk_points[:, 2],
mode='markers',
marker=dict(
size=2,
color='red',
opacity=0.9,
),
name='DBH'
))
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=0),
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode='data'
),
showlegend=False
)
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
st.markdown("""
<style>
.centered-plot {
text-align: center;
}
</style>
""", unsafe_allow_html=True)
st.plotly_chart(fig, use_container_width=True)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
with st.spinner("Running inference..."):
testFile = SingleTreePointCloudLoader(temp_file_path, file_type)
testFileLoader = torch.utils.data.DataLoader(testFile, batch_size=8, shuffle=False, num_workers=0)
point_set, _ = next(iter(testFileLoader))
point_set = point_set.transpose(2, 1)
with torch.no_grad():
logits, _ = classifier(point_set)
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence_score = (probabilities.numpy().tolist())[0][predicted_class] * 100
predicted_label = CLASSES[predicted_class]
st.write(f"**Predicted class: {predicted_label}**")
st.write(f"**Confidence score: {confidence_score:.2f}%**")
st.write(f"**Height of tree: {height:.2f}m**")
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
st.write(f"**DBH: {dbh:.2f}m**")
except Exception as e:
st.error(f"An error occured: {str(e)}")