File size: 3,274 Bytes
4f5520c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import open3d as o3d
import numpy as np
import cadquery as cq

# Load the tokenizer from Qwen2-1.5B and model weights from filapro/cad-recode
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("filapro/cad-recode", trust_remote_code=True)

# Set device (GPU if available, CPU otherwise)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}")

@st.cache(allow_output_mutation=True)
def load_point_cloud(file):
    """Loads a point cloud from a uploaded file."""
    if not file:
        return None

    if file.type not in ("application/octet-stream", "text/plain"):
        st.error("Please upload a point cloud file (.pcd, .xyz, etc.)")
        return None

    try:
        point_cloud = o3d.io.read_point_cloud(file)
    except Exception as e:
        st.error(f"Error loading point cloud: {e}")
        return None

    return point_cloud

def prepare_input_data(point_cloud):
    """Prepares point cloud data for model input."""
    if not point_cloud:
        return None

    point_cloud_array = np.asarray(point_cloud.points).flatten()
    input_text = " ".join(map(str, point_cloud_array))
    return input_text

def generate_cad_code(input_text):
    """Runs inference and decodes generated output."""
    if not input_text:
        return None

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)

    cad_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return cad_code

def generate_cad_model(cad_code):
    """Generates a CAD model from the provided code."""
    if not cad_code:
        return None

    try:
        # Execute CAD code using CadQuery library
        exec(cad_code)
        cad_model = cq.Workplane("XY").val()
    except Exception as e:
        st.error(f"Error generating CAD model: {e}")
        return None

    return cad_model

def main():
    """Streamlit app for point cloud to CAD code conversion."""
    st.title("Point Cloud to CAD Code Converter")
    st.write("This app uses the filapro/cad-recode model to generate Python code for a 3D CAD model from your point cloud data.")

    uploaded_file = st.file_uploader("Upload Point Cloud File")
    point_cloud = load_point_cloud(uploaded_file)

    if point_cloud:
        input_text = prepare_input_data(point_cloud)
        cad_code = generate_cad_code(input_text)

        if cad_code:
            st.success("Generated Python CAD Code:")
            st.code(cad_code)

            cad_model = generate_cad_model(cad_code)
            if cad_model:
                # Optionally, use a 3D visualization library like trimesh
                # to display the generated CAD model (not included)
                st.success("Generated CAD Model (Visualization not yet implemented)")
                # st.write(cad_model)  # Replace with visualization code

if __name__ == "__main__":
    main()