jhansss commited on
Commit
2195601
·
1 Parent(s): bbfbc08

Include model cleanup

Browse files
Files changed (1) hide show
  1. pipeline.py +15 -0
pipeline.py CHANGED
@@ -39,16 +39,31 @@ class SingingDialoguePipeline:
39
  self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
40
 
41
  def set_asr_model(self, asr_model: str):
 
 
 
 
 
42
  self.asr = get_asr_model(
43
  asr_model, device=self.device, cache_dir=self.cache_dir
44
  )
45
 
46
  def set_llm_model(self, llm_model: str):
 
 
 
 
 
47
  self.llm = get_llm_model(
48
  llm_model, device=self.device, cache_dir=self.cache_dir
49
  )
50
 
51
  def set_svs_model(self, svs_model: str):
 
 
 
 
 
52
  self.svs = get_svs_model(
53
  svs_model, device=self.device, cache_dir=self.cache_dir
54
  )
 
39
  self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
40
 
41
  def set_asr_model(self, asr_model: str):
42
+ if self.asr is not None:
43
+ del self.asr
44
+ import gc
45
+ gc.collect()
46
+ torch.cuda.empty_cache()
47
  self.asr = get_asr_model(
48
  asr_model, device=self.device, cache_dir=self.cache_dir
49
  )
50
 
51
  def set_llm_model(self, llm_model: str):
52
+ if self.llm is not None:
53
+ del self.llm
54
+ import gc
55
+ gc.collect()
56
+ torch.cuda.empty_cache()
57
  self.llm = get_llm_model(
58
  llm_model, device=self.device, cache_dir=self.cache_dir
59
  )
60
 
61
  def set_svs_model(self, svs_model: str):
62
+ if self.svs is not None:
63
+ del self.svs
64
+ import gc
65
+ gc.collect()
66
+ torch.cuda.empty_cache()
67
  self.svs = get_svs_model(
68
  svs_model, device=self.device, cache_dir=self.cache_dir
69
  )