arcleife commited on
Commit
a8268e2
·
verified ·
1 Parent(s): de9d075

Create hf_dataset_saver.py

Browse files
Files changed (1) hide show
  1. hf_dataset_saver.py +272 -0
hf_dataset_saver.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class HuggingFaceDatasetSaver(FlaggingCallback):
2
+ """
3
+ A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
4
+
5
+ Example:
6
+ import gradio as gr
7
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
8
+ def image_classifier(inp):
9
+ return {'cat': 0.3, 'dog': 0.7}
10
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
11
+ allow_flagging="manual", flagging_callback=hf_writer)
12
+ Guides: using-flagging
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ hf_token: str,
18
+ dataset_name: str,
19
+ private: bool = False,
20
+ info_filename: str = "dataset_info.json",
21
+ separate_dirs: bool = False,
22
+ ):
23
+ """
24
+ Parameters:
25
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
26
+ dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
27
+ private: Whether the dataset should be private (defaults to False).
28
+ info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
29
+ separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
30
+ """
31
+ self.hf_token = hf_token
32
+ self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
33
+ self.dataset_private = private
34
+ self.info_filename = info_filename
35
+ self.separate_dirs = separate_dirs
36
+
37
+ def setup(self, components: Sequence[Component], flagging_dir: str):
38
+ """
39
+ Params:
40
+ flagging_dir (str): local directory where the dataset is cloned,
41
+ updated, and pushed from.
42
+ """
43
+ # Setup dataset on the Hub
44
+ self.dataset_id = huggingface_hub.create_repo(
45
+ repo_id=self.dataset_id,
46
+ token=self.hf_token,
47
+ private=self.dataset_private,
48
+ repo_type="dataset",
49
+ exist_ok=True,
50
+ ).repo_id
51
+ path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
52
+ huggingface_hub.metadata_update(
53
+ repo_id=self.dataset_id,
54
+ repo_type="dataset",
55
+ metadata={
56
+ "configs": [
57
+ {
58
+ "config_name": "default",
59
+ "data_files": [{"split": "train", "path": path_glob}],
60
+ }
61
+ ]
62
+ },
63
+ overwrite=True,
64
+ token=self.hf_token,
65
+ )
66
+
67
+ # Setup flagging dir
68
+ self.components = components
69
+ self.dataset_dir = (
70
+ Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
71
+ )
72
+ self.dataset_dir.mkdir(parents=True, exist_ok=True)
73
+ self.infos_file = self.dataset_dir / self.info_filename
74
+
75
+ # Download remote files to local
76
+ remote_files = [self.info_filename]
77
+ if not self.separate_dirs:
78
+ # No separate dirs => means all data is in the same CSV file => download it to get its current content
79
+ remote_files.append("data.csv")
80
+
81
+ for filename in remote_files:
82
+ try:
83
+ huggingface_hub.hf_hub_download(
84
+ repo_id=self.dataset_id,
85
+ repo_type="dataset",
86
+ filename=filename,
87
+ local_dir=self.dataset_dir,
88
+ token=self.hf_token,
89
+ )
90
+ except huggingface_hub.utils.EntryNotFoundError:
91
+ pass
92
+
93
+ def flag(
94
+ self,
95
+ flag_data: list[Any],
96
+ flag_option: str = "",
97
+ username: str | None = None,
98
+ ) -> int:
99
+ if self.separate_dirs:
100
+ # JSONL files to support dataset preview on the Hub
101
+ unique_id = str(uuid.uuid4())
102
+ components_dir = self.dataset_dir / unique_id
103
+ data_file = components_dir / "metadata.jsonl"
104
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
105
+ else:
106
+ # Unique CSV file
107
+ components_dir = self.dataset_dir
108
+ data_file = components_dir / "data.csv"
109
+ path_in_repo = None # upload at root level
110
+
111
+ return self._flag_in_dir(
112
+ data_file=data_file,
113
+ components_dir=components_dir,
114
+ path_in_repo=path_in_repo,
115
+ flag_data=flag_data,
116
+ flag_option=flag_option,
117
+ username=username or "",
118
+ )
119
+
120
+ def _flag_in_dir(
121
+ self,
122
+ data_file: Path,
123
+ components_dir: Path,
124
+ path_in_repo: str | None,
125
+ flag_data: list[Any],
126
+ flag_option: str = "",
127
+ username: str = "",
128
+ ) -> int:
129
+ # Deserialize components (write images/audio to files)
130
+ features, row = self._deserialize_components(
131
+ components_dir, flag_data, flag_option, username
132
+ )
133
+
134
+ # Write generic info to dataset_infos.json + upload
135
+ with filelock.FileLock(str(self.infos_file) + ".lock"):
136
+ if not self.infos_file.exists():
137
+ self.infos_file.write_text(
138
+ json.dumps({"flagged": {"features": features}})
139
+ )
140
+
141
+ huggingface_hub.upload_file(
142
+ repo_id=self.dataset_id,
143
+ repo_type="dataset",
144
+ token=self.hf_token,
145
+ path_in_repo=self.infos_file.name,
146
+ path_or_fileobj=self.infos_file,
147
+ )
148
+
149
+ headers = list(features.keys())
150
+
151
+ if not self.separate_dirs:
152
+ with filelock.FileLock(components_dir / ".lock"):
153
+ sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
154
+ sample_name = str(sample_nb)
155
+ huggingface_hub.upload_folder(
156
+ repo_id=self.dataset_id,
157
+ repo_type="dataset",
158
+ commit_message=f"Flagged sample #{sample_name}",
159
+ path_in_repo=path_in_repo,
160
+ ignore_patterns="*.lock",
161
+ folder_path=components_dir,
162
+ token=self.hf_token,
163
+ )
164
+ else:
165
+ sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
166
+ sample_nb = len(
167
+ [path for path in self.dataset_dir.iterdir() if path.is_dir()]
168
+ )
169
+ huggingface_hub.upload_folder(
170
+ repo_id=self.dataset_id,
171
+ repo_type="dataset",
172
+ commit_message=f"Flagged sample #{sample_name}",
173
+ path_in_repo=path_in_repo,
174
+ ignore_patterns="*.lock",
175
+ folder_path=components_dir,
176
+ token=self.hf_token,
177
+ )
178
+
179
+ return sample_nb
180
+
181
+ @staticmethod
182
+ def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
183
+ """Save data as CSV and return the sample name (row number)."""
184
+ is_new = not data_file.exists()
185
+
186
+ with data_file.open("a", newline="", encoding="utf-8") as csvfile:
187
+ writer = csv.writer(csvfile)
188
+
189
+ # Write CSV headers if new file
190
+ if is_new:
191
+ writer.writerow(utils.sanitize_list_for_csv(headers))
192
+
193
+ # Write CSV row for flagged sample
194
+ writer.writerow(utils.sanitize_list_for_csv(row))
195
+
196
+ with data_file.open(encoding="utf-8") as csvfile:
197
+ return sum(1 for _ in csv.reader(csvfile)) - 1
198
+
199
+ @staticmethod
200
+ def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
201
+ """Save data as JSONL and return the sample name (uuid)."""
202
+ Path.mkdir(data_file.parent, parents=True, exist_ok=True)
203
+ with open(data_file, "w", encoding="utf-8") as f:
204
+ json.dump(dict(zip(headers, row)), f)
205
+ return data_file.parent.name
206
+
207
+ def _deserialize_components(
208
+ self,
209
+ data_dir: Path,
210
+ flag_data: list[Any],
211
+ flag_option: str = "",
212
+ username: str = "",
213
+ ) -> tuple[dict[Any, Any], list[Any]]:
214
+ """Deserialize components and return the corresponding row for the flagged sample.
215
+
216
+ Images/audio are saved to disk as individual files.
217
+ """
218
+ # Components that can have a preview on dataset repos
219
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
220
+
221
+ # Generate the row corresponding to the flagged sample
222
+ features = OrderedDict()
223
+ row = []
224
+ for component, sample in zip(self.components, flag_data):
225
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
226
+ label = component.label or ""
227
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
228
+ save_dir.mkdir(exist_ok=True, parents=True)
229
+ deserialized = utils.simplify_file_data_in_str(
230
+ component.flag(sample, save_dir)
231
+ )
232
+
233
+ # Add deserialized object to row
234
+ features[label] = {"dtype": "string", "_type": "Value"}
235
+ try:
236
+ deserialized_path = Path(deserialized)
237
+ if not deserialized_path.exists():
238
+ raise FileNotFoundError(f"File {deserialized} not found")
239
+ row.append(str(deserialized_path.relative_to(self.dataset_dir)))
240
+ except (FileNotFoundError, TypeError, ValueError, OSError):
241
+ deserialized = "" if deserialized is None else str(deserialized)
242
+ row.append(deserialized)
243
+
244
+ # If component is eligible for a preview, add the URL of the file
245
+ # Be mindful that images and audio can be None
246
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
247
+ for _component, _type in file_preview_types.items():
248
+ if isinstance(component, _component):
249
+ features[label + " file"] = {"_type": _type}
250
+ break
251
+ if deserialized:
252
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
253
+ Path(deserialized).relative_to(self.dataset_dir)
254
+ ).replace("\\", "/")
255
+ row.append(
256
+ huggingface_hub.hf_hub_url(
257
+ repo_id=self.dataset_id,
258
+ filename=path_in_repo,
259
+ repo_type="dataset",
260
+ )
261
+ )
262
+ else:
263
+ row.append("")
264
+
265
+ timestamp = datetime.now(timezone(timedelta(hours=9))).isoformat()
266
+ features["flag"] = {"dtype": "string", "_type": "Value"}
267
+ features["username"] = {"dtype": "string", "_type": "Value"}
268
+ features["timestamp"] = {"dtype": "string", "_type": "Value"}
269
+ row.append(flag_option)
270
+ row.append(username)
271
+ row.append(timestamp)
272
+ return features, row