Spaces:
Runtime error
Runtime error
Update x.py
Browse files
x.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
import glob
|
3 |
import stat
|
4 |
import xml.etree.ElementTree as ET
|
5 |
import torch
|
@@ -7,17 +6,14 @@ import torch.nn as nn
|
|
7 |
import torch.nn.functional as F
|
8 |
import logging
|
9 |
import requests
|
10 |
-
import faiss
|
11 |
from collections import defaultdict
|
12 |
-
from typing import List, Dict, Any
|
13 |
from colorama import Fore, Style, init
|
14 |
from accelerate import Accelerator
|
15 |
from torch.utils.data import DataLoader, TensorDataset
|
16 |
-
from torch.cuda.amp import GradScaler, autocast
|
17 |
from transformers import AutoTokenizer, AutoModel
|
18 |
from sentence_transformers import SentenceTransformer
|
19 |
-
|
20 |
-
from sentence_transformers.uniformer import Uniformer
|
21 |
|
22 |
# Initialize colorama
|
23 |
init(autoreset=True)
|
@@ -84,46 +80,50 @@ class DynamicModel(nn.Module):
|
|
84 |
self.sections = nn.ModuleDict({sn: nn.ModuleList([self.create_layer(lp) for lp in layers]) for sn, layers in sections.items()})
|
85 |
|
86 |
def create_layer(self, lp):
|
87 |
-
|
88 |
if lp.get('batch_norm', True):
|
89 |
-
|
90 |
-
|
91 |
-
if
|
92 |
-
|
93 |
-
elif
|
94 |
-
|
95 |
-
elif
|
96 |
-
|
97 |
-
elif
|
98 |
-
|
99 |
-
elif
|
100 |
-
|
101 |
-
if
|
102 |
-
|
103 |
if lp.get('memory_augmentation', False):
|
104 |
-
|
105 |
if lp.get('hybrid_attention', False):
|
106 |
-
|
107 |
if lp.get('dynamic_flash_attention', False):
|
108 |
-
|
109 |
if lp.get('magic_state', False):
|
110 |
-
|
111 |
-
return nn.Sequential(*
|
112 |
|
113 |
-
def forward(self, x,
|
114 |
-
if
|
115 |
-
for
|
116 |
-
x =
|
117 |
else:
|
118 |
-
for
|
119 |
-
for
|
120 |
-
x =
|
121 |
return x
|
122 |
|
123 |
def parse_xml_file(file_path):
|
124 |
tree, root, layers = ET.parse(file_path), ET.parse(file_path).getroot(), []
|
125 |
for layer in root.findall('.//layer'):
|
126 |
-
lp = {
|
|
|
|
|
|
|
|
|
127 |
if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']:
|
128 |
raise ValueError(f"Unsupported activation function: {lp['activation']}")
|
129 |
if lp['input_size'] <= 0 or lp['output_size'] <= 0:
|
@@ -154,7 +154,10 @@ def create_model_from_folder(folder_path):
|
|
154 |
return DynamicModel(dict(sections))
|
155 |
|
156 |
def create_embeddings_and_stores(folder_path, model_name="sentence-transformers/all-MiniLM-L6-v2"):
|
157 |
-
tokenizer
|
|
|
|
|
|
|
158 |
for root, dirs, files in os.walk(folder_path):
|
159 |
for file in files:
|
160 |
if file.endswith('.xml'):
|
@@ -166,23 +169,26 @@ def create_embeddings_and_stores(folder_path, model_name="sentence-transformers/
|
|
166 |
text = elem.text.strip()
|
167 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
168 |
with torch.no_grad():
|
169 |
-
embeddings = model(**inputs).last_hidden_state.mean(dim=1).numpy()
|
170 |
-
|
171 |
doc_store.append(text)
|
172 |
except Exception as e:
|
173 |
logging.error(f"Error processing {file_path}: {str(e)}")
|
174 |
-
return
|
175 |
|
176 |
-
def
|
177 |
-
tokenizer
|
|
|
178 |
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
179 |
with torch.no_grad():
|
180 |
-
query_embedding = model(**inputs).last_hidden_state.mean(dim=1).numpy()
|
181 |
-
|
182 |
-
|
|
|
183 |
|
184 |
def fetch_courtlistener_data(query):
|
185 |
-
base_url
|
|
|
186 |
try:
|
187 |
response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
|
188 |
response.raise_for_status()
|
@@ -194,14 +200,14 @@ def fetch_courtlistener_data(query):
|
|
194 |
class CustomModel(nn.Module):
|
195 |
def __init__(self, model_name="distilbert-base-uncased"):
|
196 |
super().__init__()
|
197 |
-
self.model_name = model_name
|
198 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
199 |
self.encoder = AutoModel.from_pretrained(model_name)
|
200 |
self.hidden_size = self.encoder.config.hidden_size
|
201 |
-
self.dropout = nn.Dropout(p=0.
|
202 |
-
self.fc1 = nn.Linear(self.hidden_size,
|
203 |
-
self.fc2 = nn.Linear(
|
204 |
-
self.fc3 = nn.Linear(
|
|
|
205 |
self.memory = nn.LSTM(self.hidden_size, 64, bidirectional=True, batch_first=True)
|
206 |
self.memory_fc1 = nn.Linear(64 * 2, 32)
|
207 |
self.memory_fc2 = nn.Linear(32, 16)
|
@@ -212,7 +218,8 @@ class CustomModel(nn.Module):
|
|
212 |
x = outputs.last_hidden_state.mean(dim=1)
|
213 |
x = self.dropout(F.relu(self.fc1(x)))
|
214 |
x = self.dropout(F.relu(self.fc2(x)))
|
215 |
-
x = self.fc3(x)
|
|
|
216 |
return x
|
217 |
|
218 |
def training_step(self, data, labels, optimizer, criterion):
|
@@ -234,45 +241,17 @@ class CustomModel(nn.Module):
|
|
234 |
with torch.no_grad():
|
235 |
return self.forward(input)
|
236 |
|
237 |
-
class CustomModelInference(nn.Module):
|
238 |
-
def __init__(self, model_name="distilbert-base-uncased"):
|
239 |
-
super().__init__()
|
240 |
-
self.model_name = model_name
|
241 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
242 |
-
self.encoder = AutoModel.from_pretrained(model_name)
|
243 |
-
self.hidden_size = self.encoder.config.hidden_size
|
244 |
-
self.dropout = nn.Dropout(p=0.2)
|
245 |
-
self.fc1 = nn.Linear(self.hidden_size, 64)
|
246 |
-
self.fc2 = nn.Linear(64, 32)
|
247 |
-
self.fc3 = nn.Linear(32, 16)
|
248 |
-
self.reader = ParagraphReader("data/docstore.json")
|
249 |
-
self.model_embedding = SentenceTransformer('sentence-transformers/multilingual-v2')
|
250 |
-
self.vectorstore = Uniformer("distilusembert-base-nli-mean-tokens", torch.nn.CrossEntropyLoss(), margin=0.5, temperature=0.1, top_k=4)
|
251 |
-
|
252 |
-
def forward(self, data):
|
253 |
-
tokens = self.tokenizer(data, return_tensors="pt", truncation=True, padding=True)
|
254 |
-
outputs = self.encoder(**tokens)
|
255 |
-
x = outputs.last_hidden_state.mean(dim=1)
|
256 |
-
x = self.dropout(F.relu(self.fc1(x)))
|
257 |
-
x = self.dropout(F.relu(self.fc2(x)))
|
258 |
-
x = self.fc3(x)
|
259 |
-
return x
|
260 |
-
|
261 |
-
def infer(self, input):
|
262 |
-
self.eval()
|
263 |
-
with torch.no_grad():
|
264 |
-
return self.forward(input)
|
265 |
-
|
266 |
-
def update_memory(self, data):
|
267 |
-
embeddings = self.model_embedding.encode(data, convert_to_tensor=True)
|
268 |
-
self.vectorstore.add(embeddings)
|
269 |
-
|
270 |
def main():
|
271 |
-
folder_path
|
|
|
272 |
logging.info(f"Created dynamic PyTorch model with sections: {list(model.sections.keys())}")
|
273 |
-
|
274 |
-
accelerator
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
277 |
for epoch in range(num_epochs):
|
278 |
model.train()
|
@@ -287,10 +266,10 @@ def main():
|
|
287 |
avg_loss = total_loss / len(dataloader)
|
288 |
logging.info(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
|
289 |
query = "example query text"
|
290 |
-
results =
|
291 |
logging.info(f"Query results: {results}")
|
292 |
courtlistener_data = fetch_courtlistener_data(query)
|
293 |
logging.info(f"CourtListener API results: {courtlistener_data}")
|
294 |
|
295 |
if __name__ == "__main__":
|
296 |
-
main()
|
|
|
1 |
import os
|
|
|
2 |
import stat
|
3 |
import xml.etree.ElementTree as ET
|
4 |
import torch
|
|
|
6 |
import torch.nn.functional as F
|
7 |
import logging
|
8 |
import requests
|
|
|
9 |
from collections import defaultdict
|
10 |
+
from typing import List, Dict, Any
|
11 |
from colorama import Fore, Style, init
|
12 |
from accelerate import Accelerator
|
13 |
from torch.utils.data import DataLoader, TensorDataset
|
|
|
14 |
from transformers import AutoTokenizer, AutoModel
|
15 |
from sentence_transformers import SentenceTransformer
|
16 |
+
import numpy as np
|
|
|
17 |
|
18 |
# Initialize colorama
|
19 |
init(autoreset=True)
|
|
|
80 |
self.sections = nn.ModuleDict({sn: nn.ModuleList([self.create_layer(lp) for lp in layers]) for sn, layers in sections.items()})
|
81 |
|
82 |
def create_layer(self, lp):
|
83 |
+
layers = [nn.Linear(lp['input_size'], lp['output_size'])]
|
84 |
if lp.get('batch_norm', True):
|
85 |
+
layers.append(nn.BatchNorm1d(lp['output_size']))
|
86 |
+
activation = lp.get('activation', 'relu')
|
87 |
+
if activation == 'relu':
|
88 |
+
layers.append(nn.ReLU(inplace=True))
|
89 |
+
elif activation == 'tanh':
|
90 |
+
layers.append(nn.Tanh())
|
91 |
+
elif activation == 'sigmoid':
|
92 |
+
layers.append(nn.Sigmoid())
|
93 |
+
elif activation == 'leaky_relu':
|
94 |
+
layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
95 |
+
elif activation == 'elu':
|
96 |
+
layers.append(nn.ELU(alpha=1.0, inplace=True))
|
97 |
+
if dropout := lp.get('dropout', 0.0):
|
98 |
+
layers.append(nn.Dropout(p=dropout))
|
99 |
if lp.get('memory_augmentation', False):
|
100 |
+
layers.append(MemoryAugmentationLayer(lp['output_size']))
|
101 |
if lp.get('hybrid_attention', False):
|
102 |
+
layers.append(HybridAttentionLayer(lp['output_size']))
|
103 |
if lp.get('dynamic_flash_attention', False):
|
104 |
+
layers.append(DynamicFlashAttentionLayer(lp['output_size']))
|
105 |
if lp.get('magic_state', False):
|
106 |
+
layers.append(MagicStateLayer(lp['output_size']))
|
107 |
+
return nn.Sequential(*layers)
|
108 |
|
109 |
+
def forward(self, x, section_name=None):
|
110 |
+
if section_name:
|
111 |
+
for layer in self.sections[section_name]:
|
112 |
+
x = layer(x)
|
113 |
else:
|
114 |
+
for section_name, layers in self.sections.items():
|
115 |
+
for layer in layers:
|
116 |
+
x = layer(x)
|
117 |
return x
|
118 |
|
119 |
def parse_xml_file(file_path):
|
120 |
tree, root, layers = ET.parse(file_path), ET.parse(file_path).getroot(), []
|
121 |
for layer in root.findall('.//layer'):
|
122 |
+
lp = {
|
123 |
+
'input_size': int(layer.get('input_size', 128)),
|
124 |
+
'output_size': int(layer.get('output_size', 256)),
|
125 |
+
'activation': layer.get('activation', 'relu').lower()
|
126 |
+
}
|
127 |
if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']:
|
128 |
raise ValueError(f"Unsupported activation function: {lp['activation']}")
|
129 |
if lp['input_size'] <= 0 or lp['output_size'] <= 0:
|
|
|
154 |
return DynamicModel(dict(sections))
|
155 |
|
156 |
def create_embeddings_and_stores(folder_path, model_name="sentence-transformers/all-MiniLM-L6-v2"):
|
157 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
158 |
+
model = AutoModel.from_pretrained(model_name)
|
159 |
+
doc_store = []
|
160 |
+
embeddings_list = []
|
161 |
for root, dirs, files in os.walk(folder_path):
|
162 |
for file in files:
|
163 |
if file.endswith('.xml'):
|
|
|
169 |
text = elem.text.strip()
|
170 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
171 |
with torch.no_grad():
|
172 |
+
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
|
173 |
+
embeddings_list.append(embeddings)
|
174 |
doc_store.append(text)
|
175 |
except Exception as e:
|
176 |
logging.error(f"Error processing {file_path}: {str(e)}")
|
177 |
+
return embeddings_list, doc_store
|
178 |
|
179 |
+
def query_embeddings(query, embeddings_list, doc_store, model_name="sentence-transformers/all-MiniLM-L6-v2"):
|
180 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
181 |
+
model = AutoModel.from_pretrained(model_name)
|
182 |
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
183 |
with torch.no_grad():
|
184 |
+
query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
|
185 |
+
similarities = [np.dot(query_embedding, emb.T) for emb in embeddings_list]
|
186 |
+
top_k_indices = np.argsort(similarities, axis=0)[-5:][::-1]
|
187 |
+
return [doc_store[i] for i in top_k_indices]
|
188 |
|
189 |
def fetch_courtlistener_data(query):
|
190 |
+
base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi"
|
191 |
+
params = {"method": "auto", "query": query, "meta": "/nz", "results": "50", "format": "json"}
|
192 |
try:
|
193 |
response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
|
194 |
response.raise_for_status()
|
|
|
200 |
class CustomModel(nn.Module):
|
201 |
def __init__(self, model_name="distilbert-base-uncased"):
|
202 |
super().__init__()
|
|
|
203 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
204 |
self.encoder = AutoModel.from_pretrained(model_name)
|
205 |
self.hidden_size = self.encoder.config.hidden_size
|
206 |
+
self.dropout = nn.Dropout(p=0.3)
|
207 |
+
self.fc1 = nn.Linear(self.hidden_size, 128)
|
208 |
+
self.fc2 = nn.Linear(128, 64)
|
209 |
+
self.fc3 = nn.Linear(64, 32)
|
210 |
+
self.fc4 = nn.Linear(32, 16)
|
211 |
self.memory = nn.LSTM(self.hidden_size, 64, bidirectional=True, batch_first=True)
|
212 |
self.memory_fc1 = nn.Linear(64 * 2, 32)
|
213 |
self.memory_fc2 = nn.Linear(32, 16)
|
|
|
218 |
x = outputs.last_hidden_state.mean(dim=1)
|
219 |
x = self.dropout(F.relu(self.fc1(x)))
|
220 |
x = self.dropout(F.relu(self.fc2(x)))
|
221 |
+
x = self.dropout(F.relu(self.fc3(x)))
|
222 |
+
x = self.fc4(x)
|
223 |
return x
|
224 |
|
225 |
def training_step(self, data, labels, optimizer, criterion):
|
|
|
241 |
with torch.no_grad():
|
242 |
return self.forward(input)
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
def main():
|
245 |
+
folder_path = 'data'
|
246 |
+
model = create_model_from_folder(folder_path)
|
247 |
logging.info(f"Created dynamic PyTorch model with sections: {list(model.sections.keys())}")
|
248 |
+
embeddings_list, doc_store = create_embeddings_and_stores(folder_path)
|
249 |
+
accelerator = Accelerator()
|
250 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
251 |
+
criterion = nn.CrossEntropyLoss()
|
252 |
+
num_epochs = 10
|
253 |
+
dataset = TensorDataset(torch.randn(100, 128), torch.randint(0, 2, (100,)))
|
254 |
+
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
|
255 |
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
256 |
for epoch in range(num_epochs):
|
257 |
model.train()
|
|
|
266 |
avg_loss = total_loss / len(dataloader)
|
267 |
logging.info(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
|
268 |
query = "example query text"
|
269 |
+
results = query_embeddings(query, embeddings_list, doc_store)
|
270 |
logging.info(f"Query results: {results}")
|
271 |
courtlistener_data = fetch_courtlistener_data(query)
|
272 |
logging.info(f"CourtListener API results: {courtlistener_data}")
|
273 |
|
274 |
if __name__ == "__main__":
|
275 |
+
main()
|