Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
·
126a4c6
1
Parent(s):
a240da9
Refactor ml model handling
Browse files- src/db.py +4 -1
- src/ml.py +116 -67
- src/webservice.py +6 -0
- src/worker.py +79 -38
src/db.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
import sqlite3
|
3 |
from contextlib import contextmanager
|
4 |
from typing import Generator
|
@@ -6,6 +7,8 @@ from typing import Generator
|
|
6 |
logger = logging.getLogger(__name__)
|
7 |
logger.setLevel(logging.DEBUG)
|
8 |
|
|
|
|
|
9 |
|
10 |
schema_entries = """
|
11 |
CREATE TABLE entries
|
@@ -67,7 +70,7 @@ def _get_db_connection() -> sqlite3.Connection:
|
|
67 |
global TABLES_CREATED
|
68 |
|
69 |
# sqlite cannot deal with concurrent access, so we set a big timeout
|
70 |
-
conn = sqlite3.connect(
|
71 |
if TABLES_CREATED:
|
72 |
return conn
|
73 |
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
import sqlite3
|
4 |
from contextlib import contextmanager
|
5 |
from typing import Generator
|
|
|
7 |
logger = logging.getLogger(__name__)
|
8 |
logger.setLevel(logging.DEBUG)
|
9 |
|
10 |
+
db_file = os.getenv("DB_FILE_NAME", "sqlite-data.db")
|
11 |
+
|
12 |
|
13 |
schema_entries = """
|
14 |
CREATE TABLE entries
|
|
|
70 |
global TABLES_CREATED
|
71 |
|
72 |
# sqlite cannot deal with concurrent access, so we set a big timeout
|
73 |
+
conn = sqlite3.connect(db_file, timeout=30)
|
74 |
if TABLES_CREATED:
|
75 |
return conn
|
76 |
|
src/ml.py
CHANGED
@@ -1,52 +1,126 @@
|
|
1 |
import abc
|
|
|
2 |
import logging
|
3 |
import re
|
4 |
|
5 |
import httpx
|
6 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
7 |
|
8 |
from base import JobInput
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logger.setLevel(logging.DEBUG)
|
12 |
|
13 |
-
MODEL_NAME = "google/flan-t5-large"
|
14 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
15 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def __init__(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
self.template = "Summarize the text below in two sentences:\n\n{}"
|
21 |
-
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
22 |
-
self.generation_config.max_new_tokens = 200
|
23 |
-
self.generation_config.min_new_tokens = 100
|
24 |
-
self.generation_config.top_k = 5
|
25 |
-
self.generation_config.repetition_penalty = 1.5
|
26 |
|
27 |
def __call__(self, x: str) -> str:
|
28 |
text = self.template.format(x)
|
29 |
-
inputs = tokenizer(text, return_tensors="pt")
|
30 |
-
outputs = model.generate(**inputs, generation_config=self.generation_config)
|
31 |
-
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
32 |
assert isinstance(output, str)
|
33 |
return output
|
34 |
|
35 |
def get_name(self) -> str:
|
36 |
-
return f"
|
37 |
|
38 |
|
39 |
-
class Tagger:
|
40 |
-
def __init__(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
41 |
self.template = (
|
42 |
"Create a list of tags for the text below. The tags should be high level "
|
43 |
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
|
44 |
)
|
45 |
-
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
46 |
-
self.generation_config.max_new_tokens = 50
|
47 |
-
self.generation_config.min_new_tokens = 25
|
48 |
-
# increase the temperature to make the model more creative
|
49 |
-
self.generation_config.temperature = 1.5
|
50 |
|
51 |
def _extract_tags(self, text: str) -> list[str]:
|
52 |
tags = set()
|
@@ -57,46 +131,25 @@ class Tagger:
|
|
57 |
|
58 |
def __call__(self, x: str) -> list[str]:
|
59 |
text = self.template.format(x)
|
60 |
-
inputs = tokenizer(text, return_tensors="pt")
|
61 |
-
outputs = model.generate(**inputs, generation_config=self.generation_config)
|
62 |
-
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
63 |
tags = self._extract_tags(output)
|
64 |
return tags
|
65 |
|
66 |
def get_name(self) -> str:
|
67 |
-
return f"
|
68 |
-
|
69 |
-
|
70 |
-
class Processor(abc.ABC):
|
71 |
-
def __call__(self, job: JobInput) -> str:
|
72 |
-
_id = job.id
|
73 |
-
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
|
74 |
-
result = self.process(job)
|
75 |
-
logger.info(f"Finished processing input (id={_id[:8]})")
|
76 |
-
return result
|
77 |
-
|
78 |
-
def process(self, input: JobInput) -> str:
|
79 |
-
raise NotImplementedError
|
80 |
-
|
81 |
-
def match(self, input: JobInput) -> bool:
|
82 |
-
raise NotImplementedError
|
83 |
-
|
84 |
-
def get_name(self) -> str:
|
85 |
-
raise NotImplementedError
|
86 |
|
87 |
|
88 |
-
class
|
89 |
def match(self, input: JobInput) -> bool:
|
90 |
return True
|
91 |
|
92 |
def process(self, input: JobInput) -> str:
|
93 |
return input.content
|
94 |
|
95 |
-
def get_name(self) -> str:
|
96 |
-
return self.__class__.__name__
|
97 |
-
|
98 |
|
99 |
-
class
|
100 |
def __init__(self) -> None:
|
101 |
self.client = httpx.Client()
|
102 |
self.regex = re.compile(r"(https?://[^\s]+)")
|
@@ -118,26 +171,22 @@ class PlainUrlProcessor(Processor):
|
|
118 |
text = self.template.format(url=self.url, content=text)
|
119 |
return text
|
120 |
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
123 |
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
self.registry: list[Processor] = []
|
128 |
-
self.default_registry: list[Processor] = []
|
129 |
-
self.set_default_processors()
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
self.registry.append(processor)
|
136 |
-
|
137 |
-
def dispatch(self, input: JobInput) -> Processor:
|
138 |
-
for processor in self.registry + self.default_registry:
|
139 |
-
if processor.match(input):
|
140 |
-
return processor
|
141 |
|
142 |
-
|
143 |
-
|
|
|
1 |
import abc
|
2 |
+
from typing import Any
|
3 |
import logging
|
4 |
import re
|
5 |
|
6 |
import httpx
|
|
|
7 |
|
8 |
from base import JobInput
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logger.setLevel(logging.DEBUG)
|
12 |
|
|
|
|
|
|
|
13 |
|
14 |
+
class Processor(abc.ABC):
|
15 |
+
def get_name(self) -> str:
|
16 |
+
return self.__class__.__name__
|
17 |
+
|
18 |
+
def __call__(self, job: JobInput) -> str:
|
19 |
+
_id = job.id
|
20 |
+
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
|
21 |
+
result = self.process(job)
|
22 |
+
logger.info(f"Finished processing input (id={_id[:8]})")
|
23 |
+
return result
|
24 |
|
25 |
+
@abc.abstractmethod
|
26 |
+
def process(self, input: JobInput) -> str:
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
@abc.abstractmethod
|
30 |
+
def match(self, input: JobInput) -> bool:
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
|
34 |
+
class Summarizer(abc.ABC):
|
35 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
def get_name(self) -> str:
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
@abc.abstractmethod
|
42 |
+
def __call__(self, x: str) -> str:
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
|
46 |
+
class Tagger(abc.ABC):
|
47 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
def get_name(self) -> str:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
@abc.abstractmethod
|
54 |
+
def __call__(self, x: str) -> list[str]:
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
|
58 |
+
class MlRegistry:
|
59 |
def __init__(self) -> None:
|
60 |
+
self.processors: list[Processor] = []
|
61 |
+
self.summerizer: Summarizer | None = None
|
62 |
+
self.tagger: Tagger | None = None
|
63 |
+
self.model = None
|
64 |
+
self.tokenizer = None
|
65 |
+
|
66 |
+
def register_processor(self, processor: Processor) -> None:
|
67 |
+
self.processors.append(processor)
|
68 |
+
|
69 |
+
def register_summarizer(self, summarizer: Summarizer) -> None:
|
70 |
+
self.summerizer = summarizer
|
71 |
+
|
72 |
+
def register_tagger(self, tagger: Tagger) -> None:
|
73 |
+
self.tagger = tagger
|
74 |
+
|
75 |
+
def get_processor(self, input: JobInput) -> Processor:
|
76 |
+
assert self.processors
|
77 |
+
for processor in self.processors:
|
78 |
+
if processor.match(input):
|
79 |
+
return processor
|
80 |
+
|
81 |
+
return RawTextProcessor()
|
82 |
+
|
83 |
+
def get_summarizer(self) -> Summarizer:
|
84 |
+
assert self.summerizer
|
85 |
+
return self.summerizer
|
86 |
+
|
87 |
+
def get_tagger(self) -> Tagger:
|
88 |
+
assert self.tagger
|
89 |
+
return self.tagger
|
90 |
+
|
91 |
+
|
92 |
+
class HfTransformersSummarizer(Summarizer):
|
93 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
94 |
+
self.model_name = model_name
|
95 |
+
self.model = model
|
96 |
+
self.tokenizer = tokenizer
|
97 |
+
self.generation_config = generation_config
|
98 |
+
|
99 |
self.template = "Summarize the text below in two sentences:\n\n{}"
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
def __call__(self, x: str) -> str:
|
102 |
text = self.template.format(x)
|
103 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
104 |
+
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
|
105 |
+
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
106 |
assert isinstance(output, str)
|
107 |
return output
|
108 |
|
109 |
def get_name(self) -> str:
|
110 |
+
return f"{self.__class__.__name__}({self.model_name})"
|
111 |
|
112 |
|
113 |
+
class HfTransformersTagger(Tagger):
|
114 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
115 |
+
self.model_name = model_name
|
116 |
+
self.model = model
|
117 |
+
self.tokenizer = tokenizer
|
118 |
+
self.generation_config = generation_config
|
119 |
+
|
120 |
self.template = (
|
121 |
"Create a list of tags for the text below. The tags should be high level "
|
122 |
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
|
123 |
)
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
def _extract_tags(self, text: str) -> list[str]:
|
126 |
tags = set()
|
|
|
131 |
|
132 |
def __call__(self, x: str) -> list[str]:
|
133 |
text = self.template.format(x)
|
134 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
135 |
+
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
|
136 |
+
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
137 |
tags = self._extract_tags(output)
|
138 |
return tags
|
139 |
|
140 |
def get_name(self) -> str:
|
141 |
+
return f"{self.__class__.__name__}({self.model_name})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
|
144 |
+
class RawTextProcessor(Processor):
|
145 |
def match(self, input: JobInput) -> bool:
|
146 |
return True
|
147 |
|
148 |
def process(self, input: JobInput) -> str:
|
149 |
return input.content
|
150 |
|
|
|
|
|
|
|
151 |
|
152 |
+
class DefaultUrlProcessor(Processor):
|
153 |
def __init__(self) -> None:
|
154 |
self.client = httpx.Client()
|
155 |
self.regex = re.compile(r"(https?://[^\s]+)")
|
|
|
171 |
text = self.template.format(url=self.url, content=text)
|
172 |
return text
|
173 |
|
174 |
+
# class ProcessorRegistry:
|
175 |
+
# def __init__(self) -> None:
|
176 |
+
# self.registry: list[Processor] = []
|
177 |
+
# self.default_registry: list[Processor] = []
|
178 |
+
# self.set_default_processors()
|
179 |
|
180 |
+
# def set_default_processors(self) -> None:
|
181 |
+
# self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
|
182 |
|
183 |
+
# def register(self, processor: Processor) -> None:
|
184 |
+
# self.registry.append(processor)
|
|
|
|
|
|
|
185 |
|
186 |
+
# def dispatch(self, input: JobInput) -> Processor:
|
187 |
+
# for processor in self.registry + self.default_registry:
|
188 |
+
# if processor.match(input):
|
189 |
+
# return processor
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
+
# # should never be requires, but eh
|
192 |
+
# return RawProcessor()
|
src/webservice.py
CHANGED
@@ -14,6 +14,12 @@ logger.setLevel(logging.DEBUG)
|
|
14 |
app = FastAPI()
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@app.post("/submit/")
|
18 |
def submit_job(input: RequestInput) -> str:
|
19 |
# submit a new job, poor man's job queue
|
|
|
14 |
app = FastAPI()
|
15 |
|
16 |
|
17 |
+
# status
|
18 |
+
@app.get("/status/")
|
19 |
+
def status() -> str:
|
20 |
+
return "OK"
|
21 |
+
|
22 |
+
|
23 |
@app.post("/submit/")
|
24 |
def submit_job(input: RequestInput) -> str:
|
25 |
# submit a new job, poor man's job queue
|
src/worker.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1 |
import time
|
|
|
2 |
|
3 |
from base import JobInput
|
4 |
from db import get_db_cursor
|
5 |
-
from ml import
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
SLEEP_INTERVAL = 5
|
8 |
|
9 |
|
10 |
-
processor_registry = ProcessorRegistry()
|
11 |
-
summarizer = Summarizer()
|
12 |
-
tagger = Tagger()
|
13 |
-
print("loaded ML models")
|
14 |
-
|
15 |
-
|
16 |
def check_pending_jobs() -> list[JobInput]:
|
17 |
"""Check DB for pending jobs"""
|
18 |
with get_db_cursor() as cursor:
|
@@ -30,15 +31,38 @@ def check_pending_jobs() -> list[JobInput]:
|
|
30 |
]
|
31 |
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
with get_db_cursor() as cursor:
|
43 |
# write to entries, summary, tags tables
|
44 |
cursor.execute(
|
@@ -46,39 +70,23 @@ def store(
|
|
46 |
"INSERT INTO summaries (entry_id, summary, summarizer_name)"
|
47 |
" VALUES (?, ?, ?)"
|
48 |
),
|
49 |
-
(job.id, summary, summarizer_name),
|
50 |
)
|
51 |
cursor.executemany(
|
52 |
"INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
|
53 |
-
[(job.id, tag, tagger_name) for tag in tags],
|
54 |
)
|
55 |
|
56 |
|
57 |
-
def process_job(job: JobInput) -> None:
|
58 |
tic = time.perf_counter()
|
59 |
print(f"Processing job for (id={job.id[:8]})")
|
60 |
|
61 |
# care: acquire cursor (which leads to locking) as late as possible, since
|
62 |
# the processing and we don't want to block other workers during that time
|
63 |
try:
|
64 |
-
|
65 |
-
|
66 |
-
processed = processor(job)
|
67 |
-
|
68 |
-
tagger_name = tagger.get_name()
|
69 |
-
tags = tagger(processed)
|
70 |
-
|
71 |
-
summarizer_name = summarizer.get_name()
|
72 |
-
summary = summarizer(processed)
|
73 |
-
|
74 |
-
store(
|
75 |
-
job,
|
76 |
-
summary=summary,
|
77 |
-
tags=tags,
|
78 |
-
processor_name=processor_name,
|
79 |
-
summarizer_name=summarizer_name,
|
80 |
-
tagger_name=tagger_name,
|
81 |
-
)
|
82 |
# update job status to done
|
83 |
with get_db_cursor() as cursor:
|
84 |
cursor.execute(
|
@@ -96,7 +104,40 @@ def process_job(job: JobInput) -> None:
|
|
96 |
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def main() -> None:
|
|
|
|
|
|
|
100 |
while True:
|
101 |
jobs = check_pending_jobs()
|
102 |
if not jobs:
|
@@ -106,7 +147,7 @@ def main() -> None:
|
|
106 |
|
107 |
print(f"Found {len(jobs)} pending job(s), processing...")
|
108 |
for job in jobs:
|
109 |
-
process_job(job)
|
110 |
|
111 |
|
112 |
if __name__ == "__main__":
|
|
|
1 |
import time
|
2 |
+
from dataclasses import dataclass
|
3 |
|
4 |
from base import JobInput
|
5 |
from db import get_db_cursor
|
6 |
+
from ml import (
|
7 |
+
DefaultUrlProcessor,
|
8 |
+
HfTransformersSummarizer,
|
9 |
+
HfTransformersTagger,
|
10 |
+
MlRegistry,
|
11 |
+
RawTextProcessor,
|
12 |
+
)
|
13 |
|
14 |
SLEEP_INTERVAL = 5
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def check_pending_jobs() -> list[JobInput]:
|
18 |
"""Check DB for pending jobs"""
|
19 |
with get_db_cursor() as cursor:
|
|
|
31 |
]
|
32 |
|
33 |
|
34 |
+
@dataclass
|
35 |
+
class JobOutput:
|
36 |
+
summary: str
|
37 |
+
tags: list[str]
|
38 |
+
processor_name: str
|
39 |
+
summarizer_name: str
|
40 |
+
tagger_name: str
|
41 |
+
|
42 |
+
|
43 |
+
def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
|
44 |
+
processor = registry.get_processor(job)
|
45 |
+
processor_name = processor.get_name()
|
46 |
+
processed = processor(job)
|
47 |
+
|
48 |
+
tagger = registry.get_tagger()
|
49 |
+
tagger_name = tagger.get_name()
|
50 |
+
tags = tagger(processed)
|
51 |
+
|
52 |
+
summarizer = registry.get_summarizer()
|
53 |
+
summarizer_name = summarizer.get_name()
|
54 |
+
summary = summarizer(processed)
|
55 |
+
|
56 |
+
return JobOutput(
|
57 |
+
summary=summary,
|
58 |
+
tags=tags,
|
59 |
+
processor_name=processor_name,
|
60 |
+
summarizer_name=summarizer_name,
|
61 |
+
tagger_name=tagger_name,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def store(job: JobInput, output: JobOutput) -> None:
|
66 |
with get_db_cursor() as cursor:
|
67 |
# write to entries, summary, tags tables
|
68 |
cursor.execute(
|
|
|
70 |
"INSERT INTO summaries (entry_id, summary, summarizer_name)"
|
71 |
" VALUES (?, ?, ?)"
|
72 |
),
|
73 |
+
(job.id, output.summary, output.summarizer_name),
|
74 |
)
|
75 |
cursor.executemany(
|
76 |
"INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
|
77 |
+
[(job.id, tag, output.tagger_name) for tag in output.tags],
|
78 |
)
|
79 |
|
80 |
|
81 |
+
def process_job(job: JobInput, registry: MlRegistry) -> None:
|
82 |
tic = time.perf_counter()
|
83 |
print(f"Processing job for (id={job.id[:8]})")
|
84 |
|
85 |
# care: acquire cursor (which leads to locking) as late as possible, since
|
86 |
# the processing and we don't want to block other workers during that time
|
87 |
try:
|
88 |
+
output = _process_job(job, registry)
|
89 |
+
store(job, output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
# update job status to done
|
91 |
with get_db_cursor() as cursor:
|
92 |
cursor.execute(
|
|
|
104 |
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
105 |
|
106 |
|
107 |
+
def load_mlregistry(model_name: str) -> MlRegistry:
|
108 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
109 |
+
|
110 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
112 |
+
|
113 |
+
config_summarizer = GenerationConfig.from_pretrained(model_name)
|
114 |
+
config_summarizer.max_new_tokens = 200
|
115 |
+
config_summarizer.min_new_tokens = 100
|
116 |
+
config_summarizer.top_k = 5
|
117 |
+
config_summarizer.repetition_penalty = 1.5
|
118 |
+
|
119 |
+
config_tagger = GenerationConfig.from_pretrained(model_name)
|
120 |
+
config_tagger.max_new_tokens = 50
|
121 |
+
config_tagger.min_new_tokens = 25
|
122 |
+
# increase the temperature to make the model more creative
|
123 |
+
config_tagger.temperature = 1.5
|
124 |
+
|
125 |
+
summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer)
|
126 |
+
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
127 |
+
|
128 |
+
registry = MlRegistry()
|
129 |
+
registry.register_processor(DefaultUrlProcessor())
|
130 |
+
registry.register_processor(RawTextProcessor())
|
131 |
+
registry.register_summarizer(summarizer)
|
132 |
+
registry.register_tagger(tagger)
|
133 |
+
|
134 |
+
return registry
|
135 |
+
|
136 |
+
|
137 |
def main() -> None:
|
138 |
+
model_name = "google/flan-t5-large"
|
139 |
+
registry = load_mlregistry(model_name)
|
140 |
+
|
141 |
while True:
|
142 |
jobs = check_pending_jobs()
|
143 |
if not jobs:
|
|
|
147 |
|
148 |
print(f"Found {len(jobs)} pending job(s), processing...")
|
149 |
for job in jobs:
|
150 |
+
process_job(job, registry)
|
151 |
|
152 |
|
153 |
if __name__ == "__main__":
|