Spaces:
Running
on
Zero
Running
on
Zero
Update lib/graph_extract.py
Browse filesFormatting, prompt adherence, exception handling, flash attn attempt
- lib/graph_extract.py +25 -18
lib/graph_extract.py
CHANGED
@@ -43,7 +43,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
43 |
tokenizer = AutoTokenizer.from_pretrained(
|
44 |
"sciphi/triplex",
|
45 |
trust_remote_code=True,
|
46 |
-
attn_implementation="flash_attention_2",
|
47 |
torch_dtype=torch.bfloat16,
|
48 |
)
|
49 |
|
@@ -59,10 +59,13 @@ generation_config.pad_token_id = tokenizer.eos_token_id
|
|
59 |
@spaces.GPU
|
60 |
def triplextract(text, entity_types, predicates):
|
61 |
input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
|
|
|
62 |
**Entity Types:**
|
63 |
{entity_types}
|
|
|
64 |
**Predicates:**
|
65 |
{predicates}
|
|
|
66 |
**Text:**
|
67 |
{text}
|
68 |
"""
|
@@ -103,7 +106,7 @@ def triplextract(text, entity_types, predicates):
|
|
103 |
return "Error: CUDA out of memory."
|
104 |
except Exception as e:
|
105 |
print(f"Error in generation: {e}")
|
106 |
-
return f"Error in generation: {str(e)}"
|
107 |
|
108 |
def parse_triples(prediction):
|
109 |
entities = {}
|
@@ -125,20 +128,24 @@ def parse_triples(prediction):
|
|
125 |
|
126 |
for item in items:
|
127 |
if isinstance(item, str):
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
return entities, relationships
|
|
|
43 |
tokenizer = AutoTokenizer.from_pretrained(
|
44 |
"sciphi/triplex",
|
45 |
trust_remote_code=True,
|
46 |
+
attn_implementation="flash_attention_2" if flash_attn_installed else None,
|
47 |
torch_dtype=torch.bfloat16,
|
48 |
)
|
49 |
|
|
|
59 |
@spaces.GPU
|
60 |
def triplextract(text, entity_types, predicates):
|
61 |
input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
|
62 |
+
|
63 |
**Entity Types:**
|
64 |
{entity_types}
|
65 |
+
|
66 |
**Predicates:**
|
67 |
{predicates}
|
68 |
+
|
69 |
**Text:**
|
70 |
{text}
|
71 |
"""
|
|
|
106 |
return "Error: CUDA out of memory."
|
107 |
except Exception as e:
|
108 |
print(f"Error in generation: {e}")
|
109 |
+
return f"Error in generation, please try again: {str(e)}"
|
110 |
|
111 |
def parse_triples(prediction):
|
112 |
entities = {}
|
|
|
128 |
|
129 |
for item in items:
|
130 |
if isinstance(item, str):
|
131 |
+
try:
|
132 |
+
if ":" in item:
|
133 |
+
id, entity = item.split(",", 1)
|
134 |
+
id = id.strip("[]").strip()
|
135 |
+
entity_type, entity_value = entity.split(":", 1)
|
136 |
+
entities[id] = {
|
137 |
+
"type": entity_type.strip(),
|
138 |
+
"value": entity_value.strip(),
|
139 |
+
}
|
140 |
+
else:
|
141 |
+
parts = item.split()
|
142 |
+
if len(parts) >= 3:
|
143 |
+
source = parts[0].strip("[]")
|
144 |
+
relation = " ".join(parts[1:-1])
|
145 |
+
target = parts[-1].strip("[]")
|
146 |
+
relationships.append((source, relation.strip(), target))
|
147 |
+
except Exception as e:
|
148 |
+
# TODO: Handle gracefully
|
149 |
+
print(f"Error in processing: {item}: {e}")
|
150 |
+
|
151 |
return entities, relationships
|