File size: 912 Bytes
da51fe7
2074891
 
 
 
 
576bca9
2074891
da51fe7
 
 
 
 
2074891
da51fe7
 
96937e3
da51fe7
 
 
 
96937e3
 
 
2074891
 
da51fe7
96937e3
 
da51fe7
96937e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoProcessor, AutoModel
import torch

# set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if DEVICE.type != 'cuda':
    raise ValueError("need to run on GPU")

class EndpointHandler():
    def __init__(self, path=""):
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path)
        self.model.to(DEVICE)

    def __call__(self, data): #-> List[Dict[str, Any]]

        # get inputs
        input_text = data['inputs']

        inputs = self.processor(
            text=input_text,
            return_tensors="pt",
            voice_preset = "v2/en_speaker_6"
            ).to(DEVICE)

        speech_values = self.model.generate(**inputs, do_sample=True)
        sample_rate = self.model.generation_config.sample_rate


        return {'audio': speech_values.tolist(), 'sample_rate': sample_rate}