--- license: gemma library_name: transformers pipeline_tag: text-generation datasets: - Norod78/hebrew_lyrics_prompting - Norod78/hebrew_lyrics_prompting_finetune language: - he base_model: - google/gemma-2-2b-it --- # מחולל שירים מטופשים :) ```python from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import TextStreamer import torch model_id = "Norod78/hebrew_lyrics-gemma2_2b" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) print(f"model.device = {model.device}") input_text = "כתוב לי שיר על תפוח אדמה עם חרדה חברתית" input_template = tokenizer.apply_chat_template([{"role": "user", "content": input_text}], tokenize=False, add_generation_prompt=True) input_ids = tokenizer(input_template, return_tensors="pt").to(model.device) outputs = model.generate(**input_ids, max_new_tokens=256, repetition_penalty=1.05, temperature=0.5, no_repeat_ngram_size = 4, do_sample = True) decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] result = decoded_output.replace("user\n", "משתמש:\n").replace("model\n", "\nמודל:\n") print("result = ", result) chat = [ {"role": "user", "content": input_text}, {"role": "asistant"} ] chat_with_template = tokenizer.apply_chat_template(chat, tokenize=False) inputs = tokenizer( [ chat_with_template ], return_tensors = "pt").to(model.device) text_streamer = TextStreamer(tokenizer) _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens=256 , repetition_penalty=1.1, temperature=0.6, top_p=0.4, top_k=40, do_sample = True) ```