ragavsachdeva commited on
Commit
8e093d6
·
verified ·
1 Parent(s): 88cd128

Update modelling_magi.py

Browse files
Files changed (1) hide show
  1. modelling_magi.py +15 -1
modelling_magi.py CHANGED
@@ -242,12 +242,15 @@ class MagiModel(PreTrainedModel):
242
  file.write(transript)
243
  return transript
244
 
245
- def get_text_character_affinity_matrices_given_annotations(
246
  self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
247
  ):
248
  assert not self.config.disable_detections
249
  move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
250
 
 
 
 
251
  inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
252
  inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
253
  processed_targets = inputs_to_detection_transformer.pop("labels")
@@ -255,6 +258,7 @@ class MagiModel(PreTrainedModel):
255
  detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
256
  predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
257
  predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
 
258
 
259
  predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
260
  matching_dict = {
@@ -266,6 +270,7 @@ class MagiModel(PreTrainedModel):
266
  matched_char_obj_tokens_for_batch = []
267
  matched_text_obj_tokens_for_batch = []
268
  t2c_tokens_for_batch = []
 
269
 
270
  text_bboxes_for_batch = []
271
  character_bboxes_for_batch = []
@@ -288,6 +293,7 @@ class MagiModel(PreTrainedModel):
288
  matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
289
  matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
290
  t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
 
291
 
292
  text_character_affinity_matrices = self._get_text_character_affinity_matrices(
293
  character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
@@ -296,8 +302,16 @@ class MagiModel(PreTrainedModel):
296
  apply_sigmoid=apply_sigmoid,
297
  )
298
 
 
 
 
 
 
 
 
299
  return {
300
  "text_character_affinity_matrices": text_character_affinity_matrices,
 
301
  "text_bboxes_for_batch": text_bboxes_for_batch,
302
  "character_bboxes_for_batch": character_bboxes_for_batch,
303
  }
 
242
  file.write(transript)
243
  return transript
244
 
245
+ def get_affinity_matrices_given_annotations(
246
  self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
247
  ):
248
  assert not self.config.disable_detections
249
  move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
250
 
251
+ character_bboxes_in_batch = [[bbox for bbox, label in zip(a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
252
+ crop_embeddings_for_batch = self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn)
253
+
254
  inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
255
  inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
256
  processed_targets = inputs_to_detection_transformer.pop("labels")
 
258
  detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
259
  predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
260
  predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
261
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
262
 
263
  predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
264
  matching_dict = {
 
270
  matched_char_obj_tokens_for_batch = []
271
  matched_text_obj_tokens_for_batch = []
272
  t2c_tokens_for_batch = []
273
+ c2c_tokens_for_batch = []
274
 
275
  text_bboxes_for_batch = []
276
  character_bboxes_for_batch = []
 
293
  matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
294
  matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
295
  t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
296
+ c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
297
 
298
  text_character_affinity_matrices = self._get_text_character_affinity_matrices(
299
  character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
 
302
  apply_sigmoid=apply_sigmoid,
303
  )
304
 
305
+ character_character_affinity_matrices = self._get_character_character_affinity_matrices(
306
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
307
+ crop_embeddings_for_batch=crop_embeddings_for_batch,
308
+ c2c_tokens_for_batch=c2c_tokens_for_batch,
309
+ apply_sigmoid=apply_sigmoid,
310
+ )
311
+
312
  return {
313
  "text_character_affinity_matrices": text_character_affinity_matrices,
314
+ "character_character_affinity_matrices": character_character_affinity_matrices,
315
  "text_bboxes_for_batch": text_bboxes_for_batch,
316
  "character_bboxes_for_batch": character_bboxes_for_batch,
317
  }