Spaces:
Build error
Build error
Upload 30 files
Browse files- .gitattributes +7 -0
- configs/alpaca.json +6 -0
- configs/cleaning_config.yaml +203 -0
- configs/config.yaml +55 -0
- configs/config_noUI.yaml +37 -0
- configs/latex_commands.yaml +162 -0
- datasets/computer_vision.gexf +0 -0
- datasets/economics.gexf +3 -0
- datasets/finance.gexf +3 -0
- datasets/large_language_models.gexf +3 -0
- datasets/quantative_biology.gexf +3 -0
- datasets/quantum_physics.gexf +3 -0
- datasets/robotics.gexf +3 -0
- datasets/telecommunications.gexf +3 -0
- litbench_pipeline.py +746 -0
- retriever/retriever.py +129 -0
- tasks/abs_2_title.py +23 -0
- tasks/abs_completion.py +29 -0
- tasks/citation_sentence.py +25 -0
- tasks/gen_related_work.py +430 -0
- tasks/influential_papers.py +41 -0
- tasks/intro_2_abs.py +28 -0
- tasks/link_pred.py +23 -0
- tasks/paper_retrieval.py +21 -0
- train.py +385 -0
- utils/de-macro.py +1110 -0
- utils/def_handle.py +75 -0
- utils/gradio_utils.py +20 -0
- utils/graph_utils.py +111 -0
- utils/latexpand +713 -0
- utils/utils.py +701 -0
.gitattributes
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets/economics.gexf filter=lfs diff=lfs merge=lfs -text
|
2 |
+
datasets/finance.gexf filter=lfs diff=lfs merge=lfs -text
|
3 |
+
datasets/large_language_models.gexf filter=lfs diff=lfs merge=lfs -text
|
4 |
+
datasets/quantative_biology.gexf filter=lfs diff=lfs merge=lfs -text
|
5 |
+
datasets/quantum_physics.gexf filter=lfs diff=lfs merge=lfs -text
|
6 |
+
datasets/robotics.gexf filter=lfs diff=lfs merge=lfs -text
|
7 |
+
datasets/telecommunications.gexf filter=lfs diff=lfs merge=lfs -text
|
configs/alpaca.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"description": "Template used by Alpaca-LoRA.",
|
3 |
+
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
4 |
+
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
5 |
+
"response_split": "### Response:"
|
6 |
+
}
|
configs/cleaning_config.yaml
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
patterns_and_insertions:
|
2 |
+
[
|
3 |
+
{
|
4 |
+
"pattern" : '(?:\\figcomp{\s*)(?P<first>.*?)\s*}\s*{\s*(?P<second>.*?)\s*}\s*{\s*(?P<third>.*?)\s*}',
|
5 |
+
"insertion" : '\parbox[c]{{ {second} \linewidth}} {{ \includegraphics[width= {third} \linewidth]{{figures/{first} }} }}',
|
6 |
+
"description" : "Replace figcomp",
|
7 |
+
},
|
8 |
+
]
|
9 |
+
|
10 |
+
verbose: False
|
11 |
+
|
12 |
+
commands_to_delete: [
|
13 |
+
'footnote',
|
14 |
+
'footnote ',
|
15 |
+
'crdata',
|
16 |
+
'appendixhead',
|
17 |
+
'selectlanguage',
|
18 |
+
'name',
|
19 |
+
'expandafter',
|
20 |
+
'copyrightyear',
|
21 |
+
'acmYear',
|
22 |
+
'acmBooktitle',
|
23 |
+
'acmPrice',
|
24 |
+
'authorcontributions',
|
25 |
+
'thanksref',
|
26 |
+
'funding',
|
27 |
+
'conflictsofinterest',
|
28 |
+
'externalbibliography',
|
29 |
+
'acmDOI',
|
30 |
+
'acmISBN',
|
31 |
+
'acmConference',
|
32 |
+
'titleheader',
|
33 |
+
'affil',
|
34 |
+
'authorrunning',
|
35 |
+
'pagenumbering',
|
36 |
+
'enlargethispage',
|
37 |
+
'author',
|
38 |
+
'AuthorNames',
|
39 |
+
'author\n',
|
40 |
+
'Author\n',
|
41 |
+
'Author',
|
42 |
+
'fntext',
|
43 |
+
'icmlauthor',
|
44 |
+
'icmlauthor\n',
|
45 |
+
'icmladdress',
|
46 |
+
'icmladdress\n',
|
47 |
+
'received',
|
48 |
+
'runninghead',
|
49 |
+
'bstctlcite',
|
50 |
+
'slugger',
|
51 |
+
'tocauthor',
|
52 |
+
'author\*',
|
53 |
+
'vspace\*',
|
54 |
+
'\write18',
|
55 |
+
'hspace\*',
|
56 |
+
'vspace',
|
57 |
+
'hspace',
|
58 |
+
'maketitle',
|
59 |
+
'institute',
|
60 |
+
'label',
|
61 |
+
'urlstyle',
|
62 |
+
'acks',
|
63 |
+
'tnoteref',
|
64 |
+
'Appendix',
|
65 |
+
'urlstyle',
|
66 |
+
'url',
|
67 |
+
'editor',
|
68 |
+
'ccsdesc',
|
69 |
+
'cortext',
|
70 |
+
'bibliography',
|
71 |
+
'permission',
|
72 |
+
'usetikzlibrary',
|
73 |
+
'thanks',
|
74 |
+
'thispagestyle',
|
75 |
+
'abovedisplayskip',
|
76 |
+
'belowdisplayskip',
|
77 |
+
'bibliographystyle',
|
78 |
+
'IfSubStringInString',
|
79 |
+
'hyphenation',
|
80 |
+
'theoremstyle',
|
81 |
+
'colorbox',
|
82 |
+
'textcolor',
|
83 |
+
'color',
|
84 |
+
'caption',
|
85 |
+
'thlabel',
|
86 |
+
'fancyfoot',
|
87 |
+
'captionof',
|
88 |
+
'settopmatter',
|
89 |
+
'IEEEtriggeratref',
|
90 |
+
'IEEEauthorblockN',
|
91 |
+
'IEEEauthorblockA',
|
92 |
+
'IEEEauthorblockN\n',
|
93 |
+
'IEEEauthorblockA\n',
|
94 |
+
'IEEEauthorrefmark',
|
95 |
+
'orcid',
|
96 |
+
'typeout',
|
97 |
+
'fancyhead',
|
98 |
+
'pagestyle',
|
99 |
+
'biboptions',
|
100 |
+
'affiliation',
|
101 |
+
'address',
|
102 |
+
'institution',
|
103 |
+
'printalgoIEEE',
|
104 |
+
'date',
|
105 |
+
'authornote',
|
106 |
+
'numberofauthors',
|
107 |
+
'footnotetext',
|
108 |
+
'email',
|
109 |
+
'reftitle',
|
110 |
+
'setcopyright',
|
111 |
+
'ead',
|
112 |
+
'deleted',
|
113 |
+
'includegraphics',
|
114 |
+
'comment',
|
115 |
+
'abstract',
|
116 |
+
'replaced',
|
117 |
+
'xspace',
|
118 |
+
]
|
119 |
+
|
120 |
+
commands_only_to_delete: [
|
121 |
+
'titlerunning',
|
122 |
+
'runningtitle',
|
123 |
+
'title',
|
124 |
+
'title\*',
|
125 |
+
'accept',
|
126 |
+
'added',
|
127 |
+
'icmltitle',
|
128 |
+
'textsuperscript',
|
129 |
+
'texttt',
|
130 |
+
'textsc',
|
131 |
+
'textit',
|
132 |
+
'mathit',
|
133 |
+
'makebox',
|
134 |
+
'mbox',
|
135 |
+
'textbf',
|
136 |
+
'acl',
|
137 |
+
'textnormal',
|
138 |
+
'texttt ',
|
139 |
+
'textsc ',
|
140 |
+
'textit ',
|
141 |
+
'mathit ',
|
142 |
+
'textbf ',
|
143 |
+
'gls',
|
144 |
+
'Gls',
|
145 |
+
'glspl',
|
146 |
+
'textnormal ',
|
147 |
+
'inlinetitle',
|
148 |
+
'mbox',
|
149 |
+
'hl',
|
150 |
+
'highlight',
|
151 |
+
'IEEEraisesectionheading',
|
152 |
+
'IEEEtitleabstractindextext',
|
153 |
+
'IEEEPARstart',
|
154 |
+
'it',
|
155 |
+
'added',
|
156 |
+
'paragraph',
|
157 |
+
'paragraph\*',
|
158 |
+
'MakeLowercase',
|
159 |
+
'emph',
|
160 |
+
'emph ',
|
161 |
+
'text',
|
162 |
+
'acp',
|
163 |
+
'ac'
|
164 |
+
]
|
165 |
+
|
166 |
+
environments_to_delete: [
|
167 |
+
'icmlauthorlist',
|
168 |
+
'tikzpicture',
|
169 |
+
'groupplot',
|
170 |
+
'biography',
|
171 |
+
'IEEEbiographynophoto',
|
172 |
+
'acronym',
|
173 |
+
'MSCcodes',
|
174 |
+
'IEEEbiography',
|
175 |
+
'figure',
|
176 |
+
'AMS',
|
177 |
+
'acknowledgement',
|
178 |
+
'acknowledgments',
|
179 |
+
'acknowledgements',
|
180 |
+
'figure\*',
|
181 |
+
'minipage',
|
182 |
+
'table',
|
183 |
+
'table\*',
|
184 |
+
'glist',
|
185 |
+
'tabular',
|
186 |
+
'tabular\*',
|
187 |
+
'center',
|
188 |
+
'remark',
|
189 |
+
'algorithm',
|
190 |
+
'algorithmic',
|
191 |
+
'CCSXML',
|
192 |
+
'acks',
|
193 |
+
'lstlisting',
|
194 |
+
'tabu',
|
195 |
+
'algorithm\*',
|
196 |
+
'algorithmic\*',
|
197 |
+
'longtable',
|
198 |
+
'sidewaystable\*',
|
199 |
+
'sidewaystable',
|
200 |
+
'appendices',
|
201 |
+
'wrapfigure',
|
202 |
+
'appendix'
|
203 |
+
]
|
configs/config.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_downloading:
|
2 |
+
download_directory: "quant_bio_retrieval/" # directory where the papers will be downloaded and the graph will be saved
|
3 |
+
gexf_file: "test_graph.gexf" # name of the graph file that will be created only if downloading option is true
|
4 |
+
processing:
|
5 |
+
random_seed: 10
|
6 |
+
keep_unstructured_content: false # keep unstructured content of the papers as graph node attribute if true
|
7 |
+
arxiv_rate_limit: 3 # time in seconds to wait between each arxiv api call to avoid ban
|
8 |
+
|
9 |
+
retriever:
|
10 |
+
embedder: BAAI/bge-large-en-v1.5
|
11 |
+
num_retrievals: 30000
|
12 |
+
load_arxiv_embeds: True # load arxiv embeddings from huggingface if true else generate them
|
13 |
+
|
14 |
+
inference:
|
15 |
+
base_model: meta-llama/Meta-Llama-3-8B
|
16 |
+
pretrained_model: "models/Robotics/Meta-LLama-3-8B-Quantative-Robotics" # used only if training option is false
|
17 |
+
generation_args:
|
18 |
+
max_new_tokens: 1000
|
19 |
+
do_sample: True
|
20 |
+
top_p: 0.9
|
21 |
+
top_k: 50
|
22 |
+
temperature: 0.7
|
23 |
+
no_repeat_ngram_size: 2
|
24 |
+
num_beams: 1
|
25 |
+
gen_related_work_instruct_model: meta-llama/Llama-3.1-8B-Instruct # Model assisting at the generation of related work instructions
|
26 |
+
|
27 |
+
training:
|
28 |
+
predefined_graph_path: "robotics.gexf" # path to the graph dataset used for fine-tuning only if downloading option is false
|
29 |
+
trainer_args:
|
30 |
+
per_device_train_batch_size: 4
|
31 |
+
warmup_steps: 100
|
32 |
+
num_train_epochs: 1
|
33 |
+
learning_rate: 0.0002
|
34 |
+
lr_scheduler_type: 'cosine'
|
35 |
+
fp16: true
|
36 |
+
logging_steps: 1
|
37 |
+
save_steps: 50
|
38 |
+
trainer_output_dir: trainer_outputs/
|
39 |
+
tokenizer:
|
40 |
+
max_length: 1024
|
41 |
+
qlora:
|
42 |
+
rank: 8
|
43 |
+
lora_alpha: 32
|
44 |
+
lora_dropout: 0.05
|
45 |
+
target_modules: # modules for which to train lora adapters
|
46 |
+
- q_proj
|
47 |
+
- k_proj
|
48 |
+
- v_proj
|
49 |
+
- o_proj
|
50 |
+
|
51 |
+
# Used only if training option is true to save and load the fine-tuned model
|
52 |
+
model_saving:
|
53 |
+
model_name: llama_1b_qlora_uncensored
|
54 |
+
model_output_dir: models # model saved in {model_output_dir}/{model_name}_{index} # model saved in {model_output_dir}/{model_name}_{index} after fine-tuning completion
|
55 |
+
index: 1
|
configs/config_noUI.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Note: In order to train the model and then evaluate the same model, you need to
|
2 |
+
# make sure that the base_model to be the same in both the train and eval sections
|
3 |
+
|
4 |
+
# Evaluation configuration
|
5 |
+
eval:
|
6 |
+
base_model: meta-llama/Llama-3.2-1B
|
7 |
+
graph_path: datasets/quantum_graph.gexf
|
8 |
+
model_name: llama_1b_qlora_uncensored
|
9 |
+
|
10 |
+
# Training configuration
|
11 |
+
training:
|
12 |
+
graph_path: datasets/quantum_graph.gexf # path to the graph file to train on
|
13 |
+
base_model: meta-llama/Llama-3.2-1B
|
14 |
+
trainer_args:
|
15 |
+
per_device_train_batch_size: 4
|
16 |
+
warmup_steps: 100
|
17 |
+
num_train_epochs: 1
|
18 |
+
learning_rate: 0.0002
|
19 |
+
lr_scheduler_type: 'cosine'
|
20 |
+
fp16: true
|
21 |
+
logging_steps: 1
|
22 |
+
save_steps: 50
|
23 |
+
trainer_output_dir: trainer_outputs/
|
24 |
+
tokenizer:
|
25 |
+
max_length: 1024
|
26 |
+
qlora:
|
27 |
+
rank: 8
|
28 |
+
lora_alpha: 32
|
29 |
+
lora_dropout: 0.05
|
30 |
+
target_modules: # modules for which to train lora adapters
|
31 |
+
- q_proj
|
32 |
+
- k_proj
|
33 |
+
- v_proj
|
34 |
+
- o_proj
|
35 |
+
model_saving:
|
36 |
+
model_output_dir: models # model saved in {model_output_dir}/{model_name} after fine-tuning completion
|
37 |
+
model_name: llama_1b_qlora_uncensored
|
configs/latex_commands.yaml
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
verbatim_to_delete: [
|
2 |
+
'\IEEEpeerreviewmaketitle',
|
3 |
+
'\normalcolor',
|
4 |
+
'\ifCLASSOPTIONcaptionsoff',
|
5 |
+
'\pagebreak',
|
6 |
+
'\makeatletter',
|
7 |
+
'\makeatother',
|
8 |
+
'\maketitle',
|
9 |
+
'\preface',
|
10 |
+
'\eShell',
|
11 |
+
'\medskip',
|
12 |
+
'\tableofcontents',
|
13 |
+
'\begin{@twocolumnfalse}',
|
14 |
+
'\end{@twocolumnfalse}',
|
15 |
+
'\bgroup',
|
16 |
+
'\egroup',
|
17 |
+
'\ifnalpaper1',
|
18 |
+
'\let\thefootnote',
|
19 |
+
'\begin{spacing}{2.0}',
|
20 |
+
'\end{landscape}',
|
21 |
+
'\begin{landscape}',
|
22 |
+
'\begin{doublespacing}',
|
23 |
+
'\end{doublespacing}',
|
24 |
+
'\begin{spacing}',
|
25 |
+
'\end{spacing}',
|
26 |
+
'\printbibliography',
|
27 |
+
'\begin{sloppypar}',
|
28 |
+
'\end{sloppypar}',
|
29 |
+
'\ifbd',
|
30 |
+
'\iftr',
|
31 |
+
'\fussy',
|
32 |
+
'\sloppy',
|
33 |
+
'\emergencystretch',
|
34 |
+
'\hideLIPIcs',
|
35 |
+
'\tolerance',
|
36 |
+
'\hbadness',
|
37 |
+
'\bShell',
|
38 |
+
'\glsresetall',
|
39 |
+
'\copyrightnotice',
|
40 |
+
'\copyright',
|
41 |
+
'\centering',
|
42 |
+
'\immediate',
|
43 |
+
'\doublespacing',
|
44 |
+
'\flushbottom',
|
45 |
+
'\printAffiliationsAndNotice',
|
46 |
+
'\IEEEpubid',
|
47 |
+
'\twocolumn',
|
48 |
+
'\noindent',
|
49 |
+
'\indent',
|
50 |
+
'\onecolumn',
|
51 |
+
'\ignore',
|
52 |
+
'\selectfont',
|
53 |
+
'\raggedbottom',
|
54 |
+
'\IEEEoverridecommandlockouts',
|
55 |
+
'\newline',
|
56 |
+
'\tiny',
|
57 |
+
'\break',
|
58 |
+
'\mainmatter',
|
59 |
+
'\let\rc\rangle',
|
60 |
+
'\let\lc\langle',
|
61 |
+
'\acresetall',
|
62 |
+
'\acknowledgments',
|
63 |
+
'\begin{section}',
|
64 |
+
'\begin{small}',
|
65 |
+
'\end{small}',
|
66 |
+
'\relax',
|
67 |
+
'\ninept',
|
68 |
+
'\FloatBarrier',
|
69 |
+
'\boldmath',
|
70 |
+
'\end{section}',
|
71 |
+
'\Huge',
|
72 |
+
'\fancyhf',
|
73 |
+
'\fancyhead',
|
74 |
+
'\begin{frontmatter}',
|
75 |
+
'\end{frontmatter}',
|
76 |
+
'\clearpage',
|
77 |
+
'\huge',
|
78 |
+
'\newpage',
|
79 |
+
'\IEEEdisplaynontitleabstractindextext',
|
80 |
+
'\DontPrintSemicolon',
|
81 |
+
'\Large',
|
82 |
+
'\LARGE',
|
83 |
+
'\ifCLASSOPTIONcompsoc',
|
84 |
+
'\ifCLASSOPTIONonecolumn',
|
85 |
+
'\xspace',
|
86 |
+
'\large',
|
87 |
+
'\acmcopyr',
|
88 |
+
'\flushleft',
|
89 |
+
'\newpage',
|
90 |
+
'\protect',
|
91 |
+
'\begingroup',
|
92 |
+
'\endgroup',
|
93 |
+
'\bigskip',
|
94 |
+
'\smallskip',
|
95 |
+
'\small',
|
96 |
+
'\left',
|
97 |
+
'\right',
|
98 |
+
'\vfill',
|
99 |
+
'\hfill',
|
100 |
+
'\begin{appendices}',
|
101 |
+
'\IEEEQED',
|
102 |
+
'\leavevmode',
|
103 |
+
'\footnotesize',
|
104 |
+
'\nonumber',
|
105 |
+
'\scriptsize',
|
106 |
+
'\IEEEpubidadjcol',
|
107 |
+
'\balance',
|
108 |
+
'\normalsize',
|
109 |
+
'\ifloguseIEEEConf',
|
110 |
+
'\else',
|
111 |
+
'\fi',
|
112 |
+
'\bf ',
|
113 |
+
'\it ',
|
114 |
+
'\verb ',
|
115 |
+
'\tt ',
|
116 |
+
'\em ',
|
117 |
+
'\par '
|
118 |
+
]
|
119 |
+
|
120 |
+
two_arguments: [
|
121 |
+
'IEEEPARstart',
|
122 |
+
'pgfdeclareplotmark',
|
123 |
+
'setcounter',
|
124 |
+
'texorpdfstring',
|
125 |
+
'fontsize',
|
126 |
+
'addtocounter',
|
127 |
+
'addtolength'
|
128 |
+
]
|
129 |
+
|
130 |
+
three_arguments: [
|
131 |
+
'definecolor'
|
132 |
+
]
|
133 |
+
|
134 |
+
two_arguments_elaborate: [
|
135 |
+
'markboth',
|
136 |
+
'setlength',
|
137 |
+
'pgfdeclareplotmark',
|
138 |
+
'icmlsetsymbol',
|
139 |
+
'texorpdfstring',
|
140 |
+
'conferenceinfo',
|
141 |
+
'acrodef',
|
142 |
+
'icmlcorrespondingauthor',
|
143 |
+
'pdfbookmark',
|
144 |
+
'icmlaffiliation',
|
145 |
+
'icmlcorrespondingauthor'
|
146 |
+
]
|
147 |
+
|
148 |
+
three_arguments_elaborate: [
|
149 |
+
'ifthenelse',
|
150 |
+
'addcontentsline'
|
151 |
+
]
|
152 |
+
|
153 |
+
replace_comments: [
|
154 |
+
'def\\',
|
155 |
+
'def ',
|
156 |
+
'newglossaryentry',
|
157 |
+
'newtheorem',
|
158 |
+
'newcommand',
|
159 |
+
'renewcommand',
|
160 |
+
'newenvironment',
|
161 |
+
'renewenvironment'
|
162 |
+
]
|
datasets/computer_vision.gexf
ADDED
The diff for this file is too large to render.
See raw diff
|
|
datasets/economics.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb70ebc1810a487c967dc7d196b4b32ffb8237870502d85496633c7ec639e3a3
|
3 |
+
size 17470884
|
datasets/finance.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4739b443038f71850022a93ecf64c97d2fde08c492f716ee0d3e7a9d657302a2
|
3 |
+
size 18446416
|
datasets/large_language_models.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6b8fa29877039cb8b4665889421cdd4488752bcad6f512ac0b05678a73107c1
|
3 |
+
size 18262076
|
datasets/quantative_biology.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99666f9881f04850511a4e424243ed28228ae7cfd28859fc2f9104faa4e8bb4b
|
3 |
+
size 20476813
|
datasets/quantum_physics.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7abcb0d7aec2edb5461161ff613fe87176a4eb9c1522c2b0c1903fbccf66f74
|
3 |
+
size 20678301
|
datasets/robotics.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:05681fb68e8168d530cc5bd037f5a6236135446be05862e2be9bf61b548d2c98
|
3 |
+
size 19601009
|
datasets/telecommunications.gexf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d0f89f57398abcc70065c2756e15fcd2792c34f6ac9137ba1a8521a5dd8cd22c
|
3 |
+
size 14749926
|
litbench_pipeline.py
ADDED
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from train import *
|
2 |
+
from utils.utils import *
|
3 |
+
from utils.graph_utils import *
|
4 |
+
from utils.gradio_utils import *
|
5 |
+
from retriever.retriever import retriever
|
6 |
+
from tasks.abs_2_title import abs_2_title
|
7 |
+
from tasks.abs_completion import abs_completion
|
8 |
+
from tasks.citation_sentence import citation_sentence
|
9 |
+
from tasks.intro_2_abs import intro_2_abs
|
10 |
+
from tasks.link_pred import link_pred
|
11 |
+
from tasks.paper_retrieval import paper_retrieval
|
12 |
+
from tasks.influential_papers import influential_papers
|
13 |
+
from tasks.gen_related_work import gen_related_work
|
14 |
+
import random
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
import networkx as nx
|
19 |
+
import tarfile
|
20 |
+
import gzip
|
21 |
+
import time
|
22 |
+
import urllib.request
|
23 |
+
from tqdm import tqdm
|
24 |
+
from colorama import Fore
|
25 |
+
import wandb
|
26 |
+
import gradio as gr
|
27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, TextIteratorStreamer, pipeline
|
28 |
+
from threading import Thread
|
29 |
+
import signal
|
30 |
+
import gzip
|
31 |
+
import time
|
32 |
+
import torch
|
33 |
+
from peft.peft_model import PeftModel
|
34 |
+
from datasets import load_dataset
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
# Function to determine the chatbot's first message based on user choices
|
39 |
+
def setup(download_option, train_option):
|
40 |
+
download_papers.value = (download_option == "Download Paper")
|
41 |
+
train_model.value = (train_option == "Train")
|
42 |
+
|
43 |
+
if download_option == "Download Paper":
|
44 |
+
initial_message = [{"role": "assistant", "content": "Hello, what domain are you interested in?"}]
|
45 |
+
elif download_option != "Download Paper" and train_option == "Train":
|
46 |
+
initial_message = [{"role": "assistant", "content": "What domain is your graph about?"}]
|
47 |
+
else:
|
48 |
+
initial_message = [{"role": "assistant", "content": "Please provide your task prompt."}]
|
49 |
+
|
50 |
+
return gr.update(visible=False), gr.update(visible=True), f"Download: {download_option}\nTrain: {train_option}", initial_message
|
51 |
+
|
52 |
+
|
53 |
+
# Function to toggle the selected task based on user input
|
54 |
+
def update_button_styles(selected_task):
|
55 |
+
"""Update button styles based on selection."""
|
56 |
+
return [gr.update(variant="primary" if selected_task == prompt else "secondary") for prompt in task_list]
|
57 |
+
|
58 |
+
|
59 |
+
# Fetch and store arXiv source files
|
60 |
+
def fetch_arxiv_papers(papers_to_download):
|
61 |
+
# Download the arXiv metadata file if it doesn't exist
|
62 |
+
dataset = 'datasets/arxiv-metadata-oai-snapshot.json'
|
63 |
+
data = []
|
64 |
+
if not os.path.exists(dataset):
|
65 |
+
os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets")
|
66 |
+
|
67 |
+
with open(dataset, 'r') as f:
|
68 |
+
for line in f:
|
69 |
+
data.append(json.loads(line))
|
70 |
+
|
71 |
+
papers = [d for d in data]
|
72 |
+
paper_ids = [d['id'] for d in data]
|
73 |
+
paper_titles = [
|
74 |
+
(
|
75 |
+
re.sub(r' +', ' ', re.sub(r'[\n]+', ' ', paper['title']))
|
76 |
+
.replace("\\emph", "")
|
77 |
+
.replace("\\emp", "")
|
78 |
+
.replace("\\em", "")
|
79 |
+
.replace(",", "")
|
80 |
+
.replace("{", "")
|
81 |
+
.replace("}", "")
|
82 |
+
.strip(".")
|
83 |
+
.strip()
|
84 |
+
.strip(".")
|
85 |
+
.lower()
|
86 |
+
)
|
87 |
+
for paper in papers
|
88 |
+
]
|
89 |
+
paper_dict = {
|
90 |
+
k:v
|
91 |
+
for k,v in zip(paper_titles, paper_ids)
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
total_papers = len(papers_to_download)
|
96 |
+
download_progress_bar=gr.Progress()
|
97 |
+
|
98 |
+
llm_resp = []
|
99 |
+
results = {
|
100 |
+
"Number of papers": 0,
|
101 |
+
"Number of latex papers": 0,
|
102 |
+
"Number of bib files": 0,
|
103 |
+
"Number of bbl files": 0,
|
104 |
+
"Number of inline files": 0,
|
105 |
+
"Number of introductions found": 0,
|
106 |
+
"Number of related works found": 0,
|
107 |
+
"Number of succesful finding of extracts": 0
|
108 |
+
}
|
109 |
+
num_papers, num_edges, t, iter_ind = 0, 0, 0, 0
|
110 |
+
graph = {}
|
111 |
+
|
112 |
+
arxiv_rate_lim = config['data_downloading']['processing']['arxiv_rate_limit']
|
113 |
+
for paper_name in tqdm(papers_to_download):
|
114 |
+
results["Number of papers"] += 1
|
115 |
+
print(
|
116 |
+
Fore.BLUE + "Number of papers processed: {} \n Number of edges found: {} \n Time of previous iter: {} \n Now processing paper: {} \n\n"
|
117 |
+
.format(num_papers, num_edges, time.time()-t, paper_name) + Fore.RESET
|
118 |
+
)
|
119 |
+
t = time.time()
|
120 |
+
num_papers += 1
|
121 |
+
|
122 |
+
# Prepare the paper name for downloading and saving
|
123 |
+
paper_name_download = paper_name
|
124 |
+
if re.search(r'[a-zA-Z]', paper_name) is not None:
|
125 |
+
paper_name = "".join(paper_name.split('/'))
|
126 |
+
tar_file_path = save_zip_directory + paper_name + '.tar.gz'
|
127 |
+
|
128 |
+
# Attempt to download the paper source files from arXiv
|
129 |
+
try:
|
130 |
+
# Track start time for download
|
131 |
+
t1 = time.time()
|
132 |
+
urllib.request.urlretrieve(
|
133 |
+
"https://arxiv.org/src/"+paper_name_download,
|
134 |
+
tar_file_path)
|
135 |
+
except Exception as e:
|
136 |
+
print("Couldn't download paper {}".format(paper_name))
|
137 |
+
# Skip to the next paper if download fails
|
138 |
+
continue
|
139 |
+
|
140 |
+
# Define the directory where the paper will be extracted
|
141 |
+
extracted_dir = save_directory + paper_name + '/'
|
142 |
+
isExist = os.path.exists(extracted_dir)
|
143 |
+
if not isExist:
|
144 |
+
os.makedirs(extracted_dir)
|
145 |
+
|
146 |
+
# Attempt to extract the tar.gz archive
|
147 |
+
try:
|
148 |
+
tar = tarfile.open(tar_file_path)
|
149 |
+
tar.extractall(extracted_dir)
|
150 |
+
tar.close()
|
151 |
+
except Exception as e:
|
152 |
+
# If tar extraction fails, attempt to read and extract using gzip
|
153 |
+
try:
|
154 |
+
with gzip.open(tar_file_path, 'rb') as f:
|
155 |
+
file_content = f.read()
|
156 |
+
|
157 |
+
# Save the extracted content as a .tex file
|
158 |
+
with open(extracted_dir+paper_name+'.tex', 'w') as f:
|
159 |
+
f.write(file_content.decode())
|
160 |
+
except Exception as e:
|
161 |
+
print("Could not extract paper id: {}".format(paper_name))
|
162 |
+
# Skip this paper if extraction fails
|
163 |
+
continue
|
164 |
+
|
165 |
+
try:
|
166 |
+
# Perform initial cleaning and get the main TeX file
|
167 |
+
initial_clean(extracted_dir, config=False)
|
168 |
+
main_file = get_main(extracted_dir)
|
169 |
+
|
170 |
+
# If no main TeX file is found, remove the downloaded archive and continue
|
171 |
+
if main_file == None:
|
172 |
+
print("No tex files found")
|
173 |
+
os.remove(tar_file_path)
|
174 |
+
continue
|
175 |
+
|
176 |
+
# Check if the main TeX file contains a valid LaTeX document
|
177 |
+
h = check_begin(main_file)
|
178 |
+
if h == True:
|
179 |
+
results["Number of latex papers"] += 1
|
180 |
+
# Flag to check for internal bibliography
|
181 |
+
check_internal = 0
|
182 |
+
# Dictionary to store bibliographic references
|
183 |
+
final_library = {}
|
184 |
+
|
185 |
+
# Identify bibliography files (.bib or .bbl)
|
186 |
+
bib_files = find_bib(extracted_dir)
|
187 |
+
if bib_files == []:
|
188 |
+
bbl_files = find_bbl(extracted_dir)
|
189 |
+
if bbl_files == []:
|
190 |
+
# No external bibliography found
|
191 |
+
check_internal = 1
|
192 |
+
else:
|
193 |
+
final_library = get_library_bbl(bbl_files)
|
194 |
+
results["Number of bbl files"] += 1
|
195 |
+
else:
|
196 |
+
results["Number of bib files"] += 1
|
197 |
+
final_library = get_library_bib(bib_files)
|
198 |
+
|
199 |
+
# Apply post-processing to clean the TeX document
|
200 |
+
main_file = post_processing(extracted_dir, main_file)
|
201 |
+
|
202 |
+
# Read the cleaned LaTeX document content
|
203 |
+
descr = main_file
|
204 |
+
content = read_tex_file(descr)
|
205 |
+
|
206 |
+
# If configured, store the raw content in the graph
|
207 |
+
if config['data_downloading']['processing']['keep_unstructured_content']:
|
208 |
+
graph[paper_name] = {'content': content}
|
209 |
+
else:
|
210 |
+
graph[paper_name] = {}
|
211 |
+
|
212 |
+
# Check for inline bibliography within the LaTeX document
|
213 |
+
if check_internal == 1:
|
214 |
+
beginning_bib = '\\begin{thebibliography}'
|
215 |
+
end_bib = '\\end{thebibliography}'
|
216 |
+
|
217 |
+
if content.find(beginning_bib) != -1 and content.find(end_bib) != -1:
|
218 |
+
bibliography = content[content.find(beginning_bib):content.find(end_bib) + len(end_bib)]
|
219 |
+
save_bbl = os.path.join(extracted_dir, "bibliography.bbl")
|
220 |
+
|
221 |
+
results["Number of inline files"] += 1
|
222 |
+
with open(save_bbl, "w") as f:
|
223 |
+
f.write(bibliography)
|
224 |
+
|
225 |
+
final_library = get_library_bbl([save_bbl])
|
226 |
+
|
227 |
+
# If no valid bibliography is found, skip processing citations
|
228 |
+
if final_library == {}:
|
229 |
+
print("No library found...")
|
230 |
+
continue
|
231 |
+
|
232 |
+
# Extract relevant sections such as "Related Work" and "Introduction"
|
233 |
+
related_works = get_related_works(content)
|
234 |
+
if related_works != '':
|
235 |
+
graph[paper_name]['Related Work'] = related_works
|
236 |
+
results["Number of intro/related found"] += 1
|
237 |
+
|
238 |
+
intro = get_intro(content)
|
239 |
+
if intro != '':
|
240 |
+
graph[paper_name]['Introduction'] = intro
|
241 |
+
results["Number of introductions found"] += 1
|
242 |
+
|
243 |
+
# Extract citation sentences from the introduction and related works
|
244 |
+
sentences_citing = get_citing_sentences(intro + '\n' + related_works)
|
245 |
+
|
246 |
+
# Map citations to corresponding papers
|
247 |
+
raw_sentences_citing = {}
|
248 |
+
for k,v in sentences_citing.items():
|
249 |
+
new_values = []
|
250 |
+
for item in v:
|
251 |
+
try:
|
252 |
+
new_values.append(paper_dict[final_library[item]['title']])
|
253 |
+
except Exception as e:
|
254 |
+
pass
|
255 |
+
if new_values != []:
|
256 |
+
raw_sentences_citing[k] = new_values
|
257 |
+
|
258 |
+
# Construct citation edges
|
259 |
+
edges_set = []
|
260 |
+
for k,v in raw_sentences_citing.items():
|
261 |
+
for item in v:
|
262 |
+
edges_set.append((paper_name_download, item, {"sentence":k}))
|
263 |
+
|
264 |
+
iter_ind +=1
|
265 |
+
if len(edges_set) !=0:
|
266 |
+
results["Number of succesful finding of extracts"] += 1
|
267 |
+
graph[paper_name]['Citations'] = edges_set
|
268 |
+
num_edges += len(edges_set)
|
269 |
+
|
270 |
+
# Save progress after every 10 iterations
|
271 |
+
if iter_ind % 10 == 0:
|
272 |
+
print("Saving graph now")
|
273 |
+
with open(save_path, 'w') as f:
|
274 |
+
json.dump(results, f)
|
275 |
+
with open(save_graph, 'w') as f:
|
276 |
+
json.dump(graph, f)
|
277 |
+
|
278 |
+
except Exception as e:
|
279 |
+
print("Could not get main paper {}".format(paper_name))
|
280 |
+
|
281 |
+
# Update the progress bar after processing each paper
|
282 |
+
download_progress_bar(num_papers / total_papers)
|
283 |
+
|
284 |
+
|
285 |
+
# Ensure a minimum time gap of 3 seconds between iterations to avoid bans from arXiv
|
286 |
+
t2 = time.time() # End time
|
287 |
+
elapsed_time = t2 - t1
|
288 |
+
if elapsed_time < arxiv_rate_lim:
|
289 |
+
time.sleep(arxiv_rate_lim - elapsed_time)
|
290 |
+
|
291 |
+
|
292 |
+
# Final saving of processed data
|
293 |
+
with open(save_graph, 'w') as f:
|
294 |
+
json.dump(graph, f)
|
295 |
+
with open(save_path, 'w') as f:
|
296 |
+
json.dump(results, f)
|
297 |
+
|
298 |
+
|
299 |
+
# Log final completion message
|
300 |
+
llm_resp.append("✅ Successfully downloaded and cleaned {} papers.".format(results["Number of latex papers"]))
|
301 |
+
return "\n".join(llm_resp)
|
302 |
+
|
303 |
+
|
304 |
+
# Chat prediction function
|
305 |
+
def predict(message, history, selected_task):
|
306 |
+
global model
|
307 |
+
# Initialize the conversation string
|
308 |
+
conversation = ""
|
309 |
+
|
310 |
+
# Parse the history: Gradio `type="messages"` uses dictionaries with 'role' and 'content'
|
311 |
+
for item in history:
|
312 |
+
if item["role"] == "assistant":
|
313 |
+
conversation += f"<bot>: {item['content']}\n"
|
314 |
+
elif item["role"] == "user":
|
315 |
+
conversation += f"<human>: {item['content']}\n"
|
316 |
+
|
317 |
+
# Add the user's current message to the conversation
|
318 |
+
conversation += f"<human>: {message}\n<bot>:"
|
319 |
+
|
320 |
+
# Handle preferences
|
321 |
+
if len(history) == 0:
|
322 |
+
if not download_papers.value and not train_model.value:
|
323 |
+
yield "✅ Using model from configuration file..."
|
324 |
+
|
325 |
+
adapter_path = config["inference"]["pretrained_model"]
|
326 |
+
peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
|
327 |
+
|
328 |
+
# change the global model with peft model
|
329 |
+
model = peft_model
|
330 |
+
|
331 |
+
time.sleep(2.5)
|
332 |
+
|
333 |
+
if not (len(history) == 0 and (train_model.value or download_papers.value)):
|
334 |
+
# Streamer for generating responses
|
335 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
336 |
+
stop = StopOnTokens()
|
337 |
+
|
338 |
+
generate_kwargs = {
|
339 |
+
"streamer": streamer,
|
340 |
+
"max_new_tokens": config['inference']['generation_args']["max_new_tokens"],
|
341 |
+
"do_sample": config['inference']['generation_args']["do_sample"],
|
342 |
+
"top_p": config['inference']['generation_args']["top_p"],
|
343 |
+
"top_k": config['inference']['generation_args']["top_k"],
|
344 |
+
"temperature": config['inference']['generation_args']["temperature"],
|
345 |
+
"no_repeat_ngram_size": config['inference']['generation_args']["no_repeat_ngram_size"],
|
346 |
+
"num_beams": config['inference']['generation_args']["num_beams"],
|
347 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
348 |
+
}
|
349 |
+
|
350 |
+
def generate_response(model, generate_kwargs, selected_task):
|
351 |
+
global advanced_tasks_out
|
352 |
+
has_predefined_template = generate_kwargs["streamer"].tokenizer.chat_template is not None
|
353 |
+
|
354 |
+
if selected_task == "Abstract Completion":
|
355 |
+
prompt = abs_completion(message, template, has_predefined_template)
|
356 |
+
elif selected_task == "Title Generation":
|
357 |
+
prompt = abs_2_title(message, template, has_predefined_template)
|
358 |
+
elif selected_task == "Citation Recommendation":
|
359 |
+
prompt = paper_retrieval(message, template, has_predefined_template)
|
360 |
+
elif selected_task == "Citation Sentence Generation":
|
361 |
+
prompt = citation_sentence(message, template, has_predefined_template)
|
362 |
+
elif selected_task == "Citation Link Prediction":
|
363 |
+
prompt = link_pred(message, template, has_predefined_template)
|
364 |
+
elif selected_task == "Introduction to Abstract":
|
365 |
+
prompt = intro_2_abs(message, template, tokenizer.model_max_length, has_predefined_template)
|
366 |
+
elif selected_task == "Influential Papers Recommendation":
|
367 |
+
if download_papers.value:
|
368 |
+
graph = nx.read_gexf(gexf_file)
|
369 |
+
advanced_tasks_out = influential_papers(message, graph)
|
370 |
+
else:
|
371 |
+
graph = nx.read_gexf(predef_graph)
|
372 |
+
advanced_tasks_out = influential_papers(message, graph)
|
373 |
+
elif selected_task == "Related Work Generation":
|
374 |
+
adapter_path = (
|
375 |
+
f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph"
|
376 |
+
if train_model.value else config['inference']['pretrained_model']
|
377 |
+
)
|
378 |
+
if download_papers.value:
|
379 |
+
advanced_tasks_out = gen_related_work(message, gexf_file, adapter_path)
|
380 |
+
else:
|
381 |
+
advanced_tasks_out = gen_related_work(message, predef_graph, adapter_path)
|
382 |
+
else:
|
383 |
+
prompt = conversation + f"<human>: {message}\n<bot>:"
|
384 |
+
|
385 |
+
if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation":
|
386 |
+
if tokenizer.chat_template is not None:
|
387 |
+
response = model_pipeline(prompt, **generate_kwargs)
|
388 |
+
streamer.put(response[0]['generated_text'][-1])
|
389 |
+
else:
|
390 |
+
model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
391 |
+
generate_kwargs["inputs"] = model_inputs["input_ids"]
|
392 |
+
generate_kwargs["attention_mask"] = model_inputs["attention_mask"]
|
393 |
+
|
394 |
+
response = model.generate(**generate_kwargs)
|
395 |
+
streamer.put(response)
|
396 |
+
|
397 |
+
# Generate the response in a separate thread
|
398 |
+
t = Thread(target=generate_response,
|
399 |
+
kwargs={
|
400 |
+
"model": model,
|
401 |
+
"generate_kwargs": generate_kwargs,
|
402 |
+
"selected_task": selected_task
|
403 |
+
})
|
404 |
+
|
405 |
+
global advanced_tasks_out
|
406 |
+
advanced_tasks_out = None
|
407 |
+
t.start()
|
408 |
+
|
409 |
+
# Stream the partial response
|
410 |
+
if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation":
|
411 |
+
partial_message = ""
|
412 |
+
for new_token in streamer:
|
413 |
+
if new_token != '<': # Ignore placeholder tokens
|
414 |
+
partial_message += new_token
|
415 |
+
yield partial_message
|
416 |
+
else:
|
417 |
+
if selected_task == "Related Work Generation":
|
418 |
+
yield "🔍 Generating related work..."
|
419 |
+
while advanced_tasks_out == None:
|
420 |
+
time.sleep(0.1)
|
421 |
+
yield advanced_tasks_out
|
422 |
+
|
423 |
+
# Fetch arXiv papers if the user opted to download them
|
424 |
+
if len(history) == 0:
|
425 |
+
if download_papers.value:
|
426 |
+
# Fetch relevant papers
|
427 |
+
yield "🔍 Retrieving relevant papers..."
|
428 |
+
|
429 |
+
retrieve_progress = gr.Progress()
|
430 |
+
for percent in retriever(message, retrieval_nodes_path):
|
431 |
+
retrieve_progress(percent)
|
432 |
+
|
433 |
+
with open(retrieval_nodes_path, "r") as f:
|
434 |
+
data_download = json.load(f)
|
435 |
+
|
436 |
+
papers_to_download = list(data_download.keys())
|
437 |
+
|
438 |
+
yield f"📥 Fetching {len(papers_to_download)} arXiv papers' source files... Please wait."
|
439 |
+
|
440 |
+
content = fetch_arxiv_papers(papers_to_download)
|
441 |
+
yield content
|
442 |
+
time.sleep(2.5)
|
443 |
+
|
444 |
+
|
445 |
+
# Train the model with the retrieved graph
|
446 |
+
if len(history) == 0:
|
447 |
+
if train_model.value:
|
448 |
+
training_progress=gr.Progress()
|
449 |
+
|
450 |
+
training_progress(0.0)
|
451 |
+
|
452 |
+
# If the user opted to download papers, use the retrieved graph, else use the predefined graph
|
453 |
+
if download_papers.value:
|
454 |
+
yield "🚀 Training the model with the retrieved graph..."
|
455 |
+
|
456 |
+
with open(save_graph, "r") as f:
|
457 |
+
data_graph = json.load(f)
|
458 |
+
|
459 |
+
renamed_data = {
|
460 |
+
"/".join(re.match(r"([a-z-]+)([0-9]+)", key, re.I).groups()) if re.match(r"([a-z-]+)([0-9]+)", key, re.I) else key: value
|
461 |
+
for key, value in data_graph.items()
|
462 |
+
}
|
463 |
+
|
464 |
+
concept_data = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
|
465 |
+
id2topics = {
|
466 |
+
entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
|
467 |
+
for entry in concept_data["train"]
|
468 |
+
}
|
469 |
+
|
470 |
+
dataset = 'datasets/arxiv-metadata-oai-snapshot.json'
|
471 |
+
data = []
|
472 |
+
if not os.path.exists(dataset):
|
473 |
+
os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets")
|
474 |
+
with open(dataset, 'r') as f:
|
475 |
+
for line in f:
|
476 |
+
data.append(json.loads(line))
|
477 |
+
papers = {d['id']: d for d in data}
|
478 |
+
|
479 |
+
G = nx.DiGraph()
|
480 |
+
for k in renamed_data:
|
481 |
+
if k not in G and k in papers:
|
482 |
+
if config['data_downloading']['processing']['keep_unstructured_content']:
|
483 |
+
G.add_node(
|
484 |
+
k,
|
485 |
+
title=papers[k]['title'],
|
486 |
+
abstract=papers[k]['abstract'],
|
487 |
+
introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '',
|
488 |
+
related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '',
|
489 |
+
concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else '',
|
490 |
+
content=renamed_data[k].get('content', '') if k in renamed_data else ''
|
491 |
+
)
|
492 |
+
else:
|
493 |
+
G.add_node(
|
494 |
+
k,
|
495 |
+
title=papers[k]['title'],
|
496 |
+
abstract=papers[k]['abstract'],
|
497 |
+
introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '',
|
498 |
+
related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '',
|
499 |
+
concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else ''
|
500 |
+
)
|
501 |
+
if 'Citations' in renamed_data[k]:
|
502 |
+
for citation in renamed_data[k]['Citations']:
|
503 |
+
source, target, metadata = citation
|
504 |
+
sentence = metadata.get('sentence', '') # Extract sentence or default to empty string
|
505 |
+
|
506 |
+
if target not in G and target in papers:
|
507 |
+
if config['data_downloading']['processing']['keep_unstructured_content']:
|
508 |
+
G.add_node(
|
509 |
+
target,
|
510 |
+
title=papers[target]['title'],
|
511 |
+
abstract=papers[target]['abstract'],
|
512 |
+
introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '',
|
513 |
+
related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '',
|
514 |
+
concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else '',
|
515 |
+
content=renamed_data[target].get('content', '') if target in renamed_data else ''
|
516 |
+
)
|
517 |
+
else:
|
518 |
+
G.add_node(
|
519 |
+
target,
|
520 |
+
title=papers[target]['title'],
|
521 |
+
abstract=papers[target]['abstract'],
|
522 |
+
introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '',
|
523 |
+
related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '',
|
524 |
+
concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else ''
|
525 |
+
)
|
526 |
+
|
527 |
+
G.add_edge(source, target, sentence=sentence)
|
528 |
+
|
529 |
+
G.remove_nodes_from(list(nx.isolates(G)))
|
530 |
+
|
531 |
+
nx.write_gexf(G, gexf_file)
|
532 |
+
print(f"Processed graph written to {gexf_file}")
|
533 |
+
else:
|
534 |
+
yield f"✅ Using predefined graph: {predef_graph}"
|
535 |
+
|
536 |
+
|
537 |
+
wandb.init(project='qlora_train')
|
538 |
+
|
539 |
+
if download_papers.value:
|
540 |
+
trainer = QloraTrainer_CS(config=config, use_predefined_graph=False)
|
541 |
+
else:
|
542 |
+
trainer = QloraTrainer_CS(config=config, use_predefined_graph=True)
|
543 |
+
|
544 |
+
print("Load base model")
|
545 |
+
trainer.load_base_model()
|
546 |
+
|
547 |
+
|
548 |
+
print("Start training")
|
549 |
+
def update_progress():
|
550 |
+
# Wait for the trainer to be initialized
|
551 |
+
while trainer.transformer_trainer is None:
|
552 |
+
time.sleep(0.5)
|
553 |
+
|
554 |
+
time.sleep(1.5)
|
555 |
+
# Update the progress bar until training is complete
|
556 |
+
while trainer.transformer_trainer.state.global_step != trainer.transformer_trainer.state.max_steps:
|
557 |
+
progress_bar = (
|
558 |
+
trainer.transformer_trainer.state.global_step /
|
559 |
+
trainer.transformer_trainer.state.max_steps
|
560 |
+
)
|
561 |
+
training_progress(progress_bar)
|
562 |
+
time.sleep(0.5)
|
563 |
+
training_progress(1.0)
|
564 |
+
|
565 |
+
t1 = Thread(target=trainer.train)
|
566 |
+
t1.start()
|
567 |
+
t2 = Thread(target=update_progress())
|
568 |
+
t2.start()
|
569 |
+
t1.join()
|
570 |
+
t2.join()
|
571 |
+
|
572 |
+
yield "🎉 Model training complete! Please provide your task prompt."
|
573 |
+
|
574 |
+
adapter_path = f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph"
|
575 |
+
peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
|
576 |
+
|
577 |
+
# change the global model with peft model
|
578 |
+
model = peft_model
|
579 |
+
|
580 |
+
|
581 |
+
|
582 |
+
if __name__ == "__main__":
|
583 |
+
print("This is running in a virtual environment: {}".format(is_venv()))
|
584 |
+
|
585 |
+
config = read_yaml_file("configs/config.yaml")
|
586 |
+
template_file_path = 'configs/alpaca.json'
|
587 |
+
template = json.load(open(template_file_path, "r"))
|
588 |
+
|
589 |
+
|
590 |
+
seed_no = config['data_downloading']['processing']['random_seed']
|
591 |
+
model_name = config['inference']['base_model']
|
592 |
+
working_dir = config['data_downloading']['download_directory']
|
593 |
+
save_zip_directory = working_dir + 'research_papers_zip/'
|
594 |
+
save_directory = working_dir + 'research_papers/'
|
595 |
+
save_description = working_dir + 'description/'
|
596 |
+
save_path = save_description + 'results.json'
|
597 |
+
save_graph = save_description + 'test_graph.json'
|
598 |
+
gexf_file = save_description + config['data_downloading']['gexf_file']
|
599 |
+
predef_graph = 'datasets/' + config['training']['predefined_graph_path']
|
600 |
+
retrieval_nodes_path = 'datasets/retrieval_nodes.json'
|
601 |
+
|
602 |
+
isExist = os.path.exists(save_zip_directory)
|
603 |
+
if not isExist:
|
604 |
+
os.makedirs(save_zip_directory)
|
605 |
+
isExist = os.path.exists(save_directory)
|
606 |
+
if not isExist:
|
607 |
+
os.makedirs(save_directory)
|
608 |
+
isExist = os.path.exists(save_description)
|
609 |
+
if not isExist:
|
610 |
+
os.makedirs(save_description)
|
611 |
+
|
612 |
+
|
613 |
+
random.seed(seed_no)
|
614 |
+
|
615 |
+
|
616 |
+
# Load model and tokenizer
|
617 |
+
bnb_config = BitsAndBytesConfig(
|
618 |
+
load_in_8bit=True,
|
619 |
+
bnb_8bit_use_double_quant=True,
|
620 |
+
bnb_8bit_quant_type="nf8",
|
621 |
+
bnb_8bit_compute_dtype=torch.bfloat16
|
622 |
+
)
|
623 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
624 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
|
625 |
+
if model.device.type != 'cuda':
|
626 |
+
model.to('cuda')
|
627 |
+
|
628 |
+
if tokenizer.chat_template is not None:
|
629 |
+
model_pipeline = pipeline(
|
630 |
+
"text-generation",
|
631 |
+
model=model_name,
|
632 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
633 |
+
device_map="auto",
|
634 |
+
)
|
635 |
+
|
636 |
+
signal.signal(signal.SIGINT, signal_handler)
|
637 |
+
|
638 |
+
|
639 |
+
# Global States for User Preferences
|
640 |
+
download_papers = gr.State(value=True) # Default: Download papers
|
641 |
+
train_model = gr.State(value=True) # Default: Train the model
|
642 |
+
|
643 |
+
|
644 |
+
# Categorized Recommended Prompts
|
645 |
+
task_list = {
|
646 |
+
"Abstract Completion",
|
647 |
+
"Introduction to Abstract",
|
648 |
+
"Title Generation",
|
649 |
+
"Citation Recommendation",
|
650 |
+
"Citation Sentence Generation",
|
651 |
+
"Citation Link Prediction",
|
652 |
+
"Influential Papers Recommendation",
|
653 |
+
"Related Work Generation",
|
654 |
+
}
|
655 |
+
|
656 |
+
|
657 |
+
# CSS for Styling
|
658 |
+
css = """
|
659 |
+
body { background-color: #E0F7FA; margin: 0; padding: 0; }
|
660 |
+
.gradio-container { background-color: #E0F7FA; border-radius: 10px; }
|
661 |
+
#logo-container { display: flex; justify-content: center; align-items: center; margin: 0 auto; padding: 0; max-width: 120px; height: 120px; border-radius: 10px; overflow: hidden; }
|
662 |
+
#scroll-menu { max-height: 310px; overflow-y: auto; padding: 10px; background-color: #fff; margin-top: 10px;}
|
663 |
+
#task-header { background-color: #0288d1; color: white; font-size: 18px; padding: 8px; text-align: center; margin-bottom: 5px; margin-top: 40px; }
|
664 |
+
#category-header { background-color: #ecb939; font-size: 16px; padding: 8px; margin: 10px 0; }
|
665 |
+
"""
|
666 |
+
|
667 |
+
# State to store the selected task
|
668 |
+
selected_task = gr.State(value="")
|
669 |
+
|
670 |
+
|
671 |
+
# Gradio Interface
|
672 |
+
with gr.Blocks(theme="soft", css=css) as demo:
|
673 |
+
gr.HTML('<div id="logo-container"><img src="https://static.thenounproject.com/png/6480915-200.png" alt="Logo"></div>')
|
674 |
+
gr.Markdown("# LitBench Interface")
|
675 |
+
|
676 |
+
|
677 |
+
# Setup row for user preferences
|
678 |
+
with gr.Row(visible=True) as setup_row:
|
679 |
+
with gr.Column():
|
680 |
+
gr.Markdown("### Setup Your Preferences")
|
681 |
+
download_option = gr.Dropdown(
|
682 |
+
choices=["Download Paper", "Don't Download"],
|
683 |
+
value="Download Paper",
|
684 |
+
label="Download Option"
|
685 |
+
)
|
686 |
+
train_option = gr.Dropdown(
|
687 |
+
choices=["Train", "Don't Train"],
|
688 |
+
value="Train",
|
689 |
+
label="Training Option"
|
690 |
+
)
|
691 |
+
setup_button = gr.Button("Set Preferences and Proceed")
|
692 |
+
|
693 |
+
|
694 |
+
# Chatbot row for user interaction
|
695 |
+
with gr.Row(visible=False) as chatbot_row:
|
696 |
+
# Store the currently selected task
|
697 |
+
with gr.Column(scale=3):
|
698 |
+
gr.Markdown("### Start Chatting!")
|
699 |
+
chatbot = gr.ChatInterface(
|
700 |
+
predict,
|
701 |
+
chatbot=gr.Chatbot(
|
702 |
+
height=400,
|
703 |
+
type="messages",
|
704 |
+
avatar_images=[
|
705 |
+
"https://icons.veryicon.com/png/o/miscellaneous/user-avatar/user-avatar-male-5.png",
|
706 |
+
"https://cdn-icons-png.flaticon.com/512/8649/8649595.png"
|
707 |
+
],
|
708 |
+
),
|
709 |
+
textbox=gr.Textbox(placeholder="Type your message here..."),
|
710 |
+
additional_inputs=selected_task,
|
711 |
+
additional_inputs_accordion=gr.Accordion(visible=False, label="Additional Inputs", ),
|
712 |
+
)
|
713 |
+
|
714 |
+
# Store user preferences and selected task for display
|
715 |
+
preferences_output = gr.Textbox(value="", interactive=False, label="Your Preferences")
|
716 |
+
|
717 |
+
|
718 |
+
# Task selection buttons for user interaction
|
719 |
+
with gr.Column(scale=1):
|
720 |
+
gr.HTML('<div id="task-header">Tasks:</div>')
|
721 |
+
with gr.Column(elem_id="scroll-menu"):
|
722 |
+
# Create buttons
|
723 |
+
button_map = {prompt: gr.Button(prompt) for prompt in task_list}
|
724 |
+
|
725 |
+
for prompt in task_list:
|
726 |
+
button_map[prompt].click(
|
727 |
+
toggle_selection,
|
728 |
+
inputs=[selected_task, gr.State(value=prompt)], # Toggle task selection
|
729 |
+
outputs=selected_task
|
730 |
+
).then(
|
731 |
+
update_button_styles, # Update button appearances
|
732 |
+
inputs=[selected_task],
|
733 |
+
outputs=[button_map[p] for p in task_list] # Update all buttons
|
734 |
+
)
|
735 |
+
|
736 |
+
|
737 |
+
# Setup button to finalize user preferences and start chatbot
|
738 |
+
setup_button.click(
|
739 |
+
setup,
|
740 |
+
inputs=[download_option, train_option],
|
741 |
+
outputs=[setup_row, chatbot_row, preferences_output, chatbot.chatbot]
|
742 |
+
)
|
743 |
+
|
744 |
+
|
745 |
+
# Launch the interface
|
746 |
+
demo.launch(server_port=7880)
|
retriever/retriever.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
from datasets import load_dataset
|
10 |
+
from utils.utils import read_yaml_file
|
11 |
+
|
12 |
+
|
13 |
+
def generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs):
|
14 |
+
id2topics = {
|
15 |
+
entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
|
16 |
+
for entry in tmp_id_2_abs['train']
|
17 |
+
}
|
18 |
+
|
19 |
+
for topic_level in ['Level 1', 'Level 2', 'Level 3']:
|
20 |
+
i = 0
|
21 |
+
batch_size = 2048
|
22 |
+
candidate_emb_list = []
|
23 |
+
pbar = tqdm(total=len(paper_list))
|
24 |
+
while i < len(paper_list):
|
25 |
+
yield i / len(paper_list) / 3 if topic_level == 'Level 1' else 0.33 + i / len(paper_list) / 3 if topic_level == 'Level 2' else 0.66 + i / len(paper_list) / 3
|
26 |
+
paper_batch = paper_list[i:i+batch_size]
|
27 |
+
paper_text_batch = []
|
28 |
+
for paper_id in paper_batch:
|
29 |
+
topics = id2topics[paper_id][int(topic_level[6])-1]
|
30 |
+
topic_text = ''
|
31 |
+
for t in topics:
|
32 |
+
topic_text += t + ','
|
33 |
+
paper_text_batch.append(topic_text)
|
34 |
+
inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
|
35 |
+
with torch.no_grad():
|
36 |
+
outputs = model(**inputs.to('cuda'))
|
37 |
+
candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
|
38 |
+
candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
|
39 |
+
candidate_emb_list.append(candidate_embeddings)
|
40 |
+
|
41 |
+
i += len(candidate_embeddings)
|
42 |
+
pbar.update(len(candidate_embeddings))
|
43 |
+
|
44 |
+
all_candidate_embs = torch.cat(candidate_emb_list, 0)
|
45 |
+
|
46 |
+
df = pd.DataFrame({
|
47 |
+
"paper_id": paper_list,
|
48 |
+
"embedding": list(all_candidate_embs.numpy())
|
49 |
+
})
|
50 |
+
|
51 |
+
if not os.path.exists('datasets/topic_level_embeds'):
|
52 |
+
os.makedirs('datasets/topic_level_embeds')
|
53 |
+
|
54 |
+
df.to_parquet(f'datasets/topic_level_embeds/{topic_level}_emb.parquet', engine='pyarrow', compression='snappy')
|
55 |
+
|
56 |
+
all_candidate_embs_L1 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 1_emb.parquet')['embedding'].tolist()))
|
57 |
+
all_candidate_embs_L2 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 2_emb.parquet')['embedding'].tolist()))
|
58 |
+
all_candidate_embs_L3 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 3_emb.parquet')['embedding'].tolist()))
|
59 |
+
all_candidate_embs = all_candidate_embs_L1 + all_candidate_embs_L2 + all_candidate_embs_L3
|
60 |
+
|
61 |
+
df = pd.DataFrame({
|
62 |
+
"paper_id": paper_list,
|
63 |
+
"embedding": list(all_candidate_embs.numpy())
|
64 |
+
})
|
65 |
+
|
66 |
+
df.to_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet', engine='pyarrow', compression='snappy')
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def retriever(query, retrieval_nodes_path):
|
71 |
+
yield 0
|
72 |
+
config = read_yaml_file('configs/config.yaml')
|
73 |
+
|
74 |
+
# Load the model and tokenizer to generate the embeddings
|
75 |
+
embedder_name = config['retriever']['embedder']
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained(embedder_name)
|
77 |
+
model = AutoModel.from_pretrained(embedder_name).to(device='cuda', dtype=torch.float16)
|
78 |
+
|
79 |
+
|
80 |
+
# Load the arXiv dataset
|
81 |
+
tmp_id_2_abs = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
|
82 |
+
paper_list = list(tmp_id_2_abs['train']['paper_id'])
|
83 |
+
|
84 |
+
|
85 |
+
# Generate the query embeddings
|
86 |
+
inputs = tokenizer([query], return_tensors='pt', padding=True, truncation=True)
|
87 |
+
with torch.no_grad():
|
88 |
+
outputs = model(**inputs.to('cuda'))
|
89 |
+
query_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
|
90 |
+
|
91 |
+
# Generate the candidate embeddings
|
92 |
+
# Load the embeddings from the dataset, otherwise generate the embeddings and save them
|
93 |
+
if config['retriever']['load_arxiv_embeds']:
|
94 |
+
dataset = load_dataset("AliMaatouk/arXiv-Topics-Embeddings", cache_dir="datasets/topic_level_embeds")
|
95 |
+
table = dataset["train"].data # Get PyArrow Table
|
96 |
+
all_candidate_embs = table.column("embedding").to_numpy()
|
97 |
+
else:
|
98 |
+
# If the file does not exist, generate the embeddings, otherwise, load the embeddings
|
99 |
+
if not os.path.exists('datasets/topic_level_embeds/arxiv_papers_embeds.parquet'):
|
100 |
+
yield from generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs)
|
101 |
+
|
102 |
+
all_candidate_embs = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet')['embedding'].tolist()))
|
103 |
+
all_candidate_embs = all_candidate_embs.cpu().numpy()
|
104 |
+
|
105 |
+
all_candidate_embs = np.stack(all_candidate_embs)
|
106 |
+
|
107 |
+
|
108 |
+
# Calculate the cosine similarity between the query and all candidate embeddings
|
109 |
+
query_embeddings = np.array(query_embeddings)
|
110 |
+
similarity_scores = cosine_similarity(query_embeddings, all_candidate_embs)[0]
|
111 |
+
|
112 |
+
|
113 |
+
# Sort the papers by similarity scores and select the top K papers
|
114 |
+
id_score_list = []
|
115 |
+
for i in range(len(paper_list)):
|
116 |
+
id_score_list.append([paper_list[i], similarity_scores[i]])
|
117 |
+
|
118 |
+
sorted_scores = sorted(id_score_list, key=lambda i: i[-1], reverse = True)
|
119 |
+
top_K_paper = [sample[0] for sample in sorted_scores[:config['retriever']['num_retrievals']]]
|
120 |
+
|
121 |
+
papers_results = {
|
122 |
+
paper: True
|
123 |
+
for paper in top_K_paper
|
124 |
+
}
|
125 |
+
|
126 |
+
with open(retrieval_nodes_path, 'w') as f:
|
127 |
+
json.dump(papers_results, f)
|
128 |
+
|
129 |
+
yield 1.0
|
tasks/abs_2_title.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generate a prompt for generating the title of a paper based on its abstract.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
usr_input (str): A string containing the title and abstract of the paper in the format "Title: <title> Abstract: <abstract>".
|
6 |
+
template (dict): A dictionary containing the template for the prompt with a key "prompt_input".
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
str: A formatted string with the instruction and abstract to be used as input for generating the title.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def abs_2_title(usr_input, template, has_predefined_template=False):
|
13 |
+
instruction = "Please generate the title of paper based on its abstract"
|
14 |
+
|
15 |
+
if has_predefined_template:
|
16 |
+
res = [
|
17 |
+
{"role": "system", "content": instruction},
|
18 |
+
{"role": "user", "content": usr_input},
|
19 |
+
]
|
20 |
+
else:
|
21 |
+
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
22 |
+
|
23 |
+
return res
|
tasks/abs_completion.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generates a formatted prompt for completing the abstract of a paper.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
usr_input (str): The user input containing the title and part of the abstract.
|
6 |
+
Expected format:
|
7 |
+
"Title: <title>\nAbstract: <abstract>"
|
8 |
+
template (dict): A dictionary containing the template for the prompt.
|
9 |
+
Expected format:
|
10 |
+
{"prompt_input": "<template_string>"}
|
11 |
+
The template string should contain placeholders for
|
12 |
+
'instruction' and 'input'.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
str: A formatted string with the instruction and the input embedded in the template.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def abs_completion(usr_input, template, has_predefined_template=False):
|
19 |
+
instruction = "Please complete the abstract of a paper."
|
20 |
+
|
21 |
+
if has_predefined_template:
|
22 |
+
res = [
|
23 |
+
{"role": "system", "content": instruction},
|
24 |
+
{"role": "user", "content": usr_input},
|
25 |
+
]
|
26 |
+
else:
|
27 |
+
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
28 |
+
|
29 |
+
return res
|
tasks/citation_sentence.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generates a citation sentence based on the titles and abstracts of two papers.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
usr_input (str): A string containing the titles and abstracts of Paper A and Paper B.
|
6 |
+
The format should be:
|
7 |
+
"Title A: <title of paper A>\nAbstract A: <abstract of paper A>\nTitle B: <title of paper B>\nAbstract B: <abstract of paper B>"
|
8 |
+
template (dict): A dictionary containing a template for the prompt input. The key "prompt_input" should map to a string with placeholders for the instruction and input.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
str: A formatted string that combines the instruction and the prompt input with the provided titles and abstracts.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def citation_sentence(usr_input, template, has_predefined_template=False):
|
15 |
+
instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section. \n"
|
16 |
+
|
17 |
+
if has_predefined_template:
|
18 |
+
res = [
|
19 |
+
{"role": "system", "content": instruction},
|
20 |
+
{"role": "user", "content": usr_input},
|
21 |
+
]
|
22 |
+
else:
|
23 |
+
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
24 |
+
|
25 |
+
return res
|
tasks/gen_related_work.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generates the related work section for a given paper.
|
3 |
+
|
4 |
+
The input
|
5 |
+
- The input prompt is a string that contains the information of the paper for which the related work section needs to be generated.
|
6 |
+
- The input prompt should be in the following format:
|
7 |
+
Title of Paper: <title of the paper>
|
8 |
+
|
9 |
+
Abstract of Paper: <abstract of the paper>
|
10 |
+
The output
|
11 |
+
- The output is a string that contains the related work section for the given paper.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import json
|
16 |
+
import networkx as nx
|
17 |
+
import numpy as np
|
18 |
+
from tqdm import tqdm
|
19 |
+
from peft import PeftModel
|
20 |
+
from transformers import (AutoModel, AutoTokenizer, AutoModelForCausalLM, pipeline)
|
21 |
+
from tqdm import tqdm
|
22 |
+
import re
|
23 |
+
import pandas as pd
|
24 |
+
import os
|
25 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
26 |
+
from utils.utils import read_yaml_file
|
27 |
+
import datetime
|
28 |
+
|
29 |
+
|
30 |
+
class LitFM():
|
31 |
+
def __init__(self, graph_path, adapter_path):
|
32 |
+
self.graph_name = graph_path.split('.')[0].split('/')[-1] if '/' in graph_path else graph_path.split('.')[0]
|
33 |
+
self.batch_size = 32
|
34 |
+
self.neigh_num = 4
|
35 |
+
|
36 |
+
config = read_yaml_file('configs/config.yaml')
|
37 |
+
retrieval_graph_path = graph_path
|
38 |
+
|
39 |
+
self.pretrained_model = config['retriever']['embedder']
|
40 |
+
|
41 |
+
# define generation model
|
42 |
+
model_path = config['inference']["base_model"]
|
43 |
+
self.generation_tokenizer = AutoTokenizer.from_pretrained(model_path)
|
44 |
+
self.generation_tokenizer.model_max_length = 2048
|
45 |
+
if self.generation_tokenizer.pad_token is None:
|
46 |
+
self.generation_tokenizer.pad_token = self.generation_tokenizer.eos_token
|
47 |
+
self.generation_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
|
48 |
+
self.generation_model = PeftModel.from_pretrained(self.generation_model, adapter_path, adapter_name="instruction", torch_dtype=torch.float16)
|
49 |
+
self.model_pipeline = None
|
50 |
+
if self.generation_tokenizer.chat_template is not None:
|
51 |
+
self.model_pipeline = pipeline(
|
52 |
+
"text-generation",
|
53 |
+
model=model_path,
|
54 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
55 |
+
device_map="auto",
|
56 |
+
)
|
57 |
+
|
58 |
+
# define instruction models
|
59 |
+
self.instruction_pipe = pipeline(
|
60 |
+
"text-generation",
|
61 |
+
model=config["inference"]["gen_related_work_instruct_model"],
|
62 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
63 |
+
device_map="auto",
|
64 |
+
)
|
65 |
+
|
66 |
+
# load graph data for retrieval
|
67 |
+
def translate_graph(graph):
|
68 |
+
all_nodes = list(graph.nodes())
|
69 |
+
raw_id_2_id_dict = {}
|
70 |
+
id_2_raw_id_dict = {}
|
71 |
+
|
72 |
+
num = 0
|
73 |
+
for node in all_nodes:
|
74 |
+
raw_id_2_id_dict[node] = num
|
75 |
+
id_2_raw_id_dict[num] = node
|
76 |
+
num += 1
|
77 |
+
|
78 |
+
return raw_id_2_id_dict, id_2_raw_id_dict
|
79 |
+
|
80 |
+
whole_graph_data_raw = nx.read_gexf(retrieval_graph_path, node_type=None, relabel=False, version='1.2draft')
|
81 |
+
self.whole_graph_raw_id_2_id_dict, self.whole_graph_id_2_raw_id_dict = translate_graph(whole_graph_data_raw)
|
82 |
+
|
83 |
+
self.whole_graph_id_2_title_abs = dict()
|
84 |
+
for paper_id in whole_graph_data_raw.nodes():
|
85 |
+
title = whole_graph_data_raw.nodes()[paper_id]['title']
|
86 |
+
abstract = whole_graph_data_raw.nodes()[paper_id]['abstract']
|
87 |
+
self.whole_graph_id_2_title_abs[self.whole_graph_raw_id_2_id_dict[paper_id]] = [title, abstract]
|
88 |
+
|
89 |
+
# define prompt template
|
90 |
+
template_file_path = 'configs/alpaca.json'
|
91 |
+
with open(template_file_path) as fp:
|
92 |
+
self.template = json.load(fp)
|
93 |
+
self.human_instruction = ['### Input:', '### Response:']
|
94 |
+
|
95 |
+
|
96 |
+
def _generate_retrieval_prompt(self, data_point: dict):
|
97 |
+
instruction = "Please select the paper that is more likely to be cited by the paper from the list of candidate papers. Your answer MUST be **only the exact title** of the selected paper without generating ANY other text or section. Your answer MUST belong to the list of candidate papers.\n"
|
98 |
+
prompt_input = ""
|
99 |
+
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
100 |
+
prompt_input = prompt_input + "candidate papers: " + "\n"
|
101 |
+
for i in range(len(data_point['nei_titles'])):
|
102 |
+
prompt_input = prompt_input + str(i) + '. ' + data_point['nei_titles'][i] + "\n"
|
103 |
+
|
104 |
+
if self.model_pipeline is not None:
|
105 |
+
res = [
|
106 |
+
{"role": "system", "content": instruction},
|
107 |
+
{"role": "user", "content": prompt_input},
|
108 |
+
]
|
109 |
+
else:
|
110 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
111 |
+
|
112 |
+
return res
|
113 |
+
|
114 |
+
def _generate_sentence_prompt(self, data_point):
|
115 |
+
instruction = "Please generate the citation sentence of how the Paper cites paper B in its related work section."
|
116 |
+
|
117 |
+
prompt_input = ""
|
118 |
+
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
119 |
+
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
|
120 |
+
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
|
121 |
+
|
122 |
+
if self.model_pipeline is not None:
|
123 |
+
res = [
|
124 |
+
{"role": "system", "content": instruction},
|
125 |
+
{"role": "user", "content": prompt_input},
|
126 |
+
]
|
127 |
+
else:
|
128 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
129 |
+
|
130 |
+
return res
|
131 |
+
|
132 |
+
def _generate_topic_prompt(self, data_point):
|
133 |
+
prompt_input = ""
|
134 |
+
prompt_input = prompt_input + "Here are the information of the paper: \n"
|
135 |
+
prompt_input = prompt_input + data_point['usr_prompt'] + '\n'
|
136 |
+
prompt_input = prompt_input + "Directlty give me the topics you select.\n"
|
137 |
+
|
138 |
+
res = [
|
139 |
+
{"role": "system", "content": "I need to write the related work section for this paper. Could you suggest three most relevant topics to discuss in the related work section? Your answer should be strictly one topic after the other line by line with nothing else being generated and no further explanation/information.\n"},
|
140 |
+
{"role": "user", "content": prompt_input},
|
141 |
+
]
|
142 |
+
|
143 |
+
return res
|
144 |
+
|
145 |
+
def _generate_paragraph_prompt(self, data_point):
|
146 |
+
prompt_input = ""
|
147 |
+
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
148 |
+
prompt_input = prompt_input + "Topic of this paragraph: " + data_point['topic'] + "\n"
|
149 |
+
prompt_input = prompt_input + "Papers that should be cited in paragraph: \n"
|
150 |
+
|
151 |
+
i = data_point['paper_citation_indicator']
|
152 |
+
for paper_idx in range(len(data_point['nei_title'])):
|
153 |
+
prompt_input = prompt_input + "[" + str(i) + "]. " + data_point['nei_title'][paper_idx][0] + '.' + " Citation sentence of this paper in the paragraph: " + data_point['nei_sentence'][paper_idx] + '\n'
|
154 |
+
i += 1
|
155 |
+
|
156 |
+
prompt_input = prompt_input + "All the above cited papers should be included and each cited paper should be indicated with its index number. Note that you should not include the title of any paper\n"
|
157 |
+
|
158 |
+
res = [
|
159 |
+
{"role": "system", "content": "Please write a paragraph that review the research relationships between this paper and other cited papers.\n"},
|
160 |
+
{"role": "user", "content": prompt_input},
|
161 |
+
]
|
162 |
+
|
163 |
+
return res
|
164 |
+
|
165 |
+
def _generate_summary_prompt(self, data_point):
|
166 |
+
prompt_input = ""
|
167 |
+
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
168 |
+
prompt_input = prompt_input + "Paragraphs that should be combined: " + "\n"
|
169 |
+
|
170 |
+
i = 1
|
171 |
+
for para in data_point['paragraphs']:
|
172 |
+
prompt_input = prompt_input + " Paragraph " + str(i) + ": " + para + '\n'
|
173 |
+
i += 1
|
174 |
+
|
175 |
+
res = [
|
176 |
+
{"role": "system", "content": "Please combine the following paragraphs in a cohenrent way that also keeps the citations and make the flow between paragraphs more smoothly\nAdd a sentence at the beginning of each paragraph to clarify its connection to the previous ones. Do not include any other surrounding text and not add a references list at all\n"},
|
177 |
+
{"role": "user", "content": prompt_input},
|
178 |
+
]
|
179 |
+
|
180 |
+
return res
|
181 |
+
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def generate_text(prompt, tokenizer, model, temperature, top_p, repetition_penalty, max_new_tokens):
|
185 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
186 |
+
|
187 |
+
with torch.no_grad():
|
188 |
+
output = model.generate(
|
189 |
+
**inputs,
|
190 |
+
do_sample=True,
|
191 |
+
temperature=temperature,
|
192 |
+
top_p=top_p,
|
193 |
+
repetition_penalty=repetition_penalty,
|
194 |
+
max_new_tokens=max_new_tokens,
|
195 |
+
pad_token_id=tokenizer.pad_token_id,
|
196 |
+
eos_token_id=tokenizer.eos_token_id,
|
197 |
+
use_cache=True,
|
198 |
+
)
|
199 |
+
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
200 |
+
return output_text
|
201 |
+
|
202 |
+
def get_llm_response(self, prompt, model_type):
|
203 |
+
self.generation_model.set_adapter('instruction')
|
204 |
+
if model_type == 'zeroshot':
|
205 |
+
raw_output = self.instruction_pipe(
|
206 |
+
prompt,
|
207 |
+
max_new_tokens=8096,
|
208 |
+
temperature=0.9,
|
209 |
+
top_p=0.95,
|
210 |
+
repetition_penalty=1.15,
|
211 |
+
)[0]['generated_text'][-1]
|
212 |
+
|
213 |
+
if model_type == 'zeroshot_short':
|
214 |
+
raw_output = self.instruction_pipe(
|
215 |
+
prompt,
|
216 |
+
max_new_tokens=256,
|
217 |
+
temperature=0.9,
|
218 |
+
top_p=0.95,
|
219 |
+
repetition_penalty=1.15,
|
220 |
+
)[0]['generated_text'][-1]
|
221 |
+
|
222 |
+
if model_type == 'instruction':
|
223 |
+
self.generation_model.set_adapter('instruction')
|
224 |
+
if self.model_pipeline is not None:
|
225 |
+
raw_output = self.model_pipeline(
|
226 |
+
prompt,
|
227 |
+
temperature=0.9,
|
228 |
+
top_p=0.95,
|
229 |
+
repetition_penalty=1.15,
|
230 |
+
)[0]['generated_text'][-1]
|
231 |
+
else:
|
232 |
+
raw_output = self.generate_text(
|
233 |
+
prompt,
|
234 |
+
self.generation_tokenizer,
|
235 |
+
self.generation_model,
|
236 |
+
temperature=0.9,
|
237 |
+
top_p=0.95,
|
238 |
+
repetition_penalty=1.15,
|
239 |
+
max_new_tokens=256,
|
240 |
+
)
|
241 |
+
|
242 |
+
return raw_output
|
243 |
+
|
244 |
+
def single_paper_sentence_test(self, usr_prompt, t_title, t_abs):
|
245 |
+
datapoint = {'usr_prompt':usr_prompt, 't_title':t_title, 't_abs':t_abs}
|
246 |
+
prompt = self._generate_sentence_prompt(datapoint)
|
247 |
+
ans = self.get_llm_response(prompt, 'instruction')
|
248 |
+
res = ans.strip().split(self.human_instruction[1])[-1]
|
249 |
+
return res
|
250 |
+
|
251 |
+
def single_paper_retrieval_test(self, usr_prompt, candidates):
|
252 |
+
datapoint = {'usr_prompt':usr_prompt, 'nei_titles':list(candidates), 't_title': ''}
|
253 |
+
prompt = self._generate_retrieval_prompt(datapoint)
|
254 |
+
ans = self.get_llm_response(prompt, 'instruction')
|
255 |
+
res = ans.strip().split(self.human_instruction[1])[-1]
|
256 |
+
return res
|
257 |
+
|
258 |
+
def single_paper_topic_test(self, usr_prompt):
|
259 |
+
datapoint = {'usr_prompt': usr_prompt}
|
260 |
+
prompt = self._generate_topic_prompt(datapoint)
|
261 |
+
ans = self.get_llm_response(prompt, 'zeroshot_short')
|
262 |
+
res = ans['content']
|
263 |
+
res = res.replace('\n\n', '\n')
|
264 |
+
return res
|
265 |
+
|
266 |
+
def retrieval_for_one_query(self, id_2_title_abs, prompt):
|
267 |
+
if os.path.exists(f'datasets/{self.graph_name}_embeddings.parquet'):
|
268 |
+
all_query_embs = torch.tensor(np.array(pd.read_parquet(f'datasets/{self.graph_name}_embeddings.parquet')))
|
269 |
+
else:
|
270 |
+
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
|
271 |
+
model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5").to(device='cuda', dtype=torch.float16)
|
272 |
+
model.eval()
|
273 |
+
|
274 |
+
paper_list = list(id_2_title_abs.keys())
|
275 |
+
|
276 |
+
all_query_embs = torch.zeros(len(paper_list), 1024)
|
277 |
+
i = 0
|
278 |
+
batch_size = 200
|
279 |
+
candidate_emb_list = []
|
280 |
+
pbar = tqdm(total=len(paper_list))
|
281 |
+
while i < len(paper_list):
|
282 |
+
paper_batch = paper_list[i:i+batch_size]
|
283 |
+
paper_text_batch = []
|
284 |
+
for paper_id in paper_batch:
|
285 |
+
prompt = id_2_title_abs[paper_id][0] + id_2_title_abs[paper_id][1]
|
286 |
+
paper_text_batch.append(prompt)
|
287 |
+
inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
|
288 |
+
|
289 |
+
with torch.no_grad():
|
290 |
+
outputs = model(**inputs.to('cuda'))
|
291 |
+
candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
|
292 |
+
candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
|
293 |
+
candidate_emb_list.append(candidate_embeddings)
|
294 |
+
|
295 |
+
i += len(candidate_embeddings)
|
296 |
+
pbar.update(len(candidate_embeddings))
|
297 |
+
|
298 |
+
all_query_embs = torch.cat(candidate_emb_list, 0)
|
299 |
+
pd.DataFrame(all_query_embs.numpy()).to_parquet(f'datasets/{self.graph_name}_embeddings.parquet')
|
300 |
+
|
301 |
+
# get the embeddings of the prompt
|
302 |
+
pretrained_model_name = self.pretrained_model
|
303 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
304 |
+
LLM_model = AutoModel.from_pretrained(pretrained_model_name).cuda()
|
305 |
+
LLM_model.eval()
|
306 |
+
|
307 |
+
encoded_input = tokenizer([prompt], padding = True, truncation=True, max_length=512 , return_tensors='pt')
|
308 |
+
with torch.no_grad():
|
309 |
+
output = LLM_model(**encoded_input.to('cuda'), output_hidden_states=True).hidden_states[-1]
|
310 |
+
sentence_embedding = output[:, 0, :]
|
311 |
+
|
312 |
+
tmp_scores = cosine_similarity(sentence_embedding.to("cpu"), all_query_embs.to("cpu"))[0]
|
313 |
+
_, idxs = torch.sort(torch.tensor(tmp_scores), descending=True)
|
314 |
+
top_10 = [int(k) for k in idxs[:10]]
|
315 |
+
|
316 |
+
return [id_2_title_abs[i][0] for i in top_10], [self.whole_graph_id_2_raw_id_dict[i] for i in top_10]
|
317 |
+
|
318 |
+
|
319 |
+
def single_paper_related_work_generation(self, usr_prompt):
|
320 |
+
citation_papers = []
|
321 |
+
nei_sentence = []
|
322 |
+
|
323 |
+
# Get topics
|
324 |
+
retrieval_query = self.single_paper_topic_test(usr_prompt)
|
325 |
+
|
326 |
+
# Split topics
|
327 |
+
topic_num = 3
|
328 |
+
try:
|
329 |
+
split_topics = retrieval_query.strip().split('\n')
|
330 |
+
if split_topics[0] == '':
|
331 |
+
split_topics = split_topics[1:]
|
332 |
+
split_topics = split_topics[:topic_num]
|
333 |
+
except:
|
334 |
+
split_topics = retrieval_query.strip().split(':')
|
335 |
+
split_topics = split_topics.strip().split(';')
|
336 |
+
split_topics = split_topics[:topic_num]
|
337 |
+
if len(split_topics) > topic_num:
|
338 |
+
return ["too many topics", split_topics]
|
339 |
+
|
340 |
+
|
341 |
+
# Get top-5 papers for each topic
|
342 |
+
for retrieval_query in split_topics:
|
343 |
+
# retrieve papers
|
344 |
+
candidate_citation_papers, candidate_raw_ids = self.retrieval_for_one_query(self.whole_graph_id_2_title_abs, retrieval_query)
|
345 |
+
topic_specific_citation_papers = []
|
346 |
+
# select top-5 papers
|
347 |
+
for _ in range(5):
|
348 |
+
# picking most likely to be cited paper
|
349 |
+
selected_paper = self.single_paper_retrieval_test(usr_prompt, candidate_citation_papers).replace(' \n','').replace('\n','')
|
350 |
+
|
351 |
+
words = selected_paper.strip().split(' ')
|
352 |
+
index = -1
|
353 |
+
for w in words:
|
354 |
+
try:
|
355 |
+
index = int(w)
|
356 |
+
except:
|
357 |
+
pass
|
358 |
+
|
359 |
+
if index != -1 and index < len(candidate_citation_papers):
|
360 |
+
paper_title = candidate_citation_papers[index]
|
361 |
+
candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
|
362 |
+
topic_specific_citation_papers.append([paper_title, candidate_raw_ids[index]])
|
363 |
+
else:
|
364 |
+
for i, paper_title in enumerate(list(candidate_citation_papers)):
|
365 |
+
if paper_title.lower().replace(' ', '') in selected_paper.lower().replace(' ', '') or selected_paper.lower().replace(' ', '') in paper_title.lower().replace(' ', ''):
|
366 |
+
candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
|
367 |
+
topic_specific_citation_papers.append([paper_title, candidate_raw_ids[i]])
|
368 |
+
break
|
369 |
+
citation_papers.append(topic_specific_citation_papers)
|
370 |
+
|
371 |
+
|
372 |
+
# Remove empty lists
|
373 |
+
citation_papers = [x for x in citation_papers if x != []]
|
374 |
+
|
375 |
+
|
376 |
+
# Generate citation sentences
|
377 |
+
for topic_idx in range(len(citation_papers)):
|
378 |
+
topic_specific_nei_sentence = []
|
379 |
+
for paper_idx in range(len(citation_papers[topic_idx])):
|
380 |
+
sentence = self.single_paper_sentence_test(usr_prompt, citation_papers[topic_idx][paper_idx][0], "")
|
381 |
+
# Match \cite{...}
|
382 |
+
sentence = re.sub(r'\\cite\{[^{}]+\}', "", sentence)
|
383 |
+
topic_specific_nei_sentence.append(sentence)
|
384 |
+
nei_sentence.append(topic_specific_nei_sentence)
|
385 |
+
|
386 |
+
|
387 |
+
# Generate paragraphs
|
388 |
+
paragraphs = []
|
389 |
+
references = [] # Store references for citation
|
390 |
+
paper_citation_indicator = 1 # Indicator for citation paper
|
391 |
+
for topic_idx in range(len(citation_papers)):
|
392 |
+
datapoint = {'usr_prompt': usr_prompt,
|
393 |
+
'nei_title': citation_papers[topic_idx],
|
394 |
+
'nei_sentence': nei_sentence[topic_idx],
|
395 |
+
'topic': split_topics[topic_idx],
|
396 |
+
'paper_citation_indicator': paper_citation_indicator}
|
397 |
+
|
398 |
+
prompt = self._generate_paragraph_prompt(datapoint)
|
399 |
+
ans = self.get_llm_response(prompt, 'zeroshot')
|
400 |
+
res = ans['content']
|
401 |
+
paragraphs.append(res)
|
402 |
+
|
403 |
+
# Store referencess
|
404 |
+
for ref_idx, paper in enumerate(citation_papers[topic_idx]):
|
405 |
+
# Extract year and month from raw_id
|
406 |
+
raw_id = re.sub(r'[a-zA-Z/]+', '', paper[1])
|
407 |
+
year = raw_id[:2]
|
408 |
+
year = '19' + year if int(year) > 70 else '20' + year
|
409 |
+
month = datetime.date(1900, int(raw_id[2:4]), 1).strftime('%B')
|
410 |
+
|
411 |
+
references.append(f"[{paper_citation_indicator + ref_idx}] {paper[0]}, arXiv {raw_id}, {month} {year}")
|
412 |
+
# Update paper_citation_indicator
|
413 |
+
paper_citation_indicator = paper_citation_indicator + len(nei_sentence[topic_idx])
|
414 |
+
|
415 |
+
|
416 |
+
# Generate summary
|
417 |
+
datapoint = {'usr_prompt': usr_prompt, 'paragraphs': paragraphs}
|
418 |
+
prompt = self._generate_summary_prompt(datapoint)
|
419 |
+
ans = self.get_llm_response(prompt, 'zeroshot')
|
420 |
+
summary = ans['content']
|
421 |
+
|
422 |
+
# Append references to summary
|
423 |
+
summary_with_references = summary + "\n\n### References\n" + "\n".join(references)
|
424 |
+
|
425 |
+
return summary_with_references
|
426 |
+
|
427 |
+
|
428 |
+
def gen_related_work(message, graph_path, adapter_path):
|
429 |
+
litfm_instance = LitFM(graph_path, adapter_path)
|
430 |
+
return litfm_instance.single_paper_related_work_generation(message)
|
tasks/influential_papers.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Influential Papers Task
|
3 |
+
|
4 |
+
This module provides functionality to identify the most influential papers in a citation graph.
|
5 |
+
|
6 |
+
Functions:
|
7 |
+
influential_papers(K, graph):
|
8 |
+
Given an integer K and a citation graph, returns the K most influential papers based on the number of citations.
|
9 |
+
The function returns the title and abstract of each of the K most influential papers in a formatted string.
|
10 |
+
|
11 |
+
Usage:
|
12 |
+
The script reads configuration from a YAML file, loads a citation graph from a GEXF file, and prints the K most influential papers.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import datetime
|
16 |
+
import re
|
17 |
+
|
18 |
+
def influential_papers(message, graph):
|
19 |
+
# Get integer number from message
|
20 |
+
K = int(re.search(r'\d+', message).group())
|
21 |
+
|
22 |
+
in_degree = dict(graph.in_degree())
|
23 |
+
sorted_in_degree = sorted(in_degree.items(), key=lambda x: x[1], reverse=True)
|
24 |
+
|
25 |
+
most_cited_papers = []
|
26 |
+
for i in range(K):
|
27 |
+
node = sorted_in_degree[i]
|
28 |
+
paper = graph.nodes[node[0]]
|
29 |
+
most_cited_papers.append(paper)
|
30 |
+
|
31 |
+
resp = "Here are the most influential papers:\n"
|
32 |
+
for i, paper in enumerate(most_cited_papers):
|
33 |
+
full_paper_id = paper['label']
|
34 |
+
paper_id = re.sub(r'[a-zA-Z/]+', '', full_paper_id)
|
35 |
+
year = paper_id[:2]
|
36 |
+
year = '19' + year if int(year) > 70 else '20' + year
|
37 |
+
month = datetime.date(1900, int(paper_id[2:4]), 1).strftime('%B')
|
38 |
+
|
39 |
+
resp += f"{i+1}. Title: {paper['title']}, arXiv {full_paper_id}, {month} {year} \nAbstract: {paper['abstract']}\n"
|
40 |
+
|
41 |
+
return resp
|
tasks/intro_2_abs.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generate the abstract of a paper based on its introduction section.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
usr_prompt (str): The user-provided prompt containing the introduction section of the paper.
|
6 |
+
template (dict): A dictionary containing the template for generating the abstract.
|
7 |
+
context_window (int): The maximum length of the context window for the prompt input.
|
8 |
+
|
9 |
+
Returns:
|
10 |
+
str: The generated abstract based on the introduction section.
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
def intro_2_abs(usr_prompt, template, context_window, has_predefined_template=False):
|
15 |
+
instruction = "Please generate the abstract of paper based on its introduction section."
|
16 |
+
|
17 |
+
# Reduce it to make it fit
|
18 |
+
prompt_input = usr_prompt[:int(context_window*2)]
|
19 |
+
|
20 |
+
if has_predefined_template:
|
21 |
+
res = [
|
22 |
+
{"role": "system", "content": instruction},
|
23 |
+
{"role": "user", "content": prompt_input},
|
24 |
+
]
|
25 |
+
else:
|
26 |
+
res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
27 |
+
|
28 |
+
return res
|
tasks/link_pred.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Determine if paper A will cite paper B.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
usr_input (str): The user-provided input containing the titles and abstracts of papers A and B.
|
6 |
+
template (dict): A dictionary containing the template for generating the link prediction task.
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
str: The generated link prediction task based on the user input.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def link_pred(usr_input, template, has_predefined_template=False):
|
13 |
+
instruction = "Determine if paper A will cite paper B."
|
14 |
+
|
15 |
+
if has_predefined_template:
|
16 |
+
res = [
|
17 |
+
{"role": "system", "content": instruction},
|
18 |
+
{"role": "user", "content": usr_input},
|
19 |
+
]
|
20 |
+
else:
|
21 |
+
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
22 |
+
|
23 |
+
return res
|
tasks/paper_retrieval.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Retrieves the most likely paper to be cited by Paper A from a list of candidate papers based on user input.
|
3 |
+
Args:
|
4 |
+
usr_input (str): A string containing the title and abstract of Paper A followed by the titles and abstracts of candidate papers.
|
5 |
+
template (dict): A dictionary containing a template for formatting the prompt input.
|
6 |
+
Returns:
|
7 |
+
str: A string containing the prompt input for the user.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def paper_retrieval(usr_input, template, has_predefined_template=False):
|
11 |
+
instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
|
12 |
+
|
13 |
+
if has_predefined_template:
|
14 |
+
res = [
|
15 |
+
{"role": "system", "content": instruction},
|
16 |
+
{"role": "user", "content": usr_input},
|
17 |
+
]
|
18 |
+
else:
|
19 |
+
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
20 |
+
|
21 |
+
return res
|
train.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import transformers
|
5 |
+
import networkx as nx
|
6 |
+
from tqdm import tqdm
|
7 |
+
from peft import (LoraConfig, get_peft_model,
|
8 |
+
prepare_model_for_kbit_training)
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class QloraTrainer_CS:
|
14 |
+
def __init__(self, config: dict, use_predefined_graph=False):
|
15 |
+
self.config = config
|
16 |
+
self.use_predefined_graph = use_predefined_graph
|
17 |
+
self.tokenizer = None
|
18 |
+
self.base_model = None
|
19 |
+
self.adapter_model = None
|
20 |
+
self.merged_model = None
|
21 |
+
self.transformer_trainer = None
|
22 |
+
self.test_data = None
|
23 |
+
|
24 |
+
template_file_path = 'configs/alpaca.json'
|
25 |
+
with open(template_file_path) as fp:
|
26 |
+
self.template = json.load(fp)
|
27 |
+
|
28 |
+
|
29 |
+
def load_base_model(self):
|
30 |
+
model_id = self.config['inference']["base_model"]
|
31 |
+
print(model_id)
|
32 |
+
|
33 |
+
bnb_config = BitsAndBytesConfig(
|
34 |
+
load_in_8bit=True,
|
35 |
+
bnb_8bit_use_double_quant=True,
|
36 |
+
bnb_8bit_quant_type="nf8",
|
37 |
+
bnb_8bit_compute_dtype=torch.bfloat16
|
38 |
+
)
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
40 |
+
tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
|
41 |
+
if not tokenizer.pad_token:
|
42 |
+
tokenizer.pad_token = tokenizer.eos_token
|
43 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
|
44 |
+
if model.device.type != 'cuda':
|
45 |
+
model.to('cuda')
|
46 |
+
|
47 |
+
model.gradient_checkpointing_enable()
|
48 |
+
model = prepare_model_for_kbit_training(model)
|
49 |
+
|
50 |
+
self.tokenizer = tokenizer
|
51 |
+
self.base_model = model
|
52 |
+
|
53 |
+
|
54 |
+
def train(self):
|
55 |
+
# Set up lora config or load pre-trained adapter
|
56 |
+
lora_config = LoraConfig(
|
57 |
+
r=self.config['training']['qlora']['rank'],
|
58 |
+
lora_alpha=self.config['training']['qlora']['lora_alpha'],
|
59 |
+
target_modules=self.config['training']['qlora']['target_modules'],
|
60 |
+
lora_dropout=self.config['training']['qlora']['lora_dropout'],
|
61 |
+
bias="none",
|
62 |
+
task_type="CAUSAL_LM",
|
63 |
+
)
|
64 |
+
model = get_peft_model(self.base_model, lora_config)
|
65 |
+
self._print_trainable_parameters(model)
|
66 |
+
|
67 |
+
print("Start data preprocessing")
|
68 |
+
train_data = self._process_data_instruction()
|
69 |
+
|
70 |
+
print('Length of dataset: ', len(train_data))
|
71 |
+
|
72 |
+
print("Start training")
|
73 |
+
self.transformer_trainer = transformers.Trainer(
|
74 |
+
model=model,
|
75 |
+
train_dataset=train_data,
|
76 |
+
args=transformers.TrainingArguments(
|
77 |
+
per_device_train_batch_size=self.config["training"]['trainer_args']["per_device_train_batch_size"],
|
78 |
+
gradient_accumulation_steps=self.config['model_saving']['index'],
|
79 |
+
warmup_steps=self.config["training"]['trainer_args']["warmup_steps"],
|
80 |
+
num_train_epochs=self.config["training"]['trainer_args']["num_train_epochs"],
|
81 |
+
learning_rate=self.config["training"]['trainer_args']["learning_rate"],
|
82 |
+
lr_scheduler_type=self.config["training"]['trainer_args']["lr_scheduler_type"],
|
83 |
+
fp16=self.config["training"]['trainer_args']["fp16"],
|
84 |
+
logging_steps=self.config["training"]['trainer_args']["logging_steps"],
|
85 |
+
output_dir=self.config["training"]['trainer_args']["trainer_output_dir"],
|
86 |
+
report_to="wandb",
|
87 |
+
save_steps=self.config["training"]['trainer_args']["save_steps"],
|
88 |
+
),
|
89 |
+
data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
|
90 |
+
)
|
91 |
+
|
92 |
+
model.config.use_cache = False
|
93 |
+
|
94 |
+
self.transformer_trainer.train()
|
95 |
+
|
96 |
+
model_save_path = f"{self.config['model_saving']['model_output_dir']}/{self.config['model_saving']['model_name']}_{self.config['model_saving']['index']}_adapter_test_graph"
|
97 |
+
self.transformer_trainer.save_model(model_save_path)
|
98 |
+
|
99 |
+
self.adapter_model = model
|
100 |
+
print(f"Training complete, adapter model saved in {model_save_path}")
|
101 |
+
|
102 |
+
|
103 |
+
def _print_trainable_parameters(self, model):
|
104 |
+
"""
|
105 |
+
Prints the number of trainable parameters in the model.
|
106 |
+
"""
|
107 |
+
trainable_params = 0
|
108 |
+
all_param = 0
|
109 |
+
for _, param in model.named_parameters():
|
110 |
+
all_param += param.numel()
|
111 |
+
if param.requires_grad:
|
112 |
+
trainable_params += param.numel()
|
113 |
+
print(
|
114 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
def _process_data_instruction(self):
|
119 |
+
context_window = self.tokenizer.model_max_length
|
120 |
+
if self.use_predefined_graph:
|
121 |
+
graph_data = nx.read_gexf('datasets/' + self.config["training"]["predefined_graph_path"], node_type=None, relabel=False, version='1.2draft')
|
122 |
+
else:
|
123 |
+
graph_path = self.config['data_downloading']['download_directory'] + 'description/' + self.config['data_downloading']['gexf_file']
|
124 |
+
graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft')
|
125 |
+
raw_graph = graph_data
|
126 |
+
|
127 |
+
test_set_size = len(graph_data.nodes()) // 10
|
128 |
+
|
129 |
+
all_test_nodes = set(list(graph_data.nodes())[:test_set_size])
|
130 |
+
all_train_nodes = set(list(graph_data.nodes())[test_set_size:])
|
131 |
+
|
132 |
+
raw_id_2_title_abs = dict()
|
133 |
+
for paper_id in list(graph_data.nodes())[test_set_size:]:
|
134 |
+
title = graph_data.nodes()[paper_id]['title']
|
135 |
+
abstract = graph_data.nodes()[paper_id]['abstract']
|
136 |
+
raw_id_2_title_abs[paper_id] = [title, abstract]
|
137 |
+
|
138 |
+
raw_id_2_intro = dict()
|
139 |
+
for paper_id in list(graph_data.nodes())[test_set_size:]:
|
140 |
+
if graph_data.nodes[paper_id]['introduction'] != '':
|
141 |
+
intro = graph_data.nodes[paper_id]['introduction']
|
142 |
+
raw_id_2_intro[paper_id] = intro
|
143 |
+
|
144 |
+
raw_id_pair_2_sentence = dict()
|
145 |
+
for edge in list(graph_data.edges()):
|
146 |
+
sentence = graph_data.edges()[edge]['sentence']
|
147 |
+
raw_id_pair_2_sentence[edge] = sentence
|
148 |
+
|
149 |
+
|
150 |
+
test_data = []
|
151 |
+
edge_list = []
|
152 |
+
for edge in list(raw_graph.edges()):
|
153 |
+
src, tar = edge
|
154 |
+
if src not in all_test_nodes and tar not in all_test_nodes:
|
155 |
+
edge_list.append(edge)
|
156 |
+
else:
|
157 |
+
test_data.append(edge)
|
158 |
+
train_num = int(len(edge_list))
|
159 |
+
|
160 |
+
data_LP = []
|
161 |
+
data_abstract_2_title = []
|
162 |
+
data_paper_retrieval = []
|
163 |
+
data_citation_sentence = []
|
164 |
+
data_abs_completion = []
|
165 |
+
data_title_2_abs = []
|
166 |
+
data_intro_2_abs = []
|
167 |
+
|
168 |
+
|
169 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
170 |
+
source, target = sample[0], sample[1]
|
171 |
+
source_title, source_abs = raw_id_2_title_abs[source]
|
172 |
+
target_title, target_abs = raw_id_2_title_abs[target]
|
173 |
+
# LP prompt
|
174 |
+
rand_ind = random.choice(list(raw_id_2_title_abs.keys()))
|
175 |
+
neg_title, neg_abs = raw_id_2_title_abs[rand_ind]
|
176 |
+
data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'})
|
177 |
+
data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'})
|
178 |
+
|
179 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
180 |
+
source, target = sample[0], sample[1]
|
181 |
+
source_title, source_abs = raw_id_2_title_abs[source]
|
182 |
+
target_title, target_abs = raw_id_2_title_abs[target]
|
183 |
+
# abs_2_title prompt
|
184 |
+
data_abstract_2_title.append({'title':source_title, 'abs':source_abs})
|
185 |
+
data_abstract_2_title.append({'title':target_title, 'abs':target_abs})
|
186 |
+
|
187 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
188 |
+
source, target = sample[0], sample[1]
|
189 |
+
source_title, source_abs = raw_id_2_title_abs[source]
|
190 |
+
target_title, target_abs = raw_id_2_title_abs[target]
|
191 |
+
# paper_retrieval prompt
|
192 |
+
neighbors = list(nx.all_neighbors(raw_graph, source))
|
193 |
+
sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target]))
|
194 |
+
sampled_neg_nodes = random.sample(sample_node_list, 5) + [target]
|
195 |
+
random.shuffle(sampled_neg_nodes)
|
196 |
+
data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title})
|
197 |
+
|
198 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
199 |
+
source, target = sample[0], sample[1]
|
200 |
+
source_title, source_abs = raw_id_2_title_abs[source]
|
201 |
+
target_title, target_abs = raw_id_2_title_abs[target]
|
202 |
+
# citation_sentence prompt
|
203 |
+
citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)]
|
204 |
+
data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence})
|
205 |
+
|
206 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
207 |
+
source, target = sample[0], sample[1]
|
208 |
+
source_title, source_abs = raw_id_2_title_abs[source]
|
209 |
+
target_title, target_abs = raw_id_2_title_abs[target]
|
210 |
+
# abs_complete prompt
|
211 |
+
data_abs_completion.append({'title':source_title, 'abs':source_abs})
|
212 |
+
data_abs_completion.append({'title':target_title, 'abs':target_abs})
|
213 |
+
|
214 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
215 |
+
source, target = sample[0], sample[1]
|
216 |
+
source_title, source_abs = raw_id_2_title_abs[source]
|
217 |
+
target_title, target_abs = raw_id_2_title_abs[target]
|
218 |
+
# title_2_abs prompt
|
219 |
+
data_title_2_abs.append({'title':source_title, 'right_abs':source_abs})
|
220 |
+
data_title_2_abs.append({'title':target_title, 'right_abs':target_abs})
|
221 |
+
|
222 |
+
for sample in tqdm(random.sample(edge_list, train_num)):
|
223 |
+
source, target = sample[0], sample[1]
|
224 |
+
if source in raw_id_2_intro:
|
225 |
+
source_intro = raw_id_2_intro[source]
|
226 |
+
_, source_abs = raw_id_2_title_abs[source]
|
227 |
+
data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs})
|
228 |
+
if target in raw_id_2_intro:
|
229 |
+
target_intro = raw_id_2_intro[target]
|
230 |
+
_, target_abs = raw_id_2_title_abs[target]
|
231 |
+
data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs})
|
232 |
+
|
233 |
+
data_prompt = []
|
234 |
+
data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval]
|
235 |
+
data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP]
|
236 |
+
data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title]
|
237 |
+
data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence]
|
238 |
+
data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion]
|
239 |
+
data_prompt += [self._generate_title_2_abstract_prompt(data_point) for data_point in data_title_2_abs]
|
240 |
+
data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs]
|
241 |
+
|
242 |
+
print("Total prompts:", len(data_prompt))
|
243 |
+
random.shuffle(data_prompt)
|
244 |
+
if self.tokenizer.chat_template is None:
|
245 |
+
data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)]
|
246 |
+
else:
|
247 |
+
data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)]
|
248 |
+
|
249 |
+
return data_tokenized
|
250 |
+
|
251 |
+
|
252 |
+
def _generate_LP_prompt(self, data_point: dict):
|
253 |
+
instruction = "Determine if paper A will cite paper B."
|
254 |
+
|
255 |
+
prompt_input = ""
|
256 |
+
prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
|
257 |
+
prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
|
258 |
+
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
|
259 |
+
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
|
260 |
+
|
261 |
+
if self.tokenizer.chat_template is None:
|
262 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
263 |
+
res = f"{res}{data_point['label']}"
|
264 |
+
else:
|
265 |
+
res = [
|
266 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
267 |
+
{"role": "assistant", "content": data_point['label']}
|
268 |
+
]
|
269 |
+
|
270 |
+
return res
|
271 |
+
|
272 |
+
def _generate_abstract_2_title_prompt(self, data_point: dict):
|
273 |
+
instruction = "Please generate the title of paper based on its abstract."
|
274 |
+
|
275 |
+
prompt_input = ""
|
276 |
+
prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n"
|
277 |
+
|
278 |
+
if self.tokenizer.chat_template is None:
|
279 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
280 |
+
res = f"{res}{data_point['title']}"
|
281 |
+
else:
|
282 |
+
res = [
|
283 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
284 |
+
{"role": "assistant", "content": data_point['title']}
|
285 |
+
]
|
286 |
+
|
287 |
+
return res
|
288 |
+
|
289 |
+
def _generate_paper_retrieval_prompt(self, data_point: dict):
|
290 |
+
instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
|
291 |
+
|
292 |
+
prompt_input = ""
|
293 |
+
prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n"
|
294 |
+
prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n"
|
295 |
+
prompt_input = prompt_input + "candidate papers: " + "\n"
|
296 |
+
for i in range(len(data_point['sample_title'])):
|
297 |
+
prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n"
|
298 |
+
|
299 |
+
if self.tokenizer.chat_template is None:
|
300 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
301 |
+
res = f"{res}{data_point['right_title']}"
|
302 |
+
else:
|
303 |
+
res = [
|
304 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
305 |
+
{"role": "assistant", "content": data_point['right_title']}
|
306 |
+
]
|
307 |
+
|
308 |
+
return res
|
309 |
+
|
310 |
+
def _generate_citation_sentence_prompt(self, data_point: dict):
|
311 |
+
instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section."
|
312 |
+
|
313 |
+
prompt_input = ""
|
314 |
+
prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
|
315 |
+
prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
|
316 |
+
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
|
317 |
+
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
|
318 |
+
|
319 |
+
if self.tokenizer.chat_template is None:
|
320 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
321 |
+
res = f"{res}{data_point['sentence']}"
|
322 |
+
else:
|
323 |
+
res = [
|
324 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
325 |
+
{"role": "assistant", "content": data_point['sentence']}
|
326 |
+
]
|
327 |
+
|
328 |
+
return res
|
329 |
+
|
330 |
+
def _generate_abstract_completion_prompt(self, data_point: dict):
|
331 |
+
instruction = "Please complete the abstract of a paper."
|
332 |
+
|
333 |
+
prompt_input = ""
|
334 |
+
prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n"
|
335 |
+
|
336 |
+
split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))]
|
337 |
+
prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n"
|
338 |
+
|
339 |
+
if self.tokenizer.chat_template is None:
|
340 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
341 |
+
res = f"{res}{data_point['abs']}"
|
342 |
+
else:
|
343 |
+
res = [
|
344 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
345 |
+
{"role": "assistant", "content": data_point['abs']}
|
346 |
+
]
|
347 |
+
|
348 |
+
return res
|
349 |
+
|
350 |
+
def _generate_title_2_abstract_prompt(self, data_point: dict):
|
351 |
+
instruction = "Please generate the abstract of paper based on its title."
|
352 |
+
|
353 |
+
prompt_input = ""
|
354 |
+
prompt_input = prompt_input + "Title: " + data_point['title'] + "\n"
|
355 |
+
|
356 |
+
if self.tokenizer.chat_template is None:
|
357 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
358 |
+
res = f"{res}{data_point['right_abs']}"
|
359 |
+
else:
|
360 |
+
res = [
|
361 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
362 |
+
{"role": "assistant", "content": data_point['right_abs']}
|
363 |
+
]
|
364 |
+
|
365 |
+
return res
|
366 |
+
|
367 |
+
def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window):
|
368 |
+
instruction = "Please generate the abstract of paper based on its introduction section."
|
369 |
+
|
370 |
+
prompt_input = ""
|
371 |
+
prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n"
|
372 |
+
|
373 |
+
# Reduce it to make it fit
|
374 |
+
prompt_input = prompt_input[:int(context_window*2)]
|
375 |
+
|
376 |
+
if self.tokenizer.chat_template is None:
|
377 |
+
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
378 |
+
res = f"{res}{data_point['abs']}"
|
379 |
+
else:
|
380 |
+
res = [
|
381 |
+
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
|
382 |
+
{"role": "assistant", "content": data_point['abs']}
|
383 |
+
]
|
384 |
+
|
385 |
+
return res
|
utils/de-macro.py
ADDED
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python -O
|
2 |
+
|
3 |
+
r"""
|
4 |
+
Copyright 2005-2020 Peter Gacs
|
5 |
+
Licensed under the Academic Free Licence version 2.1
|
6 |
+
|
7 |
+
DE-MACRO
|
8 |
+
|
9 |
+
Version 1.4.1 - A small typo corrected.
|
10 |
+
|
11 |
+
Version 1.4 - Luca Citi made it python2.7 and python3 compatible.
|
12 |
+
Peter Gacs improved the parsing of \input{<filename>},
|
13 |
+
and made @ a letter in the style files.
|
14 |
+
Version 1.3 - this version is much more conservative about deleting
|
15 |
+
comments and inserting or deleting blank space: tries to
|
16 |
+
leave in all comments, adds space only when necessary, and
|
17 |
+
tries not to delete space in the main text.
|
18 |
+
The motivating comments came from Daniel Webb.
|
19 |
+
Version 1.2 - a syntactical bug corrected, thanks Brian de Alwis!
|
20 |
+
|
21 |
+
|
22 |
+
PURPOSE
|
23 |
+
|
24 |
+
This program can eliminate most private macros from a LaTeX file.
|
25 |
+
Applications:
|
26 |
+
- your publisher has difficulty dealing with many private macros
|
27 |
+
- you cooperate with colleagues who do not understand your macros
|
28 |
+
- preprocessing before a system like latex2html, which is somewhat
|
29 |
+
unpredictable with private macros.
|
30 |
+
|
31 |
+
It cannot be used to eliminate more complex macros that rely on
|
32 |
+
more programming-like constructs in style files. In particular, it will
|
33 |
+
not replace style files that have options.
|
34 |
+
|
35 |
+
USAGE
|
36 |
+
|
37 |
+
de-macro [--defs <defs-db>] <tex-file-1>[.tex] [<tex-file-2>[.tex] ...]
|
38 |
+
|
39 |
+
Simplest example: de-macro testament
|
40 |
+
|
41 |
+
(As you see, the <> is used only in the notation of this documentation,
|
42 |
+
you should not type it.)
|
43 |
+
|
44 |
+
If <tex-file-i> contains a command \usepackage{<defs-file>-private}
|
45 |
+
then the file <defs-file>-private.sty will be read, and its macros will be
|
46 |
+
replaced in <tex-file-i> with their definitions.
|
47 |
+
The result is in <tex-file-i>-clean.tex.
|
48 |
+
|
49 |
+
Only newcommand, renewcommand, newenvironment, and renewenvironment are
|
50 |
+
understood (it does not matter, whether you write new or renew).
|
51 |
+
These can be nested but do not be too clever, since I do not
|
52 |
+
guarantee the same expansion order as in TeX.
|
53 |
+
|
54 |
+
FILES
|
55 |
+
|
56 |
+
<tex-file-1>.db
|
57 |
+
<tex-file>-clean.tex
|
58 |
+
<defs-file>-private.sty
|
59 |
+
|
60 |
+
For speed, a macro database file called <defs-file>.db is created.
|
61 |
+
If such a file exists already then it is used.
|
62 |
+
If <defs-file>-private.sty is older than <tex-file-1>.db then it will not
|
63 |
+
be used.
|
64 |
+
|
65 |
+
It is possible to specify another database filename via --defs <defs-db>.
|
66 |
+
Then <defs-db>.db will be used.
|
67 |
+
|
68 |
+
For each <tex-file-i>, a file <tex-file-i>-clean.tex will be produced.
|
69 |
+
If <tex-file-i>-clean.tex is newer than <tex-file-i>.tex then it stays.
|
70 |
+
|
71 |
+
INPUT COMMAND
|
72 |
+
|
73 |
+
If a tex file contains a command \input{<tex-file-j>} or \input <tex-file-j>
|
74 |
+
then <tex-file-j>.tex is processed recursively, and <tex-file-j>-clean.tex
|
75 |
+
will be inserted into the final output.
|
76 |
+
For speed, if <tex-file-j>-clean.tex is newer than <tex-file-j>.tex
|
77 |
+
then <tex-file-j>.tex will not be reprocessed.
|
78 |
+
|
79 |
+
The dependency checking is not sophisticated, so if you rewrite some macros
|
80 |
+
then remove all *-clean.tex files!
|
81 |
+
|
82 |
+
"""
|
83 |
+
|
84 |
+
import sys, os, re, shelve
|
85 |
+
|
86 |
+
# Utilities
|
87 |
+
|
88 |
+
class No_detail:
|
89 |
+
strerror = ""
|
90 |
+
|
91 |
+
no_detail = No_detail()
|
92 |
+
|
93 |
+
|
94 |
+
class Error(Exception):
|
95 |
+
"""Base class for exceptions in this module."""
|
96 |
+
pass
|
97 |
+
|
98 |
+
class Empty_text_error(Error):
|
99 |
+
"""Exception raised for errors in the input.
|
100 |
+
|
101 |
+
Attributes:
|
102 |
+
data -- data that was found empty
|
103 |
+
message
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, data, message):
|
107 |
+
self.data = data
|
108 |
+
self.message = message
|
109 |
+
|
110 |
+
def warn(error_message, detail = no_detail):
|
111 |
+
sys.stderr.write(error_message + "\n")
|
112 |
+
if no_detail != detail:
|
113 |
+
sys.stderr.write(detail.strerror + "\n")
|
114 |
+
|
115 |
+
def die(error_message, detail = no_detail):
|
116 |
+
warn(error_message, detail = no_detail)
|
117 |
+
sys.exit(1)
|
118 |
+
|
119 |
+
def getopt_map(one_letter_opts, long_optlist):
|
120 |
+
"Turns long options into an option map, using getopt."
|
121 |
+
import getopt
|
122 |
+
optlist, args = getopt.getopt(sys.argv[1:],
|
123 |
+
one_letter_opts, long_optlist)
|
124 |
+
opt_map = {}
|
125 |
+
for pair in optlist: opt_map[pair[0]] = pair[1] or 1
|
126 |
+
return opt_map, args
|
127 |
+
|
128 |
+
def newer(file1, file2):
|
129 |
+
|
130 |
+
if not os.path.isfile(file1):
|
131 |
+
return False
|
132 |
+
|
133 |
+
try:
|
134 |
+
stat_return = os.lstat(file1)
|
135 |
+
except OSError as detail:
|
136 |
+
die("lstat " + file1 + " failed:", detail)
|
137 |
+
time1 = stat_return.st_mtime
|
138 |
+
|
139 |
+
try:
|
140 |
+
stat_return = os.lstat(file2)
|
141 |
+
except OSError as detail:
|
142 |
+
die("lstat " + file2 + " failed:", detail)
|
143 |
+
time2 = stat_return.st_mtime
|
144 |
+
|
145 |
+
return time1 > time2
|
146 |
+
|
147 |
+
def cut_extension(filename, ext):
|
148 |
+
"""
|
149 |
+
If filename has extension ext (including the possible dot),
|
150 |
+
it will be cut off.
|
151 |
+
"""
|
152 |
+
file = filename
|
153 |
+
index = filename.rfind(ext)
|
154 |
+
if 0 <= index and len(file)-len(ext) == index:
|
155 |
+
file = file[:index]
|
156 |
+
return file
|
157 |
+
|
158 |
+
|
159 |
+
class Stream:
|
160 |
+
data = None
|
161 |
+
pos = None
|
162 |
+
item = None
|
163 |
+
|
164 |
+
def legal(self):
|
165 |
+
return 0 <= self.pos and self.pos < len(self.data)
|
166 |
+
|
167 |
+
def uplegal(self):
|
168 |
+
return self.pos < len(self.data)
|
169 |
+
|
170 |
+
def __init__(self, data_v = None):
|
171 |
+
self.data = data_v
|
172 |
+
if self.data:
|
173 |
+
self.pos = 0
|
174 |
+
self.item = self.data[self.pos]
|
175 |
+
|
176 |
+
def next(self):
|
177 |
+
self.pos += 1
|
178 |
+
if self.pos < len(self.data):
|
179 |
+
self.item = self.data[self.pos]
|
180 |
+
return self.item
|
181 |
+
|
182 |
+
def reset(self):
|
183 |
+
if self.data and 0 < len(self.data):
|
184 |
+
self.pos = 0
|
185 |
+
self.item = self.data[0]
|
186 |
+
return self.item
|
187 |
+
|
188 |
+
|
189 |
+
# Basic classes
|
190 |
+
|
191 |
+
blank_re = re.compile(r"\s")
|
192 |
+
blanked_filename_re = re.compile(r"^\s+(\w*)\s+")
|
193 |
+
braced_filename_re = re.compile(r"^\s*{\s*(\w*)\s*}")
|
194 |
+
blank_or_rbrace_re = re.compile(r"[\s}]")
|
195 |
+
pos_digit_re = re.compile(r"[1-9]")
|
196 |
+
|
197 |
+
def isletter(c, isatletter=False):
|
198 |
+
if "@" == c:
|
199 |
+
return isatletter
|
200 |
+
else:
|
201 |
+
return c.isalpha()
|
202 |
+
|
203 |
+
class Token:
|
204 |
+
"""Type 0 means ordinary character, types 1,2 mean escape sequence
|
205 |
+
(without the \ ), type 3 means comment.
|
206 |
+
"""
|
207 |
+
simple_ty = 0
|
208 |
+
esc_symb_ty = 1
|
209 |
+
esc_str_ty = 2
|
210 |
+
comment_ty = 3
|
211 |
+
|
212 |
+
type = simple_ty
|
213 |
+
val = " "
|
214 |
+
|
215 |
+
def __init__(self, type_v=simple_ty, val_v=" "):
|
216 |
+
self.type = type_v
|
217 |
+
self.val = val_v
|
218 |
+
|
219 |
+
def show(self):
|
220 |
+
out = ""
|
221 |
+
if simple_ty == self.type or comment_ty == self.type:
|
222 |
+
out = self.val
|
223 |
+
else:
|
224 |
+
out = "\\" + self.val
|
225 |
+
return out
|
226 |
+
|
227 |
+
|
228 |
+
# Constants
|
229 |
+
|
230 |
+
g_token = Token(0," ") # generic token
|
231 |
+
simple_ty = g_token.simple_ty
|
232 |
+
comment_ty = g_token.comment_ty
|
233 |
+
esc_symb_ty = g_token.esc_symb_ty
|
234 |
+
esc_str_ty = g_token.esc_str_ty
|
235 |
+
|
236 |
+
|
237 |
+
def detokenize(text, isatletter=False):
|
238 |
+
"""
|
239 |
+
Input is a list of tokens.
|
240 |
+
Output is a string.
|
241 |
+
"""
|
242 |
+
out = ""
|
243 |
+
if 0 == len(text):
|
244 |
+
return
|
245 |
+
pos = 0
|
246 |
+
out += text[pos].show()
|
247 |
+
pos += 1
|
248 |
+
while pos < len(text):
|
249 |
+
previtem = text[pos-1]
|
250 |
+
item = text[pos]
|
251 |
+
"""Insert a separating space after an escape sequence if it is a
|
252 |
+
string and is followed by a letter."""
|
253 |
+
if (esc_str_ty == previtem.type
|
254 |
+
and simple_ty == item.type and isletter(item.val[0], isatletter)):
|
255 |
+
out += " "
|
256 |
+
out += item.show()
|
257 |
+
pos += 1
|
258 |
+
return out
|
259 |
+
|
260 |
+
|
261 |
+
def strip_comments(text):
|
262 |
+
"""
|
263 |
+
Input is a list of tokens.
|
264 |
+
Output is the same list except the comment tokens.
|
265 |
+
"""
|
266 |
+
out = []
|
267 |
+
for token in text:
|
268 |
+
if not comment_ty == token.type:
|
269 |
+
out.append(token)
|
270 |
+
return out
|
271 |
+
|
272 |
+
class Group:
|
273 |
+
"""type 0 means a token, type 1 means contents of a group within {}
|
274 |
+
"""
|
275 |
+
token_ty = 0
|
276 |
+
group_ty = 1
|
277 |
+
type = token_ty
|
278 |
+
val = [] # Value is a token list.
|
279 |
+
|
280 |
+
def __init__(self, type_v, val_v):
|
281 |
+
self.type = type_v
|
282 |
+
self.val = val_v
|
283 |
+
|
284 |
+
def show(self):
|
285 |
+
if token_ty == self.type:
|
286 |
+
return self.val.show()
|
287 |
+
else:
|
288 |
+
return "{%s}" % detokenize(self.val)
|
289 |
+
|
290 |
+
# Constants
|
291 |
+
|
292 |
+
g_group = Group(0, [])
|
293 |
+
token_ty = g_group.token_ty
|
294 |
+
group_ty = g_group.group_ty
|
295 |
+
|
296 |
+
|
297 |
+
def tokenize(in_str, isatletter=False):
|
298 |
+
"""Returns a list of tokens.
|
299 |
+
"""
|
300 |
+
text = []
|
301 |
+
cs = Char_stream(in_str)
|
302 |
+
cs.reset()
|
303 |
+
if not cs.legal():
|
304 |
+
raise Error("No string to tokenize.")
|
305 |
+
while cs.uplegal():
|
306 |
+
if "%" == cs.item:
|
307 |
+
comment = cs.scan_comment_token()
|
308 |
+
text.append(Token(comment_ty, comment))
|
309 |
+
elif "\\" != cs.item:
|
310 |
+
text.append(Token(simple_ty, cs.item))
|
311 |
+
cs.next()
|
312 |
+
else:
|
313 |
+
cs.next()
|
314 |
+
name = cs.scan_escape_token(isatletter)
|
315 |
+
if isletter(name[0], isatletter):
|
316 |
+
token = Token(esc_str_ty, name)
|
317 |
+
else:
|
318 |
+
token = Token(esc_symb_ty, name)
|
319 |
+
text.append(token)
|
320 |
+
if "makeatletter" == name:
|
321 |
+
isatletter=True
|
322 |
+
elif "makeatother" == name:
|
323 |
+
isatletter=False
|
324 |
+
return text
|
325 |
+
|
326 |
+
|
327 |
+
class Command_def:
|
328 |
+
name = "1"
|
329 |
+
numargs = 0
|
330 |
+
body= ""
|
331 |
+
|
332 |
+
def __init__(self, name_v, numargs_v, body_v):
|
333 |
+
self.name = name_v
|
334 |
+
self.numargs = numargs_v
|
335 |
+
self.body = body_v
|
336 |
+
|
337 |
+
def show(self):
|
338 |
+
out = "\\newcommand{\\%s}" % (self.name)
|
339 |
+
if 0 < self.numargs:
|
340 |
+
out += "[%d]" % self.numargs
|
341 |
+
out += "{%s}" % detokenize(self.body)
|
342 |
+
return out
|
343 |
+
|
344 |
+
|
345 |
+
class Env_def:
|
346 |
+
name = "1"
|
347 |
+
numargs = 0
|
348 |
+
begin = ""
|
349 |
+
end = ""
|
350 |
+
|
351 |
+
def __init__(self, name_v, numargs_v, begin_v, end_v):
|
352 |
+
self.name = name_v
|
353 |
+
self.numargs = numargs_v
|
354 |
+
self.begin = begin_v
|
355 |
+
self.end = end_v
|
356 |
+
|
357 |
+
def show(self):
|
358 |
+
out = "\\newenvironment{%s}" % self.name
|
359 |
+
if 0 < self.numargs:
|
360 |
+
out += "[%d]" % self.numargs
|
361 |
+
out += "{%s}" % detokenize(self.begin)
|
362 |
+
out += "{%s}" % detokenize(self.end)
|
363 |
+
return out
|
364 |
+
|
365 |
+
|
366 |
+
class Command_instance:
|
367 |
+
name = "1"
|
368 |
+
args = []
|
369 |
+
|
370 |
+
def __init__(self, name_v, args_v):
|
371 |
+
self.name = name_v
|
372 |
+
self.args = args_v
|
373 |
+
|
374 |
+
def show(self):
|
375 |
+
out = "\\"+self.name
|
376 |
+
for arg in self.args:
|
377 |
+
out += "{%s}" % detokenize(arg)
|
378 |
+
return out
|
379 |
+
|
380 |
+
|
381 |
+
class Env_instance:
|
382 |
+
name = "1"
|
383 |
+
args = []
|
384 |
+
|
385 |
+
def __init__(self, name_v, args_v, body_v):
|
386 |
+
self.name = name_v
|
387 |
+
self.args = args_v
|
388 |
+
self.body = body_v
|
389 |
+
|
390 |
+
def show(self):
|
391 |
+
out = "\\begin{%s}" % self.name
|
392 |
+
for arg in self.args:
|
393 |
+
out += "{%s}" % detokenize(arg)
|
394 |
+
out += detokenize(self.body)
|
395 |
+
out += "\\end{%s}" % self.name
|
396 |
+
return out
|
397 |
+
|
398 |
+
class Char_stream(Stream):
|
399 |
+
|
400 |
+
def scan_escape_token(self, isatletter=False):
|
401 |
+
"""
|
402 |
+
Starts after the escape sign, assumes that it is scanning a symbol.
|
403 |
+
Returns a token-string.
|
404 |
+
"""
|
405 |
+
out = self.item # Continue only if this is a letter.
|
406 |
+
item = self.next()
|
407 |
+
if isletter(out, isatletter):
|
408 |
+
while self.uplegal() and isletter(item, isatletter):
|
409 |
+
out += item
|
410 |
+
item = self.next()
|
411 |
+
return out
|
412 |
+
|
413 |
+
def scan_comment_token(self):
|
414 |
+
"""
|
415 |
+
Starts at the comment sign %, assumes that it is scanning a comment.
|
416 |
+
Returns the whole comment string,
|
417 |
+
including the % and all empty space after it.
|
418 |
+
"""
|
419 |
+
comment = ""
|
420 |
+
while self.uplegal() and "\n" != self.item:
|
421 |
+
comment += self.item
|
422 |
+
self.next()
|
423 |
+
while self.uplegal() and blank_re.match(self.item):
|
424 |
+
comment += self.item
|
425 |
+
self.next()
|
426 |
+
return comment
|
427 |
+
|
428 |
+
def scan_input_filename(self):
|
429 |
+
"""We have just read an \input token. The next group or word will be
|
430 |
+
interpreted as a filename (possibly without .tex). Filenames should not begin with spaces.
|
431 |
+
Return the filename.
|
432 |
+
"""
|
433 |
+
item = self.item
|
434 |
+
file = ""
|
435 |
+
while self.uplegal() and blank_re.match(self.item):
|
436 |
+
item = self.next()
|
437 |
+
if "{" == item:
|
438 |
+
item = self.next()
|
439 |
+
while self.uplegal() and not "}" == item:
|
440 |
+
file += item
|
441 |
+
item = self.next()
|
442 |
+
self.next()
|
443 |
+
else:
|
444 |
+
while self.uplegal() and not blank_re.match(item):
|
445 |
+
file += item
|
446 |
+
item = self.next()
|
447 |
+
return file
|
448 |
+
|
449 |
+
def scan_package_filenames(self):
|
450 |
+
r"""We just read a \usepackage token. The next group will be
|
451 |
+
interpreted as a list of filenames (without .sty) separated by commas.
|
452 |
+
Return the list.
|
453 |
+
"""
|
454 |
+
item = self.item
|
455 |
+
while self.uplegal() and blank_re.match(item):
|
456 |
+
item = self.next()
|
457 |
+
file = ""
|
458 |
+
if not "{" == item:
|
459 |
+
raise Error("\\usepackage not followed by brace.")
|
460 |
+
item = self.next()
|
461 |
+
while self.uplegal() and not blank_or_rbrace_re.match(item):
|
462 |
+
file += item
|
463 |
+
item = self.next()
|
464 |
+
self.next()
|
465 |
+
return file.split(",")
|
466 |
+
|
467 |
+
|
468 |
+
class Tex_stream(Stream):
|
469 |
+
|
470 |
+
defs = ({}, {})
|
471 |
+
defs_db = "x"
|
472 |
+
defs_db_file = "x.db"
|
473 |
+
debug = False
|
474 |
+
|
475 |
+
def smart_tokenize(self, in_str, handle_inputs=False, isatletter=False):
|
476 |
+
"""Returns a list of tokens.
|
477 |
+
It may interpret and carry out all \input commands.
|
478 |
+
"""
|
479 |
+
self.data = []
|
480 |
+
text = self.data
|
481 |
+
cs = Char_stream(in_str)
|
482 |
+
cs.reset()
|
483 |
+
if not cs.legal():
|
484 |
+
raise Error("No string to tokenize.")
|
485 |
+
while cs.uplegal():
|
486 |
+
if "%" == cs.item:
|
487 |
+
comment = cs.scan_comment_token()
|
488 |
+
text.append(Token(comment_ty, comment))
|
489 |
+
elif "\\" != cs.item:
|
490 |
+
text.append(Token(simple_ty, cs.item))
|
491 |
+
cs.next()
|
492 |
+
else:
|
493 |
+
cs.next()
|
494 |
+
name = cs.scan_escape_token(isatletter)
|
495 |
+
if "input" == name and handle_inputs:
|
496 |
+
file = cs.scan_input_filename()
|
497 |
+
to_add = self.process_if_newer(file)
|
498 |
+
text.extend(to_add)
|
499 |
+
elif "usepackage" == name:
|
500 |
+
while cs.uplegal() and blank_re.match(cs.item):
|
501 |
+
cs.next()
|
502 |
+
if "[" == cs.item: # Packages with options will not be processed.
|
503 |
+
text.extend([Token(esc_str_ty, "usepackage"),
|
504 |
+
Token(simple_ty, "[")])
|
505 |
+
cs.next()
|
506 |
+
continue
|
507 |
+
files = cs.scan_package_filenames()
|
508 |
+
i = 0
|
509 |
+
while i < len(files): # process private packages
|
510 |
+
file = files[i]
|
511 |
+
p = file.rfind("-private")
|
512 |
+
if p < 0 or not len(file) - len("-private") == p:
|
513 |
+
i += 1
|
514 |
+
continue
|
515 |
+
defs_db_file = file+".db"
|
516 |
+
self.add_defs(file)
|
517 |
+
del files[i:(i+1)]
|
518 |
+
if files: # non-private packages left
|
519 |
+
group_content = ",".join(files)
|
520 |
+
to_add_str = "\\usepackage{%s}" % (group_content)
|
521 |
+
to_add = tokenize(to_add_str,isatletter)
|
522 |
+
text.extend(to_add)
|
523 |
+
else:
|
524 |
+
if isletter(name[0], isatletter):
|
525 |
+
token = Token(esc_str_ty, name)
|
526 |
+
else:
|
527 |
+
token = Token(esc_symb_ty, name)
|
528 |
+
text.append(token)
|
529 |
+
if "makeatletter" == name:
|
530 |
+
isatletter=True
|
531 |
+
elif "makeatother" == name:
|
532 |
+
isatletter=False
|
533 |
+
self.reset()
|
534 |
+
return self.data
|
535 |
+
|
536 |
+
def smart_detokenize(self,isatletter=False):
|
537 |
+
"""
|
538 |
+
Output is a string.
|
539 |
+
If the list contains an \input{file} then the content of file
|
540 |
+
file-clean.tex replaces it in the output.
|
541 |
+
"""
|
542 |
+
self.reset()
|
543 |
+
if not self.legal():
|
544 |
+
return ""
|
545 |
+
out = ""
|
546 |
+
previtem = None
|
547 |
+
while self.uplegal():
|
548 |
+
item = self.item
|
549 |
+
"""Insert a separating space after an escape sequence if it is a
|
550 |
+
string and is followed by a letter."""
|
551 |
+
if (None != previtem and esc_str_ty == previtem.type
|
552 |
+
and simple_ty == item.type and isletter(item.val[0], isatletter)):
|
553 |
+
out += " "
|
554 |
+
previtem = item
|
555 |
+
if not (esc_str_ty == item.type and "input" == item.val):
|
556 |
+
out += item.show()
|
557 |
+
self.next()
|
558 |
+
else:
|
559 |
+
self.next()
|
560 |
+
group = self.scan_group()
|
561 |
+
file = detokenize(group.val)
|
562 |
+
clean_file = "%s-clean.tex" % (file)
|
563 |
+
print("Reading file %s" % (clean_file))
|
564 |
+
fp = open(clean_file,"r")
|
565 |
+
content = fp.read()
|
566 |
+
fp.close()
|
567 |
+
out += content
|
568 |
+
return out
|
569 |
+
|
570 |
+
# Basic tex scanning
|
571 |
+
|
572 |
+
def skip_blank_tokens(self): # we also skip comment tokens.
|
573 |
+
item = self.item
|
574 |
+
while (self.uplegal() and
|
575 |
+
(comment_ty == item.type or
|
576 |
+
(simple_ty == item.type and blank_re.match(item.val)))):
|
577 |
+
item = self.next()
|
578 |
+
return item
|
579 |
+
|
580 |
+
def scan_group(self):
|
581 |
+
"""Returns group.
|
582 |
+
"""
|
583 |
+
if not self.legal():
|
584 |
+
raise Error("No group to scan.")
|
585 |
+
item = self.item
|
586 |
+
if not (simple_ty == item.type and "{" == item.val):
|
587 |
+
return Group(token_ty, [self.item])
|
588 |
+
count = 1
|
589 |
+
group = []
|
590 |
+
item = self.next()
|
591 |
+
while count and self.uplegal():
|
592 |
+
if simple_ty == item.type:
|
593 |
+
if "{" == item.val:
|
594 |
+
count += 1
|
595 |
+
elif "}" == item.val:
|
596 |
+
count -= 1
|
597 |
+
if count != 0:
|
598 |
+
group.append(item)
|
599 |
+
item = self.next()
|
600 |
+
return Group(group_ty, group)
|
601 |
+
|
602 |
+
# Command and environment definitions
|
603 |
+
|
604 |
+
def scan_command_name(self):
|
605 |
+
"""Returns name.
|
606 |
+
"""
|
607 |
+
if not self.legal():
|
608 |
+
raise Error("No command name to scan.")
|
609 |
+
item = self.item
|
610 |
+
name = ""
|
611 |
+
if item.type in [esc_symb_ty, esc_str_ty]:
|
612 |
+
name = item.val
|
613 |
+
else:
|
614 |
+
if not "{" == item.val:
|
615 |
+
raise Error("Command definition misses first {.")
|
616 |
+
self.next()
|
617 |
+
item = self.skip_blank_tokens()
|
618 |
+
if not item.type in [esc_symb_ty, esc_str_ty]:
|
619 |
+
raise Error("Command definition does not begin with control sequence.")
|
620 |
+
name = item.val
|
621 |
+
self.next()
|
622 |
+
item = self.skip_blank_tokens()
|
623 |
+
if not "}" == item.val:
|
624 |
+
raise Error("Definition for commmand %s misses first }., %s" %
|
625 |
+
(name, item.val))
|
626 |
+
self.next()
|
627 |
+
self.skip_blank_tokens()
|
628 |
+
return name
|
629 |
+
|
630 |
+
def scan_numargs(self, name):
|
631 |
+
"""
|
632 |
+
name is the name of the command or environment definition being
|
633 |
+
scanned.
|
634 |
+
Starts on a nonblank token.
|
635 |
+
Returns numargs
|
636 |
+
where numargs is the number of arguments in a command or environment
|
637 |
+
definition,
|
638 |
+
"""
|
639 |
+
if not self.legal():
|
640 |
+
raise Error("No numargs to scan.")
|
641 |
+
item = self.item
|
642 |
+
numargs = 0
|
643 |
+
if not simple_ty == item.type:
|
644 |
+
raise Error("Illegal command or environment definition: "+name)
|
645 |
+
if "[" == item.val:
|
646 |
+
if not 4 < len(self.data):
|
647 |
+
raise Error("Command or environment definition is illegal: "+name)
|
648 |
+
item = self.next()
|
649 |
+
if not simple_ty == item.type:
|
650 |
+
raise Error("Illegal command or environment definition: "+name)
|
651 |
+
numargs = item.val
|
652 |
+
if not pos_digit_re.match(numargs):
|
653 |
+
raise Error("%s must be argument number after %s" % (numargs, name))
|
654 |
+
numargs = int(numargs)
|
655 |
+
self.next()
|
656 |
+
item = self.skip_blank_tokens()
|
657 |
+
if not simple_ty == item.type:
|
658 |
+
raise Error("Illegal command definition: "+name)
|
659 |
+
if "]" != item.val:
|
660 |
+
raise Error("Illegal command definition: "+name)
|
661 |
+
self.next()
|
662 |
+
self.skip_blank_tokens()
|
663 |
+
return numargs
|
664 |
+
|
665 |
+
def scan_command_def(self):
|
666 |
+
"""Scan a command definition.
|
667 |
+
Return command_def.
|
668 |
+
Assumes that the number of arguments is at most 9.
|
669 |
+
"""
|
670 |
+
if not self.legal():
|
671 |
+
raise Error("No command definition to scan.")
|
672 |
+
item = self.item
|
673 |
+
if not 2 < len(self.data):
|
674 |
+
raise Error("Command definition is illegal.")
|
675 |
+
# newcommand or renewcommand
|
676 |
+
if not item.type in [esc_symb_ty, esc_str_ty]:
|
677 |
+
raise Error("Command definition should begin with control sequence: "+item.val)
|
678 |
+
if item.val not in ["newcommand", "renewcommand"]:
|
679 |
+
raise Error("Command definition should begin with control sequence.")
|
680 |
+
self.next()
|
681 |
+
self.skip_blank_tokens()
|
682 |
+
|
683 |
+
cmd_name = self.scan_command_name()
|
684 |
+
numargs = self.scan_numargs(cmd_name)
|
685 |
+
|
686 |
+
body_group = self.scan_group()
|
687 |
+
if group_ty != body_group.type:
|
688 |
+
raise Error("Command body missing: "+cmd_name)
|
689 |
+
body_val = strip_comments(body_group.val)
|
690 |
+
return Command_def(cmd_name, numargs, body_val)
|
691 |
+
|
692 |
+
def scan_env_name(self):
|
693 |
+
"""Starts on a {.
|
694 |
+
Returns name.
|
695 |
+
"""
|
696 |
+
if not self.legal():
|
697 |
+
raise Error("No environment name to scan.")
|
698 |
+
item = self.item
|
699 |
+
if not "{" == item.val:
|
700 |
+
raise Error("Env. definition begins with %s, not with {" % (item.val))
|
701 |
+
self.next()
|
702 |
+
item = self.skip_blank_tokens()
|
703 |
+
name = ""
|
704 |
+
if not simple_ty == item.type:
|
705 |
+
raise Error("1. Env. def. begins with cont. seq. %s, not with env.name."
|
706 |
+
% (item.val))
|
707 |
+
while self.uplegal() and not blank_or_rbrace_re.match(item.val):
|
708 |
+
name += item.val
|
709 |
+
item = self.next()
|
710 |
+
if not simple_ty == item.type:
|
711 |
+
raise Error("2. Env. def. begins with cont. seq. %s, not with env.name."
|
712 |
+
% (item.val))
|
713 |
+
item = self.skip_blank_tokens()
|
714 |
+
if not "}" == item.val:
|
715 |
+
raise Error("Command definition does not begin with control sequence.")
|
716 |
+
self.next()
|
717 |
+
self.skip_blank_tokens()
|
718 |
+
return name
|
719 |
+
|
720 |
+
def scan_env_def(self):
|
721 |
+
"""Scan an environment definition.
|
722 |
+
Return env_def
|
723 |
+
Assumes that the number of arguments is at most 9.
|
724 |
+
"""
|
725 |
+
if not self.legal():
|
726 |
+
raise Error("No environment definition to scan.")
|
727 |
+
item = self.item
|
728 |
+
if not 7 < len(self.data):
|
729 |
+
raise Error("Environment definition is illegal.")
|
730 |
+
pos = 0
|
731 |
+
|
732 |
+
if not item.type in [esc_symb_ty, esc_str_ty]:
|
733 |
+
raise Error("Env. definition does not begin with control sequence:"+
|
734 |
+
item.val)
|
735 |
+
if item.val not in ["newenvironment", "renewenvironment"]:
|
736 |
+
raise Error("Env. definition does not begin with control sequence.")
|
737 |
+
self.next()
|
738 |
+
self.skip_blank_tokens()
|
739 |
+
|
740 |
+
env_name = self.scan_env_name()
|
741 |
+
numargs = self.scan_numargs(env_name)
|
742 |
+
self.skip_blank_tokens()
|
743 |
+
|
744 |
+
begin_group = self.scan_group()
|
745 |
+
if group_ty != begin_group.type:
|
746 |
+
raise Error("Begin body missing: "+env_name)
|
747 |
+
begin_val = strip_comments(begin_group.val)
|
748 |
+
|
749 |
+
self.skip_blank_tokens()
|
750 |
+
|
751 |
+
end_group = self.scan_group()
|
752 |
+
if group_ty != end_group.type:
|
753 |
+
raise Error("End body missing:"+env_name)
|
754 |
+
end_val = strip_comments(end_group.val)
|
755 |
+
|
756 |
+
return Env_def(env_name, numargs, begin_val, end_val)
|
757 |
+
|
758 |
+
def scan_defs(self):
|
759 |
+
if not self.legal():
|
760 |
+
raise Error("No definitions to scan.")
|
761 |
+
self.reset()
|
762 |
+
command_defs, env_defs = self.defs
|
763 |
+
while self.uplegal():
|
764 |
+
if (esc_str_ty == self.item.type
|
765 |
+
and self.item.val in ["newcommand", "renewcommand"]):
|
766 |
+
def_start_pos = self.pos
|
767 |
+
command_def = self.scan_command_def()
|
768 |
+
command_defs[command_def.name] = command_def
|
769 |
+
def_end_pos = self.pos
|
770 |
+
for del_pos in range(def_start_pos,def_end_pos):
|
771 |
+
del self.data[def_start_pos]
|
772 |
+
self.pos = def_start_pos
|
773 |
+
self.item = self.data[self.pos]
|
774 |
+
elif (esc_str_ty == self.item.type and self.item.val
|
775 |
+
in ["newenvironment", "renewenvironment"]):
|
776 |
+
def_start_pos = self.pos
|
777 |
+
env_def = self.scan_env_def()
|
778 |
+
env_defs[env_def.name] = env_def
|
779 |
+
def_end_pos = self.pos
|
780 |
+
for del_pos in range(def_start_pos,def_end_pos):
|
781 |
+
del self.data[def_start_pos]
|
782 |
+
self.pos = def_start_pos
|
783 |
+
self.item = self.data[self.pos]
|
784 |
+
else:
|
785 |
+
self.next()
|
786 |
+
|
787 |
+
# Instances
|
788 |
+
|
789 |
+
def scan_args(self, command_or_env_def):
|
790 |
+
"""Scan the arguments of a command or environment.
|
791 |
+
Return [args].
|
792 |
+
"""
|
793 |
+
if not self.legal():
|
794 |
+
raise Error("No arguments to scan.")
|
795 |
+
numargs = command_or_env_def.numargs
|
796 |
+
name = command_or_env_def.name
|
797 |
+
|
798 |
+
args = []
|
799 |
+
for i in range(numargs):
|
800 |
+
arg = []
|
801 |
+
if not (simple_ty == self.item.type and "{" == self.item.val):
|
802 |
+
arg = [self.item]
|
803 |
+
self.next()
|
804 |
+
else:
|
805 |
+
group = self.scan_group()
|
806 |
+
arg = group.val
|
807 |
+
args.append(arg)
|
808 |
+
return args
|
809 |
+
|
810 |
+
def scan_command(self, command_def):
|
811 |
+
"""Scan the arguments of a command.
|
812 |
+
Return command_instance
|
813 |
+
"""
|
814 |
+
if not self.legal():
|
815 |
+
raise Error("No command to scan.")
|
816 |
+
if not self.item.type in [esc_symb_ty, esc_str_ty]:
|
817 |
+
raise Error("Command does not begin with control sequence.")
|
818 |
+
name = self.item.val
|
819 |
+
self.next()
|
820 |
+
if 0 < command_def.numargs:
|
821 |
+
self.skip_blank_tokens()
|
822 |
+
args = self.scan_args(command_def)
|
823 |
+
else:
|
824 |
+
args = []
|
825 |
+
return Command_instance(name, args)
|
826 |
+
|
827 |
+
def test_env_boundary(self, item):
|
828 |
+
"""Check whether an environment begin or end follows.
|
829 |
+
Return 1 if \begin, -1 if \end, 0 otherwise.
|
830 |
+
"""
|
831 |
+
d = 0
|
832 |
+
if esc_str_ty == item.type:
|
833 |
+
if "begin"==item.val:
|
834 |
+
d = 1
|
835 |
+
elif "end"==item.val:
|
836 |
+
d = -1
|
837 |
+
return d
|
838 |
+
|
839 |
+
def scan_env_begin(self):
|
840 |
+
"""Scan an environment name.
|
841 |
+
Return env_name.
|
842 |
+
"""
|
843 |
+
if not self.legal():
|
844 |
+
raise Error("No environment begin to scan.")
|
845 |
+
item = self.item
|
846 |
+
if not (esc_str_ty == item.type and "begin" == item.val):
|
847 |
+
raise Error("Environment does not begin with begin.")
|
848 |
+
self.next()
|
849 |
+
name_group = self.scan_group()
|
850 |
+
name = detokenize(name_group.val)
|
851 |
+
return name
|
852 |
+
|
853 |
+
def scan_env_end(self):
|
854 |
+
"""Scan an environment end.
|
855 |
+
Return env_name.
|
856 |
+
"""
|
857 |
+
if not self.legal():
|
858 |
+
raise Error("No environment end to scan.")
|
859 |
+
item = self.item
|
860 |
+
if not (esc_str_ty == item.type and "end" == item.val):
|
861 |
+
raise Error("Environment does not end with end.")
|
862 |
+
self.next()
|
863 |
+
name_group = self.scan_group()
|
864 |
+
name = detokenize(name_group.val)
|
865 |
+
return name
|
866 |
+
|
867 |
+
def scan_env_rest(self, env_def):
|
868 |
+
"""Scanning starts after \begin{envname}.
|
869 |
+
Returns env_instance.
|
870 |
+
"""
|
871 |
+
if not self.legal():
|
872 |
+
raise Error("No environment rest to scan.")
|
873 |
+
count = 1 # We are already within a boundary.
|
874 |
+
args = self.scan_args(env_def)
|
875 |
+
body = []
|
876 |
+
while count and self.uplegal():
|
877 |
+
old_pos = self.pos
|
878 |
+
d = self.test_env_boundary(self.item)
|
879 |
+
count += d
|
880 |
+
if 1 == d:
|
881 |
+
self.scan_env_begin()
|
882 |
+
elif -1 == d:
|
883 |
+
self.scan_env_end()
|
884 |
+
else:
|
885 |
+
self.next()
|
886 |
+
if 0 < count:
|
887 |
+
body.extend(self.data[old_pos : self.pos])
|
888 |
+
return Env_instance(env_def.name, args, body)
|
889 |
+
|
890 |
+
# Definitions
|
891 |
+
|
892 |
+
def restore_defs(self):
|
893 |
+
if os.path.isfile(self.defs_db_file):
|
894 |
+
print("Using defs db %s" % (self.defs_db_file))
|
895 |
+
db_h = shelve.open(self.defs_db)
|
896 |
+
self.defs = db_h["defs"]
|
897 |
+
db_h.close()
|
898 |
+
|
899 |
+
def save_defs(self):
|
900 |
+
db_h = shelve.open(self.defs_db)
|
901 |
+
if "defs" in db_h:
|
902 |
+
del db_h["defs"]
|
903 |
+
db_h["defs"] = self.defs
|
904 |
+
db_h.close()
|
905 |
+
|
906 |
+
def add_defs(self, defs_file):
|
907 |
+
defs_file_compl = defs_file + ".sty"
|
908 |
+
if not os.path.isfile(defs_file_compl):
|
909 |
+
raise Error("%s does not exist" % (defs_file_compl))
|
910 |
+
|
911 |
+
defs_db_file = self.defs_db_file
|
912 |
+
if newer(defs_db_file, defs_file_compl):
|
913 |
+
print("Using defs db %s for %s" % (defs_db_file, defs_file))
|
914 |
+
else:
|
915 |
+
defs_fp = open(defs_file_compl, "r")
|
916 |
+
defs_str = defs_fp.read()
|
917 |
+
defs_fp.close()
|
918 |
+
ds = Tex_stream()
|
919 |
+
ds.defs = self.defs
|
920 |
+
defs_text = ds.smart_tokenize(defs_str,isatletter=True)
|
921 |
+
# changing ds.defs will change self.defs
|
922 |
+
if self.debug:
|
923 |
+
defs_seen_file = "%s-seen.sty" % (defs_file)
|
924 |
+
defs_seen_fp = open(defs_seen_file, "w")
|
925 |
+
out = detokenize(defs_text,isatletter=True)
|
926 |
+
defs_seen_fp.write(out)
|
927 |
+
defs_seen_fp.close()
|
928 |
+
ds.scan_defs()
|
929 |
+
if self.debug:
|
930 |
+
out = ""
|
931 |
+
command_defs, env_defs = self.defs
|
932 |
+
for def_name in command_defs.keys():
|
933 |
+
out += command_defs[def_name].show() + "\n"
|
934 |
+
for def_name in env_defs.keys():
|
935 |
+
out += env_defs[def_name].show() +"\n"
|
936 |
+
print("Definitions after reading %s:" % (defs_file))
|
937 |
+
print(out)
|
938 |
+
|
939 |
+
# Applying definitions, recursively
|
940 |
+
# (maybe not quite in Knuth order, so avoid tricks!)
|
941 |
+
|
942 |
+
def subst_args(self, body, args):
|
943 |
+
out = []
|
944 |
+
pos = 0
|
945 |
+
while pos < len(body):
|
946 |
+
item = body[pos]
|
947 |
+
if not (simple_ty == item.type and "#" == item.val):
|
948 |
+
out.append(item)
|
949 |
+
pos += 1
|
950 |
+
continue
|
951 |
+
pos += 1
|
952 |
+
token = body[pos]
|
953 |
+
argnum = token.val
|
954 |
+
if not pos_digit_re.match(argnum):
|
955 |
+
raise Error("# is not followed by number.")
|
956 |
+
argnum = int(argnum)
|
957 |
+
if argnum > len(args):
|
958 |
+
raise Error("Too large argument number.")
|
959 |
+
arg = args[argnum-1]
|
960 |
+
out += arg
|
961 |
+
pos += 1
|
962 |
+
return out
|
963 |
+
|
964 |
+
def apply_command_recur(self, command_instance):
|
965 |
+
command_defs, env_defs = self.defs
|
966 |
+
name = command_instance.name
|
967 |
+
command_def = command_defs[name]
|
968 |
+
|
969 |
+
args = command_instance.args
|
970 |
+
body = command_def.body
|
971 |
+
result = self.subst_args(body, args)
|
972 |
+
try:
|
973 |
+
result = self.apply_all_recur(result)
|
974 |
+
except Empty_text_error as e:
|
975 |
+
raise Error("apply_all_recur fails on command instance %s: %s, %s" % \
|
976 |
+
(command_instance.show(), detokenize(e.data), e.message))
|
977 |
+
return result
|
978 |
+
|
979 |
+
def apply_env_recur(self, env_instance):
|
980 |
+
command_defs, env_defs = self.defs
|
981 |
+
name = env_instance.name
|
982 |
+
env_def = env_defs[name]
|
983 |
+
|
984 |
+
begin, end = env_def.begin, env_def.end
|
985 |
+
body, args = env_instance.body, env_instance.args
|
986 |
+
out = self.subst_args(begin, args) + body + self.subst_args(end, args)
|
987 |
+
return self.apply_all_recur(out)
|
988 |
+
|
989 |
+
|
990 |
+
def apply_all_recur(self, data, report=False):
|
991 |
+
ts = Tex_stream(data)
|
992 |
+
ts.defs = self.defs
|
993 |
+
command_defs, env_defs = self.defs
|
994 |
+
out = []
|
995 |
+
progress_step = 10000
|
996 |
+
progress = progress_step
|
997 |
+
if not ts.legal():
|
998 |
+
raise Empty_text_error(data, "No text to process.")
|
999 |
+
while ts.uplegal():
|
1000 |
+
if self.pos > progress:
|
1001 |
+
if report:
|
1002 |
+
print(self.pos)
|
1003 |
+
progress += progress_step
|
1004 |
+
if not ts.item.type in [esc_symb_ty, esc_str_ty]:
|
1005 |
+
out.append(ts.item)
|
1006 |
+
ts.next()
|
1007 |
+
continue
|
1008 |
+
if 1 == ts.test_env_boundary(ts.item):
|
1009 |
+
old_pos = ts.pos
|
1010 |
+
env_name = ts.scan_env_begin()
|
1011 |
+
if env_name not in env_defs:
|
1012 |
+
out.extend(ts.data[old_pos : ts.pos])
|
1013 |
+
continue
|
1014 |
+
else:
|
1015 |
+
env_def = env_defs[env_name]
|
1016 |
+
env_instance = ts.scan_env_rest(env_def)
|
1017 |
+
result = ts.apply_env_recur(env_instance)
|
1018 |
+
out.extend(result)
|
1019 |
+
elif ts.item.val not in command_defs:
|
1020 |
+
out.append(ts.item)
|
1021 |
+
ts.next()
|
1022 |
+
continue
|
1023 |
+
else:
|
1024 |
+
command_def = command_defs[ts.item.val]
|
1025 |
+
command_inst = ts.scan_command(command_def)
|
1026 |
+
result = ts.apply_command_recur(command_inst)
|
1027 |
+
out.extend(result)
|
1028 |
+
return out
|
1029 |
+
|
1030 |
+
|
1031 |
+
# Processing files
|
1032 |
+
|
1033 |
+
def process_file(self, file):
|
1034 |
+
"""Returns the new defs.
|
1035 |
+
"""
|
1036 |
+
file = cut_extension(file, ".tex")
|
1037 |
+
source_file = "%s.tex" % (file)
|
1038 |
+
print("File %s [" % (source_file))
|
1039 |
+
source_fp = open(source_file, "r")
|
1040 |
+
text_str = source_fp.read()
|
1041 |
+
source_fp.close()
|
1042 |
+
|
1043 |
+
self.smart_tokenize(text_str, handle_inputs=True)
|
1044 |
+
if not self.data:
|
1045 |
+
raise Error("Empty tokenization result.")
|
1046 |
+
self.reset()
|
1047 |
+
|
1048 |
+
if self.debug:
|
1049 |
+
source_seen_fname = "%s-seen.tex" % (file)
|
1050 |
+
source_seen_fp = open(source_seen_fname, "w")
|
1051 |
+
source_seen_fp.write(detokenize(self.data))
|
1052 |
+
source_seen_fp.close()
|
1053 |
+
self.scan_defs()
|
1054 |
+
self.data = self.apply_all_recur(self.data, report=True)
|
1055 |
+
result_fname = "%s-clean.tex" % (file)
|
1056 |
+
print("Writing %s [" % (result_fname))
|
1057 |
+
result_fp = open(result_fname, "w")
|
1058 |
+
result_fp.write(self.smart_detokenize())
|
1059 |
+
result_fp.close()
|
1060 |
+
print("] file %s" % (result_fname))
|
1061 |
+
print("] file %s" % (source_file))
|
1062 |
+
|
1063 |
+
def process_if_newer(self, file):
|
1064 |
+
"""
|
1065 |
+
\input{file} is added to the token list.
|
1066 |
+
If the input file is newer it is processed.
|
1067 |
+
Returns tokenized \input{file}.
|
1068 |
+
"""
|
1069 |
+
file = cut_extension(file, ".tex")
|
1070 |
+
tex_file = file+".tex"
|
1071 |
+
clean_tex_file = file+"-clean.tex"
|
1072 |
+
if newer(clean_tex_file, tex_file):
|
1073 |
+
print("Using %s." % (clean_tex_file))
|
1074 |
+
else:
|
1075 |
+
ts = Tex_stream()
|
1076 |
+
ts.data = []
|
1077 |
+
ts.defs = self.defs
|
1078 |
+
ts.process_file(file)
|
1079 |
+
to_add = "\\input{%s}" % (file)
|
1080 |
+
return tokenize(to_add)
|
1081 |
+
|
1082 |
+
# Main
|
1083 |
+
|
1084 |
+
long_optlist = ["debug","defs="]
|
1085 |
+
options, restargs = getopt_map("x", long_optlist)
|
1086 |
+
|
1087 |
+
debug = False
|
1088 |
+
if "--debug" in options:
|
1089 |
+
debug = True
|
1090 |
+
|
1091 |
+
root = restargs[0]
|
1092 |
+
root = cut_extension(root, ".tex")
|
1093 |
+
if "--defs" in options:
|
1094 |
+
defs_root = options["--defs"]
|
1095 |
+
else:
|
1096 |
+
defs_root = "%s" % (root)
|
1097 |
+
defs_db = defs_root
|
1098 |
+
defs_db_file = defs_root+".db"
|
1099 |
+
|
1100 |
+
ts = Tex_stream()
|
1101 |
+
ts.defs_db = defs_db
|
1102 |
+
ts.defs_db_file = defs_db_file
|
1103 |
+
ts.debug = debug
|
1104 |
+
|
1105 |
+
ts.restore_defs()
|
1106 |
+
for root in restargs:
|
1107 |
+
ts.process_file(root)
|
1108 |
+
|
1109 |
+
print("(Re)creating defs db %s" % (defs_db))
|
1110 |
+
ts.save_defs()
|
utils/def_handle.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
args = parse_command_line()
|
7 |
+
data = read(args.input)
|
8 |
+
data = convert(data)
|
9 |
+
write(args.output, data)
|
10 |
+
|
11 |
+
|
12 |
+
def parse_command_line():
|
13 |
+
parser = argparse.ArgumentParser(
|
14 |
+
description='Replace \\def with \\newcommand where possible.',
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
'input',
|
18 |
+
help='TeX input file with \\def',
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
'--output',
|
22 |
+
'-o',
|
23 |
+
required=True,
|
24 |
+
help='TeX output file with \\newcommand',
|
25 |
+
)
|
26 |
+
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
def read(path):
|
30 |
+
with open(path, mode='rb') as handle:
|
31 |
+
return handle.read()
|
32 |
+
|
33 |
+
|
34 |
+
def convert(data):
|
35 |
+
return re.sub(
|
36 |
+
rb'((?:\\(?:expandafter|global|long|outer|protected)'
|
37 |
+
rb'(?: +|\r?\n *)?)*)?'
|
38 |
+
rb'\\def *(\\[a-zA-Z]+) *(?:#+([0-9]))*\{',
|
39 |
+
replace,
|
40 |
+
data,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def replace(match):
|
45 |
+
prefix = match.group(1)
|
46 |
+
if (
|
47 |
+
prefix is not None and
|
48 |
+
(
|
49 |
+
b'expandafter' in prefix or
|
50 |
+
b'global' in prefix or
|
51 |
+
b'outer' in prefix or
|
52 |
+
b'protected' in prefix
|
53 |
+
)
|
54 |
+
):
|
55 |
+
pass #return match.group(0)
|
56 |
+
|
57 |
+
result = rb'\newcommand'
|
58 |
+
|
59 |
+
result += b'{' + match.group(2) + b'}'
|
60 |
+
if match.lastindex == 3:
|
61 |
+
result += b'[' + match.group(3) + b']'
|
62 |
+
|
63 |
+
result += b'{'
|
64 |
+
return result
|
65 |
+
|
66 |
+
|
67 |
+
def write(path, data):
|
68 |
+
with open(path, mode='wb') as handle:
|
69 |
+
handle.write(data)
|
70 |
+
|
71 |
+
print('=> File written: {0}'.format(path))
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
main()
|
utils/gradio_utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import StoppingCriteria
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
# Handle termination signal
|
6 |
+
def signal_handler(sig, frame):
|
7 |
+
print("\nTermination signal received. Shutting down Gradio interface.")
|
8 |
+
sys.exit(0)
|
9 |
+
|
10 |
+
# Custom stopping criteria
|
11 |
+
class StopOnTokens(StoppingCriteria):
|
12 |
+
def __call__(self, input_ids, scores, **kwargs):
|
13 |
+
stop_ids = [29, 0] # Define specific stop token IDs
|
14 |
+
return input_ids[0][-1] in stop_ids
|
15 |
+
|
16 |
+
# Toggle task selection
|
17 |
+
def toggle_selection(current_task, new_task):
|
18 |
+
"""Toggle task selection: deselect if clicked again, otherwise update selection."""
|
19 |
+
updated_task = "" if current_task == new_task else new_task
|
20 |
+
return updated_task
|
utils/graph_utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import regex
|
2 |
+
import re
|
3 |
+
|
4 |
+
def retrieve_text_cite(text, command):
|
5 |
+
base_pattern = (
|
6 |
+
r'\\' + command + r"(?:\[(?:.*?)\])*\{((?:[^{}]+|\{(?1)\})*)\}(?:\[(?:.*?)\])*"
|
7 |
+
)
|
8 |
+
|
9 |
+
def extract_text_inside_curly_braces(text):
|
10 |
+
pattern = r"\{((?:[^{}]|(?R))*)\}"
|
11 |
+
|
12 |
+
match = regex.search(pattern, text)
|
13 |
+
|
14 |
+
if match:
|
15 |
+
return match.group(1)
|
16 |
+
else:
|
17 |
+
return ""
|
18 |
+
|
19 |
+
found_texts = []
|
20 |
+
for match in regex.finditer(base_pattern, text):
|
21 |
+
temp_substring = text[match.span()[0] : match.span()[1]]
|
22 |
+
found_texts.append(extract_text_inside_curly_braces(temp_substring))
|
23 |
+
|
24 |
+
return found_texts
|
25 |
+
|
26 |
+
def get_citing_sentences(content):
|
27 |
+
content_new = re.sub(r'[\n]+', ' ', content) # keep only one \n
|
28 |
+
content_new = re.sub(r'e\.g\.' , 'eg', content_new)
|
29 |
+
content_new = re.sub(r'i\.e\.' , 'eg', content_new)
|
30 |
+
content_new = re.sub(r'etc\.' , 'etc', content_new)
|
31 |
+
content_new = re.sub(r' +', ' ', content_new)
|
32 |
+
sentences = [sentence + '.' for sentence in content_new.split('.')]
|
33 |
+
citing_sentences = [s for s in sentences if '\\cite' in s]
|
34 |
+
results = {}
|
35 |
+
for s in citing_sentences:
|
36 |
+
citations = retrieve_text_cite(s, 'cite')
|
37 |
+
final_citations = []
|
38 |
+
for cite in citations:
|
39 |
+
final_citations.extend(cite.split(','))
|
40 |
+
results[s] = final_citations
|
41 |
+
return results
|
42 |
+
|
43 |
+
def get_intro(content):
|
44 |
+
sections = retrieve_text_cite(content, 'section')
|
45 |
+
if sections == []:
|
46 |
+
return ''
|
47 |
+
try_intro = [x for x in sections if x.strip().lower() == 'introduction']
|
48 |
+
if try_intro == []:
|
49 |
+
return ''
|
50 |
+
else:
|
51 |
+
to_find = try_intro[0]
|
52 |
+
ind = sections.index(to_find)
|
53 |
+
if ind + 1 < len(sections):
|
54 |
+
start_marker = f'\\section{{{sections[ind]}}}'
|
55 |
+
end_marker = f'\\section{{{sections[ind+1]}}}'
|
56 |
+
start_point = content.find(start_marker)
|
57 |
+
end_point = content.find(end_marker)
|
58 |
+
return content[start_point+len(start_marker):end_point]
|
59 |
+
else:
|
60 |
+
return ''
|
61 |
+
|
62 |
+
def get_related_works(content):
|
63 |
+
sections = retrieve_text_cite(content, 'section')
|
64 |
+
if sections == []:
|
65 |
+
return ''
|
66 |
+
possible_related = [
|
67 |
+
"Literature Review",
|
68 |
+
"Related Work",
|
69 |
+
"Related Works",
|
70 |
+
"Prior Work",
|
71 |
+
"Prior Works",
|
72 |
+
"Related Research",
|
73 |
+
"Research Overview",
|
74 |
+
"Previous Work",
|
75 |
+
"Previous Works",
|
76 |
+
"Review of the Literature",
|
77 |
+
"Review of Related Literature",
|
78 |
+
"Survey of Related Work",
|
79 |
+
"Survey of Related Works",
|
80 |
+
"Background",
|
81 |
+
"Research Background",
|
82 |
+
"Review of Prior Research",
|
83 |
+
"Literature Survey",
|
84 |
+
"Overview of Literature",
|
85 |
+
"Existing Literature",
|
86 |
+
"Review of Existing Work",
|
87 |
+
"Review of Existing Works",
|
88 |
+
"Review of Previous Studies",
|
89 |
+
"Review of Prior Literature",
|
90 |
+
"Summary of Related Research",
|
91 |
+
"Survey of Existing Literature",
|
92 |
+
"Survey of Literature",
|
93 |
+
"Existing Research Overview",
|
94 |
+
"Prior Literature Review"
|
95 |
+
]
|
96 |
+
possible_sections = [x for x in sections if any([True for y in possible_related if y.lower() == x.strip().lower()])]
|
97 |
+
if possible_sections == []:
|
98 |
+
return ''
|
99 |
+
else:
|
100 |
+
to_find = possible_sections[0]
|
101 |
+
ind = sections.index(to_find)
|
102 |
+
|
103 |
+
if ind + 1 < len(sections):
|
104 |
+
start_marker = f'\\section{{{sections[ind]}}}'
|
105 |
+
end_marker = f'\\section{{{sections[ind+1]}}}'
|
106 |
+
start_point = content.find(start_marker)
|
107 |
+
end_point = content.find(end_marker)
|
108 |
+
return content[start_point+len(start_marker):end_point]
|
109 |
+
|
110 |
+
else:
|
111 |
+
return ''
|
utils/latexpand
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/perl
|
2 |
+
# Inspired by latexpand by D. Musliner, University of Michigan
|
3 |
+
# 2012-2023: Matthieu Moy <[email protected]>
|
4 |
+
# BSD License
|
5 |
+
|
6 |
+
use strict;
|
7 |
+
use Cwd;
|
8 |
+
use Getopt::Long;
|
9 |
+
use IO::Handle;
|
10 |
+
use File::Spec;
|
11 |
+
|
12 |
+
my $TEXINPUTS = $ENV{'TEXINPUTS'};
|
13 |
+
# By default, search in current directory. We use '.' and not getcwd()
|
14 |
+
# to avoid issues if the working directory contains a ':' character.
|
15 |
+
if (!$TEXINPUTS) { $TEXINPUTS = '.'; }
|
16 |
+
|
17 |
+
my $verbose;
|
18 |
+
my $keep_comments;
|
19 |
+
my $keep_includes;
|
20 |
+
my $empty_comments;
|
21 |
+
my $help;
|
22 |
+
my $long_help;
|
23 |
+
my %defines = ();
|
24 |
+
my $output;
|
25 |
+
my $explain;
|
26 |
+
my $show_graphics;
|
27 |
+
my $graphics_extensions = ":.pdf:.png:.jpg:.eps";
|
28 |
+
my $expand_usepackage;
|
29 |
+
my $expand_bbl;
|
30 |
+
my $biber;
|
31 |
+
my $fatal;
|
32 |
+
my $version;
|
33 |
+
my $makeatletter;
|
34 |
+
my $inside_import;
|
35 |
+
my $in_enc = "bytes";
|
36 |
+
my $out_enc = "bytes";
|
37 |
+
|
38 |
+
GetOptions (
|
39 |
+
'h' => \$help,
|
40 |
+
'help' => \$long_help,
|
41 |
+
'verbose|v' => \$verbose,
|
42 |
+
'keep-comments' => \$keep_comments,
|
43 |
+
'keep-includes' => \$keep_includes,
|
44 |
+
'empty-comments' => \$empty_comments,
|
45 |
+
'define|d=s%' => \%defines,
|
46 |
+
'output|o=s' => \$output,
|
47 |
+
'explain' => \$explain,
|
48 |
+
'show-graphics' => \$show_graphics,
|
49 |
+
'graphics-extensions' => \$graphics_extensions,
|
50 |
+
'expand-usepackage' => \$expand_usepackage,
|
51 |
+
'expand-bbl=s' => \$expand_bbl,
|
52 |
+
'biber=s' => \$biber,
|
53 |
+
'fatal' => \$fatal,
|
54 |
+
'version' => \$version,
|
55 |
+
'makeatletter' => \$makeatletter,
|
56 |
+
'in-encoding=s' => \$in_enc,
|
57 |
+
'out-encoding=s' => \$out_enc,
|
58 |
+
) or pod2usage_wrapper(2);
|
59 |
+
version() if $version;
|
60 |
+
pod2usage_wrapper(0) if $help;
|
61 |
+
pod2usage_wrapper(-exitstatus => 0, -output => \*STDOUT, -verbose => 2) if $long_help;
|
62 |
+
|
63 |
+
sub pod2usage_wrapper
|
64 |
+
{
|
65 |
+
# Like pod2usage, but fall back to a simpler implem in case
|
66 |
+
# pod2usage can't be found.
|
67 |
+
if (eval {require Pod::Usage;1;} ne 1) {
|
68 |
+
print "Please install perldoc and Pod::Usage to get proper help.\n";
|
69 |
+
my $started = 0;
|
70 |
+
open (my $in, '<', "$0") or die $!;
|
71 |
+
while (<$in>) {
|
72 |
+
if ($started) {
|
73 |
+
print;
|
74 |
+
}
|
75 |
+
if (/^__END__$/) {
|
76 |
+
$started = 1;
|
77 |
+
}
|
78 |
+
}
|
79 |
+
} else {
|
80 |
+
Pod::Usage->import();
|
81 |
+
pod2usage(@_);
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
sub get_version
|
86 |
+
{
|
87 |
+
# $VERSION's value will be substituted by 'make dist', but the
|
88 |
+
# next line won't (the string has to be broken to avoid it).
|
89 |
+
my $VERSION = 'v1.7.2';
|
90 |
+
if ($VERSION eq '@LATEXPAND' . '_VERSION@') {
|
91 |
+
my($vol,$dir,$file) = File::Spec->splitpath($0);
|
92 |
+
chdir($dir);
|
93 |
+
$VERSION = `git describe --tags HEAD 2>/dev/null`;
|
94 |
+
}
|
95 |
+
if ($VERSION eq '') {
|
96 |
+
$VERSION = '<unknown version>';
|
97 |
+
}
|
98 |
+
$VERSION =~ s/^\s+|\s+$//g;
|
99 |
+
return $VERSION;
|
100 |
+
}
|
101 |
+
|
102 |
+
sub version
|
103 |
+
{
|
104 |
+
print "latexpand version ". get_version() .".\n";
|
105 |
+
exit(0);
|
106 |
+
}
|
107 |
+
|
108 |
+
my $nl = "";
|
109 |
+
if ($empty_comments) {
|
110 |
+
$nl = "%\n";
|
111 |
+
}
|
112 |
+
|
113 |
+
if ($output && $output ne "-") {
|
114 |
+
open (my $OUTPUT, '>', "$output") or die $!;
|
115 |
+
STDOUT->fdopen(\*$OUTPUT, 'w') or die $!;
|
116 |
+
}
|
117 |
+
|
118 |
+
sub say
|
119 |
+
{
|
120 |
+
if ($verbose) {
|
121 |
+
print STDERR "$_[0]";
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
my $makeatletter_found;
|
126 |
+
my $in_preamble;
|
127 |
+
|
128 |
+
use open IN => ":$in_enc", OUT => ":$out_enc";
|
129 |
+
|
130 |
+
foreach my $file (@ARGV)
|
131 |
+
{
|
132 |
+
say "processing $file\n";
|
133 |
+
$makeatletter_found = 0;
|
134 |
+
$in_preamble = 1;
|
135 |
+
$inside_import = "";
|
136 |
+
if ($file =~ /\.bib$/) {
|
137 |
+
warn "WARNING: latexpand is not meant to be used on BibTeX files like '$file'.\n" .
|
138 |
+
" Run latexpand on your main .tex file, using '--expand-bbl FILE'\n" .
|
139 |
+
" or '--biber FILE' if needed to inline the generated bbl file.\n";
|
140 |
+
} elsif (not $file =~ /\.tex$/) {
|
141 |
+
warn "WARNING: latexpand is meant to be used on .tex files, which $file isn't.\n";
|
142 |
+
}
|
143 |
+
process_file($file, " ");
|
144 |
+
}
|
145 |
+
|
146 |
+
sub cat_file
|
147 |
+
{
|
148 |
+
my $file = shift;
|
149 |
+
open (my $INFILE, "<", $file) || die "could not open input file '$file'\n";
|
150 |
+
while (<$INFILE>) {
|
151 |
+
print;
|
152 |
+
}
|
153 |
+
close ($INFILE);
|
154 |
+
}
|
155 |
+
|
156 |
+
sub process_file
|
157 |
+
{
|
158 |
+
my $file = shift;
|
159 |
+
my $prefix = (shift || "");
|
160 |
+
my $in_comment = 0;
|
161 |
+
open(my $FILE, "<", $file) or die "could not open input file '$file'\n";
|
162 |
+
my $commented_newline = 0;
|
163 |
+
while (my $line = <$FILE>) {
|
164 |
+
if ($line =~ /^[ \t]*\\endinput/) {
|
165 |
+
# Surprisingly, text after \endinput on the
|
166 |
+
# same line is kept in output. Also, add a
|
167 |
+
# space (before %), automatically inserted by
|
168 |
+
# TeX at the end of file.
|
169 |
+
$line =~ s/\\endinput(.*)\n?/$1 % /;
|
170 |
+
$in_comment = 1;
|
171 |
+
process_line($line, $prefix, \$commented_newline);
|
172 |
+
last;
|
173 |
+
}
|
174 |
+
while (my ($k, $v) = each (%defines))
|
175 |
+
{
|
176 |
+
$line=~s!\\$k!$v!g;
|
177 |
+
}
|
178 |
+
process_line($line, $prefix, \$commented_newline, $file);
|
179 |
+
if ($line =~ /^%.*[^\n]\z/ || $line =~ /[^\\]%.*[^\n]\z/) {
|
180 |
+
# file ends with a comment not ending with a newline
|
181 |
+
print "\n";
|
182 |
+
}
|
183 |
+
# Garbage at end of line after \end{document} is
|
184 |
+
# ignored by LaTeX, but we don't allow anything before
|
185 |
+
# to avoid e.g. \verb|\end{document}| from terminating
|
186 |
+
# the file.
|
187 |
+
if (!$keep_comments && $line =~ /^[ \t]*\\end\{document\}/) {
|
188 |
+
last;
|
189 |
+
}
|
190 |
+
}
|
191 |
+
close($FILE);
|
192 |
+
return $in_comment;
|
193 |
+
}
|
194 |
+
|
195 |
+
sub process_line
|
196 |
+
{
|
197 |
+
my ($line, $prefix, $commented_newline, $file) = @_;
|
198 |
+
$_ = $line;
|
199 |
+
if ($$commented_newline) {
|
200 |
+
# Leading whitespaces after a comment is ignored.
|
201 |
+
# There's no space in:
|
202 |
+
# Line 1%
|
203 |
+
# Line 2.
|
204 |
+
# Match just space and tabs (\s would match \n)
|
205 |
+
s/^[ \t]*//;
|
206 |
+
if (/^$/) {
|
207 |
+
# Deal with:
|
208 |
+
#
|
209 |
+
# Line 1 % comment
|
210 |
+
#
|
211 |
+
# Line 2
|
212 |
+
#
|
213 |
+
# The newline after Line 1 is commented, but we still
|
214 |
+
# want a new paragraph. We strip the comment together
|
215 |
+
# with its newline, but re-add a newline to chnge
|
216 |
+
# paragraph here if needed:
|
217 |
+
print "\n";
|
218 |
+
}
|
219 |
+
}
|
220 |
+
$$commented_newline = 0;
|
221 |
+
# Consider \makeatletter only in preamble, because we do want
|
222 |
+
# to warn on \someCommand{\makeatletter\command@with@arobase}.
|
223 |
+
if ($in_preamble && /^[^%]*\\makeatletter/) {
|
224 |
+
$makeatletter_found = 1;
|
225 |
+
}
|
226 |
+
if ($in_preamble && /^[^%]*\\makeatother/) {
|
227 |
+
$makeatletter_found = 0;
|
228 |
+
}
|
229 |
+
my $command;
|
230 |
+
if (!$makeatletter && !$makeatletter_found
|
231 |
+
&& (($command) = /^[^%]*(\\[[:alpha:]]*@[[:alpha:]]*)/)
|
232 |
+
&& ($command ne '\@')) {
|
233 |
+
print STDERR "Warning: command $command containing @ found in\n";
|
234 |
+
print STDERR "Warning: $file.\n";
|
235 |
+
print STDERR "Warning: consider using --makeatletter if the result is not compilable.\n";
|
236 |
+
}
|
237 |
+
|
238 |
+
# non-comment is a sequence of:
|
239 |
+
# - escaped character (\\.), including \% and \\
|
240 |
+
# - neither '%' nor '\'.
|
241 |
+
my $NON_COMMENT = '([^\\\\%]|\\\\.)*';
|
242 |
+
|
243 |
+
unless ($keep_comments) {
|
244 |
+
# Special-case for \url{} commands, which may contain '%'
|
245 |
+
# characters. It's hard to catch them in $NON_COMMENT since we'd
|
246 |
+
# need a regexp so that "\url{foo" can't match as non-comment in
|
247 |
+
# the line \url{foo%bar}, but "\url{foo%bar}" would match.
|
248 |
+
# Escaping these '%' is not mandatory, but allowed, hence we can
|
249 |
+
# pre-process the line by escaping them, and let latexpand work
|
250 |
+
# as normal afterwards.
|
251 |
+
# Known limitation: latexpand doesn't do balanced braces
|
252 |
+
# recognition, and just refuses both { and } within \url{}
|
253 |
+
# argument for %-detection to work ([^{}%] below). Fix should be
|
254 |
+
# possible using
|
255 |
+
# https://stackoverflow.com/questions/15301708/perl-regular-expression-match-nested-brackets
|
256 |
+
# but is it worth the trouble? (file an issue or send a merge
|
257 |
+
# request if you think it is)
|
258 |
+
|
259 |
+
# While there are \url{URL} with unescaped % in URL ...
|
260 |
+
my $NON_PERCENT = '([^\\}]%|[^{}%])*';
|
261 |
+
while (/^(?<before>.*\\url\{)(?<url>$NON_PERCENT[^\\}]%$NON_PERCENT)(?<after>\}.*)$/) {
|
262 |
+
my ($before, $url, $after) = ($+{before}, $+{url}, $+{after});
|
263 |
+
# escape unescaped % in URL, if any
|
264 |
+
$url =~ s/([^\\])%/$1\\%/g;
|
265 |
+
$_ = $before . $url . $after ."\n";
|
266 |
+
}
|
267 |
+
if (!$empty_comments) {
|
268 |
+
# Include \n in pattern to avoid matching
|
269 |
+
# comments at end of files
|
270 |
+
|
271 |
+
# remove comments + whitespace-only lines completely
|
272 |
+
if (s/^\s*%.*\n//) {
|
273 |
+
$$commented_newline = 1;
|
274 |
+
}
|
275 |
+
|
276 |
+
# Special-case commands at end of line. We
|
277 |
+
# don't want "\\foo%\nbar" to become
|
278 |
+
# "\\foobar" (but we still want \@% to result
|
279 |
+
# in no space!)
|
280 |
+
if (s/^($NON_COMMENT\\([[:alpha:]]|[[:alpha:]@]{2,}))%.*\n/$1 /) {
|
281 |
+
$$commented_newline = 1;
|
282 |
+
} elsif (s/^($NON_COMMENT)%.*\n/$1/) {
|
283 |
+
# remove only the comment if the line has actual content
|
284 |
+
$$commented_newline = 1;
|
285 |
+
}
|
286 |
+
}
|
287 |
+
# Apply the "empty comments" treatment unconditionally
|
288 |
+
# for comments not matched above (it doesn't harm to
|
289 |
+
# keep an empty comment sometimes, but it may harm to
|
290 |
+
# leave a real comment if the goal was to strip them).
|
291 |
+
s/^(([^\\%]|\\.)*)%.*$/$1%/;
|
292 |
+
}
|
293 |
+
|
294 |
+
unless ($keep_includes) {
|
295 |
+
# \input{foo.tex}
|
296 |
+
my $ARGBRACES = '\{\\s*([^"}\\s][^}]*)(\\s*)\}';
|
297 |
+
# \input{"foo bar.tex"}
|
298 |
+
my $ARGQUOTED = '\{\\s*"([^"]*)"(\\s*)\}';
|
299 |
+
# \input foo.tex
|
300 |
+
my $ARGSPACES = '\\s([^\{\\s][^\\s]+?)\\s()';
|
301 |
+
my $ARGUMENT = "\\s*?(?|$ARGBRACES|$ARGQUOTED|$ARGSPACES)";
|
302 |
+
|
303 |
+
if (my ($before, $ignored, $full_filename, $trailing, $after)
|
304 |
+
= /^($NON_COMMENT)\\include$ARGUMENT(.*)$/) {
|
305 |
+
$full_filename = find_tex_file($full_filename . ".tex");
|
306 |
+
if ($full_filename) {
|
307 |
+
say $prefix . "Found include for file: $full_filename\n";
|
308 |
+
print $before . $nl;
|
309 |
+
print '\clearpage{}' . $nl;
|
310 |
+
print "% start include $full_filename\n" if ($explain);
|
311 |
+
my $in_comment = process_file($full_filename, $prefix . " ");
|
312 |
+
if ($explain) {
|
313 |
+
print " % end include $full_filename\n";
|
314 |
+
} elsif ($in_comment) {
|
315 |
+
print "\n";
|
316 |
+
}
|
317 |
+
print '\clearpage{}' . $nl;
|
318 |
+
print $nl . $after . "\n";
|
319 |
+
$_ = "";
|
320 |
+
}
|
321 |
+
} elsif (my ($before, $ignored, $full_filename, $trailing, $after)
|
322 |
+
= /^($NON_COMMENT)\\input$ARGUMENT(.*)$/) {
|
323 |
+
if ($inside_import) {
|
324 |
+
$full_filename = $inside_import . $full_filename;
|
325 |
+
}
|
326 |
+
$full_filename = find_tex_file($full_filename, ":.tex");
|
327 |
+
if ($full_filename) {
|
328 |
+
say $prefix . "Found input for file: $full_filename\n";
|
329 |
+
# Apparently, in some versions of LaTeX, a space
|
330 |
+
# after filename in \input{foo.tex } is inserted
|
331 |
+
# _before_ the inclusion. That was the case for
|
332 |
+
# me when 31fa806 (deal with space after
|
333 |
+
# filename in \input and \include, 2019-12-11)
|
334 |
+
# was written, but is not anymore, hence we just
|
335 |
+
# throw $trailing away.
|
336 |
+
print $before . $nl;
|
337 |
+
print "% start input $full_filename\n" if ($explain);
|
338 |
+
my $in_comment = process_file($full_filename, $prefix . " ");
|
339 |
+
if ($explain) {
|
340 |
+
print " % end input $full_filename\n";
|
341 |
+
} elsif ($in_comment) {
|
342 |
+
print "\n";
|
343 |
+
}
|
344 |
+
if ($after =~ /[^\s]/) {
|
345 |
+
# LaTeX produces this space, so let's do it also
|
346 |
+
print " " . $nl . $after . "\n";
|
347 |
+
} else {
|
348 |
+
print " ";
|
349 |
+
}
|
350 |
+
$_ = "";
|
351 |
+
}
|
352 |
+
} elsif (my ($before, $ignored, $dir, $ignored, $full_filename, $ignored, $after)
|
353 |
+
= /^($NON_COMMENT)\\(?:sub)?import$ARGUMENT$ARGUMENT(.*)$/) {
|
354 |
+
if ($explain) {
|
355 |
+
print "% dir " . $dir ."\n";
|
356 |
+
print "% full_filename " . $full_filename ."\n";
|
357 |
+
print "% after " . $after ."\n";
|
358 |
+
print "% inside_import $inside_import\n";
|
359 |
+
}
|
360 |
+
$full_filename = $dir . $full_filename;
|
361 |
+
if ($inside_import) {
|
362 |
+
$full_filename = $inside_import . $full_filename;
|
363 |
+
}
|
364 |
+
print "% cat(inside_import,dir,full_filename) " . $full_filename ."\n" if ($explain);
|
365 |
+
$full_filename = find_tex_file($full_filename, ":.tex");
|
366 |
+
if ($full_filename) {
|
367 |
+
say $prefix . "Found input for file: $full_filename\n";
|
368 |
+
print $before . $nl;
|
369 |
+
print "% start input $full_filename\n" if ($explain);
|
370 |
+
my $previous_import_dir = $inside_import;
|
371 |
+
$inside_import = $inside_import . $dir;
|
372 |
+
my $in_comment = process_file($full_filename, $prefix . " ");
|
373 |
+
$inside_import = $previous_import_dir;
|
374 |
+
if ($explain) {
|
375 |
+
print " % end input $full_filename\n";
|
376 |
+
} elsif ($in_comment) {
|
377 |
+
print "\n";
|
378 |
+
}
|
379 |
+
if ($after =~ /[^\s]/) {
|
380 |
+
# LaTeX produces this space, so let's do it also
|
381 |
+
print " " . $nl . $after . "\n";
|
382 |
+
} else {
|
383 |
+
print " ";
|
384 |
+
}
|
385 |
+
$_ = "";
|
386 |
+
}
|
387 |
+
} elsif (my ($before, $ignored, $args, $full_filename, $ignored, $after)
|
388 |
+
= /^($NON_COMMENT)\\includegraphics(\[[^\]]*?\]|)$ARGUMENT(.*)$/) {
|
389 |
+
if ($explain) {
|
390 |
+
print "% inside_import " . $inside_import ."\n";
|
391 |
+
print "% before " . $before ."\n";
|
392 |
+
print "% ignored " . $ignored ."\n";
|
393 |
+
print "% args " . $args ."\n";
|
394 |
+
print "% full_filename " . $full_filename ."\n";
|
395 |
+
print "% after " . $after ."\n";
|
396 |
+
}
|
397 |
+
if ($inside_import) {
|
398 |
+
$full_filename = $inside_import . $full_filename;
|
399 |
+
print "$before\\includegraphics" . "$args" . "{$full_filename}$after\n";
|
400 |
+
$_ = "";
|
401 |
+
}
|
402 |
+
} elsif (my ($before, $ignored, $args, $full_filename, $ignored, $after)
|
403 |
+
= /^($NON_COMMENT)\\lstinputlisting(\[[^\]]*?\]|)$ARGUMENT(.*)$/) {
|
404 |
+
if ($explain) {
|
405 |
+
print "% inside_import " . $inside_import ."\n";
|
406 |
+
print "% before " . $before ."\n";
|
407 |
+
print "% ignored " . $ignored ."\n";
|
408 |
+
print "% args " . $args ."\n";
|
409 |
+
print "% full_filename " . $full_filename ."\n";
|
410 |
+
print "% after " . $after ."\n";
|
411 |
+
}
|
412 |
+
if ($inside_import) {
|
413 |
+
$full_filename = $inside_import . $full_filename;
|
414 |
+
print "$before\\lstinputlisting" . "$args" . "{$full_filename}$after\n";
|
415 |
+
$_ = "";
|
416 |
+
}
|
417 |
+
}
|
418 |
+
}
|
419 |
+
if ($expand_usepackage) {
|
420 |
+
# Don't bother with before and after text, we just require the
|
421 |
+
# usepackage to be alone on its line.
|
422 |
+
if (my ($package_name) = /^\s*\\usepackage\{([^\}]*)\}\s*(%.*)?$/) {
|
423 |
+
my $full = find_file($package_name . ".sty", $TEXINPUTS);
|
424 |
+
if ($full) {
|
425 |
+
say $prefix . "Found package file: $full\n";
|
426 |
+
process_file($full, $prefix . " ");
|
427 |
+
$_ = "";
|
428 |
+
# Forget about any commented newline
|
429 |
+
# before the \usepackage:
|
430 |
+
$$commented_newline = 0;
|
431 |
+
} else {
|
432 |
+
say $prefix . "Not including external package $package_name\n";
|
433 |
+
}
|
434 |
+
}
|
435 |
+
}
|
436 |
+
if ($expand_bbl) {
|
437 |
+
if (my ($before, $bib_name, $after)
|
438 |
+
= /^(.*)\\(?:bibliography|bibselect)\{([^\}]*)\}(.*)$/) {
|
439 |
+
# The BBL file is not necessarily $bib_name.
|
440 |
+
# Take it from the command-line.
|
441 |
+
print $before . $nl;
|
442 |
+
say $prefix . "Expanding BBL file: $expand_bbl\n";
|
443 |
+
process_file($expand_bbl, $prefix . " ");
|
444 |
+
print " " . $nl . $after . "\n";
|
445 |
+
$_ = "";
|
446 |
+
}
|
447 |
+
}
|
448 |
+
if ($biber) {
|
449 |
+
if (my ($before, $after)
|
450 |
+
= /^(.*)\\(?:addbibresource)\{[^\}]*\}(.*)$/) {
|
451 |
+
# See https://tex.stackexchange.com/questions/166518/biblatex-include-bbl-problem-with-verb-field/166526#166526
|
452 |
+
my $biber_noext = $biber;
|
453 |
+
$biber_noext =~ s/.bbl//;
|
454 |
+
print $before . $nl;
|
455 |
+
say $prefix . "Expanding Biber BBL file: $biber\n";
|
456 |
+
print '\begin{filecontents*}{' . $biber . '}' . "\n";
|
457 |
+
cat_file($biber);
|
458 |
+
print "\n";
|
459 |
+
print '\end{filecontents*}
|
460 |
+
|
461 |
+
\usepackage{xpatch}
|
462 |
+
|
463 |
+
%Patch the biblatex input command.
|
464 |
+
%replace "testinput-bbl" if you change the name above.
|
465 |
+
%disable if you want to run biblatex/biber normally
|
466 |
+
\makeatletter
|
467 |
+
\patchcmd\blx@bblinput{\blx@blxinit}
|
468 |
+
{\blx@blxinit
|
469 |
+
\def\jobname{' . $biber_noext . '}%new jobname
|
470 |
+
}{}{\fail}
|
471 |
+
\makeatother
|
472 |
+
';
|
473 |
+
say $prefix . "End expansion of Biber BBL file: $biber\n";
|
474 |
+
print " " . $nl . $after . "\n";
|
475 |
+
$_ = "";
|
476 |
+
}
|
477 |
+
}
|
478 |
+
if ($show_graphics) {
|
479 |
+
if (/\\includegraphics(\[[^\]]*\])?{([^}]*)}/) {
|
480 |
+
my $full_filename = $2;
|
481 |
+
if ($inside_import) {
|
482 |
+
$full_filename = $inside_import . $full_filename;
|
483 |
+
}
|
484 |
+
my $full = find_tex_file($full_filename, $graphics_extensions);
|
485 |
+
say $prefix . "needs graphics file: ";
|
486 |
+
print STDERR "$full\n";
|
487 |
+
}
|
488 |
+
}
|
489 |
+
if (/^[ \t]*\\begin\{document\}/) {
|
490 |
+
$in_preamble = 0;
|
491 |
+
if ($makeatletter) {
|
492 |
+
print '\makeatletter' . $nl;
|
493 |
+
}
|
494 |
+
}
|
495 |
+
print;
|
496 |
+
}
|
497 |
+
|
498 |
+
sub unquote
|
499 |
+
{
|
500 |
+
my $str = shift;
|
501 |
+
my $x = substr($str, 0, 1);
|
502 |
+
my $y = substr($str, -1, 1);
|
503 |
+
if ($x eq $y && ($x eq '"' || $x eq "'")) {
|
504 |
+
$str = substr($str, 1, -1);
|
505 |
+
}
|
506 |
+
# There's a weird LaTeX syntax: \include{"file\space
|
507 |
+
# with\space spaces"}, so remove these \space when unquoting.
|
508 |
+
$str =~ s/\\space / /g;
|
509 |
+
return $str;
|
510 |
+
}
|
511 |
+
|
512 |
+
# search $1 in $TEXINPUTS, with possible extensions in $2
|
513 |
+
sub find_tex_file
|
514 |
+
{
|
515 |
+
my $file = unquote(shift);
|
516 |
+
my $extensions = (shift || ":");
|
517 |
+
foreach my $ext (split(':', $extensions, -1)) {
|
518 |
+
my $full = find_file_global($file . $ext);
|
519 |
+
if ($full) {
|
520 |
+
return $full;
|
521 |
+
}
|
522 |
+
}
|
523 |
+
if ($fatal) {
|
524 |
+
die "ERROR: Could not find file [$file]\n";
|
525 |
+
} else {
|
526 |
+
print STDERR "Warning: Could not find file [$file]\n";
|
527 |
+
return;
|
528 |
+
}
|
529 |
+
}
|
530 |
+
|
531 |
+
sub find_file_global
|
532 |
+
{
|
533 |
+
my $file = shift;
|
534 |
+
if (open(my $fh, "-|", "kpsewhich", $file)) {
|
535 |
+
my $full = <$fh>;
|
536 |
+
$full =~ s/\s+$//;
|
537 |
+
close($fh);
|
538 |
+
if ($full) {
|
539 |
+
return $full;
|
540 |
+
}
|
541 |
+
}
|
542 |
+
# Should be useless, but fall-back in case kpsewhich fails (or is not installed, or ...):
|
543 |
+
return find_file($file, $TEXINPUTS);
|
544 |
+
}
|
545 |
+
|
546 |
+
# Find files, not searching for global files (to allow not expanding global .sty packages)
|
547 |
+
sub find_file
|
548 |
+
{
|
549 |
+
my ($file, $path) = @_;
|
550 |
+
if (File::Spec->file_name_is_absolute($file)) {
|
551 |
+
if (-e "$file" && ! -d "$file") {
|
552 |
+
return $file;
|
553 |
+
} else {
|
554 |
+
return;
|
555 |
+
}
|
556 |
+
}
|
557 |
+
|
558 |
+
# TEXINPUTS=...: (trailing :) means "append default search
|
559 |
+
# directories". We don't want global directories here, but
|
560 |
+
# still add . that may be needed.
|
561 |
+
if (substr($path, -1) eq ':') {
|
562 |
+
$path .= '.';
|
563 |
+
}
|
564 |
+
foreach my $dir (split(':', $path)) {
|
565 |
+
if (-e "$dir/$file" && ! -d "$dir/$file") {
|
566 |
+
return("$dir/$file");
|
567 |
+
}
|
568 |
+
}
|
569 |
+
return;
|
570 |
+
}
|
571 |
+
|
572 |
+
|
573 |
+
__END__
|
574 |
+
|
575 |
+
=head1 NAME
|
576 |
+
|
577 |
+
latexpand - Flatten LaTeX file by expanding \include and \input, ... and remove comments
|
578 |
+
|
579 |
+
=head1 SYNOPSIS
|
580 |
+
|
581 |
+
latexpand [options] FILE...
|
582 |
+
|
583 |
+
=head2 Options:
|
584 |
+
|
585 |
+
--verbose show what's going on
|
586 |
+
--keep-comments don't strip comments (comments are lines
|
587 |
+
starting with %, and anything below
|
588 |
+
\end{document})
|
589 |
+
--empty-comments keep empty comments (i.e. % at end of lines) for clarity
|
590 |
+
--keep-includes don't expand \input and \include directives
|
591 |
+
--expand-usepackage
|
592 |
+
Expand \usepackage{...} directives if the
|
593 |
+
corresponding .sty file is found in
|
594 |
+
$TEXINPUTS (or the current directory if
|
595 |
+
$TEXINPUTS is not set)
|
596 |
+
--expand-bbl FILE
|
597 |
+
Expand the bibliography by inlining FILE
|
598 |
+
(should be a *.bbl file)
|
599 |
+
--biber FILE Include \bibliography{} with FILE's content,
|
600 |
+
as needed by biblatex with the biber backend.
|
601 |
+
(similar to --expand-bbl FILE, but for
|
602 |
+
biber+biblatex).
|
603 |
+
--help this help message
|
604 |
+
--define <key>=<val>, -d <key>=<val>
|
605 |
+
defines a macro key to be replaced by value, e.g.,
|
606 |
+
when called with -d foo=bar would replace all occurences
|
607 |
+
of \foo in the code with bar. Can be supplied multiple times.
|
608 |
+
--output <file>, -o <file>
|
609 |
+
generate output in <file>
|
610 |
+
--explain generate explanatory comments in output
|
611 |
+
--show-graphics show included graphics
|
612 |
+
--graphics_extensions
|
613 |
+
colon-separated list of possible graphics extensions
|
614 |
+
(used by --show-graphics to find the actual graphics files)
|
615 |
+
--fatal Die in case a file can't be found.
|
616 |
+
--makeatletter Insert a \makeatletter in the preamble. In some
|
617 |
+
rare cases it may break your document, but it
|
618 |
+
may help fixing bad interactions between
|
619 |
+
@-commands and inclusion (see BUGS section).
|
620 |
+
--in-encoding FMT, --out-encoding FMT
|
621 |
+
File encoding used by input and output files.
|
622 |
+
This uses the same syntax as PerlIO's layers.
|
623 |
+
Example:
|
624 |
+
--in-encoding 'encoding(UTF-8)'
|
625 |
+
The default is 'bytes' and should always work.
|
626 |
+
|
627 |
+
=head1 USES
|
628 |
+
|
629 |
+
The most common use of latexpand is to simplify distribution of source
|
630 |
+
LaTeX files, typically to satisfy the requirement of editors and
|
631 |
+
archival sites (springer, arXiv.org, ...) who force the authors to
|
632 |
+
submit sources. One does not necessarily want to submit sources with
|
633 |
+
comments, and uploading a document made of several files including
|
634 |
+
each other is a bit painful. By default, latexpand answers both
|
635 |
+
problems by outputing a single LaTeX file that contain no comment.
|
636 |
+
|
637 |
+
=head1 GETTING LATEXPAND
|
638 |
+
|
639 |
+
The latest version of latexpand is available here:
|
640 |
+
|
641 |
+
https://gitlab.com/latexpand/latexpand
|
642 |
+
|
643 |
+
Versions are uploaded to ctan.org from time to time:
|
644 |
+
|
645 |
+
http://www.ctan.org/pkg/latexpand
|
646 |
+
|
647 |
+
=head1 BUGS
|
648 |
+
|
649 |
+
Please, report bugs on the issue tracker on the project site:
|
650 |
+
|
651 |
+
https://gitlab.com/latexpand/latexpand/issues
|
652 |
+
|
653 |
+
=head2 Known bugs
|
654 |
+
|
655 |
+
=head3 Verbatim
|
656 |
+
|
657 |
+
latexpand currently ignores \begin{verbatim} ... \end{verbatim}, and
|
658 |
+
will therefore process any \include, \input, ... directives that
|
659 |
+
appear within verbatim environments (while it shouldn't).
|
660 |
+
|
661 |
+
LaTeX comments inside verbatim environments are also incorrectly
|
662 |
+
stripped. You can use --keep-comments as a workaround to avoid this.
|
663 |
+
|
664 |
+
=head3 Comment environment
|
665 |
+
|
666 |
+
It would be nice to remove code between \begin{comment} and
|
667 |
+
\end{comment} too if \usepackage{comment} is used.
|
668 |
+
|
669 |
+
Code like
|
670 |
+
|
671 |
+
foo%
|
672 |
+
\begin{comment}
|
673 |
+
|
674 |
+
will produce the incorrect
|
675 |
+
|
676 |
+
foo\begin{comment}
|
677 |
+
|
678 |
+
A workaround is to use --empty-comments when such tricky usage of the
|
679 |
+
comments package is done.
|
680 |
+
|
681 |
+
=head3 \makeatletter and use with transfig/xfig with \scalebox{}
|
682 |
+
|
683 |
+
If \input{} or \include{} appears as argument to a command, and the
|
684 |
+
file included contains \makeatletter, then after expansion, the
|
685 |
+
\makeatletter and the @-command appear as argument to the command,
|
686 |
+
which is forbidden because the argument is parsed (and the @-command
|
687 |
+
badly tokenized) before being executed.
|
688 |
+
|
689 |
+
This happens with
|
690 |
+
|
691 |
+
\scalebox{ \input{file-generated-by-xfig.pdf_t} }
|
692 |
+
|
693 |
+
Workaround: add \makeatletter before the scalebox manually in your
|
694 |
+
code, like
|
695 |
+
|
696 |
+
\makeatletter{}
|
697 |
+
\scalebox{ \input{file-generated-by-xfig.pdf_t} }
|
698 |
+
\makeatother{}
|
699 |
+
|
700 |
+
In the case of xfig generated files, it is necessary only for the
|
701 |
+
first occurence.
|
702 |
+
|
703 |
+
A more brute-force workaround is to use latexpand --makeatletter.
|
704 |
+
|
705 |
+
=head1 SEE ALSO
|
706 |
+
|
707 |
+
Instructions to include only the relevant .bib items (french):
|
708 |
+
|
709 |
+
https://lacl.fr/~caubert/notes/portabilite-du-tex.html#dependances
|
710 |
+
|
711 |
+
=head1 VERSION
|
712 |
+
|
713 |
+
This is latexpand version v1.7.2.
|
utils/utils.py
ADDED
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import regex
|
3 |
+
import yaml
|
4 |
+
import shutil
|
5 |
+
import bibtexparser
|
6 |
+
from charset_normalizer import from_path
|
7 |
+
from langdetect import detect
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
import numpy as np
|
11 |
+
import networkx as nx
|
12 |
+
import re
|
13 |
+
|
14 |
+
|
15 |
+
def is_venv():
|
16 |
+
return (hasattr(sys, 'real_prefix') or
|
17 |
+
(hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix))
|
18 |
+
|
19 |
+
def read_yaml_file(file_path):
|
20 |
+
with open(file_path, 'r') as file:
|
21 |
+
try:
|
22 |
+
data = yaml.safe_load(file)
|
23 |
+
return data
|
24 |
+
except yaml.YAMLError as e:
|
25 |
+
print(f"Error reading YAML file: {e}")
|
26 |
+
|
27 |
+
def read_tex_file(file_path):
|
28 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
29 |
+
tex_content = file.read()
|
30 |
+
return tex_content
|
31 |
+
|
32 |
+
def write_tex_file(file_path, s):
|
33 |
+
with open(file_path, 'w', encoding='utf-8') as file:
|
34 |
+
file.write(s)
|
35 |
+
|
36 |
+
def get_core(s):
|
37 |
+
start = '\\begin{document}'
|
38 |
+
end = '\\end{document}'
|
39 |
+
beginning_doc = s.find(start)
|
40 |
+
end_doc = s.rfind(end)
|
41 |
+
return s[beginning_doc+len(start):end_doc]
|
42 |
+
|
43 |
+
|
44 |
+
def retrieve_text(text, command, keep_text=False):
|
45 |
+
"""Removes '\\command{*}' from the string 'text'.
|
46 |
+
|
47 |
+
Regex `base_pattern` used to match balanced parentheses taken from:
|
48 |
+
https://stackoverflow.com/questions/546433/regular-expression-to-match-balanced-parentheses/35271017#35271017
|
49 |
+
"""
|
50 |
+
base_pattern = (
|
51 |
+
r'\\' + command + r"(?:\[(?:.*?)\])*\{((?:[^{}]+|\{(?1)\})*)\}(?:\[(?:.*?)\])*"
|
52 |
+
)
|
53 |
+
|
54 |
+
def extract_text_inside_curly_braces(text):
|
55 |
+
"""Extract text inside of {} from command string"""
|
56 |
+
pattern = r"\{((?:[^{}]|(?R))*)\}"
|
57 |
+
|
58 |
+
match = regex.search(pattern, text)
|
59 |
+
|
60 |
+
if match:
|
61 |
+
return match.group(1)
|
62 |
+
else:
|
63 |
+
return ""
|
64 |
+
|
65 |
+
# Loops in case of nested commands that need to retain text, e.g. \red{hello \red{world}}.
|
66 |
+
while True:
|
67 |
+
all_substitutions = []
|
68 |
+
has_match = False
|
69 |
+
for match in regex.finditer(base_pattern, text):
|
70 |
+
# In case there are only spaces or nothing up to the following newline,
|
71 |
+
# adds a percent, not to alter the newlines.
|
72 |
+
has_match = True
|
73 |
+
|
74 |
+
if not keep_text:
|
75 |
+
new_substring = ""
|
76 |
+
else:
|
77 |
+
temp_substring = text[match.span()[0] : match.span()[1]]
|
78 |
+
return extract_text_inside_curly_braces(temp_substring)
|
79 |
+
|
80 |
+
if match.span()[1] < len(text):
|
81 |
+
next_newline = text[match.span()[1] :].find("\n")
|
82 |
+
if next_newline != -1:
|
83 |
+
text_until_newline = text[
|
84 |
+
match.span()[1] : match.span()[1] + next_newline
|
85 |
+
]
|
86 |
+
if (
|
87 |
+
not text_until_newline or text_until_newline.isspace()
|
88 |
+
) and not keep_text:
|
89 |
+
new_substring = "%"
|
90 |
+
all_substitutions.append((match.span()[0], match.span()[1], new_substring))
|
91 |
+
|
92 |
+
for start, end, new_substring in reversed(all_substitutions):
|
93 |
+
text = text[:start] + new_substring + text[end:]
|
94 |
+
|
95 |
+
if not keep_text or not has_match:
|
96 |
+
break
|
97 |
+
|
98 |
+
|
99 |
+
def reduce_linebreaks(s):
|
100 |
+
return re.sub(r'(\n[ \t]*)+(\n[ \t]*)+', '\n\n', s)
|
101 |
+
|
102 |
+
|
103 |
+
def replace_percentage(s):
|
104 |
+
return re.sub(r'% *\n', '\n', s)
|
105 |
+
|
106 |
+
|
107 |
+
def reduce_spaces(s):
|
108 |
+
return re.sub(' +', ' ', s)
|
109 |
+
|
110 |
+
|
111 |
+
def delete_urls(s):
|
112 |
+
return re.sub(r'http\S+', '', s)
|
113 |
+
|
114 |
+
|
115 |
+
def remove_tilde(s):
|
116 |
+
s1 = re.sub(r'[~ ]\.', '.', s)
|
117 |
+
s2 = re.sub(r'[~ ],', ',', s1)
|
118 |
+
return re.sub(r'{}', '', s2)
|
119 |
+
|
120 |
+
|
121 |
+
def remove_verbatim_words(s):
|
122 |
+
with open("configs/latex_commands.yaml", "r") as stream:
|
123 |
+
read_config = yaml.safe_load(stream)
|
124 |
+
|
125 |
+
for command in read_config['verbatim_to_delete']:
|
126 |
+
s = s.replace(command, '')
|
127 |
+
|
128 |
+
for command in read_config['two_arguments']:
|
129 |
+
pattern = r'\\' + command + r'{[^}]*}' + r'{[^}]*}'
|
130 |
+
s = re.sub(pattern, '', s)
|
131 |
+
|
132 |
+
for command in read_config['three_arguments']:
|
133 |
+
pattern = r'\\' + command + r'{[^}]*}' + r'{[^}]*}' + r'{[^}]*}'
|
134 |
+
s = re.sub(pattern, '', s)
|
135 |
+
|
136 |
+
for command in read_config['two_arguments_elaborate']:
|
137 |
+
s = remove_multargument(s, '\\' + command, 2)
|
138 |
+
|
139 |
+
for command in read_config['three_arguments_elaborate']:
|
140 |
+
s = remove_multargument(s, '\\' + command, 3)
|
141 |
+
|
142 |
+
for command in read_config['replace_comments']:
|
143 |
+
pattern = r'\\' + command
|
144 |
+
s = re.sub(pattern, '%', s)
|
145 |
+
|
146 |
+
s = re.sub(
|
147 |
+
r'\\end{[\s]*abstract[\s]*}',
|
148 |
+
'',
|
149 |
+
s,
|
150 |
+
flags=re.IGNORECASE
|
151 |
+
)
|
152 |
+
|
153 |
+
s = re.sub(
|
154 |
+
r'\\begin{[\s]*abstract[\s]*}',
|
155 |
+
'Abstract\n\n',
|
156 |
+
s,
|
157 |
+
flags=re.IGNORECASE
|
158 |
+
)
|
159 |
+
return s
|
160 |
+
|
161 |
+
|
162 |
+
def yes_or_no(s):
|
163 |
+
return 1 if "Yes" == s[0:3] else 0 if "No" == s[0:2] else -1
|
164 |
+
|
165 |
+
|
166 |
+
def get_main(directory):
|
167 |
+
file_paths = []
|
168 |
+
for root, _, files in os.walk(directory):
|
169 |
+
for file in files:
|
170 |
+
file_path = os.path.join(root, file)
|
171 |
+
file_paths.append(file_path)
|
172 |
+
latex_paths = [f for f in file_paths if f.endswith('.tex')]
|
173 |
+
number_tex = len(latex_paths)
|
174 |
+
if number_tex == 0:
|
175 |
+
return None
|
176 |
+
if number_tex == 1:
|
177 |
+
return latex_paths[0]
|
178 |
+
adjacency = np.zeros((number_tex, number_tex))
|
179 |
+
keys = [os.path.basename(path) for path in latex_paths]
|
180 |
+
reg_ex = r'\\input{(.*?)}|\\include{(.*?)}|\\import{(.*?)}|\\subfile{(.*?)}|\\include[*]{(.*?)}|}'
|
181 |
+
for i,file in enumerate(latex_paths):
|
182 |
+
content = read_tex_file(file)
|
183 |
+
find_pattern_input = re.findall(reg_ex, content)
|
184 |
+
find_pattern_input = [tup for tup in find_pattern_input if not all(element == "" for element in tup)]
|
185 |
+
number_matches = len(find_pattern_input)
|
186 |
+
if number_matches == 0:
|
187 |
+
continue
|
188 |
+
else:
|
189 |
+
content = replace_imports(file, content)
|
190 |
+
reg_ex_clean = r'\\input{(.*?)}|\\include{(.*?)}'
|
191 |
+
find_pattern_input = re.findall(reg_ex_clean, content)
|
192 |
+
number_matches = len(find_pattern_input)
|
193 |
+
for j in range(number_matches):
|
194 |
+
match = find_pattern_input[j]
|
195 |
+
non_empty_match = [t for t in match if t]
|
196 |
+
for non_empty in non_empty_match:
|
197 |
+
base_match = os.path.basename(non_empty)
|
198 |
+
if not base_match.endswith('.tex'):
|
199 |
+
base_match = base_match + '.tex'
|
200 |
+
if base_match not in keys:
|
201 |
+
continue
|
202 |
+
ind = keys.index(base_match)
|
203 |
+
adjacency[i][ind] = 1
|
204 |
+
G = nx.from_numpy_array(adjacency, create_using=nx.DiGraph)
|
205 |
+
connected_components = list(nx.weakly_connected_components(G))
|
206 |
+
size_connected = [len(x) for x in connected_components]
|
207 |
+
maximum_size = max(size_connected)
|
208 |
+
biggest_connected = [x for x in connected_components if len(x) == maximum_size]
|
209 |
+
if len(biggest_connected)>1:
|
210 |
+
roots = [n for connected in biggest_connected for n in connected if not list(G.predecessors(n))]
|
211 |
+
_check = []
|
212 |
+
for r in roots:
|
213 |
+
try:
|
214 |
+
_check.append(check_begin(latex_paths[r]))
|
215 |
+
except Exception as e:
|
216 |
+
_check.append(False)
|
217 |
+
potentials_files = [latex_paths[x] for x, y in zip(roots, _check) if y == True]
|
218 |
+
sizes_files = [os.path.getsize(x) for x in potentials_files]
|
219 |
+
return potentials_files[sizes_files.index(max(sizes_files))]
|
220 |
+
|
221 |
+
else:
|
222 |
+
roots = [n for n in biggest_connected[0] if not list(G.predecessors(n))]
|
223 |
+
return latex_paths[roots[0]]
|
224 |
+
|
225 |
+
|
226 |
+
def initial_clean(directory, config):
|
227 |
+
config_cmd = ''
|
228 |
+
if config == True:
|
229 |
+
config_cmd = '--config configs/cleaning_config.yaml'
|
230 |
+
temp_dir = directory[:directory.rfind('/')] + '_temp' + '/'
|
231 |
+
shutil.copytree(directory, temp_dir)
|
232 |
+
try:
|
233 |
+
command_res = os.system('arxiv_latex_cleaner --keep_bib {} {}'.format(directory, config_cmd))
|
234 |
+
if command_res != 0:
|
235 |
+
raise Exception('Error cleaning')
|
236 |
+
else:
|
237 |
+
shutil.rmtree(temp_dir)
|
238 |
+
|
239 |
+
except Exception as e:
|
240 |
+
shutil.rmtree(directory)
|
241 |
+
os.rename(temp_dir, directory)
|
242 |
+
file_paths = []
|
243 |
+
for root, _, files in os.walk(directory):
|
244 |
+
for file in files:
|
245 |
+
file_path = os.path.join(root, file)
|
246 |
+
file_paths.append(file_path)
|
247 |
+
latex_paths = [f for f in file_paths if f.endswith('.tex')]
|
248 |
+
for p in latex_paths:
|
249 |
+
results = from_path(p)
|
250 |
+
with open(p, 'w', encoding='utf-8') as f:
|
251 |
+
f.write(str(results.best()))
|
252 |
+
os.system('arxiv_latex_cleaner --keep_bib {} {}'.format(directory, config_cmd))
|
253 |
+
cleaned_directory = directory[:directory.rfind('/')] + '_arXiv'
|
254 |
+
shutil.rmtree(directory)
|
255 |
+
os.rename(cleaned_directory, directory)
|
256 |
+
|
257 |
+
|
258 |
+
def check_begin(directory):
|
259 |
+
content = read_tex_file(directory)
|
260 |
+
english = detect(content) == 'en'
|
261 |
+
return True and english if re.findall(r'\\begin{document}', content) else False
|
262 |
+
|
263 |
+
|
264 |
+
def post_processing(extracted_dir, file):
|
265 |
+
_dir = os.path.dirname(file) + '/'
|
266 |
+
perl_expand(file)
|
267 |
+
file = _dir + 'merged_latexpand.tex'
|
268 |
+
try:
|
269 |
+
de_macro(file)
|
270 |
+
file = _dir + 'merged_latexpand-clean.tex'
|
271 |
+
except Exception as e:
|
272 |
+
pass
|
273 |
+
try:
|
274 |
+
def_handle(file)
|
275 |
+
except Exception as e:
|
276 |
+
pass
|
277 |
+
try:
|
278 |
+
declare_operator(file) # has additional add-ons
|
279 |
+
except Exception as e:
|
280 |
+
pass
|
281 |
+
try:
|
282 |
+
de_macro(file)
|
283 |
+
file = _dir + os.path.splitext(os.path.basename(file))[0] + '-clean' + '.tex'
|
284 |
+
except Exception as e:
|
285 |
+
pass
|
286 |
+
initial_clean(_dir, config=True)
|
287 |
+
initial_clean(_dir, config=False)
|
288 |
+
tex_content = read_tex_file(file)
|
289 |
+
final_tex = reduce_spaces(
|
290 |
+
delete_urls(
|
291 |
+
remove_tilde(
|
292 |
+
reduce_linebreaks(
|
293 |
+
replace_percentage(
|
294 |
+
remove_verbatim_words(
|
295 |
+
tex_content
|
296 |
+
)
|
297 |
+
)
|
298 |
+
)
|
299 |
+
)
|
300 |
+
)
|
301 |
+
).strip()
|
302 |
+
shutil.rmtree(extracted_dir)
|
303 |
+
os.makedirs(extracted_dir)
|
304 |
+
write_tex_file(extracted_dir + 'final_cleaned.tex', final_tex)
|
305 |
+
initial_clean(extracted_dir, config=False)
|
306 |
+
return extracted_dir + 'final_cleaned.tex'
|
307 |
+
|
308 |
+
|
309 |
+
def perl_expand(file):
|
310 |
+
# Save the current working directory
|
311 |
+
oldpwd = os.getcwd()
|
312 |
+
target_dir = os.path.dirname(file) + '/'
|
313 |
+
# Correctly construct the path
|
314 |
+
target = os.path.join(target_dir, 'latexpand')
|
315 |
+
src = './src/utils/latexpand'
|
316 |
+
# Copy the `latexpand` script to the target directory
|
317 |
+
shutil.copyfile(src, target)
|
318 |
+
# Change to the target directory
|
319 |
+
os.chdir(target_dir)
|
320 |
+
|
321 |
+
# Run the perl command without shell=True and handle redirection within Python
|
322 |
+
with open('merged_latexpand.tex', 'w') as output_file:
|
323 |
+
subprocess.run(['perl', 'latexpand', os.path.basename(file)],
|
324 |
+
stdout=output_file, stderr=subprocess.DEVNULL)
|
325 |
+
|
326 |
+
# Return to the original directory
|
327 |
+
os.chdir(oldpwd)
|
328 |
+
|
329 |
+
|
330 |
+
def de_macro(file):
|
331 |
+
# Save the current working directory\
|
332 |
+
oldpwd = os.getcwd()
|
333 |
+
target_dir = os.path.dirname(file) + '/'
|
334 |
+
# Construct the target path
|
335 |
+
target = os.path.join(target_dir, 'de-macro.py')
|
336 |
+
src = '.src/utils/de-macro.py'
|
337 |
+
|
338 |
+
# Copy the `de-macro.py` script to the target directory
|
339 |
+
shutil.copyfile(src, target)
|
340 |
+
# Change to the target directory
|
341 |
+
os.chdir(target_dir)
|
342 |
+
|
343 |
+
# Run the de-macro script without os.system and capture errors
|
344 |
+
try:
|
345 |
+
subprocess.run(['python3', 'de-macro.py', os.path.basename(file)],
|
346 |
+
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
347 |
+
except subprocess.CalledProcessError as e:
|
348 |
+
raise Exception(f"Error de-macro: {e}") from e
|
349 |
+
finally:
|
350 |
+
# Always return to the original directory
|
351 |
+
os.chdir(oldpwd)
|
352 |
+
|
353 |
+
|
354 |
+
def def_handle(file):
|
355 |
+
h = os.system('python3 src/utils/def_handle.py {} --output {}'.format(file, file))
|
356 |
+
if h != 0:
|
357 |
+
raise Exception('Error def handle')
|
358 |
+
|
359 |
+
|
360 |
+
def declare_operator(file):
|
361 |
+
s = read_tex_file(file)
|
362 |
+
## Operators
|
363 |
+
pattern = r'\\DeclareMathOperator'
|
364 |
+
s = re.sub(pattern, r'\\newcommand', s)
|
365 |
+
pattern = {
|
366 |
+
r'\\newcommand\*': r'\\newcommand',
|
367 |
+
r'\\providecommand\*': r'\\newcommand',
|
368 |
+
r'\\providecommand': r'\\newcommand',
|
369 |
+
r'\\renewcommand\*': r'\\renewcommand',
|
370 |
+
r'\\newenvironment\*': r'\\newenvironment',
|
371 |
+
r'\\renewenvironment\*': r'\\renewenvironment'
|
372 |
+
}
|
373 |
+
s = re.sub(r'\\end +', r'\\end', s)
|
374 |
+
for key in pattern:
|
375 |
+
s = re.sub(key, pattern[key], s)
|
376 |
+
## Title
|
377 |
+
start = '\\begin{document}'
|
378 |
+
beginning_doc = s.find(start)
|
379 |
+
pattern = {
|
380 |
+
r'\\icmltitlerunning\*': r'\\title',
|
381 |
+
r'\\icmltitlerunning': r'\\title',
|
382 |
+
r'\\inlinetitle\*': r'\\title',
|
383 |
+
r'\\icmltitle\*': r'\\title',
|
384 |
+
r'\\inlinetitle': r'\\title',
|
385 |
+
r'\\icmltitle': r'\\title',
|
386 |
+
r'\\titlerunning\*': r'\\title',
|
387 |
+
r'\\titlerunning': r'\\title',
|
388 |
+
r'\\toctitle': r'\\title',
|
389 |
+
r'\\title\*': r'\\title',
|
390 |
+
r'\\TITLE\*': r'\\title',
|
391 |
+
r'\\TITLE': r'\\title',
|
392 |
+
r'\\Title\*': r'\\title',
|
393 |
+
r'\\Title': r'\\title',
|
394 |
+
}
|
395 |
+
for key in pattern:
|
396 |
+
s = re.sub(key, pattern[key], s)
|
397 |
+
find_potential = s.find('\\title')
|
398 |
+
|
399 |
+
## Remove \\
|
400 |
+
title_content = retrieve_text(s, 'title', keep_text = True)
|
401 |
+
if title_content != None:
|
402 |
+
cleaned_title = re.sub(r'\\\\', ' ', title_content)
|
403 |
+
cleaned_title = re.sub(r'\n',' ', cleaned_title)
|
404 |
+
cleaned_title = re.sub(r'\~',' ', cleaned_title)
|
405 |
+
s = s.replace(title_content, cleaned_title)
|
406 |
+
if find_potential != -1 and find_potential < beginning_doc:
|
407 |
+
s = s.replace('\\maketitle', cleaned_title)
|
408 |
+
|
409 |
+
## Cite and ref commands
|
410 |
+
pattern = {
|
411 |
+
r'\\citep\*': r'\\cite',
|
412 |
+
r'\\citet\*': r'\\cite',
|
413 |
+
r'\\citep': r'\\cite',
|
414 |
+
r'\\citet': r'\\cite',
|
415 |
+
r'\\cite\*': r'\\cite',
|
416 |
+
r'\\citealt\*': r'\\cite',
|
417 |
+
r'\\citealt': r'\\cite',
|
418 |
+
r'\\citealtp\*': r'\\cite',
|
419 |
+
r'\\citealp': r'\\cite',
|
420 |
+
r'\\citeyear\*': r'\\cite',
|
421 |
+
r'\\citeyear': r'\\cite',
|
422 |
+
r'\\citeauthor\*': r'\\cite',
|
423 |
+
r'\\citeauthor': r'\\cite',
|
424 |
+
r'\\citenum\*': r'\\cite',
|
425 |
+
r'\\citenum': r'\\cite',
|
426 |
+
r'\\cref': r'\\ref',
|
427 |
+
r'\\Cref': r'\\ref',
|
428 |
+
r'\\factref': r'\\ref',
|
429 |
+
r'\\appref': r'\\ref',
|
430 |
+
r'\\thmref': r'\\ref',
|
431 |
+
r'\\secref': r'\\ref',
|
432 |
+
r'\\lemref': r'\\ref',
|
433 |
+
r'\\corref': r'\\ref',
|
434 |
+
r'\\eqref': r'\\ref',
|
435 |
+
r'\\autoref': r'\\ref',
|
436 |
+
r'begin{thm}': r'begin{theorem}',
|
437 |
+
r'begin{lem}': r'begin{lemma}',
|
438 |
+
r'begin{cor}': r'begin{corollary}',
|
439 |
+
r'begin{exm}': r'begin{example}',
|
440 |
+
r'begin{defi}': r'begin{definition}',
|
441 |
+
r'begin{rem}': r'begin{remark}',
|
442 |
+
r'begin{prop}': r'begin{proposition}',
|
443 |
+
r'end{thm}': r'end{theorem}',
|
444 |
+
r'end{lem}': r'end{lemma}',
|
445 |
+
r'end{cor}': r'end{corollary}',
|
446 |
+
r'end{exm}': r'end{example}',
|
447 |
+
r'end{defi}': r'end{definition}',
|
448 |
+
r'end{rem}': r'end{remark}',
|
449 |
+
r'end{prop}': r'end{proposition}',
|
450 |
+
}
|
451 |
+
|
452 |
+
for key in pattern:
|
453 |
+
s = re.sub(key, pattern[key], s)
|
454 |
+
|
455 |
+
|
456 |
+
pattern = {
|
457 |
+
r'subsubsection': r'section',
|
458 |
+
r'subsubsection ': r'section',
|
459 |
+
r'subsubsection\*': r'section',
|
460 |
+
r'subsubsection\* ': r'section',
|
461 |
+
r'subsection': r'section',
|
462 |
+
r'subsection ': r'section',
|
463 |
+
r'subsection\*': r'section',
|
464 |
+
r'subsection\* ': r'section',
|
465 |
+
r'section ': r'section',
|
466 |
+
r'section\*': r'section',
|
467 |
+
r'section\* ': r'section',
|
468 |
+
r'chapter': r'section',
|
469 |
+
r'chapter ': r'section',
|
470 |
+
r'chapter\*': r'section',
|
471 |
+
r'chapter\* ': r'section',
|
472 |
+
r'mysubsubsection': r'section',
|
473 |
+
r'mysubsection': r'section',
|
474 |
+
r'mysection': r'section',
|
475 |
+
}
|
476 |
+
|
477 |
+
for key in pattern:
|
478 |
+
s = re.sub(key, pattern[key], s)
|
479 |
+
|
480 |
+
# In case any new commands for appendix/appendices
|
481 |
+
s = re.sub(r'newcommand{\\appendix}', '', s)
|
482 |
+
s = re.sub(r'newcommand{\\appendices}', '', s)
|
483 |
+
s = get_core(s)
|
484 |
+
|
485 |
+
## In case of double titles being defined
|
486 |
+
title_content = retrieve_text(s, 'title', keep_text = True)
|
487 |
+
if title_content != None:
|
488 |
+
cleaned_title = re.sub(r'\\\\', ' ', title_content)
|
489 |
+
cleaned_title = re.sub(r'\n',' ', cleaned_title)
|
490 |
+
cleaned_title = re.sub(r'\~',' ', cleaned_title)
|
491 |
+
s = s.replace(title_content, cleaned_title)
|
492 |
+
write_tex_file(file, s)
|
493 |
+
|
494 |
+
|
495 |
+
def replace_imports(file, s):
|
496 |
+
regex_p1 = r'\\import{(.*?)}{(.*?)}'
|
497 |
+
s = re.sub(regex_p1, r"\\input{\1\2}", s)
|
498 |
+
regex_p2 = r'\\subfile{(.*?)}'
|
499 |
+
s = re.sub(regex_p2, r"\\input{\1}", s)
|
500 |
+
regex_p3 = r'\\include[*]{(.*?)}'
|
501 |
+
s = re.sub(regex_p3, r"\\input{\1}", s)
|
502 |
+
write_tex_file(file, s)
|
503 |
+
return s
|
504 |
+
|
505 |
+
|
506 |
+
def remove_multargument(s, target, k):
|
507 |
+
ind = s.find(target)
|
508 |
+
while ind != -1:
|
509 |
+
start_ind = ind + len(target)
|
510 |
+
stack_open = 0
|
511 |
+
stack_close = 0
|
512 |
+
track_arg = 0
|
513 |
+
for i, char in enumerate(s[start_ind:]):
|
514 |
+
if char == '{':
|
515 |
+
stack_open += 1
|
516 |
+
if char == '}':
|
517 |
+
stack_close += 1
|
518 |
+
if stack_open !=0 and stack_close !=0:
|
519 |
+
if stack_open == stack_close:
|
520 |
+
track_arg += 1
|
521 |
+
stack_open = 0
|
522 |
+
stack_close = 0
|
523 |
+
if track_arg == k:
|
524 |
+
break
|
525 |
+
s = s[:ind] + s[start_ind + i + 1:]
|
526 |
+
ind = s.find(target)
|
527 |
+
return s
|
528 |
+
|
529 |
+
|
530 |
+
def fix_citations(s):
|
531 |
+
pattern = {
|
532 |
+
r'\\citep\*': r'\\cite',
|
533 |
+
r'\\citet\*': r'\\cite',
|
534 |
+
r'\\citep': r'\\cite',
|
535 |
+
r'\\citet': r'\\cite',
|
536 |
+
r'\\cite\*': r'\\cite',
|
537 |
+
r'\\citealt\*': r'\\cite',
|
538 |
+
r'\\citealt': r'\\cite',
|
539 |
+
r'\\citealtp\*': r'\\cite',
|
540 |
+
r'\\citealp': r'\\cite',
|
541 |
+
r'\\citeyear\*': r'\\cite',
|
542 |
+
r'\\citeyear': r'\\cite',
|
543 |
+
r'\\citeauthor\*': r'\\cite',
|
544 |
+
r'\\citeauthor': r'\\cite',
|
545 |
+
r'\\citenum\*': r'\\cite',
|
546 |
+
r'\\citenum': r'\\cite'
|
547 |
+
}
|
548 |
+
for key in pattern:
|
549 |
+
s = re.sub(key, pattern[key], s)
|
550 |
+
return s
|
551 |
+
|
552 |
+
def find_bib(directory):
|
553 |
+
file_paths = []
|
554 |
+
for root, _, files in os.walk(directory):
|
555 |
+
for file in files:
|
556 |
+
file_path = os.path.join(root, file)
|
557 |
+
file_paths.append(file_path)
|
558 |
+
bib_paths = [f for f in file_paths if f.endswith('.bib')]
|
559 |
+
return bib_paths
|
560 |
+
|
561 |
+
def create_bib_from_bbl(bibfile):
|
562 |
+
with open(bibfile, 'r') as f:
|
563 |
+
content = f.read()
|
564 |
+
library_raw = bibtexparser.parse_string(content)
|
565 |
+
library = {}
|
566 |
+
for block in library_raw.blocks:
|
567 |
+
if isinstance(
|
568 |
+
block,
|
569 |
+
(bibtexparser.model.DuplicateBlockKeyBlock, bibtexparser.model.ParsingFailedBlock, bibtexparser.model.ImplicitComment)
|
570 |
+
):
|
571 |
+
continue
|
572 |
+
fields = {}
|
573 |
+
for field in block.fields:
|
574 |
+
fields[field.key] = field.value
|
575 |
+
|
576 |
+
## Get a good title one ##
|
577 |
+
field_content = fields["note"]
|
578 |
+
field_content = field_content.replace("\n", " ")
|
579 |
+
field_content = re.sub(" +", " ", field_content)
|
580 |
+
if field_content.find("``") != -1 and field_content.find("\'\'") != -1:
|
581 |
+
title = (
|
582 |
+
field_content[field_content.find("``") + 2 : field_content.find("\'\'")]
|
583 |
+
.replace("\\emph", "")
|
584 |
+
.replace("\\emp", "")
|
585 |
+
.replace("\\em", "")
|
586 |
+
.replace(",", "")
|
587 |
+
.replace("{", "")
|
588 |
+
.replace("}","")
|
589 |
+
.replace("``", "")
|
590 |
+
.replace("\'\'", "")
|
591 |
+
.strip(".")
|
592 |
+
.strip()
|
593 |
+
.strip(".")
|
594 |
+
.lower()
|
595 |
+
)
|
596 |
+
fields['title'] = title
|
597 |
+
else:
|
598 |
+
if field_content.count("\\newblock") == 2:
|
599 |
+
field_content = field_content.replace("\\newblock", "``", 1)
|
600 |
+
field_content = field_content.replace("\\newblock", "\'\'", 1)
|
601 |
+
if field_content.find("``") != -1 and field_content.find("\'\'") != -1:
|
602 |
+
title = (
|
603 |
+
field_content[field_content.find("``") + 2 : field_content.find("\'\'")]
|
604 |
+
.replace("\\emph", "")
|
605 |
+
.replace("\\emp", "")
|
606 |
+
.replace("\\em", "")
|
607 |
+
.replace(",", "")
|
608 |
+
.replace("{", "")
|
609 |
+
.replace("}","")
|
610 |
+
.replace("``", "")
|
611 |
+
.replace("\'\'", "")
|
612 |
+
.strip(".")
|
613 |
+
.strip()
|
614 |
+
.strip(".")
|
615 |
+
.lower()
|
616 |
+
)
|
617 |
+
fields['title'] = title
|
618 |
+
library[block.key] = fields
|
619 |
+
return library
|
620 |
+
|
621 |
+
|
622 |
+
def create_bib(bibfile):
|
623 |
+
with open(bibfile, 'r') as f:
|
624 |
+
content = f.read()
|
625 |
+
library_raw = bibtexparser.parse_string(content)
|
626 |
+
|
627 |
+
library = {}
|
628 |
+
for block in library_raw.blocks:
|
629 |
+
if isinstance(
|
630 |
+
block,
|
631 |
+
(bibtexparser.model.DuplicateBlockKeyBlock, bibtexparser.model.ParsingFailedBlock, bibtexparser.model.ImplicitComment)
|
632 |
+
):
|
633 |
+
continue
|
634 |
+
fields = {}
|
635 |
+
for field in block.fields:
|
636 |
+
fields[field.key] = field.value.replace('{', '').replace('}', '')
|
637 |
+
if field.key == 'title':
|
638 |
+
title = re.sub(r'[\n]+', ' ', field.value) # keep only one \n
|
639 |
+
title = re.sub(r' +', ' ', title)
|
640 |
+
fields[field.key] = (
|
641 |
+
title.replace("\\emph", "")
|
642 |
+
.replace("\\emp", "")
|
643 |
+
.replace("\\em", "")
|
644 |
+
.replace(",", "")
|
645 |
+
.replace("{", "")
|
646 |
+
.replace("}", "")
|
647 |
+
.strip(".")
|
648 |
+
.strip()
|
649 |
+
.strip(".")
|
650 |
+
.lower()
|
651 |
+
)
|
652 |
+
if 'title' not in fields:
|
653 |
+
continue
|
654 |
+
library[block.key] = fields
|
655 |
+
return library
|
656 |
+
|
657 |
+
|
658 |
+
def find_bbl(directory):
|
659 |
+
file_paths = []
|
660 |
+
for root, _, files in os.walk(directory):
|
661 |
+
for file in files:
|
662 |
+
file_path = os.path.join(root, file)
|
663 |
+
file_paths.append(file_path)
|
664 |
+
bib_paths = [f for f in file_paths if f.endswith('.bbl')]
|
665 |
+
return bib_paths
|
666 |
+
|
667 |
+
|
668 |
+
def textobib(file):
|
669 |
+
oldpwd = os.getcwd()
|
670 |
+
target_dir = os.path.dirname(file) + '/'
|
671 |
+
target = target_dir + 'tex2bib'
|
672 |
+
src = './tex2bib'
|
673 |
+
shutil.copyfile(src, target)
|
674 |
+
os.chdir(target_dir)
|
675 |
+
output_file = os.path.splitext(os.path.basename(file))[0] + '.bib'
|
676 |
+
os.system('perl tex2bib -i {} -o {}'.format(os.path.basename(file), output_file))
|
677 |
+
os.chdir(oldpwd)
|
678 |
+
return target_dir + output_file
|
679 |
+
|
680 |
+
|
681 |
+
def get_library_bib(bib_files):
|
682 |
+
library = []
|
683 |
+
for bib_file in bib_files:
|
684 |
+
library.append(create_bib(bib_file))
|
685 |
+
final_library = {}
|
686 |
+
for d in library:
|
687 |
+
final_library.update(d)
|
688 |
+
return final_library
|
689 |
+
|
690 |
+
|
691 |
+
def get_library_bbl(bbl_files):
|
692 |
+
bib_files = []
|
693 |
+
for bbl_file in bbl_files:
|
694 |
+
bib_files.append(textobib(bbl_file))
|
695 |
+
library = []
|
696 |
+
for bib_file in bib_files:
|
697 |
+
library.append(create_bib_from_bbl(bib_file))
|
698 |
+
final_library = {}
|
699 |
+
for d in library:
|
700 |
+
final_library.update(d)
|
701 |
+
return final_library
|