Commit
·
bcd84b9
1
Parent(s):
3e16f65
first old model
Browse files- README.md +217 -3
- config.json +44 -0
- generation_config.json +8 -0
- generic_nel.py +192 -0
- handler.py +125 -0
- model.safetensors +3 -0
- requirements.txt +5 -0
- scheduler.pt +3 -0
- sentencepiece.bpe.model +3 -0
- special_tokens_map.json +51 -0
- tokenizer_config.json +55 -0
- trainer_state.json +207 -0
README.md
CHANGED
@@ -1,3 +1,217 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
language:
|
4 |
+
- multilingual
|
5 |
+
- af
|
6 |
+
- am
|
7 |
+
- ar
|
8 |
+
- as
|
9 |
+
- az
|
10 |
+
- be
|
11 |
+
- bg
|
12 |
+
- bm
|
13 |
+
- bn
|
14 |
+
- br
|
15 |
+
- bs
|
16 |
+
- ca
|
17 |
+
- cs
|
18 |
+
- cy
|
19 |
+
- da
|
20 |
+
- de
|
21 |
+
- el
|
22 |
+
- en
|
23 |
+
- eo
|
24 |
+
- es
|
25 |
+
- et
|
26 |
+
- eu
|
27 |
+
- fa
|
28 |
+
- ff
|
29 |
+
- fi
|
30 |
+
- fr
|
31 |
+
- fy
|
32 |
+
- ga
|
33 |
+
- gd
|
34 |
+
- gl
|
35 |
+
- gn
|
36 |
+
- gu
|
37 |
+
- ha
|
38 |
+
- he
|
39 |
+
- hi
|
40 |
+
- hr
|
41 |
+
- ht
|
42 |
+
- hu
|
43 |
+
- hy
|
44 |
+
- id
|
45 |
+
- ig
|
46 |
+
- is
|
47 |
+
- it
|
48 |
+
- ja
|
49 |
+
- jv
|
50 |
+
- ka
|
51 |
+
- kg
|
52 |
+
- kk
|
53 |
+
- km
|
54 |
+
- kn
|
55 |
+
- ko
|
56 |
+
- ku
|
57 |
+
- ky
|
58 |
+
- la
|
59 |
+
- lg
|
60 |
+
- ln
|
61 |
+
- lo
|
62 |
+
- lt
|
63 |
+
- lv
|
64 |
+
- mg
|
65 |
+
- mk
|
66 |
+
- ml
|
67 |
+
- mn
|
68 |
+
- mr
|
69 |
+
- ms
|
70 |
+
- my
|
71 |
+
- ne
|
72 |
+
- nl
|
73 |
+
- no
|
74 |
+
- om
|
75 |
+
- or
|
76 |
+
- pa
|
77 |
+
- pl
|
78 |
+
- ps
|
79 |
+
- pt
|
80 |
+
- qu
|
81 |
+
- ro
|
82 |
+
- ru
|
83 |
+
- sa
|
84 |
+
- sd
|
85 |
+
- si
|
86 |
+
- sk
|
87 |
+
- sl
|
88 |
+
- so
|
89 |
+
- sq
|
90 |
+
- sr
|
91 |
+
- ss
|
92 |
+
- su
|
93 |
+
- sv
|
94 |
+
- sw
|
95 |
+
- ta
|
96 |
+
- te
|
97 |
+
- th
|
98 |
+
- ti
|
99 |
+
- tl
|
100 |
+
- tn
|
101 |
+
- tr
|
102 |
+
- uk
|
103 |
+
- ur
|
104 |
+
- uz
|
105 |
+
- vi
|
106 |
+
- wo
|
107 |
+
- xh
|
108 |
+
- yo
|
109 |
+
- zh
|
110 |
+
|
111 |
+
license: agpl-3.0
|
112 |
+
tags:
|
113 |
+
- retrieval
|
114 |
+
- entity-retrieval
|
115 |
+
- named-entity-disambiguation
|
116 |
+
- entity-disambiguation
|
117 |
+
- named-entity-linking
|
118 |
+
- entity-linking
|
119 |
+
- text2text-generation
|
120 |
+
---
|
121 |
+
|
122 |
+
# Model Card for `impresso-project/nel-mgenre-multilingual`
|
123 |
+
|
124 |
+
The **Impresso multilingual named entity linking (NEL)** model is based on **mGENRE** (multilingual Generative ENtity REtrieval) proposed by [De Cao et al](https://arxiv.org/abs/2103.12528), a sequence-to-sequence architecture for entity disambiguation based on [mBART](https://arxiv.org/abs/2001.08210). It uses **constrained generation** to output entity names mapped to Wikidata/QIDs.
|
125 |
+
|
126 |
+
This model was adapted for historical texts and fine-tuned on the [HIPE-2022 dataset](https://github.com/hipe-eval/HIPE-2022-data), which includes a variety of historical document types and languages.
|
127 |
+
|
128 |
+
## Model Details
|
129 |
+
|
130 |
+
### Model Description
|
131 |
+
|
132 |
+
### Model Description
|
133 |
+
|
134 |
+
- **Developed by:** EPFL from the [Impresso team](https://impresso-project.ch). The project is an interdisciplinary project focused on historical media analysis across languages, time, and modalities. Funded by the Swiss National Science Foundation ([CRSII5_173719](http://p3.snf.ch/project-173719), [CRSII5_213585](https://data.snf.ch/grants/grant/213585)) and the Luxembourg National Research Fund (grant No. 17498891).
|
135 |
+
- **Model type:** mBART-based sequence-to-sequence model with constrained beam search for named entity linking
|
136 |
+
- **Languages:** Multilingual (100+ languages, optimized for French, German, and English)
|
137 |
+
- **License:** [AGPL v3+](https://github.com/impresso/impresso-pyindexation/blob/master/LICENSE)
|
138 |
+
- **Finetuned from:** [`facebook/mgenre-wiki`](https://huggingface.co/facebook/mgenre-wiki)
|
139 |
+
-
|
140 |
+
### Model Architecture
|
141 |
+
|
142 |
+
- **Architecture:** mBART-based seq2seq with constrained beam search
|
143 |
+
|
144 |
+
## Training Details
|
145 |
+
|
146 |
+
### Training Data
|
147 |
+
|
148 |
+
The model was trained on the following datasets:
|
149 |
+
|
150 |
+
| Dataset alias | README | Document type | Languages | Suitable for | Project | License |
|
151 |
+
|---------|---------|---------------|-----------| ---------------|---------------| ---------------|
|
152 |
+
| ajmc | [link](documentation/README-ajmc.md) | classical commentaries | de, fr, en | NERC-Coarse, NERC-Fine, EL | [AjMC](https://mromanello.github.io/ajax-multi-commentary/) | [](https://creativecommons.org/licenses/by/4.0/) |
|
153 |
+
| hipe2020 | [link](documentation/README-hipe2020.md)| historical newspapers | de, fr, en | NERC-Coarse, NERC-Fine, EL | [CLEF-HIPE-2020](https://impresso.github.io/CLEF-HIPE-2020)| [](https://creativecommons.org/licenses/by-nc-sa/4.0/)|
|
154 |
+
| topres19th | [link](documentation/README-topres19th.md) | historical newspapers | en | NERC-Coarse, EL |[Living with Machines](https://livingwithmachines.ac.uk/) | [](https://creativecommons.org/licenses/by-nc-sa/4.0/)|
|
155 |
+
| newseye | [link](documentation/README-newseye.md)| historical newspapers | de, fi, fr, sv | NERC-Coarse, NERC-Fine, EL | [NewsEye](https://www.newseye.eu/) | [](https://creativecommons.org/licenses/by/4.0/)|
|
156 |
+
| sonar | [link](documentation/README-sonar.md) | historical newspapers | de | NERC-Coarse, EL | [SoNAR](https://sonar.fh-potsdam.de/) | [](https://creativecommons.org/licenses/by/4.0/)|
|
157 |
+
|
158 |
+
|
159 |
+
## How to Use
|
160 |
+
|
161 |
+
```python
|
162 |
+
from transformers import AutoTokenizer, pipeline
|
163 |
+
|
164 |
+
NEL_MODEL_NAME = "impresso-project/nel-mgenre-multilingual"
|
165 |
+
nel_tokenizer = AutoTokenizer.from_pretrained(NEL_MODEL_NAME)
|
166 |
+
|
167 |
+
nel_pipeline = pipeline("generic-nel", model=NEL_MODEL_NAME,
|
168 |
+
tokenizer=nel_tokenizer,
|
169 |
+
trust_remote_code=True,
|
170 |
+
device='cpu')
|
171 |
+
|
172 |
+
sentence = "Le 0ctobre 1894, [START] Dreyfvs [END] est arrêté à Paris, accusé d'espionnage pour l'Allemagne — un événement qui déch1ra la société fr4nçaise pendant des années."
|
173 |
+
print(nel_pipeline(sentence))
|
174 |
+
```
|
175 |
+
|
176 |
+
### Output Format
|
177 |
+
|
178 |
+
```python
|
179 |
+
[
|
180 |
+
{
|
181 |
+
'surface': 'Dreyfvs',
|
182 |
+
'wkd_id': 'Q171826',
|
183 |
+
'wkpedia_pagename': 'Alfred Dreyfus',
|
184 |
+
'wkpedia_url': 'https://fr.wikipedia.org/wiki/Alfred_Dreyfus',
|
185 |
+
'type': 'UNK',
|
186 |
+
'confidence_nel': 99.98,
|
187 |
+
'lOffset': 24,
|
188 |
+
'rOffset': 33}]
|
189 |
+
```
|
190 |
+
The type of the entity is `UNK` because the model was not trained on the entity type. The `confidence_nel` score indicates the model's confidence in the prediction.
|
191 |
+
|
192 |
+
## Use Cases
|
193 |
+
|
194 |
+
- Entity disambiguation in noisy OCR settings
|
195 |
+
- Linking historical names to modern Wikidata entities
|
196 |
+
- Assisting downstream event extraction and biography generation from historical archives
|
197 |
+
|
198 |
+
## Limitations
|
199 |
+
|
200 |
+
- Sensitive to tokenisation and malformed spans
|
201 |
+
- Accuracy degrades on non-Wikidata entities or in highly ambiguous contexts
|
202 |
+
- Focused on historical entity mentions — performance may vary on modern texts
|
203 |
+
|
204 |
+
## Environmental Impact
|
205 |
+
|
206 |
+
- **Hardware:** 1x A100 (80GB) for finetuning
|
207 |
+
- **Training time:** ~12 hours
|
208 |
+
- **Estimated CO₂ Emissions:** ~2.3 kg CO₂eq
|
209 |
+
|
210 |
+
## Contact
|
211 |
+
|
212 |
+
- Website: [https://impresso-project.ch](https://impresso-project.ch)
|
213 |
+
|
214 |
+
<p align="center">
|
215 |
+
<img src="https://github.com/impresso/impresso.github.io/blob/master/assets/images/3x1--Yellow-Impresso-Black-on-White--transparent.png?raw=true" width="300" alt="Impresso Logo"/>
|
216 |
+
</p>
|
217 |
+
|
config.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "facebook/mgenre-wiki",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"architectures": [
|
6 |
+
"MBartForConditionalGeneration"
|
7 |
+
],
|
8 |
+
"custom_pipelines": {
|
9 |
+
"generic-nel": {
|
10 |
+
"impl": "generic_nel.NelPipeline",
|
11 |
+
"pt": [
|
12 |
+
"MBartForConditionalGeneration"
|
13 |
+
],
|
14 |
+
"tf": []
|
15 |
+
}
|
16 |
+
},
|
17 |
+
"attention_dropout": 0.0,
|
18 |
+
"bos_token_id": 0,
|
19 |
+
"classifier_dropout": 0.0,
|
20 |
+
"d_model": 1024,
|
21 |
+
"decoder_attention_heads": 16,
|
22 |
+
"decoder_ffn_dim": 4096,
|
23 |
+
"decoder_layerdrop": 0.0,
|
24 |
+
"decoder_layers": 12,
|
25 |
+
"decoder_start_token_id": 2,
|
26 |
+
"dropout": 0.1,
|
27 |
+
"encoder_attention_heads": 16,
|
28 |
+
"encoder_ffn_dim": 4096,
|
29 |
+
"encoder_layerdrop": 0.0,
|
30 |
+
"encoder_layers": 12,
|
31 |
+
"eos_token_id": 2,
|
32 |
+
"forced_eos_token_id": 2,
|
33 |
+
"init_std": 0.02,
|
34 |
+
"is_encoder_decoder": true,
|
35 |
+
"max_position_embeddings": 1024,
|
36 |
+
"model_type": "mbart",
|
37 |
+
"num_hidden_layers": 12,
|
38 |
+
"pad_token_id": 1,
|
39 |
+
"scale_embedding": true,
|
40 |
+
"torch_dtype": "float32",
|
41 |
+
"transformers_version": "4.31.0",
|
42 |
+
"use_cache": true,
|
43 |
+
"vocab_size": 256001
|
44 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 0,
|
3 |
+
"decoder_start_token_id": 2,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"forced_eos_token_id": 2,
|
6 |
+
"pad_token_id": 1,
|
7 |
+
"transformers_version": "4.46.0.dev0"
|
8 |
+
}
|
generic_nel.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline
|
2 |
+
import nltk
|
3 |
+
import requests
|
4 |
+
import torch
|
5 |
+
|
6 |
+
nltk.download("averaged_perceptron_tagger")
|
7 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
8 |
+
|
9 |
+
NEL_MODEL = "nel-mgenre-multilingual"
|
10 |
+
|
11 |
+
|
12 |
+
def get_wikipedia_page_props(input_str: str):
|
13 |
+
"""
|
14 |
+
Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
|
15 |
+
If the request fails, it falls back to using the OpenRefine Wikidata API.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
input_str (str): The input string in the format "page_name >> language".
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
str: The QID or "NIL" if the QID is not found.
|
22 |
+
"""
|
23 |
+
# print(f"Input string: {input_str}")
|
24 |
+
if ">>" not in input_str:
|
25 |
+
page_name = input_str
|
26 |
+
language = "en"
|
27 |
+
print(
|
28 |
+
f"<< was not found in {input_str} so we are checking with these values: Page name: {page_name}, Language: {language}"
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
# Preprocess the input string
|
32 |
+
try:
|
33 |
+
page_name, language = input_str.split(">>")
|
34 |
+
page_name = page_name.strip()
|
35 |
+
language = language.strip()
|
36 |
+
except:
|
37 |
+
page_name = input_str
|
38 |
+
language = "en"
|
39 |
+
print(
|
40 |
+
f"<< was not found in {input_str} so we are checking with these values: Page name: {page_name}, Language: {language}"
|
41 |
+
)
|
42 |
+
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
|
43 |
+
wikipedia_params = {
|
44 |
+
"action": "query",
|
45 |
+
"prop": "pageprops",
|
46 |
+
"format": "json",
|
47 |
+
"titles": page_name,
|
48 |
+
}
|
49 |
+
|
50 |
+
qid = "NIL"
|
51 |
+
try:
|
52 |
+
# Attempt to fetch from Wikipedia API
|
53 |
+
response = requests.get(wikipedia_url, params=wikipedia_params)
|
54 |
+
response.raise_for_status()
|
55 |
+
data = response.json()
|
56 |
+
|
57 |
+
if "pages" in data["query"]:
|
58 |
+
page_id = list(data["query"]["pages"].keys())[0]
|
59 |
+
|
60 |
+
if "pageprops" in data["query"]["pages"][page_id]:
|
61 |
+
page_props = data["query"]["pages"][page_id]["pageprops"]
|
62 |
+
|
63 |
+
if "wikibase_item" in page_props:
|
64 |
+
# print(page_props["wikibase_item"], language)
|
65 |
+
return page_props["wikibase_item"], language
|
66 |
+
else:
|
67 |
+
return qid, language
|
68 |
+
else:
|
69 |
+
return qid, language
|
70 |
+
else:
|
71 |
+
return qid, language
|
72 |
+
except Exception as e:
|
73 |
+
return qid, language
|
74 |
+
|
75 |
+
|
76 |
+
def get_wikipedia_title(qid, language="en"):
|
77 |
+
url = f"https://www.wikidata.org/w/api.php"
|
78 |
+
params = {
|
79 |
+
"action": "wbgetentities",
|
80 |
+
"format": "json",
|
81 |
+
"ids": qid,
|
82 |
+
"props": "sitelinks/urls",
|
83 |
+
"sitefilter": f"{language}wiki",
|
84 |
+
}
|
85 |
+
|
86 |
+
response = requests.get(url, params=params)
|
87 |
+
try:
|
88 |
+
response.raise_for_status() # Raise an HTTPError if the response was not 2xx
|
89 |
+
data = response.json()
|
90 |
+
except requests.exceptions.RequestException as e:
|
91 |
+
print(f"HTTP error: {e}")
|
92 |
+
return "NIL", "None"
|
93 |
+
except ValueError as e: # Catch JSON decode errors
|
94 |
+
print(f"Invalid JSON response: {response.text}")
|
95 |
+
return "NIL", "None"
|
96 |
+
|
97 |
+
try:
|
98 |
+
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
|
99 |
+
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
|
100 |
+
return title, url
|
101 |
+
except KeyError:
|
102 |
+
return "NIL", "None"
|
103 |
+
|
104 |
+
|
105 |
+
class NelPipeline(Pipeline):
|
106 |
+
|
107 |
+
def _sanitize_parameters(self, **kwargs):
|
108 |
+
preprocess_kwargs = {}
|
109 |
+
if "text" in kwargs:
|
110 |
+
preprocess_kwargs["text"] = kwargs["text"]
|
111 |
+
|
112 |
+
return preprocess_kwargs, {}, {}
|
113 |
+
|
114 |
+
def preprocess(self, text, **kwargs):
|
115 |
+
# Extract the entity between [START] and [END]
|
116 |
+
start_token = "[START]"
|
117 |
+
end_token = "[END]"
|
118 |
+
|
119 |
+
if start_token in text and end_token in text:
|
120 |
+
start_idx = text.index(start_token) + len(start_token)
|
121 |
+
end_idx = text.index(end_token)
|
122 |
+
enclosed_entity = text[start_idx:end_idx].strip()
|
123 |
+
lOffset = start_idx # left offset (start of the entity)
|
124 |
+
rOffset = end_idx # right offset (end of the entity)
|
125 |
+
else:
|
126 |
+
enclosed_entity = None
|
127 |
+
lOffset = None
|
128 |
+
rOffset = None
|
129 |
+
|
130 |
+
# Generate predictions using the model
|
131 |
+
outputs = self.model.generate(
|
132 |
+
**self.tokenizer([text], return_tensors="pt").to(self.device),
|
133 |
+
num_beams=1,
|
134 |
+
num_return_sequences=1,
|
135 |
+
max_new_tokens=30,
|
136 |
+
return_dict_in_generate=True,
|
137 |
+
output_scores=True,
|
138 |
+
)
|
139 |
+
# Decode the predictions into readable text
|
140 |
+
wikipedia_prediction = self.tokenizer.batch_decode(
|
141 |
+
outputs.sequences, skip_special_tokens=True
|
142 |
+
)[0]
|
143 |
+
# Process the scores for each token
|
144 |
+
|
145 |
+
transition_scores = self.model.compute_transition_scores(
|
146 |
+
outputs.sequences, outputs.scores, normalize_logits=True
|
147 |
+
)
|
148 |
+
log_prob_sum = sum(transition_scores[0])
|
149 |
+
|
150 |
+
# Calculate the probability for the entire sequence by exponentiating the sum of log probabilities
|
151 |
+
sequence_confidence = torch.exp(log_prob_sum)
|
152 |
+
percentage = sequence_confidence.cpu().numpy() * 100.0
|
153 |
+
|
154 |
+
# print(wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage)
|
155 |
+
|
156 |
+
# Return the predictions along with the extracted entity, lOffset, and rOffset
|
157 |
+
return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage
|
158 |
+
|
159 |
+
def _forward(self, inputs):
|
160 |
+
return inputs
|
161 |
+
|
162 |
+
def postprocess(self, outputs, **kwargs):
|
163 |
+
"""
|
164 |
+
Postprocess the outputs of the model
|
165 |
+
:param outputs:
|
166 |
+
:param kwargs:
|
167 |
+
:return:
|
168 |
+
"""
|
169 |
+
|
170 |
+
wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
|
171 |
+
qid, language = get_wikipedia_page_props(wikipedia_prediction)
|
172 |
+
title, url = get_wikipedia_title(qid, language=language)
|
173 |
+
|
174 |
+
# if title is "NIL":
|
175 |
+
# title = wikipedia_prediction
|
176 |
+
|
177 |
+
percentage = round(percentage, 2)
|
178 |
+
|
179 |
+
results = [
|
180 |
+
{
|
181 |
+
# "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
|
182 |
+
"surface": enclosed_entity,
|
183 |
+
"wkd_id": qid,
|
184 |
+
"wkpedia_pagename": title,
|
185 |
+
"wkpedia_url": url,
|
186 |
+
"type": "UNK",
|
187 |
+
"confidence_nel": percentage,
|
188 |
+
"lOffset": lOffset,
|
189 |
+
"rOffset": rOffset,
|
190 |
+
}
|
191 |
+
]
|
192 |
+
return results
|
handler.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
import requests
|
5 |
+
import nltk
|
6 |
+
from transformers import pipeline
|
7 |
+
|
8 |
+
# Download required NLTK models
|
9 |
+
nltk.download("averaged_perceptron_tagger")
|
10 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
11 |
+
|
12 |
+
# Define your model name
|
13 |
+
NEL_MODEL = "nel-mgenre-multilingual"
|
14 |
+
|
15 |
+
|
16 |
+
def get_wikipedia_page_props(input_str: str):
|
17 |
+
if ">>" not in input_str:
|
18 |
+
page_name = input_str
|
19 |
+
language = "en"
|
20 |
+
else:
|
21 |
+
try:
|
22 |
+
page_name, language = input_str.split(">>")
|
23 |
+
page_name = page_name.strip()
|
24 |
+
language = language.strip()
|
25 |
+
except:
|
26 |
+
page_name = input_str
|
27 |
+
language = "en"
|
28 |
+
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
|
29 |
+
wikipedia_params = {
|
30 |
+
"action": "query",
|
31 |
+
"prop": "pageprops",
|
32 |
+
"format": "json",
|
33 |
+
"titles": page_name,
|
34 |
+
}
|
35 |
+
|
36 |
+
qid = "NIL"
|
37 |
+
try:
|
38 |
+
response = requests.get(wikipedia_url, params=wikipedia_params)
|
39 |
+
response.raise_for_status()
|
40 |
+
data = response.json()
|
41 |
+
|
42 |
+
if "pages" in data["query"]:
|
43 |
+
page_id = list(data["query"]["pages"].keys())[0]
|
44 |
+
|
45 |
+
if "pageprops" in data["query"]["pages"][page_id]:
|
46 |
+
page_props = data["query"]["pages"][page_id]["pageprops"]
|
47 |
+
|
48 |
+
if "wikibase_item" in page_props:
|
49 |
+
return page_props["wikibase_item"], language
|
50 |
+
else:
|
51 |
+
return qid, language
|
52 |
+
else:
|
53 |
+
return qid, language
|
54 |
+
else:
|
55 |
+
return qid, language
|
56 |
+
except Exception as e:
|
57 |
+
return qid, language
|
58 |
+
|
59 |
+
|
60 |
+
def get_wikipedia_title(qid, language="en"):
|
61 |
+
url = f"https://www.wikidata.org/w/api.php"
|
62 |
+
params = {
|
63 |
+
"action": "wbgetentities",
|
64 |
+
"format": "json",
|
65 |
+
"ids": qid,
|
66 |
+
"props": "sitelinks/urls",
|
67 |
+
"sitefilter": f"{language}wiki",
|
68 |
+
}
|
69 |
+
|
70 |
+
response = requests.get(url, params=params)
|
71 |
+
try:
|
72 |
+
response.raise_for_status()
|
73 |
+
data = response.json()
|
74 |
+
except requests.exceptions.RequestException as e:
|
75 |
+
return "NIL", "None"
|
76 |
+
except ValueError as e:
|
77 |
+
return "NIL", "None"
|
78 |
+
|
79 |
+
try:
|
80 |
+
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
|
81 |
+
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
|
82 |
+
return title, url
|
83 |
+
except KeyError:
|
84 |
+
return "NIL", "None"
|
85 |
+
|
86 |
+
|
87 |
+
class NelPipeline:
|
88 |
+
def __init__(self, model_dir: str = "."):
|
89 |
+
self.model_name = NEL_MODEL
|
90 |
+
print(f"Loading {model_dir}")
|
91 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
92 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
93 |
+
self.model = pipeline("generic-nel", model="impresso-project/nel-mgenre-multilingual",
|
94 |
+
tokenizer=self.tokenizer,
|
95 |
+
trust_remote_code=True,
|
96 |
+
device=self.device)
|
97 |
+
|
98 |
+
def preprocess(self, text: str):
|
99 |
+
|
100 |
+
linked_entity = self.model(text)
|
101 |
+
|
102 |
+
return linked_entity
|
103 |
+
|
104 |
+
def postprocess(self, outputs):
|
105 |
+
linked_entity = outputs
|
106 |
+
|
107 |
+
return linked_entity
|
108 |
+
|
109 |
+
|
110 |
+
class EndpointHandler:
|
111 |
+
def __init__(self, path: str = None):
|
112 |
+
# Initialize the NelPipeline with the specified model
|
113 |
+
self.pipeline = NelPipeline("impresso-project/nel-mgenre-multilingual")
|
114 |
+
|
115 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
116 |
+
# Process incoming data
|
117 |
+
inputs = data.get("inputs", "")
|
118 |
+
if not isinstance(inputs, str):
|
119 |
+
raise ValueError("Input must be a string.")
|
120 |
+
|
121 |
+
# Preprocess, forward, and postprocess
|
122 |
+
preprocessed = self.pipeline.preprocess(inputs)
|
123 |
+
results = self.pipeline.postprocess(preprocessed)
|
124 |
+
|
125 |
+
return results
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8cfb5bf9aa521336b586ae37eecac31ed7e86327a1be1802d32551472988633
|
3 |
+
size 2468961388
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
requests
|
5 |
+
typing
|
scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc00424e1006d8552c992d3bde1acf8d4282909093c4e18a2112a6e6b087b217
|
3 |
+
size 1064
|
sentencepiece.bpe.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ee4dc054a17c18fe81f76c0b1cda00e9fc1cfd9e0f1a16cb6d77009e2076653
|
3 |
+
size 4870365
|
special_tokens_map.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"cls_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"eos_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"mask_token": {
|
24 |
+
"content": "<mask>",
|
25 |
+
"lstrip": true,
|
26 |
+
"normalized": true,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"pad_token": {
|
31 |
+
"content": "<pad>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
},
|
37 |
+
"sep_token": {
|
38 |
+
"content": "</s>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false
|
43 |
+
},
|
44 |
+
"unk_token": {
|
45 |
+
"content": "<unk>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
}
|
51 |
+
}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<pad>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"256001": {
|
36 |
+
"content": "<mask>",
|
37 |
+
"lstrip": true,
|
38 |
+
"normalized": true,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"clean_up_tokenization_spaces": false,
|
46 |
+
"cls_token": "<s>",
|
47 |
+
"eos_token": "</s>",
|
48 |
+
"mask_token": "<mask>",
|
49 |
+
"model_max_length": 512,
|
50 |
+
"pad_token": "<pad>",
|
51 |
+
"sep_token": "</s>",
|
52 |
+
"sp_model_kwargs": {},
|
53 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
54 |
+
"unk_token": "<unk>"
|
55 |
+
}
|
trainer_state.json
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 10.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 6480,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.7716049382716049,
|
13 |
+
"grad_norm": 1.0553990602493286,
|
14 |
+
"learning_rate": 1.846913580246914e-05,
|
15 |
+
"loss": 0.9346,
|
16 |
+
"step": 500
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"epoch": 1.0,
|
20 |
+
"eval_bleu": 0.0,
|
21 |
+
"eval_gen_len": 10.0959,
|
22 |
+
"eval_loss": 0.15309424698352814,
|
23 |
+
"eval_runtime": 7.8759,
|
24 |
+
"eval_samples_per_second": 154.903,
|
25 |
+
"eval_steps_per_second": 2.539,
|
26 |
+
"step": 648
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"epoch": 1.5432098765432098,
|
30 |
+
"grad_norm": 0.7297214269638062,
|
31 |
+
"learning_rate": 1.6925925925925926e-05,
|
32 |
+
"loss": 0.0763,
|
33 |
+
"step": 1000
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"epoch": 2.0,
|
37 |
+
"eval_bleu": 0.0,
|
38 |
+
"eval_gen_len": 10.159,
|
39 |
+
"eval_loss": 0.16104426980018616,
|
40 |
+
"eval_runtime": 7.6405,
|
41 |
+
"eval_samples_per_second": 159.674,
|
42 |
+
"eval_steps_per_second": 2.618,
|
43 |
+
"step": 1296
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"epoch": 2.314814814814815,
|
47 |
+
"grad_norm": 0.44237253069877625,
|
48 |
+
"learning_rate": 1.5382716049382717e-05,
|
49 |
+
"loss": 0.0446,
|
50 |
+
"step": 1500
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"epoch": 3.0,
|
54 |
+
"eval_bleu": 0.0,
|
55 |
+
"eval_gen_len": 10.0426,
|
56 |
+
"eval_loss": 0.17489495873451233,
|
57 |
+
"eval_runtime": 7.8385,
|
58 |
+
"eval_samples_per_second": 155.642,
|
59 |
+
"eval_steps_per_second": 2.552,
|
60 |
+
"step": 1944
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"epoch": 3.0864197530864197,
|
64 |
+
"grad_norm": 0.3801327049732208,
|
65 |
+
"learning_rate": 1.3839506172839507e-05,
|
66 |
+
"loss": 0.0275,
|
67 |
+
"step": 2000
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"epoch": 3.8580246913580245,
|
71 |
+
"grad_norm": 0.29495081305503845,
|
72 |
+
"learning_rate": 1.2296296296296298e-05,
|
73 |
+
"loss": 0.0162,
|
74 |
+
"step": 2500
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"epoch": 4.0,
|
78 |
+
"eval_bleu": 0.0,
|
79 |
+
"eval_gen_len": 10.1139,
|
80 |
+
"eval_loss": 0.1843736320734024,
|
81 |
+
"eval_runtime": 7.7649,
|
82 |
+
"eval_samples_per_second": 157.118,
|
83 |
+
"eval_steps_per_second": 2.576,
|
84 |
+
"step": 2592
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"epoch": 4.62962962962963,
|
88 |
+
"grad_norm": 0.29735738039016724,
|
89 |
+
"learning_rate": 1.0753086419753086e-05,
|
90 |
+
"loss": 0.0106,
|
91 |
+
"step": 3000
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"epoch": 5.0,
|
95 |
+
"eval_bleu": 0.0,
|
96 |
+
"eval_gen_len": 9.9508,
|
97 |
+
"eval_loss": 0.19341909885406494,
|
98 |
+
"eval_runtime": 7.6995,
|
99 |
+
"eval_samples_per_second": 158.452,
|
100 |
+
"eval_steps_per_second": 2.598,
|
101 |
+
"step": 3240
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"epoch": 5.401234567901234,
|
105 |
+
"grad_norm": 0.07027166336774826,
|
106 |
+
"learning_rate": 9.209876543209877e-06,
|
107 |
+
"loss": 0.0076,
|
108 |
+
"step": 3500
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"epoch": 6.0,
|
112 |
+
"eval_bleu": 0.0,
|
113 |
+
"eval_gen_len": 9.9377,
|
114 |
+
"eval_loss": 0.20017552375793457,
|
115 |
+
"eval_runtime": 7.6996,
|
116 |
+
"eval_samples_per_second": 158.45,
|
117 |
+
"eval_steps_per_second": 2.598,
|
118 |
+
"step": 3888
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"epoch": 6.172839506172839,
|
122 |
+
"grad_norm": 0.1504916250705719,
|
123 |
+
"learning_rate": 7.666666666666667e-06,
|
124 |
+
"loss": 0.0059,
|
125 |
+
"step": 4000
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"epoch": 6.944444444444445,
|
129 |
+
"grad_norm": 0.24264627695083618,
|
130 |
+
"learning_rate": 6.123456790123458e-06,
|
131 |
+
"loss": 0.0043,
|
132 |
+
"step": 4500
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"epoch": 7.0,
|
136 |
+
"eval_bleu": 0.0,
|
137 |
+
"eval_gen_len": 10.0279,
|
138 |
+
"eval_loss": 0.20386986434459686,
|
139 |
+
"eval_runtime": 7.7944,
|
140 |
+
"eval_samples_per_second": 156.523,
|
141 |
+
"eval_steps_per_second": 2.566,
|
142 |
+
"step": 4536
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"epoch": 7.716049382716049,
|
146 |
+
"grad_norm": 0.08363181352615356,
|
147 |
+
"learning_rate": 4.580246913580247e-06,
|
148 |
+
"loss": 0.0035,
|
149 |
+
"step": 5000
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"epoch": 8.0,
|
153 |
+
"eval_bleu": 0.0,
|
154 |
+
"eval_gen_len": 10.1566,
|
155 |
+
"eval_loss": 0.20531675219535828,
|
156 |
+
"eval_runtime": 7.6989,
|
157 |
+
"eval_samples_per_second": 158.465,
|
158 |
+
"eval_steps_per_second": 2.598,
|
159 |
+
"step": 5184
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"epoch": 8.487654320987655,
|
163 |
+
"grad_norm": 0.13225023448467255,
|
164 |
+
"learning_rate": 3.0370370370370372e-06,
|
165 |
+
"loss": 0.0029,
|
166 |
+
"step": 5500
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"epoch": 9.0,
|
170 |
+
"eval_bleu": 0.0,
|
171 |
+
"eval_gen_len": 10.0689,
|
172 |
+
"eval_loss": 0.20702147483825684,
|
173 |
+
"eval_runtime": 7.6619,
|
174 |
+
"eval_samples_per_second": 159.23,
|
175 |
+
"eval_steps_per_second": 2.61,
|
176 |
+
"step": 5832
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"epoch": 9.25925925925926,
|
180 |
+
"grad_norm": 0.022540247067809105,
|
181 |
+
"learning_rate": 1.4938271604938272e-06,
|
182 |
+
"loss": 0.003,
|
183 |
+
"step": 6000
|
184 |
+
}
|
185 |
+
],
|
186 |
+
"logging_steps": 500,
|
187 |
+
"max_steps": 6480,
|
188 |
+
"num_input_tokens_seen": 0,
|
189 |
+
"num_train_epochs": 10,
|
190 |
+
"save_steps": 1000,
|
191 |
+
"stateful_callbacks": {
|
192 |
+
"TrainerControl": {
|
193 |
+
"args": {
|
194 |
+
"should_epoch_stop": false,
|
195 |
+
"should_evaluate": false,
|
196 |
+
"should_log": false,
|
197 |
+
"should_save": true,
|
198 |
+
"should_training_stop": true
|
199 |
+
},
|
200 |
+
"attributes": {}
|
201 |
+
}
|
202 |
+
},
|
203 |
+
"total_flos": 4.288315707563704e+17,
|
204 |
+
"train_batch_size": 64,
|
205 |
+
"trial_name": null,
|
206 |
+
"trial_params": null
|
207 |
+
}
|