File size: 948 Bytes
28fbe36
 
ee1d4ef
 
 
28fbe36
 
 
 
 
 
 
ee1d4ef
 
 
 
28fbe36
 
ee1d4ef
 
 
 
 
 
28fbe36
ee1d4ef
 
 
 
 
 
9b35c39
57418f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
---
library_name: transformers
license: apache-2.0
base_model:
- state-spaces/mamba2-130m
---


## How to Get Started with the Model

Use the code below to get started with the model.

```python
import torch  
from transformers import AutoTokenizer  
from transformers import Mamba2ForCausalLM


if __name__ == "__main__":
    device = "cuda"
    model_id = "benchang1110/mamba2-130m-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = Mamba2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)
    model.eval()

    with torch.no_grad():
      text = input("Input: ")
      input_ids = tokenizer(text, return_tensors="pt").to(device)
      output = model.generate(**input_ids, max_new_tokens=1024, do_sample=False)
      print(tokenizer.decode(output[0], skip_special_tokens=True))
```

Conversion script: [mamba2hf.py](https://gist.github.com/Benchangatrul284/9e98bcde5a64d6b918d905511e09598b)