hysts HF Staff commited on
Commit
1bb0264
·
1 Parent(s): e820b78
Files changed (2) hide show
  1. app.py +4 -15
  2. scheduler.py +78 -34
app.py CHANGED
@@ -4,9 +4,7 @@ import datetime
4
  import json
5
  import os
6
  import pathlib
7
- import shutil
8
  import tempfile
9
- import uuid
10
  from typing import Any
11
 
12
  import gradio as gr
@@ -15,11 +13,9 @@ from gradio_client import Client
15
  from scheduler import ParquetScheduler
16
 
17
  HF_TOKEN = os.getenv('HF_TOKEN')
18
- UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID')
19
  UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '15'))
20
  USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
21
- LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results'))
22
- LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
23
 
24
  ABOUT_THIS_SPACE = '''
25
  This Space is a sample Space that collects user preferences for the results generated by a diffusion model.
@@ -35,8 +31,7 @@ scheduler = ParquetScheduler(repo_id=UPLOAD_REPO_ID,
35
  repo_type='dataset',
36
  every=UPLOAD_FREQUENCY,
37
  private=not USE_PUBLIC_REPO,
38
- token=HF_TOKEN,
39
- folder_path=LOCAL_SAVE_DIR)
40
 
41
  client = Client('stabilityai/stable-diffusion')
42
 
@@ -69,9 +64,6 @@ def get_selected_index(evt: gr.SelectData) -> int:
69
 
70
  def save_preference(config_path: str, gallery: list[dict[str, Any]],
71
  selected_index: int) -> None:
72
- save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}'
73
- save_dir.mkdir(parents=True, exist_ok=True)
74
-
75
  # Load config
76
  with open(config_path) as f:
77
  data = json.load(f)
@@ -80,12 +72,9 @@ def save_preference(config_path: str, gallery: list[dict[str, Any]],
80
  data['selected_index'] = selected_index
81
  data['timestamp'] = datetime.datetime.utcnow().isoformat()
82
 
83
- # Copy and add images
84
  for index, path in enumerate(x['name'] for x in gallery):
85
- name = f'{index:03d}'
86
- dst_path = save_dir / f'{name}{pathlib.Path(path).suffix}'
87
- shutil.move(path, dst_path)
88
- data[f'image_{name}'] = dst_path
89
 
90
  # Send to scheduler
91
  scheduler.append(data)
 
4
  import json
5
  import os
6
  import pathlib
 
7
  import tempfile
 
8
  from typing import Any
9
 
10
  import gradio as gr
 
13
  from scheduler import ParquetScheduler
14
 
15
  HF_TOKEN = os.getenv('HF_TOKEN')
16
+ UPLOAD_REPO_ID = os.environ['UPLOAD_REPO_ID']
17
  UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '15'))
18
  USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
 
 
19
 
20
  ABOUT_THIS_SPACE = '''
21
  This Space is a sample Space that collects user preferences for the results generated by a diffusion model.
 
31
  repo_type='dataset',
32
  every=UPLOAD_FREQUENCY,
33
  private=not USE_PUBLIC_REPO,
34
+ token=HF_TOKEN)
 
35
 
36
  client = Client('stabilityai/stable-diffusion')
37
 
 
64
 
65
  def save_preference(config_path: str, gallery: list[dict[str, Any]],
66
  selected_index: int) -> None:
 
 
 
67
  # Load config
68
  with open(config_path) as f:
69
  data = json.load(f)
 
72
  data['selected_index'] = selected_index
73
  data['timestamp'] = datetime.datetime.utcnow().isoformat()
74
 
75
+ # Add images
76
  for index, path in enumerate(x['name'] for x in gallery):
77
+ data[f'image_{index:03d}'] = path
 
 
 
78
 
79
  # Send to scheduler
80
  scheduler.append(data)
scheduler.py CHANGED
@@ -1,71 +1,114 @@
1
  import json
2
- import tempfile
3
  import uuid
4
  from pathlib import Path
5
- from typing import Any, Dict, List
6
 
7
  import pyarrow as pa
8
  import pyarrow.parquet as pq
9
- from huggingface_hub import CommitScheduler
10
 
11
 
12
  class ParquetScheduler(CommitScheduler):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def append(self, row: Dict[str, Any]) -> None:
 
14
  with self.lock:
15
- if not hasattr(self, 'rows') or self.rows is None: # type: ignore
16
- self.rows = []
17
- self.rows.append(row)
18
-
19
- def set_schema(self, schema: Dict[str, Dict[str, str]]) -> None:
20
- """
21
- Define a schema to help `datasets` load the generated library.
22
- This method is optional and can be called once just after the scheduler had been created. If it is not called,
23
- the schema is automatically inferred before pushing the data to the Hub.
24
- See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
25
- possible values.
26
- Example:
27
- ```py
28
- scheduler.set_schema({
29
- "prompt": {"_type": "Value", "dtype": "string"},
30
- "negative_prompt": {"_type": "Value", "dtype": "string"},
31
- "guidance_scale": {"_type": "Value", "dtype": "int64"},
32
- "image": {"_type": "Image"},
33
- })
34
- ```
35
- """
36
- self._schema = schema
37
 
38
  def push_to_hub(self):
39
  # Check for new rows to push
40
  with self.lock:
41
- rows = self.rows
42
- self.rows = None
43
  if not rows:
44
  return
 
45
 
46
  # Load images + create 'features' config for datasets library
47
- hf_features: Dict[str, Dict] = getattr(self, '_schema', None) or {}
48
  path_to_cleanup: List[Path] = []
49
  for row in rows:
50
  for key, value in row.items():
51
  # Infer schema (for `datasets` library)
52
- if key not in hf_features:
53
- hf_features[key] = _infer_schema(key, value)
54
 
55
  # Load binary files if necessary
56
- if hf_features[key]['_type'] in ('Image', 'Audio'):
57
  # It's an image or audio: we load the bytes and remember to cleanup the file
58
  file_path = Path(value)
59
  if file_path.is_file():
60
  row[key] = {
61
  'path': file_path.name,
62
- 'bytes': file_path.read_bytes()
63
  }
64
  path_to_cleanup.append(file_path)
65
 
66
  # Complete rows if needed
67
  for row in rows:
68
- for feature in hf_features:
69
  if feature not in row:
70
  row[feature] = None
71
 
@@ -75,7 +118,7 @@ class ParquetScheduler(CommitScheduler):
75
  # Add metadata (used by datasets library)
76
  table = table.replace_schema_metadata(
77
  {'huggingface': json.dumps({'info': {
78
- 'features': hf_features
79
  }})})
80
 
81
  # Write to parquet file
@@ -90,6 +133,7 @@ class ParquetScheduler(CommitScheduler):
90
  path_in_repo=f'{uuid.uuid4()}.parquet',
91
  path_or_fileobj=archive_file.name,
92
  )
 
93
 
94
  # Cleanup
95
  archive_file.close()
 
1
  import json
 
2
  import uuid
3
  from pathlib import Path
4
+ from typing import Any, Dict, List, Optional, Union
5
 
6
  import pyarrow as pa
7
  import pyarrow.parquet as pq
8
+ from huggingface_hub import CommitScheduler, HfApi
9
 
10
 
11
  class ParquetScheduler(CommitScheduler):
12
+ """
13
+ Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
14
+ call will result in 1 row in your final dataset.
15
+
16
+ ```py
17
+ # Start scheduler
18
+ >>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")
19
+
20
+ # Append some data to be uploaded
21
+ >>> scheduler.append({...})
22
+ >>> scheduler.append({...})
23
+ >>> scheduler.append({...})
24
+ ```
25
+
26
+ The scheduler will automatically infer the schema from the data it pushes.
27
+ Optionally, you can manually set the schema yourself:
28
+
29
+ ```py
30
+ >>> scheduler = ParquetScheduler(
31
+ ... repo_id="my-parquet-dataset",
32
+ ... schema={
33
+ ... "prompt": {"_type": "Value", "dtype": "string"},
34
+ ... "negative_prompt": {"_type": "Value", "dtype": "string"},
35
+ ... "guidance_scale": {"_type": "Value", "dtype": "int64"},
36
+ ... "image": {"_type": "Image"},
37
+ ... },
38
+ ... )
39
+
40
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
41
+ possible values.
42
+ """
43
+ def __init__(
44
+ self,
45
+ *,
46
+ repo_id: str,
47
+ schema: Optional[Dict[str, Dict[str, str]]] = None,
48
+ every: Union[int, float] = 5,
49
+ path_in_repo: Optional[str] = 'data',
50
+ repo_type: Optional[str] = 'dataset',
51
+ revision: Optional[str] = None,
52
+ private: bool = False,
53
+ token: Optional[str] = None,
54
+ allow_patterns: Union[List[str], str, None] = None,
55
+ ignore_patterns: Union[List[str], str, None] = None,
56
+ hf_api: Optional[HfApi] = None,
57
+ ) -> None:
58
+ super().__init__(
59
+ repo_id=repo_id,
60
+ folder_path='dummy', # not used by the scheduler
61
+ every=every,
62
+ path_in_repo=path_in_repo,
63
+ repo_type=repo_type,
64
+ revision=revision,
65
+ private=private,
66
+ token=token,
67
+ allow_patterns=allow_patterns,
68
+ ignore_patterns=ignore_patterns,
69
+ hf_api=hf_api,
70
+ )
71
+
72
+ self._rows: List[Dict[str, Any]] = []
73
+ self._schema = schema
74
+
75
  def append(self, row: Dict[str, Any]) -> None:
76
+ """Add a new item to be uploaded."""
77
  with self.lock:
78
+ self._rows.append(row)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def push_to_hub(self):
81
  # Check for new rows to push
82
  with self.lock:
83
+ rows = self._rows
84
+ self._rows = []
85
  if not rows:
86
  return
87
+ print(f'Got {len(rows)} item(s) to commit.')
88
 
89
  # Load images + create 'features' config for datasets library
90
+ schema: Dict[str, Dict] = self._schema or {}
91
  path_to_cleanup: List[Path] = []
92
  for row in rows:
93
  for key, value in row.items():
94
  # Infer schema (for `datasets` library)
95
+ if key not in schema:
96
+ schema[key] = _infer_schema(key, value)
97
 
98
  # Load binary files if necessary
99
+ if schema[key]['_type'] in ('Image', 'Audio'):
100
  # It's an image or audio: we load the bytes and remember to cleanup the file
101
  file_path = Path(value)
102
  if file_path.is_file():
103
  row[key] = {
104
  'path': file_path.name,
105
+ 'bytes': file_path.read_bytes(),
106
  }
107
  path_to_cleanup.append(file_path)
108
 
109
  # Complete rows if needed
110
  for row in rows:
111
+ for feature in schema:
112
  if feature not in row:
113
  row[feature] = None
114
 
 
118
  # Add metadata (used by datasets library)
119
  table = table.replace_schema_metadata(
120
  {'huggingface': json.dumps({'info': {
121
+ 'features': schema
122
  }})})
123
 
124
  # Write to parquet file
 
133
  path_in_repo=f'{uuid.uuid4()}.parquet',
134
  path_or_fileobj=archive_file.name,
135
  )
136
+ print(f'Commit completed.')
137
 
138
  # Cleanup
139
  archive_file.close()