|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
from tqdm import tqdm |
|
import math |
|
|
|
|
|
class FullChatDataset(Dataset): |
|
def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=128): |
|
self.datasets = [] |
|
|
|
|
|
for name in dataset_names: |
|
try: |
|
dataset = load_dataset(name, split="train") |
|
self.datasets.append(dataset) |
|
except Exception as e: |
|
print(f"Failed to load dataset {name}: {e}") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return sum(len(d) for d in self.datasets) |
|
|
|
def __getitem__(self, idx): |
|
|
|
for dataset in self.datasets: |
|
if idx < len(dataset): |
|
item = dataset[idx] |
|
break |
|
idx -= len(dataset) |
|
|
|
|
|
if 'dialog' in item: |
|
dialog = item['dialog'] |
|
elif 'messages' in item: |
|
dialog = [msg['text'] for msg in item['messages']] |
|
else: |
|
dialog = [v for k, v in item.items() if isinstance(v, str)] |
|
|
|
context = " [SEP] ".join(dialog[:-1]) |
|
response = dialog[-1] |
|
|
|
inputs = self.tokenizer( |
|
context, |
|
text_pair=response, |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
return { |
|
'input_ids': inputs['input_ids'].flatten(), |
|
'attention_mask': inputs['attention_mask'].flatten(), |
|
'labels': inputs['input_ids'].flatten() |
|
} |
|
|
|
|
|
class SimpleTransformerModel(nn.Module): |
|
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
self.pos_encoder = PositionalEncoding(d_model) |
|
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) |
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) |
|
self.fc = nn.Linear(d_model, vocab_size) |
|
|
|
def forward(self, x, mask=None): |
|
x = self.embedding(x) |
|
x = self.pos_encoder(x) |
|
x = self.transformer(x, mask) |
|
return self.fc(x) |
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, d_model, max_len=500): |
|
super().__init__() |
|
position = torch.arange(max_len).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
|
pe = torch.zeros(max_len, d_model) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
return x + self.pe[:x.size(1)] |
|
|
|
|
|
def train(model, dataloader, epochs=3, lr=3e-4): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
criterion = nn.CrossEntropyLoss(ignore_index=0) |
|
optimizer = optim.Adam(model.parameters(), lr=lr) |
|
|
|
for epoch in range(epochs): |
|
model.train() |
|
total_loss = 0 |
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
|
|
|
for batch in pbar: |
|
inputs = batch['input_ids'].to(device) |
|
masks = batch['attention_mask'].to(device) |
|
labels = batch['labels'].to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(inputs, masks) |
|
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
pbar.set_postfix({'loss': loss.item()}) |
|
|
|
print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}") |
|
|
|
|
|
def chat(model, tokenizer, prompt, max_length=50): |
|
device = next(model.parameters()).device |
|
model.eval() |
|
|
|
inputs = tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
max_length=128, |
|
truncation=True, |
|
padding='max_length' |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=inputs['input_ids'], |
|
attention_mask=inputs['attention_mask'], |
|
max_length=max_length, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.7 |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dataset = FullChatDataset() |
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=True) |
|
|
|
|
|
model = SimpleTransformerModel(len(dataset.tokenizer)) |
|
|
|
|
|
train(model, dataloader) |
|
|
|
|
|
torch.save(model.state_dict(), "chatbot_model.pt") |
|
dataset.tokenizer.save_pretrained("chatbot_tokenizer") |
|
|
|
|
|
while True: |
|
user_input = input("You: ") |
|
if user_input.lower() in ['exit', 'quit']: |
|
break |
|
response = chat(model, dataset.tokenizer, user_input) |
|
print(f"Bot: {response}") |