import streamlit as st from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration import torch from pathlib import Path import urllib.request # To latex stuff #################################### import itertools import re rep_tex_dict = { "SU3":{"-3":r"\bar{\textbf{3}}","3":r"\textbf{3}"}, "SU2":{"-2":r"\textbf{2}","2":r"\textbf{2}","-3":r"\textbf{3}","3":r"\textbf{3}"}, } def fieldobj_to_tex(obj,lor_index,pos): su3 = None su2 = None u1 = None hel = None sp = None #print(obj) obj_mod = obj.copy() for tok in obj: if "SU3" in tok: su3 = tok.split("=")[-1] obj_mod.remove(tok) if "SU2" in tok: su2 = tok.split("=")[-1] obj_mod.remove(tok) if "U1" in tok: u1 = tok.split("=")[-1] obj_mod.remove(tok) if "HEL" in tok: hel = tok.split("=")[-1] if hel == "1" : hel = "+1" if "SP" in tok: sp = tok.split("=")[-1] #print(obj) assert sp is not None outtex= "" if sp == "0" : outtex += "\phi" if sp == "1" : outtex += "A"+pos+lor_index if sp == "1/2" : outtex += "\psi" outtex += r"_{(" if su3 is not None: outtex += rep_tex_dict["SU3"][su3]+" ," else: outtex += r"\textbf{1}"+" ," if su2 is not None: outtex += rep_tex_dict["SU2"][su2]+" ," else: outtex += r"\textbf{1}"+" ," if u1 is not None: outtex += u1+" ," else: outtex += r"\textbf{0}"+" ," if hel is not None: outtex += "h:"+ hel + " ," if outtex[-1] == ",": outtex = outtex[:-1]+")}" return outtex def derobj_to_tex(obj,lor_index,pos): if pos == "^": outtex = "D^{"+lor_index+"}_{(" elif pos == "_": outtex = "D_{"+lor_index+"(" else: raise ValueError("pos must be ^ or _") if "G" not in obj and "W" not in obj and "B" not in obj: if pos == "^": return "\partial^{"+lor_index+"}" elif pos == "_": return "\partial_{"+lor_index+"}" if "G" in obj: outtex += "SU3," if "W" in obj: outtex += "SU2," if "B" in obj: outtex += "U1," if outtex[-1] == ",": outtex = outtex[:-1]+")}" return outtex def gamobj_to_tex(obj,lor_index,pos): outtex = "\gamma"+pos+lor_index return outtex def obj_to_tex(obj,lor_index="\mu",pos="^"): if isinstance(obj,tuple): obj = list(obj) if isinstance(obj,str): obj = [i for i in obj.split(" ") if i != ""] # remove any space char in the first element of the list if obj[0] == "+" : return "\quad\quad+" if obj[0] == "-" : return "\quad\quad-" if obj[0] == "i" : return "i" if obj[0] == "QF" : return fieldobj_to_tex(obj,lor_index,pos) if obj[0] == "D": return derobj_to_tex(obj,lor_index,pos) if obj[0] == "GM": return gamobj_to_tex(obj,lor_index,pos) if obj[0] == "CMAD": return "[ "+derobj_to_tex(obj,lor_index,pos) if obj[0] == "CMBD": return ", "+derobj_to_tex(obj,lor_index,pos)+' ]' def split_with_delimiter_preserved(string, delimiters,ignore_dots=False): if "." in string and ignore_dots == False: #print(string) raise ValueError("Unexpected ending to the generated Lagrangian") pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')' pattern = re.split(pattern, string) pattern = [" + " if i == "+ " else i for i in pattern ] pattern = [i for i in pattern if i != ""] return pattern def split_with_delimiter_preserved(string, delimiters,ignore_dots=False): if "." in string and ignore_dots == False: #print(string) raise ValueError("Unexpected ending to the generated Lagrangian") pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')' pattern = re.split(pattern, string) pattern = [" + " if i == "+ " else i for i in pattern ] pattern = [i for i in pattern if i != ""] return pattern def clean_split(inlist, delimiters): i = 0 merged_list = [] while i < len(inlist): if inlist[i] in delimiters: if i < len(inlist) - 1: merged_list.append(inlist[i] + inlist[i+1]) i += 1 # Skip the next element as it has been merged else: merged_list.append(inlist[i]) # If it's the last element, append it without merging else: merged_list.append(inlist[i]) i += 1 return merged_list def get_obj_dict(inlist): outdict = {} for iitem in inlist: idict = {"ID":None,"LATEX":None} id = [i for i in iitem.split() if "ID" in i] if len(id) == 1: idict["ID"] = id[0] if "QF" in iitem: idict["LATEX"] = obj_to_tex(iitem,"\\mu","^") if iitem == "+" or iitem == "-" or iitem == "i": idict["LATEX"] = obj_to_tex(iitem ) outdict[iitem] = idict return outdict def get_con_dict(inlist): outdict = {} for iitem in inlist: iitem = iitem.split() iitem = [i for i in iitem if i != ""] sym = [i for i in iitem if ("SU" in i or "LZ" in i)] assert len(sym) == 1, "More than one symmetry in contraction" ids = [i for i in iitem if ("SU" not in i and "LZ" not in i)] if sym[0] not in outdict.keys(): outdict[sym[0]] = [ids] else: outdict[sym[0]].append(ids) return outdict def term_to_tex(term,verbose=False): # Clean term term = term.replace(".","").replace(" = ", "=").replace(" =- ", "=-").replace(" / ", "/").replace("CMA D", "CMAD").replace("CMB D", "CMBD") term = split_with_delimiter_preserved(term,[" QF "," D "," GM "," CMAD "," CMBD "," CON "]) term = clean_split(term, [" QF "," D "," GM "," CMAD "," CMBD "," CON "]) if verbose: print(term) if term == [" + "] or term == [" - "] or term == [" i "]: return term[0] # Get Dictionary of objects objdict = get_obj_dict([i for i in term if " CON " not in i]) if verbose: for i,j in objdict.items(): print(i,"\t\t",j) # Do contractions contractions = [i for i in term if " CON " in i] assert len(contractions) < 2, "More than one contraction in term" if (len(contractions) == 1) and contractions != [" CON "]: contractions = contractions[0] contractions = split_with_delimiter_preserved(contractions,[" LZ "," SU2 "," SU3 "]) contractions = clean_split(contractions, [" LZ "," SU2 "," SU3 "]) contractions = [i for i in contractions if i != " CON"] condict = get_con_dict(contractions) if verbose: print(condict) if "LZ" in condict.keys(): firstlz = True cma = True for con in condict["LZ"]: for kobj , iobj in objdict.items(): if iobj["ID"] is None : continue if iobj["ID"] in con: if cma: lsymb = "\\mu" else: lsymb = "\\nu" if firstlz: iobj["LATEX"] = obj_to_tex(kobj,lsymb,"^") firstlz = False else: iobj["LATEX"] = obj_to_tex(kobj,lsymb,"_") cma = False firstlz = True outstr = " ".join([objdict[i]["LATEX"] for i in term if " CON " not in i]) return outstr def display_in_latex(instring,verbose=False): #latex_string = r"$\overgroup{\Large{" + instring + "}}$" latex_string = r"$\Large{" + instring + "}$" if verbose: print(latex_string) display(Latex(latex_string)) return instring def str_tex(instr,num=0): #print("INPUT:",iinstr) #print("TERM:") #outstr = "" #instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "]) if num != 0: instr = instr[:num] inlist = [term.replace(".","") for term in instr] outstr = "" coup = 0 mass = 0 outstr = "\\begin{aligned}" for i, iterm in enumerate(inlist): if i ==0: outstr += " \mathcal{L}= \quad \\\\ & " else: nqf = iterm.count("QF SP = 0") nD = iterm.count(" D ") if nqf != 0 and nqf != 2 and nD == 0: coup += 1 outstr += " \lambda_{"+str(coup)+"} \," if nqf == 2 and nD == 0: mass += 1 outstr += " m^2_{"+str(mass)+"} \," outstr += term_to_tex(iterm,False) + " \quad " if i%4 == 0: outstr += " \\\\ \\\\ & " return outstr def master_str_tex(iinstr): instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "]) try: outstr = str_tex(instr) except Exception as e: outstr = str_tex(instr,-1) outstr += " \cdots" print(e) outstr += "\\end{aligned}" return outstr######### device = 'cpu' model_name = "JoseEliel/BART-Lagrangian" @st.cache_resource def load_model(): model = BartForConditionalGeneration.from_pretrained(model_name).to(device) return model model = load_model() @st.cache_resource def load_tokenizer(): return PreTrainedTokenizerFast.from_pretrained(model_name) hf_tokenizer = load_tokenizer() def process_input(input_text): input_text = input_text.replace("[SOS]", "").replace("[EOS]", "").replace("FIELD", "SPLITFIELD") fields = input_text.split('SPLIT')[1:] fields = [x.strip().split(' ') for x in fields] fields = sorted(fields) fields = "[SOS] " + " ".join([" ".join(x) for x in fields]) + " [EOS]" return fields def process_output(output_text): return output_text.replace("[SOS]", "").replace("[EOS]", "").replace(".","") def process_output_pretty_print(output_text): pretty_output = output_text.replace(" / ", "/") pretty_output = pretty_output.replace("=- ", "= -") pretty_output = pretty_output.replace("+", "\n+") return pretty_output def generate_lagrangian(input_text): input_text = process_input(input_text) inputs = hf_tokenizer([input_text], return_tensors='pt').to(device) with st.spinner(text="Generating Lagrangian..."): lagrangian_ids = model.generate(inputs['input_ids'], max_length=1024) lagrangian = hf_tokenizer.decode(lagrangian_ids[0].tolist(), skip_special_tokens=False) lagrangian = process_output(lagrangian) return lagrangian def generate_field(sp, su2, su3, u1): # Initialize components list components = [f"FIELD SPIN={sp}"] # Conditionally add each component if su2 != "$1$": components.append(f"SU2={su2}") if su3 == "$\\bar{3}$": components.append("SU3=-3") if su3 != "$1$" and su3 != "$\\bar{3}$": components.append(f"SU3={su3}") if u1 != "0": components.append(f"U1={u1}") # Join components into final string return " ".join(components).replace("$","") def main(): # Streamlit UI (Adjusted without 'className') st.title("$\\mathscr{L}$agrangian Generator") st.markdown(" ### For a set of chosen fields, this model generates the corresponding Lagrangian which encodes all interactions and dynamics of the fields.") st.markdown(" #### This is a simple demo of our smaller [BART](https://arxiv.org/abs/1910.13461)-based model with 110M parameters") st.markdown(" ##### :violet[Due to computational resources, we limit the number of fields to 2 and spin = 0]") st.markdown(" ##### Choose up to two different fields:") su2_options = ["$1$", "$2$", "$3$"] su3_options = ["$1$", "$3$", "$\\bar{3}$"] u1_options = ["-1","-2/3", "-1/2", "-1/3", "0","1/3" ,"1/2", "2/3", "1"] # Initialize or update session state variables if 'count' not in st.session_state: st.session_state.count = 0 # Keeps track of button presses if 'field_strings' not in st.session_state: st.session_state.field_strings = [] # Stores the generated field strings with st.form("field_selection"): su2_selection = st.radio("Select $\\mathrm{SU}(2)$ value:", su2_options) su3_selection = st.radio("Select $\\mathrm{SU}(3)$ value:", su3_options) u1_selection = st.radio("Select $\\mathrm{U}(1)$ value:", u1_options) submitted = st.form_submit_button("Add field") if submitted: if st.session_state.count < 4: sp_value = 0 # Assume SP is always 0 field_string = generate_field(sp_value, su2_selection, su3_selection, u1_selection) st.session_state.field_strings.append(field_string) # Save generated field string st.session_state.count += 1 # Increment button press count elif st.session_state.count >= 2: st.write("You have reached the maximum number of fields we allow in this demo.") clear_fields = st.button("Clear fields") if clear_fields: st.session_state.field_strings = [] st.session_state.count = 0 # Button to generate field text, allows up to 2 button presses st.write(f"Input Fields:") for i, fs in enumerate(st.session_state.field_strings, 1): texfield = obj_to_tex(fs) fieldname = f"Field {i}:" st.latex("\\text{" + fieldname + "} \quad" + texfield) if st.button("Generate Lagrangian"): input_fields = " ".join(st.session_state.field_strings) if input_fields == "": st.write("Please add fields before generating the Lagrangian.") return else: print(input_fields) print("\n") # append input fields into csv file, create if not exist #with open('usesdata.csv', 'a') as f: # f.write(input_fields + "\n") # append and prepend input fields with SOS and EOS tokens input_fields = "[SOS] " + input_fields + " [EOS]" generated_lagrangian = generate_lagrangian(input_fields) print(generated_lagrangian) print("\n") # Save generated lagrangian into same csv file, create if not exist #with open('usesdata.csv', 'a') as f: # f.write(generated_lagrangian + "\n") #latex_output = master_str_tex(generated_lagrangian[1:]) #print(latex_output) #print("\n\n") # save latex output in file #with open('usesdata.csv', 'a') as f: # f.write(latex_output + "\n") #st.text_area("Generated Lagrangian", pretty_output, height=300) st.markdown("### Generated Lagrangian") st.text(generated_lagrangian) # write my contact info st.markdown("### Contact") st.markdown("If you have any questions or suggestions, please feel free to Email us. [Eliel](mailto:eliel.camargo-molina@physics.uu.se) or [Yong Sheng](mailto:yongsheng.koay@physics.uu.se).") if __name__ == "__main__": main()