|
import re |
|
from typing import Dict |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from torch.optim import AdamW |
|
from torch.utils.data import DataLoader |
|
from tqdm.auto import tqdm |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_scheduler |
|
|
|
|
|
def setup_model(): |
|
|
|
model_name = "Salesforce/codet5-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
|
|
def prepare_dataset(): |
|
|
|
dataset = load_dataset( |
|
"code_search_net", "python", split="train[:1000]", trust_remote_code=True |
|
) |
|
|
|
def extract_function_info(example: Dict) -> Dict: |
|
"""Extract clean function definitions and docstrings.""" |
|
code = example["whole_func_string"] |
|
|
|
|
|
if not code.strip().startswith("def "): |
|
|
|
return { |
|
"function": "", |
|
"documentation": "", |
|
"input": "", |
|
"output": "" |
|
} |
|
|
|
|
|
code = re.sub(r"\n\s*\n", "\n", code) |
|
docstring = example["func_documentation_string"].strip() |
|
|
|
return { |
|
"function": code, |
|
"documentation": docstring, |
|
"input": f"Write a Python function that: {docstring}", |
|
"output": code, |
|
} |
|
|
|
|
|
processed_dataset = dataset.map(extract_function_info) |
|
|
|
processed_dataset = processed_dataset.filter(lambda x: x["function"] != "") |
|
|
|
return processed_dataset |
|
|
|
|
|
def tokenize_data(examples, tokenizer, max_length=512): |
|
"""Tokenize inputs and outputs for training.""" |
|
|
|
model_inputs = tokenizer( |
|
examples['input'], |
|
max_length=max_length, |
|
padding='max_length', |
|
truncation=True |
|
) |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
labels = tokenizer( |
|
examples['output'], |
|
max_length=max_length, |
|
padding='max_length', |
|
truncation=True |
|
).input_ids |
|
|
|
model_inputs['labels'] = labels |
|
return model_inputs |
|
|
|
|
|
def train(): |
|
model, tokenizer = setup_model() |
|
dataset = prepare_dataset() |
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
batch_size = 8 |
|
num_epochs = 3 |
|
learning_rate = 5e-5 |
|
max_length = 512 |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
lambda x: tokenize_data(x, tokenizer, max_length), |
|
batched=True, |
|
batch_size=16, |
|
remove_columns=dataset.column_names, |
|
) |
|
|
|
def collate_fn(examples): |
|
return { |
|
'input_ids': torch.stack([torch.tensor(example['input_ids']) for example in examples]).to(device), |
|
'attention_mask': torch.stack([torch.tensor(example['attention_mask']) for example in examples]).to(device), |
|
'labels': torch.stack([torch.tensor(example['labels']) for example in examples]).to(device) |
|
} |
|
|
|
train_dataloader = DataLoader( |
|
tokenized_dataset, |
|
shuffle=True, |
|
batch_size=batch_size, |
|
collate_fn=collate_fn |
|
) |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=learning_rate) |
|
num_training_steps = num_epochs * len(train_dataloader) |
|
lr_scheduler = get_scheduler( |
|
name="linear", |
|
optimizer=optimizer, |
|
num_warmup_steps=0, |
|
num_training_steps=num_training_steps, |
|
) |
|
|
|
|
|
progress_bar = tqdm(range(num_training_steps)) |
|
model.train() |
|
|
|
for epoch in range(num_epochs): |
|
for batch in train_dataloader: |
|
outputs = model(**batch) |
|
loss = outputs.loss |
|
loss.backward() |
|
|
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
progress_bar.update(1) |
|
progress_bar.set_description(f"Loss: {loss.item():.4f}") |
|
|
|
|
|
model.save_pretrained(f"checkpoint-epoch-{epoch}") |
|
tokenizer.save_pretrained(f"checkpoint-epoch-{epoch}") |
|
|
|
print("Training completed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|