Spaces:
Runtime error
Runtime error
| import copy | |
| from itertools import chain | |
| from typing import Dict, Optional, Sequence, Type | |
| import torch | |
| from pie_modules.annotations import BinaryCorefRelation | |
| from pie_modules.document.processing.text_pair import shift_span | |
| from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule | |
| from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter | |
| from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory | |
| from pie_modules.taskmodules.re_text_classification_with_indices import ( | |
| ModelTargetType as REModelTargetType, | |
| ) | |
| from pie_modules.taskmodules.re_text_classification_with_indices import ( | |
| TaskOutputType as RETaskOutputType, | |
| ) | |
| from pytorch_ie import Document, TaskModule | |
| from pytorch_ie.annotations import LabeledSpan | |
| from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations | |
| class SharpBracketMarkerFactory(MarkerFactory): | |
| def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str: | |
| result = "<" | |
| if not is_start: | |
| result += "/" | |
| result += self._get_role_marker(role) | |
| if label is not None: | |
| result += f":{label}" | |
| result += ">" | |
| return result | |
| def get_append_marker(self, role: str, label: Optional[str] = None) -> str: | |
| role_marker = self._get_role_marker(role) | |
| if label is None: | |
| return f"<{role_marker}>" | |
| else: | |
| return f"<{role_marker}={label}>" | |
| class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers( | |
| RETextClassificationWithIndicesTaskModule | |
| ): | |
| def __init__(self, use_sharp_marker: bool = False, **kwargs): | |
| super().__init__(**kwargs) | |
| self.use_sharp_marker = use_sharp_marker | |
| def get_marker_factory(self) -> MarkerFactory: | |
| if self.use_sharp_marker: | |
| return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker) | |
| else: | |
| return MarkerFactory(role_to_marker=self.argument_role_to_marker) | |
| def construct_text_document_from_text_pair_coref_document( | |
| document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, | |
| glue_text: str, | |
| no_relation_label: str, | |
| relation_label_mapping: Optional[Dict[str, str]] = None, | |
| add_span_mapping_to_metadata: bool = False, | |
| ) -> TextDocumentWithLabeledSpansAndBinaryRelations: | |
| if document.text == document.text_pair: | |
| new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( | |
| id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text | |
| ) | |
| old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} | |
| new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} | |
| for old_span in chain(document.labeled_spans, document.labeled_spans_pair): | |
| new_span = old_span.copy() | |
| # when detaching / copying the span, it may be the same as a previous span from the other | |
| new_span = new2new_spans.get(new_span, new_span) | |
| new2new_spans[new_span] = new_span | |
| old2new_spans[old_span] = new_span | |
| else: | |
| new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( | |
| text=document.text + glue_text + document.text_pair, | |
| id=document.id, | |
| metadata=copy.deepcopy(document.metadata), | |
| ) | |
| old2new_spans = {} | |
| old2new_spans.update({span: span.copy() for span in document.labeled_spans}) | |
| offset = len(document.text) + len(glue_text) | |
| old2new_spans.update( | |
| {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} | |
| ) | |
| # sort to make order deterministic | |
| new_doc.labeled_spans.extend( | |
| sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) | |
| ) | |
| for old_rel in document.binary_coref_relations: | |
| label = old_rel.label if old_rel.score > 0.0 else no_relation_label | |
| if relation_label_mapping is not None: | |
| label = relation_label_mapping.get(label, label) | |
| new_rel = old_rel.copy( | |
| head=old2new_spans[old_rel.head], | |
| tail=old2new_spans[old_rel.tail], | |
| label=label, | |
| score=1.0, | |
| ) | |
| new_doc.binary_relations.append(new_rel) | |
| if add_span_mapping_to_metadata: | |
| new_doc.metadata["span_mapping"] = old2new_spans | |
| return new_doc | |
| class CrossTextBinaryCorefByRETextClassificationTaskModule( | |
| TaskModuleWithDocumentConverter, | |
| RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers, | |
| ): | |
| def __init__( | |
| self, | |
| coref_relation_label: str, | |
| relation_annotation: str = "binary_relations", | |
| probability_threshold: float = 0.0, | |
| **kwargs, | |
| ): | |
| if relation_annotation != "binary_relations": | |
| raise ValueError( | |
| f"{type(self).__name__} requires relation_annotation='binary_relations', " | |
| f"but it is: {relation_annotation}" | |
| ) | |
| super().__init__(relation_annotation=relation_annotation, **kwargs) | |
| self.coref_relation_label = coref_relation_label | |
| self.probability_threshold = probability_threshold | |
| def document_type(self) -> Optional[Type[Document]]: | |
| return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| def _get_glue_text(self) -> str: | |
| result = self.tokenizer.decode(self._get_glue_token_ids()) | |
| return result | |
| def _convert_document( | |
| self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| ) -> TextDocumentWithLabeledSpansAndBinaryRelations: | |
| return construct_text_document_from_text_pair_coref_document( | |
| document, | |
| glue_text=self._get_glue_text(), | |
| relation_label_mapping={"coref": self.coref_relation_label}, | |
| no_relation_label=self.none_label, | |
| add_span_mapping_to_metadata=True, | |
| ) | |
| def _integrate_predictions_from_converted_document( | |
| self, | |
| document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, | |
| converted_document: TextDocumentWithLabeledSpansAndBinaryRelations, | |
| ) -> None: | |
| original2converted_span = converted_document.metadata["span_mapping"] | |
| new2original_span = { | |
| converted_s: orig_s for orig_s, converted_s in original2converted_span.items() | |
| } | |
| for rel in converted_document.binary_relations.predictions: | |
| original_head = new2original_span[rel.head] | |
| original_tail = new2original_span[rel.tail] | |
| if rel.label != self.coref_relation_label: | |
| raise ValueError(f"unexpected label: {rel.label}") | |
| if rel.score >= self.probability_threshold: | |
| original_predicted_rel = BinaryCorefRelation( | |
| head=original_head, tail=original_tail, label="coref", score=rel.score | |
| ) | |
| document.binary_coref_relations.predictions.append(original_predicted_rel) | |
| def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]: | |
| coref_relation_idx = self.label_to_id[self.coref_relation_label] | |
| # we are just concerned with the coref class, so we overwrite the labels field | |
| model_output = copy.copy(model_output) | |
| model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx | |
| return super().unbatch_output(model_output=model_output) | |