romit
commited on
Fixed output cls
Browse files
torch-ext/mamba_ssm/utils/generation.py
CHANGED
|
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|
| 11 |
from einops import rearrange, repeat
|
| 12 |
from torch import Tensor
|
| 13 |
from torch.profiler import ProfilerActivity, profile, record_function
|
| 14 |
-
from transformers.generation import
|
| 15 |
|
| 16 |
|
| 17 |
@dataclass
|
|
@@ -146,7 +146,7 @@ def decode(
|
|
| 146 |
max_length: int
|
| 147 |
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
| 148 |
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
| 149 |
-
Returns:
|
| 150 |
sequences: (batch, max_length)
|
| 151 |
scores: tuples of (batch, vocab_size)
|
| 152 |
"""
|
|
@@ -240,7 +240,7 @@ def decode(
|
|
| 240 |
end.record()
|
| 241 |
torch.cuda.synchronize()
|
| 242 |
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
| 243 |
-
output_cls =
|
| 244 |
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
| 245 |
|
| 246 |
|
|
|
|
| 11 |
from einops import rearrange, repeat
|
| 12 |
from torch import Tensor
|
| 13 |
from torch.profiler import ProfilerActivity, profile, record_function
|
| 14 |
+
from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer
|
| 15 |
|
| 16 |
|
| 17 |
@dataclass
|
|
|
|
| 146 |
max_length: int
|
| 147 |
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
| 148 |
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
| 149 |
+
Returns: GenerateDecoderOnlyOutput, with the following fields:
|
| 150 |
sequences: (batch, max_length)
|
| 151 |
scores: tuples of (batch, vocab_size)
|
| 152 |
"""
|
|
|
|
| 240 |
end.record()
|
| 241 |
torch.cuda.synchronize()
|
| 242 |
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
| 243 |
+
output_cls = GenerateDecoderOnlyOutput
|
| 244 |
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
| 245 |
|
| 246 |
|