amurienne commited on
Commit
12ea454
·
verified ·
1 Parent(s): 01f53fe

enabling zerogpu

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -21,6 +21,8 @@
21
  # THE SOFTWARE.
22
 
23
  import os
 
 
24
 
25
  import gradio as gr
26
 
@@ -32,6 +34,9 @@ from transformers import (
32
 
33
  from huggingface_hub import InferenceClient
34
 
 
 
 
35
  # CHAT MODEL
36
 
37
  class chat_engine_hf_api:
@@ -67,6 +72,7 @@ bw_tokenizer = AutoTokenizer.from_pretrained(bw_modelcard)
67
  bw_translation_pipeline = pipeline("translation", model=bw_model, tokenizer=bw_tokenizer, src_lang='br', tgt_lang='fr', max_length=400, device="cpu")
68
 
69
  # translation function
 
70
  def translate(text, forward: bool):
71
  if forward:
72
  return fw_translation_pipeline("traduis de français en breton: " + text)[0]['translation_text']
 
21
  # THE SOFTWARE.
22
 
23
  import os
24
+ import spaces
25
+ import torch
26
 
27
  import gradio as gr
28
 
 
34
 
35
  from huggingface_hub import InferenceClient
36
 
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ print(f"current device is: {device}")
39
+
40
  # CHAT MODEL
41
 
42
  class chat_engine_hf_api:
 
72
  bw_translation_pipeline = pipeline("translation", model=bw_model, tokenizer=bw_tokenizer, src_lang='br', tgt_lang='fr', max_length=400, device="cpu")
73
 
74
  # translation function
75
+ @spaces.GPU
76
  def translate(text, forward: bool):
77
  if forward:
78
  return fw_translation_pipeline("traduis de français en breton: " + text)[0]['translation_text']