Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import pathlib | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| suffix_dict = {"metal_nonmetal_classifier": ".csv"} | |
| def create_temp_file(path: str) -> str: | |
| temp_dir = tempfile.gettempdir() | |
| temp_folder = os.path.join(temp_dir, "gt4sd_crystal") | |
| os.makedirs(temp_folder, exist_ok=True) | |
| # Clean up directory | |
| for i in os.listdir(temp_folder): | |
| print("Removing", i) | |
| os.remove(os.path.join(temp_folder, i)) | |
| temp_path = os.path.join(temp_folder, path.split("/")[-1]) | |
| shutil.copy2(path, temp_path) | |
| return temp_path | |
| def main(property: str, data_file: str): | |
| print(data_file, data_file.orig_name, data_file.name) | |
| if data_file is None: | |
| raise TypeError("You have to pass either an input file for the crystal model") | |
| # Copy file into a UNIQUE temporary directory | |
| file_path = Path(create_temp_file(data_file.name)) | |
| folder = file_path.parent | |
| print(file_path) | |
| print(folder) | |
| if file_path.suffix == ".cif": | |
| input_path = folder | |
| elif file_path.suffix == ".csv": | |
| input_path = file_path | |
| elif file_path.suffix == ".zip": | |
| # Unzip zip | |
| shutil.unpack_archive(file_path, file_path.parent) | |
| if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0: | |
| raise ValueError("No `.cif` files were found inside the `.zip`.") | |
| input_path = folder | |
| else: | |
| raise TypeError( | |
| "You have to pass a `.csv` (for `metal_nonmetal_classifier`)," | |
| " a `.cif` (for all other properties) or a `.zip` with multiple" | |
| f" `.cif` files. Not {type(data_file)}." | |
| ) | |
| prop_name = property.replace(" ", "_").lower() | |
| algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name] | |
| # Pass hyperparameters if applicable | |
| kwargs = {"algorithm_version": "v0"} | |
| model = algo(config(**kwargs)) | |
| result = model(input=input_path) | |
| return pd.DataFrame(result) | |
| if __name__ == "__main__": | |
| # Preparation (retrieve all available algorithms) | |
| properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1] | |
| properties = list(map(lambda x: x.replace("_", " ").title(), properties)) | |
| # Load metadata | |
| metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
| examples = [ | |
| ["Formation Energy", metadata_root.joinpath("7206075.cif")], | |
| ["Bulk moduli", metadata_root.joinpath("crystals.zip")], | |
| ["Metal Nonmetal Classifier", metadata_root.joinpath("metal.csv")], | |
| ["Bulk moduli", metadata_root.joinpath("9000046.cif")], | |
| ] | |
| with open(metadata_root.joinpath("article.md"), "r") as f: | |
| article = f.read() | |
| with open(metadata_root.joinpath("description.md"), "r") as f: | |
| description = f.read() | |
| demo = gr.Interface( | |
| fn=main, | |
| title="Crystal properties", | |
| inputs=[ | |
| gr.Dropdown(properties, label="Property", value="Instability"), | |
| gr.File( | |
| file_types=[".cif", ".csv", ".zip"], | |
| label="Input file for crystal model", | |
| ), | |
| ], | |
| outputs=gr.DataFrame(label="Output"), | |
| article=article, | |
| description=description, | |
| examples=examples, | |
| ) | |
| demo.launch(debug=True, show_error=True) | |