Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja
Browse filesAdded "{% generation %}" markers enables the TRL [SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTConfig)'s `assistant_only_loss` config option. `assistant_only_loss` tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around `{% generation %}` by this PR. I confirmed that this behaves as expected by using this custom template for the `gpt-oss-20b` tokenizer as the `processing_class` for SFTTrainer.
See this transformers [PR](https://github.com/huggingface/transformers/pull/30650) that introduced this change
See also how [trl/trainer/sft_trainer.py](https://github.com/huggingface/trl/blob/206964ce16e15f2afd4f8f12fe49d1d828312f97/trl/trainer/sft_trainer.py#L845) uses this marker in [transformers/utils/chat_template_utils.py](https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/utils/chat_template_utils.py#L475).
Code segment to verify the masking is done correctly, where assistant tokens are printed in green:
```
tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE
templated_output = tokenizer.apply_chat_template(
sample['messages'],
tokenize=True,
add_generation_prompt=False,
return_assistant_tokens_mask=True,
return_dict=True,
)
print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"
input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']
if len(input_ids) != len(assistant_mask):
raise ValueError("Mismatch between input_ids and assistant_masks length.")
current_chunk_tokens = []
current_mask_status = None
for token_id, is_assistant in zip(input_ids, assistant_mask):
mask_status = bool(is_assistant)
if current_mask_status is None:
current_mask_status = mask_status
if mask_status != current_mask_status:
# Decode and print the completed chunk
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
# Start a new chunk
current_chunk_tokens = [token_id]
current_mask_status = mask_status
else:
current_chunk_tokens.append(token_id)
# Print the final chunk after the loop
if current_chunk_tokens:
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
```
Prints something like:
```
<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]
```
- chat_template.jinja +24 -17
|
@@ -288,30 +288,37 @@
|
|
| 288 |
{%- endif %}
|
| 289 |
{%- if message.content and message.thinking %}
|
| 290 |
{{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
|
| 291 |
-
{%- elif message.content and not future_final_message.found %}
|
| 292 |
-
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
|
| 293 |
-
{%- elif message.thinking and not future_final_message.found %}
|
| 294 |
-
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
| 295 |
{%- endif %}
|
| 296 |
-
{
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
{%- set last_tool_call.name = tool_call.name %}
|
| 302 |
{%- elif loop.last and not add_generation_prompt %}
|
| 303 |
{#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
|
| 304 |
{#- This is a situation that should only occur in training, never in inference. #}
|
| 305 |
-
{
|
| 306 |
-
{
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
{%- else %}
|
| 313 |
{#- CoT is dropped during all previous turns, so we never render it for inference #}
|
| 314 |
-
{
|
|
|
|
|
|
|
| 315 |
{%- set last_tool_call.name = none %}
|
| 316 |
{%- endif %}
|
| 317 |
{%- elif message.role == 'tool' -%}
|
|
|
|
| 288 |
{%- endif %}
|
| 289 |
{%- if message.content and message.thinking %}
|
| 290 |
{{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
{%- endif %}
|
| 292 |
+
{% generation %}
|
| 293 |
+
{%- if message.content and not future_final_message.found %}
|
| 294 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
|
| 295 |
+
{%- elif message.thinking and not future_final_message.found %}
|
| 296 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
| 297 |
+
{%- endif %}
|
| 298 |
+
{{- "<|start|>assistant to=" }}
|
| 299 |
+
{{- "functions." + tool_call.name + "<|channel|>commentary " }}
|
| 300 |
+
{{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
|
| 301 |
+
{{- tool_call.arguments|tojson }}
|
| 302 |
+
{{- "<|call|>" }}
|
| 303 |
+
{% endgeneration %}
|
| 304 |
{%- set last_tool_call.name = tool_call.name %}
|
| 305 |
{%- elif loop.last and not add_generation_prompt %}
|
| 306 |
{#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
|
| 307 |
{#- This is a situation that should only occur in training, never in inference. #}
|
| 308 |
+
{% generation %}
|
| 309 |
+
{%- if "thinking" in message %}
|
| 310 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
| 311 |
+
{%- endif %}
|
| 312 |
+
{#- <|return|> indicates the end of generation, but <|end|> does not #}
|
| 313 |
+
{#- <|return|> should never be an input to the model, but we include it as the final token #}
|
| 314 |
+
{#- when training, so the model learns to emit it. #}
|
| 315 |
+
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
|
| 316 |
+
{% endgeneration %}
|
| 317 |
{%- else %}
|
| 318 |
{#- CoT is dropped during all previous turns, so we never render it for inference #}
|
| 319 |
+
{% generation %}
|
| 320 |
+
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
|
| 321 |
+
{% endgeneration %}
|
| 322 |
{%- set last_tool_call.name = none %}
|
| 323 |
{%- endif %}
|
| 324 |
{%- elif message.role == 'tool' -%}
|