Supports MacOS in MiniCPM-o 2.6 (#19)
Browse files- Supports MacOS in MiniCPM-o 2.6 (1ceb0cbfa4dd6c40d2d504994a50afd210222039)
Co-authored-by: Richard Fang <[email protected]>
- modeling_minicpmo.py +6 -6
modeling_minicpmo.py
CHANGED
@@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
184 |
args=(),
|
185 |
init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
|
186 |
)
|
187 |
-
vocos = Vocos(feature_extractor, backbone, head).to(
|
188 |
vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
189 |
return vocos
|
190 |
|
@@ -1207,7 +1207,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
1207 |
|
1208 |
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
1209 |
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
|
1210 |
-
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].
|
1211 |
|
1212 |
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
|
1213 |
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
|
@@ -1311,7 +1311,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
1311 |
text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
|
1312 |
tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
|
1313 |
"input_ids"
|
1314 |
-
].
|
1315 |
return tts_input_ids
|
1316 |
|
1317 |
def _build_streaming_mask(self, tts_tokens_len):
|
@@ -1342,7 +1342,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
1342 |
gen_text = text.split("<|tts_eos|>")[0]
|
1343 |
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
1344 |
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
|
1345 |
-
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(
|
1346 |
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
1347 |
|
1348 |
logits_warpers, logits_processors = gen_logits(
|
@@ -1639,7 +1639,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
1639 |
|
1640 |
tts_input_ids = self.tts_processor.text_tokenizer(
|
1641 |
tts_text, return_tensors="pt", add_special_tokens=False
|
1642 |
-
)["input_ids"].
|
1643 |
text_input_ids = tts_input_ids[:, begin:end]
|
1644 |
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
1645 |
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
@@ -1748,7 +1748,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
1748 |
if end > begin:
|
1749 |
tts_input_ids = self.tts_processor.text_tokenizer(
|
1750 |
tts_text, return_tensors="pt", add_special_tokens=False
|
1751 |
-
)["input_ids"].
|
1752 |
text_input_ids = tts_input_ids[:, begin:end]
|
1753 |
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
1754 |
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
|
|
184 |
args=(),
|
185 |
init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
|
186 |
)
|
187 |
+
vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32)
|
188 |
vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
189 |
return vocos
|
190 |
|
|
|
1207 |
|
1208 |
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
1209 |
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
|
1210 |
+
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device)
|
1211 |
|
1212 |
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
|
1213 |
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
|
|
|
1311 |
text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
|
1312 |
tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
|
1313 |
"input_ids"
|
1314 |
+
].to(self.device)
|
1315 |
return tts_input_ids
|
1316 |
|
1317 |
def _build_streaming_mask(self, tts_tokens_len):
|
|
|
1342 |
gen_text = text.split("<|tts_eos|>")[0]
|
1343 |
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
1344 |
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
|
1345 |
+
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
|
1346 |
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
1347 |
|
1348 |
logits_warpers, logits_processors = gen_logits(
|
|
|
1639 |
|
1640 |
tts_input_ids = self.tts_processor.text_tokenizer(
|
1641 |
tts_text, return_tensors="pt", add_special_tokens=False
|
1642 |
+
)["input_ids"].to(self.device)
|
1643 |
text_input_ids = tts_input_ids[:, begin:end]
|
1644 |
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
1645 |
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
|
|
1748 |
if end > begin:
|
1749 |
tts_input_ids = self.tts_processor.text_tokenizer(
|
1750 |
tts_text, return_tensors="pt", add_special_tokens=False
|
1751 |
+
)["input_ids"].to(self.device)
|
1752 |
text_input_ids = tts_input_ids[:, begin:end]
|
1753 |
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
1754 |
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|