xbarusui commited on
Commit
1d87d9c
·
verified ·
1 Parent(s): a00e9dd

first edit app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -4
app.py CHANGED
@@ -1,7 +1,95 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import io
6
+ import re
7
+ import os
8
+ from datetime import datetime
9
+ import spaces
10
 
11
+ @spaces.GPU
12
+ def load_model():
13
+ model_id = "oshizo/japanese-sexual-moderation-v2"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ model = AutoModelForSequenceClassification.from_pretrained(
16
+ model_id,
17
+ problem_type="regression"
18
+ )
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model = model.to(device)
21
+ return tokenizer, model, device
22
 
23
+ @spaces.GPU
24
+ def analyze_text(text, tokenizer, model, device):
25
+ with torch.no_grad():
26
+ encoding = tokenizer([text], padding='max_length', truncation=True, max_length=64, return_tensors="pt")
27
+ encoding = {k: v.to(device) for k, v in encoding.items()}
28
+ score = model(**encoding).logits.item()
29
+ return score
30
+
31
+ @spaces.GPU
32
+ def split_text(text, split_by='sentence'):
33
+ if split_by == 'sentence':
34
+ return [sent.strip() for sent in re.split('。|!|?', text) if sent.strip()]
35
+ else: # split by line
36
+ return [line.strip() for line in text.split('\n') if line.strip()]
37
+
38
+ @spaces.GPU
39
+ def create_graph(texts, scores):
40
+ fig, ax = plt.subplots(figsize=(12, 6))
41
+ ax.bar(range(len(scores)), scores)
42
+ ax.set_xlabel('テキスト番号')
43
+ ax.set_ylabel('スコア')
44
+ ax.set_title("分析結果")
45
+ ax.set_xticks(range(len(scores)))
46
+ ax.set_xticklabels(range(1, len(scores) + 1))
47
+ plt.tight_layout()
48
+ return fig
49
+
50
+ @spaces.GPU
51
+ def create_not_r18_text(texts, scores):
52
+ not_r18_texts = []
53
+ for text, score in zip(texts, scores):
54
+ if score < 0.4:
55
+ not_r18_texts.append(text)
56
+ else:
57
+ not_r18_texts.append('') # 除外された行の位置に空行を挿入
58
+ return '\n'.join(not_r18_texts)
59
+
60
+ tokenizer, model, device = load_model()
61
+
62
+ @spaces.GPU
63
+ def process_text(text, split_by):
64
+ texts = split_text(text, split_by)
65
+ scores = [analyze_text(t, tokenizer, model, device) for t in texts]
66
+
67
+ graph = create_graph(texts, scores)
68
+ not_r18_text = create_not_r18_text(texts, scores)
69
+
70
+ result = {
71
+ "texts": texts,
72
+ "scores": scores,
73
+ }
74
+
75
+ return result, graph, not_r18_text
76
+
77
+ # Gradio インターフェースの定義
78
+ iface = gr.Interface(
79
+ fn=process_text,
80
+ inputs=[
81
+ gr.Textbox(label="テキスト入力"),
82
+ gr.Radio(["sentence", "line"], label="分割方法", value="sentence")
83
+ ],
84
+ outputs=[
85
+ gr.JSON(label="分析結果"),
86
+ gr.Plot(label="スコアグラフ"),
87
+ gr.Textbox(label="R18判定除外テキスト")
88
+ ],
89
+ title="テキスト分析API",
90
+ description="テキストを入力し、R18判定と分析を行います。"
91
+ )
92
+
93
+ # サーバーの起動
94
+ if __name__ == "__main__":
95
+ iface.launch(server_name="0.0.0.0", server_port=7860)