Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer | |
from palm_rlhf_pytorch import PaLM | |
import gradio as gr | |
def generate(prompt, seq_len, temperature, filter_thres, model): | |
device = torch.device("cpu") | |
model = PaLM( | |
num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False, | |
).to(device) | |
model.load('/palm_410m_8k_v0.pt') | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
encoded_text = tokenizer(prompt, return_tensors="pt") | |
output_tensor = model.generate( | |
seq_len=seq_len, | |
prompt=encoded_text["input_ids"].to(device), | |
temperature=temperature, | |
filter_thres=filter_thres, | |
pad_value=0.0, | |
eos_token=tokenizer.eos_token_id, | |
return_seq_without_prompt=False, | |
use_tqdm=True, | |
) | |
decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True) | |
return decoded_output[0] | |
iface = gr.Interface(fn=generate, inputs="text", outputs="text") | |
iface.launch() |