Warn when system prompt is modified
Browse files- processing_midashenglm.py +45 -1
processing_midashenglm.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import Dict, List, Optional, Union, cast
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
from transformers import Qwen2Tokenizer, Qwen2TokenizerFast, Wav2Vec2FeatureExtractor
|
| 6 |
from transformers.feature_extraction_utils import BatchFeature
|
| 7 |
-
from transformers.processing_utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from typing_extensions import Unpack
|
| 9 |
|
| 10 |
|
|
@@ -153,6 +159,44 @@ class MiDashengLMProcessor(ProcessorMixin):
|
|
| 153 |
f"Expected audio to be a numpy array, torch tensor, or string, but got {type(sample)}."
|
| 154 |
)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
def __call__(
|
| 157 |
self,
|
| 158 |
text: Optional[List[str]] = None,
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections.abc import Mapping
|
| 3 |
from typing import Dict, List, Optional, Union, cast
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
from transformers import Qwen2Tokenizer, Qwen2TokenizerFast, Wav2Vec2FeatureExtractor
|
| 8 |
from transformers.feature_extraction_utils import BatchFeature
|
| 9 |
+
from transformers.processing_utils import (
|
| 10 |
+
AllKwargsForChatTemplate,
|
| 11 |
+
ProcessingKwargs,
|
| 12 |
+
ProcessorMixin,
|
| 13 |
+
)
|
| 14 |
from typing_extensions import Unpack
|
| 15 |
|
| 16 |
|
|
|
|
| 159 |
f"Expected audio to be a numpy array, torch tensor, or string, but got {type(sample)}."
|
| 160 |
)
|
| 161 |
|
| 162 |
+
def apply_chat_template(
|
| 163 |
+
self,
|
| 164 |
+
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
| 165 |
+
chat_template: Optional[str] = None,
|
| 166 |
+
**kwargs: Unpack[AllKwargsForChatTemplate],
|
| 167 |
+
) -> str:
|
| 168 |
+
if conversation:
|
| 169 |
+
first_msgs = (
|
| 170 |
+
[conversation[0]]
|
| 171 |
+
if isinstance(conversation[0], Mapping)
|
| 172 |
+
else [conv[0] for conv in conversation if conv]
|
| 173 |
+
)
|
| 174 |
+
for first_msg in first_msgs:
|
| 175 |
+
if first_msg["role"] != "system":
|
| 176 |
+
continue
|
| 177 |
+
system_prompt: str
|
| 178 |
+
if isinstance(first_msg["content"], str):
|
| 179 |
+
system_prompt = first_msg["content"]
|
| 180 |
+
elif isinstance(first_msg["content"], list):
|
| 181 |
+
for part in first_msg["content"]:
|
| 182 |
+
if isinstance(part, dict) and "text" in part:
|
| 183 |
+
system_prompt = part["text"]
|
| 184 |
+
break
|
| 185 |
+
else:
|
| 186 |
+
continue
|
| 187 |
+
else:
|
| 188 |
+
continue
|
| 189 |
+
if system_prompt != (
|
| 190 |
+
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
|
| 191 |
+
"capable of perceiving auditory and visual inputs, as well as generating text and speech."
|
| 192 |
+
):
|
| 193 |
+
logging.warning(
|
| 194 |
+
"The system prompt has been modified, which may reduce model performance. "
|
| 195 |
+
"Prefer using the default system prompt by omitting the system role from the input."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return super().apply_chat_template(conversation, chat_template, **kwargs)
|
| 199 |
+
|
| 200 |
def __call__(
|
| 201 |
self,
|
| 202 |
text: Optional[List[str]] = None,
|