Mamba2-hf
Collection
HF compatible format of state-spaces/mamba2
•
5 items
•
Updated
Use the code below to get started with the model.
import torch
from transformers import AutoTokenizer
from transformers import Mamba2ForCausalLM
if __name__ == "__main__":
device = "cuda"
model_id = "benchang1110/mamba2-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Mamba2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)
model.eval()
with torch.no_grad():
text = input("Input: ")
input_ids = tokenizer(text, return_tensors="pt").to(device)
output = model.generate(**input_ids, max_new_tokens=1024, do_sample=False)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Conversion script: mamba2hf.py
Base model
state-spaces/mamba2-130m