TheOnlyHatem commited on
Commit
82696af
·
verified ·
1 Parent(s): 45da514

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -1,22 +1,29 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import MBartForConditionalGeneration, MBartTokenizer
4
 
5
- # Remplace par ton repo exact si besoin :
6
- model_name = "alice/mini/mBART_french_correction"
7
 
8
- # Chargement du tokenizer et du modèle
9
- tokenizer = MBartTokenizer.from_pretrained(model_name)
10
- model = MBartForConditionalGeneration.from_pretrained(model_name)
11
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model.to(device)
14
 
15
  def correction_grammaticale(texte):
 
 
 
 
 
 
16
  # Tokenisation
17
- inputs = tokenizer(texte, return_tensors="pt", max_length=512, truncation=True).to(device)
18
 
19
- # Génération
20
  outputs = model.generate(
21
  **inputs,
22
  max_length=512,
@@ -24,15 +31,16 @@ def correction_grammaticale(texte):
24
  early_stopping=True
25
  )
26
 
27
- # Décodage
28
  correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
  return correction
30
 
 
31
  demo = gr.Interface(
32
  fn=correction_grammaticale,
33
  inputs=gr.Textbox(label="Texte à corriger"),
34
  outputs=gr.Textbox(label="Texte corrigé"),
35
- title="Correcteur MBART Français"
36
  )
37
 
38
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Modèle Hugging Face sélectionné
6
+ model_name = "PoloHuggingface/French_grammar_error_corrector"
7
 
8
+ # Chargement du modèle et du tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
 
12
+ # Vérification GPU/CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
16
  def correction_grammaticale(texte):
17
+ """
18
+ Fonction qui envoie le texte au modèle T5 pour correction grammaticale.
19
+ """
20
+ # Préfixe facultatif (à tester si nécessaire)
21
+ input_text = texte
22
+
23
  # Tokenisation
24
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
25
 
26
+ # Génération du texte corrigé
27
  outputs = model.generate(
28
  **inputs,
29
  max_length=512,
 
31
  early_stopping=True
32
  )
33
 
34
+ # Décodage du texte corrigé
35
  correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
  return correction
37
 
38
+ # Interface utilisateur Gradio
39
  demo = gr.Interface(
40
  fn=correction_grammaticale,
41
  inputs=gr.Textbox(label="Texte à corriger"),
42
  outputs=gr.Textbox(label="Texte corrigé"),
43
+ title="Correcteur Grammatical Français"
44
  )
45
 
46
  if __name__ == "__main__":