|
|
--- |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
base_model: |
|
|
- state-spaces/mamba2-2.7b |
|
|
--- |
|
|
|
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
|
|
Use the code below to get started with the model. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from transformers import Mamba2ForCausalLM |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = "cuda" |
|
|
model_id = "benchang1110/mamba2-2.7b-hf" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = Mamba2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
text = input("Input: ") |
|
|
input_ids = tokenizer(text, return_tensors="pt").to(device) |
|
|
output = model.generate(**input_ids, max_new_tokens=1024, do_sample=False) |
|
|
print(tokenizer.decode(output[0], skip_special_tokens=True)) |
|
|
``` |
|
|
Conversion script: [mamba2hf.py](https://gist.github.com/Benchangatrul284/9e98bcde5a64d6b918d905511e09598b) |