Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +4 -1
ts_generation_mixin.py
CHANGED
|
@@ -28,6 +28,8 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 28 |
streamer: Optional["BaseStreamer"] = None,
|
| 29 |
**model_kwargs,
|
| 30 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
|
|
|
|
|
|
| 31 |
if len(input_ids.shape) == 2:
|
| 32 |
batch_size, cur_len = input_ids.shape
|
| 33 |
else:
|
|
@@ -169,6 +171,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 169 |
if streamer is not None:
|
| 170 |
streamer.end()
|
| 171 |
|
|
|
|
| 172 |
if return_dict_in_generate:
|
| 173 |
if self.config.is_encoder_decoder:
|
| 174 |
return GenerateEncoderDecoderOutput(
|
|
@@ -192,7 +195,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 192 |
past_key_values=model_kwargs.get("past_key_values"),
|
| 193 |
)
|
| 194 |
else:
|
| 195 |
-
return input_ids
|
| 196 |
|
| 197 |
def _update_model_kwargs_for_generation(
|
| 198 |
self,
|
|
|
|
| 28 |
streamer: Optional["BaseStreamer"] = None,
|
| 29 |
**model_kwargs,
|
| 30 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
| 31 |
+
input_ids_origin_device = input_ids.device
|
| 32 |
+
input_ids = input_ids.to(self.device)
|
| 33 |
if len(input_ids.shape) == 2:
|
| 34 |
batch_size, cur_len = input_ids.shape
|
| 35 |
else:
|
|
|
|
| 171 |
if streamer is not None:
|
| 172 |
streamer.end()
|
| 173 |
|
| 174 |
+
input_ids.squeeze_(dim=-1).to(input_ids_origin_device)
|
| 175 |
if return_dict_in_generate:
|
| 176 |
if self.config.is_encoder_decoder:
|
| 177 |
return GenerateEncoderDecoderOutput(
|
|
|
|
| 195 |
past_key_values=model_kwargs.get("past_key_values"),
|
| 196 |
)
|
| 197 |
else:
|
| 198 |
+
return input_ids
|
| 199 |
|
| 200 |
def _update_model_kwargs_for_generation(
|
| 201 |
self,
|