Andreas99 commited on
Commit
3647b6e
·
verified ·
1 Parent(s): 4b13db1

Upload 30 files

Browse files
.gitattributes ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets/economics.gexf filter=lfs diff=lfs merge=lfs -text
2
+ datasets/finance.gexf filter=lfs diff=lfs merge=lfs -text
3
+ datasets/large_language_models.gexf filter=lfs diff=lfs merge=lfs -text
4
+ datasets/quantative_biology.gexf filter=lfs diff=lfs merge=lfs -text
5
+ datasets/quantum_physics.gexf filter=lfs diff=lfs merge=lfs -text
6
+ datasets/robotics.gexf filter=lfs diff=lfs merge=lfs -text
7
+ datasets/telecommunications.gexf filter=lfs diff=lfs merge=lfs -text
configs/alpaca.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Template used by Alpaca-LoRA.",
3
+ "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
4
+ "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
5
+ "response_split": "### Response:"
6
+ }
configs/cleaning_config.yaml ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ patterns_and_insertions:
2
+ [
3
+ {
4
+ "pattern" : '(?:\\figcomp{\s*)(?P<first>.*?)\s*}\s*{\s*(?P<second>.*?)\s*}\s*{\s*(?P<third>.*?)\s*}',
5
+ "insertion" : '\parbox[c]{{ {second} \linewidth}} {{ \includegraphics[width= {third} \linewidth]{{figures/{first} }} }}',
6
+ "description" : "Replace figcomp",
7
+ },
8
+ ]
9
+
10
+ verbose: False
11
+
12
+ commands_to_delete: [
13
+ 'footnote',
14
+ 'footnote ',
15
+ 'crdata',
16
+ 'appendixhead',
17
+ 'selectlanguage',
18
+ 'name',
19
+ 'expandafter',
20
+ 'copyrightyear',
21
+ 'acmYear',
22
+ 'acmBooktitle',
23
+ 'acmPrice',
24
+ 'authorcontributions',
25
+ 'thanksref',
26
+ 'funding',
27
+ 'conflictsofinterest',
28
+ 'externalbibliography',
29
+ 'acmDOI',
30
+ 'acmISBN',
31
+ 'acmConference',
32
+ 'titleheader',
33
+ 'affil',
34
+ 'authorrunning',
35
+ 'pagenumbering',
36
+ 'enlargethispage',
37
+ 'author',
38
+ 'AuthorNames',
39
+ 'author\n',
40
+ 'Author\n',
41
+ 'Author',
42
+ 'fntext',
43
+ 'icmlauthor',
44
+ 'icmlauthor\n',
45
+ 'icmladdress',
46
+ 'icmladdress\n',
47
+ 'received',
48
+ 'runninghead',
49
+ 'bstctlcite',
50
+ 'slugger',
51
+ 'tocauthor',
52
+ 'author\*',
53
+ 'vspace\*',
54
+ '\write18',
55
+ 'hspace\*',
56
+ 'vspace',
57
+ 'hspace',
58
+ 'maketitle',
59
+ 'institute',
60
+ 'label',
61
+ 'urlstyle',
62
+ 'acks',
63
+ 'tnoteref',
64
+ 'Appendix',
65
+ 'urlstyle',
66
+ 'url',
67
+ 'editor',
68
+ 'ccsdesc',
69
+ 'cortext',
70
+ 'bibliography',
71
+ 'permission',
72
+ 'usetikzlibrary',
73
+ 'thanks',
74
+ 'thispagestyle',
75
+ 'abovedisplayskip',
76
+ 'belowdisplayskip',
77
+ 'bibliographystyle',
78
+ 'IfSubStringInString',
79
+ 'hyphenation',
80
+ 'theoremstyle',
81
+ 'colorbox',
82
+ 'textcolor',
83
+ 'color',
84
+ 'caption',
85
+ 'thlabel',
86
+ 'fancyfoot',
87
+ 'captionof',
88
+ 'settopmatter',
89
+ 'IEEEtriggeratref',
90
+ 'IEEEauthorblockN',
91
+ 'IEEEauthorblockA',
92
+ 'IEEEauthorblockN\n',
93
+ 'IEEEauthorblockA\n',
94
+ 'IEEEauthorrefmark',
95
+ 'orcid',
96
+ 'typeout',
97
+ 'fancyhead',
98
+ 'pagestyle',
99
+ 'biboptions',
100
+ 'affiliation',
101
+ 'address',
102
+ 'institution',
103
+ 'printalgoIEEE',
104
+ 'date',
105
+ 'authornote',
106
+ 'numberofauthors',
107
+ 'footnotetext',
108
+ 'email',
109
+ 'reftitle',
110
+ 'setcopyright',
111
+ 'ead',
112
+ 'deleted',
113
+ 'includegraphics',
114
+ 'comment',
115
+ 'abstract',
116
+ 'replaced',
117
+ 'xspace',
118
+ ]
119
+
120
+ commands_only_to_delete: [
121
+ 'titlerunning',
122
+ 'runningtitle',
123
+ 'title',
124
+ 'title\*',
125
+ 'accept',
126
+ 'added',
127
+ 'icmltitle',
128
+ 'textsuperscript',
129
+ 'texttt',
130
+ 'textsc',
131
+ 'textit',
132
+ 'mathit',
133
+ 'makebox',
134
+ 'mbox',
135
+ 'textbf',
136
+ 'acl',
137
+ 'textnormal',
138
+ 'texttt ',
139
+ 'textsc ',
140
+ 'textit ',
141
+ 'mathit ',
142
+ 'textbf ',
143
+ 'gls',
144
+ 'Gls',
145
+ 'glspl',
146
+ 'textnormal ',
147
+ 'inlinetitle',
148
+ 'mbox',
149
+ 'hl',
150
+ 'highlight',
151
+ 'IEEEraisesectionheading',
152
+ 'IEEEtitleabstractindextext',
153
+ 'IEEEPARstart',
154
+ 'it',
155
+ 'added',
156
+ 'paragraph',
157
+ 'paragraph\*',
158
+ 'MakeLowercase',
159
+ 'emph',
160
+ 'emph ',
161
+ 'text',
162
+ 'acp',
163
+ 'ac'
164
+ ]
165
+
166
+ environments_to_delete: [
167
+ 'icmlauthorlist',
168
+ 'tikzpicture',
169
+ 'groupplot',
170
+ 'biography',
171
+ 'IEEEbiographynophoto',
172
+ 'acronym',
173
+ 'MSCcodes',
174
+ 'IEEEbiography',
175
+ 'figure',
176
+ 'AMS',
177
+ 'acknowledgement',
178
+ 'acknowledgments',
179
+ 'acknowledgements',
180
+ 'figure\*',
181
+ 'minipage',
182
+ 'table',
183
+ 'table\*',
184
+ 'glist',
185
+ 'tabular',
186
+ 'tabular\*',
187
+ 'center',
188
+ 'remark',
189
+ 'algorithm',
190
+ 'algorithmic',
191
+ 'CCSXML',
192
+ 'acks',
193
+ 'lstlisting',
194
+ 'tabu',
195
+ 'algorithm\*',
196
+ 'algorithmic\*',
197
+ 'longtable',
198
+ 'sidewaystable\*',
199
+ 'sidewaystable',
200
+ 'appendices',
201
+ 'wrapfigure',
202
+ 'appendix'
203
+ ]
configs/config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_downloading:
2
+ download_directory: "quant_bio_retrieval/" # directory where the papers will be downloaded and the graph will be saved
3
+ gexf_file: "test_graph.gexf" # name of the graph file that will be created only if downloading option is true
4
+ processing:
5
+ random_seed: 10
6
+ keep_unstructured_content: false # keep unstructured content of the papers as graph node attribute if true
7
+ arxiv_rate_limit: 3 # time in seconds to wait between each arxiv api call to avoid ban
8
+
9
+ retriever:
10
+ embedder: BAAI/bge-large-en-v1.5
11
+ num_retrievals: 30000
12
+ load_arxiv_embeds: True # load arxiv embeddings from huggingface if true else generate them
13
+
14
+ inference:
15
+ base_model: meta-llama/Meta-Llama-3-8B
16
+ pretrained_model: "models/Robotics/Meta-LLama-3-8B-Quantative-Robotics" # used only if training option is false
17
+ generation_args:
18
+ max_new_tokens: 1000
19
+ do_sample: True
20
+ top_p: 0.9
21
+ top_k: 50
22
+ temperature: 0.7
23
+ no_repeat_ngram_size: 2
24
+ num_beams: 1
25
+ gen_related_work_instruct_model: meta-llama/Llama-3.1-8B-Instruct # Model assisting at the generation of related work instructions
26
+
27
+ training:
28
+ predefined_graph_path: "robotics.gexf" # path to the graph dataset used for fine-tuning only if downloading option is false
29
+ trainer_args:
30
+ per_device_train_batch_size: 4
31
+ warmup_steps: 100
32
+ num_train_epochs: 1
33
+ learning_rate: 0.0002
34
+ lr_scheduler_type: 'cosine'
35
+ fp16: true
36
+ logging_steps: 1
37
+ save_steps: 50
38
+ trainer_output_dir: trainer_outputs/
39
+ tokenizer:
40
+ max_length: 1024
41
+ qlora:
42
+ rank: 8
43
+ lora_alpha: 32
44
+ lora_dropout: 0.05
45
+ target_modules: # modules for which to train lora adapters
46
+ - q_proj
47
+ - k_proj
48
+ - v_proj
49
+ - o_proj
50
+
51
+ # Used only if training option is true to save and load the fine-tuned model
52
+ model_saving:
53
+ model_name: llama_1b_qlora_uncensored
54
+ model_output_dir: models # model saved in {model_output_dir}/{model_name}_{index} # model saved in {model_output_dir}/{model_name}_{index} after fine-tuning completion
55
+ index: 1
configs/config_noUI.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note: In order to train the model and then evaluate the same model, you need to
2
+ # make sure that the base_model to be the same in both the train and eval sections
3
+
4
+ # Evaluation configuration
5
+ eval:
6
+ base_model: meta-llama/Llama-3.2-1B
7
+ graph_path: datasets/quantum_graph.gexf
8
+ model_name: llama_1b_qlora_uncensored
9
+
10
+ # Training configuration
11
+ training:
12
+ graph_path: datasets/quantum_graph.gexf # path to the graph file to train on
13
+ base_model: meta-llama/Llama-3.2-1B
14
+ trainer_args:
15
+ per_device_train_batch_size: 4
16
+ warmup_steps: 100
17
+ num_train_epochs: 1
18
+ learning_rate: 0.0002
19
+ lr_scheduler_type: 'cosine'
20
+ fp16: true
21
+ logging_steps: 1
22
+ save_steps: 50
23
+ trainer_output_dir: trainer_outputs/
24
+ tokenizer:
25
+ max_length: 1024
26
+ qlora:
27
+ rank: 8
28
+ lora_alpha: 32
29
+ lora_dropout: 0.05
30
+ target_modules: # modules for which to train lora adapters
31
+ - q_proj
32
+ - k_proj
33
+ - v_proj
34
+ - o_proj
35
+ model_saving:
36
+ model_output_dir: models # model saved in {model_output_dir}/{model_name} after fine-tuning completion
37
+ model_name: llama_1b_qlora_uncensored
configs/latex_commands.yaml ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ verbatim_to_delete: [
2
+ '\IEEEpeerreviewmaketitle',
3
+ '\normalcolor',
4
+ '\ifCLASSOPTIONcaptionsoff',
5
+ '\pagebreak',
6
+ '\makeatletter',
7
+ '\makeatother',
8
+ '\maketitle',
9
+ '\preface',
10
+ '\eShell',
11
+ '\medskip',
12
+ '\tableofcontents',
13
+ '\begin{@twocolumnfalse}',
14
+ '\end{@twocolumnfalse}',
15
+ '\bgroup',
16
+ '\egroup',
17
+ '\ifnalpaper1',
18
+ '\let\thefootnote',
19
+ '\begin{spacing}{2.0}',
20
+ '\end{landscape}',
21
+ '\begin{landscape}',
22
+ '\begin{doublespacing}',
23
+ '\end{doublespacing}',
24
+ '\begin{spacing}',
25
+ '\end{spacing}',
26
+ '\printbibliography',
27
+ '\begin{sloppypar}',
28
+ '\end{sloppypar}',
29
+ '\ifbd',
30
+ '\iftr',
31
+ '\fussy',
32
+ '\sloppy',
33
+ '\emergencystretch',
34
+ '\hideLIPIcs',
35
+ '\tolerance',
36
+ '\hbadness',
37
+ '\bShell',
38
+ '\glsresetall',
39
+ '\copyrightnotice',
40
+ '\copyright',
41
+ '\centering',
42
+ '\immediate',
43
+ '\doublespacing',
44
+ '\flushbottom',
45
+ '\printAffiliationsAndNotice',
46
+ '\IEEEpubid',
47
+ '\twocolumn',
48
+ '\noindent',
49
+ '\indent',
50
+ '\onecolumn',
51
+ '\ignore',
52
+ '\selectfont',
53
+ '\raggedbottom',
54
+ '\IEEEoverridecommandlockouts',
55
+ '\newline',
56
+ '\tiny',
57
+ '\break',
58
+ '\mainmatter',
59
+ '\let\rc\rangle',
60
+ '\let\lc\langle',
61
+ '\acresetall',
62
+ '\acknowledgments',
63
+ '\begin{section}',
64
+ '\begin{small}',
65
+ '\end{small}',
66
+ '\relax',
67
+ '\ninept',
68
+ '\FloatBarrier',
69
+ '\boldmath',
70
+ '\end{section}',
71
+ '\Huge',
72
+ '\fancyhf',
73
+ '\fancyhead',
74
+ '\begin{frontmatter}',
75
+ '\end{frontmatter}',
76
+ '\clearpage',
77
+ '\huge',
78
+ '\newpage',
79
+ '\IEEEdisplaynontitleabstractindextext',
80
+ '\DontPrintSemicolon',
81
+ '\Large',
82
+ '\LARGE',
83
+ '\ifCLASSOPTIONcompsoc',
84
+ '\ifCLASSOPTIONonecolumn',
85
+ '\xspace',
86
+ '\large',
87
+ '\acmcopyr',
88
+ '\flushleft',
89
+ '\newpage',
90
+ '\protect',
91
+ '\begingroup',
92
+ '\endgroup',
93
+ '\bigskip',
94
+ '\smallskip',
95
+ '\small',
96
+ '\left',
97
+ '\right',
98
+ '\vfill',
99
+ '\hfill',
100
+ '\begin{appendices}',
101
+ '\IEEEQED',
102
+ '\leavevmode',
103
+ '\footnotesize',
104
+ '\nonumber',
105
+ '\scriptsize',
106
+ '\IEEEpubidadjcol',
107
+ '\balance',
108
+ '\normalsize',
109
+ '\ifloguseIEEEConf',
110
+ '\else',
111
+ '\fi',
112
+ '\bf ',
113
+ '\it ',
114
+ '\verb ',
115
+ '\tt ',
116
+ '\em ',
117
+ '\par '
118
+ ]
119
+
120
+ two_arguments: [
121
+ 'IEEEPARstart',
122
+ 'pgfdeclareplotmark',
123
+ 'setcounter',
124
+ 'texorpdfstring',
125
+ 'fontsize',
126
+ 'addtocounter',
127
+ 'addtolength'
128
+ ]
129
+
130
+ three_arguments: [
131
+ 'definecolor'
132
+ ]
133
+
134
+ two_arguments_elaborate: [
135
+ 'markboth',
136
+ 'setlength',
137
+ 'pgfdeclareplotmark',
138
+ 'icmlsetsymbol',
139
+ 'texorpdfstring',
140
+ 'conferenceinfo',
141
+ 'acrodef',
142
+ 'icmlcorrespondingauthor',
143
+ 'pdfbookmark',
144
+ 'icmlaffiliation',
145
+ 'icmlcorrespondingauthor'
146
+ ]
147
+
148
+ three_arguments_elaborate: [
149
+ 'ifthenelse',
150
+ 'addcontentsline'
151
+ ]
152
+
153
+ replace_comments: [
154
+ 'def\\',
155
+ 'def ',
156
+ 'newglossaryentry',
157
+ 'newtheorem',
158
+ 'newcommand',
159
+ 'renewcommand',
160
+ 'newenvironment',
161
+ 'renewenvironment'
162
+ ]
datasets/computer_vision.gexf ADDED
The diff for this file is too large to render. See raw diff
 
datasets/economics.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb70ebc1810a487c967dc7d196b4b32ffb8237870502d85496633c7ec639e3a3
3
+ size 17470884
datasets/finance.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4739b443038f71850022a93ecf64c97d2fde08c492f716ee0d3e7a9d657302a2
3
+ size 18446416
datasets/large_language_models.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6b8fa29877039cb8b4665889421cdd4488752bcad6f512ac0b05678a73107c1
3
+ size 18262076
datasets/quantative_biology.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99666f9881f04850511a4e424243ed28228ae7cfd28859fc2f9104faa4e8bb4b
3
+ size 20476813
datasets/quantum_physics.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7abcb0d7aec2edb5461161ff613fe87176a4eb9c1522c2b0c1903fbccf66f74
3
+ size 20678301
datasets/robotics.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05681fb68e8168d530cc5bd037f5a6236135446be05862e2be9bf61b548d2c98
3
+ size 19601009
datasets/telecommunications.gexf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0f89f57398abcc70065c2756e15fcd2792c34f6ac9137ba1a8521a5dd8cd22c
3
+ size 14749926
litbench_pipeline.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from train import *
2
+ from utils.utils import *
3
+ from utils.graph_utils import *
4
+ from utils.gradio_utils import *
5
+ from retriever.retriever import retriever
6
+ from tasks.abs_2_title import abs_2_title
7
+ from tasks.abs_completion import abs_completion
8
+ from tasks.citation_sentence import citation_sentence
9
+ from tasks.intro_2_abs import intro_2_abs
10
+ from tasks.link_pred import link_pred
11
+ from tasks.paper_retrieval import paper_retrieval
12
+ from tasks.influential_papers import influential_papers
13
+ from tasks.gen_related_work import gen_related_work
14
+ import random
15
+ import json
16
+ import os
17
+ import re
18
+ import networkx as nx
19
+ import tarfile
20
+ import gzip
21
+ import time
22
+ import urllib.request
23
+ from tqdm import tqdm
24
+ from colorama import Fore
25
+ import wandb
26
+ import gradio as gr
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, TextIteratorStreamer, pipeline
28
+ from threading import Thread
29
+ import signal
30
+ import gzip
31
+ import time
32
+ import torch
33
+ from peft.peft_model import PeftModel
34
+ from datasets import load_dataset
35
+
36
+
37
+
38
+ # Function to determine the chatbot's first message based on user choices
39
+ def setup(download_option, train_option):
40
+ download_papers.value = (download_option == "Download Paper")
41
+ train_model.value = (train_option == "Train")
42
+
43
+ if download_option == "Download Paper":
44
+ initial_message = [{"role": "assistant", "content": "Hello, what domain are you interested in?"}]
45
+ elif download_option != "Download Paper" and train_option == "Train":
46
+ initial_message = [{"role": "assistant", "content": "What domain is your graph about?"}]
47
+ else:
48
+ initial_message = [{"role": "assistant", "content": "Please provide your task prompt."}]
49
+
50
+ return gr.update(visible=False), gr.update(visible=True), f"Download: {download_option}\nTrain: {train_option}", initial_message
51
+
52
+
53
+ # Function to toggle the selected task based on user input
54
+ def update_button_styles(selected_task):
55
+ """Update button styles based on selection."""
56
+ return [gr.update(variant="primary" if selected_task == prompt else "secondary") for prompt in task_list]
57
+
58
+
59
+ # Fetch and store arXiv source files
60
+ def fetch_arxiv_papers(papers_to_download):
61
+ # Download the arXiv metadata file if it doesn't exist
62
+ dataset = 'datasets/arxiv-metadata-oai-snapshot.json'
63
+ data = []
64
+ if not os.path.exists(dataset):
65
+ os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets")
66
+
67
+ with open(dataset, 'r') as f:
68
+ for line in f:
69
+ data.append(json.loads(line))
70
+
71
+ papers = [d for d in data]
72
+ paper_ids = [d['id'] for d in data]
73
+ paper_titles = [
74
+ (
75
+ re.sub(r' +', ' ', re.sub(r'[\n]+', ' ', paper['title']))
76
+ .replace("\\emph", "")
77
+ .replace("\\emp", "")
78
+ .replace("\\em", "")
79
+ .replace(",", "")
80
+ .replace("{", "")
81
+ .replace("}", "")
82
+ .strip(".")
83
+ .strip()
84
+ .strip(".")
85
+ .lower()
86
+ )
87
+ for paper in papers
88
+ ]
89
+ paper_dict = {
90
+ k:v
91
+ for k,v in zip(paper_titles, paper_ids)
92
+ }
93
+
94
+
95
+ total_papers = len(papers_to_download)
96
+ download_progress_bar=gr.Progress()
97
+
98
+ llm_resp = []
99
+ results = {
100
+ "Number of papers": 0,
101
+ "Number of latex papers": 0,
102
+ "Number of bib files": 0,
103
+ "Number of bbl files": 0,
104
+ "Number of inline files": 0,
105
+ "Number of introductions found": 0,
106
+ "Number of related works found": 0,
107
+ "Number of succesful finding of extracts": 0
108
+ }
109
+ num_papers, num_edges, t, iter_ind = 0, 0, 0, 0
110
+ graph = {}
111
+
112
+ arxiv_rate_lim = config['data_downloading']['processing']['arxiv_rate_limit']
113
+ for paper_name in tqdm(papers_to_download):
114
+ results["Number of papers"] += 1
115
+ print(
116
+ Fore.BLUE + "Number of papers processed: {} \n Number of edges found: {} \n Time of previous iter: {} \n Now processing paper: {} \n\n"
117
+ .format(num_papers, num_edges, time.time()-t, paper_name) + Fore.RESET
118
+ )
119
+ t = time.time()
120
+ num_papers += 1
121
+
122
+ # Prepare the paper name for downloading and saving
123
+ paper_name_download = paper_name
124
+ if re.search(r'[a-zA-Z]', paper_name) is not None:
125
+ paper_name = "".join(paper_name.split('/'))
126
+ tar_file_path = save_zip_directory + paper_name + '.tar.gz'
127
+
128
+ # Attempt to download the paper source files from arXiv
129
+ try:
130
+ # Track start time for download
131
+ t1 = time.time()
132
+ urllib.request.urlretrieve(
133
+ "https://arxiv.org/src/"+paper_name_download,
134
+ tar_file_path)
135
+ except Exception as e:
136
+ print("Couldn't download paper {}".format(paper_name))
137
+ # Skip to the next paper if download fails
138
+ continue
139
+
140
+ # Define the directory where the paper will be extracted
141
+ extracted_dir = save_directory + paper_name + '/'
142
+ isExist = os.path.exists(extracted_dir)
143
+ if not isExist:
144
+ os.makedirs(extracted_dir)
145
+
146
+ # Attempt to extract the tar.gz archive
147
+ try:
148
+ tar = tarfile.open(tar_file_path)
149
+ tar.extractall(extracted_dir)
150
+ tar.close()
151
+ except Exception as e:
152
+ # If tar extraction fails, attempt to read and extract using gzip
153
+ try:
154
+ with gzip.open(tar_file_path, 'rb') as f:
155
+ file_content = f.read()
156
+
157
+ # Save the extracted content as a .tex file
158
+ with open(extracted_dir+paper_name+'.tex', 'w') as f:
159
+ f.write(file_content.decode())
160
+ except Exception as e:
161
+ print("Could not extract paper id: {}".format(paper_name))
162
+ # Skip this paper if extraction fails
163
+ continue
164
+
165
+ try:
166
+ # Perform initial cleaning and get the main TeX file
167
+ initial_clean(extracted_dir, config=False)
168
+ main_file = get_main(extracted_dir)
169
+
170
+ # If no main TeX file is found, remove the downloaded archive and continue
171
+ if main_file == None:
172
+ print("No tex files found")
173
+ os.remove(tar_file_path)
174
+ continue
175
+
176
+ # Check if the main TeX file contains a valid LaTeX document
177
+ h = check_begin(main_file)
178
+ if h == True:
179
+ results["Number of latex papers"] += 1
180
+ # Flag to check for internal bibliography
181
+ check_internal = 0
182
+ # Dictionary to store bibliographic references
183
+ final_library = {}
184
+
185
+ # Identify bibliography files (.bib or .bbl)
186
+ bib_files = find_bib(extracted_dir)
187
+ if bib_files == []:
188
+ bbl_files = find_bbl(extracted_dir)
189
+ if bbl_files == []:
190
+ # No external bibliography found
191
+ check_internal = 1
192
+ else:
193
+ final_library = get_library_bbl(bbl_files)
194
+ results["Number of bbl files"] += 1
195
+ else:
196
+ results["Number of bib files"] += 1
197
+ final_library = get_library_bib(bib_files)
198
+
199
+ # Apply post-processing to clean the TeX document
200
+ main_file = post_processing(extracted_dir, main_file)
201
+
202
+ # Read the cleaned LaTeX document content
203
+ descr = main_file
204
+ content = read_tex_file(descr)
205
+
206
+ # If configured, store the raw content in the graph
207
+ if config['data_downloading']['processing']['keep_unstructured_content']:
208
+ graph[paper_name] = {'content': content}
209
+ else:
210
+ graph[paper_name] = {}
211
+
212
+ # Check for inline bibliography within the LaTeX document
213
+ if check_internal == 1:
214
+ beginning_bib = '\\begin{thebibliography}'
215
+ end_bib = '\\end{thebibliography}'
216
+
217
+ if content.find(beginning_bib) != -1 and content.find(end_bib) != -1:
218
+ bibliography = content[content.find(beginning_bib):content.find(end_bib) + len(end_bib)]
219
+ save_bbl = os.path.join(extracted_dir, "bibliography.bbl")
220
+
221
+ results["Number of inline files"] += 1
222
+ with open(save_bbl, "w") as f:
223
+ f.write(bibliography)
224
+
225
+ final_library = get_library_bbl([save_bbl])
226
+
227
+ # If no valid bibliography is found, skip processing citations
228
+ if final_library == {}:
229
+ print("No library found...")
230
+ continue
231
+
232
+ # Extract relevant sections such as "Related Work" and "Introduction"
233
+ related_works = get_related_works(content)
234
+ if related_works != '':
235
+ graph[paper_name]['Related Work'] = related_works
236
+ results["Number of intro/related found"] += 1
237
+
238
+ intro = get_intro(content)
239
+ if intro != '':
240
+ graph[paper_name]['Introduction'] = intro
241
+ results["Number of introductions found"] += 1
242
+
243
+ # Extract citation sentences from the introduction and related works
244
+ sentences_citing = get_citing_sentences(intro + '\n' + related_works)
245
+
246
+ # Map citations to corresponding papers
247
+ raw_sentences_citing = {}
248
+ for k,v in sentences_citing.items():
249
+ new_values = []
250
+ for item in v:
251
+ try:
252
+ new_values.append(paper_dict[final_library[item]['title']])
253
+ except Exception as e:
254
+ pass
255
+ if new_values != []:
256
+ raw_sentences_citing[k] = new_values
257
+
258
+ # Construct citation edges
259
+ edges_set = []
260
+ for k,v in raw_sentences_citing.items():
261
+ for item in v:
262
+ edges_set.append((paper_name_download, item, {"sentence":k}))
263
+
264
+ iter_ind +=1
265
+ if len(edges_set) !=0:
266
+ results["Number of succesful finding of extracts"] += 1
267
+ graph[paper_name]['Citations'] = edges_set
268
+ num_edges += len(edges_set)
269
+
270
+ # Save progress after every 10 iterations
271
+ if iter_ind % 10 == 0:
272
+ print("Saving graph now")
273
+ with open(save_path, 'w') as f:
274
+ json.dump(results, f)
275
+ with open(save_graph, 'w') as f:
276
+ json.dump(graph, f)
277
+
278
+ except Exception as e:
279
+ print("Could not get main paper {}".format(paper_name))
280
+
281
+ # Update the progress bar after processing each paper
282
+ download_progress_bar(num_papers / total_papers)
283
+
284
+
285
+ # Ensure a minimum time gap of 3 seconds between iterations to avoid bans from arXiv
286
+ t2 = time.time() # End time
287
+ elapsed_time = t2 - t1
288
+ if elapsed_time < arxiv_rate_lim:
289
+ time.sleep(arxiv_rate_lim - elapsed_time)
290
+
291
+
292
+ # Final saving of processed data
293
+ with open(save_graph, 'w') as f:
294
+ json.dump(graph, f)
295
+ with open(save_path, 'w') as f:
296
+ json.dump(results, f)
297
+
298
+
299
+ # Log final completion message
300
+ llm_resp.append("✅ Successfully downloaded and cleaned {} papers.".format(results["Number of latex papers"]))
301
+ return "\n".join(llm_resp)
302
+
303
+
304
+ # Chat prediction function
305
+ def predict(message, history, selected_task):
306
+ global model
307
+ # Initialize the conversation string
308
+ conversation = ""
309
+
310
+ # Parse the history: Gradio `type="messages"` uses dictionaries with 'role' and 'content'
311
+ for item in history:
312
+ if item["role"] == "assistant":
313
+ conversation += f"<bot>: {item['content']}\n"
314
+ elif item["role"] == "user":
315
+ conversation += f"<human>: {item['content']}\n"
316
+
317
+ # Add the user's current message to the conversation
318
+ conversation += f"<human>: {message}\n<bot>:"
319
+
320
+ # Handle preferences
321
+ if len(history) == 0:
322
+ if not download_papers.value and not train_model.value:
323
+ yield "✅ Using model from configuration file..."
324
+
325
+ adapter_path = config["inference"]["pretrained_model"]
326
+ peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
327
+
328
+ # change the global model with peft model
329
+ model = peft_model
330
+
331
+ time.sleep(2.5)
332
+
333
+ if not (len(history) == 0 and (train_model.value or download_papers.value)):
334
+ # Streamer for generating responses
335
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
336
+ stop = StopOnTokens()
337
+
338
+ generate_kwargs = {
339
+ "streamer": streamer,
340
+ "max_new_tokens": config['inference']['generation_args']["max_new_tokens"],
341
+ "do_sample": config['inference']['generation_args']["do_sample"],
342
+ "top_p": config['inference']['generation_args']["top_p"],
343
+ "top_k": config['inference']['generation_args']["top_k"],
344
+ "temperature": config['inference']['generation_args']["temperature"],
345
+ "no_repeat_ngram_size": config['inference']['generation_args']["no_repeat_ngram_size"],
346
+ "num_beams": config['inference']['generation_args']["num_beams"],
347
+ "stopping_criteria": StoppingCriteriaList([stop]),
348
+ }
349
+
350
+ def generate_response(model, generate_kwargs, selected_task):
351
+ global advanced_tasks_out
352
+ has_predefined_template = generate_kwargs["streamer"].tokenizer.chat_template is not None
353
+
354
+ if selected_task == "Abstract Completion":
355
+ prompt = abs_completion(message, template, has_predefined_template)
356
+ elif selected_task == "Title Generation":
357
+ prompt = abs_2_title(message, template, has_predefined_template)
358
+ elif selected_task == "Citation Recommendation":
359
+ prompt = paper_retrieval(message, template, has_predefined_template)
360
+ elif selected_task == "Citation Sentence Generation":
361
+ prompt = citation_sentence(message, template, has_predefined_template)
362
+ elif selected_task == "Citation Link Prediction":
363
+ prompt = link_pred(message, template, has_predefined_template)
364
+ elif selected_task == "Introduction to Abstract":
365
+ prompt = intro_2_abs(message, template, tokenizer.model_max_length, has_predefined_template)
366
+ elif selected_task == "Influential Papers Recommendation":
367
+ if download_papers.value:
368
+ graph = nx.read_gexf(gexf_file)
369
+ advanced_tasks_out = influential_papers(message, graph)
370
+ else:
371
+ graph = nx.read_gexf(predef_graph)
372
+ advanced_tasks_out = influential_papers(message, graph)
373
+ elif selected_task == "Related Work Generation":
374
+ adapter_path = (
375
+ f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph"
376
+ if train_model.value else config['inference']['pretrained_model']
377
+ )
378
+ if download_papers.value:
379
+ advanced_tasks_out = gen_related_work(message, gexf_file, adapter_path)
380
+ else:
381
+ advanced_tasks_out = gen_related_work(message, predef_graph, adapter_path)
382
+ else:
383
+ prompt = conversation + f"<human>: {message}\n<bot>:"
384
+
385
+ if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation":
386
+ if tokenizer.chat_template is not None:
387
+ response = model_pipeline(prompt, **generate_kwargs)
388
+ streamer.put(response[0]['generated_text'][-1])
389
+ else:
390
+ model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
391
+ generate_kwargs["inputs"] = model_inputs["input_ids"]
392
+ generate_kwargs["attention_mask"] = model_inputs["attention_mask"]
393
+
394
+ response = model.generate(**generate_kwargs)
395
+ streamer.put(response)
396
+
397
+ # Generate the response in a separate thread
398
+ t = Thread(target=generate_response,
399
+ kwargs={
400
+ "model": model,
401
+ "generate_kwargs": generate_kwargs,
402
+ "selected_task": selected_task
403
+ })
404
+
405
+ global advanced_tasks_out
406
+ advanced_tasks_out = None
407
+ t.start()
408
+
409
+ # Stream the partial response
410
+ if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation":
411
+ partial_message = ""
412
+ for new_token in streamer:
413
+ if new_token != '<': # Ignore placeholder tokens
414
+ partial_message += new_token
415
+ yield partial_message
416
+ else:
417
+ if selected_task == "Related Work Generation":
418
+ yield "🔍 Generating related work..."
419
+ while advanced_tasks_out == None:
420
+ time.sleep(0.1)
421
+ yield advanced_tasks_out
422
+
423
+ # Fetch arXiv papers if the user opted to download them
424
+ if len(history) == 0:
425
+ if download_papers.value:
426
+ # Fetch relevant papers
427
+ yield "🔍 Retrieving relevant papers..."
428
+
429
+ retrieve_progress = gr.Progress()
430
+ for percent in retriever(message, retrieval_nodes_path):
431
+ retrieve_progress(percent)
432
+
433
+ with open(retrieval_nodes_path, "r") as f:
434
+ data_download = json.load(f)
435
+
436
+ papers_to_download = list(data_download.keys())
437
+
438
+ yield f"📥 Fetching {len(papers_to_download)} arXiv papers' source files... Please wait."
439
+
440
+ content = fetch_arxiv_papers(papers_to_download)
441
+ yield content
442
+ time.sleep(2.5)
443
+
444
+
445
+ # Train the model with the retrieved graph
446
+ if len(history) == 0:
447
+ if train_model.value:
448
+ training_progress=gr.Progress()
449
+
450
+ training_progress(0.0)
451
+
452
+ # If the user opted to download papers, use the retrieved graph, else use the predefined graph
453
+ if download_papers.value:
454
+ yield "🚀 Training the model with the retrieved graph..."
455
+
456
+ with open(save_graph, "r") as f:
457
+ data_graph = json.load(f)
458
+
459
+ renamed_data = {
460
+ "/".join(re.match(r"([a-z-]+)([0-9]+)", key, re.I).groups()) if re.match(r"([a-z-]+)([0-9]+)", key, re.I) else key: value
461
+ for key, value in data_graph.items()
462
+ }
463
+
464
+ concept_data = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
465
+ id2topics = {
466
+ entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
467
+ for entry in concept_data["train"]
468
+ }
469
+
470
+ dataset = 'datasets/arxiv-metadata-oai-snapshot.json'
471
+ data = []
472
+ if not os.path.exists(dataset):
473
+ os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets")
474
+ with open(dataset, 'r') as f:
475
+ for line in f:
476
+ data.append(json.loads(line))
477
+ papers = {d['id']: d for d in data}
478
+
479
+ G = nx.DiGraph()
480
+ for k in renamed_data:
481
+ if k not in G and k in papers:
482
+ if config['data_downloading']['processing']['keep_unstructured_content']:
483
+ G.add_node(
484
+ k,
485
+ title=papers[k]['title'],
486
+ abstract=papers[k]['abstract'],
487
+ introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '',
488
+ related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '',
489
+ concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else '',
490
+ content=renamed_data[k].get('content', '') if k in renamed_data else ''
491
+ )
492
+ else:
493
+ G.add_node(
494
+ k,
495
+ title=papers[k]['title'],
496
+ abstract=papers[k]['abstract'],
497
+ introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '',
498
+ related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '',
499
+ concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else ''
500
+ )
501
+ if 'Citations' in renamed_data[k]:
502
+ for citation in renamed_data[k]['Citations']:
503
+ source, target, metadata = citation
504
+ sentence = metadata.get('sentence', '') # Extract sentence or default to empty string
505
+
506
+ if target not in G and target in papers:
507
+ if config['data_downloading']['processing']['keep_unstructured_content']:
508
+ G.add_node(
509
+ target,
510
+ title=papers[target]['title'],
511
+ abstract=papers[target]['abstract'],
512
+ introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '',
513
+ related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '',
514
+ concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else '',
515
+ content=renamed_data[target].get('content', '') if target in renamed_data else ''
516
+ )
517
+ else:
518
+ G.add_node(
519
+ target,
520
+ title=papers[target]['title'],
521
+ abstract=papers[target]['abstract'],
522
+ introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '',
523
+ related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '',
524
+ concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else ''
525
+ )
526
+
527
+ G.add_edge(source, target, sentence=sentence)
528
+
529
+ G.remove_nodes_from(list(nx.isolates(G)))
530
+
531
+ nx.write_gexf(G, gexf_file)
532
+ print(f"Processed graph written to {gexf_file}")
533
+ else:
534
+ yield f"✅ Using predefined graph: {predef_graph}"
535
+
536
+
537
+ wandb.init(project='qlora_train')
538
+
539
+ if download_papers.value:
540
+ trainer = QloraTrainer_CS(config=config, use_predefined_graph=False)
541
+ else:
542
+ trainer = QloraTrainer_CS(config=config, use_predefined_graph=True)
543
+
544
+ print("Load base model")
545
+ trainer.load_base_model()
546
+
547
+
548
+ print("Start training")
549
+ def update_progress():
550
+ # Wait for the trainer to be initialized
551
+ while trainer.transformer_trainer is None:
552
+ time.sleep(0.5)
553
+
554
+ time.sleep(1.5)
555
+ # Update the progress bar until training is complete
556
+ while trainer.transformer_trainer.state.global_step != trainer.transformer_trainer.state.max_steps:
557
+ progress_bar = (
558
+ trainer.transformer_trainer.state.global_step /
559
+ trainer.transformer_trainer.state.max_steps
560
+ )
561
+ training_progress(progress_bar)
562
+ time.sleep(0.5)
563
+ training_progress(1.0)
564
+
565
+ t1 = Thread(target=trainer.train)
566
+ t1.start()
567
+ t2 = Thread(target=update_progress())
568
+ t2.start()
569
+ t1.join()
570
+ t2.join()
571
+
572
+ yield "🎉 Model training complete! Please provide your task prompt."
573
+
574
+ adapter_path = f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph"
575
+ peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
576
+
577
+ # change the global model with peft model
578
+ model = peft_model
579
+
580
+
581
+
582
+ if __name__ == "__main__":
583
+ print("This is running in a virtual environment: {}".format(is_venv()))
584
+
585
+ config = read_yaml_file("configs/config.yaml")
586
+ template_file_path = 'configs/alpaca.json'
587
+ template = json.load(open(template_file_path, "r"))
588
+
589
+
590
+ seed_no = config['data_downloading']['processing']['random_seed']
591
+ model_name = config['inference']['base_model']
592
+ working_dir = config['data_downloading']['download_directory']
593
+ save_zip_directory = working_dir + 'research_papers_zip/'
594
+ save_directory = working_dir + 'research_papers/'
595
+ save_description = working_dir + 'description/'
596
+ save_path = save_description + 'results.json'
597
+ save_graph = save_description + 'test_graph.json'
598
+ gexf_file = save_description + config['data_downloading']['gexf_file']
599
+ predef_graph = 'datasets/' + config['training']['predefined_graph_path']
600
+ retrieval_nodes_path = 'datasets/retrieval_nodes.json'
601
+
602
+ isExist = os.path.exists(save_zip_directory)
603
+ if not isExist:
604
+ os.makedirs(save_zip_directory)
605
+ isExist = os.path.exists(save_directory)
606
+ if not isExist:
607
+ os.makedirs(save_directory)
608
+ isExist = os.path.exists(save_description)
609
+ if not isExist:
610
+ os.makedirs(save_description)
611
+
612
+
613
+ random.seed(seed_no)
614
+
615
+
616
+ # Load model and tokenizer
617
+ bnb_config = BitsAndBytesConfig(
618
+ load_in_8bit=True,
619
+ bnb_8bit_use_double_quant=True,
620
+ bnb_8bit_quant_type="nf8",
621
+ bnb_8bit_compute_dtype=torch.bfloat16
622
+ )
623
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
624
+ model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
625
+ if model.device.type != 'cuda':
626
+ model.to('cuda')
627
+
628
+ if tokenizer.chat_template is not None:
629
+ model_pipeline = pipeline(
630
+ "text-generation",
631
+ model=model_name,
632
+ model_kwargs={"torch_dtype": torch.bfloat16},
633
+ device_map="auto",
634
+ )
635
+
636
+ signal.signal(signal.SIGINT, signal_handler)
637
+
638
+
639
+ # Global States for User Preferences
640
+ download_papers = gr.State(value=True) # Default: Download papers
641
+ train_model = gr.State(value=True) # Default: Train the model
642
+
643
+
644
+ # Categorized Recommended Prompts
645
+ task_list = {
646
+ "Abstract Completion",
647
+ "Introduction to Abstract",
648
+ "Title Generation",
649
+ "Citation Recommendation",
650
+ "Citation Sentence Generation",
651
+ "Citation Link Prediction",
652
+ "Influential Papers Recommendation",
653
+ "Related Work Generation",
654
+ }
655
+
656
+
657
+ # CSS for Styling
658
+ css = """
659
+ body { background-color: #E0F7FA; margin: 0; padding: 0; }
660
+ .gradio-container { background-color: #E0F7FA; border-radius: 10px; }
661
+ #logo-container { display: flex; justify-content: center; align-items: center; margin: 0 auto; padding: 0; max-width: 120px; height: 120px; border-radius: 10px; overflow: hidden; }
662
+ #scroll-menu { max-height: 310px; overflow-y: auto; padding: 10px; background-color: #fff; margin-top: 10px;}
663
+ #task-header { background-color: #0288d1; color: white; font-size: 18px; padding: 8px; text-align: center; margin-bottom: 5px; margin-top: 40px; }
664
+ #category-header { background-color: #ecb939; font-size: 16px; padding: 8px; margin: 10px 0; }
665
+ """
666
+
667
+ # State to store the selected task
668
+ selected_task = gr.State(value="")
669
+
670
+
671
+ # Gradio Interface
672
+ with gr.Blocks(theme="soft", css=css) as demo:
673
+ gr.HTML('<div id="logo-container"><img src="https://static.thenounproject.com/png/6480915-200.png" alt="Logo"></div>')
674
+ gr.Markdown("# LitBench Interface")
675
+
676
+
677
+ # Setup row for user preferences
678
+ with gr.Row(visible=True) as setup_row:
679
+ with gr.Column():
680
+ gr.Markdown("### Setup Your Preferences")
681
+ download_option = gr.Dropdown(
682
+ choices=["Download Paper", "Don't Download"],
683
+ value="Download Paper",
684
+ label="Download Option"
685
+ )
686
+ train_option = gr.Dropdown(
687
+ choices=["Train", "Don't Train"],
688
+ value="Train",
689
+ label="Training Option"
690
+ )
691
+ setup_button = gr.Button("Set Preferences and Proceed")
692
+
693
+
694
+ # Chatbot row for user interaction
695
+ with gr.Row(visible=False) as chatbot_row:
696
+ # Store the currently selected task
697
+ with gr.Column(scale=3):
698
+ gr.Markdown("### Start Chatting!")
699
+ chatbot = gr.ChatInterface(
700
+ predict,
701
+ chatbot=gr.Chatbot(
702
+ height=400,
703
+ type="messages",
704
+ avatar_images=[
705
+ "https://icons.veryicon.com/png/o/miscellaneous/user-avatar/user-avatar-male-5.png",
706
+ "https://cdn-icons-png.flaticon.com/512/8649/8649595.png"
707
+ ],
708
+ ),
709
+ textbox=gr.Textbox(placeholder="Type your message here..."),
710
+ additional_inputs=selected_task,
711
+ additional_inputs_accordion=gr.Accordion(visible=False, label="Additional Inputs", ),
712
+ )
713
+
714
+ # Store user preferences and selected task for display
715
+ preferences_output = gr.Textbox(value="", interactive=False, label="Your Preferences")
716
+
717
+
718
+ # Task selection buttons for user interaction
719
+ with gr.Column(scale=1):
720
+ gr.HTML('<div id="task-header">Tasks:</div>')
721
+ with gr.Column(elem_id="scroll-menu"):
722
+ # Create buttons
723
+ button_map = {prompt: gr.Button(prompt) for prompt in task_list}
724
+
725
+ for prompt in task_list:
726
+ button_map[prompt].click(
727
+ toggle_selection,
728
+ inputs=[selected_task, gr.State(value=prompt)], # Toggle task selection
729
+ outputs=selected_task
730
+ ).then(
731
+ update_button_styles, # Update button appearances
732
+ inputs=[selected_task],
733
+ outputs=[button_map[p] for p in task_list] # Update all buttons
734
+ )
735
+
736
+
737
+ # Setup button to finalize user preferences and start chatbot
738
+ setup_button.click(
739
+ setup,
740
+ inputs=[download_option, train_option],
741
+ outputs=[setup_row, chatbot_row, preferences_output, chatbot.chatbot]
742
+ )
743
+
744
+
745
+ # Launch the interface
746
+ demo.launch(server_port=7880)
retriever/retriever.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+ import json
4
+ import torch
5
+ from tqdm import tqdm
6
+ import os
7
+ import pandas as pd
8
+ import numpy as np
9
+ from datasets import load_dataset
10
+ from utils.utils import read_yaml_file
11
+
12
+
13
+ def generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs):
14
+ id2topics = {
15
+ entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
16
+ for entry in tmp_id_2_abs['train']
17
+ }
18
+
19
+ for topic_level in ['Level 1', 'Level 2', 'Level 3']:
20
+ i = 0
21
+ batch_size = 2048
22
+ candidate_emb_list = []
23
+ pbar = tqdm(total=len(paper_list))
24
+ while i < len(paper_list):
25
+ yield i / len(paper_list) / 3 if topic_level == 'Level 1' else 0.33 + i / len(paper_list) / 3 if topic_level == 'Level 2' else 0.66 + i / len(paper_list) / 3
26
+ paper_batch = paper_list[i:i+batch_size]
27
+ paper_text_batch = []
28
+ for paper_id in paper_batch:
29
+ topics = id2topics[paper_id][int(topic_level[6])-1]
30
+ topic_text = ''
31
+ for t in topics:
32
+ topic_text += t + ','
33
+ paper_text_batch.append(topic_text)
34
+ inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
35
+ with torch.no_grad():
36
+ outputs = model(**inputs.to('cuda'))
37
+ candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
38
+ candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
39
+ candidate_emb_list.append(candidate_embeddings)
40
+
41
+ i += len(candidate_embeddings)
42
+ pbar.update(len(candidate_embeddings))
43
+
44
+ all_candidate_embs = torch.cat(candidate_emb_list, 0)
45
+
46
+ df = pd.DataFrame({
47
+ "paper_id": paper_list,
48
+ "embedding": list(all_candidate_embs.numpy())
49
+ })
50
+
51
+ if not os.path.exists('datasets/topic_level_embeds'):
52
+ os.makedirs('datasets/topic_level_embeds')
53
+
54
+ df.to_parquet(f'datasets/topic_level_embeds/{topic_level}_emb.parquet', engine='pyarrow', compression='snappy')
55
+
56
+ all_candidate_embs_L1 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 1_emb.parquet')['embedding'].tolist()))
57
+ all_candidate_embs_L2 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 2_emb.parquet')['embedding'].tolist()))
58
+ all_candidate_embs_L3 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 3_emb.parquet')['embedding'].tolist()))
59
+ all_candidate_embs = all_candidate_embs_L1 + all_candidate_embs_L2 + all_candidate_embs_L3
60
+
61
+ df = pd.DataFrame({
62
+ "paper_id": paper_list,
63
+ "embedding": list(all_candidate_embs.numpy())
64
+ })
65
+
66
+ df.to_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet', engine='pyarrow', compression='snappy')
67
+
68
+
69
+
70
+ def retriever(query, retrieval_nodes_path):
71
+ yield 0
72
+ config = read_yaml_file('configs/config.yaml')
73
+
74
+ # Load the model and tokenizer to generate the embeddings
75
+ embedder_name = config['retriever']['embedder']
76
+ tokenizer = AutoTokenizer.from_pretrained(embedder_name)
77
+ model = AutoModel.from_pretrained(embedder_name).to(device='cuda', dtype=torch.float16)
78
+
79
+
80
+ # Load the arXiv dataset
81
+ tmp_id_2_abs = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
82
+ paper_list = list(tmp_id_2_abs['train']['paper_id'])
83
+
84
+
85
+ # Generate the query embeddings
86
+ inputs = tokenizer([query], return_tensors='pt', padding=True, truncation=True)
87
+ with torch.no_grad():
88
+ outputs = model(**inputs.to('cuda'))
89
+ query_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
90
+
91
+ # Generate the candidate embeddings
92
+ # Load the embeddings from the dataset, otherwise generate the embeddings and save them
93
+ if config['retriever']['load_arxiv_embeds']:
94
+ dataset = load_dataset("AliMaatouk/arXiv-Topics-Embeddings", cache_dir="datasets/topic_level_embeds")
95
+ table = dataset["train"].data # Get PyArrow Table
96
+ all_candidate_embs = table.column("embedding").to_numpy()
97
+ else:
98
+ # If the file does not exist, generate the embeddings, otherwise, load the embeddings
99
+ if not os.path.exists('datasets/topic_level_embeds/arxiv_papers_embeds.parquet'):
100
+ yield from generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs)
101
+
102
+ all_candidate_embs = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet')['embedding'].tolist()))
103
+ all_candidate_embs = all_candidate_embs.cpu().numpy()
104
+
105
+ all_candidate_embs = np.stack(all_candidate_embs)
106
+
107
+
108
+ # Calculate the cosine similarity between the query and all candidate embeddings
109
+ query_embeddings = np.array(query_embeddings)
110
+ similarity_scores = cosine_similarity(query_embeddings, all_candidate_embs)[0]
111
+
112
+
113
+ # Sort the papers by similarity scores and select the top K papers
114
+ id_score_list = []
115
+ for i in range(len(paper_list)):
116
+ id_score_list.append([paper_list[i], similarity_scores[i]])
117
+
118
+ sorted_scores = sorted(id_score_list, key=lambda i: i[-1], reverse = True)
119
+ top_K_paper = [sample[0] for sample in sorted_scores[:config['retriever']['num_retrievals']]]
120
+
121
+ papers_results = {
122
+ paper: True
123
+ for paper in top_K_paper
124
+ }
125
+
126
+ with open(retrieval_nodes_path, 'w') as f:
127
+ json.dump(papers_results, f)
128
+
129
+ yield 1.0
tasks/abs_2_title.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate a prompt for generating the title of a paper based on its abstract.
3
+
4
+ Args:
5
+ usr_input (str): A string containing the title and abstract of the paper in the format "Title: <title> Abstract: <abstract>".
6
+ template (dict): A dictionary containing the template for the prompt with a key "prompt_input".
7
+
8
+ Returns:
9
+ str: A formatted string with the instruction and abstract to be used as input for generating the title.
10
+ """
11
+
12
+ def abs_2_title(usr_input, template, has_predefined_template=False):
13
+ instruction = "Please generate the title of paper based on its abstract"
14
+
15
+ if has_predefined_template:
16
+ res = [
17
+ {"role": "system", "content": instruction},
18
+ {"role": "user", "content": usr_input},
19
+ ]
20
+ else:
21
+ res = template["prompt_input"].format(instruction=instruction, input=usr_input)
22
+
23
+ return res
tasks/abs_completion.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generates a formatted prompt for completing the abstract of a paper.
3
+
4
+ Args:
5
+ usr_input (str): The user input containing the title and part of the abstract.
6
+ Expected format:
7
+ "Title: <title>\nAbstract: <abstract>"
8
+ template (dict): A dictionary containing the template for the prompt.
9
+ Expected format:
10
+ {"prompt_input": "<template_string>"}
11
+ The template string should contain placeholders for
12
+ 'instruction' and 'input'.
13
+
14
+ Returns:
15
+ str: A formatted string with the instruction and the input embedded in the template.
16
+ """
17
+
18
+ def abs_completion(usr_input, template, has_predefined_template=False):
19
+ instruction = "Please complete the abstract of a paper."
20
+
21
+ if has_predefined_template:
22
+ res = [
23
+ {"role": "system", "content": instruction},
24
+ {"role": "user", "content": usr_input},
25
+ ]
26
+ else:
27
+ res = template["prompt_input"].format(instruction=instruction, input=usr_input)
28
+
29
+ return res
tasks/citation_sentence.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generates a citation sentence based on the titles and abstracts of two papers.
3
+
4
+ Args:
5
+ usr_input (str): A string containing the titles and abstracts of Paper A and Paper B.
6
+ The format should be:
7
+ "Title A: <title of paper A>\nAbstract A: <abstract of paper A>\nTitle B: <title of paper B>\nAbstract B: <abstract of paper B>"
8
+ template (dict): A dictionary containing a template for the prompt input. The key "prompt_input" should map to a string with placeholders for the instruction and input.
9
+
10
+ Returns:
11
+ str: A formatted string that combines the instruction and the prompt input with the provided titles and abstracts.
12
+ """
13
+
14
+ def citation_sentence(usr_input, template, has_predefined_template=False):
15
+ instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section. \n"
16
+
17
+ if has_predefined_template:
18
+ res = [
19
+ {"role": "system", "content": instruction},
20
+ {"role": "user", "content": usr_input},
21
+ ]
22
+ else:
23
+ res = template["prompt_input"].format(instruction=instruction, input=usr_input)
24
+
25
+ return res
tasks/gen_related_work.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generates the related work section for a given paper.
3
+
4
+ The input
5
+ - The input prompt is a string that contains the information of the paper for which the related work section needs to be generated.
6
+ - The input prompt should be in the following format:
7
+ Title of Paper: <title of the paper>
8
+
9
+ Abstract of Paper: <abstract of the paper>
10
+ The output
11
+ - The output is a string that contains the related work section for the given paper.
12
+ """
13
+
14
+ import torch
15
+ import json
16
+ import networkx as nx
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from peft import PeftModel
20
+ from transformers import (AutoModel, AutoTokenizer, AutoModelForCausalLM, pipeline)
21
+ from tqdm import tqdm
22
+ import re
23
+ import pandas as pd
24
+ import os
25
+ from sklearn.metrics.pairwise import cosine_similarity
26
+ from utils.utils import read_yaml_file
27
+ import datetime
28
+
29
+
30
+ class LitFM():
31
+ def __init__(self, graph_path, adapter_path):
32
+ self.graph_name = graph_path.split('.')[0].split('/')[-1] if '/' in graph_path else graph_path.split('.')[0]
33
+ self.batch_size = 32
34
+ self.neigh_num = 4
35
+
36
+ config = read_yaml_file('configs/config.yaml')
37
+ retrieval_graph_path = graph_path
38
+
39
+ self.pretrained_model = config['retriever']['embedder']
40
+
41
+ # define generation model
42
+ model_path = config['inference']["base_model"]
43
+ self.generation_tokenizer = AutoTokenizer.from_pretrained(model_path)
44
+ self.generation_tokenizer.model_max_length = 2048
45
+ if self.generation_tokenizer.pad_token is None:
46
+ self.generation_tokenizer.pad_token = self.generation_tokenizer.eos_token
47
+ self.generation_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
48
+ self.generation_model = PeftModel.from_pretrained(self.generation_model, adapter_path, adapter_name="instruction", torch_dtype=torch.float16)
49
+ self.model_pipeline = None
50
+ if self.generation_tokenizer.chat_template is not None:
51
+ self.model_pipeline = pipeline(
52
+ "text-generation",
53
+ model=model_path,
54
+ model_kwargs={"torch_dtype": torch.bfloat16},
55
+ device_map="auto",
56
+ )
57
+
58
+ # define instruction models
59
+ self.instruction_pipe = pipeline(
60
+ "text-generation",
61
+ model=config["inference"]["gen_related_work_instruct_model"],
62
+ model_kwargs={"torch_dtype": torch.bfloat16},
63
+ device_map="auto",
64
+ )
65
+
66
+ # load graph data for retrieval
67
+ def translate_graph(graph):
68
+ all_nodes = list(graph.nodes())
69
+ raw_id_2_id_dict = {}
70
+ id_2_raw_id_dict = {}
71
+
72
+ num = 0
73
+ for node in all_nodes:
74
+ raw_id_2_id_dict[node] = num
75
+ id_2_raw_id_dict[num] = node
76
+ num += 1
77
+
78
+ return raw_id_2_id_dict, id_2_raw_id_dict
79
+
80
+ whole_graph_data_raw = nx.read_gexf(retrieval_graph_path, node_type=None, relabel=False, version='1.2draft')
81
+ self.whole_graph_raw_id_2_id_dict, self.whole_graph_id_2_raw_id_dict = translate_graph(whole_graph_data_raw)
82
+
83
+ self.whole_graph_id_2_title_abs = dict()
84
+ for paper_id in whole_graph_data_raw.nodes():
85
+ title = whole_graph_data_raw.nodes()[paper_id]['title']
86
+ abstract = whole_graph_data_raw.nodes()[paper_id]['abstract']
87
+ self.whole_graph_id_2_title_abs[self.whole_graph_raw_id_2_id_dict[paper_id]] = [title, abstract]
88
+
89
+ # define prompt template
90
+ template_file_path = 'configs/alpaca.json'
91
+ with open(template_file_path) as fp:
92
+ self.template = json.load(fp)
93
+ self.human_instruction = ['### Input:', '### Response:']
94
+
95
+
96
+ def _generate_retrieval_prompt(self, data_point: dict):
97
+ instruction = "Please select the paper that is more likely to be cited by the paper from the list of candidate papers. Your answer MUST be **only the exact title** of the selected paper without generating ANY other text or section. Your answer MUST belong to the list of candidate papers.\n"
98
+ prompt_input = ""
99
+ prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
100
+ prompt_input = prompt_input + "candidate papers: " + "\n"
101
+ for i in range(len(data_point['nei_titles'])):
102
+ prompt_input = prompt_input + str(i) + '. ' + data_point['nei_titles'][i] + "\n"
103
+
104
+ if self.model_pipeline is not None:
105
+ res = [
106
+ {"role": "system", "content": instruction},
107
+ {"role": "user", "content": prompt_input},
108
+ ]
109
+ else:
110
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
111
+
112
+ return res
113
+
114
+ def _generate_sentence_prompt(self, data_point):
115
+ instruction = "Please generate the citation sentence of how the Paper cites paper B in its related work section."
116
+
117
+ prompt_input = ""
118
+ prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
119
+ prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
120
+ prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
121
+
122
+ if self.model_pipeline is not None:
123
+ res = [
124
+ {"role": "system", "content": instruction},
125
+ {"role": "user", "content": prompt_input},
126
+ ]
127
+ else:
128
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
129
+
130
+ return res
131
+
132
+ def _generate_topic_prompt(self, data_point):
133
+ prompt_input = ""
134
+ prompt_input = prompt_input + "Here are the information of the paper: \n"
135
+ prompt_input = prompt_input + data_point['usr_prompt'] + '\n'
136
+ prompt_input = prompt_input + "Directlty give me the topics you select.\n"
137
+
138
+ res = [
139
+ {"role": "system", "content": "I need to write the related work section for this paper. Could you suggest three most relevant topics to discuss in the related work section? Your answer should be strictly one topic after the other line by line with nothing else being generated and no further explanation/information.\n"},
140
+ {"role": "user", "content": prompt_input},
141
+ ]
142
+
143
+ return res
144
+
145
+ def _generate_paragraph_prompt(self, data_point):
146
+ prompt_input = ""
147
+ prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
148
+ prompt_input = prompt_input + "Topic of this paragraph: " + data_point['topic'] + "\n"
149
+ prompt_input = prompt_input + "Papers that should be cited in paragraph: \n"
150
+
151
+ i = data_point['paper_citation_indicator']
152
+ for paper_idx in range(len(data_point['nei_title'])):
153
+ prompt_input = prompt_input + "[" + str(i) + "]. " + data_point['nei_title'][paper_idx][0] + '.' + " Citation sentence of this paper in the paragraph: " + data_point['nei_sentence'][paper_idx] + '\n'
154
+ i += 1
155
+
156
+ prompt_input = prompt_input + "All the above cited papers should be included and each cited paper should be indicated with its index number. Note that you should not include the title of any paper\n"
157
+
158
+ res = [
159
+ {"role": "system", "content": "Please write a paragraph that review the research relationships between this paper and other cited papers.\n"},
160
+ {"role": "user", "content": prompt_input},
161
+ ]
162
+
163
+ return res
164
+
165
+ def _generate_summary_prompt(self, data_point):
166
+ prompt_input = ""
167
+ prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
168
+ prompt_input = prompt_input + "Paragraphs that should be combined: " + "\n"
169
+
170
+ i = 1
171
+ for para in data_point['paragraphs']:
172
+ prompt_input = prompt_input + " Paragraph " + str(i) + ": " + para + '\n'
173
+ i += 1
174
+
175
+ res = [
176
+ {"role": "system", "content": "Please combine the following paragraphs in a cohenrent way that also keeps the citations and make the flow between paragraphs more smoothly\nAdd a sentence at the beginning of each paragraph to clarify its connection to the previous ones. Do not include any other surrounding text and not add a references list at all\n"},
177
+ {"role": "user", "content": prompt_input},
178
+ ]
179
+
180
+ return res
181
+
182
+
183
+ @staticmethod
184
+ def generate_text(prompt, tokenizer, model, temperature, top_p, repetition_penalty, max_new_tokens):
185
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
186
+
187
+ with torch.no_grad():
188
+ output = model.generate(
189
+ **inputs,
190
+ do_sample=True,
191
+ temperature=temperature,
192
+ top_p=top_p,
193
+ repetition_penalty=repetition_penalty,
194
+ max_new_tokens=max_new_tokens,
195
+ pad_token_id=tokenizer.pad_token_id,
196
+ eos_token_id=tokenizer.eos_token_id,
197
+ use_cache=True,
198
+ )
199
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
200
+ return output_text
201
+
202
+ def get_llm_response(self, prompt, model_type):
203
+ self.generation_model.set_adapter('instruction')
204
+ if model_type == 'zeroshot':
205
+ raw_output = self.instruction_pipe(
206
+ prompt,
207
+ max_new_tokens=8096,
208
+ temperature=0.9,
209
+ top_p=0.95,
210
+ repetition_penalty=1.15,
211
+ )[0]['generated_text'][-1]
212
+
213
+ if model_type == 'zeroshot_short':
214
+ raw_output = self.instruction_pipe(
215
+ prompt,
216
+ max_new_tokens=256,
217
+ temperature=0.9,
218
+ top_p=0.95,
219
+ repetition_penalty=1.15,
220
+ )[0]['generated_text'][-1]
221
+
222
+ if model_type == 'instruction':
223
+ self.generation_model.set_adapter('instruction')
224
+ if self.model_pipeline is not None:
225
+ raw_output = self.model_pipeline(
226
+ prompt,
227
+ temperature=0.9,
228
+ top_p=0.95,
229
+ repetition_penalty=1.15,
230
+ )[0]['generated_text'][-1]
231
+ else:
232
+ raw_output = self.generate_text(
233
+ prompt,
234
+ self.generation_tokenizer,
235
+ self.generation_model,
236
+ temperature=0.9,
237
+ top_p=0.95,
238
+ repetition_penalty=1.15,
239
+ max_new_tokens=256,
240
+ )
241
+
242
+ return raw_output
243
+
244
+ def single_paper_sentence_test(self, usr_prompt, t_title, t_abs):
245
+ datapoint = {'usr_prompt':usr_prompt, 't_title':t_title, 't_abs':t_abs}
246
+ prompt = self._generate_sentence_prompt(datapoint)
247
+ ans = self.get_llm_response(prompt, 'instruction')
248
+ res = ans.strip().split(self.human_instruction[1])[-1]
249
+ return res
250
+
251
+ def single_paper_retrieval_test(self, usr_prompt, candidates):
252
+ datapoint = {'usr_prompt':usr_prompt, 'nei_titles':list(candidates), 't_title': ''}
253
+ prompt = self._generate_retrieval_prompt(datapoint)
254
+ ans = self.get_llm_response(prompt, 'instruction')
255
+ res = ans.strip().split(self.human_instruction[1])[-1]
256
+ return res
257
+
258
+ def single_paper_topic_test(self, usr_prompt):
259
+ datapoint = {'usr_prompt': usr_prompt}
260
+ prompt = self._generate_topic_prompt(datapoint)
261
+ ans = self.get_llm_response(prompt, 'zeroshot_short')
262
+ res = ans['content']
263
+ res = res.replace('\n\n', '\n')
264
+ return res
265
+
266
+ def retrieval_for_one_query(self, id_2_title_abs, prompt):
267
+ if os.path.exists(f'datasets/{self.graph_name}_embeddings.parquet'):
268
+ all_query_embs = torch.tensor(np.array(pd.read_parquet(f'datasets/{self.graph_name}_embeddings.parquet')))
269
+ else:
270
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
271
+ model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5").to(device='cuda', dtype=torch.float16)
272
+ model.eval()
273
+
274
+ paper_list = list(id_2_title_abs.keys())
275
+
276
+ all_query_embs = torch.zeros(len(paper_list), 1024)
277
+ i = 0
278
+ batch_size = 200
279
+ candidate_emb_list = []
280
+ pbar = tqdm(total=len(paper_list))
281
+ while i < len(paper_list):
282
+ paper_batch = paper_list[i:i+batch_size]
283
+ paper_text_batch = []
284
+ for paper_id in paper_batch:
285
+ prompt = id_2_title_abs[paper_id][0] + id_2_title_abs[paper_id][1]
286
+ paper_text_batch.append(prompt)
287
+ inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
288
+
289
+ with torch.no_grad():
290
+ outputs = model(**inputs.to('cuda'))
291
+ candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
292
+ candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
293
+ candidate_emb_list.append(candidate_embeddings)
294
+
295
+ i += len(candidate_embeddings)
296
+ pbar.update(len(candidate_embeddings))
297
+
298
+ all_query_embs = torch.cat(candidate_emb_list, 0)
299
+ pd.DataFrame(all_query_embs.numpy()).to_parquet(f'datasets/{self.graph_name}_embeddings.parquet')
300
+
301
+ # get the embeddings of the prompt
302
+ pretrained_model_name = self.pretrained_model
303
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
304
+ LLM_model = AutoModel.from_pretrained(pretrained_model_name).cuda()
305
+ LLM_model.eval()
306
+
307
+ encoded_input = tokenizer([prompt], padding = True, truncation=True, max_length=512 , return_tensors='pt')
308
+ with torch.no_grad():
309
+ output = LLM_model(**encoded_input.to('cuda'), output_hidden_states=True).hidden_states[-1]
310
+ sentence_embedding = output[:, 0, :]
311
+
312
+ tmp_scores = cosine_similarity(sentence_embedding.to("cpu"), all_query_embs.to("cpu"))[0]
313
+ _, idxs = torch.sort(torch.tensor(tmp_scores), descending=True)
314
+ top_10 = [int(k) for k in idxs[:10]]
315
+
316
+ return [id_2_title_abs[i][0] for i in top_10], [self.whole_graph_id_2_raw_id_dict[i] for i in top_10]
317
+
318
+
319
+ def single_paper_related_work_generation(self, usr_prompt):
320
+ citation_papers = []
321
+ nei_sentence = []
322
+
323
+ # Get topics
324
+ retrieval_query = self.single_paper_topic_test(usr_prompt)
325
+
326
+ # Split topics
327
+ topic_num = 3
328
+ try:
329
+ split_topics = retrieval_query.strip().split('\n')
330
+ if split_topics[0] == '':
331
+ split_topics = split_topics[1:]
332
+ split_topics = split_topics[:topic_num]
333
+ except:
334
+ split_topics = retrieval_query.strip().split(':')
335
+ split_topics = split_topics.strip().split(';')
336
+ split_topics = split_topics[:topic_num]
337
+ if len(split_topics) > topic_num:
338
+ return ["too many topics", split_topics]
339
+
340
+
341
+ # Get top-5 papers for each topic
342
+ for retrieval_query in split_topics:
343
+ # retrieve papers
344
+ candidate_citation_papers, candidate_raw_ids = self.retrieval_for_one_query(self.whole_graph_id_2_title_abs, retrieval_query)
345
+ topic_specific_citation_papers = []
346
+ # select top-5 papers
347
+ for _ in range(5):
348
+ # picking most likely to be cited paper
349
+ selected_paper = self.single_paper_retrieval_test(usr_prompt, candidate_citation_papers).replace(' \n','').replace('\n','')
350
+
351
+ words = selected_paper.strip().split(' ')
352
+ index = -1
353
+ for w in words:
354
+ try:
355
+ index = int(w)
356
+ except:
357
+ pass
358
+
359
+ if index != -1 and index < len(candidate_citation_papers):
360
+ paper_title = candidate_citation_papers[index]
361
+ candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
362
+ topic_specific_citation_papers.append([paper_title, candidate_raw_ids[index]])
363
+ else:
364
+ for i, paper_title in enumerate(list(candidate_citation_papers)):
365
+ if paper_title.lower().replace(' ', '') in selected_paper.lower().replace(' ', '') or selected_paper.lower().replace(' ', '') in paper_title.lower().replace(' ', ''):
366
+ candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
367
+ topic_specific_citation_papers.append([paper_title, candidate_raw_ids[i]])
368
+ break
369
+ citation_papers.append(topic_specific_citation_papers)
370
+
371
+
372
+ # Remove empty lists
373
+ citation_papers = [x for x in citation_papers if x != []]
374
+
375
+
376
+ # Generate citation sentences
377
+ for topic_idx in range(len(citation_papers)):
378
+ topic_specific_nei_sentence = []
379
+ for paper_idx in range(len(citation_papers[topic_idx])):
380
+ sentence = self.single_paper_sentence_test(usr_prompt, citation_papers[topic_idx][paper_idx][0], "")
381
+ # Match \cite{...}
382
+ sentence = re.sub(r'\\cite\{[^{}]+\}', "", sentence)
383
+ topic_specific_nei_sentence.append(sentence)
384
+ nei_sentence.append(topic_specific_nei_sentence)
385
+
386
+
387
+ # Generate paragraphs
388
+ paragraphs = []
389
+ references = [] # Store references for citation
390
+ paper_citation_indicator = 1 # Indicator for citation paper
391
+ for topic_idx in range(len(citation_papers)):
392
+ datapoint = {'usr_prompt': usr_prompt,
393
+ 'nei_title': citation_papers[topic_idx],
394
+ 'nei_sentence': nei_sentence[topic_idx],
395
+ 'topic': split_topics[topic_idx],
396
+ 'paper_citation_indicator': paper_citation_indicator}
397
+
398
+ prompt = self._generate_paragraph_prompt(datapoint)
399
+ ans = self.get_llm_response(prompt, 'zeroshot')
400
+ res = ans['content']
401
+ paragraphs.append(res)
402
+
403
+ # Store referencess
404
+ for ref_idx, paper in enumerate(citation_papers[topic_idx]):
405
+ # Extract year and month from raw_id
406
+ raw_id = re.sub(r'[a-zA-Z/]+', '', paper[1])
407
+ year = raw_id[:2]
408
+ year = '19' + year if int(year) > 70 else '20' + year
409
+ month = datetime.date(1900, int(raw_id[2:4]), 1).strftime('%B')
410
+
411
+ references.append(f"[{paper_citation_indicator + ref_idx}] {paper[0]}, arXiv {raw_id}, {month} {year}")
412
+ # Update paper_citation_indicator
413
+ paper_citation_indicator = paper_citation_indicator + len(nei_sentence[topic_idx])
414
+
415
+
416
+ # Generate summary
417
+ datapoint = {'usr_prompt': usr_prompt, 'paragraphs': paragraphs}
418
+ prompt = self._generate_summary_prompt(datapoint)
419
+ ans = self.get_llm_response(prompt, 'zeroshot')
420
+ summary = ans['content']
421
+
422
+ # Append references to summary
423
+ summary_with_references = summary + "\n\n### References\n" + "\n".join(references)
424
+
425
+ return summary_with_references
426
+
427
+
428
+ def gen_related_work(message, graph_path, adapter_path):
429
+ litfm_instance = LitFM(graph_path, adapter_path)
430
+ return litfm_instance.single_paper_related_work_generation(message)
tasks/influential_papers.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Influential Papers Task
3
+
4
+ This module provides functionality to identify the most influential papers in a citation graph.
5
+
6
+ Functions:
7
+ influential_papers(K, graph):
8
+ Given an integer K and a citation graph, returns the K most influential papers based on the number of citations.
9
+ The function returns the title and abstract of each of the K most influential papers in a formatted string.
10
+
11
+ Usage:
12
+ The script reads configuration from a YAML file, loads a citation graph from a GEXF file, and prints the K most influential papers.
13
+ """
14
+
15
+ import datetime
16
+ import re
17
+
18
+ def influential_papers(message, graph):
19
+ # Get integer number from message
20
+ K = int(re.search(r'\d+', message).group())
21
+
22
+ in_degree = dict(graph.in_degree())
23
+ sorted_in_degree = sorted(in_degree.items(), key=lambda x: x[1], reverse=True)
24
+
25
+ most_cited_papers = []
26
+ for i in range(K):
27
+ node = sorted_in_degree[i]
28
+ paper = graph.nodes[node[0]]
29
+ most_cited_papers.append(paper)
30
+
31
+ resp = "Here are the most influential papers:\n"
32
+ for i, paper in enumerate(most_cited_papers):
33
+ full_paper_id = paper['label']
34
+ paper_id = re.sub(r'[a-zA-Z/]+', '', full_paper_id)
35
+ year = paper_id[:2]
36
+ year = '19' + year if int(year) > 70 else '20' + year
37
+ month = datetime.date(1900, int(paper_id[2:4]), 1).strftime('%B')
38
+
39
+ resp += f"{i+1}. Title: {paper['title']}, arXiv {full_paper_id}, {month} {year} \nAbstract: {paper['abstract']}\n"
40
+
41
+ return resp
tasks/intro_2_abs.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate the abstract of a paper based on its introduction section.
3
+
4
+ Args:
5
+ usr_prompt (str): The user-provided prompt containing the introduction section of the paper.
6
+ template (dict): A dictionary containing the template for generating the abstract.
7
+ context_window (int): The maximum length of the context window for the prompt input.
8
+
9
+ Returns:
10
+ str: The generated abstract based on the introduction section.
11
+ """
12
+
13
+
14
+ def intro_2_abs(usr_prompt, template, context_window, has_predefined_template=False):
15
+ instruction = "Please generate the abstract of paper based on its introduction section."
16
+
17
+ # Reduce it to make it fit
18
+ prompt_input = usr_prompt[:int(context_window*2)]
19
+
20
+ if has_predefined_template:
21
+ res = [
22
+ {"role": "system", "content": instruction},
23
+ {"role": "user", "content": prompt_input},
24
+ ]
25
+ else:
26
+ res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
27
+
28
+ return res
tasks/link_pred.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Determine if paper A will cite paper B.
3
+
4
+ Args:
5
+ usr_input (str): The user-provided input containing the titles and abstracts of papers A and B.
6
+ template (dict): A dictionary containing the template for generating the link prediction task.
7
+
8
+ Returns:
9
+ str: The generated link prediction task based on the user input.
10
+ """
11
+
12
+ def link_pred(usr_input, template, has_predefined_template=False):
13
+ instruction = "Determine if paper A will cite paper B."
14
+
15
+ if has_predefined_template:
16
+ res = [
17
+ {"role": "system", "content": instruction},
18
+ {"role": "user", "content": usr_input},
19
+ ]
20
+ else:
21
+ res = template["prompt_input"].format(instruction=instruction, input=usr_input)
22
+
23
+ return res
tasks/paper_retrieval.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retrieves the most likely paper to be cited by Paper A from a list of candidate papers based on user input.
3
+ Args:
4
+ usr_input (str): A string containing the title and abstract of Paper A followed by the titles and abstracts of candidate papers.
5
+ template (dict): A dictionary containing a template for formatting the prompt input.
6
+ Returns:
7
+ str: A string containing the prompt input for the user.
8
+ """
9
+
10
+ def paper_retrieval(usr_input, template, has_predefined_template=False):
11
+ instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
12
+
13
+ if has_predefined_template:
14
+ res = [
15
+ {"role": "system", "content": instruction},
16
+ {"role": "user", "content": usr_input},
17
+ ]
18
+ else:
19
+ res = template["prompt_input"].format(instruction=instruction, input=usr_input)
20
+
21
+ return res
train.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import random
4
+ import transformers
5
+ import networkx as nx
6
+ from tqdm import tqdm
7
+ from peft import (LoraConfig, get_peft_model,
8
+ prepare_model_for_kbit_training)
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
+
11
+
12
+
13
+ class QloraTrainer_CS:
14
+ def __init__(self, config: dict, use_predefined_graph=False):
15
+ self.config = config
16
+ self.use_predefined_graph = use_predefined_graph
17
+ self.tokenizer = None
18
+ self.base_model = None
19
+ self.adapter_model = None
20
+ self.merged_model = None
21
+ self.transformer_trainer = None
22
+ self.test_data = None
23
+
24
+ template_file_path = 'configs/alpaca.json'
25
+ with open(template_file_path) as fp:
26
+ self.template = json.load(fp)
27
+
28
+
29
+ def load_base_model(self):
30
+ model_id = self.config['inference']["base_model"]
31
+ print(model_id)
32
+
33
+ bnb_config = BitsAndBytesConfig(
34
+ load_in_8bit=True,
35
+ bnb_8bit_use_double_quant=True,
36
+ bnb_8bit_quant_type="nf8",
37
+ bnb_8bit_compute_dtype=torch.bfloat16
38
+ )
39
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
40
+ tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
41
+ if not tokenizer.pad_token:
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
44
+ if model.device.type != 'cuda':
45
+ model.to('cuda')
46
+
47
+ model.gradient_checkpointing_enable()
48
+ model = prepare_model_for_kbit_training(model)
49
+
50
+ self.tokenizer = tokenizer
51
+ self.base_model = model
52
+
53
+
54
+ def train(self):
55
+ # Set up lora config or load pre-trained adapter
56
+ lora_config = LoraConfig(
57
+ r=self.config['training']['qlora']['rank'],
58
+ lora_alpha=self.config['training']['qlora']['lora_alpha'],
59
+ target_modules=self.config['training']['qlora']['target_modules'],
60
+ lora_dropout=self.config['training']['qlora']['lora_dropout'],
61
+ bias="none",
62
+ task_type="CAUSAL_LM",
63
+ )
64
+ model = get_peft_model(self.base_model, lora_config)
65
+ self._print_trainable_parameters(model)
66
+
67
+ print("Start data preprocessing")
68
+ train_data = self._process_data_instruction()
69
+
70
+ print('Length of dataset: ', len(train_data))
71
+
72
+ print("Start training")
73
+ self.transformer_trainer = transformers.Trainer(
74
+ model=model,
75
+ train_dataset=train_data,
76
+ args=transformers.TrainingArguments(
77
+ per_device_train_batch_size=self.config["training"]['trainer_args']["per_device_train_batch_size"],
78
+ gradient_accumulation_steps=self.config['model_saving']['index'],
79
+ warmup_steps=self.config["training"]['trainer_args']["warmup_steps"],
80
+ num_train_epochs=self.config["training"]['trainer_args']["num_train_epochs"],
81
+ learning_rate=self.config["training"]['trainer_args']["learning_rate"],
82
+ lr_scheduler_type=self.config["training"]['trainer_args']["lr_scheduler_type"],
83
+ fp16=self.config["training"]['trainer_args']["fp16"],
84
+ logging_steps=self.config["training"]['trainer_args']["logging_steps"],
85
+ output_dir=self.config["training"]['trainer_args']["trainer_output_dir"],
86
+ report_to="wandb",
87
+ save_steps=self.config["training"]['trainer_args']["save_steps"],
88
+ ),
89
+ data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
90
+ )
91
+
92
+ model.config.use_cache = False
93
+
94
+ self.transformer_trainer.train()
95
+
96
+ model_save_path = f"{self.config['model_saving']['model_output_dir']}/{self.config['model_saving']['model_name']}_{self.config['model_saving']['index']}_adapter_test_graph"
97
+ self.transformer_trainer.save_model(model_save_path)
98
+
99
+ self.adapter_model = model
100
+ print(f"Training complete, adapter model saved in {model_save_path}")
101
+
102
+
103
+ def _print_trainable_parameters(self, model):
104
+ """
105
+ Prints the number of trainable parameters in the model.
106
+ """
107
+ trainable_params = 0
108
+ all_param = 0
109
+ for _, param in model.named_parameters():
110
+ all_param += param.numel()
111
+ if param.requires_grad:
112
+ trainable_params += param.numel()
113
+ print(
114
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
115
+ )
116
+
117
+
118
+ def _process_data_instruction(self):
119
+ context_window = self.tokenizer.model_max_length
120
+ if self.use_predefined_graph:
121
+ graph_data = nx.read_gexf('datasets/' + self.config["training"]["predefined_graph_path"], node_type=None, relabel=False, version='1.2draft')
122
+ else:
123
+ graph_path = self.config['data_downloading']['download_directory'] + 'description/' + self.config['data_downloading']['gexf_file']
124
+ graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft')
125
+ raw_graph = graph_data
126
+
127
+ test_set_size = len(graph_data.nodes()) // 10
128
+
129
+ all_test_nodes = set(list(graph_data.nodes())[:test_set_size])
130
+ all_train_nodes = set(list(graph_data.nodes())[test_set_size:])
131
+
132
+ raw_id_2_title_abs = dict()
133
+ for paper_id in list(graph_data.nodes())[test_set_size:]:
134
+ title = graph_data.nodes()[paper_id]['title']
135
+ abstract = graph_data.nodes()[paper_id]['abstract']
136
+ raw_id_2_title_abs[paper_id] = [title, abstract]
137
+
138
+ raw_id_2_intro = dict()
139
+ for paper_id in list(graph_data.nodes())[test_set_size:]:
140
+ if graph_data.nodes[paper_id]['introduction'] != '':
141
+ intro = graph_data.nodes[paper_id]['introduction']
142
+ raw_id_2_intro[paper_id] = intro
143
+
144
+ raw_id_pair_2_sentence = dict()
145
+ for edge in list(graph_data.edges()):
146
+ sentence = graph_data.edges()[edge]['sentence']
147
+ raw_id_pair_2_sentence[edge] = sentence
148
+
149
+
150
+ test_data = []
151
+ edge_list = []
152
+ for edge in list(raw_graph.edges()):
153
+ src, tar = edge
154
+ if src not in all_test_nodes and tar not in all_test_nodes:
155
+ edge_list.append(edge)
156
+ else:
157
+ test_data.append(edge)
158
+ train_num = int(len(edge_list))
159
+
160
+ data_LP = []
161
+ data_abstract_2_title = []
162
+ data_paper_retrieval = []
163
+ data_citation_sentence = []
164
+ data_abs_completion = []
165
+ data_title_2_abs = []
166
+ data_intro_2_abs = []
167
+
168
+
169
+ for sample in tqdm(random.sample(edge_list, train_num)):
170
+ source, target = sample[0], sample[1]
171
+ source_title, source_abs = raw_id_2_title_abs[source]
172
+ target_title, target_abs = raw_id_2_title_abs[target]
173
+ # LP prompt
174
+ rand_ind = random.choice(list(raw_id_2_title_abs.keys()))
175
+ neg_title, neg_abs = raw_id_2_title_abs[rand_ind]
176
+ data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'})
177
+ data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'})
178
+
179
+ for sample in tqdm(random.sample(edge_list, train_num)):
180
+ source, target = sample[0], sample[1]
181
+ source_title, source_abs = raw_id_2_title_abs[source]
182
+ target_title, target_abs = raw_id_2_title_abs[target]
183
+ # abs_2_title prompt
184
+ data_abstract_2_title.append({'title':source_title, 'abs':source_abs})
185
+ data_abstract_2_title.append({'title':target_title, 'abs':target_abs})
186
+
187
+ for sample in tqdm(random.sample(edge_list, train_num)):
188
+ source, target = sample[0], sample[1]
189
+ source_title, source_abs = raw_id_2_title_abs[source]
190
+ target_title, target_abs = raw_id_2_title_abs[target]
191
+ # paper_retrieval prompt
192
+ neighbors = list(nx.all_neighbors(raw_graph, source))
193
+ sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target]))
194
+ sampled_neg_nodes = random.sample(sample_node_list, 5) + [target]
195
+ random.shuffle(sampled_neg_nodes)
196
+ data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title})
197
+
198
+ for sample in tqdm(random.sample(edge_list, train_num)):
199
+ source, target = sample[0], sample[1]
200
+ source_title, source_abs = raw_id_2_title_abs[source]
201
+ target_title, target_abs = raw_id_2_title_abs[target]
202
+ # citation_sentence prompt
203
+ citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)]
204
+ data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence})
205
+
206
+ for sample in tqdm(random.sample(edge_list, train_num)):
207
+ source, target = sample[0], sample[1]
208
+ source_title, source_abs = raw_id_2_title_abs[source]
209
+ target_title, target_abs = raw_id_2_title_abs[target]
210
+ # abs_complete prompt
211
+ data_abs_completion.append({'title':source_title, 'abs':source_abs})
212
+ data_abs_completion.append({'title':target_title, 'abs':target_abs})
213
+
214
+ for sample in tqdm(random.sample(edge_list, train_num)):
215
+ source, target = sample[0], sample[1]
216
+ source_title, source_abs = raw_id_2_title_abs[source]
217
+ target_title, target_abs = raw_id_2_title_abs[target]
218
+ # title_2_abs prompt
219
+ data_title_2_abs.append({'title':source_title, 'right_abs':source_abs})
220
+ data_title_2_abs.append({'title':target_title, 'right_abs':target_abs})
221
+
222
+ for sample in tqdm(random.sample(edge_list, train_num)):
223
+ source, target = sample[0], sample[1]
224
+ if source in raw_id_2_intro:
225
+ source_intro = raw_id_2_intro[source]
226
+ _, source_abs = raw_id_2_title_abs[source]
227
+ data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs})
228
+ if target in raw_id_2_intro:
229
+ target_intro = raw_id_2_intro[target]
230
+ _, target_abs = raw_id_2_title_abs[target]
231
+ data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs})
232
+
233
+ data_prompt = []
234
+ data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval]
235
+ data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP]
236
+ data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title]
237
+ data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence]
238
+ data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion]
239
+ data_prompt += [self._generate_title_2_abstract_prompt(data_point) for data_point in data_title_2_abs]
240
+ data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs]
241
+
242
+ print("Total prompts:", len(data_prompt))
243
+ random.shuffle(data_prompt)
244
+ if self.tokenizer.chat_template is None:
245
+ data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)]
246
+ else:
247
+ data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)]
248
+
249
+ return data_tokenized
250
+
251
+
252
+ def _generate_LP_prompt(self, data_point: dict):
253
+ instruction = "Determine if paper A will cite paper B."
254
+
255
+ prompt_input = ""
256
+ prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
257
+ prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
258
+ prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
259
+ prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
260
+
261
+ if self.tokenizer.chat_template is None:
262
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
263
+ res = f"{res}{data_point['label']}"
264
+ else:
265
+ res = [
266
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
267
+ {"role": "assistant", "content": data_point['label']}
268
+ ]
269
+
270
+ return res
271
+
272
+ def _generate_abstract_2_title_prompt(self, data_point: dict):
273
+ instruction = "Please generate the title of paper based on its abstract."
274
+
275
+ prompt_input = ""
276
+ prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n"
277
+
278
+ if self.tokenizer.chat_template is None:
279
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
280
+ res = f"{res}{data_point['title']}"
281
+ else:
282
+ res = [
283
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
284
+ {"role": "assistant", "content": data_point['title']}
285
+ ]
286
+
287
+ return res
288
+
289
+ def _generate_paper_retrieval_prompt(self, data_point: dict):
290
+ instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
291
+
292
+ prompt_input = ""
293
+ prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n"
294
+ prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n"
295
+ prompt_input = prompt_input + "candidate papers: " + "\n"
296
+ for i in range(len(data_point['sample_title'])):
297
+ prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n"
298
+
299
+ if self.tokenizer.chat_template is None:
300
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
301
+ res = f"{res}{data_point['right_title']}"
302
+ else:
303
+ res = [
304
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
305
+ {"role": "assistant", "content": data_point['right_title']}
306
+ ]
307
+
308
+ return res
309
+
310
+ def _generate_citation_sentence_prompt(self, data_point: dict):
311
+ instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section."
312
+
313
+ prompt_input = ""
314
+ prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
315
+ prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
316
+ prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
317
+ prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
318
+
319
+ if self.tokenizer.chat_template is None:
320
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
321
+ res = f"{res}{data_point['sentence']}"
322
+ else:
323
+ res = [
324
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
325
+ {"role": "assistant", "content": data_point['sentence']}
326
+ ]
327
+
328
+ return res
329
+
330
+ def _generate_abstract_completion_prompt(self, data_point: dict):
331
+ instruction = "Please complete the abstract of a paper."
332
+
333
+ prompt_input = ""
334
+ prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n"
335
+
336
+ split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))]
337
+ prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n"
338
+
339
+ if self.tokenizer.chat_template is None:
340
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
341
+ res = f"{res}{data_point['abs']}"
342
+ else:
343
+ res = [
344
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
345
+ {"role": "assistant", "content": data_point['abs']}
346
+ ]
347
+
348
+ return res
349
+
350
+ def _generate_title_2_abstract_prompt(self, data_point: dict):
351
+ instruction = "Please generate the abstract of paper based on its title."
352
+
353
+ prompt_input = ""
354
+ prompt_input = prompt_input + "Title: " + data_point['title'] + "\n"
355
+
356
+ if self.tokenizer.chat_template is None:
357
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
358
+ res = f"{res}{data_point['right_abs']}"
359
+ else:
360
+ res = [
361
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
362
+ {"role": "assistant", "content": data_point['right_abs']}
363
+ ]
364
+
365
+ return res
366
+
367
+ def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window):
368
+ instruction = "Please generate the abstract of paper based on its introduction section."
369
+
370
+ prompt_input = ""
371
+ prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n"
372
+
373
+ # Reduce it to make it fit
374
+ prompt_input = prompt_input[:int(context_window*2)]
375
+
376
+ if self.tokenizer.chat_template is None:
377
+ res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
378
+ res = f"{res}{data_point['abs']}"
379
+ else:
380
+ res = [
381
+ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
382
+ {"role": "assistant", "content": data_point['abs']}
383
+ ]
384
+
385
+ return res
utils/de-macro.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python -O
2
+
3
+ r"""
4
+ Copyright 2005-2020 Peter Gacs
5
+ Licensed under the Academic Free Licence version 2.1
6
+
7
+ DE-MACRO
8
+
9
+ Version 1.4.1 - A small typo corrected.
10
+
11
+ Version 1.4 - Luca Citi made it python2.7 and python3 compatible.
12
+ Peter Gacs improved the parsing of \input{<filename>},
13
+ and made @ a letter in the style files.
14
+ Version 1.3 - this version is much more conservative about deleting
15
+ comments and inserting or deleting blank space: tries to
16
+ leave in all comments, adds space only when necessary, and
17
+ tries not to delete space in the main text.
18
+ The motivating comments came from Daniel Webb.
19
+ Version 1.2 - a syntactical bug corrected, thanks Brian de Alwis!
20
+
21
+
22
+ PURPOSE
23
+
24
+ This program can eliminate most private macros from a LaTeX file.
25
+ Applications:
26
+ - your publisher has difficulty dealing with many private macros
27
+ - you cooperate with colleagues who do not understand your macros
28
+ - preprocessing before a system like latex2html, which is somewhat
29
+ unpredictable with private macros.
30
+
31
+ It cannot be used to eliminate more complex macros that rely on
32
+ more programming-like constructs in style files. In particular, it will
33
+ not replace style files that have options.
34
+
35
+ USAGE
36
+
37
+ de-macro [--defs <defs-db>] <tex-file-1>[.tex] [<tex-file-2>[.tex] ...]
38
+
39
+ Simplest example: de-macro testament
40
+
41
+ (As you see, the <> is used only in the notation of this documentation,
42
+ you should not type it.)
43
+
44
+ If <tex-file-i> contains a command \usepackage{<defs-file>-private}
45
+ then the file <defs-file>-private.sty will be read, and its macros will be
46
+ replaced in <tex-file-i> with their definitions.
47
+ The result is in <tex-file-i>-clean.tex.
48
+
49
+ Only newcommand, renewcommand, newenvironment, and renewenvironment are
50
+ understood (it does not matter, whether you write new or renew).
51
+ These can be nested but do not be too clever, since I do not
52
+ guarantee the same expansion order as in TeX.
53
+
54
+ FILES
55
+
56
+ <tex-file-1>.db
57
+ <tex-file>-clean.tex
58
+ <defs-file>-private.sty
59
+
60
+ For speed, a macro database file called <defs-file>.db is created.
61
+ If such a file exists already then it is used.
62
+ If <defs-file>-private.sty is older than <tex-file-1>.db then it will not
63
+ be used.
64
+
65
+ It is possible to specify another database filename via --defs <defs-db>.
66
+ Then <defs-db>.db will be used.
67
+
68
+ For each <tex-file-i>, a file <tex-file-i>-clean.tex will be produced.
69
+ If <tex-file-i>-clean.tex is newer than <tex-file-i>.tex then it stays.
70
+
71
+ INPUT COMMAND
72
+
73
+ If a tex file contains a command \input{<tex-file-j>} or \input <tex-file-j>
74
+ then <tex-file-j>.tex is processed recursively, and <tex-file-j>-clean.tex
75
+ will be inserted into the final output.
76
+ For speed, if <tex-file-j>-clean.tex is newer than <tex-file-j>.tex
77
+ then <tex-file-j>.tex will not be reprocessed.
78
+
79
+ The dependency checking is not sophisticated, so if you rewrite some macros
80
+ then remove all *-clean.tex files!
81
+
82
+ """
83
+
84
+ import sys, os, re, shelve
85
+
86
+ # Utilities
87
+
88
+ class No_detail:
89
+ strerror = ""
90
+
91
+ no_detail = No_detail()
92
+
93
+
94
+ class Error(Exception):
95
+ """Base class for exceptions in this module."""
96
+ pass
97
+
98
+ class Empty_text_error(Error):
99
+ """Exception raised for errors in the input.
100
+
101
+ Attributes:
102
+ data -- data that was found empty
103
+ message
104
+ """
105
+
106
+ def __init__(self, data, message):
107
+ self.data = data
108
+ self.message = message
109
+
110
+ def warn(error_message, detail = no_detail):
111
+ sys.stderr.write(error_message + "\n")
112
+ if no_detail != detail:
113
+ sys.stderr.write(detail.strerror + "\n")
114
+
115
+ def die(error_message, detail = no_detail):
116
+ warn(error_message, detail = no_detail)
117
+ sys.exit(1)
118
+
119
+ def getopt_map(one_letter_opts, long_optlist):
120
+ "Turns long options into an option map, using getopt."
121
+ import getopt
122
+ optlist, args = getopt.getopt(sys.argv[1:],
123
+ one_letter_opts, long_optlist)
124
+ opt_map = {}
125
+ for pair in optlist: opt_map[pair[0]] = pair[1] or 1
126
+ return opt_map, args
127
+
128
+ def newer(file1, file2):
129
+
130
+ if not os.path.isfile(file1):
131
+ return False
132
+
133
+ try:
134
+ stat_return = os.lstat(file1)
135
+ except OSError as detail:
136
+ die("lstat " + file1 + " failed:", detail)
137
+ time1 = stat_return.st_mtime
138
+
139
+ try:
140
+ stat_return = os.lstat(file2)
141
+ except OSError as detail:
142
+ die("lstat " + file2 + " failed:", detail)
143
+ time2 = stat_return.st_mtime
144
+
145
+ return time1 > time2
146
+
147
+ def cut_extension(filename, ext):
148
+ """
149
+ If filename has extension ext (including the possible dot),
150
+ it will be cut off.
151
+ """
152
+ file = filename
153
+ index = filename.rfind(ext)
154
+ if 0 <= index and len(file)-len(ext) == index:
155
+ file = file[:index]
156
+ return file
157
+
158
+
159
+ class Stream:
160
+ data = None
161
+ pos = None
162
+ item = None
163
+
164
+ def legal(self):
165
+ return 0 <= self.pos and self.pos < len(self.data)
166
+
167
+ def uplegal(self):
168
+ return self.pos < len(self.data)
169
+
170
+ def __init__(self, data_v = None):
171
+ self.data = data_v
172
+ if self.data:
173
+ self.pos = 0
174
+ self.item = self.data[self.pos]
175
+
176
+ def next(self):
177
+ self.pos += 1
178
+ if self.pos < len(self.data):
179
+ self.item = self.data[self.pos]
180
+ return self.item
181
+
182
+ def reset(self):
183
+ if self.data and 0 < len(self.data):
184
+ self.pos = 0
185
+ self.item = self.data[0]
186
+ return self.item
187
+
188
+
189
+ # Basic classes
190
+
191
+ blank_re = re.compile(r"\s")
192
+ blanked_filename_re = re.compile(r"^\s+(\w*)\s+")
193
+ braced_filename_re = re.compile(r"^\s*{\s*(\w*)\s*}")
194
+ blank_or_rbrace_re = re.compile(r"[\s}]")
195
+ pos_digit_re = re.compile(r"[1-9]")
196
+
197
+ def isletter(c, isatletter=False):
198
+ if "@" == c:
199
+ return isatletter
200
+ else:
201
+ return c.isalpha()
202
+
203
+ class Token:
204
+ """Type 0 means ordinary character, types 1,2 mean escape sequence
205
+ (without the \ ), type 3 means comment.
206
+ """
207
+ simple_ty = 0
208
+ esc_symb_ty = 1
209
+ esc_str_ty = 2
210
+ comment_ty = 3
211
+
212
+ type = simple_ty
213
+ val = " "
214
+
215
+ def __init__(self, type_v=simple_ty, val_v=" "):
216
+ self.type = type_v
217
+ self.val = val_v
218
+
219
+ def show(self):
220
+ out = ""
221
+ if simple_ty == self.type or comment_ty == self.type:
222
+ out = self.val
223
+ else:
224
+ out = "\\" + self.val
225
+ return out
226
+
227
+
228
+ # Constants
229
+
230
+ g_token = Token(0," ") # generic token
231
+ simple_ty = g_token.simple_ty
232
+ comment_ty = g_token.comment_ty
233
+ esc_symb_ty = g_token.esc_symb_ty
234
+ esc_str_ty = g_token.esc_str_ty
235
+
236
+
237
+ def detokenize(text, isatletter=False):
238
+ """
239
+ Input is a list of tokens.
240
+ Output is a string.
241
+ """
242
+ out = ""
243
+ if 0 == len(text):
244
+ return
245
+ pos = 0
246
+ out += text[pos].show()
247
+ pos += 1
248
+ while pos < len(text):
249
+ previtem = text[pos-1]
250
+ item = text[pos]
251
+ """Insert a separating space after an escape sequence if it is a
252
+ string and is followed by a letter."""
253
+ if (esc_str_ty == previtem.type
254
+ and simple_ty == item.type and isletter(item.val[0], isatletter)):
255
+ out += " "
256
+ out += item.show()
257
+ pos += 1
258
+ return out
259
+
260
+
261
+ def strip_comments(text):
262
+ """
263
+ Input is a list of tokens.
264
+ Output is the same list except the comment tokens.
265
+ """
266
+ out = []
267
+ for token in text:
268
+ if not comment_ty == token.type:
269
+ out.append(token)
270
+ return out
271
+
272
+ class Group:
273
+ """type 0 means a token, type 1 means contents of a group within {}
274
+ """
275
+ token_ty = 0
276
+ group_ty = 1
277
+ type = token_ty
278
+ val = [] # Value is a token list.
279
+
280
+ def __init__(self, type_v, val_v):
281
+ self.type = type_v
282
+ self.val = val_v
283
+
284
+ def show(self):
285
+ if token_ty == self.type:
286
+ return self.val.show()
287
+ else:
288
+ return "{%s}" % detokenize(self.val)
289
+
290
+ # Constants
291
+
292
+ g_group = Group(0, [])
293
+ token_ty = g_group.token_ty
294
+ group_ty = g_group.group_ty
295
+
296
+
297
+ def tokenize(in_str, isatletter=False):
298
+ """Returns a list of tokens.
299
+ """
300
+ text = []
301
+ cs = Char_stream(in_str)
302
+ cs.reset()
303
+ if not cs.legal():
304
+ raise Error("No string to tokenize.")
305
+ while cs.uplegal():
306
+ if "%" == cs.item:
307
+ comment = cs.scan_comment_token()
308
+ text.append(Token(comment_ty, comment))
309
+ elif "\\" != cs.item:
310
+ text.append(Token(simple_ty, cs.item))
311
+ cs.next()
312
+ else:
313
+ cs.next()
314
+ name = cs.scan_escape_token(isatletter)
315
+ if isletter(name[0], isatletter):
316
+ token = Token(esc_str_ty, name)
317
+ else:
318
+ token = Token(esc_symb_ty, name)
319
+ text.append(token)
320
+ if "makeatletter" == name:
321
+ isatletter=True
322
+ elif "makeatother" == name:
323
+ isatletter=False
324
+ return text
325
+
326
+
327
+ class Command_def:
328
+ name = "1"
329
+ numargs = 0
330
+ body= ""
331
+
332
+ def __init__(self, name_v, numargs_v, body_v):
333
+ self.name = name_v
334
+ self.numargs = numargs_v
335
+ self.body = body_v
336
+
337
+ def show(self):
338
+ out = "\\newcommand{\\%s}" % (self.name)
339
+ if 0 < self.numargs:
340
+ out += "[%d]" % self.numargs
341
+ out += "{%s}" % detokenize(self.body)
342
+ return out
343
+
344
+
345
+ class Env_def:
346
+ name = "1"
347
+ numargs = 0
348
+ begin = ""
349
+ end = ""
350
+
351
+ def __init__(self, name_v, numargs_v, begin_v, end_v):
352
+ self.name = name_v
353
+ self.numargs = numargs_v
354
+ self.begin = begin_v
355
+ self.end = end_v
356
+
357
+ def show(self):
358
+ out = "\\newenvironment{%s}" % self.name
359
+ if 0 < self.numargs:
360
+ out += "[%d]" % self.numargs
361
+ out += "{%s}" % detokenize(self.begin)
362
+ out += "{%s}" % detokenize(self.end)
363
+ return out
364
+
365
+
366
+ class Command_instance:
367
+ name = "1"
368
+ args = []
369
+
370
+ def __init__(self, name_v, args_v):
371
+ self.name = name_v
372
+ self.args = args_v
373
+
374
+ def show(self):
375
+ out = "\\"+self.name
376
+ for arg in self.args:
377
+ out += "{%s}" % detokenize(arg)
378
+ return out
379
+
380
+
381
+ class Env_instance:
382
+ name = "1"
383
+ args = []
384
+
385
+ def __init__(self, name_v, args_v, body_v):
386
+ self.name = name_v
387
+ self.args = args_v
388
+ self.body = body_v
389
+
390
+ def show(self):
391
+ out = "\\begin{%s}" % self.name
392
+ for arg in self.args:
393
+ out += "{%s}" % detokenize(arg)
394
+ out += detokenize(self.body)
395
+ out += "\\end{%s}" % self.name
396
+ return out
397
+
398
+ class Char_stream(Stream):
399
+
400
+ def scan_escape_token(self, isatletter=False):
401
+ """
402
+ Starts after the escape sign, assumes that it is scanning a symbol.
403
+ Returns a token-string.
404
+ """
405
+ out = self.item # Continue only if this is a letter.
406
+ item = self.next()
407
+ if isletter(out, isatletter):
408
+ while self.uplegal() and isletter(item, isatletter):
409
+ out += item
410
+ item = self.next()
411
+ return out
412
+
413
+ def scan_comment_token(self):
414
+ """
415
+ Starts at the comment sign %, assumes that it is scanning a comment.
416
+ Returns the whole comment string,
417
+ including the % and all empty space after it.
418
+ """
419
+ comment = ""
420
+ while self.uplegal() and "\n" != self.item:
421
+ comment += self.item
422
+ self.next()
423
+ while self.uplegal() and blank_re.match(self.item):
424
+ comment += self.item
425
+ self.next()
426
+ return comment
427
+
428
+ def scan_input_filename(self):
429
+ """We have just read an \input token. The next group or word will be
430
+ interpreted as a filename (possibly without .tex). Filenames should not begin with spaces.
431
+ Return the filename.
432
+ """
433
+ item = self.item
434
+ file = ""
435
+ while self.uplegal() and blank_re.match(self.item):
436
+ item = self.next()
437
+ if "{" == item:
438
+ item = self.next()
439
+ while self.uplegal() and not "}" == item:
440
+ file += item
441
+ item = self.next()
442
+ self.next()
443
+ else:
444
+ while self.uplegal() and not blank_re.match(item):
445
+ file += item
446
+ item = self.next()
447
+ return file
448
+
449
+ def scan_package_filenames(self):
450
+ r"""We just read a \usepackage token. The next group will be
451
+ interpreted as a list of filenames (without .sty) separated by commas.
452
+ Return the list.
453
+ """
454
+ item = self.item
455
+ while self.uplegal() and blank_re.match(item):
456
+ item = self.next()
457
+ file = ""
458
+ if not "{" == item:
459
+ raise Error("\\usepackage not followed by brace.")
460
+ item = self.next()
461
+ while self.uplegal() and not blank_or_rbrace_re.match(item):
462
+ file += item
463
+ item = self.next()
464
+ self.next()
465
+ return file.split(",")
466
+
467
+
468
+ class Tex_stream(Stream):
469
+
470
+ defs = ({}, {})
471
+ defs_db = "x"
472
+ defs_db_file = "x.db"
473
+ debug = False
474
+
475
+ def smart_tokenize(self, in_str, handle_inputs=False, isatletter=False):
476
+ """Returns a list of tokens.
477
+ It may interpret and carry out all \input commands.
478
+ """
479
+ self.data = []
480
+ text = self.data
481
+ cs = Char_stream(in_str)
482
+ cs.reset()
483
+ if not cs.legal():
484
+ raise Error("No string to tokenize.")
485
+ while cs.uplegal():
486
+ if "%" == cs.item:
487
+ comment = cs.scan_comment_token()
488
+ text.append(Token(comment_ty, comment))
489
+ elif "\\" != cs.item:
490
+ text.append(Token(simple_ty, cs.item))
491
+ cs.next()
492
+ else:
493
+ cs.next()
494
+ name = cs.scan_escape_token(isatletter)
495
+ if "input" == name and handle_inputs:
496
+ file = cs.scan_input_filename()
497
+ to_add = self.process_if_newer(file)
498
+ text.extend(to_add)
499
+ elif "usepackage" == name:
500
+ while cs.uplegal() and blank_re.match(cs.item):
501
+ cs.next()
502
+ if "[" == cs.item: # Packages with options will not be processed.
503
+ text.extend([Token(esc_str_ty, "usepackage"),
504
+ Token(simple_ty, "[")])
505
+ cs.next()
506
+ continue
507
+ files = cs.scan_package_filenames()
508
+ i = 0
509
+ while i < len(files): # process private packages
510
+ file = files[i]
511
+ p = file.rfind("-private")
512
+ if p < 0 or not len(file) - len("-private") == p:
513
+ i += 1
514
+ continue
515
+ defs_db_file = file+".db"
516
+ self.add_defs(file)
517
+ del files[i:(i+1)]
518
+ if files: # non-private packages left
519
+ group_content = ",".join(files)
520
+ to_add_str = "\\usepackage{%s}" % (group_content)
521
+ to_add = tokenize(to_add_str,isatletter)
522
+ text.extend(to_add)
523
+ else:
524
+ if isletter(name[0], isatletter):
525
+ token = Token(esc_str_ty, name)
526
+ else:
527
+ token = Token(esc_symb_ty, name)
528
+ text.append(token)
529
+ if "makeatletter" == name:
530
+ isatletter=True
531
+ elif "makeatother" == name:
532
+ isatletter=False
533
+ self.reset()
534
+ return self.data
535
+
536
+ def smart_detokenize(self,isatletter=False):
537
+ """
538
+ Output is a string.
539
+ If the list contains an \input{file} then the content of file
540
+ file-clean.tex replaces it in the output.
541
+ """
542
+ self.reset()
543
+ if not self.legal():
544
+ return ""
545
+ out = ""
546
+ previtem = None
547
+ while self.uplegal():
548
+ item = self.item
549
+ """Insert a separating space after an escape sequence if it is a
550
+ string and is followed by a letter."""
551
+ if (None != previtem and esc_str_ty == previtem.type
552
+ and simple_ty == item.type and isletter(item.val[0], isatletter)):
553
+ out += " "
554
+ previtem = item
555
+ if not (esc_str_ty == item.type and "input" == item.val):
556
+ out += item.show()
557
+ self.next()
558
+ else:
559
+ self.next()
560
+ group = self.scan_group()
561
+ file = detokenize(group.val)
562
+ clean_file = "%s-clean.tex" % (file)
563
+ print("Reading file %s" % (clean_file))
564
+ fp = open(clean_file,"r")
565
+ content = fp.read()
566
+ fp.close()
567
+ out += content
568
+ return out
569
+
570
+ # Basic tex scanning
571
+
572
+ def skip_blank_tokens(self): # we also skip comment tokens.
573
+ item = self.item
574
+ while (self.uplegal() and
575
+ (comment_ty == item.type or
576
+ (simple_ty == item.type and blank_re.match(item.val)))):
577
+ item = self.next()
578
+ return item
579
+
580
+ def scan_group(self):
581
+ """Returns group.
582
+ """
583
+ if not self.legal():
584
+ raise Error("No group to scan.")
585
+ item = self.item
586
+ if not (simple_ty == item.type and "{" == item.val):
587
+ return Group(token_ty, [self.item])
588
+ count = 1
589
+ group = []
590
+ item = self.next()
591
+ while count and self.uplegal():
592
+ if simple_ty == item.type:
593
+ if "{" == item.val:
594
+ count += 1
595
+ elif "}" == item.val:
596
+ count -= 1
597
+ if count != 0:
598
+ group.append(item)
599
+ item = self.next()
600
+ return Group(group_ty, group)
601
+
602
+ # Command and environment definitions
603
+
604
+ def scan_command_name(self):
605
+ """Returns name.
606
+ """
607
+ if not self.legal():
608
+ raise Error("No command name to scan.")
609
+ item = self.item
610
+ name = ""
611
+ if item.type in [esc_symb_ty, esc_str_ty]:
612
+ name = item.val
613
+ else:
614
+ if not "{" == item.val:
615
+ raise Error("Command definition misses first {.")
616
+ self.next()
617
+ item = self.skip_blank_tokens()
618
+ if not item.type in [esc_symb_ty, esc_str_ty]:
619
+ raise Error("Command definition does not begin with control sequence.")
620
+ name = item.val
621
+ self.next()
622
+ item = self.skip_blank_tokens()
623
+ if not "}" == item.val:
624
+ raise Error("Definition for commmand %s misses first }., %s" %
625
+ (name, item.val))
626
+ self.next()
627
+ self.skip_blank_tokens()
628
+ return name
629
+
630
+ def scan_numargs(self, name):
631
+ """
632
+ name is the name of the command or environment definition being
633
+ scanned.
634
+ Starts on a nonblank token.
635
+ Returns numargs
636
+ where numargs is the number of arguments in a command or environment
637
+ definition,
638
+ """
639
+ if not self.legal():
640
+ raise Error("No numargs to scan.")
641
+ item = self.item
642
+ numargs = 0
643
+ if not simple_ty == item.type:
644
+ raise Error("Illegal command or environment definition: "+name)
645
+ if "[" == item.val:
646
+ if not 4 < len(self.data):
647
+ raise Error("Command or environment definition is illegal: "+name)
648
+ item = self.next()
649
+ if not simple_ty == item.type:
650
+ raise Error("Illegal command or environment definition: "+name)
651
+ numargs = item.val
652
+ if not pos_digit_re.match(numargs):
653
+ raise Error("%s must be argument number after %s" % (numargs, name))
654
+ numargs = int(numargs)
655
+ self.next()
656
+ item = self.skip_blank_tokens()
657
+ if not simple_ty == item.type:
658
+ raise Error("Illegal command definition: "+name)
659
+ if "]" != item.val:
660
+ raise Error("Illegal command definition: "+name)
661
+ self.next()
662
+ self.skip_blank_tokens()
663
+ return numargs
664
+
665
+ def scan_command_def(self):
666
+ """Scan a command definition.
667
+ Return command_def.
668
+ Assumes that the number of arguments is at most 9.
669
+ """
670
+ if not self.legal():
671
+ raise Error("No command definition to scan.")
672
+ item = self.item
673
+ if not 2 < len(self.data):
674
+ raise Error("Command definition is illegal.")
675
+ # newcommand or renewcommand
676
+ if not item.type in [esc_symb_ty, esc_str_ty]:
677
+ raise Error("Command definition should begin with control sequence: "+item.val)
678
+ if item.val not in ["newcommand", "renewcommand"]:
679
+ raise Error("Command definition should begin with control sequence.")
680
+ self.next()
681
+ self.skip_blank_tokens()
682
+
683
+ cmd_name = self.scan_command_name()
684
+ numargs = self.scan_numargs(cmd_name)
685
+
686
+ body_group = self.scan_group()
687
+ if group_ty != body_group.type:
688
+ raise Error("Command body missing: "+cmd_name)
689
+ body_val = strip_comments(body_group.val)
690
+ return Command_def(cmd_name, numargs, body_val)
691
+
692
+ def scan_env_name(self):
693
+ """Starts on a {.
694
+ Returns name.
695
+ """
696
+ if not self.legal():
697
+ raise Error("No environment name to scan.")
698
+ item = self.item
699
+ if not "{" == item.val:
700
+ raise Error("Env. definition begins with %s, not with {" % (item.val))
701
+ self.next()
702
+ item = self.skip_blank_tokens()
703
+ name = ""
704
+ if not simple_ty == item.type:
705
+ raise Error("1. Env. def. begins with cont. seq. %s, not with env.name."
706
+ % (item.val))
707
+ while self.uplegal() and not blank_or_rbrace_re.match(item.val):
708
+ name += item.val
709
+ item = self.next()
710
+ if not simple_ty == item.type:
711
+ raise Error("2. Env. def. begins with cont. seq. %s, not with env.name."
712
+ % (item.val))
713
+ item = self.skip_blank_tokens()
714
+ if not "}" == item.val:
715
+ raise Error("Command definition does not begin with control sequence.")
716
+ self.next()
717
+ self.skip_blank_tokens()
718
+ return name
719
+
720
+ def scan_env_def(self):
721
+ """Scan an environment definition.
722
+ Return env_def
723
+ Assumes that the number of arguments is at most 9.
724
+ """
725
+ if not self.legal():
726
+ raise Error("No environment definition to scan.")
727
+ item = self.item
728
+ if not 7 < len(self.data):
729
+ raise Error("Environment definition is illegal.")
730
+ pos = 0
731
+
732
+ if not item.type in [esc_symb_ty, esc_str_ty]:
733
+ raise Error("Env. definition does not begin with control sequence:"+
734
+ item.val)
735
+ if item.val not in ["newenvironment", "renewenvironment"]:
736
+ raise Error("Env. definition does not begin with control sequence.")
737
+ self.next()
738
+ self.skip_blank_tokens()
739
+
740
+ env_name = self.scan_env_name()
741
+ numargs = self.scan_numargs(env_name)
742
+ self.skip_blank_tokens()
743
+
744
+ begin_group = self.scan_group()
745
+ if group_ty != begin_group.type:
746
+ raise Error("Begin body missing: "+env_name)
747
+ begin_val = strip_comments(begin_group.val)
748
+
749
+ self.skip_blank_tokens()
750
+
751
+ end_group = self.scan_group()
752
+ if group_ty != end_group.type:
753
+ raise Error("End body missing:"+env_name)
754
+ end_val = strip_comments(end_group.val)
755
+
756
+ return Env_def(env_name, numargs, begin_val, end_val)
757
+
758
+ def scan_defs(self):
759
+ if not self.legal():
760
+ raise Error("No definitions to scan.")
761
+ self.reset()
762
+ command_defs, env_defs = self.defs
763
+ while self.uplegal():
764
+ if (esc_str_ty == self.item.type
765
+ and self.item.val in ["newcommand", "renewcommand"]):
766
+ def_start_pos = self.pos
767
+ command_def = self.scan_command_def()
768
+ command_defs[command_def.name] = command_def
769
+ def_end_pos = self.pos
770
+ for del_pos in range(def_start_pos,def_end_pos):
771
+ del self.data[def_start_pos]
772
+ self.pos = def_start_pos
773
+ self.item = self.data[self.pos]
774
+ elif (esc_str_ty == self.item.type and self.item.val
775
+ in ["newenvironment", "renewenvironment"]):
776
+ def_start_pos = self.pos
777
+ env_def = self.scan_env_def()
778
+ env_defs[env_def.name] = env_def
779
+ def_end_pos = self.pos
780
+ for del_pos in range(def_start_pos,def_end_pos):
781
+ del self.data[def_start_pos]
782
+ self.pos = def_start_pos
783
+ self.item = self.data[self.pos]
784
+ else:
785
+ self.next()
786
+
787
+ # Instances
788
+
789
+ def scan_args(self, command_or_env_def):
790
+ """Scan the arguments of a command or environment.
791
+ Return [args].
792
+ """
793
+ if not self.legal():
794
+ raise Error("No arguments to scan.")
795
+ numargs = command_or_env_def.numargs
796
+ name = command_or_env_def.name
797
+
798
+ args = []
799
+ for i in range(numargs):
800
+ arg = []
801
+ if not (simple_ty == self.item.type and "{" == self.item.val):
802
+ arg = [self.item]
803
+ self.next()
804
+ else:
805
+ group = self.scan_group()
806
+ arg = group.val
807
+ args.append(arg)
808
+ return args
809
+
810
+ def scan_command(self, command_def):
811
+ """Scan the arguments of a command.
812
+ Return command_instance
813
+ """
814
+ if not self.legal():
815
+ raise Error("No command to scan.")
816
+ if not self.item.type in [esc_symb_ty, esc_str_ty]:
817
+ raise Error("Command does not begin with control sequence.")
818
+ name = self.item.val
819
+ self.next()
820
+ if 0 < command_def.numargs:
821
+ self.skip_blank_tokens()
822
+ args = self.scan_args(command_def)
823
+ else:
824
+ args = []
825
+ return Command_instance(name, args)
826
+
827
+ def test_env_boundary(self, item):
828
+ """Check whether an environment begin or end follows.
829
+ Return 1 if \begin, -1 if \end, 0 otherwise.
830
+ """
831
+ d = 0
832
+ if esc_str_ty == item.type:
833
+ if "begin"==item.val:
834
+ d = 1
835
+ elif "end"==item.val:
836
+ d = -1
837
+ return d
838
+
839
+ def scan_env_begin(self):
840
+ """Scan an environment name.
841
+ Return env_name.
842
+ """
843
+ if not self.legal():
844
+ raise Error("No environment begin to scan.")
845
+ item = self.item
846
+ if not (esc_str_ty == item.type and "begin" == item.val):
847
+ raise Error("Environment does not begin with begin.")
848
+ self.next()
849
+ name_group = self.scan_group()
850
+ name = detokenize(name_group.val)
851
+ return name
852
+
853
+ def scan_env_end(self):
854
+ """Scan an environment end.
855
+ Return env_name.
856
+ """
857
+ if not self.legal():
858
+ raise Error("No environment end to scan.")
859
+ item = self.item
860
+ if not (esc_str_ty == item.type and "end" == item.val):
861
+ raise Error("Environment does not end with end.")
862
+ self.next()
863
+ name_group = self.scan_group()
864
+ name = detokenize(name_group.val)
865
+ return name
866
+
867
+ def scan_env_rest(self, env_def):
868
+ """Scanning starts after \begin{envname}.
869
+ Returns env_instance.
870
+ """
871
+ if not self.legal():
872
+ raise Error("No environment rest to scan.")
873
+ count = 1 # We are already within a boundary.
874
+ args = self.scan_args(env_def)
875
+ body = []
876
+ while count and self.uplegal():
877
+ old_pos = self.pos
878
+ d = self.test_env_boundary(self.item)
879
+ count += d
880
+ if 1 == d:
881
+ self.scan_env_begin()
882
+ elif -1 == d:
883
+ self.scan_env_end()
884
+ else:
885
+ self.next()
886
+ if 0 < count:
887
+ body.extend(self.data[old_pos : self.pos])
888
+ return Env_instance(env_def.name, args, body)
889
+
890
+ # Definitions
891
+
892
+ def restore_defs(self):
893
+ if os.path.isfile(self.defs_db_file):
894
+ print("Using defs db %s" % (self.defs_db_file))
895
+ db_h = shelve.open(self.defs_db)
896
+ self.defs = db_h["defs"]
897
+ db_h.close()
898
+
899
+ def save_defs(self):
900
+ db_h = shelve.open(self.defs_db)
901
+ if "defs" in db_h:
902
+ del db_h["defs"]
903
+ db_h["defs"] = self.defs
904
+ db_h.close()
905
+
906
+ def add_defs(self, defs_file):
907
+ defs_file_compl = defs_file + ".sty"
908
+ if not os.path.isfile(defs_file_compl):
909
+ raise Error("%s does not exist" % (defs_file_compl))
910
+
911
+ defs_db_file = self.defs_db_file
912
+ if newer(defs_db_file, defs_file_compl):
913
+ print("Using defs db %s for %s" % (defs_db_file, defs_file))
914
+ else:
915
+ defs_fp = open(defs_file_compl, "r")
916
+ defs_str = defs_fp.read()
917
+ defs_fp.close()
918
+ ds = Tex_stream()
919
+ ds.defs = self.defs
920
+ defs_text = ds.smart_tokenize(defs_str,isatletter=True)
921
+ # changing ds.defs will change self.defs
922
+ if self.debug:
923
+ defs_seen_file = "%s-seen.sty" % (defs_file)
924
+ defs_seen_fp = open(defs_seen_file, "w")
925
+ out = detokenize(defs_text,isatletter=True)
926
+ defs_seen_fp.write(out)
927
+ defs_seen_fp.close()
928
+ ds.scan_defs()
929
+ if self.debug:
930
+ out = ""
931
+ command_defs, env_defs = self.defs
932
+ for def_name in command_defs.keys():
933
+ out += command_defs[def_name].show() + "\n"
934
+ for def_name in env_defs.keys():
935
+ out += env_defs[def_name].show() +"\n"
936
+ print("Definitions after reading %s:" % (defs_file))
937
+ print(out)
938
+
939
+ # Applying definitions, recursively
940
+ # (maybe not quite in Knuth order, so avoid tricks!)
941
+
942
+ def subst_args(self, body, args):
943
+ out = []
944
+ pos = 0
945
+ while pos < len(body):
946
+ item = body[pos]
947
+ if not (simple_ty == item.type and "#" == item.val):
948
+ out.append(item)
949
+ pos += 1
950
+ continue
951
+ pos += 1
952
+ token = body[pos]
953
+ argnum = token.val
954
+ if not pos_digit_re.match(argnum):
955
+ raise Error("# is not followed by number.")
956
+ argnum = int(argnum)
957
+ if argnum > len(args):
958
+ raise Error("Too large argument number.")
959
+ arg = args[argnum-1]
960
+ out += arg
961
+ pos += 1
962
+ return out
963
+
964
+ def apply_command_recur(self, command_instance):
965
+ command_defs, env_defs = self.defs
966
+ name = command_instance.name
967
+ command_def = command_defs[name]
968
+
969
+ args = command_instance.args
970
+ body = command_def.body
971
+ result = self.subst_args(body, args)
972
+ try:
973
+ result = self.apply_all_recur(result)
974
+ except Empty_text_error as e:
975
+ raise Error("apply_all_recur fails on command instance %s: %s, %s" % \
976
+ (command_instance.show(), detokenize(e.data), e.message))
977
+ return result
978
+
979
+ def apply_env_recur(self, env_instance):
980
+ command_defs, env_defs = self.defs
981
+ name = env_instance.name
982
+ env_def = env_defs[name]
983
+
984
+ begin, end = env_def.begin, env_def.end
985
+ body, args = env_instance.body, env_instance.args
986
+ out = self.subst_args(begin, args) + body + self.subst_args(end, args)
987
+ return self.apply_all_recur(out)
988
+
989
+
990
+ def apply_all_recur(self, data, report=False):
991
+ ts = Tex_stream(data)
992
+ ts.defs = self.defs
993
+ command_defs, env_defs = self.defs
994
+ out = []
995
+ progress_step = 10000
996
+ progress = progress_step
997
+ if not ts.legal():
998
+ raise Empty_text_error(data, "No text to process.")
999
+ while ts.uplegal():
1000
+ if self.pos > progress:
1001
+ if report:
1002
+ print(self.pos)
1003
+ progress += progress_step
1004
+ if not ts.item.type in [esc_symb_ty, esc_str_ty]:
1005
+ out.append(ts.item)
1006
+ ts.next()
1007
+ continue
1008
+ if 1 == ts.test_env_boundary(ts.item):
1009
+ old_pos = ts.pos
1010
+ env_name = ts.scan_env_begin()
1011
+ if env_name not in env_defs:
1012
+ out.extend(ts.data[old_pos : ts.pos])
1013
+ continue
1014
+ else:
1015
+ env_def = env_defs[env_name]
1016
+ env_instance = ts.scan_env_rest(env_def)
1017
+ result = ts.apply_env_recur(env_instance)
1018
+ out.extend(result)
1019
+ elif ts.item.val not in command_defs:
1020
+ out.append(ts.item)
1021
+ ts.next()
1022
+ continue
1023
+ else:
1024
+ command_def = command_defs[ts.item.val]
1025
+ command_inst = ts.scan_command(command_def)
1026
+ result = ts.apply_command_recur(command_inst)
1027
+ out.extend(result)
1028
+ return out
1029
+
1030
+
1031
+ # Processing files
1032
+
1033
+ def process_file(self, file):
1034
+ """Returns the new defs.
1035
+ """
1036
+ file = cut_extension(file, ".tex")
1037
+ source_file = "%s.tex" % (file)
1038
+ print("File %s [" % (source_file))
1039
+ source_fp = open(source_file, "r")
1040
+ text_str = source_fp.read()
1041
+ source_fp.close()
1042
+
1043
+ self.smart_tokenize(text_str, handle_inputs=True)
1044
+ if not self.data:
1045
+ raise Error("Empty tokenization result.")
1046
+ self.reset()
1047
+
1048
+ if self.debug:
1049
+ source_seen_fname = "%s-seen.tex" % (file)
1050
+ source_seen_fp = open(source_seen_fname, "w")
1051
+ source_seen_fp.write(detokenize(self.data))
1052
+ source_seen_fp.close()
1053
+ self.scan_defs()
1054
+ self.data = self.apply_all_recur(self.data, report=True)
1055
+ result_fname = "%s-clean.tex" % (file)
1056
+ print("Writing %s [" % (result_fname))
1057
+ result_fp = open(result_fname, "w")
1058
+ result_fp.write(self.smart_detokenize())
1059
+ result_fp.close()
1060
+ print("] file %s" % (result_fname))
1061
+ print("] file %s" % (source_file))
1062
+
1063
+ def process_if_newer(self, file):
1064
+ """
1065
+ \input{file} is added to the token list.
1066
+ If the input file is newer it is processed.
1067
+ Returns tokenized \input{file}.
1068
+ """
1069
+ file = cut_extension(file, ".tex")
1070
+ tex_file = file+".tex"
1071
+ clean_tex_file = file+"-clean.tex"
1072
+ if newer(clean_tex_file, tex_file):
1073
+ print("Using %s." % (clean_tex_file))
1074
+ else:
1075
+ ts = Tex_stream()
1076
+ ts.data = []
1077
+ ts.defs = self.defs
1078
+ ts.process_file(file)
1079
+ to_add = "\\input{%s}" % (file)
1080
+ return tokenize(to_add)
1081
+
1082
+ # Main
1083
+
1084
+ long_optlist = ["debug","defs="]
1085
+ options, restargs = getopt_map("x", long_optlist)
1086
+
1087
+ debug = False
1088
+ if "--debug" in options:
1089
+ debug = True
1090
+
1091
+ root = restargs[0]
1092
+ root = cut_extension(root, ".tex")
1093
+ if "--defs" in options:
1094
+ defs_root = options["--defs"]
1095
+ else:
1096
+ defs_root = "%s" % (root)
1097
+ defs_db = defs_root
1098
+ defs_db_file = defs_root+".db"
1099
+
1100
+ ts = Tex_stream()
1101
+ ts.defs_db = defs_db
1102
+ ts.defs_db_file = defs_db_file
1103
+ ts.debug = debug
1104
+
1105
+ ts.restore_defs()
1106
+ for root in restargs:
1107
+ ts.process_file(root)
1108
+
1109
+ print("(Re)creating defs db %s" % (defs_db))
1110
+ ts.save_defs()
utils/def_handle.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+
4
+
5
+ def main():
6
+ args = parse_command_line()
7
+ data = read(args.input)
8
+ data = convert(data)
9
+ write(args.output, data)
10
+
11
+
12
+ def parse_command_line():
13
+ parser = argparse.ArgumentParser(
14
+ description='Replace \\def with \\newcommand where possible.',
15
+ )
16
+ parser.add_argument(
17
+ 'input',
18
+ help='TeX input file with \\def',
19
+ )
20
+ parser.add_argument(
21
+ '--output',
22
+ '-o',
23
+ required=True,
24
+ help='TeX output file with \\newcommand',
25
+ )
26
+
27
+ return parser.parse_args()
28
+
29
+ def read(path):
30
+ with open(path, mode='rb') as handle:
31
+ return handle.read()
32
+
33
+
34
+ def convert(data):
35
+ return re.sub(
36
+ rb'((?:\\(?:expandafter|global|long|outer|protected)'
37
+ rb'(?: +|\r?\n *)?)*)?'
38
+ rb'\\def *(\\[a-zA-Z]+) *(?:#+([0-9]))*\{',
39
+ replace,
40
+ data,
41
+ )
42
+
43
+
44
+ def replace(match):
45
+ prefix = match.group(1)
46
+ if (
47
+ prefix is not None and
48
+ (
49
+ b'expandafter' in prefix or
50
+ b'global' in prefix or
51
+ b'outer' in prefix or
52
+ b'protected' in prefix
53
+ )
54
+ ):
55
+ pass #return match.group(0)
56
+
57
+ result = rb'\newcommand'
58
+
59
+ result += b'{' + match.group(2) + b'}'
60
+ if match.lastindex == 3:
61
+ result += b'[' + match.group(3) + b']'
62
+
63
+ result += b'{'
64
+ return result
65
+
66
+
67
+ def write(path, data):
68
+ with open(path, mode='wb') as handle:
69
+ handle.write(data)
70
+
71
+ print('=> File written: {0}'.format(path))
72
+
73
+
74
+ if __name__ == '__main__':
75
+ main()
utils/gradio_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import StoppingCriteria
2
+ import sys
3
+
4
+
5
+ # Handle termination signal
6
+ def signal_handler(sig, frame):
7
+ print("\nTermination signal received. Shutting down Gradio interface.")
8
+ sys.exit(0)
9
+
10
+ # Custom stopping criteria
11
+ class StopOnTokens(StoppingCriteria):
12
+ def __call__(self, input_ids, scores, **kwargs):
13
+ stop_ids = [29, 0] # Define specific stop token IDs
14
+ return input_ids[0][-1] in stop_ids
15
+
16
+ # Toggle task selection
17
+ def toggle_selection(current_task, new_task):
18
+ """Toggle task selection: deselect if clicked again, otherwise update selection."""
19
+ updated_task = "" if current_task == new_task else new_task
20
+ return updated_task
utils/graph_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex
2
+ import re
3
+
4
+ def retrieve_text_cite(text, command):
5
+ base_pattern = (
6
+ r'\\' + command + r"(?:\[(?:.*?)\])*\{((?:[^{}]+|\{(?1)\})*)\}(?:\[(?:.*?)\])*"
7
+ )
8
+
9
+ def extract_text_inside_curly_braces(text):
10
+ pattern = r"\{((?:[^{}]|(?R))*)\}"
11
+
12
+ match = regex.search(pattern, text)
13
+
14
+ if match:
15
+ return match.group(1)
16
+ else:
17
+ return ""
18
+
19
+ found_texts = []
20
+ for match in regex.finditer(base_pattern, text):
21
+ temp_substring = text[match.span()[0] : match.span()[1]]
22
+ found_texts.append(extract_text_inside_curly_braces(temp_substring))
23
+
24
+ return found_texts
25
+
26
+ def get_citing_sentences(content):
27
+ content_new = re.sub(r'[\n]+', ' ', content) # keep only one \n
28
+ content_new = re.sub(r'e\.g\.' , 'eg', content_new)
29
+ content_new = re.sub(r'i\.e\.' , 'eg', content_new)
30
+ content_new = re.sub(r'etc\.' , 'etc', content_new)
31
+ content_new = re.sub(r' +', ' ', content_new)
32
+ sentences = [sentence + '.' for sentence in content_new.split('.')]
33
+ citing_sentences = [s for s in sentences if '\\cite' in s]
34
+ results = {}
35
+ for s in citing_sentences:
36
+ citations = retrieve_text_cite(s, 'cite')
37
+ final_citations = []
38
+ for cite in citations:
39
+ final_citations.extend(cite.split(','))
40
+ results[s] = final_citations
41
+ return results
42
+
43
+ def get_intro(content):
44
+ sections = retrieve_text_cite(content, 'section')
45
+ if sections == []:
46
+ return ''
47
+ try_intro = [x for x in sections if x.strip().lower() == 'introduction']
48
+ if try_intro == []:
49
+ return ''
50
+ else:
51
+ to_find = try_intro[0]
52
+ ind = sections.index(to_find)
53
+ if ind + 1 < len(sections):
54
+ start_marker = f'\\section{{{sections[ind]}}}'
55
+ end_marker = f'\\section{{{sections[ind+1]}}}'
56
+ start_point = content.find(start_marker)
57
+ end_point = content.find(end_marker)
58
+ return content[start_point+len(start_marker):end_point]
59
+ else:
60
+ return ''
61
+
62
+ def get_related_works(content):
63
+ sections = retrieve_text_cite(content, 'section')
64
+ if sections == []:
65
+ return ''
66
+ possible_related = [
67
+ "Literature Review",
68
+ "Related Work",
69
+ "Related Works",
70
+ "Prior Work",
71
+ "Prior Works",
72
+ "Related Research",
73
+ "Research Overview",
74
+ "Previous Work",
75
+ "Previous Works",
76
+ "Review of the Literature",
77
+ "Review of Related Literature",
78
+ "Survey of Related Work",
79
+ "Survey of Related Works",
80
+ "Background",
81
+ "Research Background",
82
+ "Review of Prior Research",
83
+ "Literature Survey",
84
+ "Overview of Literature",
85
+ "Existing Literature",
86
+ "Review of Existing Work",
87
+ "Review of Existing Works",
88
+ "Review of Previous Studies",
89
+ "Review of Prior Literature",
90
+ "Summary of Related Research",
91
+ "Survey of Existing Literature",
92
+ "Survey of Literature",
93
+ "Existing Research Overview",
94
+ "Prior Literature Review"
95
+ ]
96
+ possible_sections = [x for x in sections if any([True for y in possible_related if y.lower() == x.strip().lower()])]
97
+ if possible_sections == []:
98
+ return ''
99
+ else:
100
+ to_find = possible_sections[0]
101
+ ind = sections.index(to_find)
102
+
103
+ if ind + 1 < len(sections):
104
+ start_marker = f'\\section{{{sections[ind]}}}'
105
+ end_marker = f'\\section{{{sections[ind+1]}}}'
106
+ start_point = content.find(start_marker)
107
+ end_point = content.find(end_marker)
108
+ return content[start_point+len(start_marker):end_point]
109
+
110
+ else:
111
+ return ''
utils/latexpand ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/perl
2
+ # Inspired by latexpand by D. Musliner, University of Michigan
3
+ # 2012-2023: Matthieu Moy <[email protected]>
4
+ # BSD License
5
+
6
+ use strict;
7
+ use Cwd;
8
+ use Getopt::Long;
9
+ use IO::Handle;
10
+ use File::Spec;
11
+
12
+ my $TEXINPUTS = $ENV{'TEXINPUTS'};
13
+ # By default, search in current directory. We use '.' and not getcwd()
14
+ # to avoid issues if the working directory contains a ':' character.
15
+ if (!$TEXINPUTS) { $TEXINPUTS = '.'; }
16
+
17
+ my $verbose;
18
+ my $keep_comments;
19
+ my $keep_includes;
20
+ my $empty_comments;
21
+ my $help;
22
+ my $long_help;
23
+ my %defines = ();
24
+ my $output;
25
+ my $explain;
26
+ my $show_graphics;
27
+ my $graphics_extensions = ":.pdf:.png:.jpg:.eps";
28
+ my $expand_usepackage;
29
+ my $expand_bbl;
30
+ my $biber;
31
+ my $fatal;
32
+ my $version;
33
+ my $makeatletter;
34
+ my $inside_import;
35
+ my $in_enc = "bytes";
36
+ my $out_enc = "bytes";
37
+
38
+ GetOptions (
39
+ 'h' => \$help,
40
+ 'help' => \$long_help,
41
+ 'verbose|v' => \$verbose,
42
+ 'keep-comments' => \$keep_comments,
43
+ 'keep-includes' => \$keep_includes,
44
+ 'empty-comments' => \$empty_comments,
45
+ 'define|d=s%' => \%defines,
46
+ 'output|o=s' => \$output,
47
+ 'explain' => \$explain,
48
+ 'show-graphics' => \$show_graphics,
49
+ 'graphics-extensions' => \$graphics_extensions,
50
+ 'expand-usepackage' => \$expand_usepackage,
51
+ 'expand-bbl=s' => \$expand_bbl,
52
+ 'biber=s' => \$biber,
53
+ 'fatal' => \$fatal,
54
+ 'version' => \$version,
55
+ 'makeatletter' => \$makeatletter,
56
+ 'in-encoding=s' => \$in_enc,
57
+ 'out-encoding=s' => \$out_enc,
58
+ ) or pod2usage_wrapper(2);
59
+ version() if $version;
60
+ pod2usage_wrapper(0) if $help;
61
+ pod2usage_wrapper(-exitstatus => 0, -output => \*STDOUT, -verbose => 2) if $long_help;
62
+
63
+ sub pod2usage_wrapper
64
+ {
65
+ # Like pod2usage, but fall back to a simpler implem in case
66
+ # pod2usage can't be found.
67
+ if (eval {require Pod::Usage;1;} ne 1) {
68
+ print "Please install perldoc and Pod::Usage to get proper help.\n";
69
+ my $started = 0;
70
+ open (my $in, '<', "$0") or die $!;
71
+ while (<$in>) {
72
+ if ($started) {
73
+ print;
74
+ }
75
+ if (/^__END__$/) {
76
+ $started = 1;
77
+ }
78
+ }
79
+ } else {
80
+ Pod::Usage->import();
81
+ pod2usage(@_);
82
+ }
83
+ }
84
+
85
+ sub get_version
86
+ {
87
+ # $VERSION's value will be substituted by 'make dist', but the
88
+ # next line won't (the string has to be broken to avoid it).
89
+ my $VERSION = 'v1.7.2';
90
+ if ($VERSION eq '@LATEXPAND' . '_VERSION@') {
91
+ my($vol,$dir,$file) = File::Spec->splitpath($0);
92
+ chdir($dir);
93
+ $VERSION = `git describe --tags HEAD 2>/dev/null`;
94
+ }
95
+ if ($VERSION eq '') {
96
+ $VERSION = '<unknown version>';
97
+ }
98
+ $VERSION =~ s/^\s+|\s+$//g;
99
+ return $VERSION;
100
+ }
101
+
102
+ sub version
103
+ {
104
+ print "latexpand version ". get_version() .".\n";
105
+ exit(0);
106
+ }
107
+
108
+ my $nl = "";
109
+ if ($empty_comments) {
110
+ $nl = "%\n";
111
+ }
112
+
113
+ if ($output && $output ne "-") {
114
+ open (my $OUTPUT, '>', "$output") or die $!;
115
+ STDOUT->fdopen(\*$OUTPUT, 'w') or die $!;
116
+ }
117
+
118
+ sub say
119
+ {
120
+ if ($verbose) {
121
+ print STDERR "$_[0]";
122
+ }
123
+ }
124
+
125
+ my $makeatletter_found;
126
+ my $in_preamble;
127
+
128
+ use open IN => ":$in_enc", OUT => ":$out_enc";
129
+
130
+ foreach my $file (@ARGV)
131
+ {
132
+ say "processing $file\n";
133
+ $makeatletter_found = 0;
134
+ $in_preamble = 1;
135
+ $inside_import = "";
136
+ if ($file =~ /\.bib$/) {
137
+ warn "WARNING: latexpand is not meant to be used on BibTeX files like '$file'.\n" .
138
+ " Run latexpand on your main .tex file, using '--expand-bbl FILE'\n" .
139
+ " or '--biber FILE' if needed to inline the generated bbl file.\n";
140
+ } elsif (not $file =~ /\.tex$/) {
141
+ warn "WARNING: latexpand is meant to be used on .tex files, which $file isn't.\n";
142
+ }
143
+ process_file($file, " ");
144
+ }
145
+
146
+ sub cat_file
147
+ {
148
+ my $file = shift;
149
+ open (my $INFILE, "<", $file) || die "could not open input file '$file'\n";
150
+ while (<$INFILE>) {
151
+ print;
152
+ }
153
+ close ($INFILE);
154
+ }
155
+
156
+ sub process_file
157
+ {
158
+ my $file = shift;
159
+ my $prefix = (shift || "");
160
+ my $in_comment = 0;
161
+ open(my $FILE, "<", $file) or die "could not open input file '$file'\n";
162
+ my $commented_newline = 0;
163
+ while (my $line = <$FILE>) {
164
+ if ($line =~ /^[ \t]*\\endinput/) {
165
+ # Surprisingly, text after \endinput on the
166
+ # same line is kept in output. Also, add a
167
+ # space (before %), automatically inserted by
168
+ # TeX at the end of file.
169
+ $line =~ s/\\endinput(.*)\n?/$1 % /;
170
+ $in_comment = 1;
171
+ process_line($line, $prefix, \$commented_newline);
172
+ last;
173
+ }
174
+ while (my ($k, $v) = each (%defines))
175
+ {
176
+ $line=~s!\\$k!$v!g;
177
+ }
178
+ process_line($line, $prefix, \$commented_newline, $file);
179
+ if ($line =~ /^%.*[^\n]\z/ || $line =~ /[^\\]%.*[^\n]\z/) {
180
+ # file ends with a comment not ending with a newline
181
+ print "\n";
182
+ }
183
+ # Garbage at end of line after \end{document} is
184
+ # ignored by LaTeX, but we don't allow anything before
185
+ # to avoid e.g. \verb|\end{document}| from terminating
186
+ # the file.
187
+ if (!$keep_comments && $line =~ /^[ \t]*\\end\{document\}/) {
188
+ last;
189
+ }
190
+ }
191
+ close($FILE);
192
+ return $in_comment;
193
+ }
194
+
195
+ sub process_line
196
+ {
197
+ my ($line, $prefix, $commented_newline, $file) = @_;
198
+ $_ = $line;
199
+ if ($$commented_newline) {
200
+ # Leading whitespaces after a comment is ignored.
201
+ # There's no space in:
202
+ # Line 1%
203
+ # Line 2.
204
+ # Match just space and tabs (\s would match \n)
205
+ s/^[ \t]*//;
206
+ if (/^$/) {
207
+ # Deal with:
208
+ #
209
+ # Line 1 % comment
210
+ #
211
+ # Line 2
212
+ #
213
+ # The newline after Line 1 is commented, but we still
214
+ # want a new paragraph. We strip the comment together
215
+ # with its newline, but re-add a newline to chnge
216
+ # paragraph here if needed:
217
+ print "\n";
218
+ }
219
+ }
220
+ $$commented_newline = 0;
221
+ # Consider \makeatletter only in preamble, because we do want
222
+ # to warn on \someCommand{\makeatletter\command@with@arobase}.
223
+ if ($in_preamble && /^[^%]*\\makeatletter/) {
224
+ $makeatletter_found = 1;
225
+ }
226
+ if ($in_preamble && /^[^%]*\\makeatother/) {
227
+ $makeatletter_found = 0;
228
+ }
229
+ my $command;
230
+ if (!$makeatletter && !$makeatletter_found
231
+ && (($command) = /^[^%]*(\\[[:alpha:]]*@[[:alpha:]]*)/)
232
+ && ($command ne '\@')) {
233
+ print STDERR "Warning: command $command containing @ found in\n";
234
+ print STDERR "Warning: $file.\n";
235
+ print STDERR "Warning: consider using --makeatletter if the result is not compilable.\n";
236
+ }
237
+
238
+ # non-comment is a sequence of:
239
+ # - escaped character (\\.), including \% and \\
240
+ # - neither '%' nor '\'.
241
+ my $NON_COMMENT = '([^\\\\%]|\\\\.)*';
242
+
243
+ unless ($keep_comments) {
244
+ # Special-case for \url{} commands, which may contain '%'
245
+ # characters. It's hard to catch them in $NON_COMMENT since we'd
246
+ # need a regexp so that "\url{foo" can't match as non-comment in
247
+ # the line \url{foo%bar}, but "\url{foo%bar}" would match.
248
+ # Escaping these '%' is not mandatory, but allowed, hence we can
249
+ # pre-process the line by escaping them, and let latexpand work
250
+ # as normal afterwards.
251
+ # Known limitation: latexpand doesn't do balanced braces
252
+ # recognition, and just refuses both { and } within \url{}
253
+ # argument for %-detection to work ([^{}%] below). Fix should be
254
+ # possible using
255
+ # https://stackoverflow.com/questions/15301708/perl-regular-expression-match-nested-brackets
256
+ # but is it worth the trouble? (file an issue or send a merge
257
+ # request if you think it is)
258
+
259
+ # While there are \url{URL} with unescaped % in URL ...
260
+ my $NON_PERCENT = '([^\\}]%|[^{}%])*';
261
+ while (/^(?<before>.*\\url\{)(?<url>$NON_PERCENT[^\\}]%$NON_PERCENT)(?<after>\}.*)$/) {
262
+ my ($before, $url, $after) = ($+{before}, $+{url}, $+{after});
263
+ # escape unescaped % in URL, if any
264
+ $url =~ s/([^\\])%/$1\\%/g;
265
+ $_ = $before . $url . $after ."\n";
266
+ }
267
+ if (!$empty_comments) {
268
+ # Include \n in pattern to avoid matching
269
+ # comments at end of files
270
+
271
+ # remove comments + whitespace-only lines completely
272
+ if (s/^\s*%.*\n//) {
273
+ $$commented_newline = 1;
274
+ }
275
+
276
+ # Special-case commands at end of line. We
277
+ # don't want "\\foo%\nbar" to become
278
+ # "\\foobar" (but we still want \@% to result
279
+ # in no space!)
280
+ if (s/^($NON_COMMENT\\([[:alpha:]]|[[:alpha:]@]{2,}))%.*\n/$1 /) {
281
+ $$commented_newline = 1;
282
+ } elsif (s/^($NON_COMMENT)%.*\n/$1/) {
283
+ # remove only the comment if the line has actual content
284
+ $$commented_newline = 1;
285
+ }
286
+ }
287
+ # Apply the "empty comments" treatment unconditionally
288
+ # for comments not matched above (it doesn't harm to
289
+ # keep an empty comment sometimes, but it may harm to
290
+ # leave a real comment if the goal was to strip them).
291
+ s/^(([^\\%]|\\.)*)%.*$/$1%/;
292
+ }
293
+
294
+ unless ($keep_includes) {
295
+ # \input{foo.tex}
296
+ my $ARGBRACES = '\{\\s*([^"}\\s][^}]*)(\\s*)\}';
297
+ # \input{"foo bar.tex"}
298
+ my $ARGQUOTED = '\{\\s*"([^"]*)"(\\s*)\}';
299
+ # \input foo.tex
300
+ my $ARGSPACES = '\\s([^\{\\s][^\\s]+?)\\s()';
301
+ my $ARGUMENT = "\\s*?(?|$ARGBRACES|$ARGQUOTED|$ARGSPACES)";
302
+
303
+ if (my ($before, $ignored, $full_filename, $trailing, $after)
304
+ = /^($NON_COMMENT)\\include$ARGUMENT(.*)$/) {
305
+ $full_filename = find_tex_file($full_filename . ".tex");
306
+ if ($full_filename) {
307
+ say $prefix . "Found include for file: $full_filename\n";
308
+ print $before . $nl;
309
+ print '\clearpage{}' . $nl;
310
+ print "% start include $full_filename\n" if ($explain);
311
+ my $in_comment = process_file($full_filename, $prefix . " ");
312
+ if ($explain) {
313
+ print " % end include $full_filename\n";
314
+ } elsif ($in_comment) {
315
+ print "\n";
316
+ }
317
+ print '\clearpage{}' . $nl;
318
+ print $nl . $after . "\n";
319
+ $_ = "";
320
+ }
321
+ } elsif (my ($before, $ignored, $full_filename, $trailing, $after)
322
+ = /^($NON_COMMENT)\\input$ARGUMENT(.*)$/) {
323
+ if ($inside_import) {
324
+ $full_filename = $inside_import . $full_filename;
325
+ }
326
+ $full_filename = find_tex_file($full_filename, ":.tex");
327
+ if ($full_filename) {
328
+ say $prefix . "Found input for file: $full_filename\n";
329
+ # Apparently, in some versions of LaTeX, a space
330
+ # after filename in \input{foo.tex } is inserted
331
+ # _before_ the inclusion. That was the case for
332
+ # me when 31fa806 (deal with space after
333
+ # filename in \input and \include, 2019-12-11)
334
+ # was written, but is not anymore, hence we just
335
+ # throw $trailing away.
336
+ print $before . $nl;
337
+ print "% start input $full_filename\n" if ($explain);
338
+ my $in_comment = process_file($full_filename, $prefix . " ");
339
+ if ($explain) {
340
+ print " % end input $full_filename\n";
341
+ } elsif ($in_comment) {
342
+ print "\n";
343
+ }
344
+ if ($after =~ /[^\s]/) {
345
+ # LaTeX produces this space, so let's do it also
346
+ print " " . $nl . $after . "\n";
347
+ } else {
348
+ print " ";
349
+ }
350
+ $_ = "";
351
+ }
352
+ } elsif (my ($before, $ignored, $dir, $ignored, $full_filename, $ignored, $after)
353
+ = /^($NON_COMMENT)\\(?:sub)?import$ARGUMENT$ARGUMENT(.*)$/) {
354
+ if ($explain) {
355
+ print "% dir " . $dir ."\n";
356
+ print "% full_filename " . $full_filename ."\n";
357
+ print "% after " . $after ."\n";
358
+ print "% inside_import $inside_import\n";
359
+ }
360
+ $full_filename = $dir . $full_filename;
361
+ if ($inside_import) {
362
+ $full_filename = $inside_import . $full_filename;
363
+ }
364
+ print "% cat(inside_import,dir,full_filename) " . $full_filename ."\n" if ($explain);
365
+ $full_filename = find_tex_file($full_filename, ":.tex");
366
+ if ($full_filename) {
367
+ say $prefix . "Found input for file: $full_filename\n";
368
+ print $before . $nl;
369
+ print "% start input $full_filename\n" if ($explain);
370
+ my $previous_import_dir = $inside_import;
371
+ $inside_import = $inside_import . $dir;
372
+ my $in_comment = process_file($full_filename, $prefix . " ");
373
+ $inside_import = $previous_import_dir;
374
+ if ($explain) {
375
+ print " % end input $full_filename\n";
376
+ } elsif ($in_comment) {
377
+ print "\n";
378
+ }
379
+ if ($after =~ /[^\s]/) {
380
+ # LaTeX produces this space, so let's do it also
381
+ print " " . $nl . $after . "\n";
382
+ } else {
383
+ print " ";
384
+ }
385
+ $_ = "";
386
+ }
387
+ } elsif (my ($before, $ignored, $args, $full_filename, $ignored, $after)
388
+ = /^($NON_COMMENT)\\includegraphics(\[[^\]]*?\]|)$ARGUMENT(.*)$/) {
389
+ if ($explain) {
390
+ print "% inside_import " . $inside_import ."\n";
391
+ print "% before " . $before ."\n";
392
+ print "% ignored " . $ignored ."\n";
393
+ print "% args " . $args ."\n";
394
+ print "% full_filename " . $full_filename ."\n";
395
+ print "% after " . $after ."\n";
396
+ }
397
+ if ($inside_import) {
398
+ $full_filename = $inside_import . $full_filename;
399
+ print "$before\\includegraphics" . "$args" . "{$full_filename}$after\n";
400
+ $_ = "";
401
+ }
402
+ } elsif (my ($before, $ignored, $args, $full_filename, $ignored, $after)
403
+ = /^($NON_COMMENT)\\lstinputlisting(\[[^\]]*?\]|)$ARGUMENT(.*)$/) {
404
+ if ($explain) {
405
+ print "% inside_import " . $inside_import ."\n";
406
+ print "% before " . $before ."\n";
407
+ print "% ignored " . $ignored ."\n";
408
+ print "% args " . $args ."\n";
409
+ print "% full_filename " . $full_filename ."\n";
410
+ print "% after " . $after ."\n";
411
+ }
412
+ if ($inside_import) {
413
+ $full_filename = $inside_import . $full_filename;
414
+ print "$before\\lstinputlisting" . "$args" . "{$full_filename}$after\n";
415
+ $_ = "";
416
+ }
417
+ }
418
+ }
419
+ if ($expand_usepackage) {
420
+ # Don't bother with before and after text, we just require the
421
+ # usepackage to be alone on its line.
422
+ if (my ($package_name) = /^\s*\\usepackage\{([^\}]*)\}\s*(%.*)?$/) {
423
+ my $full = find_file($package_name . ".sty", $TEXINPUTS);
424
+ if ($full) {
425
+ say $prefix . "Found package file: $full\n";
426
+ process_file($full, $prefix . " ");
427
+ $_ = "";
428
+ # Forget about any commented newline
429
+ # before the \usepackage:
430
+ $$commented_newline = 0;
431
+ } else {
432
+ say $prefix . "Not including external package $package_name\n";
433
+ }
434
+ }
435
+ }
436
+ if ($expand_bbl) {
437
+ if (my ($before, $bib_name, $after)
438
+ = /^(.*)\\(?:bibliography|bibselect)\{([^\}]*)\}(.*)$/) {
439
+ # The BBL file is not necessarily $bib_name.
440
+ # Take it from the command-line.
441
+ print $before . $nl;
442
+ say $prefix . "Expanding BBL file: $expand_bbl\n";
443
+ process_file($expand_bbl, $prefix . " ");
444
+ print " " . $nl . $after . "\n";
445
+ $_ = "";
446
+ }
447
+ }
448
+ if ($biber) {
449
+ if (my ($before, $after)
450
+ = /^(.*)\\(?:addbibresource)\{[^\}]*\}(.*)$/) {
451
+ # See https://tex.stackexchange.com/questions/166518/biblatex-include-bbl-problem-with-verb-field/166526#166526
452
+ my $biber_noext = $biber;
453
+ $biber_noext =~ s/.bbl//;
454
+ print $before . $nl;
455
+ say $prefix . "Expanding Biber BBL file: $biber\n";
456
+ print '\begin{filecontents*}{' . $biber . '}' . "\n";
457
+ cat_file($biber);
458
+ print "\n";
459
+ print '\end{filecontents*}
460
+
461
+ \usepackage{xpatch}
462
+
463
+ %Patch the biblatex input command.
464
+ %replace "testinput-bbl" if you change the name above.
465
+ %disable if you want to run biblatex/biber normally
466
+ \makeatletter
467
+ \patchcmd\blx@bblinput{\blx@blxinit}
468
+ {\blx@blxinit
469
+ \def\jobname{' . $biber_noext . '}%new jobname
470
+ }{}{\fail}
471
+ \makeatother
472
+ ';
473
+ say $prefix . "End expansion of Biber BBL file: $biber\n";
474
+ print " " . $nl . $after . "\n";
475
+ $_ = "";
476
+ }
477
+ }
478
+ if ($show_graphics) {
479
+ if (/\\includegraphics(\[[^\]]*\])?{([^}]*)}/) {
480
+ my $full_filename = $2;
481
+ if ($inside_import) {
482
+ $full_filename = $inside_import . $full_filename;
483
+ }
484
+ my $full = find_tex_file($full_filename, $graphics_extensions);
485
+ say $prefix . "needs graphics file: ";
486
+ print STDERR "$full\n";
487
+ }
488
+ }
489
+ if (/^[ \t]*\\begin\{document\}/) {
490
+ $in_preamble = 0;
491
+ if ($makeatletter) {
492
+ print '\makeatletter' . $nl;
493
+ }
494
+ }
495
+ print;
496
+ }
497
+
498
+ sub unquote
499
+ {
500
+ my $str = shift;
501
+ my $x = substr($str, 0, 1);
502
+ my $y = substr($str, -1, 1);
503
+ if ($x eq $y && ($x eq '"' || $x eq "'")) {
504
+ $str = substr($str, 1, -1);
505
+ }
506
+ # There's a weird LaTeX syntax: \include{"file\space
507
+ # with\space spaces"}, so remove these \space when unquoting.
508
+ $str =~ s/\\space / /g;
509
+ return $str;
510
+ }
511
+
512
+ # search $1 in $TEXINPUTS, with possible extensions in $2
513
+ sub find_tex_file
514
+ {
515
+ my $file = unquote(shift);
516
+ my $extensions = (shift || ":");
517
+ foreach my $ext (split(':', $extensions, -1)) {
518
+ my $full = find_file_global($file . $ext);
519
+ if ($full) {
520
+ return $full;
521
+ }
522
+ }
523
+ if ($fatal) {
524
+ die "ERROR: Could not find file [$file]\n";
525
+ } else {
526
+ print STDERR "Warning: Could not find file [$file]\n";
527
+ return;
528
+ }
529
+ }
530
+
531
+ sub find_file_global
532
+ {
533
+ my $file = shift;
534
+ if (open(my $fh, "-|", "kpsewhich", $file)) {
535
+ my $full = <$fh>;
536
+ $full =~ s/\s+$//;
537
+ close($fh);
538
+ if ($full) {
539
+ return $full;
540
+ }
541
+ }
542
+ # Should be useless, but fall-back in case kpsewhich fails (or is not installed, or ...):
543
+ return find_file($file, $TEXINPUTS);
544
+ }
545
+
546
+ # Find files, not searching for global files (to allow not expanding global .sty packages)
547
+ sub find_file
548
+ {
549
+ my ($file, $path) = @_;
550
+ if (File::Spec->file_name_is_absolute($file)) {
551
+ if (-e "$file" && ! -d "$file") {
552
+ return $file;
553
+ } else {
554
+ return;
555
+ }
556
+ }
557
+
558
+ # TEXINPUTS=...: (trailing :) means "append default search
559
+ # directories". We don't want global directories here, but
560
+ # still add . that may be needed.
561
+ if (substr($path, -1) eq ':') {
562
+ $path .= '.';
563
+ }
564
+ foreach my $dir (split(':', $path)) {
565
+ if (-e "$dir/$file" && ! -d "$dir/$file") {
566
+ return("$dir/$file");
567
+ }
568
+ }
569
+ return;
570
+ }
571
+
572
+
573
+ __END__
574
+
575
+ =head1 NAME
576
+
577
+ latexpand - Flatten LaTeX file by expanding \include and \input, ... and remove comments
578
+
579
+ =head1 SYNOPSIS
580
+
581
+ latexpand [options] FILE...
582
+
583
+ =head2 Options:
584
+
585
+ --verbose show what's going on
586
+ --keep-comments don't strip comments (comments are lines
587
+ starting with %, and anything below
588
+ \end{document})
589
+ --empty-comments keep empty comments (i.e. % at end of lines) for clarity
590
+ --keep-includes don't expand \input and \include directives
591
+ --expand-usepackage
592
+ Expand \usepackage{...} directives if the
593
+ corresponding .sty file is found in
594
+ $TEXINPUTS (or the current directory if
595
+ $TEXINPUTS is not set)
596
+ --expand-bbl FILE
597
+ Expand the bibliography by inlining FILE
598
+ (should be a *.bbl file)
599
+ --biber FILE Include \bibliography{} with FILE's content,
600
+ as needed by biblatex with the biber backend.
601
+ (similar to --expand-bbl FILE, but for
602
+ biber+biblatex).
603
+ --help this help message
604
+ --define <key>=<val>, -d <key>=<val>
605
+ defines a macro key to be replaced by value, e.g.,
606
+ when called with -d foo=bar would replace all occurences
607
+ of \foo in the code with bar. Can be supplied multiple times.
608
+ --output <file>, -o <file>
609
+ generate output in <file>
610
+ --explain generate explanatory comments in output
611
+ --show-graphics show included graphics
612
+ --graphics_extensions
613
+ colon-separated list of possible graphics extensions
614
+ (used by --show-graphics to find the actual graphics files)
615
+ --fatal Die in case a file can't be found.
616
+ --makeatletter Insert a \makeatletter in the preamble. In some
617
+ rare cases it may break your document, but it
618
+ may help fixing bad interactions between
619
+ @-commands and inclusion (see BUGS section).
620
+ --in-encoding FMT, --out-encoding FMT
621
+ File encoding used by input and output files.
622
+ This uses the same syntax as PerlIO's layers.
623
+ Example:
624
+ --in-encoding 'encoding(UTF-8)'
625
+ The default is 'bytes' and should always work.
626
+
627
+ =head1 USES
628
+
629
+ The most common use of latexpand is to simplify distribution of source
630
+ LaTeX files, typically to satisfy the requirement of editors and
631
+ archival sites (springer, arXiv.org, ...) who force the authors to
632
+ submit sources. One does not necessarily want to submit sources with
633
+ comments, and uploading a document made of several files including
634
+ each other is a bit painful. By default, latexpand answers both
635
+ problems by outputing a single LaTeX file that contain no comment.
636
+
637
+ =head1 GETTING LATEXPAND
638
+
639
+ The latest version of latexpand is available here:
640
+
641
+ https://gitlab.com/latexpand/latexpand
642
+
643
+ Versions are uploaded to ctan.org from time to time:
644
+
645
+ http://www.ctan.org/pkg/latexpand
646
+
647
+ =head1 BUGS
648
+
649
+ Please, report bugs on the issue tracker on the project site:
650
+
651
+ https://gitlab.com/latexpand/latexpand/issues
652
+
653
+ =head2 Known bugs
654
+
655
+ =head3 Verbatim
656
+
657
+ latexpand currently ignores \begin{verbatim} ... \end{verbatim}, and
658
+ will therefore process any \include, \input, ... directives that
659
+ appear within verbatim environments (while it shouldn't).
660
+
661
+ LaTeX comments inside verbatim environments are also incorrectly
662
+ stripped. You can use --keep-comments as a workaround to avoid this.
663
+
664
+ =head3 Comment environment
665
+
666
+ It would be nice to remove code between \begin{comment} and
667
+ \end{comment} too if \usepackage{comment} is used.
668
+
669
+ Code like
670
+
671
+ foo%
672
+ \begin{comment}
673
+
674
+ will produce the incorrect
675
+
676
+ foo\begin{comment}
677
+
678
+ A workaround is to use --empty-comments when such tricky usage of the
679
+ comments package is done.
680
+
681
+ =head3 \makeatletter and use with transfig/xfig with \scalebox{}
682
+
683
+ If \input{} or \include{} appears as argument to a command, and the
684
+ file included contains \makeatletter, then after expansion, the
685
+ \makeatletter and the @-command appear as argument to the command,
686
+ which is forbidden because the argument is parsed (and the @-command
687
+ badly tokenized) before being executed.
688
+
689
+ This happens with
690
+
691
+ \scalebox{ \input{file-generated-by-xfig.pdf_t} }
692
+
693
+ Workaround: add \makeatletter before the scalebox manually in your
694
+ code, like
695
+
696
+ \makeatletter{}
697
+ \scalebox{ \input{file-generated-by-xfig.pdf_t} }
698
+ \makeatother{}
699
+
700
+ In the case of xfig generated files, it is necessary only for the
701
+ first occurence.
702
+
703
+ A more brute-force workaround is to use latexpand --makeatletter.
704
+
705
+ =head1 SEE ALSO
706
+
707
+ Instructions to include only the relevant .bib items (french):
708
+
709
+ https://lacl.fr/~caubert/notes/portabilite-du-tex.html#dependances
710
+
711
+ =head1 VERSION
712
+
713
+ This is latexpand version v1.7.2.
utils/utils.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import regex
3
+ import yaml
4
+ import shutil
5
+ import bibtexparser
6
+ from charset_normalizer import from_path
7
+ from langdetect import detect
8
+ import os
9
+ import subprocess
10
+ import numpy as np
11
+ import networkx as nx
12
+ import re
13
+
14
+
15
+ def is_venv():
16
+ return (hasattr(sys, 'real_prefix') or
17
+ (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix))
18
+
19
+ def read_yaml_file(file_path):
20
+ with open(file_path, 'r') as file:
21
+ try:
22
+ data = yaml.safe_load(file)
23
+ return data
24
+ except yaml.YAMLError as e:
25
+ print(f"Error reading YAML file: {e}")
26
+
27
+ def read_tex_file(file_path):
28
+ with open(file_path, 'r', encoding='utf-8') as file:
29
+ tex_content = file.read()
30
+ return tex_content
31
+
32
+ def write_tex_file(file_path, s):
33
+ with open(file_path, 'w', encoding='utf-8') as file:
34
+ file.write(s)
35
+
36
+ def get_core(s):
37
+ start = '\\begin{document}'
38
+ end = '\\end{document}'
39
+ beginning_doc = s.find(start)
40
+ end_doc = s.rfind(end)
41
+ return s[beginning_doc+len(start):end_doc]
42
+
43
+
44
+ def retrieve_text(text, command, keep_text=False):
45
+ """Removes '\\command{*}' from the string 'text'.
46
+
47
+ Regex `base_pattern` used to match balanced parentheses taken from:
48
+ https://stackoverflow.com/questions/546433/regular-expression-to-match-balanced-parentheses/35271017#35271017
49
+ """
50
+ base_pattern = (
51
+ r'\\' + command + r"(?:\[(?:.*?)\])*\{((?:[^{}]+|\{(?1)\})*)\}(?:\[(?:.*?)\])*"
52
+ )
53
+
54
+ def extract_text_inside_curly_braces(text):
55
+ """Extract text inside of {} from command string"""
56
+ pattern = r"\{((?:[^{}]|(?R))*)\}"
57
+
58
+ match = regex.search(pattern, text)
59
+
60
+ if match:
61
+ return match.group(1)
62
+ else:
63
+ return ""
64
+
65
+ # Loops in case of nested commands that need to retain text, e.g. \red{hello \red{world}}.
66
+ while True:
67
+ all_substitutions = []
68
+ has_match = False
69
+ for match in regex.finditer(base_pattern, text):
70
+ # In case there are only spaces or nothing up to the following newline,
71
+ # adds a percent, not to alter the newlines.
72
+ has_match = True
73
+
74
+ if not keep_text:
75
+ new_substring = ""
76
+ else:
77
+ temp_substring = text[match.span()[0] : match.span()[1]]
78
+ return extract_text_inside_curly_braces(temp_substring)
79
+
80
+ if match.span()[1] < len(text):
81
+ next_newline = text[match.span()[1] :].find("\n")
82
+ if next_newline != -1:
83
+ text_until_newline = text[
84
+ match.span()[1] : match.span()[1] + next_newline
85
+ ]
86
+ if (
87
+ not text_until_newline or text_until_newline.isspace()
88
+ ) and not keep_text:
89
+ new_substring = "%"
90
+ all_substitutions.append((match.span()[0], match.span()[1], new_substring))
91
+
92
+ for start, end, new_substring in reversed(all_substitutions):
93
+ text = text[:start] + new_substring + text[end:]
94
+
95
+ if not keep_text or not has_match:
96
+ break
97
+
98
+
99
+ def reduce_linebreaks(s):
100
+ return re.sub(r'(\n[ \t]*)+(\n[ \t]*)+', '\n\n', s)
101
+
102
+
103
+ def replace_percentage(s):
104
+ return re.sub(r'% *\n', '\n', s)
105
+
106
+
107
+ def reduce_spaces(s):
108
+ return re.sub(' +', ' ', s)
109
+
110
+
111
+ def delete_urls(s):
112
+ return re.sub(r'http\S+', '', s)
113
+
114
+
115
+ def remove_tilde(s):
116
+ s1 = re.sub(r'[~ ]\.', '.', s)
117
+ s2 = re.sub(r'[~ ],', ',', s1)
118
+ return re.sub(r'{}', '', s2)
119
+
120
+
121
+ def remove_verbatim_words(s):
122
+ with open("configs/latex_commands.yaml", "r") as stream:
123
+ read_config = yaml.safe_load(stream)
124
+
125
+ for command in read_config['verbatim_to_delete']:
126
+ s = s.replace(command, '')
127
+
128
+ for command in read_config['two_arguments']:
129
+ pattern = r'\\' + command + r'{[^}]*}' + r'{[^}]*}'
130
+ s = re.sub(pattern, '', s)
131
+
132
+ for command in read_config['three_arguments']:
133
+ pattern = r'\\' + command + r'{[^}]*}' + r'{[^}]*}' + r'{[^}]*}'
134
+ s = re.sub(pattern, '', s)
135
+
136
+ for command in read_config['two_arguments_elaborate']:
137
+ s = remove_multargument(s, '\\' + command, 2)
138
+
139
+ for command in read_config['three_arguments_elaborate']:
140
+ s = remove_multargument(s, '\\' + command, 3)
141
+
142
+ for command in read_config['replace_comments']:
143
+ pattern = r'\\' + command
144
+ s = re.sub(pattern, '%', s)
145
+
146
+ s = re.sub(
147
+ r'\\end{[\s]*abstract[\s]*}',
148
+ '',
149
+ s,
150
+ flags=re.IGNORECASE
151
+ )
152
+
153
+ s = re.sub(
154
+ r'\\begin{[\s]*abstract[\s]*}',
155
+ 'Abstract\n\n',
156
+ s,
157
+ flags=re.IGNORECASE
158
+ )
159
+ return s
160
+
161
+
162
+ def yes_or_no(s):
163
+ return 1 if "Yes" == s[0:3] else 0 if "No" == s[0:2] else -1
164
+
165
+
166
+ def get_main(directory):
167
+ file_paths = []
168
+ for root, _, files in os.walk(directory):
169
+ for file in files:
170
+ file_path = os.path.join(root, file)
171
+ file_paths.append(file_path)
172
+ latex_paths = [f for f in file_paths if f.endswith('.tex')]
173
+ number_tex = len(latex_paths)
174
+ if number_tex == 0:
175
+ return None
176
+ if number_tex == 1:
177
+ return latex_paths[0]
178
+ adjacency = np.zeros((number_tex, number_tex))
179
+ keys = [os.path.basename(path) for path in latex_paths]
180
+ reg_ex = r'\\input{(.*?)}|\\include{(.*?)}|\\import{(.*?)}|\\subfile{(.*?)}|\\include[*]{(.*?)}|}'
181
+ for i,file in enumerate(latex_paths):
182
+ content = read_tex_file(file)
183
+ find_pattern_input = re.findall(reg_ex, content)
184
+ find_pattern_input = [tup for tup in find_pattern_input if not all(element == "" for element in tup)]
185
+ number_matches = len(find_pattern_input)
186
+ if number_matches == 0:
187
+ continue
188
+ else:
189
+ content = replace_imports(file, content)
190
+ reg_ex_clean = r'\\input{(.*?)}|\\include{(.*?)}'
191
+ find_pattern_input = re.findall(reg_ex_clean, content)
192
+ number_matches = len(find_pattern_input)
193
+ for j in range(number_matches):
194
+ match = find_pattern_input[j]
195
+ non_empty_match = [t for t in match if t]
196
+ for non_empty in non_empty_match:
197
+ base_match = os.path.basename(non_empty)
198
+ if not base_match.endswith('.tex'):
199
+ base_match = base_match + '.tex'
200
+ if base_match not in keys:
201
+ continue
202
+ ind = keys.index(base_match)
203
+ adjacency[i][ind] = 1
204
+ G = nx.from_numpy_array(adjacency, create_using=nx.DiGraph)
205
+ connected_components = list(nx.weakly_connected_components(G))
206
+ size_connected = [len(x) for x in connected_components]
207
+ maximum_size = max(size_connected)
208
+ biggest_connected = [x for x in connected_components if len(x) == maximum_size]
209
+ if len(biggest_connected)>1:
210
+ roots = [n for connected in biggest_connected for n in connected if not list(G.predecessors(n))]
211
+ _check = []
212
+ for r in roots:
213
+ try:
214
+ _check.append(check_begin(latex_paths[r]))
215
+ except Exception as e:
216
+ _check.append(False)
217
+ potentials_files = [latex_paths[x] for x, y in zip(roots, _check) if y == True]
218
+ sizes_files = [os.path.getsize(x) for x in potentials_files]
219
+ return potentials_files[sizes_files.index(max(sizes_files))]
220
+
221
+ else:
222
+ roots = [n for n in biggest_connected[0] if not list(G.predecessors(n))]
223
+ return latex_paths[roots[0]]
224
+
225
+
226
+ def initial_clean(directory, config):
227
+ config_cmd = ''
228
+ if config == True:
229
+ config_cmd = '--config configs/cleaning_config.yaml'
230
+ temp_dir = directory[:directory.rfind('/')] + '_temp' + '/'
231
+ shutil.copytree(directory, temp_dir)
232
+ try:
233
+ command_res = os.system('arxiv_latex_cleaner --keep_bib {} {}'.format(directory, config_cmd))
234
+ if command_res != 0:
235
+ raise Exception('Error cleaning')
236
+ else:
237
+ shutil.rmtree(temp_dir)
238
+
239
+ except Exception as e:
240
+ shutil.rmtree(directory)
241
+ os.rename(temp_dir, directory)
242
+ file_paths = []
243
+ for root, _, files in os.walk(directory):
244
+ for file in files:
245
+ file_path = os.path.join(root, file)
246
+ file_paths.append(file_path)
247
+ latex_paths = [f for f in file_paths if f.endswith('.tex')]
248
+ for p in latex_paths:
249
+ results = from_path(p)
250
+ with open(p, 'w', encoding='utf-8') as f:
251
+ f.write(str(results.best()))
252
+ os.system('arxiv_latex_cleaner --keep_bib {} {}'.format(directory, config_cmd))
253
+ cleaned_directory = directory[:directory.rfind('/')] + '_arXiv'
254
+ shutil.rmtree(directory)
255
+ os.rename(cleaned_directory, directory)
256
+
257
+
258
+ def check_begin(directory):
259
+ content = read_tex_file(directory)
260
+ english = detect(content) == 'en'
261
+ return True and english if re.findall(r'\\begin{document}', content) else False
262
+
263
+
264
+ def post_processing(extracted_dir, file):
265
+ _dir = os.path.dirname(file) + '/'
266
+ perl_expand(file)
267
+ file = _dir + 'merged_latexpand.tex'
268
+ try:
269
+ de_macro(file)
270
+ file = _dir + 'merged_latexpand-clean.tex'
271
+ except Exception as e:
272
+ pass
273
+ try:
274
+ def_handle(file)
275
+ except Exception as e:
276
+ pass
277
+ try:
278
+ declare_operator(file) # has additional add-ons
279
+ except Exception as e:
280
+ pass
281
+ try:
282
+ de_macro(file)
283
+ file = _dir + os.path.splitext(os.path.basename(file))[0] + '-clean' + '.tex'
284
+ except Exception as e:
285
+ pass
286
+ initial_clean(_dir, config=True)
287
+ initial_clean(_dir, config=False)
288
+ tex_content = read_tex_file(file)
289
+ final_tex = reduce_spaces(
290
+ delete_urls(
291
+ remove_tilde(
292
+ reduce_linebreaks(
293
+ replace_percentage(
294
+ remove_verbatim_words(
295
+ tex_content
296
+ )
297
+ )
298
+ )
299
+ )
300
+ )
301
+ ).strip()
302
+ shutil.rmtree(extracted_dir)
303
+ os.makedirs(extracted_dir)
304
+ write_tex_file(extracted_dir + 'final_cleaned.tex', final_tex)
305
+ initial_clean(extracted_dir, config=False)
306
+ return extracted_dir + 'final_cleaned.tex'
307
+
308
+
309
+ def perl_expand(file):
310
+ # Save the current working directory
311
+ oldpwd = os.getcwd()
312
+ target_dir = os.path.dirname(file) + '/'
313
+ # Correctly construct the path
314
+ target = os.path.join(target_dir, 'latexpand')
315
+ src = './src/utils/latexpand'
316
+ # Copy the `latexpand` script to the target directory
317
+ shutil.copyfile(src, target)
318
+ # Change to the target directory
319
+ os.chdir(target_dir)
320
+
321
+ # Run the perl command without shell=True and handle redirection within Python
322
+ with open('merged_latexpand.tex', 'w') as output_file:
323
+ subprocess.run(['perl', 'latexpand', os.path.basename(file)],
324
+ stdout=output_file, stderr=subprocess.DEVNULL)
325
+
326
+ # Return to the original directory
327
+ os.chdir(oldpwd)
328
+
329
+
330
+ def de_macro(file):
331
+ # Save the current working directory\
332
+ oldpwd = os.getcwd()
333
+ target_dir = os.path.dirname(file) + '/'
334
+ # Construct the target path
335
+ target = os.path.join(target_dir, 'de-macro.py')
336
+ src = '.src/utils/de-macro.py'
337
+
338
+ # Copy the `de-macro.py` script to the target directory
339
+ shutil.copyfile(src, target)
340
+ # Change to the target directory
341
+ os.chdir(target_dir)
342
+
343
+ # Run the de-macro script without os.system and capture errors
344
+ try:
345
+ subprocess.run(['python3', 'de-macro.py', os.path.basename(file)],
346
+ stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
347
+ except subprocess.CalledProcessError as e:
348
+ raise Exception(f"Error de-macro: {e}") from e
349
+ finally:
350
+ # Always return to the original directory
351
+ os.chdir(oldpwd)
352
+
353
+
354
+ def def_handle(file):
355
+ h = os.system('python3 src/utils/def_handle.py {} --output {}'.format(file, file))
356
+ if h != 0:
357
+ raise Exception('Error def handle')
358
+
359
+
360
+ def declare_operator(file):
361
+ s = read_tex_file(file)
362
+ ## Operators
363
+ pattern = r'\\DeclareMathOperator'
364
+ s = re.sub(pattern, r'\\newcommand', s)
365
+ pattern = {
366
+ r'\\newcommand\*': r'\\newcommand',
367
+ r'\\providecommand\*': r'\\newcommand',
368
+ r'\\providecommand': r'\\newcommand',
369
+ r'\\renewcommand\*': r'\\renewcommand',
370
+ r'\\newenvironment\*': r'\\newenvironment',
371
+ r'\\renewenvironment\*': r'\\renewenvironment'
372
+ }
373
+ s = re.sub(r'\\end +', r'\\end', s)
374
+ for key in pattern:
375
+ s = re.sub(key, pattern[key], s)
376
+ ## Title
377
+ start = '\\begin{document}'
378
+ beginning_doc = s.find(start)
379
+ pattern = {
380
+ r'\\icmltitlerunning\*': r'\\title',
381
+ r'\\icmltitlerunning': r'\\title',
382
+ r'\\inlinetitle\*': r'\\title',
383
+ r'\\icmltitle\*': r'\\title',
384
+ r'\\inlinetitle': r'\\title',
385
+ r'\\icmltitle': r'\\title',
386
+ r'\\titlerunning\*': r'\\title',
387
+ r'\\titlerunning': r'\\title',
388
+ r'\\toctitle': r'\\title',
389
+ r'\\title\*': r'\\title',
390
+ r'\\TITLE\*': r'\\title',
391
+ r'\\TITLE': r'\\title',
392
+ r'\\Title\*': r'\\title',
393
+ r'\\Title': r'\\title',
394
+ }
395
+ for key in pattern:
396
+ s = re.sub(key, pattern[key], s)
397
+ find_potential = s.find('\\title')
398
+
399
+ ## Remove \\
400
+ title_content = retrieve_text(s, 'title', keep_text = True)
401
+ if title_content != None:
402
+ cleaned_title = re.sub(r'\\\\', ' ', title_content)
403
+ cleaned_title = re.sub(r'\n',' ', cleaned_title)
404
+ cleaned_title = re.sub(r'\~',' ', cleaned_title)
405
+ s = s.replace(title_content, cleaned_title)
406
+ if find_potential != -1 and find_potential < beginning_doc:
407
+ s = s.replace('\\maketitle', cleaned_title)
408
+
409
+ ## Cite and ref commands
410
+ pattern = {
411
+ r'\\citep\*': r'\\cite',
412
+ r'\\citet\*': r'\\cite',
413
+ r'\\citep': r'\\cite',
414
+ r'\\citet': r'\\cite',
415
+ r'\\cite\*': r'\\cite',
416
+ r'\\citealt\*': r'\\cite',
417
+ r'\\citealt': r'\\cite',
418
+ r'\\citealtp\*': r'\\cite',
419
+ r'\\citealp': r'\\cite',
420
+ r'\\citeyear\*': r'\\cite',
421
+ r'\\citeyear': r'\\cite',
422
+ r'\\citeauthor\*': r'\\cite',
423
+ r'\\citeauthor': r'\\cite',
424
+ r'\\citenum\*': r'\\cite',
425
+ r'\\citenum': r'\\cite',
426
+ r'\\cref': r'\\ref',
427
+ r'\\Cref': r'\\ref',
428
+ r'\\factref': r'\\ref',
429
+ r'\\appref': r'\\ref',
430
+ r'\\thmref': r'\\ref',
431
+ r'\\secref': r'\\ref',
432
+ r'\\lemref': r'\\ref',
433
+ r'\\corref': r'\\ref',
434
+ r'\\eqref': r'\\ref',
435
+ r'\\autoref': r'\\ref',
436
+ r'begin{thm}': r'begin{theorem}',
437
+ r'begin{lem}': r'begin{lemma}',
438
+ r'begin{cor}': r'begin{corollary}',
439
+ r'begin{exm}': r'begin{example}',
440
+ r'begin{defi}': r'begin{definition}',
441
+ r'begin{rem}': r'begin{remark}',
442
+ r'begin{prop}': r'begin{proposition}',
443
+ r'end{thm}': r'end{theorem}',
444
+ r'end{lem}': r'end{lemma}',
445
+ r'end{cor}': r'end{corollary}',
446
+ r'end{exm}': r'end{example}',
447
+ r'end{defi}': r'end{definition}',
448
+ r'end{rem}': r'end{remark}',
449
+ r'end{prop}': r'end{proposition}',
450
+ }
451
+
452
+ for key in pattern:
453
+ s = re.sub(key, pattern[key], s)
454
+
455
+
456
+ pattern = {
457
+ r'subsubsection': r'section',
458
+ r'subsubsection ': r'section',
459
+ r'subsubsection\*': r'section',
460
+ r'subsubsection\* ': r'section',
461
+ r'subsection': r'section',
462
+ r'subsection ': r'section',
463
+ r'subsection\*': r'section',
464
+ r'subsection\* ': r'section',
465
+ r'section ': r'section',
466
+ r'section\*': r'section',
467
+ r'section\* ': r'section',
468
+ r'chapter': r'section',
469
+ r'chapter ': r'section',
470
+ r'chapter\*': r'section',
471
+ r'chapter\* ': r'section',
472
+ r'mysubsubsection': r'section',
473
+ r'mysubsection': r'section',
474
+ r'mysection': r'section',
475
+ }
476
+
477
+ for key in pattern:
478
+ s = re.sub(key, pattern[key], s)
479
+
480
+ # In case any new commands for appendix/appendices
481
+ s = re.sub(r'newcommand{\\appendix}', '', s)
482
+ s = re.sub(r'newcommand{\\appendices}', '', s)
483
+ s = get_core(s)
484
+
485
+ ## In case of double titles being defined
486
+ title_content = retrieve_text(s, 'title', keep_text = True)
487
+ if title_content != None:
488
+ cleaned_title = re.sub(r'\\\\', ' ', title_content)
489
+ cleaned_title = re.sub(r'\n',' ', cleaned_title)
490
+ cleaned_title = re.sub(r'\~',' ', cleaned_title)
491
+ s = s.replace(title_content, cleaned_title)
492
+ write_tex_file(file, s)
493
+
494
+
495
+ def replace_imports(file, s):
496
+ regex_p1 = r'\\import{(.*?)}{(.*?)}'
497
+ s = re.sub(regex_p1, r"\\input{\1\2}", s)
498
+ regex_p2 = r'\\subfile{(.*?)}'
499
+ s = re.sub(regex_p2, r"\\input{\1}", s)
500
+ regex_p3 = r'\\include[*]{(.*?)}'
501
+ s = re.sub(regex_p3, r"\\input{\1}", s)
502
+ write_tex_file(file, s)
503
+ return s
504
+
505
+
506
+ def remove_multargument(s, target, k):
507
+ ind = s.find(target)
508
+ while ind != -1:
509
+ start_ind = ind + len(target)
510
+ stack_open = 0
511
+ stack_close = 0
512
+ track_arg = 0
513
+ for i, char in enumerate(s[start_ind:]):
514
+ if char == '{':
515
+ stack_open += 1
516
+ if char == '}':
517
+ stack_close += 1
518
+ if stack_open !=0 and stack_close !=0:
519
+ if stack_open == stack_close:
520
+ track_arg += 1
521
+ stack_open = 0
522
+ stack_close = 0
523
+ if track_arg == k:
524
+ break
525
+ s = s[:ind] + s[start_ind + i + 1:]
526
+ ind = s.find(target)
527
+ return s
528
+
529
+
530
+ def fix_citations(s):
531
+ pattern = {
532
+ r'\\citep\*': r'\\cite',
533
+ r'\\citet\*': r'\\cite',
534
+ r'\\citep': r'\\cite',
535
+ r'\\citet': r'\\cite',
536
+ r'\\cite\*': r'\\cite',
537
+ r'\\citealt\*': r'\\cite',
538
+ r'\\citealt': r'\\cite',
539
+ r'\\citealtp\*': r'\\cite',
540
+ r'\\citealp': r'\\cite',
541
+ r'\\citeyear\*': r'\\cite',
542
+ r'\\citeyear': r'\\cite',
543
+ r'\\citeauthor\*': r'\\cite',
544
+ r'\\citeauthor': r'\\cite',
545
+ r'\\citenum\*': r'\\cite',
546
+ r'\\citenum': r'\\cite'
547
+ }
548
+ for key in pattern:
549
+ s = re.sub(key, pattern[key], s)
550
+ return s
551
+
552
+ def find_bib(directory):
553
+ file_paths = []
554
+ for root, _, files in os.walk(directory):
555
+ for file in files:
556
+ file_path = os.path.join(root, file)
557
+ file_paths.append(file_path)
558
+ bib_paths = [f for f in file_paths if f.endswith('.bib')]
559
+ return bib_paths
560
+
561
+ def create_bib_from_bbl(bibfile):
562
+ with open(bibfile, 'r') as f:
563
+ content = f.read()
564
+ library_raw = bibtexparser.parse_string(content)
565
+ library = {}
566
+ for block in library_raw.blocks:
567
+ if isinstance(
568
+ block,
569
+ (bibtexparser.model.DuplicateBlockKeyBlock, bibtexparser.model.ParsingFailedBlock, bibtexparser.model.ImplicitComment)
570
+ ):
571
+ continue
572
+ fields = {}
573
+ for field in block.fields:
574
+ fields[field.key] = field.value
575
+
576
+ ## Get a good title one ##
577
+ field_content = fields["note"]
578
+ field_content = field_content.replace("\n", " ")
579
+ field_content = re.sub(" +", " ", field_content)
580
+ if field_content.find("``") != -1 and field_content.find("\'\'") != -1:
581
+ title = (
582
+ field_content[field_content.find("``") + 2 : field_content.find("\'\'")]
583
+ .replace("\\emph", "")
584
+ .replace("\\emp", "")
585
+ .replace("\\em", "")
586
+ .replace(",", "")
587
+ .replace("{", "")
588
+ .replace("}","")
589
+ .replace("``", "")
590
+ .replace("\'\'", "")
591
+ .strip(".")
592
+ .strip()
593
+ .strip(".")
594
+ .lower()
595
+ )
596
+ fields['title'] = title
597
+ else:
598
+ if field_content.count("\\newblock") == 2:
599
+ field_content = field_content.replace("\\newblock", "``", 1)
600
+ field_content = field_content.replace("\\newblock", "\'\'", 1)
601
+ if field_content.find("``") != -1 and field_content.find("\'\'") != -1:
602
+ title = (
603
+ field_content[field_content.find("``") + 2 : field_content.find("\'\'")]
604
+ .replace("\\emph", "")
605
+ .replace("\\emp", "")
606
+ .replace("\\em", "")
607
+ .replace(",", "")
608
+ .replace("{", "")
609
+ .replace("}","")
610
+ .replace("``", "")
611
+ .replace("\'\'", "")
612
+ .strip(".")
613
+ .strip()
614
+ .strip(".")
615
+ .lower()
616
+ )
617
+ fields['title'] = title
618
+ library[block.key] = fields
619
+ return library
620
+
621
+
622
+ def create_bib(bibfile):
623
+ with open(bibfile, 'r') as f:
624
+ content = f.read()
625
+ library_raw = bibtexparser.parse_string(content)
626
+
627
+ library = {}
628
+ for block in library_raw.blocks:
629
+ if isinstance(
630
+ block,
631
+ (bibtexparser.model.DuplicateBlockKeyBlock, bibtexparser.model.ParsingFailedBlock, bibtexparser.model.ImplicitComment)
632
+ ):
633
+ continue
634
+ fields = {}
635
+ for field in block.fields:
636
+ fields[field.key] = field.value.replace('{', '').replace('}', '')
637
+ if field.key == 'title':
638
+ title = re.sub(r'[\n]+', ' ', field.value) # keep only one \n
639
+ title = re.sub(r' +', ' ', title)
640
+ fields[field.key] = (
641
+ title.replace("\\emph", "")
642
+ .replace("\\emp", "")
643
+ .replace("\\em", "")
644
+ .replace(",", "")
645
+ .replace("{", "")
646
+ .replace("}", "")
647
+ .strip(".")
648
+ .strip()
649
+ .strip(".")
650
+ .lower()
651
+ )
652
+ if 'title' not in fields:
653
+ continue
654
+ library[block.key] = fields
655
+ return library
656
+
657
+
658
+ def find_bbl(directory):
659
+ file_paths = []
660
+ for root, _, files in os.walk(directory):
661
+ for file in files:
662
+ file_path = os.path.join(root, file)
663
+ file_paths.append(file_path)
664
+ bib_paths = [f for f in file_paths if f.endswith('.bbl')]
665
+ return bib_paths
666
+
667
+
668
+ def textobib(file):
669
+ oldpwd = os.getcwd()
670
+ target_dir = os.path.dirname(file) + '/'
671
+ target = target_dir + 'tex2bib'
672
+ src = './tex2bib'
673
+ shutil.copyfile(src, target)
674
+ os.chdir(target_dir)
675
+ output_file = os.path.splitext(os.path.basename(file))[0] + '.bib'
676
+ os.system('perl tex2bib -i {} -o {}'.format(os.path.basename(file), output_file))
677
+ os.chdir(oldpwd)
678
+ return target_dir + output_file
679
+
680
+
681
+ def get_library_bib(bib_files):
682
+ library = []
683
+ for bib_file in bib_files:
684
+ library.append(create_bib(bib_file))
685
+ final_library = {}
686
+ for d in library:
687
+ final_library.update(d)
688
+ return final_library
689
+
690
+
691
+ def get_library_bbl(bbl_files):
692
+ bib_files = []
693
+ for bbl_file in bbl_files:
694
+ bib_files.append(textobib(bbl_file))
695
+ library = []
696
+ for bib_file in bib_files:
697
+ library.append(create_bib_from_bbl(bib_file))
698
+ final_library = {}
699
+ for d in library:
700
+ final_library.update(d)
701
+ return final_library