File size: 2,243 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from network_wrapper import network_wrapper

class ClearVoice:
    """ The main class inferface to the end users for performing speech processing
        this class provides the desired model to perform the given task
    """
    def __init__(self, task, model_names):
        """ Load the desired models for the specified task. Perform all the given models and return all results.
   
        Parameters:
        ----------
        task: str
            the task matching any of the provided tasks: 
            'speech_enhancement'
            'speech_separation'
            'target_speaker_extraction'
        model_names: str or list of str
            the model names matching any of the provided models: 
            'FRCRN_SE_16K'
            'MossFormer2_SE_48K'
            'MossFormerGAN_SE_16K'
            'MossFormer2_SS_16K'
            'AV_MossFormer2_TSE_16K'

        Returns:
        --------
        A ModelsList object, that can be run to get the desired results
        """        
        self.network_wrapper = network_wrapper()
        self.models = []
        for model_name in model_names:
            model = self.network_wrapper(task, model_name)
            self.models += [model]

    def __call__(self, input_path, online_write=False, output_path=None):
        results = {}
        for model in self.models:
            result = model.process(input_path, online_write, output_path)
            if not online_write:
                results[model.name] = result

        if not online_write:
            if len(results) == 1:
                return results[model.name]
            else:
                return results

    def write(self, results, output_path):
        add_subdir = False
        use_key = False
        if len(self.models) > 1: add_subdir = True #multi_model is True        
        for model in self.models:
            if isinstance(results, dict):
                if model.name in results: 
                   if len(results[model.name]) > 1: use_key = True
                       
                else:
                   if len(results) > 1: use_key = True #multi_input is True
            break

        for model in self.models:
            model.write(output_path, add_subdir, use_key)