Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import open3d as o3d
|
5 |
+
import numpy as np
|
6 |
+
import cadquery as cq
|
7 |
+
|
8 |
+
# Load the tokenizer from Qwen2-1.5B and model weights from filapro/cad-recode
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B", trust_remote_code=True)
|
10 |
+
model = AutoModelForCausalLM.from_pretrained("filapro/cad-recode", trust_remote_code=True)
|
11 |
+
|
12 |
+
# Set device (GPU if available, CPU otherwise)
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
model.to(device)
|
15 |
+
print(f"Model loaded on {device}")
|
16 |
+
|
17 |
+
@st.cache(allow_output_mutation=True)
|
18 |
+
def load_point_cloud(file):
|
19 |
+
"""Loads a point cloud from a uploaded file."""
|
20 |
+
if not file:
|
21 |
+
return None
|
22 |
+
|
23 |
+
if file.type not in ("application/octet-stream", "text/plain"):
|
24 |
+
st.error("Please upload a point cloud file (.pcd, .xyz, etc.)")
|
25 |
+
return None
|
26 |
+
|
27 |
+
try:
|
28 |
+
point_cloud = o3d.io.read_point_cloud(file)
|
29 |
+
except Exception as e:
|
30 |
+
st.error(f"Error loading point cloud: {e}")
|
31 |
+
return None
|
32 |
+
|
33 |
+
return point_cloud
|
34 |
+
|
35 |
+
def prepare_input_data(point_cloud):
|
36 |
+
"""Prepares point cloud data for model input."""
|
37 |
+
if not point_cloud:
|
38 |
+
return None
|
39 |
+
|
40 |
+
point_cloud_array = np.asarray(point_cloud.points).flatten()
|
41 |
+
input_text = " ".join(map(str, point_cloud_array))
|
42 |
+
return input_text
|
43 |
+
|
44 |
+
def generate_cad_code(input_text):
|
45 |
+
"""Runs inference and decodes generated output."""
|
46 |
+
if not input_text:
|
47 |
+
return None
|
48 |
+
|
49 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
|
50 |
+
inputs = {key: val.to(device) for key, val in inputs.items()}
|
51 |
+
|
52 |
+
with torch.no_grad():
|
53 |
+
outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
|
54 |
+
|
55 |
+
cad_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
56 |
+
return cad_code
|
57 |
+
|
58 |
+
def generate_cad_model(cad_code):
|
59 |
+
"""Generates a CAD model from the provided code."""
|
60 |
+
if not cad_code:
|
61 |
+
return None
|
62 |
+
|
63 |
+
try:
|
64 |
+
# Execute CAD code using CadQuery library
|
65 |
+
exec(cad_code)
|
66 |
+
cad_model = cq.Workplane("XY").val()
|
67 |
+
except Exception as e:
|
68 |
+
st.error(f"Error generating CAD model: {e}")
|
69 |
+
return None
|
70 |
+
|
71 |
+
return cad_model
|
72 |
+
|
73 |
+
def main():
|
74 |
+
"""Streamlit app for point cloud to CAD code conversion."""
|
75 |
+
st.title("Point Cloud to CAD Code Converter")
|
76 |
+
st.write("This app uses the filapro/cad-recode model to generate Python code for a 3D CAD model from your point cloud data.")
|
77 |
+
|
78 |
+
uploaded_file = st.file_uploader("Upload Point Cloud File")
|
79 |
+
point_cloud = load_point_cloud(uploaded_file)
|
80 |
+
|
81 |
+
if point_cloud:
|
82 |
+
input_text = prepare_input_data(point_cloud)
|
83 |
+
cad_code = generate_cad_code(input_text)
|
84 |
+
|
85 |
+
if cad_code:
|
86 |
+
st.success("Generated Python CAD Code:")
|
87 |
+
st.code(cad_code)
|
88 |
+
|
89 |
+
cad_model = generate_cad_model(cad_code)
|
90 |
+
if cad_model:
|
91 |
+
# Optionally, use a 3D visualization library like trimesh
|
92 |
+
# to display the generated CAD model (not included)
|
93 |
+
st.success("Generated CAD Model (Visualization not yet implemented)")
|
94 |
+
# st.write(cad_model) # Replace with visualization code
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|