arcleife commited on
Commit
de9d075
·
verified ·
1 Parent(s): b762b23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -17
app.py CHANGED
@@ -24,6 +24,7 @@ import uuid
24
  import filelock
25
  import csv
26
 
 
27
  class HuggingFaceDatasetSaver(FlaggingCallback):
28
  """
29
  A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
@@ -311,6 +312,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
311
  hf_writer = HuggingFaceDatasetSaver(hf_token, "crowdsourced-sentiment_analysis")
312
 
313
  # Prepare model
 
314
  tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", token=hf_token)
315
  model = AutoModelForSequenceClassification.from_pretrained("arcleife/roberta-sentiment-id", num_labels=3, token=hf_token).to(device)
316
 
@@ -343,20 +345,4 @@ io = gr.Interface(fn=text_classification,
343
  # flagging_callback=hf_writer
344
  )
345
 
346
- io.launch(inline=False)
347
-
348
- # with gr.Blocks() as main_interface:
349
- # gr.LoginButton()
350
-
351
- # gr.Markdown("# 人格否定検知")
352
- # gr.Markdown("**Input**にテキストを入力し、**実行**をクリックしてください。")
353
- # with gr.Row():
354
- # with gr.Column():
355
- # inp = gr.Textbox(placeholder="テキストを入力してください。", label="Input", lines=4)
356
- # with gr.Column():
357
- # out = gr.Label(label="Result")
358
- # flag = gr.Button("Flag")
359
- # btn = gr.Button("実行")
360
- # btn.click(fn=text_classification, inputs=inp, outputs=out)
361
-
362
- # main_interface.launch()
 
24
  import filelock
25
  import csv
26
 
27
+ # TODO move to separate file for cleaner code
28
  class HuggingFaceDatasetSaver(FlaggingCallback):
29
  """
30
  A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
 
312
  hf_writer = HuggingFaceDatasetSaver(hf_token, "crowdsourced-sentiment_analysis")
313
 
314
  # Prepare model
315
+ # TODO convert the model to ONNX
316
  tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", token=hf_token)
317
  model = AutoModelForSequenceClassification.from_pretrained("arcleife/roberta-sentiment-id", num_labels=3, token=hf_token).to(device)
318
 
 
345
  # flagging_callback=hf_writer
346
  )
347
 
348
+ io.launch(inline=False)