Usage:
import torch, transformers, pyreft
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
def generate_response():
'''
simple test for the model
'''
model = transformers.AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False)
streamer = transformers.TextStreamer(tokenizer,skip_prompt=True)
reft_model = pyreft.ReftModel.load("benchang1110/Tinyllama-1.1B-Chat-REFT-v1.0", model)
reft_model.set_device(device)
while(1):
prompt = input('USER:')
if prompt == "exit":
break
print("Assistant: ")
messages = [
{'content': prompt, 'role': 'user'},
]
prompt = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
print(prompt)
prompt = tokenizer(prompt, return_tensors="pt").to(device) # move prompt to the same device as the model
# have to set the following hyperparameters to make the model work (so stupid.....)
base_unit_location = prompt["input_ids"].shape[-1] - 1 # last position
first_n = 8 # (number of first_n)
last_n = 8 # (number of last_n)
LAYER = [i for i in range(model.config.num_hidden_layers)]
base_unit_locations = [[[i for i in range(first_n)] + [base_unit_location-i for i in range(last_n)]]]*len(LAYER)
_, reft_response = reft_model.generate(
prompt, unit_locations={"sources->base": (None, base_unit_locations)},
intervene_on_prompt=True, max_new_tokens=256, do_sample=True, temperature=0.3,repetition_penalty=1.1,streamer=streamer
)
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generate_response()
- Downloads last month
- 4