arcleife commited on
Commit
beb81ec
·
verified ·
1 Parent(s): a8268e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -279
app.py CHANGED
@@ -7,9 +7,9 @@ import polars as pl
7
  import re
8
  import json
9
  from datetime import datetime, timezone, timedelta
10
- from transformers import pipeline
11
- from transformers import AutoModelForSequenceClassification
12
- from transformers import AutoTokenizer, DistilBertTokenizerFast
13
 
14
  # version: 0.2.1
15
 
@@ -24,279 +24,7 @@ import uuid
24
  import filelock
25
  import csv
26
 
27
- # TODO move to separate file for cleaner code
28
- class HuggingFaceDatasetSaver(FlaggingCallback):
29
- """
30
- A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
31
-
32
- Example:
33
- import gradio as gr
34
- hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
35
- def image_classifier(inp):
36
- return {'cat': 0.3, 'dog': 0.7}
37
- demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
38
- allow_flagging="manual", flagging_callback=hf_writer)
39
- Guides: using-flagging
40
- """
41
-
42
- def __init__(
43
- self,
44
- hf_token: str,
45
- dataset_name: str,
46
- private: bool = False,
47
- info_filename: str = "dataset_info.json",
48
- separate_dirs: bool = False,
49
- ):
50
- """
51
- Parameters:
52
- hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
53
- dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
54
- private: Whether the dataset should be private (defaults to False).
55
- info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
56
- separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
57
- """
58
- self.hf_token = hf_token
59
- self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
60
- self.dataset_private = private
61
- self.info_filename = info_filename
62
- self.separate_dirs = separate_dirs
63
-
64
- def setup(self, components: Sequence[Component], flagging_dir: str):
65
- """
66
- Params:
67
- flagging_dir (str): local directory where the dataset is cloned,
68
- updated, and pushed from.
69
- """
70
- # Setup dataset on the Hub
71
- self.dataset_id = huggingface_hub.create_repo(
72
- repo_id=self.dataset_id,
73
- token=self.hf_token,
74
- private=self.dataset_private,
75
- repo_type="dataset",
76
- exist_ok=True,
77
- ).repo_id
78
- path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
79
- huggingface_hub.metadata_update(
80
- repo_id=self.dataset_id,
81
- repo_type="dataset",
82
- metadata={
83
- "configs": [
84
- {
85
- "config_name": "default",
86
- "data_files": [{"split": "train", "path": path_glob}],
87
- }
88
- ]
89
- },
90
- overwrite=True,
91
- token=self.hf_token,
92
- )
93
-
94
- # Setup flagging dir
95
- self.components = components
96
- self.dataset_dir = (
97
- Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
98
- )
99
- self.dataset_dir.mkdir(parents=True, exist_ok=True)
100
- self.infos_file = self.dataset_dir / self.info_filename
101
-
102
- # Download remote files to local
103
- remote_files = [self.info_filename]
104
- if not self.separate_dirs:
105
- # No separate dirs => means all data is in the same CSV file => download it to get its current content
106
- remote_files.append("data.csv")
107
-
108
- for filename in remote_files:
109
- try:
110
- huggingface_hub.hf_hub_download(
111
- repo_id=self.dataset_id,
112
- repo_type="dataset",
113
- filename=filename,
114
- local_dir=self.dataset_dir,
115
- token=self.hf_token,
116
- )
117
- except huggingface_hub.utils.EntryNotFoundError:
118
- pass
119
-
120
- def flag(
121
- self,
122
- flag_data: list[Any],
123
- flag_option: str = "",
124
- username: str | None = None,
125
- ) -> int:
126
- if self.separate_dirs:
127
- # JSONL files to support dataset preview on the Hub
128
- unique_id = str(uuid.uuid4())
129
- components_dir = self.dataset_dir / unique_id
130
- data_file = components_dir / "metadata.jsonl"
131
- path_in_repo = unique_id # upload in sub folder (safer for concurrency)
132
- else:
133
- # Unique CSV file
134
- components_dir = self.dataset_dir
135
- data_file = components_dir / "data.csv"
136
- path_in_repo = None # upload at root level
137
-
138
- return self._flag_in_dir(
139
- data_file=data_file,
140
- components_dir=components_dir,
141
- path_in_repo=path_in_repo,
142
- flag_data=flag_data,
143
- flag_option=flag_option,
144
- username=username or "",
145
- )
146
-
147
- def _flag_in_dir(
148
- self,
149
- data_file: Path,
150
- components_dir: Path,
151
- path_in_repo: str | None,
152
- flag_data: list[Any],
153
- flag_option: str = "",
154
- username: str = "",
155
- ) -> int:
156
- # Deserialize components (write images/audio to files)
157
- features, row = self._deserialize_components(
158
- components_dir, flag_data, flag_option, username
159
- )
160
-
161
- # Write generic info to dataset_infos.json + upload
162
- with filelock.FileLock(str(self.infos_file) + ".lock"):
163
- if not self.infos_file.exists():
164
- self.infos_file.write_text(
165
- json.dumps({"flagged": {"features": features}})
166
- )
167
-
168
- huggingface_hub.upload_file(
169
- repo_id=self.dataset_id,
170
- repo_type="dataset",
171
- token=self.hf_token,
172
- path_in_repo=self.infos_file.name,
173
- path_or_fileobj=self.infos_file,
174
- )
175
-
176
- headers = list(features.keys())
177
-
178
- if not self.separate_dirs:
179
- with filelock.FileLock(components_dir / ".lock"):
180
- sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
181
- sample_name = str(sample_nb)
182
- huggingface_hub.upload_folder(
183
- repo_id=self.dataset_id,
184
- repo_type="dataset",
185
- commit_message=f"Flagged sample #{sample_name}",
186
- path_in_repo=path_in_repo,
187
- ignore_patterns="*.lock",
188
- folder_path=components_dir,
189
- token=self.hf_token,
190
- )
191
- else:
192
- sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
193
- sample_nb = len(
194
- [path for path in self.dataset_dir.iterdir() if path.is_dir()]
195
- )
196
- huggingface_hub.upload_folder(
197
- repo_id=self.dataset_id,
198
- repo_type="dataset",
199
- commit_message=f"Flagged sample #{sample_name}",
200
- path_in_repo=path_in_repo,
201
- ignore_patterns="*.lock",
202
- folder_path=components_dir,
203
- token=self.hf_token,
204
- )
205
-
206
- return sample_nb
207
-
208
- @staticmethod
209
- def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
210
- """Save data as CSV and return the sample name (row number)."""
211
- is_new = not data_file.exists()
212
-
213
- with data_file.open("a", newline="", encoding="utf-8") as csvfile:
214
- writer = csv.writer(csvfile)
215
-
216
- # Write CSV headers if new file
217
- if is_new:
218
- writer.writerow(utils.sanitize_list_for_csv(headers))
219
-
220
- # Write CSV row for flagged sample
221
- writer.writerow(utils.sanitize_list_for_csv(row))
222
-
223
- with data_file.open(encoding="utf-8") as csvfile:
224
- return sum(1 for _ in csv.reader(csvfile)) - 1
225
-
226
- @staticmethod
227
- def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
228
- """Save data as JSONL and return the sample name (uuid)."""
229
- Path.mkdir(data_file.parent, parents=True, exist_ok=True)
230
- with open(data_file, "w", encoding="utf-8") as f:
231
- json.dump(dict(zip(headers, row)), f)
232
- return data_file.parent.name
233
-
234
- def _deserialize_components(
235
- self,
236
- data_dir: Path,
237
- flag_data: list[Any],
238
- flag_option: str = "",
239
- username: str = "",
240
- ) -> tuple[dict[Any, Any], list[Any]]:
241
- """Deserialize components and return the corresponding row for the flagged sample.
242
-
243
- Images/audio are saved to disk as individual files.
244
- """
245
- # Components that can have a preview on dataset repos
246
- file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
247
-
248
- # Generate the row corresponding to the flagged sample
249
- features = OrderedDict()
250
- row = []
251
- for component, sample in zip(self.components, flag_data):
252
- # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
253
- label = component.label or ""
254
- save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
255
- save_dir.mkdir(exist_ok=True, parents=True)
256
- deserialized = utils.simplify_file_data_in_str(
257
- component.flag(sample, save_dir)
258
- )
259
-
260
- # Add deserialized object to row
261
- features[label] = {"dtype": "string", "_type": "Value"}
262
- try:
263
- deserialized_path = Path(deserialized)
264
- if not deserialized_path.exists():
265
- raise FileNotFoundError(f"File {deserialized} not found")
266
- row.append(str(deserialized_path.relative_to(self.dataset_dir)))
267
- except (FileNotFoundError, TypeError, ValueError, OSError):
268
- deserialized = "" if deserialized is None else str(deserialized)
269
- row.append(deserialized)
270
-
271
- # If component is eligible for a preview, add the URL of the file
272
- # Be mindful that images and audio can be None
273
- if isinstance(component, tuple(file_preview_types)): # type: ignore
274
- for _component, _type in file_preview_types.items():
275
- if isinstance(component, _component):
276
- features[label + " file"] = {"_type": _type}
277
- break
278
- if deserialized:
279
- path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
280
- Path(deserialized).relative_to(self.dataset_dir)
281
- ).replace("\\", "/")
282
- row.append(
283
- huggingface_hub.hf_hub_url(
284
- repo_id=self.dataset_id,
285
- filename=path_in_repo,
286
- repo_type="dataset",
287
- )
288
- )
289
- else:
290
- row.append("")
291
-
292
- timestamp = datetime.now(timezone(timedelta(hours=9))).isoformat()
293
- features["flag"] = {"dtype": "string", "_type": "Value"}
294
- features["username"] = {"dtype": "string", "_type": "Value"}
295
- features["timestamp"] = {"dtype": "string", "_type": "Value"}
296
- row.append(flag_option)
297
- row.append(username)
298
- row.append(timestamp)
299
- return features, row
300
 
301
  # Get environment variable
302
  hf_token = os.getenv('HF_TOKEN')
@@ -312,11 +40,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
312
  hf_writer = HuggingFaceDatasetSaver(hf_token, "crowdsourced-sentiment_analysis")
313
 
314
  # Prepare model
315
- # TODO convert the model to ONNX
316
  tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", token=hf_token)
317
- model = AutoModelForSequenceClassification.from_pretrained("arcleife/roberta-sentiment-id", num_labels=3, token=hf_token).to(device)
318
 
319
- pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device, return_token_type_ids=False)
320
 
321
  def get_label(result):
322
  if result[0]['label'] == "LABEL_0":
 
7
  import re
8
  import json
9
  from datetime import datetime, timezone, timedelta
10
+ from optimum.pipelines import pipeline
11
+ from optimum.onnxruntime import ORTModelForSequenceClassification
12
+ from transformers import AutoTokenizer
13
 
14
  # version: 0.2.1
15
 
 
24
  import filelock
25
  import csv
26
 
27
+ from .hf_dataset_saver import HuggingFaceDatasetSaver
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Get environment variable
30
  hf_token = os.getenv('HF_TOKEN')
 
40
  hf_writer = HuggingFaceDatasetSaver(hf_token, "crowdsourced-sentiment_analysis")
41
 
42
  # Prepare model
 
43
  tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", token=hf_token)
44
+ model = ORTModelForSequenceClassification.from_pretrained("arcleife/roberta-sentiment-id-onnx", num_labels=3, token=hf_token).to(device)
45
 
46
+ pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device, return_token_type_ids=False, accelerator="ort")
47
 
48
  def get_label(result):
49
  if result[0]['label'] == "LABEL_0":