Bram Vanroy commited on
Commit
b818293
·
1 Parent(s): f9cf5b5

use new st caching

Browse files

Unfortunately no more hash_funcs :(

Files changed (1) hide show
  1. utils.py +5 -15
utils.py CHANGED
@@ -1,25 +1,16 @@
1
  from typing import Tuple
2
 
3
- import streamlit as st
4
-
5
- import torch
6
- from torch.quantization import quantize_dynamic
7
- from torch import nn, qint8, Tensor
8
- from torch.nn import Parameter
9
- from transformers import PreTrainedModel, PreTrainedTokenizer
10
  from optimum.bettertransformer import BetterTransformer
11
  from mbart_amr.constraints.constraints import AMRLogitsProcessor
12
  from mbart_amr.data.tokenization import AMRMBartTokenizer
 
 
 
 
13
  from transformers import MBartForConditionalGeneration
14
 
15
 
16
- st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
17
- PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
18
- Parameter: lambda parameter: parameter.data,
19
- Tensor: lambda tensor: tensor.cpu()}
20
-
21
-
22
- @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
23
  def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
24
  """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
25
  model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
@@ -51,7 +42,6 @@ def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = Fal
51
  return model, tokenizer, logits_processor
52
 
53
 
54
- @st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
55
  def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
56
  """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
57
  potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
 
1
  from typing import Tuple
2
 
 
 
 
 
 
 
 
3
  from optimum.bettertransformer import BetterTransformer
4
  from mbart_amr.constraints.constraints import AMRLogitsProcessor
5
  from mbart_amr.data.tokenization import AMRMBartTokenizer
6
+ import streamlit as st
7
+ import torch
8
+ from torch.quantization import quantize_dynamic
9
+ from torch import nn, qint8
10
  from transformers import MBartForConditionalGeneration
11
 
12
 
13
+ @st.cache_resource(show_spinner=False)
 
 
 
 
 
 
14
  def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
15
  """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
16
  model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
 
42
  return model, tokenizer, logits_processor
43
 
44
 
 
45
  def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
46
  """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
47
  potential keyword-arguments, which can include arguments such as max length, logits processors, etc.