taxagent / dataset_processing.py
fragger246's picture
Upload 6 files
2cb9c2c verified
raw
history blame contribute delete
976 Bytes
from datasets import load_dataset
from transformers import AutoTokenizer
# Model name
MODEL_NAME = "/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load dataset
dataset = load_dataset("json", data_files="tax_train_data.json") # Replace with actual dataset
# Preprocessing function
def preprocess_function(examples):
inputs = examples["prompt"] # Get prompt text
targets = examples["response"] # Get response text
# Tokenize both inputs and targets
model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=512)
labels = tokenizer(targets, padding="max_length", truncation=True, max_length=512)
model_inputs["labels"] = labels["input_ids"] # Add labels to dataset
return model_inputs
# Apply preprocessing to dataset
processed_dataset = dataset.map(preprocess_function, batched=True)
# Save processed dataset
processed_dataset.save_to_disk("processed_dataset.json")