hysts HF Staff Wauplin HF Staff commited on
Commit
b09ce4a
·
1 Parent(s): e51b1cb

Update parquet scheduler

Browse files

Co-authored-by: Lucain Pouget <[email protected]>

Files changed (1) hide show
  1. scheduler.py +53 -18
scheduler.py CHANGED
@@ -16,6 +16,25 @@ class ParquetScheduler(CommitScheduler):
16
  self.rows = []
17
  self.rows.append(row)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def push_to_hub(self):
20
  # Check for new rows to push
21
  with self.lock:
@@ -25,29 +44,24 @@ class ParquetScheduler(CommitScheduler):
25
  return
26
 
27
  # Load images + create 'features' config for datasets library
28
- hf_features: dict[str, Dict] = {}
29
  path_to_cleanup: List[Path] = []
30
  for row in rows:
31
  for key, value in row.items():
32
- if 'image' in key:
33
- # It's an image: we load the bytes, define a special schema and remember to cleanup the file
34
- # Note: could do the same with "Audio"
35
- image_path = Path(value)
36
- if image_path.is_file():
 
 
 
 
37
  row[key] = {
38
- 'path': image_path.name,
39
- 'bytes': image_path.read_bytes()
40
- }
41
- path_to_cleanup.append(image_path)
42
- if key not in hf_features:
43
- hf_features[key] = {'_type': 'Image'}
44
- else:
45
- # Otherwise, do nothing special
46
- if key not in hf_features:
47
- hf_features[key] = {
48
- '_type': 'Value',
49
- 'dtype': 'string'
50
  }
 
51
 
52
  # Complete rows if needed
53
  for row in rows:
@@ -81,3 +95,24 @@ class ParquetScheduler(CommitScheduler):
81
  archive_file.close()
82
  for path in path_to_cleanup:
83
  path.unlink(missing_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  self.rows = []
17
  self.rows.append(row)
18
 
19
+ def set_schema(self, schema: Dict[str, Dict[str, str]]) -> None:
20
+ """
21
+ Define a schema to help `datasets` load the generated library.
22
+ This method is optional and can be called once just after the scheduler had been created. If it is not called,
23
+ the schema is automatically inferred before pushing the data to the Hub.
24
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
25
+ possible values.
26
+ Example:
27
+ ```py
28
+ scheduler.set_schema({
29
+ "prompt": {"_type": "Value", "dtype": "string"},
30
+ "negative_prompt": {"_type": "Value", "dtype": "string"},
31
+ "guidance_scale": {"_type": "Value", "dtype": "int64"},
32
+ "image": {"_type": "Image"},
33
+ })
34
+ ```
35
+ """
36
+ self._schema = schema
37
+
38
  def push_to_hub(self):
39
  # Check for new rows to push
40
  with self.lock:
 
44
  return
45
 
46
  # Load images + create 'features' config for datasets library
47
+ hf_features: Dict[str, Dict] = getattr(self, '_schema', None) or {}
48
  path_to_cleanup: List[Path] = []
49
  for row in rows:
50
  for key, value in row.items():
51
+ # Infer schema (for `datasets` library)
52
+ if key not in hf_features:
53
+ hf_features[key] = _infer_schema(key, value)
54
+
55
+ # Load binary files if necessary
56
+ if hf_features[key]['_type'] in ('Image', 'Audio'):
57
+ # It's an image or audio: we load the bytes and remember to cleanup the file
58
+ file_path = Path(value)
59
+ if file_path.is_file():
60
  row[key] = {
61
+ 'path': file_path.name,
62
+ 'bytes': file_path.read_bytes()
 
 
 
 
 
 
 
 
 
 
63
  }
64
+ path_to_cleanup.append(file_path)
65
 
66
  # Complete rows if needed
67
  for row in rows:
 
95
  archive_file.close()
96
  for path in path_to_cleanup:
97
  path.unlink(missing_ok=True)
98
+
99
+
100
+ def _infer_schema(key: str, value: Any) -> Dict[str, str]:
101
+ """Infer schema for the `datasets` library.
102
+
103
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value.
104
+ """
105
+ if 'image' in key:
106
+ return {'_type': 'Image'}
107
+ if 'audio' in key:
108
+ return {'_type': 'Audio'}
109
+ if isinstance(value, int):
110
+ return {'_type': 'Value', 'dtype': 'int64'}
111
+ if isinstance(value, float):
112
+ return {'_type': 'Value', 'dtype': 'float64'}
113
+ if isinstance(value, bool):
114
+ return {'_type': 'Value', 'dtype': 'bool'}
115
+ if isinstance(value, bytes):
116
+ return {'_type': 'Value', 'dtype': 'binary'}
117
+ # Otherwise in last resort => convert it to a string
118
+ return {'_type': 'Value', 'dtype': 'string'}