File size: 3,490 Bytes
c703bc8
19e399c
c703bc8
19e399c
 
 
c703bc8
 
 
19e399c
c703bc8
 
 
 
19e399c
c703bc8
19e399c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c703bc8
19e399c
 
 
 
c703bc8
19e399c
 
 
 
 
 
 
 
 
c703bc8
 
 
 
 
19e399c
 
c703bc8
 
 
 
19e399c
 
 
 
 
 
c703bc8
19e399c
c703bc8
19e399c
c703bc8
 
 
19e399c
 
c703bc8
19e399c
 
 
 
c703bc8
 
19e399c
c703bc8
 
19e399c
c703bc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import logging
import os
import pathlib
import shutil
import tempfile
from pathlib import Path

import gradio as gr
import pandas as pd
from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

suffix_dict = {"metal_nonmetal_classifier": ".csv"}


def create_temp_file(path: str) -> str:
    temp_dir = tempfile.gettempdir()
    temp_folder = os.path.join(temp_dir, "gt4sd_crystal")
    os.makedirs(temp_folder, exist_ok=True)
    # Clean up directory
    for i in os.listdir(temp_folder):
        print("Removing", i)
        os.remove(os.path.join(temp_folder, i))

    temp_path = os.path.join(temp_folder, path.split("/")[-1])
    shutil.copy2(path, temp_path)
    return temp_path


def main(property: str, data_file: str):

    print(data_file, data_file.orig_name, data_file.name)

    if data_file is None:
        raise TypeError("You have to pass either an input file for the crystal model")

    # Copy file into a UNIQUE temporary directory
    file_path = Path(create_temp_file(data_file.name))
    folder = file_path.parent
    print(file_path)
    print(folder)
    if file_path.suffix == ".cif":
        input_path = folder
    elif file_path.suffix == ".csv":
        input_path = file_path
    elif file_path.suffix == ".zip":
        # Unzip zip
        shutil.unpack_archive(file_path, file_path.parent)
        if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0:
            raise ValueError("No `.cif` files were found inside the `.zip`.")
        input_path = folder
    else:
        raise TypeError(
            "You have to pass a `.csv` (for `metal_nonmetal_classifier`),"
            " a `.cif` (for all other properties) or a `.zip` with multiple"
            f" `.cif` files. Not {type(data_file)}."
        )

    prop_name = property.replace(" ", "_").lower()
    algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name]
    # Pass hyperparameters if applicable
    kwargs = {"algorithm_version": "v0"}
    model = algo(config(**kwargs))

    result = model(input=input_path)
    return pd.DataFrame(result)


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
    properties = list(map(lambda x: x.replace("_", " ").title(), properties))

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = [
        ["Formation Energy", metadata_root.joinpath("7206075.cif")],
        ["Bulk moduli", metadata_root.joinpath("crystals.zip")],
        ["Metal Nonmetal Classifier", metadata_root.joinpath("metal.csv")],
        ["Bulk moduli", metadata_root.joinpath("9000046.cif")],
    ]

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=main,
        title="Crystal properties",
        inputs=[
            gr.Dropdown(properties, label="Property", value="Instability"),
            gr.File(
                file_types=[".cif", ".csv", ".zip"],
                label="Input file for crystal model",
            ),
        ],
        outputs=gr.DataFrame(label="Output"),
        article=article,
        description=description,
        examples=examples,
    )
    demo.launch(debug=True, show_error=True)