Fine-tuned on gpt-2, using articles from ACLU.org on the topic of "transgender rights"
Training setup:
import torch
from transformers import (
pipeline,
AutoModelForCausalLM,
AutoTokenizer,
)
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
# Add this to monitor MPS memory usage
def print_mps_memory():
if torch.backends.mps.is_available():
print(f"MPS allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB")
print(f"MPS cached: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB")
# Call this periodically during training
print_mps_memory()
# Check if MPS is available
if torch.backends.mps.is_available():
device = torch.device("mps")
print("MPS device found.")
else:
device = torch.device("cpu")
print("MPS device not found, using CPU.")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
model = model.to(device) # Move model to MPS
ds = load_dataset("gofilipa/aclu_transgender")
# Limit dataset size for testing
train_dataset = ds['train'].select(range(min(2000, len(ds['train'])))) # Use only first 2000 samples
# Clear memory first
if torch.backends.mps.is_available():
torch.mps.empty_cache()
# Reduce training parameters for lower memory usage
training_params = SFTConfig(
output_dir="../checkpoints",
per_device_train_batch_size=1, # Keep at 1
per_device_eval_batch_size=1,
gradient_accumulation_steps=2, # Reduce from 4 to 2
num_train_epochs=3, # slowly increased as memory allows, from 1-3
learning_rate=2e-4,
weight_decay=0.001,
dataset_text_field="text", # Fixed: removed [:400]
report_to="none",
bf16=False,
fp16=False,
dataloader_pin_memory=False,
remove_unused_columns=False,
max_seq_length=512, # Add this to limit sequence length
gradient_checkpointing=True, # Add this to save memory
)
trainer = SFTTrainer(
model = model,
train_dataset = train_dataset,
processing_class = tokenizer,
args = training_params
)
trainer.train()
license: gpl-3.0 datasets: - gofilipa/aclu_transgender base_model: - openai-community/gpt2 pipeline_tag: text-generation
- Downloads last month
- 3
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support