engrharis commited on
Commit
4f5520c
·
verified ·
1 Parent(s): 71e0269

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import open3d as o3d
5
+ import numpy as np
6
+ import cadquery as cq
7
+
8
+ # Load the tokenizer from Qwen2-1.5B and model weights from filapro/cad-recode
9
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B", trust_remote_code=True)
10
+ model = AutoModelForCausalLM.from_pretrained("filapro/cad-recode", trust_remote_code=True)
11
+
12
+ # Set device (GPU if available, CPU otherwise)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+ print(f"Model loaded on {device}")
16
+
17
+ @st.cache(allow_output_mutation=True)
18
+ def load_point_cloud(file):
19
+ """Loads a point cloud from a uploaded file."""
20
+ if not file:
21
+ return None
22
+
23
+ if file.type not in ("application/octet-stream", "text/plain"):
24
+ st.error("Please upload a point cloud file (.pcd, .xyz, etc.)")
25
+ return None
26
+
27
+ try:
28
+ point_cloud = o3d.io.read_point_cloud(file)
29
+ except Exception as e:
30
+ st.error(f"Error loading point cloud: {e}")
31
+ return None
32
+
33
+ return point_cloud
34
+
35
+ def prepare_input_data(point_cloud):
36
+ """Prepares point cloud data for model input."""
37
+ if not point_cloud:
38
+ return None
39
+
40
+ point_cloud_array = np.asarray(point_cloud.points).flatten()
41
+ input_text = " ".join(map(str, point_cloud_array))
42
+ return input_text
43
+
44
+ def generate_cad_code(input_text):
45
+ """Runs inference and decodes generated output."""
46
+ if not input_text:
47
+ return None
48
+
49
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
50
+ inputs = {key: val.to(device) for key, val in inputs.items()}
51
+
52
+ with torch.no_grad():
53
+ outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
54
+
55
+ cad_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ return cad_code
57
+
58
+ def generate_cad_model(cad_code):
59
+ """Generates a CAD model from the provided code."""
60
+ if not cad_code:
61
+ return None
62
+
63
+ try:
64
+ # Execute CAD code using CadQuery library
65
+ exec(cad_code)
66
+ cad_model = cq.Workplane("XY").val()
67
+ except Exception as e:
68
+ st.error(f"Error generating CAD model: {e}")
69
+ return None
70
+
71
+ return cad_model
72
+
73
+ def main():
74
+ """Streamlit app for point cloud to CAD code conversion."""
75
+ st.title("Point Cloud to CAD Code Converter")
76
+ st.write("This app uses the filapro/cad-recode model to generate Python code for a 3D CAD model from your point cloud data.")
77
+
78
+ uploaded_file = st.file_uploader("Upload Point Cloud File")
79
+ point_cloud = load_point_cloud(uploaded_file)
80
+
81
+ if point_cloud:
82
+ input_text = prepare_input_data(point_cloud)
83
+ cad_code = generate_cad_code(input_text)
84
+
85
+ if cad_code:
86
+ st.success("Generated Python CAD Code:")
87
+ st.code(cad_code)
88
+
89
+ cad_model = generate_cad_model(cad_code)
90
+ if cad_model:
91
+ # Optionally, use a 3D visualization library like trimesh
92
+ # to display the generated CAD model (not included)
93
+ st.success("Generated CAD Model (Visualization not yet implemented)")
94
+ # st.write(cad_model) # Replace with visualization code
95
+
96
+ if __name__ == "__main__":
97
+ main()