bokesyo wanderor commited on
Commit
df3d4c9
·
verified ·
1 Parent(s): 9a8db9d

Supports MacOS in MiniCPM-o 2.6 (#19)

Browse files

- Supports MacOS in MiniCPM-o 2.6 (1ceb0cbfa4dd6c40d2d504994a50afd210222039)


Co-authored-by: Richard Fang <[email protected]>

Files changed (1) hide show
  1. modeling_minicpmo.py +6 -6
modeling_minicpmo.py CHANGED
@@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
184
  args=(),
185
  init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
186
  )
187
- vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
188
  vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
189
  return vocos
190
 
@@ -1207,7 +1207,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1207
 
1208
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1209
  generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
1210
- input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
1211
 
1212
  spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
1213
  spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
@@ -1311,7 +1311,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1311
  text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
1312
  tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
1313
  "input_ids"
1314
- ].cuda()
1315
  return tts_input_ids
1316
 
1317
  def _build_streaming_mask(self, tts_tokens_len):
@@ -1342,7 +1342,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1342
  gen_text = text.split("<|tts_eos|>")[0]
1343
  tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
1344
  tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
1345
- tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
1346
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1347
 
1348
  logits_warpers, logits_processors = gen_logits(
@@ -1639,7 +1639,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1639
 
1640
  tts_input_ids = self.tts_processor.text_tokenizer(
1641
  tts_text, return_tensors="pt", add_special_tokens=False
1642
- )["input_ids"].cuda()
1643
  text_input_ids = tts_input_ids[:, begin:end]
1644
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1645
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
@@ -1748,7 +1748,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1748
  if end > begin:
1749
  tts_input_ids = self.tts_processor.text_tokenizer(
1750
  tts_text, return_tensors="pt", add_special_tokens=False
1751
- )["input_ids"].cuda()
1752
  text_input_ids = tts_input_ids[:, begin:end]
1753
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1754
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
 
184
  args=(),
185
  init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
186
  )
187
+ vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32)
188
  vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
189
  return vocos
190
 
 
1207
 
1208
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1209
  generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
1210
+ input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device)
1211
 
1212
  spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
1213
  spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
 
1311
  text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
1312
  tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
1313
  "input_ids"
1314
+ ].to(self.device)
1315
  return tts_input_ids
1316
 
1317
  def _build_streaming_mask(self, tts_tokens_len):
 
1342
  gen_text = text.split("<|tts_eos|>")[0]
1343
  tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
1344
  tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
1345
+ tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
1346
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1347
 
1348
  logits_warpers, logits_processors = gen_logits(
 
1639
 
1640
  tts_input_ids = self.tts_processor.text_tokenizer(
1641
  tts_text, return_tensors="pt", add_special_tokens=False
1642
+ )["input_ids"].to(self.device)
1643
  text_input_ids = tts_input_ids[:, begin:end]
1644
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1645
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
 
1748
  if end > begin:
1749
  tts_input_ids = self.tts_processor.text_tokenizer(
1750
  tts_text, return_tensors="pt", add_special_tokens=False
1751
+ )["input_ids"].to(self.device)
1752
  text_input_ids = tts_input_ids[:, begin:end]
1753
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1754
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)