import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments from datasets import Dataset import tempfile import os # Load base tokenizer and model model_name = "arnir0/Tiny-LLM" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) def fine_tune_and_generate(uploaded_file, prompt): # Save uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as tmp: tmp.write(uploaded_file.read()) tmp_path = tmp.name # Read lines from uploaded text file with open(tmp_path, "r", encoding="utf-8") as f: lines = [line.strip() for line in f.readlines() if line.strip()] # Clean up temp file os.remove(tmp_path) # Create dataset for fine-tuning dataset = Dataset.from_dict({"text": lines}) # Tokenization function def tokenize_function(examples): return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128) tokenized_dataset = dataset.map(tokenize_function, batched=True) # Set training args (very small for demo) training_args = TrainingArguments( output_dir="./fine_tuned", num_train_epochs=1, per_device_train_batch_size=2, logging_dir="./logs", logging_steps=10, save_strategy="no", learning_rate=5e-5, weight_decay=0.01, ) # Trainer init trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, ) # Fine-tune the model trainer.train() # Generate text from prompt input_ids = tokenizer(prompt, return_tensors="pt").input_ids outputs = model.generate(input_ids, max_length=50, do_sample=True, top_p=0.95, top_k=50) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text # Gradio interface iface = gr.Interface( fn=fine_tune_and_generate, inputs=[ gr.File(label="Upload training text (.txt)"), gr.Textbox(lines=2, placeholder="Enter prompt for generation", label="Prompt"), ], outputs="text", title="Tiny-LLM Fine-tune & Generate", description="Upload your text file to fine-tune Tiny-LLM and generate text from a prompt." ) iface.launch()