yasserrmd commited on
Commit
19c5106
·
verified ·
1 Parent(s): 1d27550

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dataclasses import dataclass
6
+ import tiktoken
7
+ import math
8
+
9
+ # Paste your full GPT code here (copy your GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT classes)
10
+ # For brevity, assuming GPTConfig and GPT are defined here exactly as your code.
11
+
12
+ @dataclass
13
+ class GPTConfig:
14
+ block_size: int
15
+ vocab_size: int
16
+ n_layer: int
17
+ n_head: int
18
+ n_embd: int
19
+ dropout: float = 0.1
20
+ bias: bool = True
21
+
22
+ class LayerNorm(nn.Module):
23
+ def __init__(self, ndim, bias):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.ones(ndim))
26
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
27
+ def forward(self, x):
28
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
29
+
30
+ class CausalSelfAttention(nn.Module):
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
35
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
36
+ self.attn_dropout = nn.Dropout(config.dropout)
37
+ self.resid_dropout = nn.Dropout(config.dropout)
38
+ self.n_head = config.n_head
39
+ self.n_embd = config.n_embd
40
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
41
+ if not self.flash:
42
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
43
+ .view(1, 1, config.block_size, config.block_size))
44
+ def forward(self, x):
45
+ B, T, C = x.size()
46
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
47
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
48
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
49
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
50
+ if self.flash:
51
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
52
+ else:
53
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
54
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
55
+ att = F.softmax(att, dim=-1)
56
+ att = self.attn_dropout(att)
57
+ y = att @ v
58
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
59
+ y = self.resid_dropout(self.c_proj(y))
60
+ return y
61
+
62
+ class MLP(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
66
+ self.gelu = nn.GELU()
67
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
68
+ self.dropout = nn.Dropout(config.dropout)
69
+ def forward(self, x):
70
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
71
+
72
+ class Block(nn.Module):
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
76
+ self.attn = CausalSelfAttention(config)
77
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
78
+ self.mlp = MLP(config)
79
+ def forward(self, x):
80
+ x = x + self.attn(self.ln1(x))
81
+ x = x + self.mlp(self.ln2(x))
82
+ return x
83
+
84
+ class GPT(nn.Module):
85
+ def __init__(self, config):
86
+ super().__init__()
87
+ self.config = config
88
+ self.transformer = nn.ModuleDict(dict(
89
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
90
+ wpe=nn.Embedding(config.block_size, config.n_embd),
91
+ drop=nn.Dropout(config.dropout),
92
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
93
+ ln_f=LayerNorm(config.n_embd, config.bias),
94
+ ))
95
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
96
+ self.transformer.wte.weight = self.lm_head.weight # weight tying
97
+ self.apply(self._init_weights)
98
+ for pn, p in self.named_parameters():
99
+ if pn.endswith('c_proj.weight'):
100
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
101
+ def _init_weights(self, module):
102
+ if isinstance(module, nn.Linear):
103
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
104
+ if module.bias is not None:
105
+ nn.init.zeros_(module.bias)
106
+ elif isinstance(module, nn.Embedding):
107
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
108
+ def forward(self, idx, targets=None):
109
+ device = idx.device
110
+ b, t = idx.size()
111
+ assert t <= self.config.block_size
112
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
113
+ tok_emb = self.transformer.wte(idx)
114
+ pos_emb = self.transformer.wpe(pos)
115
+ x = self.transformer.drop(tok_emb + pos_emb)
116
+ for block in self.transformer.h:
117
+ x = block(x)
118
+ x = self.transformer.ln_f(x)
119
+ if targets is not None:
120
+ logits = self.lm_head(x)
121
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
122
+ return logits, loss
123
+ else:
124
+ logits = self.lm_head(x[:, [-1], :])
125
+ return logits, None
126
+ @torch.no_grad()
127
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
128
+ for _ in range(max_new_tokens):
129
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
130
+ logits, _ = self(idx_cond)
131
+ logits = logits[:, -1, :] / temperature
132
+ if top_k is not None:
133
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
134
+ logits[logits < v[:, [-1]]] = -float('Inf')
135
+ if top_p is not None:
136
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
137
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
138
+ sorted_indices_to_remove = cumulative_probs > top_p
139
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
140
+ sorted_indices_to_remove[..., 0] = 0
141
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
142
+ logits[:, indices_to_remove] = -float('Inf')
143
+ probs = F.softmax(logits, dim=-1)
144
+ idx_next = torch.multinomial(probs, num_samples=1)
145
+ idx = torch.cat((idx, idx_next), dim=1)
146
+ return idx
147
+
148
+ # --- Load checkpoint and tokenizer ---
149
+
150
+ checkpoint_path = "best_model_params.pt" # update path if needed
151
+
152
+ config = GPTConfig(
153
+ vocab_size=50257,
154
+ block_size=128,
155
+ n_layer=6,
156
+ n_head=6,
157
+ n_embd=384,
158
+ dropout=0.1,
159
+ bias=True,
160
+ )
161
+
162
+ model = GPT(config)
163
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
164
+ model.eval()
165
+
166
+ enc = tiktoken.get_encoding("gpt2")
167
+
168
+ # --- Gradio interface ---
169
+
170
+ samples = [
171
+ "The Fourth Amendment protects citizens against unreasonable searches and seizures.",
172
+ "Under the doctrine of stare decisis, courts follow precedent to ensure legal consistency.",
173
+ "The Commerce Clause grants Congress the power to regulate interstate commerce.",
174
+ "Due process requires that the government respect all legal rights owed to a person.",
175
+ "The principle of double jeopardy prevents a defendant from being tried twice for the same offense."
176
+ ]
177
+
178
+ def generate_text(prompt, max_new_tokens=150, temperature=0.7, top_k=50, top_p=0.9):
179
+ input_ids = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0)
180
+ with torch.no_grad():
181
+ output_ids = model.generate(
182
+ input_ids, max_new_tokens=max_new_tokens,
183
+ temperature=temperature, top_k=top_k, top_p=top_p
184
+ )
185
+ generated = enc.decode(output_ids.squeeze().tolist())
186
+ return generated
187
+
188
+ import gradio as gr
189
+
190
+ with gr.Blocks() as demo:
191
+ gr.Markdown("# Legal GPT Text Generation Demo")
192
+
193
+ sample_dropdown = gr.Dropdown(label="Sample prompts", choices=samples, value=samples[0])
194
+ prompt_input = gr.Textbox(label="Input Prompt", lines=3, value=samples[0])
195
+
196
+ def update_prompt(selected):
197
+ return selected
198
+
199
+ sample_dropdown.change(update_prompt, inputs=sample_dropdown, outputs=prompt_input)
200
+
201
+ generate_button = gr.Button("Generate Text")
202
+ output_text = gr.Textbox(label="Generated Output", lines=15)
203
+
204
+ generate_button.click(generate_text, inputs=prompt_input, outputs=output_text)
205
+
206
+ demo.launch()