ManasSharma07 commited on
Commit
104d956
·
verified ·
1 Parent(s): 253b7f3

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +92 -112
src/streamlit_app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import os
 
3
  import tempfile
4
  import torch
5
  # FOR CPU only mode
@@ -24,7 +25,11 @@ import yaml # Added for FairChem reference energies
24
  import subprocess
25
  import sys
26
  import pkg_resources
 
 
27
  mattersim_available = False
 
 
28
  # try:
29
  # subprocess.check_call([sys.executable, "-m", "pip", "install", "mattersim"])
30
  # except Exception as e:
@@ -1028,14 +1033,14 @@ if model_type == "ORB":
1028
  if "omat" in selected_model:
1029
  st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
1030
  # selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
1031
- selected_default_dtype = 'float64'
1032
  if model_type == "MatterSim":
1033
  selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
1034
  model_path = MATTERSIM_MODELS[selected_model]
1035
  if model_type == "SEVEN_NET":
1036
  selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys()))
1037
  if selected_model == '7net-mf-ompa':
1038
- selected_modal_7net = st.sidebar.selectbox("Select Modal (multi fidelity model):", ['omat', 'mpa'])
1039
  model_path = SEVEN_NET_MODELS[selected_model]
1040
  if atoms is not None:
1041
  if not check_atom_limit(atoms, selected_model):
@@ -1055,7 +1060,8 @@ task = st.sidebar.selectbox("Select Calculation Task:",
1055
  "Energy + Forces Calculation",
1056
  "Atomization/Cohesive Energy", # New Task Added
1057
  "Geometry Optimization",
1058
- "Cell + Geometry Optimization"])
 
1059
 
1060
  if "Optimization" in task:
1061
  st.sidebar.markdown("### Optimization Parameters")
@@ -1117,18 +1123,19 @@ if atoms is not None:
1117
  table_placeholder = st.empty() # Recreate placeholder for table
1118
 
1119
  try:
 
1120
  with st.spinner("Running calculation... Please wait."):
1121
  calc_atoms = atoms.copy()
1122
 
1123
  if model_type == "MACE":
1124
  # st.write("Setting up MACE calculator...")
1125
- calc = get_mace_model(model_path, device, selected_default_dtype)
1126
  elif model_type == "FairChem": # FairChem
1127
  # st.write("Setting up FairChem calculator...")
1128
  # Workaround for potential dtype issues when switching models
1129
- if device == "cpu": # Ensure torch default dtype matches if needed
1130
- torch.set_default_dtype(torch.float32)
1131
- _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
1132
  calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
1133
  elif model_type == "ORB":
1134
  # st.write("Setting up ORB calculator...")
@@ -1136,10 +1143,14 @@ if atoms is not None:
1136
  calc = ORBCalculator(orbff, device=device)
1137
  elif model_type == "MatterSim":
1138
  # st.write("Setting up MatterSim calculator...")
 
 
 
 
1139
  calc = MatterSimCalculator(load_path=model_path, device=device)
1140
  elif model_type == "SEVEN_NET":
1141
  # st.write("Setting up SEVENNET calculator...")
1142
- if model_path=='7net-mf-omp':
1143
  calc = SevenNetCalculator(model=model_path, modal=selected_modal_7net, device=device)
1144
  else:
1145
  calc = SevenNetCalculator(model=model_path, device=device)
@@ -1234,6 +1245,7 @@ if atoms is not None:
1234
  results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
1235
 
1236
  elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
 
1237
  opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
1238
  # Create temporary trajectory file
1239
  traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
@@ -1268,30 +1280,15 @@ if atoms is not None:
1268
 
1269
  if "Optimization" in task and "Final Energy" in results: # Check if opt was successful
1270
  st.markdown("### Optimized Structure")
1271
- # Need get_structure_viz function that takes atoms obj
1272
- def get_structure_viz_simple(atoms_obj_viz):
1273
- xyz_str_viz = f"{len(atoms_obj_viz)}\nStructure\n"
1274
- for atom_viz in atoms_obj_viz:
1275
- xyz_str_viz += f"{atom_viz.symbol} {atom_viz.position[0]:.6f} {atom_viz.position[1]:.6f} {atom_viz.position[2]:.6f}\n"
1276
- view_viz = py3Dmol.view(width=400, height=400)
1277
- view_viz.addModel(xyz_str_viz, "xyz")
1278
- view_viz.setStyle({'stick': {}})
1279
- if any(atoms_obj_viz.pbc): # Show cell for optimized periodic structures
1280
- cell_viz = atoms_obj_viz.get_cell()
1281
- if cell_viz is not None and cell_viz.any():
1282
- # Simplified cell drawing for brevity, use get_structure_viz2 if full cell needed
1283
- view_viz.addUnitCell({'box': {'lx':cell_viz.lengths()[0],'ly':cell_viz.lengths()[1],'lz':cell_viz.lengths()[2],
1284
- 'hx':cell_viz.cellpar()[3],'hy':cell_viz.cellpar()[4],'hz':cell_viz.cellpar()[5]}})
1285
 
1286
- view_viz.zoomTo()
1287
- view_viz.setBackgroundColor('white')
1288
- return view_viz
1289
-
1290
- opt_view = get_structure_viz2(calc_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
1291
  st.components.v1.html(opt_view._make_html(), width=400, height=400)
1292
 
1293
  with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt:
1294
- write(tmp_file_opt.name, calc_atoms, format="extxyz")
 
 
 
1295
  tmp_filepath_opt = tmp_file_opt.name
1296
 
1297
  with open(tmp_filepath_opt, 'r') as file_opt:
@@ -1309,81 +1306,6 @@ if atoms is not None:
1309
  show_optimized_structure_download_button()
1310
  os.unlink(tmp_filepath_opt)
1311
 
1312
- # # Convert trajectory to XYZ for download
1313
- # @st.fragment
1314
- # def show_trajectory():
1315
- # if os.path.exists(traj_filename):
1316
- # try:
1317
- # from ase.io import read
1318
- # from ase.visualize import view
1319
- # import py3Dmol
1320
-
1321
- # trajectory = read(traj_filename, index=':')
1322
- # st.markdown("### Optimization Trajectory")
1323
- # st.write(f"Captured {len(trajectory)} optimization steps")
1324
-
1325
- # # Store the trajectory in session state
1326
- # if "traj_frames" not in st.session_state:
1327
- # st.session_state.traj_frames = trajectory
1328
- # st.session_state.traj_index = 0
1329
-
1330
- # # Navigation Buttons
1331
- # col1, col2, col3, col4 = st.columns(4)
1332
- # with col1:
1333
- # if st.button("⏮ First"):
1334
- # st.session_state.traj_index = 0
1335
- # with col2:
1336
- # if st.button("◀ Previous") and st.session_state.traj_index > 0:
1337
- # st.session_state.traj_index -= 1
1338
- # with col3:
1339
- # if st.button("Next ▶") and st.session_state.traj_index < len(st.session_state.traj_frames) - 1:
1340
- # st.session_state.traj_index += 1
1341
- # with col4:
1342
- # if st.button("Last ⏭"):
1343
- # st.session_state.traj_index = len(st.session_state.traj_frames) - 1
1344
-
1345
- # # Show current frame
1346
- # current_atoms = st.session_state.traj_frames[st.session_state.traj_index]
1347
- # st.write(f"Frame {st.session_state.traj_index + 1}/{len(st.session_state.traj_frames)}")
1348
-
1349
- # # Convert to xyz string for py3Dmol
1350
- # def atoms_to_xyz_string(atoms):
1351
- # xyz_str = f"{len(atoms)}\nStep {st.session_state.traj_index}, Energy = {atoms.get_potential_energy():.6f} eV\n"
1352
- # for atom in atoms:
1353
- # xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
1354
- # return xyz_str
1355
-
1356
- # xyz_str = atoms_to_xyz_string(current_atoms)
1357
- # view = py3Dmol.view(width=400, height=400)
1358
- # view.addModel(xyz_str, "xyz")
1359
- # view.setStyle({'stick': {}})
1360
- # view.zoomTo()
1361
- # view.setBackgroundColor("white")
1362
- # st.components.v1.html(view._make_html(), height=400, width=400)
1363
-
1364
- # # Download entire trajectory
1365
- # @st.fragment
1366
- # def show_trajectory_download_button():
1367
- # trajectory_xyz = ""
1368
- # for i, atoms in enumerate(st.session_state.traj_frames):
1369
- # trajectory_xyz += f"{len(atoms)}\nStep {i}, Energy = {atoms.get_potential_energy():.6f} eV\n"
1370
- # for atom in atoms:
1371
- # trajectory_xyz += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
1372
- # st.download_button(
1373
- # label="Download Optimization Trajectory (XYZ)",
1374
- # data=trajectory_xyz,
1375
- # file_name="optimization_trajectory.xyz",
1376
- # mime="chemical/x-xyz"
1377
- # )
1378
- # show_trajectory_download_button()
1379
-
1380
- # except Exception as e:
1381
- # st.warning(f"Could not process trajectory: {e}")
1382
-
1383
- # finally:
1384
- # os.unlink(traj_filename)
1385
-
1386
- # show_trajectory()
1387
  @st.fragment
1388
  def show_trajectory_and_controls():
1389
  from ase.io import read
@@ -1439,14 +1361,6 @@ if atoms is not None:
1439
  xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
1440
  return xyz_str
1441
 
1442
- # xyz_str = atoms_to_xyz_string(current_atoms, st.session_state.traj_index)
1443
-
1444
- # view = py3Dmol.view(width=400, height=400)
1445
- # view.addModel(xyz_str, "xyz")
1446
- # view.setStyle({'stick': {}})
1447
- # view.zoomTo()
1448
- # view.setBackgroundColor("white")
1449
- # st.components.v1.html(view._make_html(), height=400, width=400)
1450
  traj_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
1451
  st.components.v1.html(traj_view._make_html(), width=400, height=400)
1452
 
@@ -1462,7 +1376,73 @@ if atoms is not None:
1462
  )
1463
 
1464
  show_trajectory_and_controls()
 
1465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1466
  except Exception as e:
1467
  st.error(f"🔴 Calculation error: {str(e)}")
1468
  st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).")
@@ -1499,4 +1479,4 @@ with st.expander('ℹ️ About This App & Foundational MLIPs'):
1499
  - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
1500
  """)
1501
  st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem, SevenNet, ORB and ❤️")
1502
- st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan Group](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/) at [IISc Bangalore](https://iisc.ac.in/)")
 
1
  import streamlit as st
2
  import os
3
+ import io
4
  import tempfile
5
  import torch
6
  # FOR CPU only mode
 
25
  import subprocess
26
  import sys
27
  import pkg_resources
28
+ from ase.vibrations import Vibrations
29
+ import matplotlib.pyplot as plt
30
  mattersim_available = False
31
+ if mattersim_available:
32
+ from mattersim.forcefield import MatterSimCalculator
33
  # try:
34
  # subprocess.check_call([sys.executable, "-m", "pip", "install", "mattersim"])
35
  # except Exception as e:
 
1033
  if "omat" in selected_model:
1034
  st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
1035
  # selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
1036
+ selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest'])
1037
  if model_type == "MatterSim":
1038
  selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
1039
  model_path = MATTERSIM_MODELS[selected_model]
1040
  if model_type == "SEVEN_NET":
1041
  selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys()))
1042
  if selected_model == '7net-mf-ompa':
1043
+ selected_modal_7net = st.sidebar.selectbox("Select Modal (multi fidelity model):", ['omat24', 'mpa'])
1044
  model_path = SEVEN_NET_MODELS[selected_model]
1045
  if atoms is not None:
1046
  if not check_atom_limit(atoms, selected_model):
 
1060
  "Energy + Forces Calculation",
1061
  "Atomization/Cohesive Energy", # New Task Added
1062
  "Geometry Optimization",
1063
+ "Cell + Geometry Optimization",
1064
+ "Vibrational Mode Analysis"])
1065
 
1066
  if "Optimization" in task:
1067
  st.sidebar.markdown("### Optimization Parameters")
 
1123
  table_placeholder = st.empty() # Recreate placeholder for table
1124
 
1125
  try:
1126
+ torch.set_default_dtype(torch.float32)
1127
  with st.spinner("Running calculation... Please wait."):
1128
  calc_atoms = atoms.copy()
1129
 
1130
  if model_type == "MACE":
1131
  # st.write("Setting up MACE calculator...")
1132
+ calc = get_mace_model(model_path, device, 'float32')
1133
  elif model_type == "FairChem": # FairChem
1134
  # st.write("Setting up FairChem calculator...")
1135
  # Workaround for potential dtype issues when switching models
1136
+ # if device == "cpu": # Ensure torch default dtype matches if needed
1137
+ # torch.set_default_dtype(torch.float32)
1138
+ # _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
1139
  calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
1140
  elif model_type == "ORB":
1141
  # st.write("Setting up ORB calculator...")
 
1143
  calc = ORBCalculator(orbff, device=device)
1144
  elif model_type == "MatterSim":
1145
  # st.write("Setting up MatterSim calculator...")
1146
+ # NOTE: Running mattersim on windows requires changing source code file
1147
+ # https://github.com/microsoft/mattersim/issues/112
1148
+ # mattersim/datasets/utils/convertor.py: 117
1149
+ # to pbc_ = np.array(structure.pbc, dtype=np.int64)
1150
  calc = MatterSimCalculator(load_path=model_path, device=device)
1151
  elif model_type == "SEVEN_NET":
1152
  # st.write("Setting up SEVENNET calculator...")
1153
+ if model_path=='7net-mf-ompa':
1154
  calc = SevenNetCalculator(model=model_path, modal=selected_modal_7net, device=device)
1155
  else:
1156
  calc = SevenNetCalculator(model=model_path, device=device)
 
1245
  results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
1246
 
1247
  elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
1248
+ is_periodic = any(calc_atoms.pbc)
1249
  opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
1250
  # Create temporary trajectory file
1251
  traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
 
1280
 
1281
  if "Optimization" in task and "Final Energy" in results: # Check if opt was successful
1282
  st.markdown("### Optimized Structure")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1283
 
1284
+ opt_view = get_structure_viz2(opt_atoms_obj, style=viz_style, show_unit_cell=True, width=400, height=400)
 
 
 
 
1285
  st.components.v1.html(opt_view._make_html(), width=400, height=400)
1286
 
1287
  with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt:
1288
+ if is_periodic:
1289
+ write(tmp_file_opt.name, calc_atoms, format="extxyz")
1290
+ else:
1291
+ write(tmp_file_opt.name, calc_atoms, format="xyz")
1292
  tmp_filepath_opt = tmp_file_opt.name
1293
 
1294
  with open(tmp_filepath_opt, 'r') as file_opt:
 
1306
  show_optimized_structure_download_button()
1307
  os.unlink(tmp_filepath_opt)
1308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1309
  @st.fragment
1310
  def show_trajectory_and_controls():
1311
  from ase.io import read
 
1361
  xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
1362
  return xyz_str
1363
 
 
 
 
 
 
 
 
 
1364
  traj_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
1365
  st.components.v1.html(traj_view._make_html(), width=400, height=400)
1366
 
 
1376
  )
1377
 
1378
  show_trajectory_and_controls()
1379
+ elif task == "Vibrational Mode Analysis":
1380
 
1381
+ st.write("Running vibrational mode analysis using finite differences...")
1382
+
1383
+ natoms = len(calc_atoms)
1384
+ is_linear = False # Set manually or auto-detect
1385
+ nmodes_expected = 3 * natoms - (5 if is_linear else 6)
1386
+
1387
+ # Create temporary directory to store .vib files
1388
+ with tempfile.TemporaryDirectory() as tmpdir:
1389
+ vib = Vibrations(calc_atoms, name=os.path.join(tmpdir, 'vib'))
1390
+
1391
+ with st.spinner("Calculating vibrational modes... This may take a few minutes."):
1392
+ vib.run()
1393
+ freqs = vib.get_frequencies()
1394
+
1395
+ # Convert frequencies to cm⁻¹
1396
+ freqs_cm = freqs #/ cm
1397
+
1398
+ # Classify frequencies
1399
+ mode_data = []
1400
+ for i, freq in enumerate(freqs_cm):
1401
+ if freq < 0:
1402
+ label = "Imaginary"
1403
+ elif abs(freq) < 500:
1404
+ label = "Low"
1405
+ else:
1406
+ label = "Physical"
1407
+ mode_data.append({
1408
+ "Mode": i + 1,
1409
+ "Frequency (cm⁻¹)": round(freq, 2),
1410
+ "Type": label
1411
+ })
1412
+
1413
+ df_modes = pd.DataFrame(mode_data)
1414
+
1415
+ # Display summary and mode count
1416
+ st.success("Vibrational analysis completed.")
1417
+ st.write(f"Number of atoms: {natoms}")
1418
+ st.write(f"Expected vibrational modes: {nmodes_expected}")
1419
+ st.write(f"Found {len(freqs_cm)} modes (including translational/rotational modes).")
1420
+
1421
+ # Show table of modes
1422
+ st.write("### Vibrational Mode Summary")
1423
+ st.dataframe(df_modes, use_container_width=True)
1424
+
1425
+ # Store in results dictionary
1426
+ results["Vibrational Modes"] = df_modes.to_dict(orient="records")
1427
+
1428
+ # Histogram plot of vibrational frequencies
1429
+ st.write("### Frequency Distribution Histogram")
1430
+ fig, ax = plt.subplots()
1431
+ ax.hist(freqs_cm, bins=30, color='skyblue', edgecolor='black')
1432
+ ax.set_xlabel("Frequency (cm⁻¹)")
1433
+ ax.set_ylabel("Number of Modes")
1434
+ ax.set_title("Distribution of Vibrational Frequencies")
1435
+ st.pyplot(fig)
1436
+
1437
+ # CSV download
1438
+ csv_buffer = io.StringIO()
1439
+ df_modes.to_csv(csv_buffer, index=False)
1440
+ st.download_button(
1441
+ label="Download Vibrational Frequencies (CSV)",
1442
+ data=csv_buffer.getvalue(),
1443
+ file_name="vibrational_modes.csv",
1444
+ mime="text/csv"
1445
+ )
1446
  except Exception as e:
1447
  st.error(f"🔴 Calculation error: {str(e)}")
1448
  st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).")
 
1479
  - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
1480
  """)
1481
  st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem, SevenNet, ORB and ❤️")
1482
+ st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan Group](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/) at [IISc Bangalore](https://iisc.ac.in/)")