Update modelling_magi.py
Browse files- modelling_magi.py +15 -1
modelling_magi.py
CHANGED
@@ -242,12 +242,15 @@ class MagiModel(PreTrainedModel):
|
|
242 |
file.write(transript)
|
243 |
return transript
|
244 |
|
245 |
-
def
|
246 |
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
247 |
):
|
248 |
assert not self.config.disable_detections
|
249 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
250 |
|
|
|
|
|
|
|
251 |
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
252 |
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
253 |
processed_targets = inputs_to_detection_transformer.pop("labels")
|
@@ -255,6 +258,7 @@ class MagiModel(PreTrainedModel):
|
|
255 |
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
256 |
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
257 |
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
|
|
258 |
|
259 |
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
260 |
matching_dict = {
|
@@ -266,6 +270,7 @@ class MagiModel(PreTrainedModel):
|
|
266 |
matched_char_obj_tokens_for_batch = []
|
267 |
matched_text_obj_tokens_for_batch = []
|
268 |
t2c_tokens_for_batch = []
|
|
|
269 |
|
270 |
text_bboxes_for_batch = []
|
271 |
character_bboxes_for_batch = []
|
@@ -288,6 +293,7 @@ class MagiModel(PreTrainedModel):
|
|
288 |
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
289 |
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
290 |
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
|
|
291 |
|
292 |
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
293 |
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
@@ -296,8 +302,16 @@ class MagiModel(PreTrainedModel):
|
|
296 |
apply_sigmoid=apply_sigmoid,
|
297 |
)
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
return {
|
300 |
"text_character_affinity_matrices": text_character_affinity_matrices,
|
|
|
301 |
"text_bboxes_for_batch": text_bboxes_for_batch,
|
302 |
"character_bboxes_for_batch": character_bboxes_for_batch,
|
303 |
}
|
|
|
242 |
file.write(transript)
|
243 |
return transript
|
244 |
|
245 |
+
def get_affinity_matrices_given_annotations(
|
246 |
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
247 |
):
|
248 |
assert not self.config.disable_detections
|
249 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
250 |
|
251 |
+
character_bboxes_in_batch = [[bbox for bbox, label in zip(a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
|
252 |
+
crop_embeddings_for_batch = self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn)
|
253 |
+
|
254 |
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
255 |
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
256 |
processed_targets = inputs_to_detection_transformer.pop("labels")
|
|
|
258 |
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
259 |
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
260 |
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
261 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
|
262 |
|
263 |
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
264 |
matching_dict = {
|
|
|
270 |
matched_char_obj_tokens_for_batch = []
|
271 |
matched_text_obj_tokens_for_batch = []
|
272 |
t2c_tokens_for_batch = []
|
273 |
+
c2c_tokens_for_batch = []
|
274 |
|
275 |
text_bboxes_for_batch = []
|
276 |
character_bboxes_for_batch = []
|
|
|
293 |
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
294 |
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
295 |
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
296 |
+
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
|
297 |
|
298 |
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
299 |
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
|
|
302 |
apply_sigmoid=apply_sigmoid,
|
303 |
)
|
304 |
|
305 |
+
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
|
306 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
307 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
308 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
309 |
+
apply_sigmoid=apply_sigmoid,
|
310 |
+
)
|
311 |
+
|
312 |
return {
|
313 |
"text_character_affinity_matrices": text_character_affinity_matrices,
|
314 |
+
"character_character_affinity_matrices": character_character_affinity_matrices,
|
315 |
"text_bboxes_for_batch": text_bboxes_for_batch,
|
316 |
"character_bboxes_for_batch": character_bboxes_for_batch,
|
317 |
}
|