José Eliel Camargo Molina commited on
Commit
302961a
·
1 Parent(s): ce3d862
Files changed (1) hide show
  1. app.py +419 -0
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
3
+ import torch
4
+ from pathlib import Path
5
+ import urllib.request
6
+
7
+ # To latex stuff
8
+ ####################################
9
+
10
+ import itertools
11
+ import re
12
+
13
+ rep_tex_dict = {
14
+ "SU3":{"-3":r"\bar{\textbf{3}}","3":r"\textbf{3}"},
15
+ "SU2":{"-2":r"\textbf{2}","2":r"\textbf{2}","-3":r"\textbf{3}","3":r"\textbf{3}"},
16
+ }
17
+
18
+ def fieldobj_to_tex(obj,lor_index,pos):
19
+ su3 = None
20
+ su2 = None
21
+ u1 = None
22
+ hel = None
23
+ sp = None
24
+
25
+ #print(obj)
26
+ obj_mod = obj.copy()
27
+ for tok in obj:
28
+ if "SU3" in tok:
29
+ su3 = tok.split("=")[-1]
30
+ obj_mod.remove(tok)
31
+ if "SU2" in tok:
32
+ su2 = tok.split("=")[-1]
33
+ obj_mod.remove(tok)
34
+ if "U1" in tok:
35
+ u1 = tok.split("=")[-1]
36
+ obj_mod.remove(tok)
37
+ if "HEL" in tok:
38
+ hel = tok.split("=")[-1]
39
+ if hel == "1" : hel = "+1"
40
+ if "SP" in tok: sp = tok.split("=")[-1]
41
+ #print(obj)
42
+ assert sp is not None
43
+
44
+ outtex= ""
45
+ if sp == "0" : outtex += "\phi"
46
+ if sp == "1" : outtex += "A"+pos+lor_index
47
+ if sp == "1/2" : outtex += "\psi"
48
+
49
+ outtex += r"_{("
50
+ if su3 is not None:
51
+ outtex += rep_tex_dict["SU3"][su3]+" ,"
52
+ else:
53
+ outtex += r"\textbf{1}"+" ,"
54
+ if su2 is not None:
55
+ outtex += rep_tex_dict["SU2"][su2]+" ,"
56
+ else:
57
+ outtex += r"\textbf{1}"+" ,"
58
+ if u1 is not None:
59
+ outtex += u1+" ,"
60
+ else:
61
+ outtex += r"\textbf{0}"+" ,"
62
+ if hel is not None: outtex += "h:"+ hel + " ,"
63
+ if outtex[-1] == ",": outtex = outtex[:-1]+")}"
64
+ return outtex
65
+
66
+ def derobj_to_tex(obj,lor_index,pos):
67
+ if pos == "^":
68
+ outtex = "D^{"+lor_index+"}_{("
69
+ elif pos == "_":
70
+ outtex = "D_{"+lor_index+"("
71
+ else:
72
+ raise ValueError("pos must be ^ or _")
73
+ if "G" not in obj and "W" not in obj and "B" not in obj:
74
+ if pos == "^":
75
+ return "\partial^{"+lor_index+"}"
76
+ elif pos == "_":
77
+ return "\partial_{"+lor_index+"}"
78
+
79
+ if "G" in obj: outtex += "SU3,"
80
+ if "W" in obj: outtex += "SU2,"
81
+ if "B" in obj: outtex += "U1,"
82
+ if outtex[-1] == ",": outtex = outtex[:-1]+")}"
83
+
84
+ return outtex
85
+
86
+ def gamobj_to_tex(obj,lor_index,pos):
87
+ outtex = "\gamma"+pos+lor_index
88
+ return outtex
89
+
90
+ def obj_to_tex(obj,lor_index="\mu",pos="^"):
91
+ if isinstance(obj,tuple): obj = list(obj)
92
+ if isinstance(obj,str): obj = [i for i in obj.split(" ") if i != ""]
93
+ # remove any space char in the first element of the list
94
+ if obj[0] == "+" :
95
+ return "\quad\quad+"
96
+ if obj[0] == "-" :
97
+ return "\quad\quad-"
98
+ if obj[0] == "i" :
99
+ return "i"
100
+ if obj[0] == "QF" :
101
+ return fieldobj_to_tex(obj,lor_index,pos)
102
+ if obj[0] == "D":
103
+ return derobj_to_tex(obj,lor_index,pos)
104
+ if obj[0] == "GM":
105
+ return gamobj_to_tex(obj,lor_index,pos)
106
+ if obj[0] == "CMAD":
107
+ return "[ "+derobj_to_tex(obj,lor_index,pos)
108
+ if obj[0] == "CMBD":
109
+ return ", "+derobj_to_tex(obj,lor_index,pos)+' ]'
110
+
111
+ def split_with_delimiter_preserved(string, delimiters,ignore_dots=False):
112
+ if "." in string and ignore_dots == False:
113
+ #print(string)
114
+ raise ValueError("Unexpected ending to the generated Lagrangian")
115
+ pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
116
+ pattern = re.split(pattern, string)
117
+ pattern = [" + " if i == "+ " else i for i in pattern ]
118
+ pattern = [i for i in pattern if i != ""]
119
+ return pattern
120
+
121
+ def split_with_delimiter_preserved(string, delimiters,ignore_dots=False):
122
+ if "." in string and ignore_dots == False:
123
+ #print(string)
124
+ raise ValueError("Unexpected ending to the generated Lagrangian")
125
+ pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
126
+ pattern = re.split(pattern, string)
127
+ pattern = [" + " if i == "+ " else i for i in pattern ]
128
+ pattern = [i for i in pattern if i != ""]
129
+ return pattern
130
+
131
+ def clean_split(inlist, delimiters):
132
+ i = 0
133
+ merged_list = []
134
+ while i < len(inlist):
135
+ if inlist[i] in delimiters:
136
+ if i < len(inlist) - 1:
137
+ merged_list.append(inlist[i] + inlist[i+1])
138
+ i += 1 # Skip the next element as it has been merged
139
+ else:
140
+ merged_list.append(inlist[i]) # If it's the last element, append it without merging
141
+ else:
142
+ merged_list.append(inlist[i])
143
+ i += 1
144
+ return merged_list
145
+
146
+
147
+ def get_obj_dict(inlist):
148
+ outdict = {}
149
+ for iitem in inlist:
150
+ idict = {"ID":None,"LATEX":None}
151
+ id = [i for i in iitem.split() if "ID" in i]
152
+ if len(id) == 1:
153
+ idict["ID"] = id[0]
154
+ if "QF" in iitem:
155
+ idict["LATEX"] = obj_to_tex(iitem,"\\mu","^")
156
+ if iitem == "+" or iitem == "-" or iitem == "i":
157
+ idict["LATEX"] = obj_to_tex(iitem )
158
+ outdict[iitem] = idict
159
+ return outdict
160
+
161
+ def get_con_dict(inlist):
162
+ outdict = {}
163
+ for iitem in inlist:
164
+ iitem = iitem.split()
165
+ iitem = [i for i in iitem if i != ""]
166
+ sym = [i for i in iitem if ("SU" in i or "LZ" in i)]
167
+ assert len(sym) == 1, "More than one symmetry in contraction"
168
+ ids = [i for i in iitem if ("SU" not in i and "LZ" not in i)]
169
+ if sym[0] not in outdict.keys():
170
+ outdict[sym[0]] = [ids]
171
+ else:
172
+ outdict[sym[0]].append(ids)
173
+ return outdict
174
+
175
+ def term_to_tex(term,verbose=False):
176
+ # Clean term
177
+ term = term.replace(".","").replace(" = ", "=").replace(" =- ", "=-").replace(" / ", "/").replace("CMA D", "CMAD").replace("CMB D", "CMBD")
178
+ term = split_with_delimiter_preserved(term,[" QF "," D "," GM "," CMAD "," CMBD "," CON "])
179
+ term = clean_split(term, [" QF "," D "," GM "," CMAD "," CMBD "," CON "])
180
+
181
+ if verbose: print(term)
182
+
183
+ if term == [" + "] or term == [" - "] or term == [" i "]:
184
+ return term[0]
185
+
186
+ # Get Dictionary of objects
187
+ objdict = get_obj_dict([i for i in term if " CON " not in i])
188
+
189
+ if verbose:
190
+ for i,j in objdict.items():
191
+ print(i,"\t\t",j)
192
+
193
+
194
+ # Do contractions
195
+ contractions = [i for i in term if " CON " in i]
196
+ assert len(contractions) < 2, "More than one contraction in term"
197
+ if (len(contractions) == 1) and contractions != [" CON "]:
198
+
199
+ contractions = contractions[0]
200
+ contractions = split_with_delimiter_preserved(contractions,[" LZ "," SU2 "," SU3 "])
201
+ contractions = clean_split(contractions, [" LZ "," SU2 "," SU3 "])
202
+ contractions = [i for i in contractions if i != " CON"]
203
+ condict = get_con_dict(contractions)
204
+ if verbose: print(condict)
205
+ if "LZ" in condict.keys():
206
+ firstlz = True
207
+ cma = True
208
+ for con in condict["LZ"]:
209
+ for kobj , iobj in objdict.items():
210
+ if iobj["ID"] is None : continue
211
+ if iobj["ID"] in con:
212
+ if cma: lsymb = "\\mu"
213
+ else: lsymb = "\\nu"
214
+
215
+ if firstlz:
216
+ iobj["LATEX"] = obj_to_tex(kobj,lsymb,"^")
217
+ firstlz = False
218
+ else:
219
+ iobj["LATEX"] = obj_to_tex(kobj,lsymb,"_")
220
+ cma = False
221
+ firstlz = True
222
+
223
+ outstr = " ".join([objdict[i]["LATEX"] for i in term if " CON " not in i])
224
+
225
+ return outstr
226
+ def display_in_latex(instring,verbose=False):
227
+ #latex_string = r"$\overgroup{\Large{" + instring + "}}$"
228
+ latex_string = r"$\Large{" + instring + "}$"
229
+ if verbose: print(latex_string)
230
+ display(Latex(latex_string))
231
+ return instring
232
+
233
+
234
+ def str_tex(instr,num=0):
235
+
236
+ #print("INPUT:",iinstr)
237
+ #print("TERM:")
238
+ #outstr = ""
239
+ #instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "])
240
+
241
+ if num != 0:
242
+ instr = instr[:num]
243
+
244
+ inlist = [term.replace(".","") for term in instr]
245
+ outstr = ""
246
+ coup = 0
247
+ mass = 0
248
+ outstr = "\\begin{aligned}"
249
+ for i, iterm in enumerate(inlist):
250
+ if i ==0:
251
+ outstr += " \mathcal{L}= \quad \\\\ & "
252
+ else:
253
+ nqf = iterm.count("QF SP = 0")
254
+ nD = iterm.count(" D ")
255
+ if nqf != 0 and nqf != 2 and nD == 0:
256
+ coup += 1
257
+ outstr += " \lambda_{"+str(coup)+"} \,"
258
+ if nqf == 2 and nD == 0:
259
+ mass += 1
260
+ outstr += " m^2_{"+str(mass)+"} \,"
261
+ outstr += term_to_tex(iterm,False) + " \quad "
262
+ if i%4 == 0: outstr += " \\\\ \\\\ & "
263
+ return outstr
264
+
265
+ def master_str_tex(iinstr):
266
+ instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "])
267
+ try:
268
+ outstr = str_tex(instr)
269
+ except Exception as e:
270
+ outstr = str_tex(instr,-1)
271
+ outstr += " \cdots"
272
+ print(e)
273
+ outstr += "\\end{aligned}"
274
+ return outstr#########
275
+
276
+
277
+
278
+ device = 'cpu'
279
+ model_name = "JoseEliel/BART-Lagrangian"
280
+
281
+ @st.cache_resource
282
+ def load_model():
283
+
284
+ model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
285
+
286
+ return model
287
+
288
+ model = load_model()
289
+
290
+ @st.cache_resource
291
+ def load_tokenizer():
292
+ return PreTrainedTokenizerFast.from_pretrained(model_name)
293
+
294
+ hf_tokenizer = load_tokenizer()
295
+
296
+ def process_input(input_text):
297
+ input_text = input_text.replace("[SOS]", "").replace("[EOS]", "").replace("FIELD", "SPLITFIELD")
298
+ fields = input_text.split('SPLIT')[1:]
299
+ fields = [x.strip().split(' ') for x in fields]
300
+ fields = sorted(fields)
301
+ fields = "[SOS] " + " ".join([" ".join(x) for x in fields]) + " [EOS]"
302
+ return fields
303
+
304
+ def process_output(output_text):
305
+ return output_text.replace("[SOS]", "").replace("[EOS]", "").replace(".","")
306
+
307
+ def process_output_pretty_print(output_text):
308
+ pretty_output = output_text.replace(" / ", "/")
309
+ pretty_output = pretty_output.replace("=- ", "= -")
310
+ pretty_output = pretty_output.replace("+", "\n+")
311
+ return pretty_output
312
+
313
+ def generate_lagrangian(input_text):
314
+ input_text = process_input(input_text)
315
+ inputs = hf_tokenizer([input_text], return_tensors='pt').to(device)
316
+ with st.spinner(text="Generating Lagrangian..."):
317
+ lagrangian_ids = model.generate(inputs['input_ids'], max_length=1024)
318
+ lagrangian = hf_tokenizer.decode(lagrangian_ids[0].tolist(), skip_special_tokens=False)
319
+ lagrangian = process_output(lagrangian)
320
+ return lagrangian
321
+
322
+ def generate_field(sp, su2, su3, u1):
323
+ # Initialize components list
324
+ components = [f"FIELD SPIN={sp}"]
325
+
326
+ # Conditionally add each component
327
+ if su2 != "$1$":
328
+ components.append(f"SU2={su2}")
329
+ if su3 == "$\\bar{3}$":
330
+ components.append("SU3=-3")
331
+ if su3 != "$1$" and su3 != "$\\bar{3}$":
332
+ components.append(f"SU3={su3}")
333
+ if u1 != "0":
334
+ components.append(f"U1={u1}")
335
+
336
+ # Join components into final string
337
+ return " ".join(components).replace("$","")
338
+
339
+ def main():
340
+ # Streamlit UI (Adjusted without 'className')
341
+ st.title("$\\mathscr{L}$agrangian Generator")
342
+ st.markdown(" ### For a set of chosen fields, this model generates the corresponding Lagrangian which encodes all interactions and dynamics of the fields.")
343
+
344
+ st.markdown(" #### This is a simple demo of our smaller [BART](https://arxiv.org/abs/1910.13461)-based model with 110M parameters")
345
+
346
+ st.markdown(" ##### :violet[Due to computational resources, we limit the number of fields to 2 and spin = 0]")
347
+ st.markdown(" ##### Choose up to two different fields:")
348
+
349
+ su2_options = ["$1$", "$2$", "$3$"]
350
+ su3_options = ["$1$", "$3$", "$\\bar{3}$"]
351
+ u1_options = ["-1","-2/3", "-1/2", "-1/3", "0","1/3" ,"1/2", "2/3", "1"]
352
+
353
+ # Initialize or update session state variables
354
+ if 'count' not in st.session_state:
355
+ st.session_state.count = 0 # Keeps track of button presses
356
+ if 'field_strings' not in st.session_state:
357
+ st.session_state.field_strings = [] # Stores the generated field strings
358
+
359
+ with st.form("field_selection"):
360
+ su2_selection = st.radio("Select $\\mathrm{SU}(2)$ value:", su2_options)
361
+ su3_selection = st.radio("Select $\\mathrm{SU}(3)$ value:", su3_options)
362
+ u1_selection = st.radio("Select $\\mathrm{U}(1)$ value:", u1_options)
363
+ submitted = st.form_submit_button("Add field")
364
+ if submitted:
365
+ if st.session_state.count < 4:
366
+ sp_value = 0 # Assume SP is always 0
367
+ field_string = generate_field(sp_value, su2_selection, su3_selection, u1_selection)
368
+ st.session_state.field_strings.append(field_string) # Save generated field string
369
+ st.session_state.count += 1 # Increment button press count
370
+ elif st.session_state.count >= 2:
371
+ st.write("You have reached the maximum number of fields we allow in this demo.")
372
+ clear_fields = st.button("Clear fields")
373
+ if clear_fields:
374
+ st.session_state.field_strings = []
375
+ st.session_state.count = 0
376
+ # Button to generate field text, allows up to 2 button presses
377
+
378
+ st.write(f"Input Fields:")
379
+ for i, fs in enumerate(st.session_state.field_strings, 1):
380
+ texfield = obj_to_tex(fs)
381
+ fieldname = f"Field {i}:"
382
+ st.latex("\\text{" + fieldname + "} \quad" + texfield)
383
+
384
+ if st.button("Generate Lagrangian"):
385
+ input_fields = " ".join(st.session_state.field_strings)
386
+ if input_fields == "":
387
+ st.write("Please add fields before generating the Lagrangian.")
388
+ return
389
+ else:
390
+ print(input_fields)
391
+ print("\n")
392
+ # append input fields into csv file, create if not exist
393
+ #with open('usesdata.csv', 'a') as f:
394
+ # f.write(input_fields + "\n")
395
+ # append and prepend input fields with SOS and EOS tokens
396
+ input_fields = "[SOS] " + input_fields + " [EOS]"
397
+ generated_lagrangian = generate_lagrangian(input_fields)
398
+ print(generated_lagrangian)
399
+ print("\n")
400
+ # Save generated lagrangian into same csv file, create if not exist
401
+ #with open('usesdata.csv', 'a') as f:
402
+ # f.write(generated_lagrangian + "\n")
403
+ #latex_output = master_str_tex(generated_lagrangian[1:])
404
+ #print(latex_output)
405
+ #print("\n\n")
406
+ # save latex output in file
407
+ #with open('usesdata.csv', 'a') as f:
408
+ # f.write(latex_output + "\n")
409
+ #st.text_area("Generated Lagrangian", pretty_output, height=300)
410
+ st.markdown("### Generated Lagrangian")
411
+ st.text(generated_lagrangian)
412
+
413
+
414
+ # write my contact info
415
+ st.markdown("### Contact")
416
+ st.markdown("If you have any questions or suggestions, please feel free to Email us. [Eliel](mailto:[email protected]) or [Yong Sheng](mailto:[email protected]).")
417
+
418
+ if __name__ == "__main__":
419
+ main()