|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments |
|
from datasets import Dataset |
|
import tempfile |
|
import os |
|
|
|
|
|
model_name = "arnir0/Tiny-LLM" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
def fine_tune_and_generate(uploaded_file, prompt): |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as tmp: |
|
tmp.write(uploaded_file.read()) |
|
tmp_path = tmp.name |
|
|
|
|
|
with open(tmp_path, "r", encoding="utf-8") as f: |
|
lines = [line.strip() for line in f.readlines() if line.strip()] |
|
|
|
|
|
os.remove(tmp_path) |
|
|
|
|
|
dataset = Dataset.from_dict({"text": lines}) |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128) |
|
|
|
tokenized_dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./fine_tuned", |
|
num_train_epochs=1, |
|
per_device_train_batch_size=2, |
|
logging_dir="./logs", |
|
logging_steps=10, |
|
save_strategy="no", |
|
learning_rate=5e-5, |
|
weight_decay=0.01, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
outputs = model.generate(input_ids, max_length=50, do_sample=True, top_p=0.95, top_k=50) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=fine_tune_and_generate, |
|
inputs=[ |
|
gr.File(label="Upload training text (.txt)"), |
|
gr.Textbox(lines=2, placeholder="Enter prompt for generation", label="Prompt"), |
|
], |
|
outputs="text", |
|
title="Tiny-LLM Fine-tune & Generate", |
|
description="Upload your text file to fine-tune Tiny-LLM and generate text from a prompt." |
|
) |
|
|
|
iface.launch() |
|
|