ChatTS-8B / modeling_qwen3_ts.py
xiezhe24's picture
Upload folder using huggingface_hub
0adfe61 verified
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
import torch
from torch import nn
from dataclasses import dataclass
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast
)
from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging, ModelOutput
from .configuration_qwen3_ts import Qwen3TSConfig
from typing import Any, Dict
logger = logging.get_logger(__name__)
########################MLP TS Embedding#####################
class TimeSeriesEmbedding(nn.Module):
def __init__(self, config):
super(TimeSeriesEmbedding, self).__init__()
self.patch_size = config['patch_size']
self.num_layers = config['num_layers']
self.hidden_size = config['hidden_size']
self.num_features = config['num_features']
self.max_sequence_length = config['max_sequence_length'] # Maximum time series length
self.use_position_embedding = config.get('use_position_embedding', False)
self.use_position_idx = config.get('use_position_idx', False)
self.use_layer_norm = config.get('use_layer_norm', False)
self.embedding_dim = config.get('embedding_dim', 16) # Embedding dimension
if self.use_position_embedding:
# Extended vocabulary: [0, max_sequence_length) for real positions, max_sequence_length for padding
self.position_embedding = nn.Embedding(self.max_sequence_length + 1, self.embedding_dim)
self.padding_idx = self.max_sequence_length # Special index for padding
input_size = 1 * self.patch_size + self.embedding_dim * self.patch_size
elif self.use_position_idx:
input_size = 2 * self.patch_size
else:
input_size = 1 * self.patch_size
# Build MLP layers
layers = []
for _ in range(self.num_layers - 1):
layers.append(nn.Linear(input_size, self.hidden_size))
layers.append(nn.GELU())
input_size = self.hidden_size
layers.append(nn.Linear(input_size, self.hidden_size))
if self.use_layer_norm:
layers.append(nn.LayerNorm(self.hidden_size))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor):
batch_size = x.size(0)
x = x.reshape(batch_size, -1, self.num_features)
# Extract mask and calculate valid lengths
mask = x[:, :, -1].long()
valid_lengths = mask.sum(dim=1).long()
patch_cnt = (valid_lengths + self.patch_size - 1) // self.patch_size
patches_list = []
# Collect position indices for batch embedding lookup
all_position_indices = []
patch_info_list = [] # Store metadata for each patch group
for i in range(batch_size):
vl = valid_lengths[i].item()
pc = patch_cnt[i].item()
if pc == 0:
continue
# Extract time series data (excluding mask)
xi = x[i, :vl, :1] # Time-series data
total_padded_length = pc * self.patch_size
padding_length = total_padded_length - vl
# Create position indices: real positions for actual data, special index for padding
position_indices = torch.arange(vl, device=x.device)
if padding_length > 0:
# Pad with last value
last_value = xi[-1:, :]
padding = last_value.repeat(padding_length, 1)
xi = torch.cat([xi, padding], dim=0)
# Use special padding index for padding positions
padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
position_indices = torch.cat([position_indices, padding_positions], dim=0)
# Reshape to patches
xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
if self.use_position_embedding:
# Collect position indices instead of calling embedding immediately
all_position_indices.append(position_indices)
patch_info_list.append({
'xi': xi,
'pc': pc,
'sample_idx': i
})
elif self.use_position_idx:
# Normalize position indices
pos_indices = torch.arange(vl, device=x.device).unsqueeze(1)
pos_indices = pos_indices / max(1, valid_lengths.max().item() - 1)
if padding_length > 0:
# Use -1 for padding positions
padding_indices = torch.full((padding_length, 1), -1, device=x.device)
pos_indices = torch.cat([pos_indices, padding_indices], dim=0)
# Combine time series data with position indices
xi_combined = torch.cat([xi.reshape(-1, 1), pos_indices], dim=1)
patch_input = xi_combined.reshape(pc, self.patch_size * 2)
patches_list.append(patch_input)
else:
# No position embedding, use raw patches
patch_input = xi
patches_list.append(patch_input)
# Batch process position embeddings if needed
if self.use_position_embedding and all_position_indices:
# Concatenate all position indices for batch embedding lookup
batch_position_indices = torch.cat(all_position_indices, dim=0)
# print(f"{x.shape=}, {x.device=}, {len(all_position_indices)=}, {batch_position_indices=}")
batch_pos_emb = self.position_embedding(batch_position_indices) # Single embedding call
# Split embeddings back and create patch inputs
emb_start_idx = 0
for patch_info in patch_info_list:
xi = patch_info['xi']
pc = patch_info['pc']
# Extract corresponding embeddings
pos_emb = batch_pos_emb[emb_start_idx:emb_start_idx + pc]
emb_start_idx += pc
# Flatten and concatenate
xi = xi.unsqueeze(-1) # (num_patches, patch_size, 1)
patch_input = torch.cat([
xi.flatten(1), # (num_patches, patch_size)
pos_emb.flatten(1) # (num_patches, patch_size * embedding_dim)
], dim=1)
patches_list.append(patch_input)
# Process all patches through MLP
if patches_list:
x_patches = torch.cat(patches_list, dim=0)
x = self.mlp(x_patches)
else:
# Handle empty case
x = torch.empty(0, self.hidden_size, device=x.device)
return x, patch_cnt
@dataclass
class Qwen3TSCausalLMOutputWithPast(CausalLMOutputWithPast):
"""
Output type of Qwen3TSForCausalLM that includes additional fields for timeseries processing.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
The attention mask used in the forward pass, potentially expanded to accommodate timeseries patches.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss.
new_token_positions (`torch.LongTensor` of shape `(batch_size, num_new_tokens)`, *optional*):
Positions where new tokens (not from timeseries) are located in the expanded sequence.
"""
attention_mask: Optional[torch.FloatTensor] = None
labels: Optional[torch.LongTensor] = None
new_token_positions: Optional[torch.LongTensor] = None
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen3TSGenerationMixin(GenerationMixin):
"""
Generation mixin for Qwen3 models with timeseries support.
This mixin handles the special case where timeseries embeddings expand the sequence length
during the first forward pass, requiring special attention mask management.
"""
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
timeseries: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Prepare inputs for generation with timeseries support.
Timeseries are only processed during the first forward pass. In subsequent
generation steps, they are already embedded in the past_key_values.
"""
# Check if we have timeseries data
has_ts = timeseries is not None and len(timeseries) > 0
# Handle timeseries for generation with past_key_values
if has_ts and past_key_values is not None:
# Get the number of tokens already processed
if isinstance(past_key_values, Cache):
past_length = past_key_values.seen_tokens
else:
past_length = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
# If we have processed tokens, timeseries have already been embedded
if past_length > 0:
# Only keep the last token for next token prediction
input_ids = input_ids[:, -1:]
# Clear timeseries as they've been processed
timeseries = None
has_ts = False
# Call parent's prepare_inputs_for_generation to handle all standard logic
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs
)
# Add timeseries to model inputs
model_inputs["timeseries"] = timeseries
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
"""
Update model keyword arguments for generation, handling attention mask from outputs.
This is necessary because timeseries processing can expand the sequence length,
and we need to use the expanded attention_mask from the model outputs.
"""
# Handle special case: use attention_mask from outputs if available
# This is crucial for timeseries models where the sequence length is expanded
if hasattr(outputs, "attention_mask") and outputs.attention_mask is not None:
model_kwargs["attention_mask"] = outputs.attention_mask
# Call parent's implementation for standard updates
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)
return model_kwargs
def generate(
self,
inputs: Optional[torch.Tensor] = None,
timeseries: Optional[torch.FloatTensor] = None,
generation_config=None,
**kwargs,
):
"""
Generate sequences with timeseries support.
Args:
inputs: Input token ids
timeseries: Optional timeseries data to be processed in the first forward pass
generation_config: Generation configuration
**kwargs: Additional keyword arguments for generation
"""
# Add timeseries to kwargs if provided
if timeseries is not None:
kwargs["timeseries"] = timeseries
# Call parent's generate method
return super().generate(
inputs=inputs,
generation_config=generation_config,
**kwargs,
)
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]) -> None:
"""
Validate model kwargs, allowing timeseries as a valid argument.
"""
# Remove timeseries from model_kwargs temporarily for validation
timeseries = model_kwargs.pop("timeseries", None)
# Call parent's validation
super()._validate_model_kwargs(model_kwargs)
# Restore timeseries
if timeseries is not None:
model_kwargs["timeseries"] = timeseries
class Qwen3TSPreTrainedModel(Qwen3PreTrainedModel):
config_class = Qwen3TSConfig
@auto_docstring
class Qwen3TSForCausalLM(Qwen3TSPreTrainedModel, Qwen3TSGenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Qwen3Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# TS embedding
self.ts_encoder = TimeSeriesEmbedding(config.ts)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def _merge_input_ids_with_time_series_features(self, time_series_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt):
batch_size, sequence_length = input_ids.shape
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = False
if batch_size > 1:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
left_padding = False
elif not _left_padding and not _right_padding:
left_padding = False
else:
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
else:
if _left_padding and not _right_padding:
left_padding = True
else:
left_padding = False
# 1. Create a mask to know where special time series tokens are
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
# 2. Calculate patch count
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
total_time_steps, embed_dim = time_series_features.shape
# Correctly calculate the total number of patches per batch
patch_index = 0
num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
attn_mask_cnt = attention_mask.sum(dim=-1)
for i in range(batch_size):
num_ts_in_batch = num_special_ts_tokens[i]
num_total_patches[i] = patch_cnt[patch_index : patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
for idx in range(patch_index, patch_index + num_ts_in_batch):
b_idx, pos = special_ts_token_mask_start_nonzero[idx]
special_ts_token_mask_start_with_size[b_idx, pos] *= (patch_cnt[idx].item() - 2)
patch_index += num_ts_in_batch
attn_mask_cnt[i] += num_total_patches[i].item()
# 3. Embeding length
max_embed_dim = sequence_length + num_total_patches.max()
# 4. Non ts tokens
batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
attn_batch_indices, attn_indices = torch.where(attention_mask == 1)
# 5. Text token in final text positions
new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
# nb_ts_pad
nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_ts_pad[:, None]
text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
# 6. Final embedding and attention masks
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device)
for i in range(attention_mask.size(0)):
if left_padding:
final_attention_mask[i, max_embed_dim - attn_mask_cnt[i] :] = 1
else:
final_attention_mask[i, : attn_mask_cnt[i]] = 1
final_labels = None
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
target_device = inputs_embeds.device
batch_indices, non_ts_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_ts_indices.to(target_device),
text_to_overwrite.to(target_device),
)
# 7. Move embedding and labels to final positions
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
# 8. Move time series to final positions
ts_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
ts_to_overwrite[batch_indices, text_to_overwrite] = False
reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
# Check that the number of time series tokens is correct
if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
)
final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
# if str(input_ids.device) == 'cuda:0':
# print(f"[EMBED] {final_embedding[ts_to_overwrite][0]=}")
# 9. Calculate position ids
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
if position_ids.size(-1) < input_ids.size(-1):
position_ids = position_ids[:, -input_ids.size(-1) :]
# 10. Move attention mask to final positions
# print(f"{type(input_ids)=}, {input_ids.shape=}, {self.config.pad_token_id=}")
pad_batch_indices, pad_indices = torch.where(input_ids == self.config.pad_token_id)
if len(pad_batch_indices) > 0:
indices_to_mask = new_token_positions[pad_batch_indices, pad_indices]
final_embedding[pad_batch_indices, indices_to_mask] = 0
# 11. Post-process new_token_positions (set -1 for padding positions)
new_token_positions = new_token_positions.masked_fill(attention_mask == 0, -1)
return final_embedding, final_attention_mask, position_ids, final_labels, new_token_positions
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
timeseries: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Qwen3TSCausalLMOutputWithPast:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the attention blocks).
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
timeseries (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size)`, *optional*):
Timeseries data to be encoded and merged with text embeddings.
Returns:
[`Qwen3TSCausalLMOutputWithPast`] or `tuple(torch.FloatTensor)`:
The model outputs with potential timeseries-expanded attention mask.
"""
# if input_ids is not None and timeseries is not None:
# # Print the input ts
# print("=================================================================")
# print("Timeseries shape:", timeseries.shape)
# print("=================================================================\n\n")
# else:
# print("Time series is None!!!!")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if timeseries is not None and timeseries.shape[0] > 0:
# use_cache = False
# print(f"timeseries shape: {timeseries.shape=}, {input_ids.shape=}")
ts_features, patch_cnt = self.ts_encoder(timeseries)
inputs_embeds = inputs_embeds.to(ts_features.dtype)
# if str(input_ids.device) == 'cuda:0':
# print(f"-------------------------------------------------------\n[before] {input_ids.shape=}, timeseries={timeseries.shape if timeseries is not None else None}, {attention_mask.sum(-1)=}")
inputs_embeds, attention_mask, position_ids, labels, new_token_positions = self._merge_input_ids_with_time_series_features(
ts_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
)
# print(f"{inputs_embeds.shape=}, {attention_mask.shape=}, {position_ids.shape=}, {labels.shape=}")
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# if str(input_ids.device) == 'cuda:0':
# print(f"[after] {inputs_embeds.shape=}, {attention_mask.sum(-1)=}, {position_ids.max(-1)=}, {cache_position=}")
outputs: BaseModelOutputWithPast = self.model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
# print(f"{logits.shape=}, {hidden_states.shape=}, {slice_indices=}, {logits_to_keep=}, labels={(labels != self.config.ignore_index).sum() if labels is not None else None}, {loss=}, {torch.sum(torch.isnan(logits))=}, {torch.sum(torch.isinf(logits))=}, {torch.sum(torch.isnan(hidden_states))=}, {torch.sum(torch.isinf(hidden_states))=}")
return Qwen3TSCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
attention_mask=attention_mask,
labels=labels,
new_token_positions=new_token_positions if timeseries is not None and timeseries.shape[0] > 0 else None,
)
__all__ = [
"Qwen3TSForCausalLM"
]