Spaces:
Running
Running
Bram Vanroy
commited on
Commit
·
b818293
1
Parent(s):
f9cf5b5
use new st caching
Browse filesUnfortunately no more hash_funcs :(
utils.py
CHANGED
@@ -1,25 +1,16 @@
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
-
import streamlit as st
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch.quantization import quantize_dynamic
|
7 |
-
from torch import nn, qint8, Tensor
|
8 |
-
from torch.nn import Parameter
|
9 |
-
from transformers import PreTrainedModel, PreTrainedTokenizer
|
10 |
from optimum.bettertransformer import BetterTransformer
|
11 |
from mbart_amr.constraints.constraints import AMRLogitsProcessor
|
12 |
from mbart_amr.data.tokenization import AMRMBartTokenizer
|
|
|
|
|
|
|
|
|
13 |
from transformers import MBartForConditionalGeneration
|
14 |
|
15 |
|
16 |
-
|
17 |
-
PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
|
18 |
-
Parameter: lambda parameter: parameter.data,
|
19 |
-
Tensor: lambda tensor: tensor.cpu()}
|
20 |
-
|
21 |
-
|
22 |
-
@st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
|
23 |
def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
|
24 |
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
|
25 |
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
|
@@ -51,7 +42,6 @@ def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = Fal
|
|
51 |
return model, tokenizer, logits_processor
|
52 |
|
53 |
|
54 |
-
@st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
|
55 |
def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
|
56 |
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
|
57 |
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
|
|
|
1 |
from typing import Tuple
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from optimum.bettertransformer import BetterTransformer
|
4 |
from mbart_amr.constraints.constraints import AMRLogitsProcessor
|
5 |
from mbart_amr.data.tokenization import AMRMBartTokenizer
|
6 |
+
import streamlit as st
|
7 |
+
import torch
|
8 |
+
from torch.quantization import quantize_dynamic
|
9 |
+
from torch import nn, qint8
|
10 |
from transformers import MBartForConditionalGeneration
|
11 |
|
12 |
|
13 |
+
@st.cache_resource(show_spinner=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
|
15 |
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
|
16 |
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
|
|
|
42 |
return model, tokenizer, logits_processor
|
43 |
|
44 |
|
|
|
45 |
def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
|
46 |
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
|
47 |
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
|