Remsky commited on
Commit
25b0fd7
·
verified ·
1 Parent(s): d0f6106

Update lib/graph_extract.py

Browse files

Formatting, prompt adherence, exception handling, flash attn attempt

Files changed (1) hide show
  1. lib/graph_extract.py +25 -18
lib/graph_extract.py CHANGED
@@ -43,7 +43,7 @@ model = AutoModelForCausalLM.from_pretrained(
43
  tokenizer = AutoTokenizer.from_pretrained(
44
  "sciphi/triplex",
45
  trust_remote_code=True,
46
- attn_implementation="flash_attention_2",
47
  torch_dtype=torch.bfloat16,
48
  )
49
 
@@ -59,10 +59,13 @@ generation_config.pad_token_id = tokenizer.eos_token_id
59
  @spaces.GPU
60
  def triplextract(text, entity_types, predicates):
61
  input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
 
62
  **Entity Types:**
63
  {entity_types}
 
64
  **Predicates:**
65
  {predicates}
 
66
  **Text:**
67
  {text}
68
  """
@@ -103,7 +106,7 @@ def triplextract(text, entity_types, predicates):
103
  return "Error: CUDA out of memory."
104
  except Exception as e:
105
  print(f"Error in generation: {e}")
106
- return f"Error in generation: {str(e)}"
107
 
108
  def parse_triples(prediction):
109
  entities = {}
@@ -125,20 +128,24 @@ def parse_triples(prediction):
125
 
126
  for item in items:
127
  if isinstance(item, str):
128
- if ":" in item:
129
- id, entity = item.split(",", 1)
130
- id = id.strip("[]").strip()
131
- entity_type, entity_value = entity.split(":", 1)
132
- entities[id] = {
133
- "type": entity_type.strip(),
134
- "value": entity_value.strip(),
135
- }
136
- else:
137
- parts = item.split()
138
- if len(parts) >= 3:
139
- source = parts[0].strip("[]")
140
- relation = " ".join(parts[1:-1])
141
- target = parts[-1].strip("[]")
142
- relationships.append((source, relation.strip(), target))
143
-
 
 
 
 
144
  return entities, relationships
 
43
  tokenizer = AutoTokenizer.from_pretrained(
44
  "sciphi/triplex",
45
  trust_remote_code=True,
46
+ attn_implementation="flash_attention_2" if flash_attn_installed else None,
47
  torch_dtype=torch.bfloat16,
48
  )
49
 
 
59
  @spaces.GPU
60
  def triplextract(text, entity_types, predicates):
61
  input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
62
+
63
  **Entity Types:**
64
  {entity_types}
65
+
66
  **Predicates:**
67
  {predicates}
68
+
69
  **Text:**
70
  {text}
71
  """
 
106
  return "Error: CUDA out of memory."
107
  except Exception as e:
108
  print(f"Error in generation: {e}")
109
+ return f"Error in generation, please try again: {str(e)}"
110
 
111
  def parse_triples(prediction):
112
  entities = {}
 
128
 
129
  for item in items:
130
  if isinstance(item, str):
131
+ try:
132
+ if ":" in item:
133
+ id, entity = item.split(",", 1)
134
+ id = id.strip("[]").strip()
135
+ entity_type, entity_value = entity.split(":", 1)
136
+ entities[id] = {
137
+ "type": entity_type.strip(),
138
+ "value": entity_value.strip(),
139
+ }
140
+ else:
141
+ parts = item.split()
142
+ if len(parts) >= 3:
143
+ source = parts[0].strip("[]")
144
+ relation = " ".join(parts[1:-1])
145
+ target = parts[-1].strip("[]")
146
+ relationships.append((source, relation.strip(), target))
147
+ except Exception as e:
148
+ # TODO: Handle gracefully
149
+ print(f"Error in processing: {item}: {e}")
150
+
151
  return entities, relationships