zulissimeta commited on
Commit
98b9e26
·
1 Parent(s): bbb728a

check uma access on file upload

Browse files
app.py CHANGED
@@ -12,11 +12,8 @@ from pathlib import Path
12
 
13
  import gradio as gr
14
 
15
- from simulation_scripts import (
16
- run_md_simulation,
17
- run_relaxation_simulation,
18
- validate_ase_atoms_and_login,
19
- )
20
 
21
  DEFAULT_MOLECULAR_REPRESENTATIONS = [
22
  {
 
12
 
13
  import gradio as gr
14
 
15
+ from simulation_scripts import (run_md_simulation, run_relaxation_simulation,
16
+ validate_ase_atoms_and_login)
 
 
 
17
 
18
  DEFAULT_MOLECULAR_REPRESENTATIONS = [
19
  {
gradio_molecule3d/backend/gradio_molecule3d/molecule3d.py CHANGED
@@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
9
 
10
  import gradio_client.utils as client_utils
11
  from gradio_client import handle_file
12
- from gradio_client.documentation import document
13
 
 
14
  from gradio import processing_utils
15
  from gradio.components.base import Component
16
  from gradio.data_classes import FileData, ListFiles
@@ -107,7 +107,7 @@ def convert_file_to_pdb(file_path: str | Path, gradio_cache: str | Path) -> str:
107
  structures = ase.io.read(file_path, ':')
108
  except Exception as e:
109
  # Bad upload structure, no need to visualize
110
- return None
111
 
112
  if all(structures[0].pbc):
113
  # find the minimum number of repeats in each unit cell direction to meet at least 20 angstroms
@@ -322,11 +322,14 @@ class Molecule3D(Component):
322
  )
323
  else:
324
  value = convert_file_to_pdb(str(value), self.GRADIO_CACHE)
325
- return FileData(
326
- path=value,
327
- orig_name=Path(value).name,
328
- size=Path(value).stat().st_size,
329
- )
 
 
 
330
 
331
  def process_example(self, value: str | list | None) -> str:
332
  if value is None:
 
9
 
10
  import gradio_client.utils as client_utils
11
  from gradio_client import handle_file
 
12
 
13
+ import gradio as gr
14
  from gradio import processing_utils
15
  from gradio.components.base import Component
16
  from gradio.data_classes import FileData, ListFiles
 
107
  structures = ase.io.read(file_path, ':')
108
  except Exception as e:
109
  # Bad upload structure, no need to visualize
110
+ raise gr.Error(f'Error parsing file with ase: {str(e)}')
111
 
112
  if all(structures[0].pbc):
113
  # find the minimum number of repeats in each unit cell direction to meet at least 20 angstroms
 
322
  )
323
  else:
324
  value = convert_file_to_pdb(str(value), self.GRADIO_CACHE)
325
+ if value is not None:
326
+ return FileData(
327
+ path=value,
328
+ orig_name=Path(value).name,
329
+ size=Path(value).stat().st_size,
330
+ )
331
+ else:
332
+ return None
333
 
334
  def process_example(self, value: str | list | None) -> str:
335
  if value is None:
hf_calculator.py CHANGED
@@ -29,6 +29,14 @@ def hash_save_file(atoms: ase.Atoms, task_name, path: Path | str):
29
  return
30
 
31
 
 
 
 
 
 
 
 
 
32
  class HFEndpointCalculator(Calculator):
33
  # A simple calculator that uses the Hugging Face Inference Endpoints to run
34
 
@@ -47,15 +55,12 @@ class HFEndpointCalculator(Calculator):
47
  # If we have an example structure, we don't need to check for authentication
48
  # Otherwise, we need to check if the user is authenticated and has gated access to the UMA models
49
  if not example:
50
- try:
51
- hf_hub.HfApi().auth_check(
52
- repo_id="facebook/UMA", token=oauth_token.token
53
- )
54
  try:
55
  hash_save_file(atoms, task_name, "/data/custom_inputs/")
56
  except FileNotFoundError:
57
  pass
58
- except (hf_hub.errors.HfHubHTTPError, AttributeError):
59
  raise gr.Error(
60
  "You need to log in to HF and have gated model access to UMA before running your own simulations!"
61
  )
 
29
  return
30
 
31
 
32
+ def validate_uma_access(oauth_token):
33
+ try:
34
+ hf_hub.HfApi().auth_check(repo_id="facebook/UMA", token=oauth_token.token)
35
+ return True
36
+ except (hf_hub.errors.HfHubHTTPError, AttributeError):
37
+ return False
38
+
39
+
40
  class HFEndpointCalculator(Calculator):
41
  # A simple calculator that uses the Hugging Face Inference Endpoints to run
42
 
 
55
  # If we have an example structure, we don't need to check for authentication
56
  # Otherwise, we need to check if the user is authenticated and has gated access to the UMA models
57
  if not example:
58
+ if validate_uma_access(oauth_token):
 
 
 
59
  try:
60
  hash_save_file(atoms, task_name, "/data/custom_inputs/")
61
  except FileNotFoundError:
62
  pass
63
+ else:
64
  raise gr.Error(
65
  "You need to log in to HF and have gated model access to UMA before running your own simulations!"
66
  )
simulation_scripts.py CHANGED
@@ -23,7 +23,7 @@ from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
23
  from ase.md.verlet import VelocityVerlet
24
  from ase.optimize import LBFGS
25
 
26
- from hf_calculator import HFEndpointCalculator
27
 
28
 
29
  def hash_file(file_path):
@@ -43,7 +43,9 @@ INFERENCE_ENDPOINT_URL = os.environ["INFERENCE_ENDPOINT_URL"]
43
 
44
 
45
  def validate_ase_atoms_and_login(
46
- structure_file: dict | str, login_button_value: str
 
 
47
  ) -> tuple[gr.Button, gr.Button, str]:
48
  # Validate and write the uploaded file content
49
  if not structure_file:
@@ -84,7 +86,8 @@ def validate_ase_atoms_and_login(
84
  f"Structure file contains {len(atoms)}, which is more than {MAX_ATOMS} atoms. Please use a smaller structure for this demo, or run this on a local machine!",
85
  )
86
  elif (hash_file(structure_file) not in EXAMPLE_FILE_HASHES) and (
87
- "Logout" not in login_button_value
 
88
  ):
89
  return (
90
  gr.Button(interactive=False),
 
23
  from ase.md.verlet import VelocityVerlet
24
  from ase.optimize import LBFGS
25
 
26
+ from hf_calculator import HFEndpointCalculator, validate_uma_access
27
 
28
 
29
  def hash_file(file_path):
 
43
 
44
 
45
  def validate_ase_atoms_and_login(
46
+ structure_file: dict | str,
47
+ login_button_value: str,
48
+ oauth_token: gr.OAuthToken | None,
49
  ) -> tuple[gr.Button, gr.Button, str]:
50
  # Validate and write the uploaded file content
51
  if not structure_file:
 
86
  f"Structure file contains {len(atoms)}, which is more than {MAX_ATOMS} atoms. Please use a smaller structure for this demo, or run this on a local machine!",
87
  )
88
  elif (hash_file(structure_file) not in EXAMPLE_FILE_HASHES) and (
89
+ ("Logout" not in login_button_value)
90
+ or not validate_uma_access(oauth_token=oauth_token)
91
  ):
92
  return (
93
  gr.Button(interactive=False),