zhoukz commited on
Commit
2c2b5be
·
unverified ·
1 Parent(s): 5b1f7b7

Warn when system prompt is modified

Browse files
Files changed (1) hide show
  1. processing_midashenglm.py +45 -1
processing_midashenglm.py CHANGED
@@ -1,10 +1,16 @@
 
 
1
  from typing import Dict, List, Optional, Union, cast
2
 
3
  import numpy as np
4
  import torch
5
  from transformers import Qwen2Tokenizer, Qwen2TokenizerFast, Wav2Vec2FeatureExtractor
6
  from transformers.feature_extraction_utils import BatchFeature
7
- from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
 
 
 
 
8
  from typing_extensions import Unpack
9
 
10
 
@@ -153,6 +159,44 @@ class MiDashengLMProcessor(ProcessorMixin):
153
  f"Expected audio to be a numpy array, torch tensor, or string, but got {type(sample)}."
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def __call__(
157
  self,
158
  text: Optional[List[str]] = None,
 
1
+ import logging
2
+ from collections.abc import Mapping
3
  from typing import Dict, List, Optional, Union, cast
4
 
5
  import numpy as np
6
  import torch
7
  from transformers import Qwen2Tokenizer, Qwen2TokenizerFast, Wav2Vec2FeatureExtractor
8
  from transformers.feature_extraction_utils import BatchFeature
9
+ from transformers.processing_utils import (
10
+ AllKwargsForChatTemplate,
11
+ ProcessingKwargs,
12
+ ProcessorMixin,
13
+ )
14
  from typing_extensions import Unpack
15
 
16
 
 
159
  f"Expected audio to be a numpy array, torch tensor, or string, but got {type(sample)}."
160
  )
161
 
162
+ def apply_chat_template(
163
+ self,
164
+ conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
165
+ chat_template: Optional[str] = None,
166
+ **kwargs: Unpack[AllKwargsForChatTemplate],
167
+ ) -> str:
168
+ if conversation:
169
+ first_msgs = (
170
+ [conversation[0]]
171
+ if isinstance(conversation[0], Mapping)
172
+ else [conv[0] for conv in conversation if conv]
173
+ )
174
+ for first_msg in first_msgs:
175
+ if first_msg["role"] != "system":
176
+ continue
177
+ system_prompt: str
178
+ if isinstance(first_msg["content"], str):
179
+ system_prompt = first_msg["content"]
180
+ elif isinstance(first_msg["content"], list):
181
+ for part in first_msg["content"]:
182
+ if isinstance(part, dict) and "text" in part:
183
+ system_prompt = part["text"]
184
+ break
185
+ else:
186
+ continue
187
+ else:
188
+ continue
189
+ if system_prompt != (
190
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
191
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech."
192
+ ):
193
+ logging.warning(
194
+ "The system prompt has been modified, which may reduce model performance. "
195
+ "Prefer using the default system prompt by omitting the system role from the input."
196
+ )
197
+
198
+ return super().apply_chat_template(conversation, chat_template, **kwargs)
199
+
200
  def __call__(
201
  self,
202
  text: Optional[List[str]] = None,