Text Generation
Transformers
English
legal
chat
transformer
SkillForge45 commited on
Commit
12ea3de
·
verified ·
1 Parent(s): 247bc76

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +169 -0
model.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer
7
+ from tqdm import tqdm
8
+ import math
9
+
10
+ # 1. Dataset class for loading and processing data
11
+ class FullChatDataset(Dataset):
12
+ def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=128):
13
+ self.datasets = []
14
+
15
+ # Load all specified datasets
16
+ for name in dataset_names:
17
+ try:
18
+ dataset = load_dataset(name, split="train")
19
+ self.datasets.append(dataset)
20
+ except Exception as e:
21
+ print(f"Failed to load dataset {name}: {e}")
22
+
23
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
24
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
25
+ self.max_length = max_length
26
+
27
+ def __len__(self):
28
+ return sum(len(d) for d in self.datasets)
29
+
30
+ def __getitem__(self, idx):
31
+ # Determine which dataset the index belongs to
32
+ for dataset in self.datasets:
33
+ if idx < len(dataset):
34
+ item = dataset[idx]
35
+ break
36
+ idx -= len(dataset)
37
+
38
+ # Handling different dataset formats
39
+ if 'dialog' in item: # For Daily Dialog
40
+ dialog = item['dialog']
41
+ elif 'messages' in item: # For some other datasets
42
+ dialog = [msg['text'] for msg in item['messages']]
43
+ else: # Universal handling
44
+ dialog = [v for k, v in item.items() if isinstance(v, str)]
45
+
46
+ context = " [SEP] ".join(dialog[:-1])
47
+ response = dialog[-1]
48
+
49
+ inputs = self.tokenizer(
50
+ context,
51
+ text_pair=response,
52
+ max_length=self.max_length,
53
+ padding='max_length',
54
+ truncation=True,
55
+ return_tensors="pt"
56
+ )
57
+
58
+ return {
59
+ 'input_ids': inputs['input_ids'].flatten(),
60
+ 'attention_mask': inputs['attention_mask'].flatten(),
61
+ 'labels': inputs['input_ids'].flatten()
62
+ }
63
+
64
+ # 2. Model architecture
65
+ class SimpleTransformerModel(nn.Module):
66
+ def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3):
67
+ super().__init__()
68
+ self.embedding = nn.Embedding(vocab_size, d_model)
69
+ self.pos_encoder = PositionalEncoding(d_model)
70
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
71
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
72
+ self.fc = nn.Linear(d_model, vocab_size)
73
+
74
+ def forward(self, x, mask=None):
75
+ x = self.embedding(x)
76
+ x = self.pos_encoder(x)
77
+ x = self.transformer(x, mask)
78
+ return self.fc(x)
79
+
80
+ class PositionalEncoding(nn.Module):
81
+ def __init__(self, d_model, max_len=500):
82
+ super().__init__()
83
+ position = torch.arange(max_len).unsqueeze(1)
84
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
85
+ pe = torch.zeros(max_len, d_model)
86
+ pe[:, 0::2] = torch.sin(position * div_term)
87
+ pe[:, 1::2] = torch.cos(position * div_term)
88
+ self.register_buffer('pe', pe)
89
+
90
+ def forward(self, x):
91
+ return x + self.pe[:x.size(1)]
92
+
93
+ # 3. Model training
94
+ def train(model, dataloader, epochs=3, lr=3e-4):
95
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96
+ model = model.to(device)
97
+ criterion = nn.CrossEntropyLoss(ignore_index=0)
98
+ optimizer = optim.Adam(model.parameters(), lr=lr)
99
+
100
+ for epoch in range(epochs):
101
+ model.train()
102
+ total_loss = 0
103
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
104
+
105
+ for batch in pbar:
106
+ inputs = batch['input_ids'].to(device)
107
+ masks = batch['attention_mask'].to(device)
108
+ labels = batch['labels'].to(device)
109
+
110
+ optimizer.zero_grad()
111
+ outputs = model(inputs, masks)
112
+ loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
113
+ loss.backward()
114
+ optimizer.step()
115
+
116
+ total_loss += loss.item()
117
+ pbar.set_postfix({'loss': loss.item()})
118
+
119
+ print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}")
120
+
121
+ # 4. Response generation
122
+ def chat(model, tokenizer, prompt, max_length=50):
123
+ device = next(model.parameters()).device
124
+ model.eval()
125
+
126
+ inputs = tokenizer(
127
+ prompt,
128
+ return_tensors="pt",
129
+ max_length=128,
130
+ truncation=True,
131
+ padding='max_length'
132
+ ).to(device)
133
+
134
+ with torch.no_grad():
135
+ outputs = model.generate(
136
+ input_ids=inputs['input_ids'],
137
+ attention_mask=inputs['attention_mask'],
138
+ max_length=max_length,
139
+ do_sample=True,
140
+ top_k=50,
141
+ top_p=0.95,
142
+ temperature=0.7
143
+ )
144
+
145
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+
147
+ # 5. Main process
148
+ if __name__ == "__main__":
149
+ # Initialization
150
+ dataset = FullChatDataset()
151
+ dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
152
+
153
+ # Model creation
154
+ model = SimpleTransformerModel(len(dataset.tokenizer))
155
+
156
+ # Training
157
+ train(model, dataloader)
158
+
159
+ # Saving
160
+ torch.save(model.state_dict(), "chatbot_model.pt")
161
+ dataset.tokenizer.save_pretrained("chatbot_tokenizer")
162
+
163
+
164
+ while True:
165
+ user_input = input("You: ")
166
+ if user_input.lower() in ['exit', 'quit']:
167
+ break
168
+ response = chat(model, dataset.tokenizer, user_input)
169
+ print(f"Bot: {response}")